Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/master'
Browse files Browse the repository at this point in the history
  • Loading branch information
trholding committed Mar 19, 2024
2 parents 3fc5375 + b3c4b6c commit 1383943
Show file tree
Hide file tree
Showing 4 changed files with 1,199 additions and 11 deletions.
114 changes: 105 additions & 9 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,22 +241,16 @@ def version2_export(model, filepath, group_size=64):
# now let's write out all the params that we are quantizing to Q8_0
# note we skip classifier weights, which are shared with the embedding
ew = []
scales = []
for i, w in enumerate(weights):
# quantize this weight
q, s, err = quantize_q80(w, group_size)
# save the int8 weights to file
serialize_int8(out_file, q) # save the tensor in int8
scales.append(s) # we'll do all the scales after all the qs
serialize_fp32(out_file, s) # save scale factors
# logging
ew.append((err, w.shape))
print(f"{i+1}/{len(weights)} quantized {tuple(w.shape)} to Q8_0 with max error {err}")

# save the scaling factors in fp32 here
# this is done to keep all the weights contiquous, making pointer arithmetic easier in C
for s in scales:
serialize_fp32(out_file, s)

# print the highest error across all weights, should be very small, e.g. O(~0.001)
ew.sort(reverse=True)
print(f"max quantization group error across all weights: {ew[0][0]}")
Expand All @@ -265,6 +259,96 @@ def version2_export(model, filepath, group_size=64):
out_file.close()
print(f"wrote {filepath}")

def hf_export(llama_model, filepath, group_size=64, dtype=torch.float32):
""" Generate the pytorch_model.bin state_dict and config.json for HuggingFace """

try:
from transformers.models.llama.configuration_llama import LlamaConfig
except ImportError:
print("Error: transformers package is required to load huggingface models")
print("Please run `pip install transformers` to install it")
return None

# Generate LlamaModel state_dict
hf_state_dict = {}

# Sometimes we have repeated key values for the heads
dim = llama_model.params.dim
num_key_value_heads = llama_model.params.n_kv_heads
n_rep = llama_model.params.n_heads // num_key_value_heads
key_value_dim = dim // n_rep

# HuggingFace needs the weights permuted.
# See: https://github.com/huggingface/transformers/blob/b132c1703eb1c8bd9dfa4ad6a9be2bfd6ef819e9/src/transformers/models/llama/convert_llama_weights_to_hf.py#L122
def permute_original(w, n_heads=llama_model.params.n_heads, dim1=dim, dim2=dim):
return w.view(dim1, dim2).reshape(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)

# Transfer weights from llama model to the HF state dictionary format
hf_state_dict['model.embed_tokens.weight'] = llama_model.tok_embeddings.weight.clone().to(dtype)
hf_state_dict['model.norm.weight'] = llama_model.norm.weight.clone().to(dtype)

# Add each layer's weights to the HF state dictionary
for i, layer in enumerate(llama_model.layers):
layer_id = layer.layer_id
hf_state_dict[f'model.layers.{i}.input_layernorm.weight'] = llama_model.layers[layer_id].attention_norm.weight.clone().to(dtype)
hf_state_dict[f'model.layers.{i}.self_attn.q_proj.weight'] = permute_original(llama_model.layers[layer_id].attention.wq.weight.clone()).to(dtype)
hf_state_dict[f'model.layers.{i}.self_attn.k_proj.weight'] = permute_original(llama_model.layers[layer_id].attention.wk.weight.clone(), num_key_value_heads, key_value_dim, dim).to(dtype)
hf_state_dict[f'model.layers.{i}.self_attn.v_proj.weight'] = llama_model.layers[layer_id].attention.wv.weight.clone().to(dtype)
hf_state_dict[f'model.layers.{i}.self_attn.o_proj.weight'] = llama_model.layers[layer_id].attention.wo.weight.clone().to(dtype)
hf_state_dict[f'model.layers.{i}.post_attention_layernorm.weight'] = llama_model.layers[layer_id].ffn_norm.weight.clone().to(dtype)
hf_state_dict[f'model.layers.{i}.mlp.gate_proj.weight'] = llama_model.layers[layer_id].feed_forward.w1.weight.clone().to(dtype)
hf_state_dict[f'model.layers.{i}.mlp.down_proj.weight'] = llama_model.layers[layer_id].feed_forward.w2.weight.clone().to(dtype)
hf_state_dict[f'model.layers.{i}.mlp.up_proj.weight'] = llama_model.layers[layer_id].feed_forward.w3.weight.clone().to(dtype)

