Skip to content

Commit

Permalink
Merge branch 'master' into saforem2/ucp-bug
Browse files Browse the repository at this point in the history
  • Loading branch information
loadams authored Jan 9, 2025
2 parents d143141 + 53fb579 commit ec9556f
Show file tree
Hide file tree
Showing 27 changed files with 392 additions and 290 deletions.
2 changes: 1 addition & 1 deletion blogs/windows/08-2024/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ Regardless of the installation choice, you can check that the installation was s
We use an image classification model, CIFAR10, and a language model, BERT, to demonstrate pretraining on Windows with DeepSpeed.

## Pretraining CIFAR10
The scripts and codes required for CIFAR10 pretraining example are available in the following path: DeepSpeedExamples\training\cifar. You can launch the CIFAR10 pretraining experiment using the following command: `deepspeed cifar10_deepspeed.py deepspeed`. The final output should look something like this:
The scripts and codes required for CIFAR10 pretraining example are available in the following path: DeepSpeedExamples\training\cifar. You can launch the CIFAR10 pretraining experiment using the following command: `deepspeed cifar10_deepspeed.py --deepspeed`. The final output should look something like this:
<div align="center">
<img src="./media/cifar10_training.png" style="width:6.5in;height:3.42153in" />
</div>
Expand Down
24 changes: 0 additions & 24 deletions deepspeed/inference/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ def __init__(self, model, config):
self.mp_group = config.tensor_parallel.tp_group
self.mpu = config.tensor_parallel.mpu

#self._validate_args(self.mpu, config.replace_with_kernel_inject)
self.quantize_merge_count = 1
self.quantization_scales = None

Expand Down Expand Up @@ -300,29 +299,6 @@ def _init_quantization_setting(self, quantization_setting):
f"mlp_extra_grouping = {self.mlp_extra_grouping}, "
f"quantize_groups = {self.quantize_groups}", [0])

# TODO: remove this function and add this functionality to pydantic config checking
def _validate_args(self, mpu, replace_with_kernel_inject):
# TODO: to support SD pipeline we need to avoid this check for now
if replace_with_kernel_inject and not isinstance(self.module, Module):
raise ValueError(f"model must be a torch.nn.Module, got {type(self.module)}")
if not isinstance(self._config.tensor_parallel.tp_size, int) or self._config.tensor_parallel.tp_size < 1:
raise ValueError(f"mp_size must be an int >= 1, got {self._config.tensor_parallel.tp_size}")

if mpu:
methods = ["get_model_parallel_group", "get_data_parallel_group"]
for method in methods:
if not hasattr(mpu, method):
raise ValueError(f"mpu is missing {method}")
if self._config.checkpoint is not None and not isinstance(self._config.checkpoint, (str, dict)):
raise ValueError(f"checkpoint must be None, str or dict, got {type(self._config.checkpoint)}")

supported_dtypes = [None, torch.half, torch.int8, torch.float, torch.bfloat16]
if self._config.dtype not in supported_dtypes:
raise ValueError(f"{self._config.dtype} not supported, valid dtype: {supported_dtypes}")

if self.injection_dict is not None and not isinstance(self.injection_dict, dict):
raise ValueError(f"injection_dict must be None or a dict, got: {self.injection_dict}")

def load_model_with_checkpoint(self, r_module):
self.mp_replace = ReplaceWithTensorSlicing(
mp_group=self.mp_group, mp_size=self._config.tensor_parallel.tp_size) #, out_dim=0, in_dim=1)
Expand Down
12 changes: 12 additions & 0 deletions deepspeed/module_inject/containers/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,18 @@
class DS_BloomContainer(MetaTensorContainer, HybridEngineContainer, BaseTransformerContainer):

def __init__(self, **kwargs):
# Check transformers version, error if > 4.43.4 (breaks at 4.44.0)
from importlib.metadata import version
v_transformers = version('transformers')
vers = v_transformers.split('.')
major = int(vers[0])
minor = int(vers[1])
if major > 4 or (major == 4 and minor > 43):
import sys
sys.exit(
f"Transformers version {v_transformers} exceeds version 4.43.4! After transformers version 4.43.4, BLOOM inference with DeepSpeed is no longer supported."
)

super().__init__(**kwargs)

# All model specific things should be defined here instead of the base class.
Expand Down
7 changes: 1 addition & 6 deletions deepspeed/ops/fp_quantizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,4 @@
# DeepSpeed Team

from .quantize import FP_Quantize, Quantizer

try:
import triton
from .fp8_gemm import matmul_fp8
except ImportError:
pass
from .fp8_gemm import matmul_fp8
163 changes: 10 additions & 153 deletions deepspeed/ops/fp_quantizer/fp8_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,161 +11,18 @@
###################################

import torch
import triton
import triton.language as tl


@triton.jit
def matmul_kernel_fp8_bf16(inp_ptr, weight_ptr, out_ptr, scale_ptr, M, N, K, stride_am, stride_ak, stride_bk,
stride_bn, stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr,
quantization_group_size: tl.constexpr):
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
def matmul_fp8(inp, weight, scale, quantization_group_size, quantizer):
from deepspeed import get_accelerator

offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
if not get_accelerator().is_triton_supported():
return matmul_fp8_fallback(inp, weight, scale, quantization_group_size, quantizer)
else:
# Import dynamically to prevent failures on systems without triton.
from .fp8_gemm_triton import matmul_fp8_triton
return matmul_fp8_triton(inp, weight, scale, quantization_group_size)

inp_data = inp_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
weight_data = weight_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
weight_ptrs_offset = offs_k[:, None] * (stride_bk // quantization_group_size) + (
(pid_n * BLOCK_SIZE_N) // quantization_group_size)

weight = tl.load(weight_data, mask=offs_k[:, None] < K, other=0.0)
scale = tl.load(scale_ptr + weight_ptrs_offset)

accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
inp = tl.load(inp_data, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
# Dequantize weight (fp8 -> bf16)
w = (((weight & 0x80) << 8) | ((weight & 0x7f) << 4)).to(tl.uint16)
w = (w + 0x3C00).to(tl.uint16)
w = (w.to(tl.bfloat16, bitcast=True) * scale).to(tl.bfloat16)

inp_data += BLOCK_SIZE_K * stride_ak
weight_data += BLOCK_SIZE_K * stride_bk
weight_mask = offs_k[:, None] < K - (k + 1) * BLOCK_SIZE_K
weight = tl.load(weight_data, mask=weight_mask, other=0.0)
scale = tl.load(scale_ptr + (weight_ptrs_offset +
(((k + 1) * BLOCK_SIZE_K * stride_bk) // quantization_group_size)),
mask=weight_mask,
other=0.0)

accumulator += tl.dot(inp, w)

out = accumulator.to(tl.bfloat16)

offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
out_data = out_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
tl.store(out_data, out, mask=(offs_cm[:, None] < M) & (offs_cn[None, :] < N))


@triton.jit
def matmul_kernel_fp8_fp16(inp_ptr, weight_ptr, out_ptr, scale_ptr, M, N, K, stride_am, stride_ak, stride_bk,
stride_bn, stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr,
quantization_group_size: tl.constexpr):
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m

offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)

inp_data = inp_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
weight_data = weight_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
weight_ptrs_offset = offs_k[:, None] * (stride_bk // quantization_group_size) + (
(pid_n * BLOCK_SIZE_N) // quantization_group_size)

weight = tl.load(weight_data, mask=offs_k[:, None] < K, other=0.0)
scale = tl.load(scale_ptr + weight_ptrs_offset)

accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
inp = tl.load(inp_data, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
# Dequantize weight (fp8 -> fp16)
w = (((weight & 0x80) << 8) | ((weight & 0x7f) << 7)).to(tl.uint16)
w = (w + 0x2000).to(tl.uint16)
w = (w.to(tl.float16, bitcast=True) * scale).to(tl.float16)

inp_data += BLOCK_SIZE_K * stride_ak
weight_data += BLOCK_SIZE_K * stride_bk

weight = tl.load(weight_data, mask=offs_k[:, None] < K - (k + 1) * BLOCK_SIZE_K, other=0.0)
scale = tl.load(scale_ptr + (weight_ptrs_offset +
(((k + 1) * BLOCK_SIZE_K * stride_bk) // quantization_group_size)))

accumulator += tl.dot(inp, w)

out = accumulator.to(tl.float16)

offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
out_data = out_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
tl.store(out_data, out, mask=(offs_cm[:, None] < M) & (offs_cn[None, :] < N))


def matmul_fp8(inp, weight, scale, quantization_group_size):

assert inp.shape[1] == weight.shape[0], \
f"Incompatible dimensions (input: {inp.shape}, weight: {weight.shape})"

M, K = inp.shape
K, N = weight.shape

out = torch.empty((M, N), device=inp.device, dtype=inp.dtype)

# GEMM tuning parameters!
# TODO: Add a more configurable tuning for selecting the best GeMM
BLOCK_SIZE_M = 16 if M <= 16 else 32 if M <= 32 else 64 if M <= 64 else 128
BLOCK_SIZE_N = 64
BLOCK_SIZE_K = max(64, quantization_group_size)
GROUP_SIZE_M = 8
num_stages = 4
num_warps = 4
if M >= 256:
BLOCK_SIZE_M = 256
BLOCK_SIZE_N = 128
BLOCK_SIZE_K = max(128, quantization_group_size)
num_stages = 3
num_warps = 8

grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
kernel = matmul_kernel_fp8_bf16 if inp.dtype == torch.bfloat16 else matmul_kernel_fp8_fp16
kernel[grid](inp,
weight,
out,
scale,
M,
N,
K,
inp.stride(0),
inp.stride(1),
weight.stride(0),
weight.stride(1),
out.stride(0),
out.stride(1),
quantization_group_size=quantization_group_size,
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
BLOCK_SIZE_K=BLOCK_SIZE_K,
GROUP_SIZE_M=GROUP_SIZE_M,
num_stages=num_stages,
num_warps=num_warps)
return out
def matmul_fp8_fallback(inp, weight, scale, quantization_group_size, quantizer):
return torch.matmul(inp, quantizer.dequantize(weight, scale=scale))
Loading

0 comments on commit ec9556f

Please sign in to comment.