package org.deeplearning4j.nn.updater;

import java.util.ArrayList;
import java.util.List;
import org.deeplearning4j.nn.api.Trainable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.learning.GradientUpdater;
import org.nd4j.linalg.learning.regularization.Regularization;

/* loaded from: input_file:org/deeplearning4j/nn/updater/UpdaterBlock.class */
public class UpdaterBlock {
    private int paramOffsetStart;
    private int paramOffsetEnd;
    private int updaterViewOffsetStart;
    private int updaterViewOffsetEnd;
    private List<ParamState> layersAndVariablesInBlock;
    private INDArray updaterView;
    private INDArray gradientView;
    private boolean updaterViewRequiresInitialization;
    private GradientUpdater gradientUpdater;

    /* loaded from: input_file:org/deeplearning4j/nn/updater/UpdaterBlock$ParamState.class */
    public static class ParamState {
        private final Trainable layer;
        private final String paramName;
        private final int paramOffsetStart;
        private final int paramOffsetEnd;
        private final INDArray paramView;
        private final INDArray gradView;

        public ParamState(Trainable trainable, String str, int i, int i2, INDArray iNDArray, INDArray iNDArray2) {
            this.layer = trainable;
            this.paramName = str;
            this.paramOffsetStart = i;
            this.paramOffsetEnd = i2;
            this.paramView = iNDArray;
            this.gradView = iNDArray2;
        }

        public Trainable getLayer() {
            return this.layer;
        }

        public String getParamName() {
            return this.paramName;
        }

        public int getParamOffsetStart() {
            return this.paramOffsetStart;
        }

        public int getParamOffsetEnd() {
            return this.paramOffsetEnd;
        }

        public INDArray getParamView() {
            return this.paramView;
        }

