Add subgroup matrix multiplication by junjihashimoto · Pull Request #80 · AnswerDotAI/gpu.cpp · GitHub
Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/dawn.cmake
225 changes: 182 additions & 43 deletions examples/matmul/run.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,93 @@ inline KernelCode createMatmulWithTranspose(const char *shaderTemplate, const si
return {unrolledCode, workgroupSize, precision};
}

inline KernelCode createMatmul12(const char *shaderTemplate, const size_t M,
const size_t K, const size_t N,
const size_t TM, const size_t TN,
const size_t LID,
const Shape &workgroupSize = {256, 1, 1},
NumType precision = kf32) {
std::string codeString(shaderTemplate);
replaceAll(codeString, {{"{{precision}}", toString(precision)},
{"{{M}}", toString(M)},
{"{{K}}", toString(K)},
{"{{N}}", toString(N)},
{"{{TM}}", toString(TM)},
{"{{TN}}", toString(TN)},
{"{{LID}}", toString(LID)}
});
return {loopUnrolling(codeString), workgroupSize, precision};
}

// ─────────────────────────────────────────────────────────────────────────────
// Optimised WGSL matrix‑multiply kernel using subgroupMatrixLoad/Store
// and subgroupMatrixMultiplyAccumulate
// ─────────────────────────────────────────────────────────────────────────────
const char* kShaderSubgroupMatrixMultiply = R"(
enable chromium_experimental_subgroup_matrix;
diagnostic (off, chromium.subgroup_matrix_uniformity);

@group(0) @binding(0) var<storage, read_write> A: array<{{precision}}>;
@group(0) @binding(1) var<storage, read_write> B: array<{{precision}}>;
@group(0) @binding(2) var<storage, read_write> C: array<{{precision}}>;

@compute @workgroup_size({{workgroupSize}})
fn main(@builtin(workgroup_id) wg: vec3<u32>,
@builtin(local_invocation_id) localID : vec3<u32>) {

let rowStart: u32 = wg.x * 8u * {{TM}};
let colStart: u32 = (wg.y * {{LID}} + localID.y) * 8u * {{TN}};

let baseA: u32 = rowStart * {{K}};
let baseB: u32 = colStart;
let cBase: u32 = rowStart * {{N}} + colStart;

var Ax: array<subgroup_matrix_left<{{precision}}, 8, 8>, {{TM}}>;
var Bx: array<subgroup_matrix_right<{{precision}}, 8, 8>, {{TN}}>;

// 4x4 accumulators (8x8 each)
var accxx: array<subgroup_matrix_result<{{precision}}, 8, 8>, {{TM}} * {{TN}}>;

for (var idx_i: u32 = 0; idx_i < {{TM}}; idx_i++) {
Ax[idx_i] = subgroup_matrix_left<{{precision}}, 8, 8>(0);
}

for (var idx_i: u32 = 0; idx_i < {{TN}}; idx_i++) {
Bx[idx_i] = subgroup_matrix_right<{{precision}}, 8, 8>(0);
}

for (var idx_i: u32 = 0; idx_i < {{TM}}; idx_i++) {
for (var idx_j: u32 = 0; idx_j < {{TN}}; idx_j++) {
accxx[idx_i+idx_j*{{TM}}] = subgroup_matrix_result<{{precision}}, 8, 8>(0);
}
}

for (var k: u32 = 0u; k < {{K}}; k = k + 8u) {
workgroupBarrier();
for (var idx_i: u32 = 0; idx_i < {{TM}}; idx_i++) {
Ax[idx_i] = subgroupMatrixLoad<subgroup_matrix_left<{{precision}},8,8>>(&A, baseA + k + 8u * {{K}} * idx_i, false, {{K}});
}

for (var idx_i: u32 = 0; idx_i < {{TN}}; idx_i++) {
Bx[idx_i] = subgroupMatrixLoad<subgroup_matrix_right<{{precision}},8,8>>(&B, baseB + k * {{N}} + 8u * idx_i, false, {{N}});
}

for (var idx_j: u32 = 0; idx_j < {{TN}}; idx_j++) {
for (var idx_i: u32 = 0; idx_i < {{TM}}; idx_i++) {
accxx[idx_j*{{TM}} + idx_i] = subgroupMatrixMultiplyAccumulate(Ax[idx_i], Bx[idx_j], accxx[idx_j*{{TM}} + idx_i]);
}
}
}

workgroupBarrier();
for (var idx_i: u32 = 0; idx_i < {{TM}}; idx_i++) {
for (var idx_j: u32 = 0; idx_j < {{TN}}; idx_j++) {
subgroupMatrixStore(&C, cBase + idx_i * 8u * {{N}} + 8u * idx_j, accxx[idx_j*{{TM}} + idx_i], false, {{N}});
}
}
}
)";

