DJL (Deep Java Library) 0.20.0以降で MPS が使えるようになっていました。 サンプルが見当たらなかったので試したメモ。
サンプルコードは こちら です。Apple M2 Pro、メモリ32 GB、macOS 13.5.2で試しました。
次のようにDevice
インスタンスをDevice.of("mps", 0)
で作れば良いようです。
import ai.djl.Device;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
public class Main {
public static void main(String[] args) {
int dimension = 1024;
Device device = Device.of("mps", 0);
//Device device = Device.cpu();
System.out.println(device.isGpu()); // false
try (NDManager manager = NDManager.newBaseManager(device)) {
NDArray array1 = manager.randomUniform(0, 1, new Shape(dimension, dimension));
NDArray array2 = manager.randomUniform(0, 1, new Shape(dimension, dimension));
NDArray result = array1.add(array2).mul(10).matMul(array1.transpose()).div(5);
System.out.println(result);
}
}
}
MPS自体はGPUではないですが、MPSのAPIを使うことで、GPUが利用されるため、MPSを使ったコードを実行するとアクティビティモニタで % GPU
の数字が0より大きくなります。