package umontreal.ssj.stat.list.lincv;

import cern.colt.matrix.DoubleMatrix2D;
import umontreal.ssj.stat.FunctionOfMultipleMeansTally;
import umontreal.ssj.stat.Tally;
import umontreal.ssj.util.MultivariateFunction;

/* loaded from: input_file:umontreal/ssj/stat/list/lincv/FunctionOfMultipleMeansTallyWithCV.class */
public class FunctionOfMultipleMeansTallyWithCV extends FunctionOfMultipleMeansTally {
    private MultivariateFunction funcNoCV;
    private double[] beta;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:umontreal/ssj/stat/list/lincv/FunctionOfMultipleMeansTallyWithCV$LinCVFunction.class */
    public static class LinCVFunction implements MultivariateFunction {
        private FunctionOfMultipleMeansTallyWithCV fcv;
        private int pplusq;
        private double[] tmp;

        public LinCVFunction(int i) {
            this.pplusq = i;
        }

        public void initFunctionOfMultipleMeansTallyWithCV(FunctionOfMultipleMeansTallyWithCV functionOfMultipleMeansTallyWithCV) {
            this.fcv = functionOfMultipleMeansTallyWithCV;
            this.tmp = new double[functionOfMultipleMeansTallyWithCV.getListOfTalliesWithCV().sizeWithoutCV()];
        }

        @Override // umontreal.ssj.util.MultivariateFunction
        public int getDimension() {
            return this.pplusq;
        }

        @Override // umontreal.ssj.util.MultivariateFunction
        public double evaluate(double... dArr) {
            MultivariateFunction functionWithoutCV = this.fcv.getFunctionWithoutCV();
            if (dArr.length != getDimension()) {
                throw new IllegalArgumentException("x has length " + dArr.length + ", which differs from the dimension " + getDimension());
            }
            int dimension = getDimension() - this.fcv.beta.length;
            System.arraycopy(dArr, 0, this.tmp, 0, dimension);
            double evaluate = functionWithoutCV.evaluate(this.tmp);
            ListOfTalliesWithCV<Tally> listOfTalliesWithCV = this.fcv.getListOfTalliesWithCV();
            for (int i = 0; i < this.fcv.beta.length; i++) {
                evaluate -= this.fcv.beta[i] * (dArr[dimension + i] - listOfTalliesWithCV.getExpectedValue(i));
            }
            return evaluate;
        }

        @Override // umontreal.ssj.util.MultivariateFunction
        public double evaluateGradient(int i, double... dArr) {
            MultivariateFunction functionWithoutCV = this.fcv.getFunctionWithoutCV();
            if (dArr.length != getDimension()) {
                throw new IllegalArgumentException("x has length " + dArr.length + ", which differs from the dimension " + getDimension());
            }
            int dimension = getDimension() - this.fcv.beta.length;
            if (i >= dimension) {
                return -this.fcv.beta[i - dimension];
            }
            System.arraycopy(dArr, 0, this.tmp, 0, dimension);
            return functionWithoutCV.evaluateGradient(i, this.tmp);
        }
    }

    public FunctionOfMultipleMeansTallyWithCV(MultivariateFunction multivariateFunction, int i, int i2) {
        super(new LinCVFunction(i + i2), ListOfTalliesWithCV.createWithTally(i, i2));
        this.funcNoCV = multivariateFunction;
        this.beta = new double[i2];
        ((LinCVFunction) getFunction()).initFunctionOfMultipleMeansTallyWithCV(this);
    }

    public FunctionOfMultipleMeansTallyWithCV(MultivariateFunction multivariateFunction, ListOfTalliesWithCV<Tally> listOfTalliesWithCV) {
        super(new LinCVFunction(listOfTalliesWithCV.size()), listOfTalliesWithCV);
        this.funcNoCV = multivariateFunction;
        this.beta = new double[listOfTalliesWithCV.getNumControlVariables()];
        ((LinCVFunction) getFunction()).initFunctionOfMultipleMeansTallyWithCV(this);
    }

    public MultivariateFunction getFunctionWithoutCV() {
        return this.funcNoCV;
    }

    public int getNumControlVariables() {
        return this.beta.length;
    }

    public int getDimensionWithoutCV() {
        return getDimension() - this.beta.length;
    }

    public double getBeta(int i) {
        return this.beta[i];
    }

    public void setBeta(int i, double d) {
        this.beta[i] = d;
    }

    public double[] getBeta() {
        return this.beta;
    }

    public void setBeta(double[] dArr) {
        if (dArr.length != this.beta.length) {
            throw new IllegalArgumentException("Invalid length of beta");
        }
        this.beta = dArr;
    }

    public ListOfTalliesWithCV<Tally> getListOfTalliesWithCV() {
        return (ListOfTalliesWithCV) super.getListOfTallies();
    }

    public double getExpectedValue(int i) {
        return getListOfTalliesWithCV().getExpectedValue(i);
    }

    public void setExpectedValue(int i, double d) {
        getListOfTalliesWithCV().setExpectedValue(i, d);
    }

    public double[] getExpectedValues() {
        return getListOfTalliesWithCV().getExpectedValues();
    }

    public void setExpectedValues(double[] dArr) {
        getListOfTalliesWithCV().setExpectedValues(dArr);
    }

    public void estimateBeta() {
        ListOfTalliesWithCV<Tally> listOfTalliesWithCV = getListOfTalliesWithCV();
        listOfTalliesWithCV.estimateBeta();
        estimateBetaFromMatrix(listOfTalliesWithCV.getBeta());
    }

    /* JADX WARN: Type inference failed for: r2v8, types: [umontreal.ssj.stat.StatProbe] */
    public void estimateBetaFromMatrix(DoubleMatrix2D doubleMatrix2D) {
        ListOfTalliesWithCV<Tally> listOfTalliesWithCV = getListOfTalliesWithCV();
        int dimension = getDimension() - this.beta.length;
        int length = this.beta.length;
        double[] dArr = new double[dimension];
        double[] dArr2 = new double[dimension];
        for (int i = 0; i < dimension; i++) {
            dArr[i] = listOfTalliesWithCV.get(i).average();
        }
        for (int i2 = 0; i2 < dimension; i2++) {
            dArr2[i2] = this.funcNoCV.evaluateGradient(i2, dArr);
        }
        for (int i3 = 0; i3 < length; i3++) {
            this.beta[i3] = 0.0d;
            for (int i4 = 0; i4 < dimension; i4++) {
                double[] dArr3 = this.beta;
                int i5 = i3;
                dArr3[i5] = dArr3[i5] + (doubleMatrix2D.getQuick(i3, i4) * dArr2[i4]);
            }
        }
    }

    @Override // umontreal.ssj.stat.FunctionOfMultipleMeansTally, umontreal.ssj.stat.StatProbe
    /* renamed from: clone */
    public FunctionOfMultipleMeansTallyWithCV mo57clone() {
        FunctionOfMultipleMeansTallyWithCV functionOfMultipleMeansTallyWithCV = (FunctionOfMultipleMeansTallyWithCV) super.mo57clone();
        functionOfMultipleMeansTallyWithCV.beta = (double[]) this.beta.clone();
        LinCVFunction linCVFunction = new LinCVFunction(getDimension());
        functionOfMultipleMeansTallyWithCV.func = linCVFunction;
        linCVFunction.initFunctionOfMultipleMeansTallyWithCV(functionOfMultipleMeansTallyWithCV);
        return functionOfMultipleMeansTallyWithCV;
    }
}
