(MTL S01E13) Metal Stitchable Functions

George Ostrobrod
7 min readJan 16, 2025

--

The Metal stitchable functions enables the implementation of a flexible computational graph, with its details defined through custom shader functions. In this article, I’ll explain the fundamental principles to help you better understand this API and its role in the Metal ecosystem.

Overview

Assume you have the following tasks:

  • Separating the design of an algorithm on the CPU side from the implementation of its parts on the GPU side.
  • Maintaining control over resources and memory consumption while ensuring graph flexibility, making MPSGraph unsuitable for your needs.
  • Allowing a user of your API to implement a specific method while keeping the rest of the details hidden from them.

All of these tasks can be solved using the Metal Stitching mechanism, which is part of Shaders Library API. SwiftUI effects leverage the same API, so understanding it can help you implement these effects in a more effective and proper way.

The main idea of the mechanism is to construct, on the CPU side, a complex `[[visible]]` function (available from Metal 2.3) using smaller, reusable components called `[[stitchable]]` functions (available from Metal 2.4). These `[[stitchable]]` functions act as the building blocks, which can be connected to each other to form the final, desired functionality.

Example

Everything becomes easier to understand when a real example is provided. Here, we’ll implement a variant of the image enhancement algorithm from the previous episode, but in a manual mode. This means we won’t compute the UV shift and Y normalization range within our graph:

Here we have seven `[[stitchable]]` blocks, which can be reused or replaced depending on the algorithm we need to implement. These blocks will be combined into a `[[visible]]` function, which can then be used in a shader or compute kernel.

GPU side

[[visible]]

A function marked with the `[[visible]]` attribute is accessible outside its shader library. You can retrieve an `MTLFunction` instance for it, use these functions with a `visible_function_table`, and pass them to your shader or kernel, selecting the necessary one based on some computed value. Alternatively, you can directly call a visible function. However, if a function directly calls a visible function, it must be explicitly passed in the pipeline descriptor.

As we’re building our method using stitchable blocks, we just need to define it as follows:

[[visible]] float3 adjustColors(float3 rgb, float2 range, float2 offset);

Later, this function can be utilized in a shader, but keep in mind that it needs to be linked in the pipeline state:

fragment float4 fshTextureQuad(ColorInOut in [[stage_in]],
constant float2 &range [[ buffer(0) ]],
constant float2 &offset [[ buffer(1) ]],
texture2d<float> image [[ texture(0) ]]) {
constexpr sampler imageSampler(mag_filter::nearest, min_filter::linear);
float4 source = image.sample(imageSampler, in.texCoord);
float3 result = adjustColors(source.rgb, range, offset);
return float4(result, 1.0);
}

[[stitchable]]

To use a function as a building block of a graph that will be compiled into a `[[visible]]` function, we need to mark it as `[[stitchable]]` (which inherently makes it `[[visible]]` as well).

[[stitchable]] float3 rgb2yuv(float3 rgb){
float y = dot(float3(0.299, 0.587, 0.114), rgb);
return float3(y, 0.493 * (rgb.b - y), 0.877 * (rgb.r - y));
}

[[stitchable]] float extractX(float3 xyz){
return xyz.x;
}

[[stitchable]] float2 extractYZ(float3 xyz){
return xyz.yz;
}

[[stitchable]] float remap(float value, float2 range){
return (value - range.x) / (range.y - range.x);
}

[[stitchable]] float2 shift(float2 value, float2 offset){
return value + offset;
}

[[stitchable]] float3 mergeXYZ(float x, float2 yz){
return float3(x, yz);
}

[[stitchable]] float3 yuv2rgb(float3 yuv){
float y = yuv.x;
float u = yuv.y;
float v = yuv.z;

return float3(y + 1.0 / 0.877 * v,
y - 0.39393 * u - 0.58081 * v,
y + 1.0 / 0.493 * u);
}

So, that’s all. Now we need to combine all of them on the CPU side.