/**
* @brief No-Op shader with matmul bindings for performance testing
*/
Expand Down Expand Up @@ -683,26 +770,30 @@ Kernel selectMatmul(Context &ctx, int version,
const Bindings</* input, weights, output */ 3> &bindings,
size_t M, size_t K, size_t N, NumType numtype) {
Kernel kernel;
CompilationInfo info;
if (version == 1) {
Shape wgSize = {256, 1, 1};
Shape nWorkgroups = cdiv({M, N, 1}, {16, 16, 1});
KernelCode matmul = createNoOp(kShaderNoOp, /*wgsize*/ wgSize);
kernel = createKernel(ctx, matmul, bindings,
/*nWorkgroups*/ nWorkgroups);
/*nWorkgroups*/ nWorkgroups,
NoParam{}, &info);
} else if (version == 2) {
Shape wgSize = {16, 16, 1};
LOG(kDefLog, kInfo, "wgSize: %s", toString(wgSize).c_str());
KernelCode matmul =
createMatmul1(kShaderMatmul1, M, K, N, /*wgsize*/ wgSize, numtype);
kernel = createKernel(ctx, matmul, bindings,
/*nWorkgroups*/ cdiv({M, N, 1}, wgSize));
/*nWorkgroups*/ cdiv({M, N, 1}, wgSize),
NoParam{}, &info);
} else if (version == 3) {
static constexpr size_t tileSize = 16;
KernelCode matmul = createMatmul2(kShaderMatmul2, M, K, N,
/*wgSize*/ {tileSize * tileSize, 1, 1}, numtype);
kernel =
createKernel(ctx, matmul, bindings,
/* nWorkgroups*/ cdiv({M, N, 1}, {tileSize, tileSize, 1}));
/* nWorkgroups*/ cdiv({M, N, 1}, {tileSize, tileSize, 1}),
NoParam{}, &info);
} else if (version == 4 || version == 6) {
static constexpr size_t BM = 64;
static constexpr size_t BK = 4;
Expand All @@ -721,7 +812,8 @@ Kernel selectMatmul(Context &ctx, int version,
numtype,
/*Loop unrolling*/ version == 6 ? true: false);
kernel = createKernel(ctx, matmul, bindings,
/*nWorkgroups*/ nWorkgroups);
/*nWorkgroups*/ nWorkgroups,
NoParam{}, &info);
} else if (version == 5 || version == 7) {
static constexpr size_t BM = 64;
static constexpr size_t BK = 8;
Expand All @@ -739,7 +831,8 @@ Kernel selectMatmul(Context &ctx, int version,
numtype,
/*Loop unrolling*/ version == 7 ? true: false);
kernel = createKernel(ctx, matmul, bindings,
/*nWorkgroups*/ nWorkgroups);
/*nWorkgroups*/ nWorkgroups,
NoParam{}, &info);
} else if (version == 8 || version == 10) {
static constexpr size_t BM = 64;
static constexpr size_t BK = 8;
Expand All @@ -757,7 +850,8 @@ Kernel selectMatmul(Context &ctx, int version,
numtype,
/*Loop unrolling*/ true);
kernel = createKernel(ctx, matmul, bindings,
/*nWorkgroups*/ nWorkgroups);
/*nWorkgroups*/ nWorkgroups,
NoParam{}, &info);
} else if (version == 9 || version == 11) {
static constexpr size_t BM = 64;
static constexpr size_t BK = 8;
Expand All @@ -774,8 +868,38 @@ Kernel selectMatmul(Context &ctx, int version,
/*wgSize*/ wgSize,
numtype);
kernel = createKernel(ctx, matmul, bindings,
/*nWorkgroups*/ nWorkgroups);
/*nWorkgroups*/ nWorkgroups,
NoParam{}, &info);
} else if (version == 12 || version == 13) {
// f16: Subgroup matrix multiply
static constexpr size_t TM = 4;
static constexpr size_t TN = 8;
static constexpr size_t LID = 2;
Shape wgSize = {32, LID, 1}; // One subgroup per workgroup
Shape nWorkgroups = {cdiv(M, 8 * TM), cdiv(N, 8 * TN * LID), 1};
LOG(kDefLog, kInfo, "M: %zu, K: %zu, N: %zu", M, K, N);
LOG(kDefLog, kInfo, "wgSize: ( %s )", toString(wgSize).c_str());
LOG(kDefLog, kInfo, "nWorkgroups: ( %s )", toString(nWorkgroups).c_str());
KernelCode matmul = createMatmul12(kShaderSubgroupMatrixMultiply, M, K, N, TM, TN, LID, wgSize, numtype);
kernel = createKernel(ctx, matmul, bindings, nWorkgroups,
NoParam{}, &info);
}

