(MTL S01E12) MPSGraph

George Ostrobrod
12 min readDec 29, 2024

--

You might need to perform tensor operations on your data or integrate them into your image processing pipeline. Metal provides Metal Performance Shaders Graph (MPSGraph) to handle such tasks. It offers a vast library of tensor operations that can be organized as a graph, eliminating the need for extensive boilerplate code.

Overview

If you have experience with frameworks like Keras or PyTorch, you’re likely familiar with how computational graphs are constructed. If not, the concept is straightforward: a graph represents a sequence of operations applied to an input tensor, which eventually produces an output tensor. Using MPSGraph, you can train or pre-train your ML models directly and then seamlessly integrate them into your image, data processing, or rendering pipelines.

While CoreML excels in supporting a broader range of machine learning operations and offers numerous tools for converting models from other frameworks, performance-critical applications might benefit from MPSGraph. It allows you to work directly with Metal resources, avoiding the overhead of multiple intermediate data conversions.

Importantly, MPSGraph isn’t just for machine learning — it’s a valuable tool for any image processing algorithm that can represent image data as tensors.

Operations

Before diving into how to use MPSGraph, let’s first explore the operations it provides:

  • Basic Math: Includes arithmetic, trigonometric, logical, and bitwise operations.
  • Branching: Control flow operations such as `for` loops and conditional statements (`if`).
  • Recurrent Neural Networks: Operations for building RNNs, including `LSTM` and `GRU` layers.
  • Convolutional Neural Networks: Support for standard CNN layers and operations.
  • Fast Fourier Transform (FFT): Enables frequency domain processing.
  • Miscellaneous: A range of additional operations for various use cases.

For a comprehensive list of supported operations and their details, refer to the official documentation.

Terminology

First, you need to understand how a graph is constructed:

  • Graph: A graph is a sequence of compute operations performed on tensors. It can also be a subgraph within a larger computational graph. This modularity allows complex computations to be broken down into manageable parts.
  • Tensor: In the context of `MPSGraph`, a tensor is an n-dimensional array that serves as the fundamental data structure for computation.
  • Constant: Some tensors can be initialized with constant data, acting as fixed inputs in the graph.
  • Placeholder: Tensors can also serve as placeholders, representing inputs that will be provided dynamically during graph execution.

Understanding these concepts is crucial as they form the foundation for building and manipulating computational graphs in `MPSGraph`.

Preparation

Creating

MPSGraph is just a representation, so it doesn’t require a Metal device for initialization:

var graph = MPSGraph()

Running

If you need to simply run your graph without additional operations, you can do so with the following straightforward method:

let fetch = graph.run(              // 1
feeds: [inputTensor: input], // 2
targetTensors: [output], // 3
targetOperations: nil) // 4
  1. Runs the graph for the given feeds and returns the target tensor values, ensuring all target operations also executed. `fetch` is a dictionary with reulting data.
  2. Feeds dictionary for the placeholder tensors. So you need define `inputTensor` placeholder and have some data in the `input`.
  3. Tensors for which the caller wishes MPSGraphTensorData to be returned. Typically it comes from a built graph (see below).
  4. Operations to be completed at the end of the run.

You also can run the graph in a `MTLCommandQueue`.

Encoding

If you need to perform more than one operation (for example data prepartion, pre-/post-processing, blitting, etc), you can use an `MPSCommandBuffer`, which is created from an `MTLCommandQueue`, and encode your graph there.

var commandQueue: MTLCommandQueue
// ... initalise the queue
let commandBuffer = MPSCommandBuffer(from: commandQueue)
// ... initialise `input`, `output`, `inputData`
let fetch = graph.encode(to: cmdBuf,
feeds: [input: inputData],
targetTensors: [output],
targetOperations: nil,
executionDescriptor: nil)
// ... parsing `fetch[output]`
commandBuffer.commit()
commandBuffer.waitUntilCompleted() // optional if you don't need to read result

Alternatively, you can base the `MPSCommandBuffer` on an existing `MTLCommandBuffer`:

