Skip to content

Commit

Permalink
Create top-level torchat.py CLI binary (pytorch#215)
Browse files Browse the repository at this point in the history
  • Loading branch information
mergennachin authored and malfet committed Jul 17, 2024
1 parent eeaf0b1 commit f69429f
Show file tree
Hide file tree
Showing 8 changed files with 179 additions and 138 deletions.
57 changes: 57 additions & 0 deletions .github/workflows/test_torchchat_commands.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
name: Run torchchat command tests

on:
push:
branches:
- main
pull_request:
workflow_dispatch:

jobs:
torchchat-command-load-test:
strategy:
matrix:
runner: [macos-14]
runs-on: ${{matrix.runner}}
steps:
- name: Checkout repo
uses: actions/checkout@v2
- name: Setup Python
uses: actions/setup-python@v2
with:
python-version: 3.11
- name: Print machine info
run: |
uname -a
if [ $(uname -s) == Darwin ]; then
sysctl machdep.cpu.brand_string
sysctl machdep.cpu.core_count
fi
- name: Install requirements
run: |
echo "Installing pip packages"
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu
pip install -r requirements.txt
- name: Download Stories files
run: |
mkdir -p checkpoints/stories15M
pushd checkpoints/stories15M
curl -fsSL -O https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.pt
curl -fsSL -O https://github.com/karpathy/llama2.c/raw/master/tokenizer.model
popd
- name: Test generate
run: |
export MODEL_PATH=checkpoints/stories15M/stories15M.pt
export MODEL_NAME=stories15M
export MODEL_DIR=/tmp
python generate.py --device cpu --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager1
python torchchat.py generate --device cpu --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager2
cat ./output_eager1
cat ./output_eager2
echo "Tests complete."
16 changes: 8 additions & 8 deletions build/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,19 @@

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import itertools

import os
import sys
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Tuple, Union
from typing import Optional, Union

import torch
import torch._dynamo.config
import torch._inductor.config
from cli import cli_args

from quantize import get_precision, name_to_dtype, quantize_model, set_precision
from quantize import name_to_dtype, quantize_model

from sentencepiece import SentencePieceProcessor

Expand Down Expand Up @@ -110,7 +110,7 @@ def from_args(cls, args): # -> TokenizerArgs:
elif args.checkpoint_dir:
tokenizer_path = args.checkpoint_dir / "tokenizer.model"
else:
raise RuntimeError(f"cannot find tokenizer model")
raise RuntimeError("cannot find tokenizer model")

if not tokenizer_path.is_file():
raise RuntimeError(f"did not find tokenizer at {tokenizer_path}")
Expand Down Expand Up @@ -243,7 +243,7 @@ def _initialize_model(
# assert model_dtype == "float32", f"dtype setting not valid for a DSO model. Specify dtype during export."
assert (
quantize is None or quantize == "{ }"
), f"quantize not valid for exported DSO model. Specify quantization during export."
), "quantize not valid for exported DSO model. Specify quantization during export."
try:
model = model_
# Replace model forward with the AOT-compiled forward
Expand All @@ -262,12 +262,12 @@ def _initialize_model(
# assert model_dtype == "float32", f"dtype setting not valid for a DSO model. Specify dtype during export."
assert (
quantize is None or quantize == "{ }"
), f"quantize not valid for exported PTE model. Specify quantization during export."
), "quantize not valid for exported PTE model. Specify quantization during export."
try:
from build.model_et import PTEModel

model = PTEModel(model_.config, builder_args.pte_path)
except Exception as e:
except Exception:
raise RuntimeError(f"Failed to load ET compiled {builder_args.pte_path}")
else:
model = model_
Expand Down
40 changes: 16 additions & 24 deletions cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import os
import time
import json
from pathlib import Path

import torch
import torch.nn as nn


default_device = "cpu" # 'cuda' if torch.cuda.is_available() else 'cpu'

Expand Down Expand Up @@ -41,11 +40,19 @@ def check_args(args, command_name: str):
print(f"Warning: {text}")


def cli_args():
import argparse
def add_arguments_for_generate(parser):
_add_arguments_common(parser)


def add_arguments_for_eval(parser):
_add_arguments_common(parser)


def add_arguments_for_export(parser):
_add_arguments_common(parser)

parser = argparse.ArgumentParser(description="Your CLI description.")

def _add_arguments_common(parser):
parser.add_argument(
"--seed",
type=int,
Expand All @@ -60,21 +67,6 @@ def cli_args():
action="store_true",
help="Whether to use tiktoken tokenizer.",
)
parser.add_argument(
"--export",
action="store_true",
help="Use torchchat to export a model.",
)
parser.add_argument(
"--eval",
action="store_true",
help="Use torchchat to eval a model.",
)
parser.add_argument(
"--generate",
action="store_true",
help="Use torchchat to generate a sequence using a model.",
)
parser.add_argument(
"--chat",
action="store_true",
Expand Down Expand Up @@ -162,10 +154,10 @@ def cli_args():
parser.add_argument(
"--quantize", type=str, default="{ }", help="Quantization options."
)
parser.add_argument("--params-table", type=str, default=None, help="Device to use")
parser.add_argument(
"--device", type=str, default=default_device, help="Device to use"
)
parser.add_argument("--params-table", type=str, default=None, help="Device to use")
parser.add_argument(
"--tasks",
nargs="+",
Expand All @@ -183,13 +175,13 @@ def cli_args():
help="maximum length sequence to evaluate",
)

args = parser.parse_args()

def arg_init(args):

if Path(args.quantize).is_file():
with open(args.quantize, "r") as f:
args.quantize = json.loads(f.read())

if args.seed:
torch.manual_seed(args.seed)

return args
54 changes: 21 additions & 33 deletions eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,33 @@

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import sys
import argparse
import time
from pathlib import Path
from typing import Optional

import torch
import torch._dynamo.config
import torch._inductor.config

from build.builder import (
_initialize_model,
_initialize_tokenizer,
BuilderArgs,
TokenizerArgs,
)

from build.model import Transformer
from cli import add_arguments_for_eval, arg_init
from generate import encode_tokens, model_forward

from quantize import set_precision

torch._dynamo.config.automatic_dynamic_shapes = True
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.triton.cudagraphs = True
torch._dynamo.config.cache_size_limit = 100000

from build.model import Transformer
from cli import cli_args
from quantize import name_to_dtype, set_precision

try:
import lm_eval
Expand All @@ -29,13 +38,6 @@
except:
lm_eval_available = False

from build.builder import (
_initialize_model,
_initialize_tokenizer,
BuilderArgs,
TokenizerArgs,
)
from generate import encode_tokens, model_forward

if lm_eval_available:
try: # lm_eval version 0.4
Expand Down Expand Up @@ -218,30 +220,19 @@ def main(args) -> None:

builder_args = BuilderArgs.from_args(args)
tokenizer_args = TokenizerArgs.from_args(args)

checkpoint_path = args.checkpoint_path
checkpoint_dir = args.checkpoint_dir
params_path = args.params_path
params_table = args.params_table
gguf_path = args.gguf_path
tokenizer_path = args.tokenizer_path
dso_path = args.dso_path
pte_path = args.pte_path
quantize = args.quantize
device = args.device
model_dtype = args.dtype
tasks = args.tasks
limit = args.limit
max_seq_length = args.max_seq_length
use_tiktoken = args.tiktoken

print(f"Using device={device}")
set_precision(buildeer_args.precision)
set_precision(builder_args.precision)

tokenizer = _initialize_tokenizer(tokenizer_args)
builder_args.setup_caches = False
model = _initialize_model(
buildeer_args,
builder_args,
quantize,
)

Expand Down Expand Up @@ -280,11 +271,8 @@ def main(args) -> None:


if __name__ == "__main__":

def cli():
args = cli_args()
main(args)


if __name__ == "__main__":
cli()
parser = argparse.ArgumentParser(description="Export specific CLI.")
add_arguments_for_eval(parser)
args = parser.parse_args()
args = arg_init(args)
main(args)
34 changes: 13 additions & 21 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import argparse
import os
import time
from pathlib import Path

import torch
import torch.nn as nn
from cli import cli_args

from quantize import get_precision, name_to_dtype, quantize_model, set_precision
from torch.export import Dim, export
from build.builder import _initialize_model, BuilderArgs
from cli import add_arguments_for_export, arg_init, check_args
from export_aoti import export_model as export_model_aoti

from quantize import set_precision

try:
executorch_export_available = True
Expand All @@ -22,13 +22,6 @@
executorch_exception = f"ET EXPORT EXCEPTION: {e}"
executorch_export_available = False

from build.builder import _initialize_model, BuilderArgs, TokenizerArgs

from build.model import Transformer
from export_aoti import export_model as export_model_aoti
from generate import decode_one_token
from quantize import name_to_dtype, quantize_model
from torch._export import capture_pre_autograd_graph

default_device = "cpu" # 'cuda' if torch.cuda.is_available() else 'cpu'

Expand All @@ -44,7 +37,6 @@ def device_sync(device):

def main(args):
builder_args = BuilderArgs.from_args(args)
tokenizer_args = TokenizerArgs.from_args(args)
quantize = args.quantize

print(f"Using device={builder_args.device}")
Expand All @@ -70,7 +62,7 @@ def main(args):
export_model_et(model, builder_args.device, args.output_pte_path, args)
else:
print(
f"Export with executorch requested but Executorch could not be loaded"
"Export with executorch requested but Executorch could not be loaded"
)
print(executorch_exception)
if output_dso_path:
Expand All @@ -79,10 +71,10 @@ def main(args):
export_model_aoti(model, builder_args.device, output_dso_path, args)


def cli():
args = cli_args()
main(args)


if __name__ == "__main__":
cli()
parser = argparse.ArgumentParser(description="Export specific CLI.")
add_arguments_for_export(parser)
args = parser.parse_args()
check_args(args, "export")
args = arg_init(args)
main(args)
Loading

0 comments on commit f69429f

Please sign in to comment.