From 122fc1db49534a5ca295fcae1b362bbd6308c32f Mon Sep 17 00:00:00 2001 From: Zeming Lin Date: Fri, 17 Jan 2020 11:38:34 -0800 Subject: [PATCH] Add begin_epoch to FairseqTask (#984) Summary: Adds a begin_epoch hook to FairseqTask. Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/984 Differential Revision: D19429433 Pulled By: myleott fbshipit-source-id: 367bd4d0d2d2bc995cca9ac151256c77ede36c83 --- fairseq/tasks/fairseq_task.py | 6 +++++- fairseq_cli/train.py | 3 +++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index 89326977a4..24004313e5 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -308,8 +308,12 @@ def inference_step(self, generator, models, sample, prefix_tokens=None): with torch.no_grad(): return generator.generate(models, sample, prefix_tokens=prefix_tokens) + def begin_epoch(self, epoch, model): + """Hook function called before the start of each epoch.""" + pass + def update_step(self, num_updates): - """Task level update when number of update increases. + """Task level update when number of updates increases. This is called after the optimization step and learning rate update at each iteration. diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index f9e03cb4bd..efa46f3f4b 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -158,6 +158,9 @@ def train(args, trainer, task, epoch_itr): args, itr, epoch_itr.epoch, no_progress_bar='simple', ) + # task specific setup per epoch + task.begin_epoch(epoch_itr.epoch, trainer.get_model()) + valid_subsets = args.valid_subset.split(',') max_update = args.max_update or math.inf for samples in progress: