Skip to content

linear_layer

LinearLayer

Bases: torch.autograd.Function

forward(ctx, x, weight, bias, activation, act_inputs) staticmethod

Compute e = activation(x @ weight + bias). This wrapper kicks the kernel_fma Triton kernel :param ctx: context for autograd :param x: input tensor :param weight: weight matrix :param bias: an optional bias tensor :param activation: Activation name. Needs to be a Triton kernel. :param act_inputs: an optional tensor to save the activation inputs (for backward) :return: result tensor

kernel_fma(C, ACT_INPUTS, A, B, bias, M, N, K, CACHE_KEY_M, CACHE_KEY_N, CACHE_KEY_K, output_m_stride, output_n_stride, act_inputs_m_stride, act_inputs_n_stride, a_m_stride, a_k_stride, b_n_stride, b_k_stride, BLOCK_M, GROUP_M, BLOCK_N, BLOCK_K, SPLIT_K, K_LOAD_MASK_NEEDED, HAS_BIAS, SHOULD_SAVE_ACT_INPUTS, ACTIVATION)

Kernel for computing Out = activation(A x W + C)

  • Input has shape (M, K)
  • Weight has shape (K, N)
  • Bias has shape (N,)
  • Output has shape (M, N)
  • ActInputs (optional) has shape (M, N)

'ActInputs' optionally saves the A x W + C intermediate for backward computations

This kernel will consolidate over K