Jul 19, 2011
Jul 19, 2011
N/A Views
MD
warning
この記事は2年以上前に更新されたものです。情報が古くなっている可能性があります。

久々に探索を復習。8パズル問題を深さ優先探索、幅優先探索、A*探索で解いてみた。適当Java実装なり

DFSはスタック、BFSはキュー、A*はPQを使うのがポイントか?

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Deque;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Queue;
import java.util.Set;
import java.util.Stack;

public class EightPuzzle {
    public static class Node implements Comparable<Node> {
        final int board[];
        int depth;
        Node prev;
        Evaluator evaluator;

        public Node(int... state) {
            if (state == null || state.length != 9) {
                throw new IllegalArgumentException("illegal");
            }
            this.board = state;
            depth = 0;
            prev = null;
        }

        private void swap(int x[], int i, int j) {
            int t = x[i];
            x[i] = x[j];
            x[j] = t;
        }

        public List<Node> movableNodes() {
            int pos = -1;
            for (int i = 0; i < board.length; i++) {
                if (board[i] == 0) {
                    pos = i;
                    break;
                }
            }
            int x = pos / 3;
            int y = pos % 3;
            List<Node> states = new ArrayList<Node>();

            if (x >= 1) {
                int[] u = board.clone();
                swap(u, pos, pos - 3);
                states.add(new Node(u));
            }

            if (x <= 1) {
                int[] d = board.clone();
                swap(d, pos, pos + 3);
                states.add(new Node(d));
            }

            if (y >= 1) {
                int[] l = board.clone();
                swap(l, pos, pos - 1);
                states.add(new Node(l));
            }

            if (y <= 1) {
                int[] r = board.clone();
                swap(r, pos, pos + 1);
                states.add(new Node(r));
            }

            for (Node n : states) {
                n.prev = this;
                n.depth = depth + 1;
                n.evaluator = evaluator;
            }
            return states;
        }

        public void printTrace() {
            Node n = this;
            Deque<Node> trace = new ArrayDeque<Node>();
            while (n != null) {
                trace.addFirst(n);
                n = n.prev;
            }
            System.out.println(trace);
        }

        @Override
        public int hashCode() {
            final int prime = 31;
            int result = 1;
            result = prime * result + Arrays.hashCode(board);
            return result;
        }

        @Override
        public boolean equals(Object obj) {
            if (this == obj)
                return true;
            if (obj == null)
                return false;
            if (getClass() != obj.getClass())
                return false;
            Node other = (Node) obj;
            if (!Arrays.equals(board, other.board))
                return false;
            return true;
        }

        @Override
        public String toString() {
            StringBuilder sb = new StringBuilder();
            for (int i = 0; i < 3; i++) {
                sb.append(System.getProperty("line.separator"));
                sb.append("[");
                for (int j = 0; j < 3; j++) {
                    if (j != 0) {
                        sb.append(", ");
                    }
                    sb.append(board[i * 3 + j]);
                }
                sb.append("]");
            }
            sb.append(System.getProperty("line.separator"));
            return sb.toString();
        }

        public int compareTo(Node o) {
            if (evaluator != null) {
                int e1 = depth + evaluator.evaluate(this);
                int e2 = o.depth + evaluator.evaluate(o);
                if (e1 < e2) {
                    return -1;
                } else if (e1 > e2) {
                    return 1;
                } else {
                    return 0;
                }
            }
            return 0;
        }
    }

    public static interface Evaluator {
        int evaluate(Node node);
    }
    
   // 深さ優先探索
    public static Node depthFirstSearch(Node init, Node goal) {
        Set<Node> closed = new HashSet<Node>();
        Stack<Node> open = new Stack<Node>();
        open.add(init);
        int maxDepth = 9;

        while (!open.isEmpty()) {
            Node n = open.pop();
            closed.add(n);
            for (Node next : n.movableNodes()) {
                if (closed.contains(next)) {
                    continue;
                }
                if (next.equals(goal)) {
                    System.out.println("goal");
                    System.out.println("depth=" + next.depth);
                    System.out.println("closed=" + closed.size());
                    return next;
                }
                if (next.depth < maxDepth) {
                    open.add(next);
                }
            }
        }
        return null;
    }

    // 幅優先探索
    public static Node breadthFirstSearch(Node init, Node goal) {
        Set<Node> closed = new HashSet<Node>();
        Queue<Node> open = new LinkedList<Node>();
        open.add(init);

        while (!open.isEmpty()) {
            Node n = open.poll();
            closed.add(n);
            for (Node next : n.movableNodes()) {
                if (closed.contains(next)) {
                    continue;
                }
                if (next.equals(goal)) {
                    System.out.println("goal");
                    System.out.println("depth=" + next.depth);
                    System.out.println("closed=" + closed.size());
                    return next;
                }
                open.add(next);
            }
        }
        return null;
    }

