IK.AM

@making's tech note


Java8のcomputeIfAbsentとラムダを使ってメモ化

🗃 {Programming/Java/java/util/Map}
🗓 Updated at 2014-03-08T06:58:06Z  🗓 Created at 2014-03-08T06:58:06Z   🌎 English Page

きしださんがとっくに書いてたけど、メモっとく

階段を上がる場合に1歩で1段、2段、3段登れるとして、階段の登り方は何パターンあるでしょうか。

再帰で解けそうですね。

public static void main(String[] args) {
    int n = 35;
    long start = System.nanoTime();
    long count = countWays(n);
    long end = System.nanoTime();
    System.out.printf("ellapsed %1$,3d [μsec]%n",(end - start) / 1000);
    System.out.println(count);
}

public static long countWays(int n) {
    if (n < 0) return 0;
    if (n == 0) return 1;
    return countWays(n - 1) + countWays(n - 2) + countWays(n - 3);
}

実行結果

ellapsed 10,096,405 [μsec]
1132436852

この計算はO(3^n)なので遅いです。途中に同じ計算を何度も繰り返しているので、無駄です。結果をキャッシュ(メモ化)してショートカットしましょう。

Java8からjava.util.Mapにメモ化computeIfAbsentメソッドが追加されました。Mapにキーが存在しなかったら引数のラムダを計算するメソッド。

書き直します。

public static void main(String[] args) {
    int n = 35;
    long start = System.nanoTime();
    long count = countWaysWithMemo(n);
    long end = System.nanoTime();
    System.out.printf("ellapsed %1$,3d [μsec]%n",(end - start) / 1000);
    System.out.println(count);
}

private static final Map<Integer, Long> memo = new HashMap<>();

public static long countWaysWithMemo(int n) {
    if (n < 0) return 0;
    if (n == 0) return 1;
    return memo.computeIfAbsent(n, x -> countWaysWithMemo(x - 1) + countWaysWithMemo(x - 2) + countWaysWithMemo(x - 3));
}

実行結果

ellapsed 80,897 [μsec]
1132436852

くっそ速くなりましたね!

ちなみに配列使った方がより爆速ですw

public static void main(String[] args) {
    int n = 35;
    long start = System.nanoTime();
    long count = countWaysWithFastMemo(n);
    long end = System.nanoTime();
    System.out.printf("ellapsed %1$,3d [μsec]%n",(end - start) / 1000);
    System.out.println(count);
}

private static long[] fastMemo = new long[100];

public static long countWaysWithFastMemo(int n) {
    if (n < 0) return 0;
    if (n == 0) return 1;
    if (fastMemo[n] > 0) {
        return fastMemo[n];
    }
    long count = countWaysWithFastMemo(n - 1) + countWaysWithFastMemo(n - 2) + countWaysWithFastMemo(n - 3);
    fastMemo[n] = count;
    return count;
}

実行結果

ellapsed  23 [μsec]
1132436852

✒️️ Edit  ⏰ History  🗑 Delete