        public INDArray getGradView() {
            return this.gradView;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof ParamState)) {
                return false;
            }
            ParamState paramState = (ParamState) obj;
            if (!paramState.canEqual(this) || getParamOffsetStart() != paramState.getParamOffsetStart() || getParamOffsetEnd() != paramState.getParamOffsetEnd()) {
                return false;
            }
            Trainable layer = getLayer();
            Trainable layer2 = paramState.getLayer();
            if (layer == null) {
                if (layer2 != null) {
                    return false;
                }
            } else if (!layer.equals(layer2)) {
                return false;
            }
            String paramName = getParamName();
            String paramName2 = paramState.getParamName();
            if (paramName == null) {
                if (paramName2 != null) {
                    return false;
                }
            } else if (!paramName.equals(paramName2)) {
                return false;
            }
            INDArray paramView = getParamView();
            INDArray paramView2 = paramState.getParamView();
            if (paramView == null) {
                if (paramView2 != null) {
                    return false;
                }
            } else if (!paramView.equals(paramView2)) {
                return false;
            }
            INDArray gradView = getGradView();
            INDArray gradView2 = paramState.getGradView();
            return gradView == null ? gradView2 == null : gradView.equals(gradView2);
        }

        protected boolean canEqual(Object obj) {
            return obj instanceof ParamState;
        }

        public int hashCode() {
            int paramOffsetStart = (((1 * 59) + getParamOffsetStart()) * 59) + getParamOffsetEnd();
            Trainable layer = getLayer();
            int hashCode = (paramOffsetStart * 59) + (layer == null ? 43 : layer.hashCode());
            String paramName = getParamName();
            int hashCode2 = (hashCode * 59) + (paramName == null ? 43 : paramName.hashCode());
            INDArray paramView = getParamView();
            int hashCode3 = (hashCode2 * 59) + (paramView == null ? 43 : paramView.hashCode());
            INDArray gradView = getGradView();
            return (hashCode3 * 59) + (gradView == null ? 43 : gradView.hashCode());
        }

        public String toString() {
            return "UpdaterBlock.ParamState(layer=" + getLayer() + ", paramName=" + getParamName() + ", paramOffsetStart=" + getParamOffsetStart() + ", paramOffsetEnd=" + getParamOffsetEnd() + ", paramView=" + getParamView() + ", gradView=" + getGradView() + ")";
        }
    }

    public UpdaterBlock(int i, int i2, int i3, int i4, List<ParamState> list) {
        this.layersAndVariablesInBlock = new ArrayList();
        this.paramOffsetStart = i;
        this.paramOffsetEnd = i2;
        this.updaterViewOffsetStart = i3;
        this.updaterViewOffsetEnd = i4;
        this.layersAndVariablesInBlock = list;
    }

    public void init() {
        if (this.gradientUpdater == null) {
            ParamState paramState = this.layersAndVariablesInBlock.get(0);
            this.gradientUpdater = paramState.getLayer().getConfig().getUpdaterByParam(paramState.getParamName()).instantiate(this.updaterView, this.updaterViewRequiresInitialization);
        }
    }

    public boolean isPretrainUpdaterBlock() {
        ParamState paramState = this.layersAndVariablesInBlock.get(0);
        return paramState.getLayer().getConfig().isPretrainParam(paramState.getParamName());
    }

    public boolean skipDueToPretrainConfig(boolean z) {
        return isPretrainUpdaterBlock() && !z;
    }

    public GradientUpdater getGradientUpdater() {
        if (this.gradientUpdater == null) {
            init();
        }
        return this.gradientUpdater;
    }

    public void update(int i, int i2) {
        update(i, i2, false, this.gradientView, null);
    }

    public void updateExternalGradient(int i, int i2, INDArray iNDArray, INDArray iNDArray2) {
        update(i, i2, true, iNDArray, iNDArray2);
    }

    private void update(int i, int i2, boolean z, INDArray iNDArray, INDArray iNDArray2) {
        if (this.gradientUpdater == null) {
            init();
        }
        INDArray iNDArray3 = z ? iNDArray.reshape(new long[]{iNDArray.length()}).get(new INDArrayIndex[]{NDArrayIndex.interval(this.paramOffsetStart, this.paramOffsetEnd)}) : this.gradientView;
        if (this.layersAndVariablesInBlock.get(0).getLayer().numParams() == 0) {
            return;
        }
        applyRegularizationAllVariables(Regularization.ApplyStep.BEFORE_UPDATER, i, i2, z, iNDArray, iNDArray2);
        this.gradientUpdater.applyUpdater(iNDArray3.reshape(new long[]{iNDArray3.length()}), i, i2);
        applyRegularizationAllVariables(Regularization.ApplyStep.POST_UPDATER, i, i2, z, iNDArray, iNDArray2);
    }

    protected void applyRegularizationAllVariables(Regularization.ApplyStep applyStep, int i, int i2, boolean z, INDArray iNDArray, INDArray iNDArray2) {
        INDArray paramView;
        INDArray gradView;
        for (ParamState paramState : this.layersAndVariablesInBlock) {
            if (iNDArray2 != null) {
                iNDArray2 = iNDArray2.reshape(new long[]{iNDArray2.length()});
            }
            if (z) {
                paramView = iNDArray2.get(new INDArrayIndex[]{NDArrayIndex.interval(paramState.getParamOffsetStart(), paramState.getParamOffsetEnd())});
                gradView = iNDArray.reshape(new long[]{iNDArray.length()}).get(new INDArrayIndex[]{NDArrayIndex.interval(paramState.getParamOffsetStart(), paramState.getParamOffsetEnd())});
            } else {
                paramView = paramState.getParamView();
                gradView = paramState.getGradView();
            }
            applyRegularization(applyStep, paramState.getLayer(), paramState.getParamName(), gradView, paramView, i, i2, this.gradientUpdater.getConfig().hasLearningRate() ? this.gradientUpdater.getConfig().getLearningRate(i, i2) : 1.0d);
        }
    }

    protected void applyRegularization(Regularization.ApplyStep applyStep, Trainable trainable, String str, INDArray iNDArray, INDArray iNDArray2, int i, int i2, double d) {
        List<Regularization> regularizationByParam = trainable.getConfig().getRegularizationByParam(str);
        if (regularizationByParam == null || regularizationByParam.isEmpty()) {
            return;
        }
        for (Regularization regularization : regularizationByParam) {
            if (regularization.applyStep() == applyStep) {
                regularization.apply(iNDArray2, iNDArray, d, i, i2);
            }
        }
    }

    public int getParamOffsetStart() {
        return this.paramOffsetStart;
    }

    public int getParamOffsetEnd() {
        return this.paramOffsetEnd;
    }

    public int getUpdaterViewOffsetStart() {
        return this.updaterViewOffsetStart;
    }

    public int getUpdaterViewOffsetEnd() {
        return this.updaterViewOffsetEnd;
    }

    public List<ParamState> getLayersAndVariablesInBlock() {
        return this.layersAndVariablesInBlock;
    }

    public INDArray getUpdaterView() {
        return this.updaterView;
    }

    public INDArray getGradientView() {
        return this.gradientView;
    }

    public boolean isUpdaterViewRequiresInitialization() {
        return this.updaterViewRequiresInitialization;
    }

    public void setParamOffsetStart(int i) {
        this.paramOffsetStart = i;
    }

    public void setParamOffsetEnd(int i) {
        this.paramOffsetEnd = i;
    }

    public void setUpdaterViewOffsetStart(int i) {
        this.updaterViewOffsetStart = i;
    }

    public void setUpdaterViewOffsetEnd(int i) {
        this.updaterViewOffsetEnd = i;
    }

    public void setLayersAndVariablesInBlock(List<ParamState> list) {
        this.layersAndVariablesInBlock = list;
    }

    public void setUpdaterView(INDArray iNDArray) {
        this.updaterView = iNDArray;
    }

    public void setGradientView(INDArray iNDArray) {
        this.gradientView = iNDArray;
    }

    public void setUpdaterViewRequiresInitialization(boolean z) {
        this.updaterViewRequiresInitialization = z;
    }

    public void setGradientUpdater(GradientUpdater gradientUpdater) {
        this.gradientUpdater = gradientUpdater;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof UpdaterBlock)) {
            return false;
        }
        UpdaterBlock updaterBlock = (UpdaterBlock) obj;
        if (!updaterBlock.canEqual(this) || getParamOffsetStart() != updaterBlock.getParamOffsetStart() || getParamOffsetEnd() != updaterBlock.getParamOffsetEnd() || getUpdaterViewOffsetStart() != updaterBlock.getUpdaterViewOffsetStart() || getUpdaterViewOffsetEnd() != updaterBlock.getUpdaterViewOffsetEnd() || isUpdaterViewRequiresInitialization() != updaterBlock.isUpdaterViewRequiresInitialization()) {
            return false;
        }
        List<ParamState> layersAndVariablesInBlock = getLayersAndVariablesInBlock();
        List<ParamState> layersAndVariablesInBlock2 = updaterBlock.getLayersAndVariablesInBlock();
        if (layersAndVariablesInBlock == null) {
            if (layersAndVariablesInBlock2 != null) {
                return false;
            }
        } else if (!layersAndVariablesInBlock.equals(layersAndVariablesInBlock2)) {
            return false;
        }
        INDArray updaterView = getUpdaterView();
        INDArray updaterView2 = updaterBlock.getUpdaterView();
        if (updaterView == null) {
            if (updaterView2 != null) {
                return false;
            }
        } else if (!updaterView.equals(updaterView2)) {
            return false;
        }
        INDArray gradientView = getGradientView();
        INDArray gradientView2 = updaterBlock.getGradientView();
        if (gradientView == null) {
            if (gradientView2 != null) {
                return false;
            }
        } else if (!gradientView.equals(gradientView2)) {
            return false;
        }
        GradientUpdater gradientUpdater = getGradientUpdater();
        GradientUpdater gradientUpdater2 = updaterBlock.getGradientUpdater();
        return gradientUpdater == null ? gradientUpdater2 == null : gradientUpdater.equals(gradientUpdater2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof UpdaterBlock;
    }

    public int hashCode() {
        int paramOffsetStart = (((((((((1 * 59) + getParamOffsetStart()) * 59) + getParamOffsetEnd()) * 59) + getUpdaterViewOffsetStart()) * 59) + getUpdaterViewOffsetEnd()) * 59) + (isUpdaterViewRequiresInitialization() ? 79 : 97);
        List<ParamState> layersAndVariablesInBlock = getLayersAndVariablesInBlock();
        int hashCode = (paramOffsetStart * 59) + (layersAndVariablesInBlock == null ? 43 : layersAndVariablesInBlock.hashCode());
        INDArray updaterView = getUpdaterView();
        int hashCode2 = (hashCode * 59) + (updaterView == null ? 43 : updaterView.hashCode());
        INDArray gradientView = getGradientView();
        int hashCode3 = (hashCode2 * 59) + (gradientView == null ? 43 : gradientView.hashCode());
        GradientUpdater gradientUpdater = getGradientUpdater();
        return (hashCode3 * 59) + (gradientUpdater == null ? 43 : gradientUpdater.hashCode());
    }

    public String toString() {
        return "UpdaterBlock(paramOffsetStart=" + getParamOffsetStart() + ", paramOffsetEnd=" + getParamOffsetEnd() + ", updaterViewOffsetStart=" + getUpdaterViewOffsetStart() + ", updaterViewOffsetEnd=" + getUpdaterViewOffsetEnd() + ", layersAndVariablesInBlock=" + getLayersAndVariablesInBlock() + ", updaterView=" + getUpdaterView() + ", gradientView=" + getGradientView() + ", updaterViewRequiresInitialization=" + isUpdaterViewRequiresInitialization() + ", gradientUpdater=" + getGradientUpdater() + ")";
    }
}
