Code for our EMNLP 2021 paper "Learning Kernel-Smoothed Machine Translation with Retrieved Examples"

Related tags

Deep LearningKSTER
Overview

KSTER

Code for our EMNLP 2021 paper "Learning Kernel-Smoothed Machine Translation with Retrieved Examples" [paper].

Usage

Download the processed datasets from this site. You can also download the built databases from this site and download the model checkpoints from this site.

Train a general-domain base model

Take English -> Germain translation for example.

export CUDA_VISIBLE_DEVICES=0,1,2,3
python3 -m joeynmt train configs/transformer_base_wmt14_en2de.yaml

Finetuning trained base model on domain-specific datasets

Take English -> Germain translation in Koran domain for example.

export CUDA_VISIBLE_DEVICES=0,1,2,3
python3 -m joeynmt train configs/transformer_base_koran_en2de.yaml

Build database

Take English -> Germain translation in Koran domain for example, wmt14_en_de.transformer.ckpt is the path of trained general-domain base model checkpoint.

mkdir database/koran_en_de_base
export CUDA_VISIBLE_DEVICES=0
python3 -m joeynmt build_database configs/transformer_base_koran_en2de.yaml \
        --ckpt wmt14_en_de.transformer.ckpt \
        --division train \
        --index_path database/koran_en_de_base/trained.index \
        --token_map_path database/koran_en_de_base/token_map \
        --embedding_path database/koran_en_de_base/embeddings.npy

Train the bandwidth estimator and weight estimator in KSTER

Take English -> Germain translation in Koran domain for example.

export CUDA_VISIBLE_DEVICES=0,1,2,3
python3 -m joeynmt combiner_train configs/transformer_base_koran_en2de.yaml \
        --ckpt wmt14_en_de.transformer.ckpt \
        --combiner dynamic_combiner \
        --top_k 16 \
        --kernel laplacian \
        --index_path database/koran_en_de_base/trained.index \
        --token_map_path database/koran_en_de_base/token_map \
        --embedding_path database/koran_en_de_base/embeddings.npy \
        --in_memory True

Inference

We unify the inference of base model, finetuned or joint-trained model, kNN-MT and KSTER with a concept of combiner (see joeynmt/combiners.py).

Combiner type Methods Description
NoCombiner Base, Finetuning, Joint-training Directly inference without retrieval.
StaticCombiner kNN-MT Retrieve similar examples during inference. mixing_weight and bandwidth are pre-specified.
DynamicCombiner KSTER Retrieve similar examples during inference. mixing_weight and bandwidth are dynamically estimated.

Inference with NoCombiner for Base model

Take English -> Germain translation in Koran domain for example.

export CUDA_VISIBLE_DEVICES=0
python3 -m joeynmt test configs/transformer_base_koran_en2de.yaml \
        --ckpt wmt14_en_de.transformer.ckpt \
        --combiner no_combiner

Inference with StaticCombiner for kNN-MT

Take English -> Germain translation in Koran domain for example.

export CUDA_VISIBLE_DEVICES=0
python3 -m joeynmt test configs/transformer_base_koran_en2de.yaml \
        --ckpt wmt14_en_de.transformer.ckpt \
        --combiner static_combiner \
        --top_k 16 \
        --mixing_weight 0.7 \
        --bandwidth 10 \
        --kernel gaussian \
        --index_path database/koran_en_de_base/trained.index \
        --token_map_path database/koran_en_de_base/token_map

Inference with DynamicCombiner for KSTER

Take English -> Germain translation in Koran domain for example, koran_en_de.laplacian.combiner.ckpt is the path of trained bandwidth estimator and weight estimator for Koran domain.
--in_memory option specifies whether to load the example embeddings to memory. Set in_memory == True for faster inference, set in_memory == False for lower memory demand.

export CUDA_VISIBLE_DEVICES=0
python3 -m joeynmt test configs/transformer_base_koran_en2de.yaml \
        --ckpt wmt14_en_de.transformer.ckpt \
        --combiner dynamic_combiner \
        --combiner_path koran_en_de.laplacian.combiner.ckpt \
        --top_k 16 \
        --kernel laplacian \
        --index_path database/koran_en_de_base/trained.index \
        --token_map_path database/koran_en_de_base/token_map \
        --embedding_path database/koran_en_de_base/embeddings.npy \
        --in_memory True

