/*
 * Decompiled with CFR 0.152.
 */
package smile.data.formula;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import smile.data.AbstractTuple;
import smile.data.DataFrame;
import smile.data.Tuple;
import smile.data.formula.All;
import smile.data.formula.Delete;
import smile.data.formula.HyperTerm;
import smile.data.formula.Term;
import smile.data.formula.Variable;
import smile.data.type.DataType;
import smile.data.type.DataTypes;
import smile.data.type.ObjectType;
import smile.data.type.StructField;
import smile.data.type.StructType;
import smile.data.vector.BaseVector;
import smile.math.matrix.DenseMatrix;
import smile.math.matrix.Matrix;

public class Formula
implements Serializable {
    private static final long serialVersionUID = 2L;
    private Term response = null;
    private HyperTerm[] predictors;
    private transient StructType schema;
    private transient StructType xschema;
    private transient Term[] x;
    private transient Term[] xy;

    public Formula(String response) {
        this(new Variable(response));
    }

    public Formula(Term response) {
        this.response = response;
        this.predictors = new HyperTerm[]{new All()};
    }

    public Formula(HyperTerm[] predictors) {
        this.predictors = predictors;
    }

    public Formula(String response, HyperTerm[] predictors) {
        this(new Variable(response), predictors);
    }

    public Formula(Term response, HyperTerm[] predictors) {
        this.response = response;
        this.predictors = predictors;
    }

    public Formula predictors() {
        return Formula.rhs(this.predictors);
    }

    public Optional<Term> response() {
        return Optional.ofNullable(this.response);
    }

    public String toString() {
        String r = this.response == null ? "" : this.response.toString();
        String p = Arrays.stream(this.predictors).map(predictor -> {
            String s = predictor.toString();
            if (!s.startsWith("- ")) {
                s = "+ " + s;
            }
            return s;
        }).collect(Collectors.joining(" "));
        if (p.startsWith("+ ")) {
            p = p.substring(2);
        }
        return String.format("%s ~ %s", r, p);
    }

    public static Formula lhs(String lhs) {
        return new Formula(lhs);
    }

    public static Formula lhs(Term lhs) {
        return new Formula(lhs);
    }

    public static Formula rhs(String ... predictors) {
        return new Formula((HyperTerm[])Arrays.stream(predictors).map(predictor -> new Variable((String)predictor)).toArray(Term[]::new));
    }

    public static Formula rhs(HyperTerm ... predictors) {
        return new Formula(predictors);
    }

    public static Formula of(String response, String ... predictors) {
        return new Formula(response, (HyperTerm[])Arrays.stream(predictors).map(predictor -> new Variable((String)predictor)).toArray(Term[]::new));
    }

    public static Formula of(String response, HyperTerm ... predictors) {
        return new Formula(response, predictors);
    }

    public static Formula of(Term response, HyperTerm ... predictors) {
        return new Formula(response, predictors);
    }

    public StructType schema() {
        return this.schema;
    }

    public StructType xschema() {
        return this.xschema;
    }

    public StructType bind(StructType inputSchema) {
        return this.bind(inputSchema, true);
    }

    private StructType bind(StructType inputSchema, boolean forced) {
        if (this.schema != null && !forced) {
            return this.schema;
        }
        if (this.response != null) {
            this.response.bind(inputSchema);
        }
        Arrays.stream(this.predictors).forEach(term -> term.bind(inputSchema));
        HashSet<String> columns = new HashSet<String>();
        if (this.response != null) {
            columns.addAll(this.response.variables());
        }
        Arrays.stream(this.predictors).filter(predictor -> !(predictor instanceof All)).flatMap(predictor -> predictor.terms().stream()).filter(term -> term instanceof Variable).forEach(term -> columns.add(term.name()));
        ArrayList<Term> factors = new ArrayList<Term>();
        if (this.response != null) {
            factors.add(this.response);
        }
        factors.addAll(Arrays.stream(this.predictors).filter(term -> !(term instanceof Delete)).flatMap(term -> {
            if (term instanceof Delete) {
                return Stream.empty();
            }
            if (term instanceof All) {
                return term.terms().stream().filter(t -> !columns.contains(t.name()));
            }
            return term.terms().stream();
        }).collect(Collectors.toList()));
        List removes = Arrays.stream(this.predictors).filter(term -> term instanceof Delete).flatMap(term -> term.terms().stream()).collect(Collectors.toList());
        factors.removeAll(removes);
        this.xy = factors.toArray(new Term[factors.size()]);
        StructField[] fields = (StructField[])factors.stream().map(factor -> factor.field()).toArray(StructField[]::new);
        this.schema = DataTypes.struct(fields);
        if (this.response != null) {
            this.x = Arrays.copyOfRange(this.xy, 1, this.xy.length);
            this.xschema = DataTypes.struct(Arrays.copyOfRange(fields, 1, fields.length));
        } else {
            this.x = this.xy;
            this.xschema = this.schema;
        }
        return this.schema;
    }

    public Tuple apply(final Tuple t) {
        this.bind(t.schema(), false);
        return new AbstractTuple(){

            @Override
            public StructType schema() {
                return Formula.this.schema;
            }

            @Override
            public Object get(int i) {
                return Formula.this.xy[i].apply(t);
            }

            @Override
            public int getInt(int i) {
                return Formula.this.xy[i].applyAsInt(t);
            }

            @Override
            public long getLong(int i) {
                return Formula.this.xy[i].applyAsLong(t);
            }

            @Override
            public float getFloat(int i) {
                return Formula.this.xy[i].applyAsFloat(t);
            }

            @Override
            public double getDouble(int i) {
                return Formula.this.xy[i].applyAsDouble(t);
            }

            @Override
            public String toString() {
                return Formula.this.schema.toString(this);
            }
        };
    }

    public Tuple x(final Tuple t) {
        this.bind(t.schema(), false);
        return new AbstractTuple(){

            @Override
            public StructType schema() {
                return Formula.this.xschema;
            }

            @Override
            public Object get(int i) {
                return Formula.this.x[i].apply(t);
            }

            @Override
            public int getInt(int i) {
                return Formula.this.x[i].applyAsInt(t);
            }

            @Override
            public long getLong(int i) {
                return Formula.this.x[i].applyAsLong(t);
            }

            @Override
            public float getFloat(int i) {
                return Formula.this.x[i].applyAsFloat(t);
            }

            @Override
            public double getDouble(int i) {
                return Formula.this.x[i].applyAsDouble(t);
            }

            @Override
            public String toString() {
                return Formula.this.xschema.toString(this);
            }
        };
    }

    public double[] xarray(Tuple t) {
        return Arrays.stream(this.x).mapToDouble(term -> term.applyAsDouble(t)).toArray();
    }

    public DataFrame apply(DataFrame df) {
        this.bind(df.schema(), true);
        BaseVector[] vectors = (BaseVector[])Arrays.stream(this.xy).map(term -> term.apply(df)).toArray(BaseVector[]::new);
        return DataFrame.of(vectors);
    }

    public DataFrame x(DataFrame df) {
        this.bind(df.schema(), true);
        BaseVector[] vectors = (BaseVector[])Arrays.stream(this.x).map(term -> term.apply(df)).toArray(BaseVector[]::new);
        return DataFrame.of(vectors);
    }

    public DenseMatrix matrix(DataFrame df) {
        return this.matrix(df, false);
    }

    public DenseMatrix matrix(DataFrame df, boolean bias) {
        this.bind(df.schema(), true);
        int nrows = df.nrows();
        int ncols = this.x.length;
        DenseMatrix m = Matrix.of((int)nrows, (int)(ncols += bias ? 1 : 0), (double)0.0);
        if (bias) {
            for (int i = 0; i < nrows; ++i) {
                m.set(i, ncols - 1, 1.0);
            }
        }
        block6: for (int j = 0; j < this.x.length; ++j) {
            BaseVector v = this.x[j].apply(df);
            DataType type = this.x[j].type();
            switch (type.id()) {
                case Double: 
                case Integer: 
                case Float: 
                case Long: 
                case Boolean: 
                case Byte: 
                case Short: 
                case Char: {
                    for (int i = 0; i < nrows; ++i) {
                        m.set(i, j, v.getDouble(i));
                    }
                    continue block6;
                }
                case String: {
                    for (int i = 0; i < nrows; ++i) {
                        String s = (String)v.get(i);
                        m.set(i, j, s == null ? Double.NaN : Double.valueOf(s));
                    }
                    continue block6;
                }
                case Object: {
                    int i;
                    Class clazz = ((ObjectType)type).getObjectClass();
                    if (clazz == Boolean.class) {
                        for (i = 0; i < nrows; ++i) {
                            Boolean b = (Boolean)v.get(i);
                            if (b != null) {
                                m.set(i, j, b != false ? 1.0 : 0.0);
                                continue;
                            }
                            m.set(i, j, Double.NaN);
                        }
                        continue block6;
                    }
                    if (Number.class.isAssignableFrom(clazz)) {
                        for (i = 0; i < nrows; ++i) {
                            m.set(i, j, v.getDouble(i));
                        }
                        continue block6;
                    }
                    throw new UnsupportedOperationException(String.format("DataFrame.toMatrix() doesn't support type %s", type));
                }
                default: {
                    throw new UnsupportedOperationException(String.format("DataFrame.toMatrix() doesn't support type %s", type));
                }
            }
        }
        return m;
    }

    public BaseVector y(DataFrame df) {
        if (this.response == null) {
            return null;
        }
        this.response.bind(df.schema());
        return this.response.apply(df);
    }

    public double y(Tuple t) {
        if (this.response == null) {
            return 0.0;
        }
        this.response.bind(t.schema());
        return this.response.applyAsDouble(t);
    }

    public int yint(Tuple t) {
        if (this.response == null) {
            return -1;
        }
        this.response.bind(t.schema());
        return this.response.applyAsInt(t);
    }
}