var mtlCommandBuffer: MTLCommandBuffer
// ... initialise the Metal command buffer
let mpsCommandBuffer = MPSCommandBuffer(from: commandQueue)
// ... encoding graph
mpsCommandBuffer.commit()
mpsCommandBuffer.waitUntilCompleted() // optional if you don't need to read result
// ... further Metal command buffer processing

ATTENTION: There are nuances to `MPSCommandBuffer` that aren’t fully covered in the documentation. Here are some key points from documentation in code:

Once we create this MPSCommandBuffer, any methods utilizing it could call commitAndContinue and so the users original commandBuffer may have been committed.

Please use the rootCommandBuffer method to get the current alive underlying MTLCommandBuffer.

`commitAndContinue()` commits the underlying root MTLCommandBuffer, and makes a new one on the same command queue. The MPS heap is moved forward to the new command buffer such that temporary objects used by the previous command buffer can be still be used with the new one.

This provides a way to move work already encoded into consideration by the Metal back end sooner. For large workloads, e.g. a neural networking graph periodically calling `commitAndContinue` may allow you to improve CPU / GPU parallelism without the substantial memory increases associated with double buffering. It will also help decrease overall latency.

Any Metal schedule or completion callbacks previously attached to this object will remain attached to the old command buffer and will fire as expected as the old command buffer is scheduled and completes. If your application is relying on such callbacks to coordinate retain / release of important objects that are needed for work encoded after `commitAndContinue`, your application should retain these objects BEFORE calling`commitAndContinue`, and attach new release callbacks to this object with a new completion handler so that they persist through the lifetime of the new underlying command buffer. You may do this, for example by adding the objects to a mutable array before calling `commitAndContinue`, then release the mutable array in a new completion callback added after `commitAndContinue`.

Because `commitAndContinue` commits the old command buffer then switches to a new one, some aspects of command buffer completion may surprise unwary developers. For example, `waitUntilCompleted` called immediately after `commitAndContinue` asks Metal to wait for the new command buffer to finish, not the old one. Since the new command buffer presumably hasn’t been committed yet, it is formally a deadlock, resources may leak and Metal may complain. Your application should ether call `commit` before `waitUntilCompleted`, or capture the `rootCommandBuffer` from before the call to `commitAndContinue` and wait on that. Similarly, your application should be sure to use the appropriate command buffer when querying the `MTLCommandBuffer.status` property.

If the underlying MTLCommandBuffer also implements `commitAndContinue`, then the message will be forwarded to that object instead. In this way, underlying predicate objects and other state will be preserved.

The following example breaks down the behavior of `MPSCommandBuffer`, highlighting how it manages the underlying Metal command buffer and its implications:

let mtlCmdBuf = commandQueue.makeCommandBuffer()!          // 1
let mpsCmdBuf = MPSCommandBuffer(commandBuffer: mtlCmdBuf) // 2
mpsCmdBuf.rootCommandBuffer.addCompletedHandler {_ in // 3
print("Completed the 1st command buffer")
}

graph.encode(to: mpsCmdBuf, ...) // 4
mpsCmdBuf.rootCommandBuffer.addCompletedHandler {_ in
print("Completed the 2nd command buffer")
}

mpsCmdBuf.commit() // 5
  1. A Metal command buffer (`mtlCmdBuf`) is created from the command queue (or is passed from outside). This is the initial command buffer that serves as the basis for the `MPSCommandBuffer`.
  2. The `MPSCommandBuffer` wraps around the existing Metal command buffer (`mtlCmdBuf`). At this point, the `mpsCmdBuf.rootCommandBuffer` references the original `mtlCmdBuf`.
  3. A completion handler is attached to the original `mtlCmdBuf`. This handler will execute when the command buffer completes execution.
  4. During graph encoding, the original `MPSCommandBuffer` (`mpsCmdBuf`) is implicitly committed with `commitAndContinue()`. This also commits the underlying `mtlCmdBuf`. The `MPSCommandBuffer` then creates a new internal Metal command buffer, replacing the previous one. After this step: the `mtlCmdBuf’s` status changes from `notEnqueued` to `committed`, the new command buffer becomes the active `mpsCmdBuf.rootCommandBuffer`.
  5. The mpsCmdBuf.commit() commits the new underlying Metal command buffer created in step (4).

