From 937535dba036dc3759a5334ab5b8110febbe8e6e Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Sun, 8 Mar 2020 06:49:07 -0700 Subject: [PATCH] Allow dictionaries to overwrite entries with #fairseq:overwrite comment (#1073) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: [This commit](https://github.com/pytorch/fairseq/commit/dd1298e15fdbfc0c3639906eee9934968d63fc29) made it so that duplicate entries in a dictionary are ignored. Unfortunately the Camembert model depends on overwriting ``, `` and ``. The proposed solution here is to allow the dictionary to have entries like: ``` 999 #fairseq:overwrite 999 #fairseq:overwrite 999 #fairseq:overwrite , 999 ▁de 999 . 999 (...) ``` These will preserve the old overwriting behavior. Thus we can release a new `camembert.v0.tar.gz` with a dictionary like above and it works. Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1073 Reviewed By: kahne Differential Revision: D20284569 Pulled By: myleott fbshipit-source-id: bf78fbff13c94bf8a6485cbdda62305ddc30c056 --- fairseq/data/dictionary.py | 32 +++++++++++++++++++------- tests/test_dictionary.py | 46 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+), 8 deletions(-) diff --git a/fairseq/data/dictionary.py b/fairseq/data/dictionary.py index 6b5202c4b2..d995976e26 100644 --- a/fairseq/data/dictionary.py +++ b/fairseq/data/dictionary.py @@ -91,9 +91,9 @@ def unk_string(self, escape=False): else: return self.unk_word - def add_symbol(self, word, n=1): + def add_symbol(self, word, n=1, overwrite=False): """Adds a word to the dictionary""" - if word in self.indices: + if word in self.indices and not overwrite: idx = self.indices[word] self.count[idx] = self.count[idx] + n return idx @@ -215,15 +215,31 @@ def add_from_file(self, f): lines = f.readlines() indices_start_line = self._load_meta(lines) + for line in lines[indices_start_line:]: - idx = line.rfind(" ") - if idx == -1: + try: + line, field = line.rstrip().rsplit(" ", 1) + if field == "#fairseq:overwrite": + overwrite = True + line, field = line.rsplit(" ", 1) + else: + overwrite = False + count = int(field) + word = line + if word in self and not overwrite: + raise RuntimeError( + "Duplicate word found when loading Dictionary: '{}'. " + "Duplicate words can overwrite earlier ones by adding the " + "#fairseq:overwrite flag at the end of the corresponding row " + "in the dictionary file. If using the Camembert model, please " + "download an updated copy of the model file." + .format(word) + ) + self.add_symbol(word, n=count, overwrite=overwrite) + except ValueError: raise ValueError( - "Incorrect dictionary format, expected ' '" + "Incorrect dictionary format, expected ' [flags]'" ) - word = line[:idx] - count = int(line[idx + 1 :]) - self.add_symbol(word, n=count) def _save(self, f, kv_iterator): if isinstance(f, str): diff --git a/tests/test_dictionary.py b/tests/test_dictionary.py index b41838b54f..d9a1ec72c8 100644 --- a/tests/test_dictionary.py +++ b/tests/test_dictionary.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import io import tempfile import unittest @@ -65,6 +66,51 @@ def assertMatch(ids, ref_ids): assertMatch(reload_ids, ref_ids2) assertMatch(finalized_ids, reload_ids) + def test_overwrite(self): + # for example, Camembert overwrites , and + dict_file = io.StringIO( + " 999 #fairseq:overwrite\n" + " 999 #fairseq:overwrite\n" + " 999 #fairseq:overwrite\n" + ", 999\n" + "▁de 999\n" + ) + d = Dictionary() + d.add_from_file(dict_file) + self.assertEqual(d.index(''), 1) + self.assertEqual(d.index('foo'), 3) + self.assertEqual(d.index(''), 4) + self.assertEqual(d.index(''), 5) + self.assertEqual(d.index(''), 6) + self.assertEqual(d.index(','), 7) + self.assertEqual(d.index('▁de'), 8) + + def test_no_overwrite(self): + # for example, Camembert overwrites , and + dict_file = io.StringIO( + " 999\n" + " 999\n" + " 999\n" + ", 999\n" + "▁de 999\n" + ) + d = Dictionary() + with self.assertRaisesRegex(RuntimeError, 'Duplicate'): + d.add_from_file(dict_file) + + def test_space(self): + # for example, character models treat space as a symbol + dict_file = io.StringIO( + " 999\n" + "a 999\n" + "b 999\n" + ) + d = Dictionary() + d.add_from_file(dict_file) + self.assertEqual(d.index(' '), 4) + self.assertEqual(d.index('a'), 5) + self.assertEqual(d.index('b'), 6) + if __name__ == '__main__': unittest.main()