IMPORTANT: The compiler will generate additional metadata for stitchable functions to enable these functions to be used with the Metal Function Stitching API. You should use this attribute only if they need this functionality as the metadata will increase the code size of the function.

CPU Side

On the CPU side, we need to work with `MTLFunction` objects, nodes, and the pipeline state. Here’s the general workflow:

  1. Define nodes: Start by creating stitching nodes, which represent the `[[stitchable]]` functions you will use.
  2. Define the graph: Use these nodes to define the computational graph that represents the flow of data between them.
  3. Create stitchable `MTLFunction` objects: Compile the `[[stitchable]]` functions into `MTLFunction` objects that can be used with the Metal Stitching API.
  4. Build the stitched library: Combine the graph into a single stitched library, which contains the compiled `[[visible]]` function.
  5. Retrieve the function: Extract the `[[visible]]` function from the stitched library for use in your pipeline.
  6. Link the function to your pipeline state: Attach the `[[visible]]` function to your pipeline state, making it usable in your Metal compute or render pipeline.

Now, let’s go through this process step by step.

Graph Definition

Firstly, we need to define input nodes for our algorithm’s input:

// [[visible]] float3 adjustColors(float3 rgb, float2 range, float2 offset);
let rgbInput = MTLFunctionStitchingInputNode(argumentIndex: 0)
let rangeInput = MTLFunctionStitchingInputNode(argumentIndex: 1)
let offsetInput = MTLFunctionStitchingInputNode(argumentIndex: 2)

Next, we need to define our stitchable blocks themselves. Note that you need to pass previous nodes as arguments to connect their outputs to the inputs of subsequent nodes:

// float3 rgb2yuv(float3 rgb)
let rgb2yuvFunc = MTLFunctionStitchingFunctionNode(
name: "rgb2yuv",
arguments: [rgbInput],
controlDependencies: [])
// float extractX(float3 xyz)
let extractXFunc = MTLFunctionStitchingFunctionNode(
name: "extractX",
arguments: [rgb2yuvFunc],
controlDependencies: [])
// float2 extractYZ(float3 xyz)
let extractYZFunc = MTLFunctionStitchingFunctionNode(
name: "extractYZ",
arguments: [rgb2yuvFunc],
controlDependencies: [])
// float remap(float value, float2 range)
let remapFunc = MTLFunctionStitchingFunctionNode(
name: "remap",
arguments: [extractXFunc, rangeInput],
controlDependencies: [])
// float2 shift(float3 value, float2 offset)
let shiftFunc = MTLFunctionStitchingFunctionNode(
name: "shift",
arguments: [extractYZFunc, offsetInput],
controlDependencies: [])
// float3 mergeXYZ(float x, float2 yz)
let mergeXYZFunc = MTLFunctionStitchingFunctionNode(
name: "mergeXYZ",
arguments: [remapFunc, shiftFunc],
controlDependencies: [])
// float3 yuv2rgb(float3 yuv)
let yuv2rgbFunc = MTLFunctionStitchingFunctionNode(
name: "yuv2rgb",
arguments: [mergeXYZFunc],
controlDependencies: [])

NOTE: These nodes are placeholders, and the actual functions are linked during the process of building the stitched library.

When all nodes are defined, we can create our graph. The `yuv2rgbFunc` is also an output node, as it’s the last node in our graph and returns the final result.

let graph = MTLFunctionStitchingGraph(
functionName: "adjustColors",
nodes: [rgb2yuvFunc, extractXFunc, extractYZFunc, remapFunc, shiftFunc, mergeXYZFunc],
outputNode: yuv2rgbFunc,
attributes: [])

Build Stitchable Library

Next, we need to retrieve the `MTLFunction` instances of our stitchable functions from the shader library:

