Skip to content

Commit

Permalink
[Distributed Inference] Make torch run work for torchchat and fix TP …
Browse files Browse the repository at this point in the history
…bugs (pytorch#877)

* [Distributed Inference] Make torch run work for torchchat
  • Loading branch information
fduwjj authored and malfet committed Jul 17, 2024
1 parent a47091a commit 5aacd36
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 57 deletions.
27 changes: 14 additions & 13 deletions build/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from build.model import Transformer
from build.utils import device_sync, is_cpu_device, is_cuda_or_cpu_device, name_to_dtype
from distributed import parallelize_llama, ParallelDims
from distributed import parallelize_llama, ParallelDims, init_distributed


@dataclass
Expand Down Expand Up @@ -278,6 +278,15 @@ def _unset_gguf_kwargs(builder_args):
builder_args.gguf_kwargs = None


def _init_model_on_meta_device(builder_args):
with torch.device("meta"):
if builder_args.params_path:
return Transformer.from_params(builder_args.params_path)
elif builder_args.params_table:
return Transformer.from_table(builder_args.params_table)
else:
return Transformer.from_name(builder_args.checkpoint_path.parent.name)

def _load_model_gguf(builder_args, only_config=False):
assert builder_args.gguf_path
if builder_args.gguf_kwargs is None:
Expand All @@ -291,14 +300,7 @@ def _load_model_gguf(builder_args, only_config=False):
def _load_model_default(builder_args, only_config=False):
assert not builder_args.gguf_path

with torch.device("meta"):
if builder_args.params_path:
model = Transformer.from_params(builder_args.params_path)
elif builder_args.params_table:
model = Transformer.from_table(builder_args.params_table)
else:
model = Transformer.from_name(builder_args.checkpoint_path.parent.name)

model = _init_model_on_meta_device(builder_args)
# checkpoint = torch.load(str(builder_args.checkpoint_path), mmap=True, weights_only=True)
cps = []
if builder_args.checkpoint_dir is not None:
Expand Down Expand Up @@ -357,12 +359,11 @@ def _load_model(builder_args, only_config=False):
pp=1,
world_size=world_size,
)
device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")
torch.cuda.set_device(device)
init_distributed(job_config)
init_distributed()
world_mesh = parallel_dims.build_mesh(device_type="cuda")

print("Applying model parallel to model ...")
parallelize_llama(model)
parallelize_llama(model, world_mesh, parallel_dims)

model = model.to(device=builder_args.device, dtype=builder_args.precision)
return model.eval()
Expand Down
1 change: 1 addition & 0 deletions distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@

from distributed.parallelize_llama import parallelize_llama
from distributed.parallel_config import ParallelDims
from distributed.utils import init_distributed
1 change: 1 addition & 0 deletions distributed/parallel_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from dataclasses import dataclass, field
from torch.distributed.device_mesh import init_device_mesh
from distributed.utils import logger

@dataclass
class ParallelDims:
Expand Down
69 changes: 36 additions & 33 deletions distributed/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@
parallelize_module,
PrepareModuleInput,
RowwiseParallel,
SequenceParallel,
)

import torch.nn as nn
from torch.distributed._tensor import Replicate, Shard
from distributed.parallel_config import ParallelDims
from torch.distributed.device_mesh import DeviceMesh
from distributed.utils import logger