    // A*探索
    public static Node aStarSearch(Node init, Node goal, Evaluator evaluator) {
        Map<Integer, Node> closed = new HashMap<Integer, Node>();
        init.evaluator = evaluator;
        PriorityQueue<Node> open = new PriorityQueue<Node>();
        open.add(init);

        while (!open.isEmpty()) {
            Node n = open.poll();
            closed.put(n.hashCode(), n);
            if (n.equals(goal)) {
                System.out.println("goal");
                System.out.println("depth=" + n.depth);
                System.out.println("closed=" + closed.size());
                return n;
            }
            for (Node next : n.movableNodes()) {
                Node prior = closed.get(next.hashCode());
                if (prior != null) {
                    if (evaluator.evaluate(next) < evaluator.evaluate(prior)) {
                        closed.remove(next.hashCode());
                        open.add(next);
                    }
                } else {
                    open.add(next);
                }
            }
        }
        return null;
    }

    public static void main(String[] args) {
        Node init = new Node(8, 1, 3, 0, 4, 5, 2, 7, 6);
        final Node goal = new Node(1, 2, 3, 8, 0, 4, 7, 6, 5);
        {
            System.out.println("== DFS ==");
            Node result = depthFirstSearch(init, goal);
            if (result != null) {
                result.printTrace();
            } else {
                System.out.println("NG");
            }
        }
        {
            System.out.println("== BFS ==");
            Node result = breadthFirstSearch(init, goal);
            if (result != null) {
                result.printTrace();
            } else {
                System.out.println("NG");
            }
        }
        {
            System.out.println("== A*  ==");
            Node result = aStarSearch(init, goal, new Evaluator() { // いい加減実装なヒューリスティック関数
                int home[][];
                {
                    home = new int[goal.board.length][2];
                    for (int i = 0; i < goal.board.length; i++) {
                        int e = goal.board[i];
                        home[e] = new int[2];
                        home[e][0] = i / 3;
                        home[e][1] = i % 3;
                    }
                }

                public int evaluate(Node node) {
                    int s = 0;
                    for (int i = 0; i < node.board.length; i++) {
                        int e = node.board[i];
                        // ゴールまでのマンハッタン距離
                        s = s + Math.abs(home[e][0] - i / 3)
                                + Math.abs(home[e][1] - i % 3);
                    }
                    return s;
                }
            });
            if (result != null) {
                result.printTrace();
            } else {
                System.out.println("NG");
            }
        }
    }
}

実行結果

適当な割にはA*が良い結果

== DFS ==
goal
depth=9
closed=161
[
[8, 1, 3]
[0, 4, 5]
[2, 7, 6]
, 
[8, 1, 3]
[2, 4, 5]
[0, 7, 6]
, 
[8, 1, 3]
[2, 4, 5]
[7, 0, 6]
, 
[8, 1, 3]
[2, 4, 5]
[7, 6, 0]
, 
[8, 1, 3]
[2, 4, 0]
[7, 6, 5]
, 
[8, 1, 3]
[2, 0, 4]
[7, 6, 5]
, 
[8, 1, 3]
[0, 2, 4]
[7, 6, 5]
, 
[0, 1, 3]
[8, 2, 4]
[7, 6, 5]
, 
[1, 0, 3]
[8, 2, 4]
[7, 6, 5]
, 
[1, 2, 3]
[8, 0, 4]
[7, 6, 5]
]
== BFS ==
goal
depth=9
closed=242
[
[8, 1, 3]
[0, 4, 5]
[2, 7, 6]
, 
[8, 1, 3]
[2, 4, 5]
[0, 7, 6]
, 
[8, 1, 3]
[2, 4, 5]
[7, 0, 6]
, 
[8, 1, 3]
[2, 4, 5]
[7, 6, 0]
, 
[8, 1, 3]
[2, 4, 0]
[7, 6, 5]
, 
[8, 1, 3]
[2, 0, 4]
[7, 6, 5]
, 
[8, 1, 3]
[0, 2, 4]
[7, 6, 5]
, 
[0, 1, 3]
[8, 2, 4]
[7, 6, 5]
, 
[1, 0, 3]
[8, 2, 4]
[7, 6, 5]
, 
[1, 2, 3]
[8, 0, 4]
[7, 6, 5]
]
== A*  ==
goal
depth=9
closed=15
[
[8, 1, 3]
[0, 4, 5]
[2, 7, 6]
, 
[8, 1, 3]
[2, 4, 5]
[0, 7, 6]
, 
[8, 1, 3]
[2, 4, 5]
[7, 0, 6]
, 
[8, 1, 3]
[2, 4, 5]
[7, 6, 0]
, 
[8, 1, 3]
[2, 4, 0]
[7, 6, 5]
, 
[8, 1, 3]
[2, 0, 4]
[7, 6, 5]
, 
[8, 1, 3]
[0, 2, 4]
[7, 6, 5]
, 
[0, 1, 3]
[8, 2, 4]
[7, 6, 5]
, 
[1, 0, 3]
[8, 2, 4]
[7, 6, 5]
, 
[1, 2, 3]
[8, 0, 4]
[7, 6, 5]
]
Found a mistake? Update the entry.
Share this article: