IK.AM

@making's tech note


DJLのPyTorchバックエンドでMPS (Metal Performance Shaders) を使うメモ

🗃 {Dev/Java/ai/djl}
🏷 Java 🏷 DJL 🏷 Machine Learning 🏷 PyTorch 🏷 MPS 
🗓 Updated at 2023-10-31T09:15:27Z  🗓 Created at 2023-10-31T08:10:27Z   🌎 English Page

DJL (Deep Java Library) 0.20.0以降で MPS が使えるようになっていました。 サンプルが見当たらなかったので試したメモ。

サンプルコードは こちら です。Apple M2 Pro、メモリ32 GB、macOS 13.5.2で試しました。

次のようにDeviceインスタンスをDevice.of("mps", 0)で作れば良いようです。

import ai.djl.Device;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;

public class Main {
    public static void main(String[] args) {
        int dimension = 1024;
        Device device = Device.of("mps", 0);
        //Device device = Device.cpu();
        System.out.println(device.isGpu()); // false
        try (NDManager manager = NDManager.newBaseManager(device)) {
            NDArray array1 = manager.randomUniform(0, 1, new Shape(dimension, dimension));
            NDArray array2 = manager.randomUniform(0, 1, new Shape(dimension, dimension));
            NDArray result = array1.add(array2).mul(10).matMul(array1.transpose()).div(5);
            System.out.println(result);
        }
    }
}

MPS自体はGPUではないですが、MPSのAPIを使うことで、GPUが利用されるため、MPSを使ったコードを実行するとアクティビティモニタで % GPU の数字が0より大きくなります。


✒️️ Edit  ⏰ History  🗑 Delete