PR08: C++ Platform Backends by agibsonccc · Pull Request #10441 · deeplearning4j/deeplearning4j · GitHub
Skip to content

PR08: C++ Platform Backends#10441

Open
agibsonccc wants to merge 7 commits into
masterfrom
pr/08-cpp-platform-backends
Open

PR08: C++ Platform Backends#10441
agibsonccc wants to merge 7 commits into
masterfrom
pr/08-cpp-platform-backends

Conversation

@agibsonccc

@agibsonccc agibsonccc commented Jun 15, 2026

Copy link
Copy Markdown
Contributor

Summary

PR 08 of 22 PRs in the ag_new_release_updates_2 branch split. Merge after Layer 2 (native core ops + helpers).

  • ARM Compute Library: 125 files covering full LLM op surface (rope, rms_norm, GQA, all activations/reductions/shape ops) for AArch64
  • oneDNN Flash Attention: FlashAttentionCache + SDPACache using oneDNN Graph API compiled partitions; FP32/FP16/BF16; thread-local stream avoids per-call overhead
  • llama.cpp/GGML: 59 files wrapping GGML via GgmlContextGuard RAII; zero-copy createGgmlTensor(); quantized GEMM (Q4_0/Q4_1/Q8_0); KV cache, MoE, RWKV, Mamba, Gemma4; 20 CUDA variants
  • cuDNN: GRU, dropout, transposed conv, global pooling, thread-safe per-stream handle cache; flash_attention.cu stub always declines (no cuDNN SDPA primitive), routes to cuBLAS fallback
  • Apple Accelerate: 28 files — BLAS via cblas, FFT via vDSP, conv, norm, element-wise via vDSP vector ops
  • Apple MPS: 21 Objective-C++ .mm files; zero-copy MTLBuffer wrapping on Unified Memory; SDPA via MPSGraphScaledDotProductAttentionOp; LSTM/GRU via MPSLSTM
  • MLIR: 17 files using Linalg/arith/math/SCF/tensor dialects
  • VLM: VlmBackendManager singleton probes backends at startup (CUDA>METAL>CPU); full ViT pipeline: preprocess → patch embed → vision encode → project → cross-attention → multimodal fusion
  • PJRT/TPU: 10 files dispatching XLA HLO computations via PJRT client
  • MIOpen/AMD: 5 files mirroring cuDNN surface via HIP runtime for ZLUDA/AMD GPUs
  • All backends use DECLARE_PLATFORM/PLATFORM_IMPL/PLATFORM_CHECK triple; PLATFORM_CHECK gates on compile-time feature flags and runtime dtype/shape compatibility

What Changed

ARM Compute Library (ACL) — platform/armcompute/ (~124 files)

  • armcomputeUtils.h — ACL tensor/layout conversion utilities
  • ArmComputeVersionProvider.h — build-time version detection
  • deconv2d.cpp — deconvolution via NEDeconvolutionLayer
  • 121 op files: activations, reductions, convolutions, attention (GQA), normalization, LLM ops (rope, rms_norm), embeddings, scatter, all shape ops — full DECLARE_PLATFORM/PLATFORM_IMPL/PLATFORM_CHECK for ENGINE_CPU on ARM

Apple Accelerate — platform/accelerate/ (28 files)

  • AccelerateVersionProvider.h — version detection
  • accelerateUtils.h/.cpp — BLAS/vDSP bridge for NDArray
  • matmul.cpp, gemv.cpp, dot.cpp, axpy.cpp, blas_extra.cpp — GEMM/GEMV/dot via cblas
  • fft.cpp — FFT via vDSP_fft_zrip (split-complex format)
  • conv1d.cpp, conv2d.cpp — convolution via vDSP_conv
  • batchnorm.cpp, layer_norm.cpp — normalization via vDSP vector ops
  • elementwise.cpp, arithmetic.cpp, transcendental.cpp, trigonometric.cpp — element-wise via vDSP and Accelerate math (vvsin, vvcos, vvexp, etc.)
  • reductions.cpp, pooling2d.cpp, comparison.cpp, cumulative.cpp, linalg.cpp, batch_ops.cpp — full coverage of pooling, SVD/solve, batch ops

