/*
 * Decompiled with CFR 0.152.
 */
package smile.stat.distribution;

import smile.math.MathEx;
import smile.stat.distribution.DiscreteDistribution;

public class BernoulliDistribution
extends DiscreteDistribution {
    private static final long serialVersionUID = 2L;
    public final double p;
    public final double q;
    private final double entropy;

    public BernoulliDistribution(double p) {
        if (p < 0.0 || p > 1.0) {
            throw new IllegalArgumentException("Invalid p: " + p);
        }
        this.p = p;
        this.q = 1.0 - p;
        this.entropy = -p * MathEx.log2(p) - this.q * MathEx.log2(this.q);
    }

    public static BernoulliDistribution fit(int[] data) {
        int k = 0;
        for (int i : data) {
            if (i == 1) {
                ++k;
                continue;
            }
            if (i == 0) continue;
            throw new IllegalArgumentException("Invalid value " + i);
        }
        double p = (double)k / (double)data.length;
        return new BernoulliDistribution(p);
    }

    public BernoulliDistribution(boolean[] data) {
        int k = 0;
        for (boolean b : data) {
            if (!b) continue;
            ++k;
        }
        this.p = (double)k / (double)data.length;
        this.q = 1.0 - this.p;
        this.entropy = -this.p * MathEx.log2(this.p) - this.q * MathEx.log2(this.q);
    }

    @Override
    public int length() {
        return 1;
    }

    @Override
    public double mean() {
        return this.p;
    }

    @Override
    public double variance() {
        return this.p * this.q;
    }

    @Override
    public double entropy() {
        return this.entropy;
    }

    public String toString() {
        return String.format("Bernoulli Distribution(%.4f)", this.p);
    }

    @Override
    public double rand() {
        if (MathEx.random() < this.q) {
            return 0.0;
        }
        return 1.0;
    }

    @Override
    public double p(int k) {
        if (k == 0) {
            return this.q;
        }
        if (k == 1) {
            return this.p;
        }
        return 0.0;
    }

    @Override
    public double logp(int k) {
        if (k == 0) {
            return Math.log(this.q);
        }
        if (k == 1) {
            return Math.log(this.p);
        }
        return Double.NEGATIVE_INFINITY;
    }

    @Override
    public double cdf(double k) {
        if (k < 0.0) {
            return 0.0;
        }
        if (k == 0.0) {
            return this.q;
        }
        return 1.0;
    }

    @Override
    public double quantile(double p) {
        if (p < 0.0 || p > 1.0) {
            throw new IllegalArgumentException("Invalid p: " + p);
        }
        if (p <= 1.0 - this.p) {
            return 0.0;
        }
        return 1.0;
    }
}

