Skip to content

Commit

Permalink
Llama3.1 with torchtune (pytorch#1123)
Browse files Browse the repository at this point in the history
* added model source and type for torchtune flamingo support

* added model source and type for torchtune flamingo support

* grab missing enum

* fix ModelArgs init

* create init func for ModelArgs for BC

* update pipeline for ModleSource and ModelType

* revert lintrunner update on ET

* introduce flamingo modules form torchtune

* back up to move to linux

* mitigate building issue

* pass local test

* structual model builder

* update torchtune address

* update install requirement

* support new torchtune flamingo component

* specific version for vision and ao

* unify text-only model generation pipeline

* convert installation back and bypass torchtune

* restructual model definition

* update exportation variable name

* remove redunctant function

* 1/n torchtune 3.1 8b

* installation update

* torchtune 3.1 8b / 30b

* bring torchchat llama3.1 back

* bring tok vali back to torchchat model + revert install_requirements.sh

* solve bugs related to tt model support

* bypass torchtune import issue

* solve Jack's wonderful comments

* remveo extra dot

* add type.Callable

* fix torchchat typos

* solve bug when args.model is None

* support builder_args.params_table is None

* remove all .DS_Store

* bring gguf back

* remove reduntant updates

* bring checkpoint back

* debug

* debug

* debug

* new factory func to produce Model from modelargs

* solve comments
  • Loading branch information
Gasoonjia authored Sep 11, 2024
1 parent 964d437 commit e2049f4
Show file tree
Hide file tree
Showing 7 changed files with 173 additions and 56 deletions.
25 changes: 19 additions & 6 deletions torchchat/cli/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@
from torchchat.utils.measure_time import measure_time
from torchchat.utils.quantize import quantize_model

# bypass the import issue before torchao is ready on macos
try:
from torchtune.models.convert_weights import meta_to_tune
except:
pass



@dataclass
class BuilderArgs:
Expand Down Expand Up @@ -328,11 +335,15 @@ def _load_model_default(builder_args, only_config=False):
assert not builder_args.gguf_path

model = _init_model_on_meta_device(builder_args)
# checkpoint = torch.load(str(builder_args.checkpoint_path), mmap=True, weights_only=True)
cps = []
if builder_args.checkpoint_dir is not None:

if builder_args.params_table and builder_args.params_table.endswith("Tune"):
print("Loading Tune checkpoint")
meta_checkpoint = torch.load(str(builder_args.checkpoint_path), mmap=True, weights_only=True)
checkpoint = meta_to_tune(meta_checkpoint)
elif builder_args.checkpoint_dir is not None:
# Load multiple checkpoint; ignore the single path.
builder_args.checkpoint_path = None
cps = []
for i in range(4):
cp_name = f"consolidated.{i}.pth"
print(f"Loading {cp_name}")
Expand Down Expand Up @@ -363,10 +374,10 @@ def _load_model_default(builder_args, only_config=False):

if "model" in checkpoint and "stories" in str(builder_args.checkpoint_path):
checkpoint = checkpoint["model"]

checkpoint = {"text_transformer." + k: v for k, v in checkpoint.items()}

checkpoint = {"model." + k: v for k, v in checkpoint.items()}
model.load_state_dict(checkpoint, assign=True, strict=True)

return model


Expand Down Expand Up @@ -534,7 +545,9 @@ def _initialize_model(
if builder_args.setup_caches:
with torch.device(builder_args.device):
model.setup_caches(
max_batch_size=1, max_seq_length=max_seq_length or model.config.transformer_args["text"].max_seq_length
max_batch_size=1,
max_seq_length=max_seq_length
or model.config.transformer_args["text"].max_seq_length,
)

model.to(dtype=builder_args.precision)
Expand Down
23 changes: 19 additions & 4 deletions torchchat/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ class GeneratorArgs:
speculate_k: int = 5
sequential_prefill: bool = False
max_autotune: bool = False
is_torchtune_model: bool = False

def __post_init__(self):
if self.compile_prefill and self.sequential_prefill:
Expand Down Expand Up @@ -161,6 +162,7 @@ def from_args(cls, args):
speculate_k=args.speculate_k,
sequential_prefill=sequential_prefill,
max_autotune=args.max_autotune,
is_torchtune_model=args.model and args.model.endswith("tune"),
)


Expand Down Expand Up @@ -197,6 +199,8 @@ def __init__(
self.profile = profile
self.quantize = quantize
self.draft_quantize = draft_quantize
self.is_torchtune_model = generator_args.is_torchtune_model
self.dtype = builder_args.precision

# global print
# from tp import maybe_init_dist
Expand Down Expand Up @@ -263,7 +267,10 @@ def __init__(
else:
self.draft_model = None

self.tokenizer_args.validate_model(self.model)
# torchtune model does not contain essential info for validation
# TODO: refactor model config to be more generic
if not self.is_torchtune_model:
self.tokenizer_args.validate_model(self.model)
self.tokenizer_args.validate_model(self.draft_model, "draft model")
generator_args.validate_build(self.builder_args)
generator_args.validate_build(self.speculative_builder_args, "draft model")
Expand Down Expand Up @@ -295,7 +302,7 @@ def sample(
need_probs: bool,
temperature: float = 1.0,
top_k: Optional[int] = None,
):
):
if temperature == 0 and not need_probs:
_, idx_next = torch.topk(logits[0, -1], k=1, dim=-1)
return (idx_next, None)
Expand Down Expand Up @@ -517,7 +524,10 @@ def generate(
if start_pos == 0:
model = model.to(device=device)
with torch.device(device):
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
if self.is_torchtune_model:
model.setup_caches(max_batch_size=1, dtype=self.dtype)
else:
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
if is_speculative and draft_model is not model:
draft_model.setup_caches(
max_batch_size=1, max_seq_length=max_seq_length
Expand Down Expand Up @@ -686,7 +696,12 @@ def chat(

self.system_prompt = None
# Set up our max_seq_length
if generator_args.chat_mode:

# This is a hack to get around the fact that different models have different ways to record their max_seq_length and might be wrong
# TODO: unify the max_seq_length config representation.
if generator_args.is_torchtune_model:
max_seq_length = self.model.config.transformer_args["text"]["max_seq_len"]
elif generator_args.chat_mode:
max_seq_length = self.model.config.transformer_args["text"].max_seq_length
print(
f"Entering Chat Mode. Will continue chatting back and forth with the language model until the models max context length of {max_seq_length} tokens is hit or until the user says /bye"
Expand Down
135 changes: 91 additions & 44 deletions torchchat/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from enum import Enum
from pathlib import Path
from typing import Callable, Dict, Optional, Union
from abc import ABC, abstractmethod

import torch
import torch.nn as nn
Expand All @@ -33,13 +34,20 @@
try:
from torchtune.models.flamingo import flamingo_decoder, flamingo_vision_encoder
from torchtune.modules.model_fusion import DeepFusionModel
from torchtune.models.llama3_1._component_builders import llama3_1 as llama3_1_builder
except:
pass

config_path = Path(f"{str(Path(__file__).parent)}/model_params")

def identity(**kwargs):
if len(kwargs) != 1:
raise ValueError("Only one argument is expected")
return list(kwargs.values())[0]

class ModelType(Enum):
TextOnly = "text_only"
Llama3_1 = "llama3_1"
Flamingo = "flamingo"

# Type for objects that can generate nn.Module instance
Expand Down Expand Up @@ -72,9 +80,18 @@ class ModelRecipe:
def _text_only(cls):
return cls(
model_type=ModelType.TextOnly,
modules={'text_transformer': Transformer},
fusion_class=nn.Identity,
modules={'text': Transformer},
fusion_class=identity,
)

@classmethod
def _llama3_1(cls):
return cls(
model_type=ModelType.Llama3_1,
modules={'text': llama3_1_builder},
fusion_class=identity,
)

@classmethod
def _flamingo(cls):
return cls(
Expand All @@ -92,6 +109,8 @@ def get_recipe(cls, model_type):
return cls._text_only()
elif model_type == ModelType.Flamingo:
return cls._flamingo()
elif model_type == ModelType.Llama3_1:
return cls._llama3_1()
else:
raise ValueError(f"Can not find the model recipe for {model_type}")

Expand Down Expand Up @@ -184,11 +203,7 @@ def from_params(cls, params_path):
except TypeError:
# try to interpret as a dict of transformer configs
model_type = ModelType(loaded_params["model_type"])

# Currently only supporting flamingo model
assert model_type == ModelType.Flamingo
transformer_args = {k: v for k, v in loaded_params.items() if k != "model_type"}

return cls(transformer_args, model_type)

@classmethod
Expand Down Expand Up @@ -266,18 +281,14 @@ def update(self, input_pos, k_val, v_val):
return k_out, v_out


class Model(nn.Module):
class Model(ABC, nn.Module):
"""
The entrance for model construction in torchchat.
"""
def __init__(self, config: ModelArgs) -> None:
super().__init__()
self.config = config
# TODO: unify the model init logic
if config.model_type == ModelType.TextOnly:
self.text_transformer = Transformer(config.transformer_args["text"])
else:
self.model = self.build_model()
self.model = self.build_model()

def build_model(self) -> nn.Module:
"""
Expand All @@ -290,50 +301,43 @@ def build_model(self) -> nn.Module:
recipe = ModelRecipe.get_recipe(self.config.model_type)
modules = {}
for name, module_class in recipe.modules.items():
modules[name] = module_class(**self.config.transformer_args[name])

if isinstance(config_args := self.config.transformer_args[name], dict):
modules[name] = module_class(**config_args)
else:
modules[name] = module_class(config_args)

return recipe.fusion_class(**modules)

@abstractmethod
def forward(self, *args, **kwargs):
raise NotImplementedError("forward method is not implemented")

def forward(self,
tokens: Optional[Tensor] = None,
input_pos: Optional[Tensor] = None,
encoder_input: Optional[Dict[str, Tensor]] = None,
encoder_mask: Optional[Tensor] = None) -> Tensor:
@abstractmethod
def setup_caches(self, *args, **kwargs):
raise NotImplementedError("setup_caches method is not implemented")

if self.config.model_type == ModelType.TextOnly:
return self.text_transformer(tokens, input_pos)
else:
assert self.config.model_type == ModelType.Flamingo
if input_pos:
warnings.warn("input_pos is not used for Flamingo model. Ignoring it.")
if encoder_input is None:
return self.model(tokens, encoder_mask = encoder_mask)
return self.model(tokens, encoder_input=encoder_input, encoder_mask = encoder_mask)

def setup_caches(self, max_batch_size, max_seq_length=None, dtype=None):
if self.config.model_type == ModelType.TextOnly:
self.text_transformer.setup_caches(max_batch_size, max_seq_length)
else:
assert self.config.model_type == ModelType.Flamingo
if max_seq_length is not None:
warnings.warn("max_seq_length is not used for Flamingo model. Ignoring it.")
self.model.setup_caches(max_batch_size, dtype=dtype)

def reset_caches(self):
assert self.config.model_type == ModelType.Flamingo
self.model.reset_caches()
@classmethod
def _get_model_instance(cls, config: ModelArgs):
model_class = MODEL_TYPE_TO_CLASS.get(config.model_type)
if model_class is None:
raise ValueError("Unsupported model type:", str(config.model_type))
return model_class(config)

@classmethod
def from_model_args(cls, config: ModelArgs):
return cls._get_model_instance(config)

@classmethod
def from_name(cls, name: str):
return cls(ModelArgs.from_name(name))
return cls._get_model_instance(ModelArgs.from_name(name))

@classmethod
def from_table(cls, name: str):
return cls(ModelArgs.from_table(name))
return cls._get_model_instance(ModelArgs.from_table(name))

@classmethod
def from_params(cls, params_path: str):
return cls(ModelArgs.from_params(params_path))
return cls._get_model_instance(ModelArgs.from_params(params_path))

@classmethod
def from_gguf(cls, gguf_path: str, **kwargs):
Expand All @@ -345,6 +349,49 @@ def from_gguf(cls, gguf_path: str, **kwargs):
return model


class TextOnlyModel(Model):
def forward(self, tokens: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
return self.model(tokens, input_pos)

def setup_caches(self, max_batch_size, max_seq_length):
self.model.setup_caches(max_batch_size, max_seq_length)


class Llama31Model(Model):
def forward(self, tokens: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
return self.model(tokens=tokens, input_pos=input_pos)

def setup_caches(self, max_batch_size, dtype):
self.model.setup_caches(max_batch_size, dtype=dtype)

def reset_caches(self):
self.model.reset_caches()


class FlamingoModel(Model):
def forward(
self,
tokens: Tensor,
encoder_input: Optional[Dict[str, Tensor]] = None,
encoder_mask: Optional[Tensor] = None,
) -> Tensor:
if encoder_input is None:
return self.model(tokens, encoder_mask=encoder_mask)
return self.model(tokens, encoder_input=encoder_input, encoder_mask=encoder_mask)

def setup_caches(self, max_batch_size, dtype):
self.model.setup_caches(max_batch_size, dtype=dtype)

def reset_caches(self):
self.model.reset_caches()


MODEL_TYPE_TO_CLASS = {
ModelType.TextOnly: TextOnlyModel,
ModelType.Flamingo: FlamingoModel,
ModelType.Llama3_1: Llama31Model,
}

class Transformer(nn.Module):
def __init__(self, config: TransformerArgs) -> None:
super().__init__()
Expand Down
12 changes: 12 additions & 0 deletions torchchat/model_config/models.json
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,18 @@
"distribution_path": "meta-llama/Meta-Llama-3.1-70B-Instruct",
"transformer_params_key": "Meta-Llama-3.1-70B"
},
"meta-llama/Meta-Llama-3.1-8B-Instruct-Tune": {
"aliases": ["llama3.1-tune", "llama3.1-chat-tune", "llama3.1-instruct-tune"],
"distribution_channel": "HuggingFaceSnapshot",
"distribution_path": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"transformer_params_key": "Meta-Llama-3.1-8B-Tune"
},
"meta-llama/Meta-Llama-3.1-70B-Instruct-Tune": {
"aliases": ["llama3.1-70b-tune"],
"distribution_channel": "HuggingFaceSnapshot",
"distribution_path": "meta-llama/Meta-Llama-3.1-70B-Instruct",
"transformer_params_key": "Meta-Llama-3.1-70B-Tune"
},
"meta-llama/CodeLlama-7b-Python-hf": {
"aliases": ["codellama", "codellama-7b"],
"distribution_channel": "HuggingFaceSnapshot",
Expand Down
15 changes: 15 additions & 0 deletions torchchat/model_params/Meta-Llama-3.1-70B-Tune.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
"model_type": "llama3_1",
"text": {
"vocab_size": 128256,
"num_layers": 80,
"num_heads": 64,
"num_kv_heads": 8,
"embed_dim": 8192,
"max_seq_len": 8192,
"intermediate_dim": 28672,
"attn_dropout": 0.0,
"norm_eps": 1e-5,
"rope_base": 500000.0
}
}
15 changes: 15 additions & 0 deletions torchchat/model_params/Meta-Llama-3.1-8B-Tune.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
"model_type": "llama3_1",
"text": {
"vocab_size": 128256,
"num_layers": 32,
"num_heads": 32,
"num_kv_heads": 8,
"embed_dim": 4096,
"max_seq_len": 8192,
"intermediate_dim": 14336,
"attn_dropout": 0.0,
"norm_eps": 1e-5,
"rope_base": 500000.0
}
}
Loading

0 comments on commit e2049f4

Please sign in to comment.