Skip to content

Commit

Permalink
[distributed] add TrackTime, CUDATrackTime to monitor perf for weight…
Browse files Browse the repository at this point in the history
… loading per stage and future perf measurements (pytorch#1121)

* add TrackTime, monitor perf for weight loading per stage

* add CUDATrackTime

* ruff formatting

* add device for CUDATrackTime per PR feedback

* add comment re: cuda context, ruff format
  • Loading branch information
lessw2020 authored Sep 11, 2024
1 parent e2049f4 commit 4e7332f
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 2 deletions.
12 changes: 10 additions & 2 deletions dist_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
get_hf_weight_map_and_path,
load_safetensor_weights,
)
from distributed.utils import Color as color, GPUMemoryMonitor

from distributed.utils import Color as color, TrackTime, CUDATrackTime, GPUMemoryMonitor

from distributed.verification_utils import find_cpu_tensors
from torchchat.cli.builder import TokenizerArgs, _initialize_tokenizer
from torchchat.model import ModelArgs, Transformer
Expand Down Expand Up @@ -188,8 +190,14 @@ def main():

# Load weights
logger.info(f"Loading weights for {pp_rank=} on {device=}")
_load_model_weights(model, hf_model_name, device=device, model_config=config)
with TrackTime("cuda") as timer:
_load_model_weights(model, hf_model_name, device=device, model_config=config)

logger.info(
f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for stage {rank}{color.reset}"
)


# Setup input position
# input_pos for prefill: a list of increasing integers from 0 to seqlen
input_pos = torch.arange(seqlen, device=device)
Expand Down
93 changes: 93 additions & 0 deletions distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
import os
from dataclasses import dataclass
from datetime import timedelta
import time
from typing import Optional


import torch

Expand Down Expand Up @@ -79,6 +82,96 @@ class NoColor:
white = ""
reset = ""

class TrackTime:
"""integrated class for perf timing via perf_counter"""

def __init__(self, use_ms: bool = False, round_to: Optional[int] = 4):
self.use_ms = use_ms
self.round_to = round_to
self.start_time = 0.0
self.elapsed_time = 0.0
self.unit = "seconds" if not use_ms else "milliseconds"

def __enter__(self):
self.start_time = time.perf_counter()
return self

def __exit__(self, exc_type, exc_val, exc_tb):
end_time = time.perf_counter()
self.elapsed_time = end_time - self.start_time

if self.use_ms:
self.elapsed_time *= 1000 # Convert to milliseconds

if self.round_to is not None:
self.elapsed_time = round(self.elapsed_time, self.round_to)

def get_time(self) -> float:
return self.elapsed_time


class CUDATrackTime:
"""
Integrated class for perf timing via cuda events.
Note - this uses the default stream to synchronize, and defaults to current device.
The event.record() will create a context on the CUDA device matching the device used at init.
"""

def __init__(self, device=None, use_ms: bool = False, round_to: Optional[int] = 4):
if device is None:
device = torch.cuda.current_device()
elif isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
device = torch.device(f"cuda:{device}")

self.device = device
# Create events on the specified device
with torch.cuda.device(self.device):
self.start_event = torch.cuda.Event(enable_timing=True)
self.end_event = torch.cuda.Event(enable_timing=True)

self.active = False
self.round_to = round_to
self.elapsed_time = 0.0
self.use_ms = use_ms
self.unit = "seconds" if not use_ms else "milliseconds"

def start(self):
if self.active:
raise RuntimeError("Timer is already running. Use .stop() to stop it")
self.start_event.record()
self.active = True

def stop(self):
if not self.active:
raise RuntimeError("Timer is not running. Use .start() to start it")
self.end_event.record()
self.active = False

def get_time(self):
if self.active:
raise RuntimeError("Timer is still running. Use .stop() to stop it")

torch.cuda.synchronize(self.device) # Synchronize all streams on the device
total_time = self.start_event.elapsed_time(self.end_event)

if not self.use_ms:
total_time = total_time / 1000.0 # to seconds

if self.round_to:
total_time = round(total_time, self.round_to)

self.elapsed_time = total_time

return self.elapsed_time

def __enter__(self):
self.start()
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self.stop()

class GPUMemoryMonitor:
def __init__(self, device: str):
Expand Down

0 comments on commit 4e7332f

Please sign in to comment.