/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.optimize.listeners;

import java.io.Serializable;
import java.net.InetAddress;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Random;
import lombok.NonNull;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class FailureTestingListener
implements TrainingListener,
Serializable {
    private static final Logger log = LoggerFactory.getLogger(FailureTestingListener.class);
    private final FailureTrigger trigger;
    private final FailureMode failureMode;

    public FailureTestingListener(@NonNull FailureMode mode, @NonNull FailureTrigger trigger) {
        if (mode == null) {
            throw new NullPointerException("mode is marked non-null but is null");
        }
        if (trigger == null) {
            throw new NullPointerException("trigger is marked non-null but is null");
        }
        this.trigger = trigger;
        this.failureMode = mode;
    }

    @Override
    public void iterationDone(Model model, int iteration, int epoch) {
        this.call(CallType.ITER_DONE, model);
    }

    @Override
    public void onEpochStart(Model model) {
        this.call(CallType.EPOCH_START, model);
    }

    @Override
    public void onEpochEnd(Model model) {
        this.call(CallType.EPOCH_END, model);
    }

    @Override
    public void onForwardPass(Model model, List<INDArray> activations) {
        this.call(CallType.FORWARD_PASS, model);
    }

    @Override
    public void onForwardPass(Model model, Map<String, INDArray> activations) {
        this.call(CallType.FORWARD_PASS, model);
    }

    @Override
    public void onGradientCalculation(Model model) {
        this.call(CallType.GRADIENT_CALC, model);
    }

    @Override
    public void onBackwardPass(Model model) {
        this.call(CallType.BACKWARD_PASS, model);
    }

    protected void call(CallType callType, Model model) {
        int epoch;
        int iter;
        if (!this.trigger.initialized()) {
            this.trigger.initialize();
        }
        if (model instanceof MultiLayerNetwork) {
            iter = ((MultiLayerNetwork)model).getIterationCount();
            epoch = ((MultiLayerNetwork)model).getEpochCount();
        } else {
            iter = ((ComputationGraph)model).getIterationCount();
            epoch = ((ComputationGraph)model).getEpochCount();
        }
        boolean triggered = this.trigger.triggerFailure(callType, iter, epoch, model);
        if (triggered) {
            log.error("*** FailureTestingListener was triggered on iteration {}, epoch {} - Failure mode is set to {} ***", new Object[]{iter, epoch, this.failureMode});
            switch (this.failureMode) {
                case OOM: {
                    ArrayList<INDArray> list = new ArrayList<INDArray>();
                    while (true) {
                        INDArray arr = Nd4j.createUninitialized((long)1000000000L);
                        list.add(arr);
                    }
                }
                case SYSTEM_EXIT_1: {
                    log.error("Exiting due to FailureTestingListener triggering - calling System.exit(1)");
                    System.exit(1);
                    break;
                }
                case ILLEGAL_STATE: {
                    log.error("Throwing new IllegalStateException due to FailureTestingListener triggering");
                    throw new IllegalStateException("FailureTestListener was triggered with failure mode " + (Object)((Object)this.failureMode) + " - iteration " + iter + ", epoch " + epoch);
                }
                case INFINITE_SLEEP: {
                    while (true) {
                        try {
                            while (true) {
                                Thread.sleep(10000L);
                            }
                        }
                        catch (InterruptedException interruptedException) {
                            continue;
                        }
                        break;
                    }
                }
                default: {
                    throw new RuntimeException("Unknown enum value: " + (Object)((Object)this.failureMode));
                }
            }
        }
    }

    public static class IterationEpochTrigger
    extends FailureTrigger {
        private final boolean isEpoch;
        private final int count;

        public IterationEpochTrigger(boolean isEpoch, int count) {
            this.isEpoch = isEpoch;
            this.count = count;
        }

        @Override
        public boolean triggerFailure(CallType callType, int iteration, int epoch, Model model) {
            return this.isEpoch && epoch == this.count || !this.isEpoch && iteration == this.count;
        }

        public boolean isEpoch() {
            return this.isEpoch;
        }

        public int getCount() {
            return this.count;
        }

        @Override
        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof IterationEpochTrigger)) {
                return false;
            }
            IterationEpochTrigger other = (IterationEpochTrigger)o;
            if (!other.canEqual(this)) {
                return false;
            }
            if (this.isEpoch() != other.isEpoch()) {
                return false;
            }
            return this.getCount() == other.getCount();
        }

        @Override
        protected boolean canEqual(Object other) {
            return other instanceof IterationEpochTrigger;
        }

        @Override
        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            result = result * 59 + (this.isEpoch() ? 79 : 97);
            result = result * 59 + this.getCount();
            return result;
        }

        @Override
        public String toString() {
            return "FailureTestingListener.IterationEpochTrigger(isEpoch=" + this.isEpoch() + ", count=" + this.getCount() + ")";
        }
    }

    public static class HostNameTrigger
    extends FailureTrigger {
        private final String hostName;
        private boolean shouldFail = false;

        public HostNameTrigger(@NonNull String hostName) {
            if (hostName == null) {
                throw new NullPointerException("hostName is marked non-null but is null");
            }
            this.hostName = hostName;
        }

        @Override
        public boolean triggerFailure(CallType callType, int iteration, int epoch, Model model) {
            return this.shouldFail;
        }

        @Override
        public void initialize() {
            super.initialize();
            try {
                String hostname = InetAddress.getLocalHost().getHostName();
                log.info("FailureTestingListere hostname: {}", (Object)hostname);
                this.shouldFail = this.hostName.equalsIgnoreCase(hostname);
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
        }

        public String getHostName() {
            return this.hostName;
        }

        public boolean isShouldFail() {
            return this.shouldFail;
        }

        public void setShouldFail(boolean shouldFail) {
            this.shouldFail = shouldFail;
        }

        @Override
        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof HostNameTrigger)) {
                return false;
            }
            HostNameTrigger other = (HostNameTrigger)o;
            if (!other.canEqual(this)) {
                return false;
            }
            String this$hostName = this.getHostName();
            String other$hostName = other.getHostName();
            if (this$hostName == null ? other$hostName != null : !this$hostName.equals(other$hostName)) {
                return false;
            }
            return this.isShouldFail() == other.isShouldFail();
        }

        @Override
        protected boolean canEqual(Object other) {
            return other instanceof HostNameTrigger;
        }

        @Override
        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            String $hostName = this.getHostName();
            result = result * 59 + ($hostName == null ? 43 : $hostName.hashCode());
            result = result * 59 + (this.isShouldFail() ? 79 : 97);
            return result;
        }

        @Override
        public String toString() {
            return "FailureTestingListener.HostNameTrigger(hostName=" + this.getHostName() + ", shouldFail=" + this.isShouldFail() + ")";
        }
    }

    public static class UserNameTrigger
    extends FailureTrigger {
        private final String userName;
        private boolean shouldFail = false;

        public UserNameTrigger(@NonNull String userName) {
            if (userName == null) {
                throw new NullPointerException("userName is marked non-null but is null");
            }
            this.userName = userName;
        }

        @Override
        public boolean triggerFailure(CallType callType, int iteration, int epoch, Model model) {
            return this.shouldFail;
        }

        @Override
        public void initialize() {
            super.initialize();
            this.shouldFail = this.userName.equalsIgnoreCase(System.getProperty("user.name"));
        }

        public String getUserName() {
            return this.userName;
        }

        public boolean isShouldFail() {
            return this.shouldFail;
        }

        public void setShouldFail(boolean shouldFail) {
            this.shouldFail = shouldFail;
        }

        @Override
        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof UserNameTrigger)) {
                return false;
            }
            UserNameTrigger other = (UserNameTrigger)o;
            if (!other.canEqual(this)) {
                return false;
            }
            String this$userName = this.getUserName();
            String other$userName = other.getUserName();
            if (this$userName == null ? other$userName != null : !this$userName.equals(other$userName)) {
                return false;
            }
            return this.isShouldFail() == other.isShouldFail();
        }

        @Override
        protected boolean canEqual(Object other) {
            return other instanceof UserNameTrigger;
        }

        @Override
        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            String $userName = this.getUserName();
            result = result * 59 + ($userName == null ? 43 : $userName.hashCode());
            result = result * 59 + (this.isShouldFail() ? 79 : 97);
            return result;
        }

        @Override
        public String toString() {
            return "FailureTestingListener.UserNameTrigger(userName=" + this.getUserName() + ", shouldFail=" + this.isShouldFail() + ")";
        }
    }

    public static class TimeSinceInitializedTrigger
    extends FailureTrigger {
        private final long msSinceInit;
        private long initTime;

        public TimeSinceInitializedTrigger(long msSinceInit) {
            this.msSinceInit = msSinceInit;
        }

        @Override
        public boolean triggerFailure(CallType callType, int iteration, int epoch, Model model) {
            return System.currentTimeMillis() - this.initTime > this.msSinceInit;
        }

        @Override
        public void initialize() {
            super.initialize();
            this.initTime = System.currentTimeMillis();
        }

        public long getMsSinceInit() {
            return this.msSinceInit;
        }

        public long getInitTime() {
            return this.initTime;
        }

        public void setInitTime(long initTime) {
            this.initTime = initTime;
        }

        @Override
        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof TimeSinceInitializedTrigger)) {
                return false;
            }
            TimeSinceInitializedTrigger other = (TimeSinceInitializedTrigger)o;
            if (!other.canEqual(this)) {
                return false;
            }
            if (this.getMsSinceInit() != other.getMsSinceInit()) {
                return false;
            }
            return this.getInitTime() == other.getInitTime();
        }

        @Override
        protected boolean canEqual(Object other) {
            return other instanceof TimeSinceInitializedTrigger;
        }

        @Override
        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            long $msSinceInit = this.getMsSinceInit();
            result = result * 59 + (int)($msSinceInit >>> 32 ^ $msSinceInit);
            long $initTime = this.getInitTime();
            result = result * 59 + (int)($initTime >>> 32 ^ $initTime);
            return result;
        }

        @Override
        public String toString() {
            return "FailureTestingListener.TimeSinceInitializedTrigger(msSinceInit=" + this.getMsSinceInit() + ", initTime=" + this.getInitTime() + ")";
        }
    }

    public static class RandomProb
    extends FailureTrigger {
        private final CallType callType;
        private final double probability;
        private Random rng;

        public RandomProb(CallType callType, double probability) {
            this.callType = callType;
            this.probability = probability;
        }

        @Override
        public boolean triggerFailure(CallType callType, int iteration, int epoch, Model model) {
            return (this.callType == CallType.ANY || callType == this.callType) && this.rng.nextDouble() < this.probability;
        }

        @Override
        public void initialize() {
            super.initialize();
            this.rng = new Random();
        }

        public CallType getCallType() {
            return this.callType;
        }

        public double getProbability() {
            return this.probability;
        }

        public Random getRng() {
            return this.rng;
        }

        public void setRng(Random rng) {
            this.rng = rng;
        }

        @Override
        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof RandomProb)) {
                return false;
            }
            RandomProb other = (RandomProb)o;
            if (!other.canEqual(this)) {
                return false;
            }
            CallType this$callType = this.getCallType();
            CallType other$callType = other.getCallType();
            if (this$callType == null ? other$callType != null : !((Object)((Object)this$callType)).equals((Object)other$callType)) {
                return false;
            }
            if (Double.compare(this.getProbability(), other.getProbability()) != 0) {
                return false;
            }
            Random this$rng = this.getRng();
            Random other$rng = other.getRng();
            return !(this$rng == null ? other$rng != null : !this$rng.equals(other$rng));
        }

        @Override
        protected boolean canEqual(Object other) {
            return other instanceof RandomProb;
        }

        @Override
        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            CallType $callType = this.getCallType();
            result = result * 59 + ($callType == null ? 43 : ((Object)((Object)$callType)).hashCode());
            long $probability = Double.doubleToLongBits(this.getProbability());
            result = result * 59 + (int)($probability >>> 32 ^ $probability);
            Random $rng = this.getRng();
            result = result * 59 + ($rng == null ? 43 : $rng.hashCode());
            return result;
        }

        @Override
        public String toString() {
            return "FailureTestingListener.RandomProb(callType=" + (Object)((Object)this.getCallType()) + ", probability=" + this.getProbability() + ", rng=" + this.getRng() + ")";
        }
    }

    public static class Or
    extends And {
        public Or(FailureTrigger ... triggers) {
            super(triggers);
        }

        @Override
        public boolean triggerFailure(CallType callType, int iteration, int epoch, Model model) {
            boolean b = false;
            for (FailureTrigger ft : this.triggers) {
                b |= ft.triggerFailure(callType, iteration, epoch, model);
            }
            return b;
        }
    }

    public static class And
    extends FailureTrigger {
        protected List<FailureTrigger> triggers;

        public And(FailureTrigger ... triggers) {
            this.triggers = Arrays.asList(triggers);
        }

        @Override
        public boolean triggerFailure(CallType callType, int iteration, int epoch, Model model) {
            boolean b = true;
            for (FailureTrigger ft : this.triggers) {
                b &= ft.triggerFailure(callType, iteration, epoch, model);
            }
            return b;
        }

        @Override
        public void initialize() {
            super.initialize();
            for (FailureTrigger ft : this.triggers) {
                ft.initialize();
            }
        }

        public And(List<FailureTrigger> triggers) {
            this.triggers = triggers;
        }
    }

    public static abstract class FailureTrigger
    implements Serializable {
        private boolean initialized = false;

        public abstract boolean triggerFailure(CallType var1, int var2, int var3, Model var4);

        public boolean initialized() {
            return this.initialized;
        }

        public void initialize() {
            this.initialized = true;
        }

        public boolean isInitialized() {
            return this.initialized;
        }

        public void setInitialized(boolean initialized) {
            this.initialized = initialized;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof FailureTrigger)) {
                return false;
            }
            FailureTrigger other = (FailureTrigger)o;
            if (!other.canEqual(this)) {
                return false;
            }
            return this.isInitialized() == other.isInitialized();
        }

        protected boolean canEqual(Object other) {
            return other instanceof FailureTrigger;
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            result = result * 59 + (this.isInitialized() ? 79 : 97);
            return result;
        }

        public String toString() {
            return "FailureTestingListener.FailureTrigger(initialized=" + this.isInitialized() + ")";
        }
    }

    public static enum CallType {
        ANY,
        EPOCH_START,
        EPOCH_END,
        FORWARD_PASS,
        GRADIENT_CALC,
        BACKWARD_PASS,
        ITER_DONE;

    }

    public static enum FailureMode {
        OOM,
        SYSTEM_EXIT_1,
        ILLEGAL_STATE,
        INFINITE_SLEEP;

    }
}