cuDNN — platform/cudnn/ (20 files)

  • CudnnVersionProvider.h — version detection
  • activations.cu / activations_extended.cu — relu/sigmoid/tanh/elu/gelu/softplus and swish/mish/hardswish/hardsigmoid via cudnnActivationForward
  • biasadd.cu — bias-add via cuDNN tensor add API
  • conv1d.cu — 1D convolution via cuDNN 2D conv (expanded dims trick)
  • cudnnUtils.h/.cu — centralized per-stream cuDNN handle caching
  • deconv2d.cu / deconv3d.cu — transposed convolution via cudnnConvolutionBackwardData
  • dropout.cu — dropout via cudnnDropoutForward with stateful RNG descriptor
  • flash_attention.cu — stub: PLATFORM_CHECK always false; routes to cuBLAS FlashAttentionHelper
  • global_pooling.cu — global max/avg pool via cudnnPoolingForward
  • gru.cu — GRU forward via cudnnRNNForward
  • instancenorm.cu, layernorm.cu, log_softmax.cu, lrn.cu, softmax.cu, reduce.cu, sconv2d.cu, simple_rnn.cu, spatial_transformer.cu — additional norm/attention/RNN ops

llama.cpp/GGML — platform/llamacpp/ (60 files)

  • GgmlVersionProvider.h — version detection
  • llamacppUtils.h/.cpp — GgmlContextGuard RAII (64MB context); createGgmlTensor() wraps NDArray buffer zero-copy; executeGgmlGraph() runs ggml_cgraph; copyGgmlToNDArray() copies result back
  • matmul.cpp, quantized_matmul.cpp — GEMM via ggml_mul_mat; Q4_0/Q4_1/Q8_0 quantized variants
  • rms_norm.cpp, rope.cpp — RMSNorm and RoPE via GGML primitives
  • kv_cache_ops.cpp — KV cache update via ggml_set (in-place scatter at sequence position)
  • grouped_query_attention.cpp, flash_attention.cpp — GQA and Flash Attention
  • moe_ops.cpp, rwkv_ops.cpp, ssm_ops.cpp, gated_delta_ops.cpp, gemma4_ops.cpp — MoE, RWKV, Mamba-style SSM, gated delta rule, Gemma4
  • model_parallel.cpp, device_locality.cpp — tensor parallelism and NUMA-aware placement
  • cuda/ (20 .cu files) — CUDA variants using GGML CUDA backend

MIOpen/AMD — platform/miopen/ (5 files)

  • miopenUtils.h — ZLUDA/MIOpen bridge for ENGINE_ZLUDA_AMD
  • activations.cpp, batchnorm.cpp, conv2d.cpp, softmax.cpp — MIOpen GPU ops via HIP runtime

oneDNN/MKL-DNN — platform/mkldnn/ (80+ files)

  • OnednnVersionProvider.h — version detection
  • mkldnnUtils.h/.cpp — thread-local dnnl::stream, Graph API helpers
  • flash_attention.cpp — Flash Attention via oneDNN Graph API; FlashAttentionCache keyed on (batch, seqQ, seqKV, numHeads, headDim, dtype, isCausal, is3D); FP32/FP16/BF16
  • sdpa.cpp — SDPA via oneDNN Graph API; SDPACache with 4D/3D and additive-bias; thread-local stream
  • gru.cpp, global_pooling.cpp — GRU and global pooling
  • eltwise_*.cpp (5 files) — element-wise ops grouped by category via oneDNN eltwise primitive
  • 40+ activation files and shape/arithmetic ops (batched_gemm, conv1d, layer_norm, pooling, reshape, transpose, etc.)

MLIR — platform/mlir/ (17 files)

  • mlirUtils.h — MLIR platform utility header
  • Op implementations using Linalg/arith/math/SCF/tensor dialects: attention, matmul, element-wise, embedding, conv2d, normalization, pooling, reductions, activations

