// A QuadTree is a map of (point, value) pairs.
// In this version, point is type Point, and the value is int.
//
// Given a query point, we can list the pairs in order of their
// distance from the query.  There is no "rebalancing" in this tree,
// so use random insertion order if possible.

// Loosely based on http://algs4.cs.princeton.edu/92search/QuadTree.java
// Mic Grigni added Pair, Node.bbox, and the distance iterator.

// Note: animation code in main, just animating the iterator.

import java.util.Iterator;
import java.util.Comparator;

public class QuadTree //<Value>
{
    // Data: the root of the tree, or null when empty.
    private Node root = null;

    // Default constructor: an empty QuadTree.
    public QuadTree() {}

    // Given a Point array, insert each point with its index as its
    // value.  Insert in random order, to get expected depth O(log N).
    // This uses the StdRandom generator.  For repeatable tests, use
    // StdRandom.setSeed(s) first.
    public QuadTree(Point[] points) {
        int N = points.length;
        // Construct pseudo-random permutation of indices 0..N-1.
        int[] perm  = new int[N];
        for (int i=0; i<N; ++i) perm[i]=i; // identity
        StdRandom.shuffle(perm);
        for (int v: perm)
            insert(points[v], v);
    }

    // Animation: leave anims negative to disable StdDraw animation.
    // Otherwise animation is enabled, and anims is a millisecond
    // delay at certain events.  We assume that an animation context
    // (like canvas size and axis scaling) is already set.
    private static int anims = -1;
    static void animate(int a) { anims=a; }
    static void hugePen() { StdDraw.setPenRadius(0.016); }
    static void fatPen()  { StdDraw.setPenRadius(0.008); }
    static void thinPen() { StdDraw.setPenRadius(0.002); }

    // QuadTree.Pair is an immutable (key, value) pair.
    // This type is user-visible via our iterator.
    public class Pair {
        public final Point point;
        public final int value;
        Pair(Point p, int v) { point=p; value=v; }
        Pair(Pair that) { point=that.point; value=that.value; }
        // An int from 0 to 3: the quadrant of p, relative to our point.
        public int quadrant(Point p) {
            return (point.x < p.x ? 2 : 0) + (point.y < p.y ? 1 : 0);
        }
        final static int SE=0, NE=1, SW=2, NW=3; // quadrants ids (unused?)
        // We will override these two methods in Node:
        double distanceTo(Point p) { return point.distanceTo(p); }
        boolean isLeaf() { return true; } // a single point
        // For animation:
        void draw() { point.draw(); }
    }

    // A Node in our tree is a Pair with up to 4 children and a bbox.
    // Node is not user-visible, but a Node could be returned as a Pair.
    private class Node extends Pair
    {
        // The two fields are null while the Node is a leaf.
        Node[] quad = null;     // four subtrees, one per quadrant
        Rectangle bbox = null;  // bounding box of all points in subtree
        // Construct as a leaf (no children yet):
        Node(Point p, int v) { super(p, v); }
        // Prepare to be internal (do before adding first child).
        void becomeInternal() {
            assert isLeaf();
            bbox = new Rectangle(point);
            quad = new Node[4];
        }
        // We override these methods from Pair:
        double distanceTo(Point p) {
            return bbox==null ? point.distanceTo(p) : bbox.distanceTo(p);
        }
        // Test whether this node is still a leaf.
        boolean isLeaf() { return quad==null; }// or: bbox==null
        void draw() { if (bbox==null) point.draw(); else bbox.draw(); }
    }

    // Public insert method: add pair (p,v).
    public void insert(Point p, int v) { root = insert(root, p, v); }
    // Private recursive implementation, returns modified subtree h.
    private Node insert(Node h, Point p, int v) {
        if (h==null) return new Node(p, v);
        if (h.isLeaf()) h.becomeInternal();
        h.bbox = h.bbox.add(p);          // grow bounding box at h
        int q = h.quadrant(p);           // which way does p go?
        h.quad[q] = insert(h.quad[q], p, v);
        return h;
    }

    // This iterator returns (point,value) pairs in order of
    // increasing distance of the point from the query point p.
    public Iterator<Pair> closestPairs(Point p) { return new PairIter(p); }