The output will be:

Completed the 1st command buffer
Completed the 2nd command buffer

ATTENTION: This behavior ensures that `MPSGraph` operations are efficiently managed, but it requires careful handling of command buffer states in more intricate processing pipelines. Be mindful of this behavior when building complex pipelines, as implicit commits can disrupt your workflow if you rely on the original command buffer.

I would recommend using asserts to check the `MTLCommandBuffer.status` during development to detect any problems early in the pipeline.

Compiling

You also can `compile` your graph into a `MPSGraphExecutable` with fixed input and outputs. But it’s another big theme.

Metal Resources

We’ve already discussed integrating running `MPSGraph` in a `MTLCommandQueue` and nuances about operating `MTLCommandBuffer`s, but what about transferring data between vanilla Metal and MPSGraph?

Although we can’t directly use `MTLTexture` or `MTLBuffer` with `MPSGraph`, we can perform the following conversions:

// 1
let tensorData = MPSGraphTensorData(mtlBuffer: buffer, shape: shape, dataType: .float32)

// 2
let mpsImage = MPSImage(texture: texture, featureChannels: featureChannels)
let tensorData = MPSGraphTensorData([mpsImage])
  1. `MTLTexture` to `MPSImage` to `MPSGraphTensorData`**: Metal textures can be converted into `MPSImage` objects. From there, you can wrap the image into `MPSGraphTensorData` for graph usage.
  2. `MTLTexture` to `MPSImage` to `MPSGraphTensorData`: Metal textures can be converted into `MPSImage` objects. From there, you can wrap the image into `MPSGraphTensorData` for graph usage.

ATTENTION:

When converting `MTLBuffer` or `MTLTexture` to `MPSGraphTensorData`, ensure that the data layout and shapes align with the expected tensor format in your graph.

Use GPU-side data export (e.g., from `MPSNDArray` to `MTLBuffer` or `MTLTexture`) whenever possible to minimize CPU-GPU synchronization overhead.

Example

Now let’s design and build a simple graph that performs basic automatic image enhancements:

  1. Convert RGB to YUV: This step separates the luminance (Y) from chrominance (UV), enabling better control over brightness and color adjustments.
  2. Normalize Y channel values: Adjust the luminance values to enhance contrast and fix exposure issues. This improves the overall brightness and dynamic range of the image.
  3. Shift UV values to their average: Recenter the chrominance values to correct white balance issues, ensuring colors appear more natural.
  4. Convert YUV back to RGB: After the adjustments, the image is converted back to RGB format for display or further processing.

Encapsulate nodes

As we can see, the graph has some nodes (e.g., RGB to YUV, YUV to RGB) or even groups of nodes (e.g., Y normalization, UV shift) that aren’t directly provided by the original `MPSGraph` but can be encapsulated as separate modules or subgraphs. To improve modularity and reusability, we can start by implementing these isolated subgraphs.

Color space conversions

As we want to use the new node in our bigger graph, we’ll just extend `MPSGraph` to include the RGB-to-YUV conversion using existing operations:

extension MPSGraph {
func rgb2yuv(rgbTensor: MPSGraphTensor) -> MPSGraphTensor { // 1
let rgbToYUVMatrixData = [Float]([ // 2
0.299, -0.14713, 0.615,
0.587, -0.28886, -0.51499,
0.114, 0.436, -0.10001
]).withUnsafeBufferPointer {
Data(buffer: $0)
}
let rgbToYUVMatrix = constant( // 3
rgbToYUVMatrixData,
shape: [3, 3],
dataType: .float32)
let yuvTensor = matrixMultiplication( // 4
primary: rgbTensor,
secondary: rgbToYUVMatrix,
name: "rgb2yuv")
return yuvTensor
}
}
  1. Extending `MPSGraph`: We add a new method `rgb2yuv(rgbTensor:)` to handle RGB-to-YUV conversion as part of the graph. The new method takes a tensor representing a source RGB image and returns a tensor representing the YUV image.
  2. Wrapping the RGB-to-YUV transformation matrix into a Data object: The RGB-to-YUV transformation matrix is defined as a flat array and converted into a `Data` object. Note that Metal uses column-major matrix definitions, so the matrix data must be arranged accordingly.
  3. Creating a constant tensor: The transformation matrix is added to the graph as a constant tensor, with the shape `[3, 3]` and data type `float32`.
  4. Performing matrix multiplication: Using the `matrixMultiplication` operation, each “pixel” in the input RGB tensor is multiplied by the transformation matrix to produce the output YUV tensor. This operation applies the RGB-to-YUV transformation to the entire input image efficiently.

