package org.planx.msd.graph;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import org.planx.msd.Discriminator;
import org.planx.msd.Discriminators;
import org.planx.msd.Extractor;
import org.planx.util.Association;
import org.planx.util.Pair;

/**
 * Removes redundant data in a directed acyclic graph using multiset
 * discrimination.
 *
 * @author Thomas Ambus
 */
public class Compactor<N> {
    protected Navigator<N> nav = null;
    protected Discriminator<N> disc = null;
    protected boolean doStat = false;
    protected int inputNumOfNodes = 0;
    protected int inputShareDegree = 0;
    protected int outputNumOfNodes = 0;
    protected int merges = 0;

    protected Compactor() {}

    /**
     * Constructs a new <code>Compactor</code>. The provided discriminator
     * must be able to discriminate a single "layer" of nodes. That is,
     * the discriminator should not be recursive, but rather treat child
     * nodes as {@link org.planx.msd.lang.EquivalenceClassDiscriminable}s.
     */
    public Compactor(Navigator<N> nav, Discriminator<N> d) {
        this(nav, d, false);
    }

    public Compactor(Navigator<N> nav, Discriminator<N> d,
                                     boolean doStatistics) {
        this.nav = nav;
        this.disc = d;
        this.doStat = doStatistics;
    }

    /**
     * Compact the graph with the specified node as root.
     */
    public void share(N root) {
        if (doStat) clearStatistics();

        // STEP 1: Group nodes by height.

        List<List<Pair<N,Edge>>> groups = heightDiscriminate(root);

        // STEP 2: Perform multiset discrimination on each group from
        // the bottom up and update pointers.

        for (List<Pair<N,Edge>> group : groups) {
            Extractor<Pair<N,Edge>,N,Edge> e =
                Discriminators.pairExtractor();
            Collection<List<Edge>> eqClasses = disc.discriminate(group, e);

            for (List<Edge> eqCls : eqClasses) {
                if (doStat) outputNumOfNodes++;
                if (eqCls.size() == 1) continue;
                N canon = nav.chooseCanonical(eqCls);
                if (canon == null) throw new NullPointerException
                            ("Must choose canonical node");

                for (Edge loc : eqCls) {
                    N parent = loc.parent;
                    if (canon != loc.node && parent != null) {
                        if (doStat) {
                            N current = nav.getChild(parent, loc.childIndex);
                            if (canon != current) merges++;
                        }
                        nav.setChild(parent, loc.childIndex, canon);
                    }
                }
            }
        }
    }

    private List<List<Pair<N,Edge>>> heightDiscriminate(N node) {
        List<List<Pair<N,Edge>>> groups = new ArrayList<List<Pair<N,Edge>>>();
        visitDepthFirst(node, null, 0, new Object(), groups);
        return groups;
    }

    private int visitDepthFirst(N node, N parent, int childIndex,
              Object visitToken, List<List<Pair<N,Edge>>> groups) {
        int height = -1;

        if (nav.getVisitToken(node) == visitToken) {
            height = nav.getHeight(node);
            visit(node, parent, childIndex, height, groups);
            if (doStat) inputShareDegree++;
            return height;
        }
        nav.setVisitToken(node, visitToken);

        if (doStat) inputNumOfNodes++;

        for (int i=0, max=nav.childCount(node); i<max; i++) {
            N child = nav.getChild(node, i);
            int h = nav.isOutside(child) ? 0 : visitDepthFirst(
                            child, node, i, visitToken, groups);
            if (h > height) height = h;
        }

        height++;
        visit(node, parent, childIndex, height, groups);
        nav.setHeight(node, height);
        return height;
    }

    private void visit(N node, N parent, int childIndex, int height,
                                    List<List<Pair<N,Edge>>> groups) {
        while (groups.size() <= height) {
            groups.add(new ArrayList<Pair<N,Edge>>());
        }
        Edge loc = new Edge(node, parent, childIndex);
        groups.get(height).add(new Association<N,Edge>(node, loc));
    }

    /**
     * A container class that specifies the location of a node in a graph.
     */
    public class Edge {
        public N node, parent;
        public int childIndex;

        Edge(N node, N parent, int childIndex) {
            this.node = node;
            this.parent = parent;
            this.childIndex = childIndex;
        }

        public String toString() {
            String s1 = (node == null) ? "null" : node.toString();
            String s2 = (parent == null) ? "null" : parent.toString();
            return "{"+s1+","+s2+","+childIndex+"}";
        }
    }

    protected void clearStatistics() {
        inputNumOfNodes = 0;
        inputShareDegree = 0;
        outputNumOfNodes = 0;
        merges = 0;
    }

    /**
     * Return a <code>Statistics</code> object containing information
     * about the last run of <code>share()</code>.
     */
    public Statistics getStatistics() {
        return new Statistics(this);
    }

    /**
     * A container class for statistical information.
     */
    public static class Statistics {
        public int inputNumOfNodes, inputShareDegree, outputNumOfNodes,
                                                     outputShareDegree;
        public int merges;

        Statistics(Compactor c) {
            this.inputNumOfNodes = c.inputNumOfNodes;
            this.inputShareDegree = c.inputShareDegree;
            this.outputNumOfNodes = c.outputNumOfNodes;
            this.merges = c.merges;
        }

        public String toString() {
            return "Compactor{in="+inputNumOfNodes+",shared="+inputShareDegree+
                   ",out="+outputNumOfNodes+",merges="+merges+"}";
        }
    }
}