    // The implentation of closestPairs: two private classes
    private class DistComp implements Comparator<Pair>
    {
        final Point p;          // the query point
        DistComp(Point p) { this.p = p; }
        public int compare(Pair a, Pair b) { // a or b may be a Node
            return Double.compare(a.distanceTo(p), b.distanceTo(p));
        }
    }
    private class PairIter implements Iterator<Pair>
    {
        Point p;
        MinPQ<Pair> pq;         // ordered by distance from query
        PairIter(Point query) {
            p = query;
            pq = new MinPQ<Pair>(30, new DistComp(query));
            add(root);
        }
        void add(Pair pair) {
            if (pair!=null) {
                if (anims >= 0) {
                    StdDraw.setPenColor(StdDraw.BLACK);
                    fatPen();
                    pair.draw();
                }
                pq.insert(pair);
            }
        }
        public boolean hasNext() { return !pq.isEmpty(); }
        public Pair next() {
            while (true) {
                Pair pair = pq.delMin();
                if (anims >= 0) {
                    // erase old pair (was black, redraw in white)
                    StdDraw.setPenColor(StdDraw.WHITE);
                    fatPen();
                    pair.draw();
                }
                if (pair.isLeaf()) {
                    if (anims >= 0) {
                        // found one, draw red edge
                        StdDraw.setPenColor(StdDraw.RED);
                        fatPen();
                        pair.draw();
                        thinPen();
                        p.drawTo(pair.point);
                        StdDraw.show(anims);
                    }
                    return pair;
                }
                // pair is an internal Node: need to disassemble it.
                Node n = (Node)pair;
                add(new Pair(n));            // n without children
                for (Node q: n.quad) add(q); // each child separately
            }
        }
        // An Iterator is not required to implement remove.
        public void remove() { throw new UnsupportedOperationException(); }
    }

    // Draw every point (fat) and bbox (thin).
    public void draw() { draw(root); }
    private void draw(Node n) {
        if (n==null) return;
        fatPen();
        n.point.draw();
        if (n.isLeaf()) return;
        thinPen();
        n.bbox.draw();
        for (Node child: n.quad) draw(child);
    }

    // This iterator returns just the values, in order of increasing
    // distance of the points from the query point p.  This is really
    // just a wrapper around the closestPairs iterator.
    public Iterator<Integer> closestValues(Point p) {
        return new ValueOfPairIter(closestPairs(p));
    }
    private class ValueOfPairIter implements Iterator<Integer> {
        Iterator<Pair> it;
        ValueOfPairIter(Iterator<Pair> it) { this.it = it; }
        public boolean hasNext() { return it.hasNext(); }
        public Integer next() { return it.next().value; }
        public void remove() { it.remove(); }
    }

    // A random point in the unit square.
    // Use StdRandom.setSeed(s) to make this repeatable.
    static Point randPoint() {
        return new Point(StdRandom.uniform(), StdRandom.uniform());
    }

    // main: run an animation of the iterator.
    // Optional command line arguments: N DELAY
    public static void main(String[] args)
    {
        int N = 200;             // number of random points in qt
        int DELAY = 100;         // delay during animation
        int SIDE = 900;          // size of window and png image
        if (args.length > 0) N = Integer.parseInt(args[0]);
        if (args.length > 1) DELAY = Integer.parseInt(args[1]);
        Point[] points = new Point[N];
        QuadTree qt = new QuadTree();
        // Generate N random points, save to array and qt.
        for (int i=0; i<N; ++i)
            qt.insert(points[i] = randPoint(), i);
        // Draw the QuadTree, in gray, thin lines.
        StdDraw.setCanvasSize(SIDE,SIDE);
        StdDraw.show(0);
        thinPen();
        StdDraw.setPenColor(StdDraw.GRAY);
        qt.draw();
        StdDraw.setPenColor(StdDraw.RED);
        hugePen();
        Point p = randPoint();
        p.draw();
        StdDraw.show(0);        // update onscreen image
        StdOut.println("Hit return to start iterator at red point.");
        StdIn.readLine();
        StdDraw.setPenColor(StdDraw.BLACK);
        animate(DELAY);
        Iterator<Integer> closest = qt.closestValues(p);
        double dist = 0.0;
        int count = 0;
        while (closest.hasNext()) {
            int w = closest.next();
            ++count;
            double wdist = p.distanceTo(points[w]);
            if (wdist < dist)
                StdOut.printf("Warning! distance decreased (%f to %f)%n",
                              dist, wdist);
            dist = wdist;
        }
        StdOut.printf("Iterator returned %d points, hit return to exit.%n",
                      count);
        StdIn.readLine();
        System.exit(0);         // close StdDraw window
    }
}