In the same way we can define YUV-to-RGB operation:

extension MPSGraph {
func yuv2rgb(yuvTensor: MPSGraphTensor) -> MPSGraphTensor {
let yuvToRGBMatrixData = [Float]([
1.0, 1.0, 1.0,
0.0, -0.39465, 2.03211,
1.13983, -0.58060, 0.0
]).withUnsafeBufferPointer {
Data(buffer: $0)
}
let yuvToRGBMatrix = constant(
yuvToRGBMatrixData,
shape: [3, 3],
dataType: .float32)
let rgbTensor = matrixMultiplication(
primary: yuvTensor,
secondary: yuvToRGBMatrix,
name: "yuv2rgb")
return rgbTensor
}
}

Lightness Normalisation

The idea of this part is very straightforward: we get the minimal and maximal values of the Y channel in our image and normalize the entire image by mapping this range to the interval [0, 1]. This process ensures that the lightness values in the image are evenly distributed, improving contrast and brightness while preserving the relative relationships between pixels.

extension MPSGraph {
func normalize(input: MPSGraphTensor) -> MPSGraphTensor {
let minVal = reductionMinimum(with: input, axes: [1, 2], name: "minVal")
let maxVal = reductionMaximum(with: input, axes: [1, 2], name: "maxVal")
let normalized = division(
subtraction(input, minVal, name: "YCentered"),
subtraction(maxVal, minVal, name: "range"),
name: "YNormalized"
)
return normalized
}
}

White Balance

Here we assume that the average value of an image with good white balance should be gray. To achieve this, we compute the average value of all pixels in the image and offset each pixel by this value. This adjustment ensures that the overall color balance of the image is neutral, correcting any unwanted tints or biases caused by imbalanced colors.

extension MPSGraph {
func meanShift(input: MPSGraphTensor) -> MPSGraphTensor {
let average = mean(of: input, axes: [1, 2], name: "Average")
let shifted = subtraction(input, average, name: "Shifted")
return shifted
}
}

Composing Graph

Now when we have most of our components ready, we can assemble the entire graph. This includes all steps of the image processing pipeline, such as slicing the RGBA input into RGB and Alpha channels, performing enhancements on the RGB channels, and finally merging them back into a complete RGBA tensor.

func buildGraph(graph: MPSGraph, rgbaTensor: MPSGraphTensor) -> MPSGraphTensor {
let rgbTensor = graph.sliceTensor(
rgbaTensor,
dimension: -1,
start: 0,
length: 3,
name: "RGB"
)
let yuvTensor = graph.rgb2yuv(rgbTensor: rgbTensor)

let yChannel = graph.sliceTensor(
yuvTensor,
dimension: -1,
start: 0,
length: 1,
name: "YChannel"
)
let yNormalized = graph.normalize(input: yChannel)

let uvChannels = graph.sliceTensor(
yuvTensor,
dimension: -1,
start: 1,
length: 2,
name: "UVChannels"
)
let shiftedUV = graph.meanShift(input: uvChannels)

let normalizedYUV = graph.concatTensors(
[yNormalized, shiftedUV],
dimension: -1,
name: "NormalizedShiftedYUV")

let normalizedRGB = graph.yuv2rgb(yuvTensor: normalizedYUV)

let alphaChannel = graph.sliceTensor(
rgbaTensor,
dimension: -1,
start: 3,
length: 1,
name: "AlphaChannel"
)
let normalizedRGBA = graph.concatTensors(
[normalizedRGB, alphaChannel],
dimension: -1,
name: "NormalizedShiftedRGBA")

return normalizedRGBA
}