if (info.status != WGPUCompilationInfoRequestStatus_Success) {
LOG(kDefLog, kError, "Failed to compile shader");
for (size_t i = 0; i < info.messages.size(); i++) {
LOG(kDefLog, kError, "Line %llu, Pos %llu: %s", info.lineNums[i],
info.linePos[i], info.messages[i].c_str());
}
exit(1);
} else {
LOG(kDefLog, kInfo, "Shader compiled successfully");
for (size_t i = 0; i < info.messages.size(); i++) {
LOG(kDefLog, kInfo, "Line %llu, Pos %llu: %s", info.lineNums[i],
info.linePos[i], info.messages[i].c_str());
}
}

return kernel;
}

Expand All @@ -791,41 +915,51 @@ void runTest(int version, size_t M, size_t K, size_t N,
assert(numtype == kf16);
}

// Allocate GPU buffers and copy data
WGPUDeviceDescriptor devDescriptor = {};
devDescriptor.requiredFeatureCount = 1;
devDescriptor.requiredFeatures = std::array{WGPUFeatureName_ShaderF16}.data();

Context ctx;
if (numtype == kf16) {
ctx = createContext(
{}, {},
/*device descriptor, enabling f16 in WGSL*/
{
.requiredFeatureCount = 1,
.requiredFeatures = std::array{WGPUFeatureName_ShaderF16}.data()
});
if (ctx.adapterStatus != WGPURequestAdapterStatus_Success) {
LOG(kDefLog, kError, "Failed to create adapter with f16 support, try running an f32 test instead (`export MATMUL_VERSION=9).");
exit(1);
static WGPUDawnTogglesDescriptor toggles = {};
toggles.chain.sType = WGPUSType_DawnTogglesDescriptor;
const char* enableList[] = {"allow_unsafe_apis"};
toggles.enabledToggles = enableList;
toggles.enabledToggleCount = 1;

static WGPUDeviceDescriptor devDesc = {};
devDesc.nextInChain = &toggles.chain;
devDesc.requiredFeatureCount = 3,
devDesc.requiredFeatures = std::array{
WGPUFeatureName_ShaderF16,
WGPUFeatureName_Subgroups,
WGPUFeatureName_ChromiumExperimentalSubgroupMatrix
}.data();
devDesc.uncapturedErrorCallbackInfo = WGPUUncapturedErrorCallbackInfo {
.callback = [](WGPUDevice const * device, WGPUErrorType type, WGPUStringView msg, void*, void*) {
LOG(kDefLog, kError, "[Uncaptured %d] %.*s\n", (int)type, (int)msg.length, msg.data);
}
if (ctx.deviceStatus != WGPURequestDeviceStatus_Success) {
LOG(kDefLog, kError, "Failed to create device with f16 support, try running an f32 test instead. (`export MATMUL_VERSION=9)");
exit(1);
};
devDesc.deviceLostCallbackInfo = WGPUDeviceLostCallbackInfo {
.mode = WGPUCallbackMode_AllowSpontaneous,
.callback = [](WGPUDevice const * device, WGPUDeviceLostReason reason, WGPUStringView msg, void*, void*) {
LOG(kDefLog, kError, "[DeviceLost %d] %.*s\n", (int)reason, (int)msg.length, msg.data);
}
}
};

if (numtype == kf32) {
ctx = createContext({}, {}, {});
if (ctx.adapterStatus != WGPURequestAdapterStatus_Success ||
ctx.deviceStatus != WGPURequestDeviceStatus_Success) {
LOG(kDefLog, kError, "Failed to create adapter or device");
// stop execution
exit(1);
} else {
LOG(kDefLog, kInfo, "Successfully created adapter and device");
static WGPULimits requiredLimits = WGPU_LIMITS_INIT;
devDesc.requiredLimits = &requiredLimits;
Context ctx = createContext({}, {}, devDesc);

WGPULoggingCallbackInfo logCb{
.callback = [](WGPULoggingType type, WGPUStringView msg, void*, void*) {
LOG(kDefLog, kError, "[WGPU %d] %.*s\n", (int)type, (int)msg.length, msg.data);
}
}
};
wgpuDeviceSetLoggingCallback(ctx.device, logCb);

if (ctx.adapterStatus != WGPURequestAdapterStatus_Success ||
ctx.deviceStatus != WGPURequestDeviceStatus_Success) {
LOG(kDefLog, kError, "Failed to create adapter or device");
// stop execution
exit(1);
} else {
LOG(kDefLog, kInfo, "Successfully created adapter and device");
}

Tensor input = createTensor(ctx, Shape{M, K}, numtype, inputPtr.get());
Tensor weights = createTensor(ctx, Shape{N, K}, numtype, weightsPtr.get()); // column-major
Expand Down Expand Up @@ -859,7 +993,7 @@ void runTest(int version, size_t M, size_t K, size_t N,
// Use microsecond for more accurate time measurement
auto duration =
std::chrono::duration_cast<std::chrono::microseconds>(end - start);
float gflops = 2 * M * N *
float gflops = 2.0f * M * N *
K / // factor of 2 for multiplication & accumulation
(static_cast<double>(duration.count()) / 1000000.0) /
1000000000.0 * static_cast<float>(nIter);
Expand All @@ -870,7 +1004,7 @@ void runTest(int version, size_t M, size_t K, size_t N,
show<precision>(outputPtr.get(), M, N, "Output[0]").c_str());

LOG(kDefLog, kInfo, "\n\n===================================================================="
"============\nExecution Time: (M = %d, K = %d, N = %d) x %d iterations "
"============\nExecution Time: (M = %zu, K = %zu, N = %zu) x %zu iterations "
":\n%.1f "
"milliseconds / dispatch ~ %.2f "
"GFLOPS\n================================================================"
Expand Down Expand Up @@ -913,13 +1047,16 @@ const std::string versionToStr(int version){
case 9: return "f32: 2D blocktiling with loop unrolling, vectorization and transpose";
case 10: return "f16: 2D blocktiling with loop unrolling and vectorization";
case 11: return "f16: 2D blocktiling with loop unrolling, vectorization and transpose";
case 12: return "f16: Subgroup matrix multiply with transpose (default)";
case 13: return "f32: Subgroup matrix multiply with transpose";
default: return "Not specified";
}
}

int main() {
std::cout << "Starting matmul test..." << std::endl;
char* version_str = getenv("MATMUL_VERSION");
int version = version_str == NULL ? 10 : atoi(version_str);
int version = version_str == NULL ? 12 : atoi(version_str);
// 1 == f32: No-Op
// 2 == f32: naive matmul
// 3 == f32: tiling
Expand All @@ -931,8 +1068,10 @@ int main() {
// 9 == f32: 2D blocktiling with loop unrolling, vectorization and transpose
// 10 == f16: 2D blocktiling with loop unrolling and vectorization (default)
// 11 == f16: 2D blocktiling with loop unrolling, vectorization and transpose
bool enableF16 = version == 10 || version ==11;
bool transposedInput = version == 9 || version == 11;
// 12 == f16: Subgroup matrix multiply with transpose (default)
// 13 == f32: Subgroup matrix multiply with transpose
bool enableF16 = version == 10 || version ==11 || version == 12;
bool transposedInput = version == 9 || version == 11 || version == 12 || version == 13;
NumType numtype = enableF16 ? kf16 : kf32;

size_t M, K, N; // Matrix dimensions
Expand Down
6 changes: 2 additions & 4 deletions gpu.hpp
Loading
Loading