/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.autodiff.samediff;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.stream.Collectors;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.NameScope;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.SameDiffLambda;
import org.nd4j.autodiff.samediff.SameDiffNoArgSingleLambda;
import org.nd4j.autodiff.samediff.SameDiffSingleLambda;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ops.custom.Invoke;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Exit;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.NextIteration;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.shade.guava.collect.Sets;

public class ControlFlow {
    public static SDVariable[] initializeLoopBody(String[] namesToUse, SameDiff loopBody, int maxIterations) {
        Preconditions.checkState((namesToUse != null && namesToUse.length == 2 ? 1 : 0) != 0, (String)"Number of input names must be 2.");
        SDVariable[] ret = new SDVariable[]{loopBody.constant(namesToUse[1], maxIterations), loopBody.var(namesToUse[0], Nd4j.zeros(1))};
        return ret;
    }

    public static SDVariable[] initializeLoopBody(String[] namesToUse, SameDiff loopBody, int maxIterations, boolean extraCond) {
        Preconditions.checkState((namesToUse != null && namesToUse.length == 3 ? 1 : 0) != 0, (String)"Number of input names must be 3.");
        SDVariable[] ret = new SDVariable[]{loopBody.var(namesToUse[0], Nd4j.zeros(1)), loopBody.constant(namesToUse[1], maxIterations), loopBody.constant(namesToUse[2], extraCond)};
        return ret;
    }

    public static SDVariable[] args(SDVariable maxIterations, SDVariable condIn, SDVariable startIterations, SDVariable[] extraArgs) {
        return LoopArgs.builder().extraArgs(extraArgs).condIn(condIn).maxIters(maxIterations).startIter(startIterations).build().toArgs();
    }

    public static SDVariable ifCond(SameDiff sameDiff, String outputName, String ifName, @NonNull SameDiffNoArgSingleLambda cond, @NonNull SameDiffNoArgSingleLambda trueBody, @NonNull SameDiffNoArgSingleLambda falseBody) {
        if (cond == null) {
            throw new NullPointerException("cond is marked non-null but is null");
        }
        if (trueBody == null) {
            throw new NullPointerException("trueBody is marked non-null but is null");
        }
        if (falseBody == null) {
            throw new NullPointerException("falseBody is marked non-null but is null");
        }
        ifName = sameDiff.newBlockName(ifName == null ? "if" : ifName);
        NameScope ifScope = sameDiff.withNameScope(ifName);
        NameScope condScope = sameDiff.withNameScope("cond");
        SDVariable pred = cond.define(sameDiff);
        condScope.close();
        if (pred.dataType() != DataType.BOOL) {
            for (SDVariable v : sameDiff.getVariablesInScope(ifScope)) {
                sameDiff.getVariables().remove((Object)v.name());
            }
            for (SameDiffOp op : sameDiff.getOpsInScope(ifScope)) {
                for (String in : op.getInputsToOp()) {
                    sameDiff.removeArgFromOp(in, op.getOp());
                }
                sameDiff.getOps().remove(op.getName());
            }
            throw new IllegalStateException("Can not use " + pred.name() + " as the condition of an If statement, the condition must be a boolean.");
        }
        HashMap<String, SDVariable[]> switches = new HashMap<String, SDVariable[]>();
        HashSet declared = Sets.newHashSet(sameDiff.variableMap().keySet());
        sameDiff.addArgumentInterceptor(argument -> {
            if (argument == null) {
                return null;
            }
            if (declared == null || !declared.contains(argument.name())) {
                return argument;
            }
            if (switches.containsKey(argument.name())) {
                return ((SDVariable[])switches.get(argument.name()))[1];
            }
            SDVariable[] s = sameDiff.switchOp(argument, pred);
            switches.put(argument.name(), s);
            return s[1];
        });
        NameScope trueScope = sameDiff.withNameScope("trueBody");
        SDVariable trueOut = trueBody.define(sameDiff);
        sameDiff.removeArgumentInterceptor();
        if (declared.contains(trueOut.name())) {
            SDVariable[] s = sameDiff.switchOp(trueOut, pred);
            switches.put(trueOut.name(), s);
            trueOut = s[1];
        }
        trueScope.close();
        HashSet declared2 = Sets.newHashSet(sameDiff.variableMap().keySet());
        sameDiff.addArgumentInterceptor(argument -> {
            if (!declared2.contains(argument.name())) {
                return argument;
            }
            if (switches.containsKey(argument.name())) {
                return ((SDVariable[])switches.get(argument.name()))[0];
            }
            SDVariable[] s = sameDiff.switchOp(argument, pred);
            switches.put(argument.name(), s);
            return s[0];
        });
        NameScope falseScope = sameDiff.withNameScope("falseBody");
        SDVariable falseOut = falseBody.define(sameDiff);
        sameDiff.removeArgumentInterceptor();
        if (declared2.contains(falseOut.name())) {
            SDVariable[] s = sameDiff.switchOp(falseOut, pred);
            switches.put(falseOut.name(), s);
            falseOut = s[0];
        }
        falseScope.close();
        SDVariable output = sameDiff.merge(trueOut, falseOut);
        ifScope.close();
        return sameDiff.updateVariableNameAndReference(output, outputName);
    }

