liminghong commited on
Commit
4311911
·
1 Parent(s): f784b6e

Make compatible with recent versions of triton

Browse files
Files changed (1) hide show
  1. flash_attn_triton.py +4 -4
flash_attn_triton.py CHANGED
@@ -188,7 +188,7 @@ def _fwd_kernel(
188
  (offs_d[None, :] < headdim),
189
  other=0.0)
190
  qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
191
- qk += tl.dot(q, k, trans_b=True)
192
  # Trying to combine the two masks seem to make the result wrong
193
  if not EVEN_N: # Need to mask out otherwise the softmax is wrong
194
  qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0,
@@ -431,7 +431,7 @@ def _bwd_kernel_one_col_block(
431
  (offs_d[None, :] < headdim),
432
  other=0.0)
433
  # recompute p = softmax(qk, dim=-1).T
434
- qk = tl.dot(q, k, trans_b=True)
435
  # Trying to combine the two masks seem to make the result wrong
436
  if not EVEN_N: # Need to mask out otherwise the softmax is wrong
437
  qk = tl.where(offs_n[None, :] < seqlen_k, qk, float('-inf'))
@@ -491,7 +491,7 @@ def _bwd_kernel_one_col_block(
491
  # else:
492
  # do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q)
493
  # & (offs_d[None, :] < headdim), other=0.0)
494
- dv += tl.dot(p.to(do.dtype), do, trans_a=True)
495
  # compute dp = dot(v, do)
496
  # There seems to be a race condition when headdim=48/96, and dq, dk are wrong.
497
  # Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True
@@ -509,7 +509,7 @@ def _bwd_kernel_one_col_block(
509
  # for BLOCK_HEADDIM=128
510
  ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype)
511
  # compute dk = dot(ds.T, q)
512
- dk += tl.dot(ds, q, trans_a=True)
513
  # compute dq
514
  if not ATOMIC_ADD:
515
  if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M
 
188
  (offs_d[None, :] < headdim),
189
  other=0.0)
190
  qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
191
+ qk += tl.dot(q, tl.trans(k))
192
  # Trying to combine the two masks seem to make the result wrong
193
  if not EVEN_N: # Need to mask out otherwise the softmax is wrong
194
  qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0,
 
431
  (offs_d[None, :] < headdim),
432
  other=0.0)
433
  # recompute p = softmax(qk, dim=-1).T
434
+ qk = tl.dot(q, tl.trans(k))
435
  # Trying to combine the two masks seem to make the result wrong
436
  if not EVEN_N: # Need to mask out otherwise the softmax is wrong
437
  qk = tl.where(offs_n[None, :] < seqlen_k, qk, float('-inf'))
 
491
  # else:
492
  # do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q)
493
  # & (offs_d[None, :] < headdim), other=0.0)
494
+ dv += tl.dot(tl.trans(p).to(do.dtype), do)
495
  # compute dp = dot(v, do)
496
  # There seems to be a race condition when headdim=48/96, and dq, dk are wrong.
497
  # Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True
 
509
  # for BLOCK_HEADDIM=128
510
  ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype)
511
  # compute dk = dot(ds.T, q)
512
+ dk += tl.dot(tl.trans(ds), q)
513
  # compute dq
514
  if not ATOMIC_ADD:
515
  if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M