/*
 * Decompiled with CFR 0.152.
 */
package umontreal.ssj.stochprocess;

import umontreal.ssj.probdist.NormalDist;
import umontreal.ssj.randvar.NormalGen;
import umontreal.ssj.rng.RandomStream;
import umontreal.ssj.stochprocess.StochasticProcess;

public class CIRProcessEuler
extends StochasticProcess {
    protected NormalGen gen;
    protected double alpha;
    protected double beta;
    protected double sigma;
    protected double[] alphadt;
    protected double[] sigmasqrdt;

    public CIRProcessEuler(double x0, double alpha, double b, double sigma, RandomStream stream) {
        this(x0, alpha, b, sigma, new NormalGen(stream, new NormalDist()));
    }

    public CIRProcessEuler(double x0, double alpha, double b, double sigma, NormalGen gen) {
        this.alpha = alpha;
        this.beta = b;
        this.sigma = sigma;
        this.x0 = x0;
        this.gen = gen;
    }

    @Override
    public double nextObservation() {
        double xOld = this.path[this.observationIndex];
        double x = xOld + (this.beta - xOld) * this.alphadt[this.observationIndex] + this.sigmasqrdt[this.observationIndex] * Math.sqrt(xOld) * this.gen.nextDouble();
        ++this.observationIndex;
        this.path[this.observationIndex] = x >= 0.0 ? x : 0.0;
        return x;
    }

    public double nextObservation(double nextTime) {
        double previousTime = this.t[this.observationIndex];
        double xOld = this.path[this.observationIndex];
        ++this.observationIndex;
        this.t[this.observationIndex] = nextTime;
        double dt = nextTime - previousTime;
        double x = xOld + this.alpha * (this.beta - xOld) * dt + this.sigma * Math.sqrt(dt * xOld) * this.gen.nextDouble();
        this.path[this.observationIndex] = x >= 0.0 ? x : 0.0;
        return x;
    }

    public double nextObservation(double x, double dt) {
        if ((x = x + this.alpha * (this.beta - x) * dt + this.sigma * Math.sqrt(dt * x) * this.gen.nextDouble()) >= 0.0) {
            return x;
        }
        return 0.0;
    }

    @Override
    public double[] generatePath() {
        double xOld = this.x0;
        for (int j = 0; j < this.d; ++j) {
            double x = xOld + (this.beta - xOld) * this.alphadt[j] + this.sigmasqrdt[j] * Math.sqrt(xOld) * this.gen.nextDouble();
            if (x < 0.0) {
                x = 0.0;
            }
            this.path[j + 1] = x;
            xOld = x;
        }
        this.observationIndex = this.d;
        return this.path;
    }

    @Override
    public double[] generatePath(RandomStream stream) {
        this.gen.setStream(stream);
        return this.generatePath();
    }

    public void setParams(double x0, double alpha, double b, double sigma) {
        this.alpha = alpha;
        this.beta = b;
        this.sigma = sigma;
        this.x0 = x0;
        if (this.observationTimesSet) {
            this.init();
        }
    }

    @Override
    public void setStream(RandomStream stream) {
        this.gen.setStream(stream);
    }

    @Override
    public RandomStream getStream() {
        return this.gen.getStream();
    }

    public double getAlpha() {
        return this.alpha;
    }

    public double getB() {
        return this.beta;
    }

    public double getSigma() {
        return this.sigma;
    }

    public NormalGen getGen() {
        return this.gen;
    }

    @Override
    protected void init() {
        super.init();
        this.alphadt = new double[this.d];
        this.sigmasqrdt = new double[this.d];
        for (int j = 0; j < this.d; ++j) {
            double dt = this.t[j + 1] - this.t[j];
            this.alphadt[j] = this.alpha * dt;
            this.sigmasqrdt[j] = this.sigma * Math.sqrt(dt);
        }
    }
}