    public static SDVariable[] loopWithConditions(LoopParams loopParams) {
        return ControlFlow.loopWithConditions(loopParams.outputVarNames, loopParams.loopName, loopParams.parent, loopParams.functionBody, loopParams.functionName, loopParams.loopVars, loopParams.functionBodyInputs, loopParams.functionBodyOutputs);
    }

    public static SDVariable[] loopWithConditions(String[] outputVarNames, String loopName, SameDiff parent, SameDiff functionBody, String functionName, SDVariable[] loopVars, String[] functionBodyInputs, String[] functionBodyOutputs) {
        Preconditions.checkState((functionBodyInputs != null && functionBodyOutputs != null && functionBodyInputs.length == functionBodyOutputs.length ? 1 : 0) != 0, (String)"Sub graph input and output names must  be defined and equal in length.");
        Preconditions.checkState((loopVars.length == functionBodyInputs.length ? 1 : 0) != 0, (String)"Loop variables and function body inputs must be equal in length.");
        for (SDVariable variable : loopVars) {
            if (variable.getSameDiff() == parent) continue;
            throw new IllegalArgumentException("Variable named " + variable.name() + " does not have correct samediff instance. Must have parent outer samediff instance.");
        }
        SameDiffSingleLambda cond = ControlFlow.condBody();
        SameDiffLambda loopBody = ControlFlow.loopBody(parent, functionBody, functionName, functionBodyInputs, functionBodyOutputs);
        return parent.whileLoop(outputVarNames, loopName, loopVars, cond, loopBody);
    }

    public static LoopLambdaArgs argsFromInputs(SDVariable[] inputs) {
        SDVariable[] extraArgs;
        SDVariable[] sDVariableArray = extraArgs = inputs.length > 3 ? new SDVariable[inputs.length - 3] : new SDVariable[]{};
        if (extraArgs.length > 0) {
            for (int i = 0; i < extraArgs.length; ++i) {
                extraArgs[i] = inputs[i + 3];
            }
        }
        return LoopLambdaArgs.builder().iterCount(inputs[1]).iterStart(inputs[0]).condIn(inputs[2]).extraArgs(extraArgs).build();
    }

    public static SameDiffLambda loopBody(SameDiff parent, SameDiff functionBody, String functionName, String[] subGraphInputNames, String[] subGraphOutputNames) {
        Preconditions.checkState((subGraphInputNames != null && subGraphOutputNames != null && subGraphInputNames.length == subGraphOutputNames.length ? 1 : 0) != 0, (String)"Sub graph input and output names must  be defined and equal in length.");
        parent.putSubFunction(functionName, functionBody);
        return (sameDiff, inputs) -> {
            LoopLambdaArgs loopLambdaArgs = ControlFlow.argsFromInputs(inputs);
            Invoke.InvokeParams invokeParams = loopLambdaArgs.invokeParams(functionName, subGraphInputNames, subGraphOutputNames);
            SDVariable[] invoke = sameDiff.invoke(invokeParams);
            ArrayList<SDVariable> retList = new ArrayList<SDVariable>();
            retList.add(inputs[0].add(1.0));
            retList.add(inputs[1]);
            retList.add(invoke[2]);
            for (int i = 3; i < invoke.length; ++i) {
                retList.add(invoke[i]);
            }
            return retList.toArray(new SDVariable[retList.size()]);
        };
    }