Apple MPS — platform/mps/ (21 .mm files)

  • MpsVersionProvider.h, mpsUtils.h/.mm — MPS bridge for NDArray to MPSMatrix/MPSNDArray; Metal command buffer lifecycle
  • mps_blas.mm — GEMM/GEMV via MPSMatrixMultiplication
  • mps_conv.mm — MPSCNNConvolution and depthwise conv
  • mps_activations.mm / mps_activations_ext.mm — MPSCNNNeuron variants
  • mps_attention.mm — SDPA via MPSGraphScaledDotProductAttentionOp with batched matmul fallback
  • mps_normalization.mm — MPSCNNBatchNormalization
  • mps_rnn.mm — LSTM/GRU via MPSLSTM
  • mps_comparison.mm, mps_elementwise.mm, mps_embedding.mm, mps_image.mm, mps_loss.mm, mps_math.mm, mps_matrix.mm, mps_reductions.mm, mps_sorting.mm, mps_transform.mm — full op coverage via MPSGraph

PJRT/TPU — platform/pjrt/ (10 files)

  • pjrtUtils.h/.cpp — PJRT C API bridge for ENGINE_TPU: matmul, element-wise, activations, conv2d, pooling, reductions, shape ops
  • matmul.cpp, batchnorm.cpp, conv2d.cpp, elementwise.cpp, pooling.cpp, reductions.cpp, shape_ops.cpp — XLA HLO computations via PJRT client

VLM — platform/vlm/ (12 files)

  • vlmBackend.h/.cppVlmBackendManager singleton; backend priority: CUDA=1 > METAL=2 > CPU=0; AUTO=99; 512MB context default
  • vlmUtils.h/.cpp — image preprocessing, NDArray ↔ GGML tensor conversion
  • vlm_image_preprocess.cpp — resize, normalize, patch tokenization for ViT
  • vlm_image_embed.cpp — patch embedding via GGML vision encoder weights
  • vlm_vision_encode.cpp — ViT-style transformer forward pass
  • vlm_vision_projection.cpp — project vision features to language model embedding dimension
  • vlm_cross_attention.cpp — cross-attention between visual features and language tokens (SmolDocling/InternVL)
  • vlm_multimodal_fusion.cpp — combine visual and textual representations
  • vlm_patch_embed.cpp / vlm_2d_position_encode.cpp — patch and 2D position embeddings
  • cuda/vlmCudaUtils.cu, cuda/vlm_core_ops.cu — CUDA kernels for VLM memory and patch embedding

Dependencies

  • Depends on: PR05, PR06, PR07 (libnd4j core ops, NDArray, memory management)
  • Required by: PR16 (LLM/VLM Java pipeline depends on llamacpp and VLM backends)

Merge Order

These 22 PRs must merge in layer order. Each layer depends on the layers above it being merged first. PRs within the same layer are independent and can merge in parallel.

This PR: Merge after Layer 2 (native core ops + helpers).

Layer PRs
0 (no deps) PR01, PR02, PR20
1 (build/infra) PR03, PR04
2 (native core) PR05, PR06, PR07
3 (native feat) PR08, PR09, PR10, PR11
4 (java core) PR12, PR13, PR14, PR15
5 (java feat) PR16
6 (import/gen) PR17, PR18, PR19, PR21
7 (validation) PR22

Part of the 22-PR split of ag_new_release_updates_2 branch.
Merge layer: 3 (native features)
Files: 382

See pr-plans/00-master-plan.md for the full split plan and merge order.

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot wasn't able to review this pull request because it exceeds the maximum number of files (300). Try reducing the number of changed files and requesting a review from Copilot again.

… debug output

Remove the debug printf("F2 opType:[%i]\n", opNum) blocks from
NativeOpExecutioner_indexreduce.cu and NativeOpExecutioner_reduce.cu entirely,
following the no-ad-hoc-printf rule for C++.

Gate all printf diagnostic output in the llamacpp backend behind
sd::Environment::getInstance().isVerbose(): printAllBackendCapabilities()
in backend_capabilities.cpp, and printArrayLocalityInfo() plus the locality
mismatch warning in device_locality.cpp.
@agibsonccc

Copy link
Copy Markdown
Contributor Author

Add Metal Performance Shaders implementations for all major op
categories: attention, blas, comparison, conv, elementwise,
embedding, image, loss, math, matrix, normalization, reductions,
rnn, sorting, transform. Add MpsVersionProvider for runtime
Metal feature detection. Refactor mpsUtils for shared utilities.
Apply THROW_EXCEPTION macro consolidation to platform backends.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants