Metal Kernel Writing Guide
This skill guides you through implementing Metal kernels for PyTorch operators on Apple Silicon.
Important: The goal of this skill is to use native Metal capabilities via the c10/metal/ infrastructure, NOT MPSGraph. Native Metal kernels provide better control, performance, and maintainability.
Overview
There are two workflows covered by this skill:
-
Adding new MPS support - Implementing a new operator from scratch
-
Migrating from MPSGraph - Converting existing MPSGraph-based operators to native Metal
Both workflows involve:
-
Update dispatch in aten/src/ATen/native/native_functions.yaml
-
Write Metal kernel in aten/src/ATen/native/mps/kernels/
-
Implement host-side stub in aten/src/ATen/native/mps/operations/
Step 1: Update native_functions.yaml
Location: aten/src/ATen/native/native_functions.yaml
For New Operators
Find the operator entry and add MPS dispatch:
Simple MPS-specific implementation
- func: my_op(Tensor self) -> Tensor dispatch: CPU: my_op_cpu CUDA: my_op_cuda MPS: my_op_mps
Shared implementation across devices (preferred for structured kernels)
- func: my_op.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) dispatch: CPU, CUDA, MPS: my_op_out
Structured kernel (preferred for new ops)
- func: my_op.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) structured: True structured_inherits: TensorIteratorBase dispatch: CPU, CUDA, MPS: my_op_out
For Migrating from MPSGraph
When migrating an existing operator from MPSGraph to native Metal, consolidate the dispatch entry:
BEFORE (MPSGraph-based, separate dispatch)
- func: atan2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) structured: True structured_inherits: TensorIteratorBase dispatch: CPU, CUDA: atan2_out MPS: atan2_out_mps # Separate MPS implementation
AFTER (native Metal, shared dispatch via stub)
- func: atan2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) structured: True structured_inherits: TensorIteratorBase dispatch: CPU, CUDA, MPS: atan2_out # MPS now uses the same stub mechanism
Key change: Replace MPS: my_op_out_mps with adding MPS to the shared dispatch line (e.g., CPU, CUDA, MPS: my_op_out ).
Dispatch naming conventions:
-
MPS: function_name_mps
-
MPS-specific implementation (old MPSGraph pattern)
-
CPU, CUDA, MPS: function_name
-
Shared stub implementation (native Metal pattern)
Step 2: Implement Metal Kernel
Location: aten/src/ATen/native/mps/kernels/
Unary Kernel Pattern
// MyKernel.metal #include <c10/metal/indexing.h> #include <c10/metal/utils.h> #include <metal_stdlib>
using namespace metal; using namespace c10::metal;
// Define operation functor struct my_op_functor { template <typename T> inline T operator()(const T x) { return /* your operation */; } };
// Register for supported types REGISTER_UNARY_OP(my_op, float, float); REGISTER_UNARY_OP(my_op, half, half); REGISTER_UNARY_OP(my_op, bfloat, bfloat);
Binary Kernel Pattern
struct my_binary_functor { template <typename T> inline T operator()(const T a, const T b) { return /* your operation */; } };
REGISTER_BINARY_OP(my_binary, float, float); REGISTER_BINARY_OP(my_binary, half, half);
Binary Kernel Type Registration Macros
For binary operations, use the convenience macros defined in BinaryKernel.metal :
// Floating-point types only (float, half, bfloat) REGISTER_FLOAT_BINARY_OP(my_op);
// Integral types with float output (for math ops like atan2, copysign) // Registers: long->float, int->float, short->float, uchar->float, char->float, bool->float REGISTER_INT2FLOAT_BINARY_OP(my_op);
// Integral types with same-type output (for bitwise/logical ops) // Registers: long, int, short, uchar, char, bool REGISTER_INTEGER_BINARY_OP(my_op);
// Floating-point with opmath precision (for ops needing higher precision) REGISTER_OPMATH_FLOAT_BINARY_OP(my_op);
Common patterns:
-
Math functions (atan2, copysign, logaddexp): Use both REGISTER_FLOAT_BINARY_OP and REGISTER_INT2FLOAT_BINARY_OP
-
Comparison/logical ops (maximum, minimum): Use both REGISTER_FLOAT_BINARY_OP and REGISTER_INTEGER_BINARY_OP
-
Arithmetic ops (add, sub, mul): Use both REGISTER_FLOAT_BINARY_OP and REGISTER_INTEGER_BINARY_OP
Example for atan2 (supports both float and int inputs):
struct atan2_functor { template <typename T, enable_if_t<is_floating_point_v<T>, bool> = true> inline T operator()(const T a, const T b) { return static_cast<T>(precise::atan2(float(a), float(b))); } template <typename T, enable_if_t<is_integral_v<T>, bool> = true> inline float operator()(const T a, const T b) { return precise::atan2(float(a), float(b)); } };
REGISTER_FLOAT_BINARY_OP(atan2); REGISTER_INT2FLOAT_BINARY_OP(atan2);
With Scalar Parameter
struct my_alpha_functor { template <typename T> inline T operator()(const T a, const T b, const T alpha) { return a + c10::metal::mul(alpha, b); } };
REGISTER_UNARY_ALPHA_OP(my_alpha, float, float, float); REGISTER_UNARY_ALPHA_OP(my_alpha, half, half, half);
Type-Specialized Functor
struct special_functor { // Floating point types template <typename T, enable_if_t<is_scalar_floating_point_v<T>, bool> = true> inline T operator()(const T x) { return precise::exp(x); // Use precise math }
// Integral types template <typename T, enable_if_t<is_scalar_integral_v<T>, bool> = true> inline float operator()(const T x) { return precise::exp(float(x)); }
// Complex types (float2 for cfloat, half2 for chalf) template <typename T, enable_if_t<is_complex_v<T>, bool> = true> inline T operator()(const T x) { // x.x = real, x.y = imaginary return T(/* real /, / imag */); } };
Note on complex types: Complex numbers in Metal are represented as vector types:
-
c10::complex<float> maps to float2 (x = real, y = imaginary)
-
c10::complex<half> maps to half2
Use is_complex_v<T> to specialize for complex types in functors.
Available c10/metal Utilities
utils.h:
-
opmath_t<T>
-
Operation math type (half->float)
-
accum_t<T>
-
Accumulation type for reductions
-
max() , min() with NaN propagation
special_math.h:
-
precise::exp() , precise::log() , precise::sqrt()
-
precise::sin() , precise::cos() , precise::tan()
-
erf() , erfc() , erfinv()
indexing.h:
-
REGISTER_UNARY_OP(name, in_type, out_type)
-
REGISTER_BINARY_OP(name, in_type, out_type)
-
REGISTER_UNARY_ALPHA_OP(name, in_type, alpha_type, out_type)
Step 3: Implement Host-Side Stub
Location: aten/src/ATen/native/mps/operations/
Choose or create an appropriate file based on operation type:
-
UnaryKernel.mm
-
Single input operations via stub dispatch
-
BinaryKernel.mm
-
Two input operations via stub dispatch
-
UnaryOps.mm / BinaryOps.mm
-
Legacy MPSGraph implementations (for reference)
-
ReduceOps.mm
-
Reductions (sum, mean, max, etc.)
-
Create new file for distinct operation categories
Stub Registration Pattern (Preferred for Native Metal)
For structured kernels that use the TensorIterator pattern:
// In BinaryKernel.mm (or appropriate file)
static void my_op_mps_kernel(TensorIteratorBase& iter) { lib.exec_binary_kernel(iter, "my_op"); // "my_op" matches the functor name in .metal }
// Register the MPS stub - this connects to the dispatch system REGISTER_DISPATCH(my_op_stub, &my_op_mps_kernel)
For unary operations:
static void my_unary_mps_kernel(TensorIteratorBase& iter) { lib.exec_unary_kernel(iter, "my_unary"); }
REGISTER_DISPATCH(my_unary_stub, &my_unary_mps_kernel)
Migration: Removing Old MPSGraph Implementation
When migrating from MPSGraph, also remove the old implementation:
Remove from BinaryOps.mm (or UnaryOps.mm):
-
Delete the TORCH_IMPL_FUNC(my_op_out_mps) implementation
-
Remove the corresponding #include <ATen/ops/my_op_native.h> header
Add to BinaryKernel.mm (or UnaryKernel.mm):
-
Add the static kernel function
-
Add the REGISTER_DISPATCH call
Step 4: Compile
After making changes, compile to verify everything builds correctly:
cd build && ninja torch_cpu
Testing
Basic operator support is already tested by test_output_match in test/test_mps.py . After implementing an operator, enable testing by removing expected failures:
- Remove from common_mps.py
Location: torch/testing/_internal/common_mps.py
Find and remove the operator from skip/xfail lists:
Remove entries like:
MPS_XFAILLIST = { "my_op": ..., # Remove this line }
MPS_SKIPLIST = { "my_op": ..., # Remove this line }
- Remove from OpInfo decorators
Location: torch/testing/_internal/common_methods_invocations.py (or related files)
Remove MPS-specific decorators from the OpInfo:
OpInfo( "my_op", # Remove decorators like: # decorators=[skipMPS, expectedFailureMPS("reason")], ... )
- Run tests to verify
Run the specific operator test
python test/test_mps.py -k test_output_match_my_op
Or run full MPS test suite
python test/test_mps.py
Checklist
-
Added MPS dispatch to native_functions.yaml
-
Implemented Metal kernel in kernels/
-
Implemented host-side operator in operations/
-
Handles empty tensors
-
Handles non-contiguous tensors
-
Supports required dtypes (float32, float16, bfloat16, and often complex types via float2/half2)
-
Removed expected failures from torch/testing/_internal/common_mps.py
-
Removed skip/xfail decorators from OpInfo (if applicable)