---
title: Java8のcomputeIfAbsentとラムダを使ってメモ化
tags: []
categories: ["Programming", "Java", "java", "util", "Map"]
date: 2014-03-08T06:58:06Z
updated: 2014-03-08T06:58:06Z
---

きしださんが[とっくに書いてた][1]けど、メモっとく

> 階段を上がる場合に1歩で1段、2段、3段登れるとして、階段の登り方は何パターンあるでしょうか。

再帰で解けそうですね。

    public static void main(String[] args) {
        int n = 35;
        long start = System.nanoTime();
        long count = countWays(n);
        long end = System.nanoTime();
        System.out.printf("ellapsed %1$,3d [μsec]%n",(end - start) / 1000);
        System.out.println(count);
    }

    public static long countWays(int n) {
        if (n < 0) return 0;
        if (n == 0) return 1;
        return countWays(n - 1) + countWays(n - 2) + countWays(n - 3);
    }

実行結果

    ellapsed 10,096,405 [μsec]
    1132436852

この計算はO(3^n)なので遅いです。途中に同じ計算を何度も繰り返しているので、無駄です。結果をキャッシュ(メモ化）してショートカットしましょう。

Java8から`java.util.Map`にメモ化`computeIfAbsent`メソッドが追加されました。Mapにキーが存在しなかったら引数のラムダを計算するメソッド。

書き直します。

    public static void main(String[] args) {
        int n = 35;
        long start = System.nanoTime();
        long count = countWaysWithMemo(n);
        long end = System.nanoTime();
        System.out.printf("ellapsed %1$,3d [μsec]%n",(end - start) / 1000);
        System.out.println(count);
    }

    private static final Map<Integer, Long> memo = new HashMap<>();

    public static long countWaysWithMemo(int n) {
        if (n < 0) return 0;
        if (n == 0) return 1;
        return memo.computeIfAbsent(n, x -> countWaysWithMemo(x - 1) + countWaysWithMemo(x - 2) + countWaysWithMemo(x - 3));
    }

実行結果

    ellapsed 80,897 [μsec]
    1132436852

くっそ速くなりましたね！

ちなみに配列使った方がより爆速ですw

    public static void main(String[] args) {
        int n = 35;
        long start = System.nanoTime();
        long count = countWaysWithFastMemo(n);
        long end = System.nanoTime();
        System.out.printf("ellapsed %1$,3d [μsec]%n",(end - start) / 1000);
        System.out.println(count);
    }

    private static long[] fastMemo = new long[100];

    public static long countWaysWithFastMemo(int n) {
        if (n < 0) return 0;
        if (n == 0) return 1;
        if (fastMemo[n] > 0) {
            return fastMemo[n];
        }
        long count = countWaysWithFastMemo(n - 1) + countWaysWithFastMemo(n - 2) + countWaysWithFastMemo(n - 3);
        fastMemo[n] = count;
        return count;
    }

実行結果

    ellapsed  23 [μsec]
    1132436852

  [1]: http://d.hatena.ne.jp/nowokay/20130523
