Skip to content

Commit

Permalink
Test SD with updated Triton
Browse files Browse the repository at this point in the history
  • Loading branch information
loadams committed Oct 25, 2023
1 parent 869629c commit a833c53
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 10 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/nv-inference.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ jobs:
- name: Install deepspeed
run: |
pip install .[dev,1bit,autotuning,inf,triton]
pip install .[dev,1bit,autotuning,inf,triton,sd]
ds_report
- name: Python environment
Expand Down
9 changes: 1 addition & 8 deletions deepspeed/ops/transformer/inference/triton_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ def _fwd_kernel(
K,
V,
sm_scale,
TMP,
Out,
stride_qz,
stride_qh,
Expand Down Expand Up @@ -57,7 +56,6 @@ def _fwd_kernel(
k_ptrs = K + off_k
v_ptrs = V + off_v
# initialize pointer to m and l
t_ptrs = TMP + off_hz * N_CTX + offs_m
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
Expand All @@ -69,8 +67,7 @@ def _fwd_kernel(
# -- compute qk ----
k = tl.load(k_ptrs + start_n * stride_kn)

qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k, trans_b=True)
qk = tl.dot(q, tl.trans(k))
qk *= sm_scale
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
Expand All @@ -87,8 +84,6 @@ def _fwd_kernel(
p = p * p_scale[:, None]
# scale acc
acc_scale = l_i / l_i_new * alpha
tl.store(t_ptrs, acc_scale)
acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(v_ptrs + start_n * stride_vk)
Expand All @@ -115,15 +110,13 @@ def forward(self, q, k, v, sm_scale, block_128=True):
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
o = torch.empty_like(q)
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1])
tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
num_warps = 4 if Lk <= 64 else 8

_fwd_kernel[grid](
q,
k,
v,
sm_scale,
tmp,
o,
q.stride(0),
q.stride(1),
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements-sd.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
diffusers
triton==2.0.0.dev20221202
triton
31 changes: 31 additions & 0 deletions tests/unit/inference/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,37 @@ def test(
assert assert_fn(bs_output, ds_output)


# Setup for these models is different from other pipelines, so we add a separate test
@pytest.mark.inference
class TestStableDiffusion(DistributedTest):
world_size = 1

def test(self):
from diffusers import DiffusionPipeline

prompt = "a dog on a rocket"
model = "prompthero/midjourney-v4-diffusion"
local_rank = int(os.getenv("LOCAL_RANK", "0"))
device = torch.device(f"cuda:{local_rank}")

pipe = DiffusionPipeline.from_pretrained(model, torch_dtype=torch.half)
pipe = pipe.to(device)
baseline_image = pipe(prompt, guidance_scale=7.5).images[0]

pipe = deepspeed.init_inference(
pipe,
mp_size=1,
dtype=torch.half,
replace_method="auto",
replace_with_kernel_inject=True,
enable_cuda_graph=False,
)
deepspeed_image = pipe(prompt, guidance_scale=7.5).images[0]

# Need to determine a heuristic for checking if images are "similar"
#assert baseline_image == deepspeed_image


@pytest.mark.seq_inference
@pytest.mark.parametrize("model_w_task", [("EleutherAI/gpt-neo-1.3B", "text-generation"),
("EleutherAI/gpt-neox-20b", "text-generation"),
Expand Down

0 comments on commit a833c53

Please sign in to comment.