From fe861e578f50dc9c06de33cd361d2f625017e624 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Thu, 8 Jun 2023 16:21:42 +0200 Subject: [PATCH] [`GPT2`] Add correct keys on `_keys_to_ignore_on_load_unexpected` on all child classes of `GPT2PreTrainedModel` (#24113) * add correct keys on `_keys_to_ignore_on_load_unexpected` * oops --- src/transformers/models/gpt2/modeling_gpt2.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 00d92f0bb23c2b..0cb406081f2754 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -668,7 +668,8 @@ class GPT2DoubleHeadsModelOutput(ModelOutput): GPT2_START_DOCSTRING, ) class GPT2Model(GPT2PreTrainedModel): - _keys_to_ignore_on_load_missing = ["attn.masked_bias"] + _keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.bias", r"h\.\d+\.attn\.masked_bias"] + _keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias"] def __init__(self, config): super().__init__(config) @@ -1149,6 +1150,7 @@ def _reorder_cache( GPT2_START_DOCSTRING, ) class GPT2DoubleHeadsModel(GPT2PreTrainedModel): + _keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.bias", r"h\.\d+\.attn\.masked_bias"] _keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"] def __init__(self, config): @@ -1377,6 +1379,7 @@ def _reorder_cache( GPT2_START_DOCSTRING, ) class GPT2ForSequenceClassification(GPT2PreTrainedModel): + _keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.bias", r"h\.\d+\.attn\.masked_bias"] _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head.weight"] def __init__(self, config): @@ -1600,6 +1603,7 @@ def forward( GPT2_START_DOCSTRING, ) class GPT2ForQuestionAnswering(GPT2PreTrainedModel): + _keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.bias", r"h\.\d+\.attn\.masked_bias"] _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias", r"lm_head.weight"] def __init__(self, config):