    public static SDVariable[] whileLoop(SameDiff sameDiff, String[] outputNames, String loopName, @NonNull SDVariable[] loopVars, @NonNull SameDiffSingleLambda cond, @NonNull SameDiffLambda body) {
        if (loopVars == null) {
            throw new NullPointerException("loopVars is marked non-null but is null");
        }
        if (cond == null) {
            throw new NullPointerException("cond is marked non-null but is null");
        }
        if (body == null) {
            throw new NullPointerException("body is marked non-null but is null");
        }
        String frameName = sameDiff.newBlockName(loopName == null ? "while" : loopName);
        NameScope loopScope = sameDiff.withNameScope(frameName);
        SDVariable counter = sameDiff.scalar(sameDiff.generateNewVarName("counter", 0), 0);
        SDVariable[] entered = new SDVariable[loopVars.length];
        for (int i = 0; i < loopVars.length; ++i) {
            entered[i] = new Enter(sameDiff, frameName, loopVars[i]).outputVariable();
        }
        SDVariable[] merged = new SDVariable[loopVars.length];
        Merge[] mergeOps = new Merge[loopVars.length];
        for (int i = 0; i < loopVars.length; ++i) {
            mergeOps[i] = new Merge(sameDiff, entered[i], entered[i]);
            merged[i] = mergeOps[i].outputVariable();
        }
        Merge counterMerge = new Merge(sameDiff, counter, counter);
        counter = counterMerge.outputVariable();
        NameScope condScope = sameDiff.withNameScope("cond");
        SDVariable condResult = cond.define(sameDiff, merged);
        condScope.close();
        if (condResult.dataType() != DataType.BOOL) {
            throw new IllegalStateException("Can not use " + condResult.name() + " as the condition of an While loop, the condition must be a boolean.");
        }
        HashSet alreadyEntered = Sets.newHashSet();
        SDVariable[] trueSwitches = new SDVariable[loopVars.length];
        SDVariable[] exits = new SDVariable[loopVars.length];
        for (int i = 0; i < loopVars.length; ++i) {
            SDVariable[] s = sameDiff.switchOp(merged[i], condResult);
            trueSwitches[i] = s[1];
            alreadyEntered.add(s[1].name());
            exits[i] = new Exit(sameDiff, s[0]).outputVariable();
        }
        HashSet declared = Sets.newHashSet(sameDiff.variableMap().keySet());
        HashMap done = new HashMap();
        SameDiff sd = sameDiff;
        sameDiff.addArgumentInterceptor(argument -> {
            if (argument == null) {
                return null;
            }
            if (!declared.contains(argument.name())) {
                return argument;
            }
            if (alreadyEntered.contains(argument.name())) {
                return argument;
            }
            if (done.containsKey(argument.name())) {
                return (SDVariable)done.get(argument.name());
            }
            SDVariable e = new Enter(sd, frameName, argument, true).outputVariable();
            done.put(argument.name(), e);
            return e;
        });
        NameScope bodyScope = sameDiff.withNameScope("body");
        SDVariable[] outs = body.define(sameDiff, trueSwitches);
        if (outs.length != mergeOps.length) {
            throw new IllegalArgumentException("Number of loop variables must be equal to number of outputs.");
        }
        bodyScope.close();
        sameDiff.removeArgumentInterceptor();
        counter.add(1.0);
        for (int i = 0; i < outs.length; ++i) {
            SDVariable n = new NextIteration(sameDiff, outs[i]).outputVariable();
            mergeOps[i].replaceArg(1, n);
        }
        counterMerge.replaceArg(1, counter);
        loopScope.close();
        return sameDiff.updateVariableNamesAndReferences(exits, outputNames);
    }

    public static SameDiffSingleLambda condBody() {
        return (sameDiff, inputs) -> {
            SDVariable currIteration = inputs[0];
            SDVariable maxIterations = inputs[1];
            SDVariable extraCond = inputs[2];
            SDVariable and = sameDiff.bitwise().and(currIteration.lt(maxIterations.castTo(currIteration.dataType())).castTo(DataType.INT64), extraCond.castTo(DataType.INT64));
            SDVariable ret = and.castTo(DataType.BOOL);
            return ret;
        };
    }

    public static class LoopLambdaArgs {
        private SDVariable iterStart;
        private SDVariable iterCount;
        private SDVariable condIn;
        private SDVariable[] extraArgs;

        public LoopLambdaArgs(SDVariable iterStart, SDVariable iterCount, SDVariable[] extraArgs, SDVariable condIn) {
            if (condIn.dataType() != DataType.BOOL) {
                throw new IllegalArgumentException("Data type for condition must be boolean!");
            }
            if (!iterCount.dataType().isNumerical()) {
                throw new IllegalArgumentException("Data type for condition must be numerical!");
            }
            this.iterCount = iterCount;
            this.extraArgs = extraArgs;
            this.condIn = condIn;
            this.iterStart = iterStart;
        }

        public Invoke.InvokeParams invokeParams(String functionName, String[] subGraphInputNames, String[] subGraphOutputNames) {
            ArrayList<SDVariable> inputs = new ArrayList<SDVariable>();
            inputs.add(this.iterStart);
            inputs.add(this.iterCount);
            inputs.add(this.condIn);
            inputs.addAll(Arrays.asList(this.extraArgs));
            return Invoke.InvokeParams.builder().functionName(functionName).inputs(inputs.toArray(new SDVariable[inputs.size()])).subGraphInputVarNames(subGraphInputNames).subGraphOutputVarNames(subGraphOutputNames).inputVarNames(inputs.stream().map(input -> input.name()).collect(Collectors.toList()).toArray(new String[inputs.size()])).build();
        }

        public static LoopLambdaArgsBuilder builder() {
            return new LoopLambdaArgsBuilder();
        }

        public SDVariable getIterStart() {
            return this.iterStart;
        }

        public SDVariable getIterCount() {
            return this.iterCount;
        }

        public SDVariable getCondIn() {
            return this.condIn;
        }

        public SDVariable[] getExtraArgs() {
            return this.extraArgs;
        }

        public void setIterStart(SDVariable iterStart) {
            this.iterStart = iterStart;
        }

        public void setIterCount(SDVariable iterCount) {
            this.iterCount = iterCount;
        }

        public void setCondIn(SDVariable condIn) {
            this.condIn = condIn;
        }

