diff --git a/examples/translation/README.md b/examples/translation/README.md index 055a508a28..0494f6777d 100644 --- a/examples/translation/README.md +++ b/examples/translation/README.md @@ -116,7 +116,13 @@ CUDA_VISIBLE_DEVICES=0 fairseq-train \ --lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 \ --dropout 0.3 --weight-decay 0.0001 \ --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ - --max-tokens 4096 + --max-tokens 4096 \ + --eval-bleu \ + --eval-bleu-args '{"beam": 5, "max_len_a": 1.2, "max_len_b": 10}' \ + --eval-bleu-detok moses \ + --eval-bleu-remove-bpe \ + --eval-bleu-print-samples \ + --best-checkpoint-metric bleu --maximize-best-checkpoint-metric ``` Finally we can evaluate our trained model: diff --git a/fairseq/meters.py b/fairseq/meters.py index 96125580a9..ad5f993c44 100644 --- a/fairseq/meters.py +++ b/fairseq/meters.py @@ -230,7 +230,11 @@ def get_smoothed_value(self, key: str) -> float: def get_smoothed_values(self) -> Dict[str, float]: """Get all smoothed values.""" - return OrderedDict([(key, self.get_smoothed_value(key)) for key in self.keys()]) + return OrderedDict([ + (key, self.get_smoothed_value(key)) + for key in self.keys() + if not key.startswith("_") + ]) def reset(self): """Reset Meter instances.""" diff --git a/fairseq/tasks/translation.py b/fairseq/tasks/translation.py index fbda301363..dad310ff78 100644 --- a/fairseq/tasks/translation.py +++ b/fairseq/tasks/translation.py @@ -3,15 +3,20 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from argparse import Namespace +import json import itertools import logging import os -from fairseq import options, utils +import numpy as np + +from fairseq import metrics, options, utils from fairseq.data import ( AppendTokenDataset, ConcatDataset, data_utils, + encoders, indexed_dataset, LanguagePairDataset, PrependTokenDataset, @@ -19,7 +24,9 @@ TruncateDataset, ) -from . import FairseqTask, register_task +from fairseq.tasks import FairseqTask, register_task + +EVAL_BLEU_ORDER = 4 logger = logging.getLogger(__name__) @@ -155,6 +162,26 @@ def add_args(parser): help='amount to upsample primary dataset') parser.add_argument('--truncate-source', action='store_true', default=False, help='truncate source to max-source-positions') + + # options for reporting BLEU during validation + parser.add_argument('--eval-bleu', action='store_true', + help='evaluation with BLEU scores') + parser.add_argument('--eval-bleu-detok', type=str, default="space", + help='detokenizer before computing BLEU (e.g., "moses"); ' + 'required if using --eval-bleu; use "space" to ' + 'disable detokenization; see fairseq.data.encoders ' + 'for other options') + parser.add_argument('--eval-bleu-detok-args', type=str, metavar='JSON', + help='args for building the tokenizer, if needed') + parser.add_argument('--eval-tokenized-bleu', action='store_true', default=False, + help='if setting, we compute tokenized BLEU instead of sacrebleu') + parser.add_argument('--eval-bleu-remove-bpe', nargs='?', const='@@ ', default=None, + help='remove BPE before computing BLEU') + parser.add_argument('--eval-bleu-args', type=str, metavar='JSON', + help='generation args for BLUE scoring, ' + 'e.g., \'{"beam": 4, "lenpen": 0.6}\'') + parser.add_argument('--eval-bleu-print-samples', action='store_true', + help='print sample generations during validation') # fmt: on def __init__(self, args, src_dict, tgt_dict): @@ -219,6 +246,75 @@ def load_dataset(self, split, epoch=0, combine=False, **kwargs): def build_dataset_for_inference(self, src_tokens, src_lengths): return LanguagePairDataset(src_tokens, src_lengths, self.source_dictionary) + def build_model(self, args): + if getattr(args, 'eval_bleu', False): + assert getattr(args, 'eval_bleu_detok', None) is not None, ( + '--eval-bleu-detok is required if using --eval-bleu; ' + 'try --eval-bleu-detok=moses (or --eval-bleu-detok=space ' + 'to disable detokenization, e.g., when using sentencepiece)' + ) + detok_args = json.loads(getattr(args, 'eval_bleu_detok_args', '{}') or '{}') + self.tokenizer = encoders.build_tokenizer(Namespace( + tokenizer=getattr(args, 'eval_bleu_detok', None), + **detok_args + )) + + gen_args = json.loads(getattr(args, 'eval_bleu_args', '{}') or '{}') + self.sequence_generator = self.build_generator(Namespace(**gen_args)) + return super().build_model(args) + + def valid_step(self, sample, model, criterion): + loss, sample_size, logging_output = super().valid_step(sample, model, criterion) + if self.args.eval_bleu: + bleu = self._inference_with_bleu(self.sequence_generator, sample, model) + logging_output['_bleu_sys_len'] = bleu.sys_len + logging_output['_bleu_ref_len'] = bleu.ref_len + # we split counts into separate entries so that they can be + # summed efficiently across workers using fast-stat-sync + assert len(bleu.counts) == EVAL_BLEU_ORDER + for i in range(EVAL_BLEU_ORDER): + logging_output['_bleu_counts_' + str(i)] = bleu.counts[i] + logging_output['_bleu_totals_' + str(i)] = bleu.totals[i] + return loss, sample_size, logging_output + + def reduce_metrics(self, logging_outputs, criterion): + super().reduce_metrics(logging_outputs, criterion) + if self.args.eval_bleu: + + def sum_logs(key): + return sum(log.get(key, 0) for log in logging_outputs) + + counts, totals = [], [] + for i in range(EVAL_BLEU_ORDER): + counts.append(sum_logs('_bleu_counts_' + str(i))) + totals.append(sum_logs('_bleu_totals_' + str(i))) + + if max(totals) > 0: + # log counts as numpy arrays -- log_scalar will sum them correctly + metrics.log_scalar('_bleu_counts', np.array(counts)) + metrics.log_scalar('_bleu_totals', np.array(totals)) + metrics.log_scalar('_bleu_sys_len', sum_logs('_bleu_sys_len')) + metrics.log_scalar('_bleu_ref_len', sum_logs('_bleu_ref_len')) + + def compute_bleu(meters): + import inspect + import sacrebleu + fn_sig = inspect.getfullargspec(sacrebleu.compute_bleu)[0] + if 'smooth_method' in fn_sig: + smooth = {'smooth_method': 'exp'} + else: + smooth = {'smooth': 'exp'} + bleu = sacrebleu.compute_bleu( + correct=meters['_bleu_counts'].sum, + total=meters['_bleu_totals'].sum, + sys_len=meters['_bleu_sys_len'].sum, + ref_len=meters['_bleu_ref_len'].sum, + **smooth + ) + return round(bleu.score, 2) + + metrics.log_derived('bleu', compute_bleu) + def max_positions(self): """Return the max sentence length allowed by the task.""" return (self.args.max_source_positions, self.args.max_target_positions) @@ -232,3 +328,30 @@ def source_dictionary(self): def target_dictionary(self): """Return the target :class:`~fairseq.data.Dictionary`.""" return self.tgt_dict + + def _inference_with_bleu(self, generator, sample, model): + import sacrebleu + + def decode(toks, escape_unk=False): + s = self.tgt_dict.string( + toks.int().cpu(), + self.args.eval_bleu_remove_bpe, + escape_unk=escape_unk, + ) + if self.tokenizer: + s = self.tokenizer.decode(s) + return s + + gen_out = self.inference_step(generator, [model], sample, None) + hyps, refs = [], [] + for i in range(len(gen_out)): + hyps.append(decode(gen_out[i][0]['tokens'])) + refs.append(decode( + utils.strip_pad(sample['target'][i], self.tgt_dict.pad()), + escape_unk=True, # don't count as matches to the hypo + )) + if self.args.eval_bleu_print_samples: + logger.info('example hypothesis: ' + hyps[0]) + logger.info('example reference: ' + refs[0]) + tokenize = sacrebleu.DEFAULT_TOKENIZER if not self.args.eval_tokenized_bleu else 'none' + return sacrebleu.corpus_bleu(hyps, [refs], tokenize=tokenize) diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index efa46f3f4b..da0fbdd4ce 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -163,39 +163,38 @@ def train(args, trainer, task, epoch_itr): valid_subsets = args.valid_subset.split(',') max_update = args.max_update or math.inf - for samples in progress: - with metrics.aggregate('train_inner'): + with metrics.aggregate() as agg: + for samples in progress: log_output = trainer.train_step(samples) num_updates = trainer.get_num_updates() if log_output is None: continue # log mid-epoch stats - stats = get_training_stats('train_inner') + stats = get_training_stats(agg.get_smoothed_values()) progress.log(stats, tag='train', step=num_updates) - if ( - not args.disable_validation - and args.save_interval_updates > 0 - and num_updates % args.save_interval_updates == 0 - and num_updates > 0 - ): - valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) - checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) + if ( + not args.disable_validation + and args.save_interval_updates > 0 + and num_updates % args.save_interval_updates == 0 + and num_updates > 0 + ): + valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) + checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) - if num_updates >= max_update: - break + if num_updates >= max_update: + break # log end-of-epoch stats - stats = get_training_stats('train') + stats = get_training_stats(agg.get_smoothed_values()) progress.print(stats, tag='train', step=num_updates) # reset epoch-level meters metrics.reset_meters('train') -def get_training_stats(stats_key): - stats = metrics.get_smoothed_values(stats_key) +def get_training_stats(stats): if 'nll_loss' in stats and 'ppl' not in stats: stats['ppl'] = utils.get_perplexity(stats['nll_loss']) stats['wall'] = round(metrics.get_meter('default', 'wall').elapsed_time, 0) @@ -233,22 +232,22 @@ def validate(args, trainer, task, epoch_itr, subsets): no_progress_bar='simple' ) - # reset validation loss meters + # reset validation meters metrics.reset_meters('valid') - for sample in progress: - trainer.valid_step(sample) + with metrics.aggregate() as agg: + for sample in progress: + trainer.valid_step(sample) # log validation stats - stats = get_valid_stats(args, trainer) + stats = get_valid_stats(args, trainer, agg.get_smoothed_values()) progress.print(stats, tag=subset, step=trainer.get_num_updates()) valid_losses.append(stats[args.best_checkpoint_metric]) return valid_losses -def get_valid_stats(args, trainer): - stats = metrics.get_smoothed_values('valid') +def get_valid_stats(args, trainer, stats): if 'nll_loss' in stats and 'ppl' not in stats: stats['ppl'] = utils.get_perplexity(stats['nll_loss']) stats['num_updates'] = trainer.get_num_updates() diff --git a/tests/test_binaries.py b/tests/test_binaries.py index 1d4a61a035..8f5b9dfb07 100644 --- a/tests/test_binaries.py +++ b/tests/test_binaries.py @@ -128,6 +128,19 @@ def test_generation(self): ]) generate_main(data_dir, ['--prefix-size', '2']) + def test_eval_bleu(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory('test_eval_bleu') as data_dir: + create_dummy_data(data_dir) + preprocess_translation_data(data_dir) + train_translation_model(data_dir, 'fconv_iwslt_de_en', [ + '--eval-bleu', + '--eval-bleu-print-samples', + '--eval-bleu-remove-bpe', + '--eval-bleu-detok', 'space', + '--eval-bleu-args', '{"beam": 4, "min_len": 10}', + ]) + def test_lstm(self): with contextlib.redirect_stdout(StringIO()): with tempfile.TemporaryDirectory('test_lstm') as data_dir: