/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.autodiff.validation.functions;

import java.util.Arrays;
import org.nd4j.common.function.Function;
import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.ndarray.INDArray;

public class RelErrorFn
implements Function<INDArray, String> {
    private final INDArray expected;
    private final double maxRelativeError;
    private final double minAbsoluteError;

    public String apply(INDArray actual) {
        if (!Arrays.equals(this.expected.shape(), actual.shape())) {
            throw new IllegalStateException("Shapes differ! " + Arrays.toString(this.expected.shape()) + " vs " + Arrays.toString(actual.shape()));
        }
        NdIndexIterator iter = new NdIndexIterator(this.expected.shape());
        while (iter.hasNext()) {
            double re;
            long[] next = iter.next();
            double d1 = this.expected.getDouble(next);
            double d2 = actual.getDouble(next);
            if (d1 == 0.0 && d2 == 0.0 || Math.abs(d1 - d2) < this.minAbsoluteError || !((re = Math.abs(d1 - d2) / (Math.abs(d1) + Math.abs(d2))) > this.maxRelativeError)) continue;
            return "Failed on relative error at position " + Arrays.toString(next) + ": relativeError=" + re + ", maxRE=" + this.maxRelativeError + ", absError=" + Math.abs(d1 - d2) + ", minAbsError=" + this.minAbsoluteError + " - values (" + d1 + "," + d2 + ")";
        }
        return null;
    }

    public RelErrorFn(INDArray expected, double maxRelativeError, double minAbsoluteError) {
        this.expected = expected;
        this.maxRelativeError = maxRelativeError;
        this.minAbsoluteError = minAbsoluteError;
    }

    public INDArray getExpected() {
        return this.expected;
    }

    public double getMaxRelativeError() {
        return this.maxRelativeError;
    }

    public double getMinAbsoluteError() {
        return this.minAbsoluteError;
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof RelErrorFn)) {
            return false;
        }
        RelErrorFn other = (RelErrorFn)o;
        if (!other.canEqual(this)) {
            return false;
        }
        INDArray this$expected = this.getExpected();
        INDArray other$expected = other.getExpected();
        if (this$expected == null ? other$expected != null : !this$expected.equals(other$expected)) {
            return false;
        }
        if (Double.compare(this.getMaxRelativeError(), other.getMaxRelativeError()) != 0) {
            return false;
        }
        return Double.compare(this.getMinAbsoluteError(), other.getMinAbsoluteError()) == 0;
    }

    protected boolean canEqual(Object other) {
        return other instanceof RelErrorFn;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        INDArray $expected = this.getExpected();
        result = result * 59 + ($expected == null ? 43 : $expected.hashCode());
        long $maxRelativeError = Double.doubleToLongBits(this.getMaxRelativeError());
        result = result * 59 + (int)($maxRelativeError >>> 32 ^ $maxRelativeError);
        long $minAbsoluteError = Double.doubleToLongBits(this.getMinAbsoluteError());
        result = result * 59 + (int)($minAbsoluteError >>> 32 ^ $minAbsoluteError);
        return result;
    }

    public String toString() {
        return "RelErrorFn(expected=" + this.getExpected() + ", maxRelativeError=" + this.getMaxRelativeError() + ", minAbsoluteError=" + this.getMinAbsoluteError() + ")";
    }
}

