package org.deeplearning4j.util;

import java.util.Arrays;
import lombok.NonNull;
import org.deeplearning4j.exception.DL4JInvalidConfigException;
import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.graph.MergeVertex;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.Convolution1DLayer;
import org.deeplearning4j.nn.conf.layers.Convolution3D;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.Deconvolution2D;
import org.deeplearning4j.nn.conf.layers.Deconvolution3D;
import org.deeplearning4j.nn.conf.layers.DepthwiseConvolution2D;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.SeparableConvolution2D;
import org.deeplearning4j.nn.conf.layers.SpaceToBatchLayer;
import org.deeplearning4j.nn.conf.layers.SpaceToDepthLayer;
import org.deeplearning4j.nn.conf.layers.Subsampling3DLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.conf.layers.Upsampling2D;
import org.deeplearning4j.nn.conf.layers.Upsampling3D;
import org.deeplearning4j.nn.conf.layers.ZeroPaddingLayer;
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping2D;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastCopyOp;
import org.nd4j.linalg.api.ops.impl.layers.convolution.MaxPooling2D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.PaddingMode;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig;
import org.nd4j.linalg.api.ops.impl.transforms.custom.Assign;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.exception.ND4JArraySizeException;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/util/ConvolutionUtils.class */
public class ConvolutionUtils {
    public static final String NCHW_NHWC_ERROR_MSG = "Note: Convolution layers can be configured for either NCHW (channels first) or NHWC (channels last) format for input images and activations.\nLayers can be configured using .dataFormat(CNN2DFormat.NCHW/NHWC) when constructing the layer, or for the entire net using .setInputType(InputType.convolutional(height, width, depth, CNN2DForman.NCHW/NHWC)).\nImageRecordReader and NativeImageLoader can also be configured to load image data in either NCHW or NHWC format which must match the network";
    private static final int[] ONES = {1, 1};

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.deeplearning4j.util.ConvolutionUtils$1, reason: invalid class name */
    /* loaded from: input_file:org/deeplearning4j/util/ConvolutionUtils$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$deeplearning4j$nn$conf$ConvolutionMode = new int[ConvolutionMode.values().length];

        static {
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$ConvolutionMode[ConvolutionMode.Same.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$ConvolutionMode[ConvolutionMode.Causal.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$ConvolutionMode[ConvolutionMode.Strict.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$ConvolutionMode[ConvolutionMode.Truncate.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
        }
    }

    private ConvolutionUtils() {
    }

    public static int[] getIntConfig(int[] iArr, int i) {
        return (iArr == null || iArr.length >= 2) ? iArr.length == 2 ? iArr : new int[]{i, i} : new int[]{iArr[0], iArr[0]};
    }

    @Deprecated
    public static int[] getOutputSize(INDArray iNDArray, int[] iArr, int[] iArr2, int[] iArr3, ConvolutionMode convolutionMode) {
        return getOutputSize(iNDArray, iArr, iArr2, iArr3, convolutionMode, ONES);
    }

    public static int[] getDeconvolutionOutputSize(INDArray iNDArray, int[] iArr, int[] iArr2, int[] iArr3, ConvolutionMode convolutionMode, int[] iArr4, CNN2DFormat cNN2DFormat) {
        boolean z = cNN2DFormat == CNN2DFormat.NCHW;
        int i = z ? 2 : 1;
        int i2 = z ? 3 : 2;
        if (iNDArray.size(i) > 2147483647L || iNDArray.size(i2) > 2147483647L) {
            throw new ND4JArraySizeException();
        }
        int size = (int) iNDArray.size(i);
        int size2 = (int) iNDArray.size(i2);
        int[] effectiveKernelSize = effectiveKernelSize(iArr, iArr4);
        return convolutionMode == ConvolutionMode.Same ? new int[]{iArr2[0] * size, iArr2[1] * size2} : new int[]{((iArr2[0] * (size - 1)) + effectiveKernelSize[0]) - (2 * iArr3[0]), ((iArr2[1] * (size2 - 1)) + effectiveKernelSize[1]) - (2 * iArr3[1])};
    }

    public static long[] getDeconvolution3DOutputSize(INDArray iNDArray, int[] iArr, int[] iArr2, int[] iArr3, int[] iArr4, ConvolutionMode convolutionMode, Convolution3D.DataFormat dataFormat) {
        long size;
        long size2;
        long size3;
        if (dataFormat == Convolution3D.DataFormat.NCDHW) {
            size = iNDArray.size(2);
            size2 = iNDArray.size(3);
            size3 = iNDArray.size(4);
        } else {
            size = iNDArray.size(1);
            size2 = iNDArray.size(2);
            size3 = iNDArray.size(3);
        }
        int[] effectiveKernelSize = effectiveKernelSize(iArr, iArr4);
        return convolutionMode == ConvolutionMode.Same ? new long[]{iArr2[0] * size, iArr2[1] * size2, iArr2[2] * size3} : new long[]{((iArr2[0] * (size - 1)) + effectiveKernelSize[0]) - (2 * iArr3[0]), ((iArr2[1] * (size2 - 1)) + effectiveKernelSize[1]) - (2 * iArr3[1]), ((iArr2[2] * (size3 - 1)) + effectiveKernelSize[2]) - (2 * iArr3[2])};
    }

    @Deprecated
    public static int[] getOutputSize(INDArray iNDArray, int[] iArr, int[] iArr2, int[] iArr3, ConvolutionMode convolutionMode, int[] iArr4) {
        return getOutputSize(iNDArray, iArr, iArr2, iArr3, convolutionMode, iArr4, CNN2DFormat.NCHW);
    }

    public static boolean layerHasConvolutionLayout(Layer layer) {
        return !((layer instanceof Convolution3D) || (layer instanceof Subsampling3DLayer) || (layer instanceof Deconvolution3D) || !(layer instanceof ConvolutionLayer) || (layer instanceof Upsampling3D) || !(layer instanceof SubsamplingLayer)) || (layer instanceof SpaceToBatchLayer) || (layer instanceof Upsampling2D) || (layer instanceof SpaceToDepthLayer) || (layer instanceof ZeroPaddingLayer) || (layer instanceof SeparableConvolution2D) || (layer instanceof Deconvolution2D) || (layer instanceof Cropping2D) || (layer instanceof DepthwiseConvolution2D);
    }

    public static CNN2DFormat getFormatForLayer(Layer layer) {
        if (layer instanceof Convolution1DLayer) {
            return ((Convolution1DLayer) layer).getCnn2dDataFormat();
        }
        if (layer instanceof ConvolutionLayer) {
            return ((ConvolutionLayer) layer).getCnn2dDataFormat();
        }
        if (layer instanceof SubsamplingLayer) {
            return ((SubsamplingLayer) layer).getCnn2dDataFormat();
        }
        if (layer instanceof SpaceToBatchLayer) {
            return ((SpaceToBatchLayer) layer).getFormat();
        }
        if (layer instanceof Upsampling2D) {
            return ((Upsampling2D) layer).getFormat();
        }
        if (layer instanceof SpaceToDepthLayer) {
            return ((SpaceToDepthLayer) layer).getDataFormat();
        }
        if (layer instanceof ZeroPaddingLayer) {
            return ((ZeroPaddingLayer) layer).getDataFormat();
        }
        if (layer instanceof SeparableConvolution2D) {
            return ((SeparableConvolution2D) layer).getCnn2dDataFormat();
        }
        if (layer instanceof Deconvolution2D) {
            return ((Deconvolution2D) layer).getCnn2dDataFormat();
        }
        if (layer instanceof DepthwiseConvolution2D) {
            return ((DepthwiseConvolution2D) layer).getCnn2dDataFormat();
        }
        if (layer instanceof Cropping2D) {
            return ((Cropping2D) layer).getDataFormat();
        }
        throw new IllegalArgumentException("Illegal type given " + layer.getClass().getName());
    }

    public static PaddingMode paddingModeForConvolutionMode(ConvolutionMode convolutionMode) {
        switch (AnonymousClass1.$SwitchMap$org$deeplearning4j$nn$conf$ConvolutionMode[convolutionMode.ordinal()]) {
            case MergeVertex.DEFAULT_MERGE_DIM /* 1 */:
                return PaddingMode.SAME;
            case 2:
                return PaddingMode.CAUSAL;
            case 3:
            case 4:
                return PaddingMode.VALID;
            default:
                throw new IllegalArgumentException("Invalid input convolution mode: " + convolutionMode);
        }
    }

    public static int[] getOutputSize(INDArray iNDArray, int[] iArr, int[] iArr2, int[] iArr3, ConvolutionMode convolutionMode, int[] iArr4, CNN2DFormat cNN2DFormat) {
        int i = 2;
        int i2 = 3;
        if (cNN2DFormat == CNN2DFormat.NHWC) {
            i = 1;
            i2 = 2;
        }
        if (iNDArray.size(i) > 2147483647L || iNDArray.size(i2) > 2147483647L) {
            throw new ND4JArraySizeException();
        }
        int size = (int) iNDArray.size(i);
        int size2 = (int) iNDArray.size(i2);
        int[] effectiveKernelSize = effectiveKernelSize(iArr, iArr4);
        validateShapes(iNDArray, effectiveKernelSize, iArr2, iArr3, convolutionMode, iArr4, new int[]{size, size2}, effectiveKernelSize == iArr);
        return (convolutionMode == ConvolutionMode.Same || convolutionMode == ConvolutionMode.Causal) ? new int[]{(int) Math.ceil(size / iArr2[0]), (int) Math.ceil(size2 / iArr2[1])} : new int[]{(((size - effectiveKernelSize[0]) + (2 * iArr3[0])) / iArr2[0]) + 1, (((size2 - effectiveKernelSize[1]) + (2 * iArr3[1])) / iArr2[1]) + 1};
    }

    public static void validateShapes(INDArray iNDArray, int[] iArr, int[] iArr2, int[] iArr3, ConvolutionMode convolutionMode, int[] iArr4, int[] iArr5, boolean z) {
        int i = iArr5[0];
        int i2 = iArr5[1];
        boolean z2 = convolutionMode == ConvolutionMode.Truncate;
        if (z2 && (iArr[0] <= 0 || iArr[0] > i + (2 * iArr3[0]))) {
            StringBuilder sb = new StringBuilder();
            sb.append("Invalid input data or configuration: ");
            if (z) {
                sb.append("effective ");
            }
            sb.append("kernel height and input height must satisfy 0 < ");
            if (z) {
                sb.append("effective ");
            }
            sb.append("kernel height <= input height + 2 * padding height. \nGot ");
            if (z) {
                sb.append("effective ");
            }
            sb.append("kernel height = ").append(iArr[0]).append(", input height = ").append(i).append(" and padding height = ").append(iArr3[0]).append(" which do not satisfy 0 < ").append(iArr[0]).append(" <= ").append(i + (2 * iArr3[0])).append(getCommonErrorMsg(iNDArray, iArr, iArr2, iArr3, iArr4));
            throw new DL4JInvalidInputException(sb.toString());
        }
        if (z2 && (iArr[1] <= 0 || iArr[1] > i2 + (2 * iArr3[1]))) {
            StringBuilder sb2 = new StringBuilder();
            sb2.append("Invalid input data or configuration: ");
            if (z) {
                sb2.append("effective ");
            }
            sb2.append("kernel width and input width must satisfy  0 < kernel width <= input width + 2 * padding width. ");
            sb2.append("\nGot ");
            if (z) {
                sb2.append("effective ");
            }
            sb2.append("kernel width = ").append(iArr[1]).append(", input width = ").append(i2).append(" and padding width = ").append(iArr3[1]).append(" which do not satisfy 0 < ").append(iArr[1]).append(" <= ").append(i2 + (2 * iArr3[1])).append("\nInput size: [numExamples,inputDepth,inputHeight,inputWidth]=").append(Arrays.toString(iNDArray.shape())).append(getCommonErrorMsg(iNDArray, iArr, iArr2, iArr3, iArr4));
            throw new DL4JInvalidInputException(sb2.toString());
        }
        if (iArr.length == 3 && z2 && (iArr[2] <= 0 || iArr[2] > iArr5[2] + (2 * iArr3[2]))) {
            int i3 = iArr5[2];
            StringBuilder sb3 = new StringBuilder();
            sb3.append("Invalid input data or configuration: ");
            if (z) {
                sb3.append("effective ");
            }
            sb3.append("kernel channels and input channels must satisfy 0 < ");
            if (z) {
                sb3.append("effective ");
            }
            sb3.append("kernel channels <= input channels + 2 * padding channels. \nGot ");
            if (z) {
                sb3.append("effective ");
            }
            sb3.append("kernel channels = ").append(iArr[2]).append(", input channels = ").append(i3).append(" and padding height = ").append(iArr3[2]).append(" which do not satisfy 0 < ").append(iArr[2]).append(" <= ").append(i3 + (2 * iArr3[2])).append(getCommonErrorMsg(iNDArray, iArr, iArr2, iArr3, iArr4));
            throw new DL4JInvalidInputException(sb3.toString());
        }
        if (convolutionMode == ConvolutionMode.Strict) {
            if (((i - iArr[0]) + (2 * iArr3[0])) % iArr2[0] != 0) {
                double d = (((i - iArr[0]) + (2 * iArr3[0])) / iArr2[0]) + 1.0d;
                String format = String.format("%.2f", Double.valueOf(d));
                int i4 = (int) d;
                int ceil = (int) Math.ceil(i / iArr2[0]);
                StringBuilder sb4 = new StringBuilder();
                sb4.append("Invalid input data or configuration: Combination of kernel size, stride and padding are not valid for given input height, using ConvolutionMode.Strict\n").append("ConvolutionMode.Strict requires: output height = (input height - kernelSize + 2*padding)/stride + 1 to be an integer. Got: (").append(i).append(" - ").append(iArr[0]).append(" + 2*").append(iArr3[0]).append(")/").append(iArr2[0]).append(" + 1 = ").append(format).append("\n").append("See \"Constraints on strides\" at http://cs231n.github.io/convolutional-networks/ and ConvolutionType enumeration Javadoc.\n").append("To truncate/crop the input, such that output height = floor(").append(format).append(") = ").append(i4).append(", use ConvolutionType.Truncate.\n").append("Alternatively use ConvolutionType.Same, which will use padding to give an output height of ceil(").append(i).append("/").append(iArr2[0]).append(")=").append(ceil).append(getCommonErrorMsg(iNDArray, iArr, iArr2, iArr3, iArr4));
                throw new DL4JInvalidConfigException(sb4.toString());
            }
            if (((i2 - iArr[1]) + (2 * iArr3[1])) % iArr2[1] != 0) {
                double d2 = (((i2 - iArr[1]) + (2 * iArr3[1])) / iArr2[1]) + 1.0d;
                String format2 = String.format("%.2f", Double.valueOf(d2));
                int i5 = (int) d2;
                int ceil2 = (int) Math.ceil(i2 / iArr2[1]);
                StringBuilder sb5 = new StringBuilder();
                sb5.append("Invalid input data or configuration: Combination of kernel size, stride and padding are not valid for given input width, using ConvolutionMode.Strict\n").append("ConvolutionMode.Strict requires: output width = (input - kernelSize + 2*padding)/stride + 1 to be an integer. Got: (").append(i2).append(" - ").append(iArr[1]).append(" + 2*").append(iArr3[1]).append(")/").append(iArr2[1]).append(" + 1 = ").append(format2).append("\n").append("See \"Constraints on strides\" at http://cs231n.github.io/convolutional-networks/ and ConvolutionType enumeration Javadoc.\n").append("To truncate/crop the input, such that output width = floor(").append(format2).append(") = ").append(i5).append(", use ConvolutionType.Truncate.\n").append("Alternatively use ConvolutionType.Same, which will use padding to give an output width of ceil(").append(i2).append("/").append(iArr2[1]).append(")=").append(ceil2).append(getCommonErrorMsg(iNDArray, iArr, iArr2, iArr3, iArr4));
                throw new DL4JInvalidConfigException(sb5.toString());
            }
            if (iArr.length != 3 || ((iArr5[2] - iArr[2]) + (2 * iArr3[2])) % iArr2[2] == 0) {
                return;
            }
            int i6 = iArr5[2];
            double d3 = (((i6 - iArr[2]) + (2 * iArr3[2])) / iArr2[2]) + 1.0d;
            String format3 = String.format("%.2f", Double.valueOf(d3));
            int i7 = (int) d3;
            int ceil3 = (int) Math.ceil(i6 / iArr2[2]);
            StringBuilder sb6 = new StringBuilder();
            sb6.append("Invalid input data or configuration: Combination of kernel size, stride and padding are not valid for given input width, using ConvolutionMode.Strict\n").append("ConvolutionMode.Strict requires: output channels = (input - kernelSize + 2*padding)/stride + 1 to be an integer. Got: (").append(i6).append(" - ").append(iArr[2]).append(" + 2*").append(iArr3[2]).append(")/").append(iArr2[1]).append(" + 1 = ").append(format3).append("\n").append("See \"Constraints on strides\" at http://cs231n.github.io/convolutional-networks/ and ConvolutionType enumeration Javadoc.\n").append("To truncate/crop the input, such that output width = floor(").append(format3).append(") = ").append(i7).append(", use ConvolutionType.Truncate.\n").append("Alternatively use ConvolutionType.Same, which will use padding to give an output width of ceil(").append(i2).append("/").append(iArr2[2]).append(")=").append(ceil3).append(getCommonErrorMsg(iNDArray, iArr, iArr2, iArr3, iArr4));
            throw new DL4JInvalidConfigException(sb6.toString());
        }
    }

    public static int[] effectiveKernelSize(int[] iArr, int[] iArr2) {
        if (iArr.length == 2) {
            return (iArr2[0] == 1 && iArr2[1] == 1) ? iArr : new int[]{iArr[0] + ((iArr[0] - 1) * (iArr2[0] - 1)), iArr[1] + ((iArr[1] - 1) * (iArr2[1] - 1))};
        }
        if (iArr.length == 3) {
            return (iArr2[0] == 1 && iArr2[1] == 1 && iArr2[2] == 1) ? iArr : new int[]{iArr[0] + ((iArr[0] - 1) * (iArr2[0] - 1)), iArr[1] + ((iArr[1] - 1) * (iArr2[1] - 1)), iArr[2] + ((iArr[2] - 1) * (iArr2[2] - 1))};
        }
        throw new IllegalArgumentException("Kernel size has to be either two or three, got: " + iArr.length);
    }

    private static String getCommonErrorMsg(INDArray iNDArray, int[] iArr, int[] iArr2, int[] iArr3, int[] iArr4) {
        String str = "\nInput size: [numExamples,inputDepth,inputHeight,inputWidth]=" + Arrays.toString(iNDArray.shape()) + ", inputKernel=" + Arrays.toString(iArr);
        if (iArr4[0] != 1 || iArr4[1] != 1) {
            str = str + ", effectiveKernelGivenDilation=" + Arrays.toString(effectiveKernelSize(iArr, iArr4));
        }
        return str + ", strides=" + Arrays.toString(iArr2) + ", padding=" + Arrays.toString(iArr3) + ", dilation=" + Arrays.toString(iArr4);
    }

    public static int[] getSameModeTopLeftPadding(int[] iArr, int[] iArr2, int[] iArr3, int[] iArr4, int[] iArr5) {
        int[] effectiveKernelSize = effectiveKernelSize(iArr3, iArr5);
        int[] iArr6 = new int[iArr3.length];
        boolean z = true;
        for (int i = 0; i < iArr3.length; i++) {
            iArr6[i] = ((((iArr[i] - 1) * iArr4[i]) + effectiveKernelSize[i]) - iArr2[i]) / 2;
            z &= iArr6[i] >= 0;
        }
        Preconditions.checkState(z, "Invalid padding values calculated: %s - layer configuration is invalid? Input size %s, output size %s, kernel %s, strides %s, dilation %s", iArr6, iArr2, iArr, iArr3, iArr4, iArr5);
        return iArr6;
    }

    public static int[] getSameModeBottomRightPadding(int[] iArr, int[] iArr2, int[] iArr3, int[] iArr4, int[] iArr5) {
        int[] effectiveKernelSize = effectiveKernelSize(iArr3, iArr5);
        int[] iArr6 = {(((((iArr[0] - 1) * iArr4[0]) + effectiveKernelSize[0]) - iArr2[0]) + 1) / 2, (((((iArr[1] - 1) * iArr4[1]) + effectiveKernelSize[1]) - iArr2[1]) + 1) / 2};
        Preconditions.checkState(iArr6[0] >= 0 && iArr6[1] >= 0, "Invalid padding values calculated: %s - layer configuration is invalid? Input size %s, output size %s, kernel %s, strides %s, dilation %s", iArr6, iArr2, iArr, iArr3, iArr4, iArr5);
        return iArr6;
    }

    public static int[] getHeightAndWidth(NeuralNetConfiguration neuralNetConfiguration) {
        return getHeightAndWidth(((ConvolutionLayer) neuralNetConfiguration.getLayer()).getKernelSize());
    }

    public static long numFeatureMap(NeuralNetConfiguration neuralNetConfiguration) {
        return ((ConvolutionLayer) neuralNetConfiguration.getLayer()).getNOut();
    }

    public static int[] getHeightAndWidth(int[] iArr) {
        if (iArr.length < 2) {
            throw new IllegalArgumentException("No width and height able to be found: array must be at least length 2");
        }
        return new int[]{iArr[iArr.length - 1], iArr[iArr.length - 2]};
    }

    public static int numChannels(int[] iArr) {
        if (iArr.length < 4) {
            return 1;
        }
        return iArr[1];
    }

    public static void validateConvolutionModePadding(ConvolutionMode convolutionMode, int[] iArr) {
        if (convolutionMode == ConvolutionMode.Same) {
            boolean z = true;
            for (int i : iArr) {
                if (i != 0) {
                    z = false;
                }
            }
            if (!z) {
                throw new IllegalArgumentException("Padding cannot be used when using the `same' convolution mode");
            }
        }
    }

    public static void validateCnnKernelStridePadding(int[] iArr, int[] iArr2, int[] iArr3) {
        if (iArr == null || iArr.length != 2) {
            throw new IllegalStateException("Invalid kernel size: expected int[] of length 2, got " + (iArr == null ? null : Arrays.toString(iArr)));
        }
        if (iArr2 == null || iArr2.length != 2) {
            throw new IllegalStateException("Invalid stride configuration: expected int[] of length 2, got " + (iArr2 == null ? null : Arrays.toString(iArr2)));
        }
        if (iArr3 == null || iArr3.length != 2) {
            throw new IllegalStateException("Invalid padding configuration: expected int[] of length 2, got " + (iArr3 == null ? null : Arrays.toString(iArr3)));
        }
        if (iArr[0] <= 0 || iArr[1] <= 0) {
            throw new IllegalStateException("Invalid kernel size: values must be positive (> 0) for all dimensions. Got: " + Arrays.toString(iArr));
        }
        if (iArr2[0] <= 0 || iArr2[1] <= 0) {
            throw new IllegalStateException("Invalid stride configuration: values must be positive (> 0) for all dimensions. Got: " + Arrays.toString(iArr2));
        }
        if (iArr3[0] < 0 || iArr3[1] < 0) {
            throw new IllegalStateException("Invalid padding configuration: values must be >= 0 for all dimensions. Got: " + Arrays.toString(iArr3));
        }
    }

    public static INDArray reshape4dTo2d(INDArray iNDArray, LayerWorkspaceMgr layerWorkspaceMgr, ArrayType arrayType) {
        return reshape4dTo2d(iNDArray, CNN2DFormat.NCHW, layerWorkspaceMgr, arrayType);
    }

    public static INDArray reshape4dTo2d(INDArray iNDArray, CNN2DFormat cNN2DFormat, LayerWorkspaceMgr layerWorkspaceMgr, ArrayType arrayType) {
        if (iNDArray.rank() != 4) {
            throw new IllegalArgumentException("Invalid input: expect NDArray with rank 4, got rank " + iNDArray.rank() + " with shape " + Arrays.toString(iNDArray.shape()));
        }
        long[] shape = iNDArray.shape();
        if (cNN2DFormat != CNN2DFormat.NCHW) {
            if (iNDArray.ordering() != 'c' || !Shape.strideDescendingCAscendingF(iNDArray)) {
                iNDArray = layerWorkspaceMgr.dup(arrayType, iNDArray, 'c');
            }
            return layerWorkspaceMgr.leverageTo(arrayType, iNDArray.reshape('c', new long[]{shape[0] * shape[1] * shape[2], shape[3]}));
        }
        INDArray permute = iNDArray.permute(new int[]{0, 2, 3, 1});
        if (permute.ordering() != 'c' || !Shape.strideDescendingCAscendingF(permute)) {
            permute = layerWorkspaceMgr.dup(arrayType, permute, 'c');
        }
        return layerWorkspaceMgr.leverageTo(arrayType, permute.reshape('c', new long[]{shape[0] * shape[2] * shape[3], shape[1]}));
    }

    public static INDArray reshape5dTo2d(@NonNull Convolution3D.DataFormat dataFormat, INDArray iNDArray, LayerWorkspaceMgr layerWorkspaceMgr, ArrayType arrayType) {
        if (dataFormat == null) {
            throw new NullPointerException("format is marked non-null but is null");
        }
        Preconditions.checkState(iNDArray.rank() == 5, "Invalid input: expect NDArray with rank 5, got rank %ndRank with shape %ndShape", iNDArray, iNDArray);
        if (dataFormat != Convolution3D.DataFormat.NDHWC) {
            iNDArray = iNDArray.permute(new int[]{0, 2, 3, 4, 1});
        }
        if (iNDArray.ordering() != 'c' || !Shape.hasDefaultStridesForShape(iNDArray)) {
            iNDArray = layerWorkspaceMgr.dup(arrayType, iNDArray, 'c');
        }
        return layerWorkspaceMgr.leverageTo(arrayType, iNDArray.reshape('c', new long[]{iNDArray.size(0) * iNDArray.size(1) * iNDArray.size(2) * iNDArray.size(3), iNDArray.size(4)}));
    }

    public static INDArray reshapeCnn3dMask(@NonNull Convolution3D.DataFormat dataFormat, INDArray iNDArray, INDArray iNDArray2, LayerWorkspaceMgr layerWorkspaceMgr, ArrayType arrayType) {
        if (dataFormat == null) {
            throw new NullPointerException("format is marked non-null but is null");
        }
        if (iNDArray == null) {
            return null;
        }
        Preconditions.checkState(iNDArray.rank() == 5, "Expected rank 5 mask for Cnn3DLossLayer in a shape broadcastable to labels shape: got mask shape %ndShape with label shape %ndShape", iNDArray, iNDArray2);
        if (iNDArray.equalShapes(iNDArray2) || ((dataFormat == Convolution3D.DataFormat.NDHWC && iNDArray.size(0) == iNDArray2.size(0) && iNDArray.size(1) == iNDArray2.size(1) && iNDArray.size(2) == iNDArray2.size(2) && iNDArray.size(3) == iNDArray2.size(3)) || (dataFormat == Convolution3D.DataFormat.NDHWC && iNDArray.size(0) == iNDArray2.size(0) && iNDArray.size(2) == iNDArray2.size(2) && iNDArray.size(3) == iNDArray2.size(3) && iNDArray.size(4) == iNDArray2.size(4)))) {
            return reshape5dTo2d(dataFormat, iNDArray, layerWorkspaceMgr, arrayType);
        }
        long[] jArr = (long[]) iNDArray2.shape().clone();
        int i = dataFormat == Convolution3D.DataFormat.NCDHW ? 1 : 4;
        jArr[i] = iNDArray.size(i);
        INDArray createUninitialized = layerWorkspaceMgr.createUninitialized(arrayType, iNDArray.dataType(), jArr, 'c');
        Nd4j.exec(new Assign(new INDArray[]{createUninitialized, iNDArray}, new INDArray[]{createUninitialized}));
        return reshape5dTo2d(dataFormat, createUninitialized, layerWorkspaceMgr, arrayType);
    }

    public static INDArray reshape2dTo4d(INDArray iNDArray, long[] jArr, CNN2DFormat cNN2DFormat, LayerWorkspaceMgr layerWorkspaceMgr, ArrayType arrayType) {
        if (iNDArray.rank() != 2) {
            throw new IllegalArgumentException("Invalid input: expect NDArray with rank 2");
        }
        if (jArr.length != 4) {
            throw new IllegalArgumentException("Invalid input: expect toShape with 4 elements: got " + Arrays.toString(jArr));
        }
        if (iNDArray.ordering() != 'c' || !Shape.hasDefaultStridesForShape(iNDArray)) {
            iNDArray = layerWorkspaceMgr.dup(arrayType, iNDArray, 'c');
        }
        return cNN2DFormat == CNN2DFormat.NCHW ? layerWorkspaceMgr.leverageTo(arrayType, iNDArray.reshape('c', new long[]{jArr[0], jArr[2], jArr[3], jArr[1]}).permute(new int[]{0, 3, 1, 2})) : layerWorkspaceMgr.leverageTo(arrayType, iNDArray.reshape('c', jArr));
    }

    public static INDArray reshape2dTo5d(Convolution3D.DataFormat dataFormat, INDArray iNDArray, long j, long j2, long j3, long j4, long j5, LayerWorkspaceMgr layerWorkspaceMgr, ArrayType arrayType) {
        if (iNDArray.rank() != 2) {
            throw new IllegalArgumentException("Invalid input: expect NDArray with rank 2");
        }
        if (iNDArray.ordering() != 'c' || !Shape.hasDefaultStridesForShape(iNDArray)) {
            iNDArray = layerWorkspaceMgr.dup(arrayType, iNDArray, 'c');
        }
        INDArray reshape = iNDArray.reshape('c', new long[]{j, j2, j3, j4, j5});
        return dataFormat == Convolution3D.DataFormat.NDHWC ? layerWorkspaceMgr.leverageTo(arrayType, reshape) : layerWorkspaceMgr.leverageTo(arrayType, reshape.permute(new int[]{0, 4, 1, 2, 3}));
    }

    @Deprecated
    public static INDArray reshapeMaskIfRequired(INDArray iNDArray, INDArray iNDArray2, LayerWorkspaceMgr layerWorkspaceMgr, ArrayType arrayType) {
        return reshapeMaskIfRequired(iNDArray, iNDArray2, null, layerWorkspaceMgr, arrayType);
    }

    public static INDArray reshapeMaskIfRequired(INDArray iNDArray, INDArray iNDArray2, CNN2DFormat cNN2DFormat, LayerWorkspaceMgr layerWorkspaceMgr, ArrayType arrayType) {
        if (iNDArray == null) {
            return null;
        }
        return iNDArray.rank() == 2 ? adapt2dMask(iNDArray, iNDArray2, cNN2DFormat, layerWorkspaceMgr, arrayType) : iNDArray.rank() == 3 ? reshape3dMask(iNDArray, layerWorkspaceMgr, arrayType) : reshape4dTo2d(iNDArray, layerWorkspaceMgr, arrayType);
    }

    public static INDArray adapt2dMask(INDArray iNDArray, INDArray iNDArray2, @NonNull CNN2DFormat cNN2DFormat, LayerWorkspaceMgr layerWorkspaceMgr, ArrayType arrayType) {
        if (cNN2DFormat == null) {
            throw new NullPointerException("format is marked non-null but is null");
        }
        if (cNN2DFormat == CNN2DFormat.NCHW) {
            long[] shape = iNDArray2.shape();
            INDArray create = layerWorkspaceMgr.create(arrayType, iNDArray.dataType(), new long[]{shape[0], 1, shape[2], shape[3]}, 'c');
            Nd4j.getExecutioner().exec(new BroadcastCopyOp(create, iNDArray, create, new int[]{0, 1}));
            return layerWorkspaceMgr.leverageTo(arrayType, create.permute(new int[]{0, 2, 3, 1}).dup('c').reshape('c', new long[]{shape[0] * shape[2] * shape[3], 1}));
        }
        long[] shape2 = iNDArray2.shape();
        INDArray create2 = layerWorkspaceMgr.create(arrayType, iNDArray.dataType(), new long[]{shape2[0], shape2[2], shape2[3], 1}, 'c');
        Nd4j.getExecutioner().exec(new BroadcastCopyOp(create2, iNDArray, create2, new int[]{0, 3}));
        return layerWorkspaceMgr.leverageTo(arrayType, create2.reshape('c', new long[]{shape2[0] * shape2[2] * shape2[3], 1}));
    }

    public static INDArray reshape3dMask(INDArray iNDArray, LayerWorkspaceMgr layerWorkspaceMgr, ArrayType arrayType) {
        if (iNDArray.ordering() != 'c' || !Shape.hasDefaultStridesForShape(iNDArray)) {
            iNDArray = layerWorkspaceMgr.dup(arrayType, iNDArray, 'c');
        }
        return iNDArray.reshape('c', new long[]{iNDArray.length(), 1});
    }

    public static INDArray reshape4dMask(INDArray iNDArray, LayerWorkspaceMgr layerWorkspaceMgr, ArrayType arrayType) {
        return reshape4dTo2d(iNDArray, layerWorkspaceMgr, arrayType);
    }

    public static int[] getHWDFromInputType(InputType inputType) {
        int height;
        int width;
        int depth;
        if (inputType instanceof InputType.InputTypeConvolutional) {
            InputType.InputTypeConvolutional inputTypeConvolutional = (InputType.InputTypeConvolutional) inputType;
            if (inputTypeConvolutional.getHeight() > 2147483647L || inputTypeConvolutional.getWidth() > 2147483647L || inputTypeConvolutional.getChannels() > 2147483647L) {
                throw new ND4JArraySizeException();
            }
            height = (int) inputTypeConvolutional.getHeight();
            width = (int) inputTypeConvolutional.getWidth();
            depth = (int) inputTypeConvolutional.getChannels();
        } else {
            if (!(inputType instanceof InputType.InputTypeConvolutionalFlat)) {
                throw new IllegalStateException("Invalid input type: expected InputTypeConvolutional or InputTypeConvolutionalFlat. Got: " + inputType);
            }
            InputType.InputTypeConvolutionalFlat inputTypeConvolutionalFlat = (InputType.InputTypeConvolutionalFlat) inputType;
            if (inputTypeConvolutionalFlat.getHeight() > 2147483647L || inputTypeConvolutionalFlat.getWidth() > 2147483647L || inputTypeConvolutionalFlat.getDepth() > 2147483647L) {
                throw new ND4JArraySizeException();
            }
            height = (int) inputTypeConvolutionalFlat.getHeight();
            width = (int) inputTypeConvolutionalFlat.getWidth();
            depth = (int) inputTypeConvolutionalFlat.getDepth();
        }
        return new int[]{height, width, depth};
    }

    public static INDArray cnn1dMaskReduction(INDArray iNDArray, int i, int i2, int i3, int i4, ConvolutionMode convolutionMode) {
        int[] outputSize;
        Preconditions.checkState(iNDArray.rank() == 2, "Rank must be 2 for cnn1d mask array - shape ", iNDArray.shape());
        if ((convolutionMode == ConvolutionMode.Same || convolutionMode == ConvolutionMode.Causal) && i2 == 1) {
            return iNDArray;
        }
        if (!Shape.hasDefaultStridesForShape(iNDArray)) {
            iNDArray = iNDArray.dup();
        }
        INDArray reshape = iNDArray.reshape(new long[]{iNDArray.size(0), 1, iNDArray.size(1), 1});
        int[] iArr = null;
        int[] iArr2 = {i, 1};
        int[] iArr3 = {i2, 1};
        int[] iArr4 = {i4, 1};
        if (convolutionMode == ConvolutionMode.Same || convolutionMode == ConvolutionMode.Causal) {
            outputSize = getOutputSize(reshape, iArr2, iArr3, null, convolutionMode, iArr4, CNN2DFormat.NCHW);
        } else {
            iArr = new int[]{i3, 0};
            outputSize = getOutputSize(reshape, iArr2, iArr3, iArr, convolutionMode, iArr4, CNN2DFormat.NCHW);
        }
        int i5 = outputSize[0];
        INDArray createUninitialized = Nd4j.createUninitialized(new int[]{(int) iNDArray.size(0), 1, i5, 1}, 'c');
        Nd4j.getExecutioner().exec(new MaxPooling2D(reshape, createUninitialized, Pooling2DConfig.builder().kH(iArr2[0]).kW(iArr2[1]).sH(iArr3[0]).sW(iArr3[1]).pH(iArr == null ? 0L : iArr[0]).pW(iArr == null ? 0L : iArr[1]).dH(iArr4[0]).dW(iArr4[1]).paddingMode(ConvolutionMode.mapToMode(convolutionMode)).isNHWC(false).build()));
        return createUninitialized.reshape('c', new long[]{iNDArray.size(0), i5});
    }

    public static INDArray cnn2dMaskReduction(INDArray iNDArray, int[] iArr, int[] iArr2, int[] iArr3, int[] iArr4, ConvolutionMode convolutionMode) {
        int[] iArr5;
        int[] iArr6;
        int[] iArr7;
        int[] iArr8;
        if (iNDArray.rank() != 4) {
            throw new IllegalStateException("Expected rank 4 mask array for 2D CNN layers. Mask arrays for 2D CNN layers must have shape [batchSize,channels,X,Y] where X = (1 or activationsHeight) and Y = (1 or activationsWidth): Got rank " + iNDArray.rank() + " array with shape " + Arrays.toString(iNDArray.shape()));
        }
        if (convolutionMode == ConvolutionMode.Same && iArr2[0] == 1 && iArr2[1] == 1) {
            return iNDArray;
        }
        if (iNDArray.size(2) == 1 && iNDArray.size(3) == 1) {
            return iNDArray;
        }
        if (iNDArray.size(3) == 1) {
            iArr5 = new int[]{iArr[0], 1};
            iArr6 = new int[]{iArr2[0], 1};
            iArr7 = new int[]{iArr3[0], 0};
            iArr8 = new int[]{iArr4[0], 1};
        } else if (iNDArray.size(2) == 1) {
            iArr5 = new int[]{1, iArr[1]};
            iArr6 = new int[]{1, iArr2[1]};
            iArr7 = new int[]{0, iArr3[1]};
            iArr8 = new int[]{1, iArr4[1]};
        } else {
            iArr5 = iArr;
            iArr6 = iArr2;
            iArr7 = iArr3;
            iArr8 = iArr4;
        }
        int[] outputSize = getOutputSize(iNDArray, iArr5, iArr6, iArr7, convolutionMode, iArr8);
        boolean z = true;
        int i = 0;
        while (true) {
            if (i >= outputSize.length) {
                break;
            }
            if (outputSize[i] != iNDArray.size(i)) {
                z = false;
                break;
            }
            i++;
        }
        if (z) {
            return iNDArray;
        }
        INDArray createUninitialized = Nd4j.createUninitialized(iNDArray.dataType(), new long[]{iNDArray.size(0), iNDArray.size(1), outputSize[0], outputSize[1]});
        Nd4j.exec(new MaxPooling2D(iNDArray, createUninitialized, Pooling2DConfig.builder().kH(iArr5[0]).kW(iArr5[1]).sH(iArr6[0]).sW(iArr6[1]).pH(iArr7[0]).pW(iArr7[1]).dH(iArr8[0]).dW(iArr8[1]).paddingMode(ConvolutionMode.mapToMode(convolutionMode)).isNHWC(false).build()));
        return createUninitialized;
    }
}