def apply_tp(
Expand Down Expand Up @@ -43,53 +44,55 @@ def apply_tp(

tp_mesh = world_mesh["tp"]

# 1. Parallelize the first embedding and the last linear proj layer
# 2. Parallelize the root norm layer over the sequence dim
# 3. Shard the first transformer block's inputs
model = parallelize_module(
model,
tp_mesh,
{
"tok_embeddings": RowwiseParallel(
input_layouts=Replicate(),
output_layouts=Shard(1),
),
"output": ColwiseParallel(
input_layouts=Shard(1),
output_layouts=Replicate(),
use_local_output=True,
),
"norm": SequenceParallel(),
},
)

# Apply tensor + sequence parallelism to every transformer block
for layer_id, transformer_block in model.layers.items():
# TODO: To figure out the TP for the tok_embedding and the linear proj layer.
# # 1. Parallelize the first embedding and the last linear proj layer
# # 2. Shard the first transformer block's inputs
# model = parallelize_module(
# model,
# tp_mesh,
# {
# "tok_embeddings": RowwiseParallel(
# input_layouts=Replicate(),
# output_layouts=Replicate(),
# ),
# "output": ColwiseParallel(
# input_layouts=Shard(1),
# output_layouts=Replicate(),
# use_local_output=True,
# ),
# },
# )

# Apply tensor parallelism to every transformer block
for transformer_block in model.layers:
layer_plan = {
"attention": prepare_module_input(
input_layouts=(Shard(1), None),
"attention": PrepareModuleInput(
input_layouts=(Replicate(), None),
desired_input_layouts=(Replicate(), None),
),
"attention.wq": ColwiseParallel(),
"attention.wk": ColwiseParallel(),
"attention.wv": ColwiseParallel(),
"attention.wo": RowwiseParallel(output_layouts=Shard(1)),
"attention_norm": SequenceParallel(),
"feed_forward": prepare_module_input(
input_layouts=(Shard(1),),
"attention.wo": RowwiseParallel(
output_layouts=Replicate(),
use_local_output=True,
),
"feed_forward": PrepareModuleInput(
input_layouts=(Replicate(),),
desired_input_layouts=(Replicate(),),
),
"feed_forward.w1": ColwiseParallel(),
"feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)),
"feed_forward.w2": RowwiseParallel(
output_layouts=Replicate(),
use_local_output=True
),
"feed_forward.w3": ColwiseParallel(),
"ffn_norm": SequenceParallel(),
}

# Adjust attention module to use the local number of heads
attn_layer = transformer_block.attention
attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size()
attn_layer.n_local_heads = attn_layer.n_local_heads // tp_mesh.size()
attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size()

parallelize_module(
module=transformer_block,
Expand Down Expand Up @@ -125,6 +128,6 @@ def parallelize_llama(
"""

if parallel_dims.tp_enabled:
model = apply_tp(model, world_mesh, parallel_dims)
model = apply_tp(model, world_mesh)

return model
15 changes: 4 additions & 11 deletions distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from datetime import timedelta

import torch
import logging
logger = logging.getLogger()


def _warn_overwrite_env(env, val):
Expand All @@ -25,24 +27,15 @@ def _warn_overwrite_env(env, val):
SKIP_CLEANUP = "3"


def init_distributed(job_config):
def init_distributed(init_timeout_seconds: int = 120):
# FlightRecorder is incompatible with =1 mode where watchdog aborts work, must use =3 (skipcleanup)
# to get flight recorder dumps. See https://github.com/pytorch/pytorch/issues/121055
# This could be done only when flight recorder is enabled, but its nice to be consistent to avoid subtle
# behavior differences
_warn_overwrite_env(ASYNC_ERROR_HANDLING, SKIP_CLEANUP)

# enable torch nccl flight recorder in the mode that would dump files if timeout is detected
_warn_overwrite_env(TRACE_BUFFER_SIZE, str(job_config.comm.trace_buf_size))
if job_config.comm.trace_buf_size > 0:
# dump on timeout by default if trace buffer is enabled
_warn_overwrite_env(DUMP_ON_TIMEOUT, "1")
dump_dir = f"{job_config.job.dump_folder}/comm_trace"
os.makedirs(dump_dir, exist_ok=True)
_warn_overwrite_env(TRACE_FILE, f"{dump_dir}/rank_")

torch.distributed.init_process_group(
"nccl", timeout=timedelta(seconds=job_config.comm.init_timeout_seconds)
"nccl", timeout=timedelta(seconds=init_timeout_seconds)
)

# to mitigate the memory issue that collectives using
Expand Down
7 changes: 7 additions & 0 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import logging
import sys
import time
import os
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Tuple
Expand Down Expand Up @@ -504,6 +505,12 @@ def _main(
# print = lambda *args, **kwargs: None

print(f"Using device={builder_args.device} {get_device_info(builder_args.device)}")
# If using distributed inference we cannot just assign device to be cuda
# because it will be assigned to cuda:0 by default. We need explicitely set
# the device to be the local rank.
if builder_args.use_distributed:
device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")
torch.cuda.set_device(device)
set_precision(builder_args.precision)
is_speculative = speculative_builder_args.checkpoint_path is not None

Expand Down

0 comments on commit 5aacd36

Please sign in to comment.