NJUNMT-tf is a general purpose sequence modeling tool in TensorFlow while neural machine translation is the main target task.
NJUNMT-tf builds NMT models almost from scratch without any high-level TensorFlow APIs which often hide details of many network components and lead to obscure code structure that is difficult to understand and manipulate. NJUNMT-tf only depends on basic TensorFlow modules, like array_ops, math_ops and nn_ops. Each operation in the code is under control.
NJUNMT-tf focuses on modularity and extensibility using standard TensorFlow modules and practices to support advanced modeling capability:
and all of the above can be used simultaneously to train novel and complex architectures.
The code also supports:
tensorflow
(>=1.6
)pyyaml
Here is a minimal workflow to get you started in using NJUNMT-tf. This example uses a toy Chinese-English dataset for machine translation with a toy setting.
1. Build the word vocabularies:
python -m bin.generate_vocab testdata/toy.zh --max_vocab_size 100 > testdata/vocab.zh
python -m bin.generate_vocab testdata/toy.en0 --max_vocab_size 100 > testdata/vocab.en
2. Train with preset sequence-to-sequence parameters:
export CUDA_VISIBLE_DEVICES=
python -m bin.train --model_dir test_model \
--config_paths "
./njunmt/example_configs/toy_seq2seq.yml,
./njunmt/example_configs/toy_training_options.yml,
./default_configs/default_optimizer.yml"
3. Translate a test file with the latest checkpoint:
export CUDA_VISIBLE_DEVICES=
python -m bin.infer --model_dir test_models \
--infer "
beam_size: 4
source_words_vocabulary: testdata/vocab.zh
target_words_vocabulary: testdata/vocab.en" \
--infer_data "
- features_file: testdata/toy.zh
labels_file: testdata/toy.en
output_file: toy.trans
output_attention: false"
Note: do not expect any good translation results with this toy example. Consider training on larger parallel datasets instead.
As you can see, there are two ways to manipulate hyperparameters of the process:
For example, there is a config file specifying the datasets for training procedure.
# datasets.yml
data:
train_features_file: testdata/toy.zh
train_labels_file: testdata/toy.en0
eval_features_file: testdata/toy.zh
eval_labels_file: testdata/toy.en
source_words_vocabulary: testdata/vocab.zh
target_words_vocabulary: testdata/vocab.en
You can either use the command:
python -m bin.train --config_paths "datasets.yml" ...
or
python -m bin.train --data "
train_features_file: testdata/toy.zh
train_labels_file: testdata/toy.en0
eval_features_file: testdata/toy.zh
eval_labels_file: testdata/toy.en
source_words_vocabulary: testdata/vocab.zh
target_words_vocabulary: testdata/vocab.en" ...
They are of the same effect.
The available FLAGS (or the top levels of yaml configs) for bin.train are as follows:
The available FLAGS (or the top levels of yaml configs) for bin.infer are as follows:
Note that:
The RNN benchmarks are performed on 1 GTX 1080Ti GPU with predefined configurations:
default_configs/adam_loss_decay.yml
default_configs/default_metrics.yml
default_configs/default_training_options.yml
default_configs/seq2seq_cgru.yml
The Transformer benchmarks are performed on 1 GTX 1080Ti GPU with predefined configurations:
default_configs/transformer_base.yml
default_configs/transformer_training_options.yml
Note that in Transformer model, we set batch_tokens_size=2500
with update_cycle=10
to realize pseudo parallel training.
The beam sizes for RNN and Transformer are 10 and 4 respectively.
The datasets are preprocessed using fetch_wmt2017_ende.sh and fetch_wmt2018_zhen.sh referring to Edinburgh’s Report.
The BLEU scores are evaluated by the wrapper script run_mteval.sh. For EN-ZH experiments, the BLEU scores are evaluated at character-level while others are evaluated at word-level.
Dataset | Model | BLEU | |
---|---|---|---|
newstest2016(dev) | newstest2017 | ||
WMT17 EN-DE | RNN | 29.6 | 23.6 |
Transformer | 33.5 | 27.0 | |
WMT17 DE-EN | RNN | 34.0 | 29.6 |
Transformer | 37.6 | 33.1 |
Dataset | Model | BLEU | |
---|---|---|---|
newsdev2017(dev) | newstest2017 | ||
WMT17 ZH-EN | RNN | 19.7 | 21.2 |
Transformer | 22.7 | 25.0 | |
WMT17 EN-ZH | RNN | 30.0 | 30.2 |
Transformer | 34.9 | 35.0 |
The following features remain unimplemented:
The implementation is inspired by the following:
Any comments or suggestions are welcome.
Please email [email protected].