See bash_scripts/test_*.sh for reproducing our results.
See logs/*.log for the logs of our results.

Acknowledgements

We build the models based on the joeynmt codebase.

Owner
jiangqn
Interested in natural language processing and machine learning.
jiangqn
CLOOB training (JAX) and inference (JAX and PyTorch)

cloob-training Pretrained models There are two pretrained CLOOB models in this repo at the moment, a 16 epoch and a 32 epoch ViT-B/16 checkpoint train

Katherine Crowson 64 Nov 27, 2022
OpenMMLab Video Perception Toolbox. It supports Video Object Detection (VID), Multiple Object Tracking (MOT), Single Object Tracking (SOT), Video Instance Segmentation (VIS) with a unified framework.

English | 简体中文 Documentation: https://mmtracking.readthedocs.io/ Introduction MMTracking is an open source video perception toolbox based on PyTorch.

OpenMMLab 2.7k Jan 08, 2023
BuildingNet: Learning to Label 3D Buildings

BuildingNet This is the implementation of the BuildingNet architecture described in this paper: Paper: BuildingNet: Learning to Label 3D Buildings Arx

16 Nov 07, 2022
[CVPRW 21] "BNN - BN = ? Training Binary Neural Networks without Batch Normalization", Tianlong Chen, Zhenyu Zhang, Xu Ouyang, Zechun Liu, Zhiqiang Shen, Zhangyang Wang

BNN - BN = ? Training Binary Neural Networks without Batch Normalization Codes for this paper BNN - BN = ? Training Binary Neural Networks without Bat

VITA 40 Dec 30, 2022
Auxiliary Raw Net (ARawNet) is a ASVSpoof detection model taking both raw waveform and handcrafted features as inputs, to balance the trade-off between performance and model complexity.

Overview This repository is an implementation of the Auxiliary Raw Net (ARawNet), which is ASVSpoof detection system taking both raw waveform and hand

6 Jul 08, 2022
Contains supplementary materials for reproduce results in HMC divergence time estimation manuscript

Scalable Bayesian divergence time estimation with ratio transformations This repository contains the instructions and files to reproduce the analyses

Suchard Research Group 1 Sep 21, 2022
Repository for tackling Kaggle Ultrasound Nerve Segmentation challenge using Torchnet.

Ultrasound Nerve Segmentation Challenge using Torchnet This repository acts as a starting point for someone who wants to start with the kaggle ultraso

Qure.ai 46 Jul 18, 2022
This project is for a Twitter bot that monitors a bird feeder in my backyard. Any detected birds are identified and posted to Twitter.

Backyard Birdbot Introduction This is a silly hobby project to use existing ML models to: Detect any birds sighted by a webcam Identify whic

Chi Young Moon 71 Dec 25, 2022
Face Recognition & AI Based Smart Attendance Monitoring System.

In today’s generation, authentication is one of the biggest problems in our society. So, one of the most known techniques used for authentication is h

Sagar Saha 1 Jan 14, 2022
(ICCV 2021) ProHMR - Probabilistic Modeling for Human Mesh Recovery

ProHMR - Probabilistic Modeling for Human Mesh Recovery Code repository for the paper: Probabilistic Modeling for Human Mesh Recovery Nikos Kolotouros

Nikos Kolotouros 209 Dec 13, 2022
QAT(quantize aware training) for classification with MQBench

MQBench Quantization Aware Training with PyTorch I am using MQBench(Model Quantization Benchmark)(http://mqbench.tech/) to quantize the model for depl

Ling Zhang 29 Nov 18, 2022
Second-order Attention Network for Single Image Super-resolution (CVPR-2019)

Second-order Attention Network for Single Image Super-resolution (CVPR-2019) "Second-order Attention Network for Single Image Super-resolution" is pub

516 Dec 28, 2022
Model-based 3D Hand Reconstruction via Self-Supervised Learning, CVPR2021

S2HAND: Model-based 3D Hand Reconstruction via Self-Supervised Learning S2HAND presents a self-supervised 3D hand reconstruction network that can join

Yujin Chen 72 Dec 12, 2022
Machine Learning Time-Series Platform

cesium: Open-Source Platform for Time Series Inference Summary cesium is an open source library that allows users to: extract features from raw time s

632 Dec 26, 2022
TensorFlow Implementation of Unsupervised Cross-Domain Image Generation

Domain Transfer Network (DTN) TensorFlow implementation of Unsupervised Cross-Domain Image Generation. Requirements Python 2.7 TensorFlow 0.12 Pickle

Yunjey Choi 865 Nov 17, 2022
Practical and Real-world applications of ML based on the homework of Hung-yi Lee Machine Learning Course 2021

Machine Learning Theory and Application Overview This repository is inspired by the Hung-yi Lee Machine Learning Course 2021. In that course, professo

SilenceJiang 35 Nov 22, 2022
Simple SN-GAN to generate CryptoPunks

CryptoPunks GAN Simple SN-GAN to generate CryptoPunks. Neural network architecture and training code has been modified from the PyTorch DCGAN example.

Teddy Koker 66 Dec 15, 2022
Image Deblurring using Generative Adversarial Networks

DeblurGAN arXiv Paper Version Pytorch implementation of the paper DeblurGAN: Blind Motion Deblurring Using Conditional Adversarial Networks. Our netwo

Orest Kupyn 2.2k Jan 01, 2023
Pipeline for employing a Lightweight deep learning models for LOW-power systems

PL-LOW A high-performance deep learning model lightweight pipeline that gradually lightens deep neural networks in order to utilize high-performance d

POSTECH Data Intelligence Lab 9 Aug 13, 2022
git git《Transformer Meets Tracker: Exploiting Temporal Context for Robust Visual Tracking》(CVPR 2021) GitHub:git2] 《Masksembles for Uncertainty Estimation》(CVPR 2021) GitHub:git3]

Transformer Meets Tracker: Exploiting Temporal Context for Robust Visual Tracking Ning Wang, Wengang Zhou, Jie Wang, and Houqiang Li Accepted by CVPR

NingWang 236 Dec 22, 2022