/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers.mkldnn;

import java.util.Collections;
import java.util.Map;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.convolution.ConvolutionHelper;
import org.deeplearning4j.nn.layers.mkldnn.BaseMKLDNNHelper;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.util.ConvolutionUtils;
import org.nd4j.common.primitives.Pair;
import org.nd4j.common.util.ArrayUtil;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Conv2D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Conv2DDerivative;
import org.nd4j.linalg.factory.Nd4j;

public class MKLDNNConvHelper
implements ConvolutionHelper {
    protected OpContext context;
    protected OpContext contextBwd;

    public MKLDNNConvHelper(DataType dataType) {
    }

    @Override
    public boolean checkSupported() {
        return BaseMKLDNNHelper.mklDnnEnabled();
    }

    @Override
    public Pair<Gradient, INDArray> backpropGradient(INDArray input, INDArray weights, INDArray bias, INDArray delta, int[] kernel, int[] strides, int[] pad, INDArray biasGradView, INDArray weightGradView, IActivation afn, ConvolutionLayer.AlgoMode mode, ConvolutionLayer.BwdFilterAlgo bwdFilterAlgo, ConvolutionLayer.BwdDataAlgo bwdDataAlgo, ConvolutionMode convolutionMode, int[] dilation, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr) {
        int i;
        INDArray[] iNDArrayArray;
        INDArray[] iNDArrayArray2;
        if (input.dataType() != DataType.FLOAT || weights.dataType() != DataType.FLOAT) {
            return null;
        }
        int hDim = 2;
        int wDim = 3;
        if (format == CNN2DFormat.NHWC) {
            hDim = 1;
            wDim = 2;
        }
        if (convolutionMode == ConvolutionMode.Same) {
            pad = ConvolutionUtils.getSameModeTopLeftPadding(new int[]{(int)delta.size(hDim), (int)delta.size(wDim)}, new int[]{(int)input.size(hDim), (int)input.size(wDim)}, kernel, strides, dilation);
        }
        if (this.contextBwd == null) {
            this.contextBwd = Nd4j.getExecutioner().buildContext();
            this.contextBwd.setIArguments(new long[]{kernel[0], kernel[1], strides[0], strides[1], pad[0], pad[1], dilation[0], dilation[1], ArrayUtil.fromBoolean((convolutionMode == ConvolutionMode.Same ? 1 : 0) != 0), format == CNN2DFormat.NCHW ? 0L : 1L, 1L});
        }
        INDArray gradAtInput = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), input.shape());
        if (biasGradView == null) {
            INDArray[] iNDArrayArray3 = new INDArray[3];
            iNDArrayArray3[0] = input;
            iNDArrayArray3[1] = weights;
            iNDArrayArray2 = iNDArrayArray3;
            iNDArrayArray3[2] = delta;
        } else {
            INDArray[] iNDArrayArray4 = new INDArray[4];
            iNDArrayArray4[0] = input;
            iNDArrayArray4[1] = weights;
            iNDArrayArray4[2] = bias;
            iNDArrayArray2 = iNDArrayArray4;
            iNDArrayArray4[3] = delta;
        }
        INDArray[] inputsArr = iNDArrayArray2;
        if (biasGradView == null) {
            INDArray[] iNDArrayArray5 = new INDArray[2];
            iNDArrayArray5[0] = gradAtInput;
            iNDArrayArray = iNDArrayArray5;
            iNDArrayArray5[1] = weightGradView;
        } else {
            INDArray[] iNDArrayArray6 = new INDArray[3];
            iNDArrayArray6[0] = gradAtInput;
            iNDArrayArray6[1] = weightGradView;
            iNDArrayArray = iNDArrayArray6;
            iNDArrayArray6[2] = biasGradView;
        }
        INDArray[] outputArr = iNDArrayArray;
        this.contextBwd.purge();
        for (i = 0; i < inputsArr.length; ++i) {
            this.contextBwd.setInputArray(i, inputsArr[i]);
        }
        for (i = 0; i < outputArr.length; ++i) {
            this.contextBwd.setOutputArray(i, outputArr[i]);
        }
        Conv2DDerivative op = new Conv2DDerivative();
        Nd4j.exec((CustomOp)op, (OpContext)this.contextBwd);
        DefaultGradient g = new DefaultGradient();
        if (biasGradView != null) {
            g.gradientForVariable().put("b", biasGradView);
        }
        g.gradientForVariable().put("W", weightGradView);
        return new Pair((Object)g, (Object)gradAtInput);
    }

    @Override
    public INDArray preOutput(INDArray input, INDArray weights, INDArray bias, int[] kernel, int[] strides, int[] pad, ConvolutionLayer.AlgoMode mode, ConvolutionLayer.FwdAlgo fwdAlgo, ConvolutionMode convolutionMode, int[] dilation, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr) {
        INDArray[] iNDArrayArray;
        long[] lArray;
        int[] outSize;
        if (input.dataType() != DataType.FLOAT || weights.dataType() != DataType.FLOAT) {
            return null;
        }
        int hDim = 2;
        int wDim = 3;
        if (format == CNN2DFormat.NHWC) {
            hDim = 1;
            wDim = 2;
        }
        int inH = (int)input.size(hDim);
        int inW = (int)input.size(wDim);
        if (convolutionMode == ConvolutionMode.Same) {
            outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, null, convolutionMode, dilation, format);
            pad = ConvolutionUtils.getSameModeTopLeftPadding(outSize, new int[]{inH, inW}, kernel, strides, dilation);
        } else {
            outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, pad, convolutionMode, dilation, format);
        }
        if (this.context == null) {
            this.context = Nd4j.getExecutioner().buildContext();
            this.context.setIArguments(new long[]{kernel[0], kernel[1], strides[0], strides[1], pad[0], pad[1], dilation[0], dilation[1], ArrayUtil.fromBoolean((convolutionMode == ConvolutionMode.Same ? 1 : 0) != 0), format == CNN2DFormat.NCHW ? 0L : 1L, 1L});
        }
        int outDepth = (int)weights.size(0);
        if (format == CNN2DFormat.NCHW) {
            long[] lArray2 = new long[4];
            lArray2[0] = input.size(0);
            lArray2[1] = outDepth;
            lArray2[2] = outSize[0];
            lArray = lArray2;
            lArray2[3] = outSize[1];
        } else {
            long[] lArray3 = new long[4];
            lArray3[0] = input.size(0);
            lArray3[1] = outSize[0];
            lArray3[2] = outSize[1];
            lArray = lArray3;
            lArray3[3] = outDepth;
        }
        long[] outShape = lArray;
        INDArray out = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, input.dataType(), outShape);
        if (bias == null) {
            INDArray[] iNDArrayArray2 = new INDArray[2];
            iNDArrayArray2[0] = input;
            iNDArrayArray = iNDArrayArray2;
            iNDArrayArray2[1] = weights;
        } else {
            INDArray[] iNDArrayArray3 = new INDArray[3];
            iNDArrayArray3[0] = input;
            iNDArrayArray3[1] = weights;
            iNDArrayArray = iNDArrayArray3;
            iNDArrayArray3[2] = bias;
        }
        INDArray[] inputsArr = iNDArrayArray;
        this.context.purge();
        for (int i = 0; i < inputsArr.length; ++i) {
            this.context.setInputArray(i, inputsArr[i]);
        }
        this.context.setOutputArray(0, out);
        Conv2D op = new Conv2D();
        Nd4j.exec((CustomOp)op, (OpContext)this.context);
        return out;
    }

    @Override
    public INDArray activate(INDArray z, IActivation afn, boolean training) {
        return afn.getActivation(z, training);
    }

    @Override
    public Map<String, Long> helperMemoryUse() {
        return Collections.emptyMap();
    }
}