# llama2.c usually uses tied weights -> reference the embed_tokens.weights instead
hf_state_dict['lm_head.weight'] = hf_state_dict['model.embed_tokens.weight']

# We check that the embeddings are tied, else use manual output weights
_embeddings_are_tied: bool = torch.equal(llama_model.tok_embeddings.weight, llama_model.output.weight)
if not _embeddings_are_tied:
hf_state_dict['lm_head.weight'] = llama_model.output.weight.clone().to(dtype)


# Generate LlamaConfig (seen in transformers.models.llama.configuration_llama)

# Extract necessary attributes from llama.c model
vocab_size = llama_model.params.vocab_size
hidden_size = llama_model.params.dim
intermediate_size = llama_model.layers[0].feed_forward.w1.weight.shape[0]
num_hidden_layers = llama_model.params.n_layers
num_attention_heads = llama_model.params.n_heads
num_key_value_heads = llama_model.params.n_kv_heads
max_position_embeddings = llama_model.params.max_seq_len
rms_norm_eps = llama_model.params.norm_eps

# TODO check values for:
# pretraining_tp, initializer_range, use_cache,
# rope_theta, and rope_scaling.

config = LlamaConfig(
vocab_size=vocab_size,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
max_position_embeddings=max_position_embeddings,
rms_norm_eps=rms_norm_eps,
tie_word_embeddings=_embeddings_are_tied,
# Manual
architectures=["LlamaForCausalLM"],
hidden_act="silu",
)


# Save files in directory filepath
# First make the directory if it doesn't exist
os.makedirs(filepath, exist_ok=True)

# Save the state dictionary in .bin format, and config as .json
torch.save(hf_state_dict, os.path.join(filepath, "pytorch_model.bin"))
config.save_pretrained(filepath)


# -----------------------------------------------------------------------------
# Load / import functions
Expand Down Expand Up @@ -405,13 +489,23 @@ def permute_reverse(w, n_heads=config.n_heads, dim1=config.dim, dim2=config.dim)
# -----------------------------------------------------------------------------
# API entrypoint

def model_export(model, filepath, version):
def model_export(model, filepath, version, dtype=torch.float32):
"""
Versions docs:
v-1:huggingface export, i.e. intended for use outside of this repo, in HF
v0: legacy llama2.c float format, DEPRECATED
v1: float32 export
v2: int8 quantized Q8_0 export, similar to llama.cpp, in groups
# TODO: add dtype export support for other versions (?)
"""
if version == 0:
legacy_export(model, filepath)
elif version == 1:
version1_export(model, filepath)
elif version == 2:
version2_export(model, filepath)
elif version == -1:
hf_export(model, filepath, dtype)
else:
raise ValueError(f"unknown version {version}")

Expand Down Expand Up @@ -451,11 +545,13 @@ def torchscript_export(model, filepath, zero_params=False, gzip_output=False):
parser = argparse.ArgumentParser()
parser.add_argument("filepath", type=str, help="the output filepath")
parser.add_argument("--version", default=0, type=int, help="the version to export with")
parser.add_argument("--dtype", type=str, help="dtype of the model (fp16, fp32)", default="fp32")
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument("--checkpoint", type=str, help="model checkpoint, .pt file")
group.add_argument("--meta-llama", type=str, help="meta llama model path")
group.add_argument("--hf", type=str, help="huggingface model path")
args = parser.parse_args()
dtype = {"fp16": torch.float16, "fp32": torch.float32}[args.dtype]

if args.checkpoint:
model = load_checkpoint(args.checkpoint)
Expand All @@ -468,4 +564,4 @@ def torchscript_export(model, filepath, zero_params=False, gzip_output=False):
parser.error("Can't load input model!")

# export
model_export(model, args.filepath, args.version)
model_export(model, args.filepath, args.version, args.dtype)
2 changes: 1 addition & 1 deletion run.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@
"token = \"replace your huggingface access token\" #@param {type:\"string\"}\n",
"path = snapshot_download(repo_id=\"meta-llama/Llama-2-7b\",cache_dir=\"Llama-2-7b\", use_auth_token=token)\n",
"\n",
"!python export_meta_llama_bin.py $path llama2_7b.bin\n",
"!python export.py llama2_7b.bin --meta-llama $path\n",
"\n",
"print(\"./run llama2_7b.bin\\n\")\n",
"!./run llama2_7b.bin"
Expand Down
Loading

0 comments on commit 1383943

Please sign in to comment.