Skip to content

Commit

Permalink
Support pure meta model lm_head tp (#6812)
Browse files Browse the repository at this point in the history
Add lm_head tp support when checkpoint not provided to
deepspeed.init_inference().

---------

Co-authored-by: Logan Adams <[email protected]>
Co-authored-by: Ma, Guokai <[email protected]>
  • Loading branch information
3 people authored Jan 10, 2025
1 parent 1d15ef0 commit fa8db5c
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions deepspeed/module_inject/replace_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,13 +342,11 @@ def set_lm_head(module):
module.lm_head, "weight") and module.lm_head.weight.is_meta:
module.lm_head.weight = embedding_weight
# enable tensor parallel for the last linear
if hasattr(module, "lm_head") and hasattr(module.lm_head,
"weight") and not module.lm_head.weight.is_meta and isinstance(
module.lm_head, torch.nn.Linear):
if hasattr(module, "lm_head") and hasattr(module.lm_head, "weight") and isinstance(
module.lm_head, torch.nn.Linear):
module = replace_wo_policy(module, ("lm_head", ), 0, "lm_head")
elif hasattr(module, "embed_out") and hasattr(module.embed_out,
"weight") and not module.embed_out.weight.is_meta and isinstance(
module.embed_out, torch.nn.Linear):
elif hasattr(module, "embed_out") and hasattr(module.embed_out, "weight") and isinstance(
module.embed_out, torch.nn.Linear):
module = replace_wo_policy(module, ("embed_out", ), 0, "embed_out")
elif hasattr(module, "language_model") and hasattr(module.language_model, "lm_head"):
module = replace_wo_policy(module.language_model, ("lm_head", ), 0, "lm_head")
Expand Down Expand Up @@ -389,7 +387,6 @@ def conv2d_parallel_shard_weights(model, rank, world_size):
checkpoint=checkpoint_file)
pbar.update(1)
gc.collect()
replaced_module = set_lm_head(replaced_module)
# conv2d tp module replace
# Now is for yuan model. Add model list and conv policy to decide whether to replace conv.
if 'Yuan' in str(replaced_module):
Expand All @@ -399,6 +396,9 @@ def conv2d_parallel_shard_weights(model, rank, world_size):
orig_class=orig_layer_impl,
replace_fn=replace_fn,
_replace_policy=config.injection_policy_tuple)
# AutoTP default set lm_head tp
if not config.replace_with_kernel_inject:
replaced_module = set_lm_head(replaced_module)

quantizer = GroupQuantizer(q_int8=quantize)
world_size = dist.get_world_size() if dist.is_initialized() else 1
Expand Down

0 comments on commit fa8db5c

Please sign in to comment.