Running the Graph

func run(image: MPSImage, commandQueue: MTLCommandQueue) async -> MPSImage? {
// 1
let inputData = MPSGraphTensorData([image])

// 2
let cmdBuf = MPSCommandBuffer(from: commandQueue)
// 3
let graph = MPSGraph()
// 4
let input = graph.placeholder(shape: [-1, -1, -1, -1],
dataType: MPSDataType.float32,
name: "input")
// 5
let output = buildGraph(graph: graph, rgbaTensor: input)
// 6
let fetch = graph.encode(to: cmdBuf,
feeds: [input: inputData],
targetTensors: [output],
targetOperations: nil,
executionDescriptor: nil)
// 7
var result: MPSImage?
if let resArray = fetch[output]?.mpsndarray() {
let resDesc = resArray.descriptor()
let imgDesc = MPSImageDescriptor(
channelFormat: MPSImageFeatureChannelFormat.float16,
width: resDesc.sliceRange(forDimension: 1).length,
height: resDesc.sliceRange(forDimension: 2).length,
featureChannels: 4)
let resImage = MPSImage(device: image.device, imageDescriptor: imgDesc)

// 8
resArray.exportData(with: cmdBuf,
to: [resImage],
offset: MPSImageCoordinate(x: 0, y: 0, channel: 0))
result = resImage
}
// 9
cmdBuf.commit()
cmdBuf.waitUntilCompleted()

return result
}
  1. Convert the input `MPSImage` into `MPSGraphTensorData`.
  2. Create an `MPSCommandBuffer` from the Metal CommandQueue.
  3. Initialize an `MPSGraph` object for constructing the computation graph.
  4. Define a placeholder tensor for input, allowing dynamic data during execution.
  5. Construct the graph for image processing using the `buildGraph` function from previous section.
  6. Encode the graph into the `MPSCommandBuffer` for execution. We could just `run` it, but we have a blitting operation below.
  7. Retrieve the output as an `MPSNDArray` and convert it back to an `MPSImage`.
  8. Export the processed data from the `MPSNDArray` back into the `MPSImage`. This operation performs on GPU side.
  9. Commit the command buffer and wait for execution to complete so we can use the resulting image on CPU side.

As a result of running the graph, we get the following result (left is an original image):

What’s under hood

When you perform a GPU capture in Xcode, you can get an approximate view of what’s happening under the hood of MPSGraph. While you cannot see the specific optimizations or exact implementation details of your graph (as MPSGraph handles this internally), the GPU capture allows you to inspect important aspects like memory consumption, resource flow, and overall execution timeline.

NOTE: you can see, that there’s lots of resources allocated by MPSGraph and which are out of your control. That could significantly impact memory consumption of your app.

Conclusion

  • Simplified Graph-Based Workflow: `MPSGraph` provides a streamlined approach for performing tensor operations without the extensive boilerplate code required by vanilla Metal.
  • Versatility: It is useful for a wide range of tasks, including image processing, machine learning, and general-purpose GPU computations.
  • Integration with Metal: While `MPSGraph` can be integrated into Metal pipelines, there are certain nuances and limitations, especially regarding resource management and command buffer handling.
  • High-Level Abstraction: The graph-based approach abstracts many low-level details, enabling developers to focus on the logic of tensor computations while relying on Metal for optimized execution.
  • Performance Optimization: GPU captures in Xcode can help analyze resource usage and execution timelines, aiding in debugging and performance tuning.
  • Reusability and Modularity: By encapsulating subgraphs and custom operations, you can build reusable and modular pipelines for complex workflows.
  • Memory consumption: MPSGraph alocates lots of internal intermediate resources, so be careful with its memory consuption.

--

--

George Ostrobrod
George Ostrobrod

Written by George Ostrobrod

Software Engineer with a background in image processing and computer graphics. Made some cool stuff for PicsArt, Pixelmator, Procreate and several others.

No responses yet