Skip to content

Commit

Permalink
corrected copy_weights_phi logic
Browse files Browse the repository at this point in the history
  • Loading branch information
ysjprojects committed Jan 9, 2025
1 parent c1c0b4f commit 130e4cd
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 2 deletions.
85 changes: 84 additions & 1 deletion litgpt/scripts/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ def copy_weights_phi(
"lm_head.bias": "lm_head.bias",
}

if config.name.startswith("Phi-3"):
if config.name.startswith(("Phi-3", "phi-4")):
weight_map.update(
{
"model.layers.{}.self_attn.qkv_proj.weight": "transformer.h.{}.attn.qkv.weight",
Expand Down Expand Up @@ -447,6 +447,85 @@ def copy_weights_qwen_2_5(
pbar.update(progress_per_file)


def copy_weights_olmo2(
config: Config,
qkv_weights: Dict[int, List[Optional[NotYetLoadedTensor]]],
state_dict: Dict[str, torch.Tensor],
hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],
saver: Optional[incremental_save] = None,
dtype: Optional[torch.dtype] = None,
pbar: Optional[tqdm] = None,
progress_per_file: Optional[float] = None,
debug_mode: Optional[bool] = False,
) -> None:
weight_map = {
"model.embed_tokens.weight": "transformer.wte.weight",
"model.layers.{}.self_attn.q_norm.weight": "transformer.h.{}.attn.q_norm.weight",
"model.layers.{}.self_attn.q_proj.weight": None,
"model.layers.{}.self_attn.k_norm.weight": "transformer.h.{}.attn.k_norm.weight",
"model.layers.{}.self_attn.k_proj.weight": None,
"model.layers.{}.self_attn.v_proj.weight": None,
"model.layers.{}.self_attn.o_proj.weight": "transformer.h.{}.attn.proj.weight",
"model.layers.{}.self_attn.rotary_emb.inv_freq": None,
"model.layers.{}.post_attention_layernorm.weight": "transformer.h.{}.norm_2.weight",
"model.layers.{}.post_attention_layernorm.bias": "transformer.h.{}.norm_2.bias",
"model.layers.{}.post_feedforward_layernorm.weight": "transformer.h.{}.post_mlp_norm.weight",
"model.norm.weight": "transformer.ln_f.weight",
"model.norm.bias": "transformer.ln_f.bias",
"lm_head.weight": "lm_head.weight",
}
if config.mlp_class_name in ("LLaMAMLP", "GemmaMLP"):
weight_map.update(
{
"model.layers.{}.mlp.gate_proj.weight": "transformer.h.{}.mlp.fc_1.weight",
"model.layers.{}.mlp.up_proj.weight": "transformer.h.{}.mlp.fc_2.weight",
"model.layers.{}.mlp.down_proj.weight": "transformer.h.{}.mlp.proj.weight",
}
)
else:
raise NotImplementedError

if progress_per_file is not None:
progress_per_file = progress_per_file / max(1, len(hf_weights) + len(qkv_weights))

for from_name, param in hf_weights.items():
name_template, *ids = layer_template(from_name, num_matches=2)
to_name = weight_map[name_template]
param = load_param(param, from_name, dtype, verbose=debug_mode)
if any(w in from_name for w in ("q_proj", "k_proj", "v_proj")):
qkv = qkv_weights.setdefault(ids[0], defaultdict(dict))
weight_name, weight_type = from_name.split(".")[-2:]
qkv[weight_type][weight_name] = param
if to_name is None:
continue
to_name = to_name.format(*ids)
if saver is not None:
param = saver.store_early(param)
state_dict[to_name] = param

if progress_per_file is not None:
pbar.update(progress_per_file)

if "lm_head.weight" not in state_dict:
state_dict["lm_head.weight"] = state_dict["transformer.wte.weight"]

for i in list(qkv_weights):
for weight_type in list(qkv_weights[i]):
qkv = qkv_weights[i][weight_type]
if len(qkv) != 3:
# qkv is splitted across different .bin files
continue
q = load_param(qkv["q_proj"], f"layer {i} q {weight_type}", dtype, verbose=debug_mode)
k = load_param(qkv["k_proj"], f"layer {i} k {weight_type}", dtype, verbose=debug_mode)
v = load_param(qkv["v_proj"], f"layer {i} v {weight_type}", dtype, verbose=debug_mode)
qkv = torch.cat((q, k, v))
state_dict[f"transformer.h.{i}.attn.qkv.{weight_type}"] = qkv
del qkv_weights[i][weight_type]

if progress_per_file is not None:
pbar.update(progress_per_file)


def qkv_reassemble(
param: Union[torch.Tensor, NotYetLoadedTensor], config: Config
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
Expand Down Expand Up @@ -537,6 +616,10 @@ def convert_hf_checkpoint(
# holder to reconstitute the split q, k, v
qkv_weights = {}
copy_fn = partial(copy_weights_qwen_2_5, config, qkv_weights)
elif model_name.lower().startswith(("olmo-2-")):
# holder to reconstitute the split q, k, v
qkv_weights = {}
copy_fn = partial(copy_weights_olmo2, config, qkv_weights)
elif config.mlp_class_name in ("LLaMAMLP", "GemmaMLP", "LLaMAMoE"):
# holder to reconstitute the split q, k, v
qkv_weights = {}
Expand Down
2 changes: 1 addition & 1 deletion litgpt/scripts/convert_lit_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def copy_weights_phi(
"lm_head.weight": "lm_head.weight",
"lm_head.bias": "lm_head.bias",
}
if config.name.startswith("Phi-3"):
if config.name.startswith(("Phi-3", "phi-4")):
weight_map.update(
{
"transformer.h.{}.attn.qkv.weight": "model.layers.{}.self_attn.qkv_proj.weight",
Expand Down

0 comments on commit 130e4cd

Please sign in to comment.