Skip to content

attention_skinny

SkinnyAttention

Bases: torch.autograd.Function

forward(ctx, q, k, v, output, sm_scale, is_causal, attention_mask=None) staticmethod

Computes attention. FP32 input and output are not supported. https://github.com/openai/triton/issues/674 Not an issue as the function is annotated with @custom_fwd(cast_inputs=torch.float16) so the input is casted to float16 before the function is called.

@param ctx: context for autograd @param q: Query matrix size (batch, heads, size_m, dhead) @param k: Key matrix size (batch, heads, size_n, dhead) @param v: Value matrix size (batch, heads, size_n, dhead) @param output: Output matrix size (batch, heads, size_m, dhead) @param sm_scale: Scaling factor applied after operation QxK @param is_causal: Autoregressive decoder attention @param attention_mask: Attention mask matrix broadcastable to (batch, heads, size_m, size_n) @return: