---
title: DJLのPyTorchバックエンドでMPS (Metal Performance Shaders) を使うメモ
tags: ["Java", "DJL", "Machine Learning", "PyTorch", "MPS"]
categories: ["Dev", "Java", "ai", "djl"]
date: 2023-10-31T08:10:27Z
updated: 2023-10-31T09:15:27Z
---

[DJL (Deep Java Library)](https://github.com/deepjavalibrary/djl) 0.20.0以降で [MPS](https://developer.apple.com/metal/pytorch/) が使えるようになっていました。
サンプルが見当たらなかったので試したメモ。

サンプルコードは [こちら](https://github.com/making/hello-djl-pytorch) です。Apple M2 Pro、メモリ32 GB、macOS 13.5.2で試しました。

次のように`Device`インスタンスを`Device.of("mps", 0)`で作れば良いようです。 

```java
import ai.djl.Device;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;

public class Main {
	public static void main(String[] args) {
		int dimension = 1024;
		Device device = Device.of("mps", 0);
		//Device device = Device.cpu();
		System.out.println(device.isGpu()); // false
		try (NDManager manager = NDManager.newBaseManager(device)) {
			NDArray array1 = manager.randomUniform(0, 1, new Shape(dimension, dimension));
			NDArray array2 = manager.randomUniform(0, 1, new Shape(dimension, dimension));
			NDArray result = array1.add(array2).mul(10).matMul(array1.transpose()).div(5);
			System.out.println(result);
		}
	}
}
```


MPS自体はGPUではないですが、MPSのAPIを使うことで、GPUが利用されるため、MPSを使ったコードを実行するとアクティビティモニタで `% GPU` の数字が0より大きくなります。