let nodeFunctions = [
library.makeFunction(name: "rgb2yuv"),
library.makeFunction(name: "extractX"),
library.makeFunction(name: "extractYZ"),
library.makeFunction(name: "remap"),
library.makeFunction(name: "shift"),
library.makeFunction(name: "mergeXYZ"),
library.makeFunction(name: "yuv2rgb"),
]

Once we have these functions, we can proceed to build the stitched library. As demonstrated, multiple graphs can be created using the same set of stitchable functions, allowing for flexibility in constructing different algorithms. Here’s how we build the stitched library:

let libDescriptor = MTLStitchedLibraryDescriptor()
libDescriptor.functions = nodeFunctions
libDescriptor.functionGraphs = [graph]
let stitchedLibrary = try? device.makeLibrary(stitchedDescriptor: libDescriptor)

Using Stitched Function in a Pipeline State

With the stitched library ready, the final `[[visible]]` function can be extracted for use in a Metal pipeline or compute operation.

let funcDescriptor = MTLFunctionDescriptor()
funcDescriptor.name = "adjustColors"
let adjustColorsFunc = try? stitchedLibrary?.makeFunction(descriptor: funcDescriptor)

To make this function ready to be used in a pipeline state, we need to describe it as a linked function. We place our function into the `privateFunctions` section because it’s not required to be exported beyond its usage in the pipeline.

NOTE: The `functions` and `binaryFunctions` sections are also available for linking Metal libraries. The distinction lies in how the libraries are linked, but that is outside the scope of this article.

let linkedFunctions = MTLLinkedFunctions()
linkedFunctions.privateFunctions = [adjustColorsFunc!]

That’s all. Now the function is ready to be linked to your pipeline state. By adding it to your `MTLRenderPipelineDescriptor` or `MTLComputePipelineDescriptor`, you can use the `adjustColors` function seamlessly in your rendering or compute workflows.

let pipelineDescriptor = MTLRenderPipelineDescriptor()
pipelineDescriptor.vertexFunction = library.makeFunction(name: "vshSimpleQuad")
pipelineDescriptor.fragmentFunction = library.makeFunction(name: "fshTextureQuad")
pipelineDescriptor.fragmentLinkedFunctions = linkedFunctions
pipelineDescriptor.colorAttachments[0].pixelFormat = .rgba8Unorm
let renderPipeline = try? device.makeRenderPipelineState(descriptor: pipelineDescriptor)

NOTE: As we use the stitched function in our fragment shader, we linked it to the `fragmentLinkedFunctions`. For vertex or kernel shaders, you need to use the corresponding fields, such as `vertexLinkedFunctions` or `computeLinkedFunctions`.

What’s under hood

First of all, let me explain what’s happening at the shader code levels in Metal.

1. MSL (Metal Shading Language): This is the high-level language used for writing shaders. The MSL code is then compiled into AIR.

2. AIR (Apple Intermediate Representation): This is a platform-independent, intermediate representation. It enables optimization and ensures portability across different GPU architectures. `[[stitchable]]` and `[[visible]]` functions live here. AIR is then compiled into MSA.

3. MSA (Metal Shading Assembly): This is the device-specific assembly code optimized for the target GPU. This step occurs when the pipeline state is built.

The graph we constructed operates within AIR and is optimized at that level. Since all of this happens under the hood, we are unable to view the compiled method’s code directly, as it exists in an optimized, intermediate form.

Conclusion

  • The Metal Stitching API enables the creation of flexible computational graphs on the CPU side, leveraging reusable `[[stitchable]]` functions to build complex `[[visible]]` functions.
  • By combining shader modularity with CPU-side graph management, it offers an efficient and scalable solution for advanced graphics and compute operations.
  • The API is well-suited for scenarios where tight control over resources, memory, and graph flexibility is required, especially when MPSGraph is unsuitable.

--

--

George Ostrobrod
George Ostrobrod

Written by George Ostrobrod

Software Engineer with a background in image processing and computer graphics. Made some cool stuff for PicsArt, Pixelmator, Procreate and several others.

No responses yet