diff --git a/deepspeed/ops/fp_quantizer/__init__.py b/deepspeed/ops/fp_quantizer/__init__.py index 51377bc6092c..f9cf23373c26 100644 --- a/deepspeed/ops/fp_quantizer/__init__.py +++ b/deepspeed/ops/fp_quantizer/__init__.py @@ -4,9 +4,4 @@ # DeepSpeed Team from .quantize import FP_Quantize, Quantizer - -try: - import triton - from .fp8_gemm import matmul_fp8 -except ImportError: - pass +from .fp8_gemm import matmul_fp8 diff --git a/deepspeed/ops/fp_quantizer/fp8_gemm.py b/deepspeed/ops/fp_quantizer/fp8_gemm.py index 55504e3af8c9..db4fa5ae2c92 100644 --- a/deepspeed/ops/fp_quantizer/fp8_gemm.py +++ b/deepspeed/ops/fp_quantizer/fp8_gemm.py @@ -11,161 +11,18 @@ ################################### import torch -import triton -import triton.language as tl -@triton.jit -def matmul_kernel_fp8_bf16(inp_ptr, weight_ptr, out_ptr, scale_ptr, M, N, K, stride_am, stride_ak, stride_bk, - stride_bn, stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, - quantization_group_size: tl.constexpr): - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m +def matmul_fp8(inp, weight, scale, quantization_group_size, quantizer): + from deepspeed import get_accelerator - offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N - offs_k = tl.arange(0, BLOCK_SIZE_K) + if not get_accelerator().is_triton_supported(): + return matmul_fp8_fallback(inp, weight, scale, quantization_group_size, quantizer) + else: + # Import dynamically to prevent failures on systems without triton. + from .fp8_gemm_triton import matmul_fp8_triton + return matmul_fp8_triton(inp, weight, scale, quantization_group_size) - inp_data = inp_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) - weight_data = weight_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) - weight_ptrs_offset = offs_k[:, None] * (stride_bk // quantization_group_size) + ( - (pid_n * BLOCK_SIZE_N) // quantization_group_size) - weight = tl.load(weight_data, mask=offs_k[:, None] < K, other=0.0) - scale = tl.load(scale_ptr + weight_ptrs_offset) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - inp = tl.load(inp_data, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) - # Dequantize weight (fp8 -> bf16) - w = (((weight & 0x80) << 8) | ((weight & 0x7f) << 4)).to(tl.uint16) - w = (w + 0x3C00).to(tl.uint16) - w = (w.to(tl.bfloat16, bitcast=True) * scale).to(tl.bfloat16) - - inp_data += BLOCK_SIZE_K * stride_ak - weight_data += BLOCK_SIZE_K * stride_bk - weight_mask = offs_k[:, None] < K - (k + 1) * BLOCK_SIZE_K - weight = tl.load(weight_data, mask=weight_mask, other=0.0) - scale = tl.load(scale_ptr + (weight_ptrs_offset + - (((k + 1) * BLOCK_SIZE_K * stride_bk) // quantization_group_size)), - mask=weight_mask, - other=0.0) - - accumulator += tl.dot(inp, w) - - out = accumulator.to(tl.bfloat16) - - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - out_data = out_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] - tl.store(out_data, out, mask=(offs_cm[:, None] < M) & (offs_cn[None, :] < N)) - - -@triton.jit -def matmul_kernel_fp8_fp16(inp_ptr, weight_ptr, out_ptr, scale_ptr, M, N, K, stride_am, stride_ak, stride_bk, - stride_bn, stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, - quantization_group_size: tl.constexpr): - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - - offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N - offs_k = tl.arange(0, BLOCK_SIZE_K) - - inp_data = inp_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) - weight_data = weight_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) - weight_ptrs_offset = offs_k[:, None] * (stride_bk // quantization_group_size) + ( - (pid_n * BLOCK_SIZE_N) // quantization_group_size) - - weight = tl.load(weight_data, mask=offs_k[:, None] < K, other=0.0) - scale = tl.load(scale_ptr + weight_ptrs_offset) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - inp = tl.load(inp_data, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) - # Dequantize weight (fp8 -> fp16) - w = (((weight & 0x80) << 8) | ((weight & 0x7f) << 7)).to(tl.uint16) - w = (w + 0x2000).to(tl.uint16) - w = (w.to(tl.float16, bitcast=True) * scale).to(tl.float16) - - inp_data += BLOCK_SIZE_K * stride_ak - weight_data += BLOCK_SIZE_K * stride_bk - - weight = tl.load(weight_data, mask=offs_k[:, None] < K - (k + 1) * BLOCK_SIZE_K, other=0.0) - scale = tl.load(scale_ptr + (weight_ptrs_offset + - (((k + 1) * BLOCK_SIZE_K * stride_bk) // quantization_group_size))) - - accumulator += tl.dot(inp, w) - - out = accumulator.to(tl.float16) - - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - out_data = out_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] - tl.store(out_data, out, mask=(offs_cm[:, None] < M) & (offs_cn[None, :] < N)) - - -def matmul_fp8(inp, weight, scale, quantization_group_size): - - assert inp.shape[1] == weight.shape[0], \ - f"Incompatible dimensions (input: {inp.shape}, weight: {weight.shape})" - - M, K = inp.shape - K, N = weight.shape - - out = torch.empty((M, N), device=inp.device, dtype=inp.dtype) - - # GEMM tuning parameters! - # TODO: Add a more configurable tuning for selecting the best GeMM - BLOCK_SIZE_M = 16 if M <= 16 else 32 if M <= 32 else 64 if M <= 64 else 128 - BLOCK_SIZE_N = 64 - BLOCK_SIZE_K = max(64, quantization_group_size) - GROUP_SIZE_M = 8 - num_stages = 4 - num_warps = 4 - if M >= 256: - BLOCK_SIZE_M = 256 - BLOCK_SIZE_N = 128 - BLOCK_SIZE_K = max(128, quantization_group_size) - num_stages = 3 - num_warps = 8 - - grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) - kernel = matmul_kernel_fp8_bf16 if inp.dtype == torch.bfloat16 else matmul_kernel_fp8_fp16 - kernel[grid](inp, - weight, - out, - scale, - M, - N, - K, - inp.stride(0), - inp.stride(1), - weight.stride(0), - weight.stride(1), - out.stride(0), - out.stride(1), - quantization_group_size=quantization_group_size, - BLOCK_SIZE_M=BLOCK_SIZE_M, - BLOCK_SIZE_N=BLOCK_SIZE_N, - BLOCK_SIZE_K=BLOCK_SIZE_K, - GROUP_SIZE_M=GROUP_SIZE_M, - num_stages=num_stages, - num_warps=num_warps) - return out +def matmul_fp8_fallback(inp, weight, scale, quantization_group_size, quantizer): + return torch.matmul(inp, quantizer.dequantize(weight, scale=scale)) diff --git a/deepspeed/ops/fp_quantizer/fp8_gemm_triton.py b/deepspeed/ops/fp_quantizer/fp8_gemm_triton.py new file mode 100644 index 000000000000..746e217d4194 --- /dev/null +++ b/deepspeed/ops/fp_quantizer/fp8_gemm_triton.py @@ -0,0 +1,171 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +######## Fused MoE kernel ######### +# These kernels are implemented for +# fusing GeMM with dequantization of +# fp8 weight data when using bit-16 +# activation. +################################### + +import torch +import triton +import triton.language as tl + + +@triton.jit +def matmul_kernel_fp8_bf16(inp_ptr, weight_ptr, out_ptr, scale_ptr, M, N, K, stride_am, stride_ak, stride_bk, + stride_bn, stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, + quantization_group_size: tl.constexpr): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + + inp_data = inp_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + weight_data = weight_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + weight_ptrs_offset = offs_k[:, None] * (stride_bk // quantization_group_size) + ( + (pid_n * BLOCK_SIZE_N) // quantization_group_size) + + weight = tl.load(weight_data, mask=offs_k[:, None] < K, other=0.0) + scale = tl.load(scale_ptr + weight_ptrs_offset) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + inp = tl.load(inp_data, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + # Dequantize weight (fp8 -> bf16) + w = (((weight & 0x80) << 8) | ((weight & 0x7f) << 4)).to(tl.uint16) + w = (w + 0x3C00).to(tl.uint16) + w = (w.to(tl.bfloat16, bitcast=True) * scale).to(tl.bfloat16) + + inp_data += BLOCK_SIZE_K * stride_ak + weight_data += BLOCK_SIZE_K * stride_bk + weight_mask = offs_k[:, None] < K - (k + 1) * BLOCK_SIZE_K + weight = tl.load(weight_data, mask=weight_mask, other=0.0) + scale = tl.load(scale_ptr + (weight_ptrs_offset + + (((k + 1) * BLOCK_SIZE_K * stride_bk) // quantization_group_size)), + mask=weight_mask, + other=0.0) + + accumulator += tl.dot(inp, w) + + out = accumulator.to(tl.bfloat16) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + out_data = out_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + tl.store(out_data, out, mask=(offs_cm[:, None] < M) & (offs_cn[None, :] < N)) + + +@triton.jit +def matmul_kernel_fp8_fp16(inp_ptr, weight_ptr, out_ptr, scale_ptr, M, N, K, stride_am, stride_ak, stride_bk, + stride_bn, stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, + quantization_group_size: tl.constexpr): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + + inp_data = inp_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + weight_data = weight_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + weight_ptrs_offset = offs_k[:, None] * (stride_bk // quantization_group_size) + ( + (pid_n * BLOCK_SIZE_N) // quantization_group_size) + + weight = tl.load(weight_data, mask=offs_k[:, None] < K, other=0.0) + scale = tl.load(scale_ptr + weight_ptrs_offset) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + inp = tl.load(inp_data, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + # Dequantize weight (fp8 -> fp16) + w = (((weight & 0x80) << 8) | ((weight & 0x7f) << 7)).to(tl.uint16) + w = (w + 0x2000).to(tl.uint16) + w = (w.to(tl.float16, bitcast=True) * scale).to(tl.float16) + + inp_data += BLOCK_SIZE_K * stride_ak + weight_data += BLOCK_SIZE_K * stride_bk + + weight = tl.load(weight_data, mask=offs_k[:, None] < K - (k + 1) * BLOCK_SIZE_K, other=0.0) + scale = tl.load(scale_ptr + (weight_ptrs_offset + + (((k + 1) * BLOCK_SIZE_K * stride_bk) // quantization_group_size))) + + accumulator += tl.dot(inp, w) + + out = accumulator.to(tl.float16) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + out_data = out_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + tl.store(out_data, out, mask=(offs_cm[:, None] < M) & (offs_cn[None, :] < N)) + + +def matmul_fp8_triton(inp, weight, scale, quantization_group_size): + + assert inp.shape[1] == weight.shape[0], \ + f"Incompatible dimensions (input: {inp.shape}, weight: {weight.shape})" + + M, K = inp.shape + K, N = weight.shape + + out = torch.empty((M, N), device=inp.device, dtype=inp.dtype) + + # GEMM tuning parameters! + # TODO: Add a more configurable tuning for selecting the best GeMM + BLOCK_SIZE_M = 16 if M <= 16 else 32 if M <= 32 else 64 if M <= 64 else 128 + BLOCK_SIZE_N = 64 + BLOCK_SIZE_K = max(64, quantization_group_size) + GROUP_SIZE_M = 8 + num_stages = 4 + num_warps = 4 + if M >= 256: + BLOCK_SIZE_M = 256 + BLOCK_SIZE_N = 128 + BLOCK_SIZE_K = max(128, quantization_group_size) + num_stages = 3 + num_warps = 8 + + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) + kernel = matmul_kernel_fp8_bf16 if inp.dtype == torch.bfloat16 else matmul_kernel_fp8_fp16 + kernel[grid](inp, + weight, + out, + scale, + M, + N, + K, + inp.stride(0), + inp.stride(1), + weight.stride(0), + weight.stride(1), + out.stride(0), + out.stride(1), + quantization_group_size=quantization_group_size, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=GROUP_SIZE_M, + num_stages=num_stages, + num_warps=num_warps) + return out diff --git a/deepspeed/runtime/zero/mics.py b/deepspeed/runtime/zero/mics.py index c9ae58a121de..628bf86a61da 100755 --- a/deepspeed/runtime/zero/mics.py +++ b/deepspeed/runtime/zero/mics.py @@ -38,7 +38,7 @@ class MiCS_AllGatherCoalescedHandle(AllGatherCoalescedHandle): def __init__(self, allgather_handle, params: List[Parameter], partitions: List[Tensor], world_size: int) -> None: super().__init__(allgather_handle, params, partitions, world_size) - def wait(self) -> None: + def wait(self, **kwargs) -> None: """ """ # let the current stream to op diff --git a/deepspeed/runtime/zero/parameter_offload.py b/deepspeed/runtime/zero/parameter_offload.py index 0be88a1e1ba6..d5b7bac55146 100644 --- a/deepspeed/runtime/zero/parameter_offload.py +++ b/deepspeed/runtime/zero/parameter_offload.py @@ -145,6 +145,16 @@ def __init__( module.ds_inflight_param_registry = InflightParamRegistry() self.__inflight_param_registry = module.ds_inflight_param_registry + self.fast_sharding_for_leaf_module = False + + if zero_module_granularity_threshold > 0: + self.min_granularity_value = sys.maxsize + self.min_granularity_layer = None + self.granularity_info = set() + self.z3_leaf_layers = [] + self._set_z3_leaf_modules_by_threshold(module, zero_module_granularity_threshold) + self.fast_sharding_for_leaf_module = True + self.param_coordinator = PartitionedParameterCoordinator( prefetch_bucket_sz=self._prefetch_bucket_sz, max_reuse_distance_in_numel=self._max_reuse_distance_in_numel, @@ -155,14 +165,7 @@ def __init__( timers=self.timers, zero_quantized_weights=self.zero_quantized_weights, zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights, - ) - - if zero_module_granularity_threshold > 0: - self.min_granularity_value = sys.maxsize - self.min_granularity_layer = None - self.granularity_info = set() - self.z3_leaf_layers = [] - self._set_z3_leaf_modules_by_threshold(module, zero_module_granularity_threshold) + fast_sharding_for_leaf_module=self.fast_sharding_for_leaf_module) self.forward_hooks = [] self.backward_hooks = [] diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index cb0cd7c8017d..e8cb797b8a5b 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -55,7 +55,7 @@ def __init__(self, param: Parameter) -> None: non_blocking=True).view(param.ds_shape) self.__param = param - def wait(self) -> None: + def wait(self, **kwargs) -> None: if not get_accelerator().resolves_data_dependency(): get_accelerator().current_stream().synchronize() self.__param.ds_status = ZeroParamStatus.AVAILABLE @@ -78,7 +78,7 @@ def __init__(self, params: List[Parameter]) -> None: non_blocking=True).view(param.ds_shape) @instrument_w_nvtx - def wait(self) -> None: + def wait(self, **kwargs) -> None: if self.__complete: return @@ -639,7 +639,7 @@ def __init__(self, handle, param: Parameter, quantization=None) -> None: self.__param = param self.__quantization = quantization - def wait(self) -> None: + def wait(self, handle_dependency=True) -> None: instrument_w_nvtx(self.__handle.wait)() if self.__quantization: instrument_w_nvtx(self.__quantization.quant_handle.wait)() @@ -650,6 +650,8 @@ def wait(self) -> None: class AllGatherCoalescedHandle: + data_buffer = [] + def __init__( self, allgather_handle, @@ -672,7 +674,7 @@ def __init__( raise RuntimeError(f"expected param {param.ds_summary()} to not be available") @instrument_w_nvtx - def wait(self) -> None: + def wait(self, handle_dependency=True) -> None: if self.complete: return @@ -704,14 +706,20 @@ def wait(self) -> None: partitions.append(part_to_copy) param.data = instrument_w_nvtx(torch.cat)(partitions).view(param.ds_shape) param.ds_status = ZeroParamStatus.AVAILABLE - - for part_to_copy in partitions: - if not get_accelerator().is_synchronized_device(): + if not get_accelerator().is_synchronized_device() and handle_dependency: + for part_to_copy in partitions: part_to_copy.record_stream(get_accelerator().current_stream()) param_offset += ds_tensor_numel self.complete = True + if not get_accelerator().is_synchronized_device() and not handle_dependency: + # if the device needs to handle dependencies and opts for explicit processing outside the function. + AllGatherCoalescedHandle.data_buffer.append(partitions) + + @staticmethod + def free_buffer(): + AllGatherCoalescedHandle.data_buffer = [] class MultipleAllGatherHandles: @@ -719,9 +727,9 @@ class MultipleAllGatherHandles: def __init__(self, handles: List[AllGatherCoalescedHandle]): self.handles = handles - def wait(self) -> None: + def wait(self, handle_dependency=True) -> None: for handle in self.handles: - handle.wait() + handle.wait(handle_dependency) class AllReduceCoalescedHandle: @@ -1377,13 +1385,13 @@ def all_gather_coalesced(params: Iterable[Parameter], quantization=quant_info, ) - def partition(param_list=None, hierarchy=0, has_been_updated=False): + def partition(param_list=None, hierarchy=0, has_been_updated=False, free_data=True): cls = param print_rank_0(f"{'--'*hierarchy}----Partitioning param {debug_param2name_id_shape_device(cls)}", force=False) if param_list is None: param_list = [cls] - self._partition(param_list, has_been_updated=has_been_updated) + self._partition(param_list, has_been_updated=has_been_updated, free_data=True) def reduce_gradients_at_owner(param_list=None, hierarchy=0): cls = param @@ -1527,12 +1535,12 @@ def _all_gather(self, param_list, async_op=False, hierarchy=None): return handles - def _partition(self, param_list, force=False, has_been_updated=False): + def _partition(self, param_list, force=False, has_been_updated=False, free_data=True): for param in param_list: print_rank_0(f"Before Partitioning Param {param.ds_id}", force=False) if self.zero_param_process_group is not None: self._partition_param_sec(param) - self._partition_param(param, has_been_updated=has_been_updated) + self._partition_param(param, has_been_updated=has_been_updated, free_data=True) param.ds_status = ZeroParamStatus.NOT_AVAILABLE # if param.ds_tensor is not None: @@ -1540,7 +1548,7 @@ def _partition(self, param_list, force=False, has_been_updated=False): # "After the parameters are initially partitioned, make sure we are not recreating the partition." #print_rank_0(f"After Partitioning Param {param.ds_id} {param.ds_tensor.size()} {param.ds_tensor}",force=False) @instrument_w_nvtx - def _partition_param(self, param, buffer=None, has_been_updated=False): + def _partition_param(self, param, buffer=None, has_been_updated=False, free_data=True): assert param.ds_status is not ZeroParamStatus.INFLIGHT, f" {param} Cannot partition a param in flight" global reuse_buffers print_rank_0(f"Param id {param.ds_id} status is {param.ds_status}", force=False) @@ -1565,7 +1573,8 @@ def _partition_param(self, param, buffer=None, has_been_updated=False): see_memory_usage(f'Before partitioning param {param.ds_id} {param.shape}', force=False) # param.data does not store anything meaningful in partitioned state - free_param(param) + if free_data: + free_param(param) see_memory_usage(f'After partitioning param {param.ds_id} {param.shape}', force=False) if param.ds_tensor.final_location == OffloadDeviceEnum.nvme: diff --git a/deepspeed/runtime/zero/partitioned_param_coordinator.py b/deepspeed/runtime/zero/partitioned_param_coordinator.py index 596d0e9c20f9..08cb6c0de54f 100644 --- a/deepspeed/runtime/zero/partitioned_param_coordinator.py +++ b/deepspeed/runtime/zero/partitioned_param_coordinator.py @@ -76,18 +76,17 @@ class __ParamInTrace: param: Parameter step_id_last_used_at: int - def __init__( - self, - prefetch_bucket_sz: int, - max_reuse_distance_in_numel: int, - max_available_parameters_in_numel: int, - allgather_stream: get_accelerator().Stream, - inflight_param_registry: InflightParamRegistry, - prefetch_nvme: bool = False, - timers=None, - zero_quantized_weights=False, - zero_quantized_nontrainable_weights=False, - ) -> None: + def __init__(self, + prefetch_bucket_sz: int, + max_reuse_distance_in_numel: int, + max_available_parameters_in_numel: int, + allgather_stream: get_accelerator().Stream, + inflight_param_registry: InflightParamRegistry, + prefetch_nvme: bool = False, + timers=None, + zero_quantized_weights=False, + zero_quantized_nontrainable_weights=False, + fast_sharding_for_leaf_module=False) -> None: # mapping of param -> handle for each param that is currently in flight self.__inflight_param_registry = inflight_param_registry # keeps track of the number of submodules invoked so far. @@ -130,6 +129,10 @@ def __init__( self.__max_ongoing_fetch_events: int = 2 self.__profiler = PartitionedParameterProfiler(timers if ENABLE_PROFILER else None) + # whether to enable fast fetch for the z3 leaf module. + # this will improve fetch speed but will not break down leaf module parameters to alleviate memory pressure. + self.fast_sharding_for_leaf_module = fast_sharding_for_leaf_module + """Tracing and Tracking TODO. consider performing trace before initializing PartitionedParameterCoordinator and passing trace results into constructor. This way all the code in here can @@ -308,6 +311,7 @@ def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None: wait_numel = 0 wait_event_name = __class__.FORWARD_FETCH_WAIT if forward else __class__.BACKWARD_FETCH_WAIT self.__profiler.start_event(wait_event_name) + fast_fetch = self.fast_sharding_for_leaf_module and z3_leaf_module(current_submodule) # wait for parameters in the immediately needed submodule to become available for param in params_to_fetch: param.ds_active_sub_modules.add(current_submodule.id) @@ -321,9 +325,9 @@ def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None: if len(self.__ongoing_fetch_events) > self.__max_ongoing_fetch_events: self.__ongoing_fetch_events.popleft().synchronize() - self.__inflight_param_registry.pop(param).wait() + self.__inflight_param_registry.pop(param).wait(handle_dependency=not fast_fetch) - if not get_accelerator().handles_memory_backpressure(): + if not get_accelerator().handles_memory_backpressure() and not fast_fetch: event = get_accelerator().Event() event.record() self.__ongoing_fetch_events.append(event) @@ -331,6 +335,8 @@ def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None: assert param.ds_status == ZeroParamStatus.AVAILABLE, param.ds_summary() if not get_accelerator().resolves_data_dependency(): get_accelerator().current_stream().wait_stream(self.__allgather_stream) + if fast_fetch: + AllGatherCoalescedHandle.free_buffer() self.__profiler.stop_event(wait_event_name, wait_numel) # kick off parameter prefetches for upcoming modules @@ -412,10 +418,20 @@ def release_sub_module(self, submodule: Module) -> None: be released.""" params_to_release = (self.__params_to_release(submodule, self.__step_id) if self.is_complete_trace() else set( p.ds_id for p in iter_params(submodule, recurse=z3_leaf_module(submodule)))) + + free_data = not z3_leaf_module(submodule) or not self.fast_sharding_for_leaf_module + if not free_data: + # wait for the computation to finish and launch as early as possible. + empty_buffer = torch.empty(1, device=get_accelerator().current_device()) + for param in iter_params(submodule, recurse=z3_leaf_module(submodule)): param.ds_active_sub_modules.discard(submodule.id) if param.ds_id in params_to_release and not param.is_external_param: - self.__release_param(param) + self.__release_param(param, free_data) + if not free_data: + if param.ds_id in params_to_release and not param.is_external_param: + # empty buffer ensures that all computations are complete + param.data = empty_buffer @instrument_w_nvtx @torch.no_grad() @@ -490,11 +506,11 @@ def __all_gather_params_(self, params: Set[Parameter], forward: bool, quantize: @compiler.disable @instrument_w_nvtx - def __release_param(self, param: Parameter) -> None: + def __release_param(self, param: Parameter, free_data: bool = True) -> None: if param.ds_status == ZeroParamStatus.AVAILABLE and not param.ds_active_sub_modules: if logger.isEnabledFor(logging.DEBUG): debug_rank0(f"-release: {param.ds_summary()}") - param.partition() + param.partition(free_data=free_data) self.__n_available_params -= param.ds_numel @instrument_w_nvtx diff --git a/op_builder/builder.py b/op_builder/builder.py index 461281d4a569..ab26054bda7d 100644 --- a/op_builder/builder.py +++ b/op_builder/builder.py @@ -415,10 +415,11 @@ def cpu_arch(self): return '-mcpu=native' return '-march=native' - def is_cuda_enable(self): + def get_cuda_compile_flag(self): try: - assert_no_cuda_mismatch(self.name) - return '-D__ENABLE_CUDA__' + if not self.is_rocm_pytorch(): + assert_no_cuda_mismatch(self.name) + return "-D__ENABLE_CUDA__" except MissingCUDAException: print(f"{WARNING} {self.name} cuda is missing or is incompatible with installed torch, " "only cpu ops can be compiled!") @@ -839,7 +840,7 @@ def cxx_args(self): CPU_ARCH = self.cpu_arch() SIMD_WIDTH = self.simd_width() - CUDA_ENABLE = self.is_cuda_enable() + CUDA_ENABLE = self.get_cuda_compile_flag() args += [ CPU_ARCH, '-fopenmp', diff --git a/tests/unit/ops/fp_quantizer/test_fp8_gemm.py b/tests/unit/ops/fp_quantizer/test_fp8_gemm.py index d66f7c8cb4cc..a4cf579f5943 100644 --- a/tests/unit/ops/fp_quantizer/test_fp8_gemm.py +++ b/tests/unit/ops/fp_quantizer/test_fp8_gemm.py @@ -14,6 +14,8 @@ from deepspeed.ops.fp_quantizer import FP_Quantize, matmul_fp8 +from deepspeed import get_accelerator + @pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"]) @pytest.mark.parametrize("q_bits", [8], ids=[ @@ -21,23 +23,19 @@ ]) @pytest.mark.parametrize("M", [1, 2, 4, 8, 32, 64, 128, 256, 512, 1024, 2048]) def test_fp_quant(dtype, q_bits, M): + device_name = get_accelerator().device_name() quantization_group_size = 128 fpq = FP_Quantize(group_size=quantization_group_size) N = 8192 H = 4096 - x = torch.randn(M, H, dtype=dtype, device='cuda') - weight_bf16 = torch.randn(H, N, dtype=dtype, device='cuda') + x = torch.randn(M, H, dtype=dtype, device=device_name) + weight_bf16 = torch.randn(H, N, dtype=dtype, device=device_name) - weight, _ = fpq.quantize(weight_bf16.data, q_bits=8, return_meta_tensor=True) + weight, _ = fpq.quantize(weight_bf16.data, q_bits=q_bits, return_meta_tensor=True) scale = fpq.get_scales() - out = matmul_fp8( - x, - weight, - scale, - quantization_group_size, - ) + out = matmul_fp8(x, weight, scale, quantization_group_size, fpq) out_q = torch.matmul(x, fpq.dequantize(weight, scale=fpq.scale))