Skip to content

batched_matmul

matmul_kernel(a_ptr, b_ptr, c_ptr, m_size, n_size, k_size, a_batch_stride, a_m_stride, a_k_stride, b_batch_stride, b_k_stride, b_n_stride, c_batch_stride, c_m_stride, c_n_stride, BLOCK_M_SIZE, BLOCK_N_SIZE, BLOCK_K_SIZE, GROUP_M_SIZE)

Kernel for computing the matmul C = A x B. A has shape (M, K), B has shape (K, N) and C has shape (M, N)