        public void setExtraArgs(SDVariable[] extraArgs) {
            this.extraArgs = extraArgs;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof LoopLambdaArgs)) {
                return false;
            }
            LoopLambdaArgs other = (LoopLambdaArgs)o;
            if (!other.canEqual(this)) {
                return false;
            }
            SDVariable this$iterStart = this.getIterStart();
            SDVariable other$iterStart = other.getIterStart();
            if (this$iterStart == null ? other$iterStart != null : !((Object)this$iterStart).equals(other$iterStart)) {
                return false;
            }
            SDVariable this$iterCount = this.getIterCount();
            SDVariable other$iterCount = other.getIterCount();
            if (this$iterCount == null ? other$iterCount != null : !((Object)this$iterCount).equals(other$iterCount)) {
                return false;
            }
            SDVariable this$condIn = this.getCondIn();
            SDVariable other$condIn = other.getCondIn();
            if (this$condIn == null ? other$condIn != null : !((Object)this$condIn).equals(other$condIn)) {
                return false;
            }
            return Arrays.deepEquals(this.getExtraArgs(), other.getExtraArgs());
        }

        protected boolean canEqual(Object other) {
            return other instanceof LoopLambdaArgs;
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            SDVariable $iterStart = this.getIterStart();
            result = result * 59 + ($iterStart == null ? 43 : ((Object)$iterStart).hashCode());
            SDVariable $iterCount = this.getIterCount();
            result = result * 59 + ($iterCount == null ? 43 : ((Object)$iterCount).hashCode());
            SDVariable $condIn = this.getCondIn();
            result = result * 59 + ($condIn == null ? 43 : ((Object)$condIn).hashCode());
            result = result * 59 + Arrays.deepHashCode(this.getExtraArgs());
            return result;
        }

        public String toString() {
            return "ControlFlow.LoopLambdaArgs(iterStart=" + this.getIterStart() + ", iterCount=" + this.getIterCount() + ", condIn=" + this.getCondIn() + ", extraArgs=" + Arrays.deepToString(this.getExtraArgs()) + ")";
        }

        public static class LoopLambdaArgsBuilder {
            private SDVariable iterStart;
            private SDVariable iterCount;
            private SDVariable[] extraArgs;
            private SDVariable condIn;

            LoopLambdaArgsBuilder() {
            }

            public LoopLambdaArgsBuilder iterStart(SDVariable iterStart) {
                this.iterStart = iterStart;
                return this;
            }

            public LoopLambdaArgsBuilder iterCount(SDVariable iterCount) {
                this.iterCount = iterCount;
                return this;
            }

            public LoopLambdaArgsBuilder extraArgs(SDVariable[] extraArgs) {
                this.extraArgs = extraArgs;
                return this;
            }

            public LoopLambdaArgsBuilder condIn(SDVariable condIn) {
                this.condIn = condIn;
                return this;
            }

            public LoopLambdaArgs build() {
                return new LoopLambdaArgs(this.iterStart, this.iterCount, this.extraArgs, this.condIn);
            }

            public String toString() {
                return "ControlFlow.LoopLambdaArgs.LoopLambdaArgsBuilder(iterStart=" + this.iterStart + ", iterCount=" + this.iterCount + ", extraArgs=" + Arrays.deepToString(this.extraArgs) + ", condIn=" + this.condIn + ")";
            }
        }
    }

    public static class LoopParams {
        private String[] outputVarNames;
        private String loopName;
        private SameDiff parent;
        private SameDiff functionBody;
        private String functionName;
        private SDVariable[] loopVars;
        private String[] functionBodyInputs;
        private String[] functionBodyOutputs;

        LoopParams(String[] outputVarNames, String loopName, SameDiff parent, SameDiff functionBody, String functionName, SDVariable[] loopVars, String[] functionBodyInputs, String[] functionBodyOutputs) {
            this.outputVarNames = outputVarNames;
            this.loopName = loopName;
            this.parent = parent;
            this.functionBody = functionBody;
            this.functionName = functionName;
            this.loopVars = loopVars;
            this.functionBodyInputs = functionBodyInputs;
            this.functionBodyOutputs = functionBodyOutputs;
        }

        public static LoopParamsBuilder builder() {
            return new LoopParamsBuilder();
        }

        public String[] getOutputVarNames() {
            return this.outputVarNames;
        }

        public String getLoopName() {
            return this.loopName;
        }

        public SameDiff getParent() {
            return this.parent;
        }

        public SameDiff getFunctionBody() {
            return this.functionBody;
        }

        public String getFunctionName() {
            return this.functionName;
        }

        public SDVariable[] getLoopVars() {
            return this.loopVars;
        }

        public String[] getFunctionBodyInputs() {
            return this.functionBodyInputs;
        }

        public String[] getFunctionBodyOutputs() {
            return this.functionBodyOutputs;
        }

        public void setOutputVarNames(String[] outputVarNames) {
            this.outputVarNames = outputVarNames;
        }

        public void setLoopName(String loopName) {
            this.loopName = loopName;
        }

        public void setParent(SameDiff parent) {
            this.parent = parent;
        }

        public void setFunctionBody(SameDiff functionBody) {
            this.functionBody = functionBody;
        }

        public void setFunctionName(String functionName) {
            this.functionName = functionName;
        }

        public void setLoopVars(SDVariable[] loopVars) {
            this.loopVars = loopVars;
        }

        public void setFunctionBodyInputs(String[] functionBodyInputs) {
            this.functionBodyInputs = functionBodyInputs;
        }

        public void setFunctionBodyOutputs(String[] functionBodyOutputs) {
            this.functionBodyOutputs = functionBodyOutputs;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof LoopParams)) {
                return false;
            }
            LoopParams other = (LoopParams)o;
            if (!other.canEqual(this)) {
                return false;
            }
            if (!Arrays.deepEquals(this.getOutputVarNames(), other.getOutputVarNames())) {
                return false;
            }
            String this$loopName = this.getLoopName();
            String other$loopName = other.getLoopName();
            if (this$loopName == null ? other$loopName != null : !this$loopName.equals(other$loopName)) {
                return false;
            }
            SameDiff this$parent = this.getParent();
            SameDiff other$parent = other.getParent();
            if (this$parent == null ? other$parent != null : !((Object)this$parent).equals(other$parent)) {
                return false;
            }
            SameDiff this$functionBody = this.getFunctionBody();
            SameDiff other$functionBody = other.getFunctionBody();
            if (this$functionBody == null ? other$functionBody != null : !((Object)this$functionBody).equals(other$functionBody)) {
                return false;
            }
            String this$functionName = this.getFunctionName();
            String other$functionName = other.getFunctionName();
            if (this$functionName == null ? other$functionName != null : !this$functionName.equals(other$functionName)) {
                return false;
            }
            if (!Arrays.deepEquals(this.getLoopVars(), other.getLoopVars())) {
                return false;
            }
            if (!Arrays.deepEquals(this.getFunctionBodyInputs(), other.getFunctionBodyInputs())) {
                return false;
            }
            return Arrays.deepEquals(this.getFunctionBodyOutputs(), other.getFunctionBodyOutputs());
        }

        protected boolean canEqual(Object other) {
            return other instanceof LoopParams;
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            result = result * 59 + Arrays.deepHashCode(this.getOutputVarNames());
            String $loopName = this.getLoopName();
            result = result * 59 + ($loopName == null ? 43 : $loopName.hashCode());
            SameDiff $parent = this.getParent();
            result = result * 59 + ($parent == null ? 43 : ((Object)$parent).hashCode());
            SameDiff $functionBody = this.getFunctionBody();
            result = result * 59 + ($functionBody == null ? 43 : ((Object)$functionBody).hashCode());
            String $functionName = this.getFunctionName();
            result = result * 59 + ($functionName == null ? 43 : $functionName.hashCode());
            result = result * 59 + Arrays.deepHashCode(this.getLoopVars());
            result = result * 59 + Arrays.deepHashCode(this.getFunctionBodyInputs());
            result = result * 59 + Arrays.deepHashCode(this.getFunctionBodyOutputs());
            return result;
        }

        public String toString() {
            return "ControlFlow.LoopParams(outputVarNames=" + Arrays.deepToString(this.getOutputVarNames()) + ", loopName=" + this.getLoopName() + ", parent=" + this.getParent() + ", functionBody=" + this.getFunctionBody() + ", functionName=" + this.getFunctionName() + ", loopVars=" + Arrays.deepToString(this.getLoopVars()) + ", functionBodyInputs=" + Arrays.deepToString(this.getFunctionBodyInputs()) + ", functionBodyOutputs=" + Arrays.deepToString(this.getFunctionBodyOutputs()) + ")";
        }

        public static class LoopParamsBuilder {
            private String[] outputVarNames;
            private String loopName;
            private SameDiff parent;
            private SameDiff functionBody;
            private String functionName;
            private SDVariable[] loopVars;
            private String[] functionBodyInputs;
            private String[] functionBodyOutputs;

            LoopParamsBuilder() {
            }

            public LoopParamsBuilder outputVarNames(String[] outputVarNames) {
                this.outputVarNames = outputVarNames;
                return this;
            }

            public LoopParamsBuilder loopName(String loopName) {
                this.loopName = loopName;
                return this;
            }

            public LoopParamsBuilder parent(SameDiff parent) {
                this.parent = parent;
                return this;
            }

            public LoopParamsBuilder functionBody(SameDiff functionBody) {
                this.functionBody = functionBody;
                return this;
            }

            public LoopParamsBuilder functionName(String functionName) {
                this.functionName = functionName;
                return this;
            }

            public LoopParamsBuilder loopVars(SDVariable[] loopVars) {
                this.loopVars = loopVars;
                return this;
            }

            public LoopParamsBuilder functionBodyInputs(String[] functionBodyInputs) {
                this.functionBodyInputs = functionBodyInputs;
                return this;
            }

            public LoopParamsBuilder functionBodyOutputs(String[] functionBodyOutputs) {
                this.functionBodyOutputs = functionBodyOutputs;
                return this;
            }

            public LoopParams build() {
                return new LoopParams(this.outputVarNames, this.loopName, this.parent, this.functionBody, this.functionName, this.loopVars, this.functionBodyInputs, this.functionBodyOutputs);
            }

            public String toString() {
                return "ControlFlow.LoopParams.LoopParamsBuilder(outputVarNames=" + Arrays.deepToString(this.outputVarNames) + ", loopName=" + this.loopName + ", parent=" + this.parent + ", functionBody=" + this.functionBody + ", functionName=" + this.functionName + ", loopVars=" + Arrays.deepToString(this.loopVars) + ", functionBodyInputs=" + Arrays.deepToString(this.functionBodyInputs) + ", functionBodyOutputs=" + Arrays.deepToString(this.functionBodyOutputs) + ")";
            }
        }
    }

    public static class LoopArgs {
        private SDVariable condIn;
        private SDVariable maxIters;
        private SDVariable startIter;
        private SDVariable[] extraArgs;

        public SDVariable[] toArgs() {
            SDVariable[] ret = new SDVariable[3 + this.extraArgs.length];
            ret[0] = this.startIter;
            ret[1] = this.maxIters;
            ret[2] = this.condIn;
            for (int i = 0; i < this.extraArgs.length; ++i) {
                ret[i + 3] = this.extraArgs[i];
            }
            return ret;
        }

        LoopArgs(SDVariable condIn, SDVariable maxIters, SDVariable startIter, SDVariable[] extraArgs) {
            this.condIn = condIn;
            this.maxIters = maxIters;
            this.startIter = startIter;
            this.extraArgs = extraArgs;
        }

        public static LoopArgsBuilder builder() {
            return new LoopArgsBuilder();
        }

        public SDVariable getCondIn() {
            return this.condIn;
        }

        public SDVariable getMaxIters() {
            return this.maxIters;
        }

        public SDVariable getStartIter() {
            return this.startIter;
        }

        public SDVariable[] getExtraArgs() {
            return this.extraArgs;
        }

        public void setCondIn(SDVariable condIn) {
            this.condIn = condIn;
        }

        public void setMaxIters(SDVariable maxIters) {
            this.maxIters = maxIters;
        }

        public void setStartIter(SDVariable startIter) {
            this.startIter = startIter;
        }

        public void setExtraArgs(SDVariable[] extraArgs) {
            this.extraArgs = extraArgs;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof LoopArgs)) {
                return false;
            }
            LoopArgs other = (LoopArgs)o;
            if (!other.canEqual(this)) {
                return false;
            }
            SDVariable this$condIn = this.getCondIn();
            SDVariable other$condIn = other.getCondIn();
            if (this$condIn == null ? other$condIn != null : !((Object)this$condIn).equals(other$condIn)) {
                return false;
            }
            SDVariable this$maxIters = this.getMaxIters();
            SDVariable other$maxIters = other.getMaxIters();
            if (this$maxIters == null ? other$maxIters != null : !((Object)this$maxIters).equals(other$maxIters)) {
                return false;
            }
            SDVariable this$startIter = this.getStartIter();
            SDVariable other$startIter = other.getStartIter();
            if (this$startIter == null ? other$startIter != null : !((Object)this$startIter).equals(other$startIter)) {
                return false;
            }
            return Arrays.deepEquals(this.getExtraArgs(), other.getExtraArgs());
        }

        protected boolean canEqual(Object other) {
            return other instanceof LoopArgs;
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            SDVariable $condIn = this.getCondIn();
            result = result * 59 + ($condIn == null ? 43 : ((Object)$condIn).hashCode());
            SDVariable $maxIters = this.getMaxIters();
            result = result * 59 + ($maxIters == null ? 43 : ((Object)$maxIters).hashCode());
            SDVariable $startIter = this.getStartIter();
            result = result * 59 + ($startIter == null ? 43 : ((Object)$startIter).hashCode());
            result = result * 59 + Arrays.deepHashCode(this.getExtraArgs());
            return result;
        }

        public String toString() {
            return "ControlFlow.LoopArgs(condIn=" + this.getCondIn() + ", maxIters=" + this.getMaxIters() + ", startIter=" + this.getStartIter() + ", extraArgs=" + Arrays.deepToString(this.getExtraArgs()) + ")";
        }

        public static class LoopArgsBuilder {
            private SDVariable condIn;
            private SDVariable maxIters;
            private SDVariable startIter;
            private SDVariable[] extraArgs;

            LoopArgsBuilder() {
            }

            public LoopArgsBuilder condIn(SDVariable condIn) {
                this.condIn = condIn;
                return this;
            }

            public LoopArgsBuilder maxIters(SDVariable maxIters) {
                this.maxIters = maxIters;
                return this;
            }

            public LoopArgsBuilder startIter(SDVariable startIter) {
                this.startIter = startIter;
                return this;
            }

            public LoopArgsBuilder extraArgs(SDVariable[] extraArgs) {
                this.extraArgs = extraArgs;
                return this;
            }

            public LoopArgs build() {
                return new LoopArgs(this.condIn, this.maxIters, this.startIter, this.extraArgs);
            }

            public String toString() {
                return "ControlFlow.LoopArgs.LoopArgsBuilder(condIn=" + this.condIn + ", maxIters=" + this.maxIters + ", startIter=" + this.startIter + ", extraArgs=" + Arrays.deepToString(this.extraArgs) + ")";
            }
        }
    }
}

