/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.zoo.model.helper;

import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.graph.GraphVertex;
import org.deeplearning4j.nn.conf.graph.MergeVertex;
import org.deeplearning4j.nn.conf.layers.ActivationLayer;
import org.deeplearning4j.nn.conf.layers.BatchNormalization;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.nd4j.linalg.activations.Activation;

public class FaceNetHelper {
    public static String getModuleName() {
        return "inception";
    }

    public static String getModuleName(String layerName) {
        return FaceNetHelper.getModuleName() + "-" + layerName;
    }

    public static ConvolutionLayer conv1x1(int in, int out, double bias) {
        return ((ConvolutionLayer.Builder)((ConvolutionLayer.Builder)((ConvolutionLayer.Builder)((ConvolutionLayer.Builder)new ConvolutionLayer.Builder(new int[]{1, 1}, new int[]{1, 1}, new int[]{0, 0}).nIn(in)).nOut(out)).biasInit(bias)).cudnnAlgoMode(ConvolutionLayer.AlgoMode.NO_WORKSPACE)).build();
    }

    public static ConvolutionLayer c3x3reduce(int in, int out, double bias) {
        return FaceNetHelper.conv1x1(in, out, bias);
    }

    public static ConvolutionLayer c5x5reduce(int in, int out, double bias) {
        return FaceNetHelper.conv1x1(in, out, bias);
    }

    public static ConvolutionLayer conv3x3(int in, int out, double bias) {
        return ((ConvolutionLayer.Builder)((ConvolutionLayer.Builder)((ConvolutionLayer.Builder)new ConvolutionLayer.Builder(new int[]{3, 3}, new int[]{1, 1}, new int[]{1, 1}).nIn(in)).nOut(out)).biasInit(bias)).build();
    }

    public static ConvolutionLayer conv5x5(int in, int out, double bias) {
        return ((ConvolutionLayer.Builder)((ConvolutionLayer.Builder)((ConvolutionLayer.Builder)((ConvolutionLayer.Builder)new ConvolutionLayer.Builder(new int[]{5, 5}, new int[]{1, 1}, new int[]{2, 2}).nIn(in)).nOut(out)).biasInit(bias)).cudnnAlgoMode(ConvolutionLayer.AlgoMode.NO_WORKSPACE)).build();
    }

    public static ConvolutionLayer conv7x7(int in, int out, double bias) {
        return ((ConvolutionLayer.Builder)((ConvolutionLayer.Builder)((ConvolutionLayer.Builder)((ConvolutionLayer.Builder)new ConvolutionLayer.Builder(new int[]{7, 7}, new int[]{2, 2}, new int[]{3, 3}).nIn(in)).nOut(out)).biasInit(bias)).cudnnAlgoMode(ConvolutionLayer.AlgoMode.NO_WORKSPACE)).build();
    }

    public static SubsamplingLayer avgPool7x7(int stride) {
        return new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.AVG, new int[]{7, 7}, new int[]{1, 1}).build();
    }

    public static SubsamplingLayer avgPoolNxN(int size, int stride) {
        return new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.AVG, new int[]{size, size}, new int[]{stride, stride}).build();
    }

    public static SubsamplingLayer pNormNxN(int pNorm, int size, int stride) {
        return ((SubsamplingLayer.Builder)new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.PNORM, new int[]{size, size}, new int[]{stride, stride}).pnorm(pNorm)).build();
    }

    public static SubsamplingLayer maxPool3x3(int stride) {
        return new SubsamplingLayer.Builder(new int[]{3, 3}, new int[]{stride, stride}, new int[]{1, 1}).build();
    }

    public static SubsamplingLayer maxPoolNxN(int size, int stride) {
        return new SubsamplingLayer.Builder(new int[]{size, size}, new int[]{stride, stride}, new int[]{1, 1}).build();
    }

    public static DenseLayer fullyConnected(int in, int out, double dropOut) {
        return ((DenseLayer.Builder)((DenseLayer.Builder)((DenseLayer.Builder)new DenseLayer.Builder().nIn(in)).nOut(out)).dropOut(dropOut)).build();
    }

    public static ConvolutionLayer convNxN(int reduceSize, int outputSize, int kernelSize, int kernelStride, boolean padding) {
        int pad = padding ? (int)Math.floor(kernelStride / 2) * 2 : 0;
        return ((ConvolutionLayer.Builder)((ConvolutionLayer.Builder)((ConvolutionLayer.Builder)((ConvolutionLayer.Builder)new ConvolutionLayer.Builder(new int[]{kernelSize, kernelSize}, new int[]{kernelStride, kernelStride}, new int[]{pad, pad}).nIn(reduceSize)).nOut(outputSize)).biasInit(0.2)).cudnnAlgoMode(ConvolutionLayer.AlgoMode.NO_WORKSPACE)).build();
    }

    public static ConvolutionLayer convNxNreduce(int inputSize, int reduceSize, int reduceStride) {
        return ((ConvolutionLayer.Builder)((ConvolutionLayer.Builder)((ConvolutionLayer.Builder)((ConvolutionLayer.Builder)new ConvolutionLayer.Builder(new int[]{1, 1}, new int[]{reduceStride, reduceStride}).nIn(inputSize)).nOut(reduceSize)).biasInit(0.2)).cudnnAlgoMode(ConvolutionLayer.AlgoMode.NO_WORKSPACE)).build();
    }

    public static BatchNormalization batchNorm(int in, int out) {
        return ((BatchNormalization.Builder)((BatchNormalization.Builder)new BatchNormalization.Builder(false).nIn(in)).nOut(out)).build();
    }

    public static ComputationGraphConfiguration.GraphBuilder appendGraph(ComputationGraphConfiguration.GraphBuilder graph, String moduleLayerName, int inputSize, int[] kernelSize, int[] kernelStride, int[] outputSize, int[] reduceSize, SubsamplingLayer.PoolingType poolingType, Activation transferFunction, String inputLayer) {
        return FaceNetHelper.appendGraph(graph, moduleLayerName, inputSize, kernelSize, kernelStride, outputSize, reduceSize, poolingType, 0, 3, 1, transferFunction, inputLayer);
    }

    public static ComputationGraphConfiguration.GraphBuilder appendGraph(ComputationGraphConfiguration.GraphBuilder graph, String moduleLayerName, int inputSize, int[] kernelSize, int[] kernelStride, int[] outputSize, int[] reduceSize, SubsamplingLayer.PoolingType poolingType, int pNorm, Activation transferFunction, String inputLayer) {
        return FaceNetHelper.appendGraph(graph, moduleLayerName, inputSize, kernelSize, kernelStride, outputSize, reduceSize, poolingType, pNorm, 3, 1, transferFunction, inputLayer);
    }

    public static ComputationGraphConfiguration.GraphBuilder appendGraph(ComputationGraphConfiguration.GraphBuilder graph, String moduleLayerName, int inputSize, int[] kernelSize, int[] kernelStride, int[] outputSize, int[] reduceSize, SubsamplingLayer.PoolingType poolingType, int poolSize, int poolStride, Activation transferFunction, String inputLayer) {
        return FaceNetHelper.appendGraph(graph, moduleLayerName, inputSize, kernelSize, kernelStride, outputSize, reduceSize, poolingType, 0, poolSize, poolStride, transferFunction, inputLayer);
    }

    public static ComputationGraphConfiguration.GraphBuilder appendGraph(ComputationGraphConfiguration.GraphBuilder graph, String moduleLayerName, int inputSize, int[] kernelSize, int[] kernelStride, int[] outputSize, int[] reduceSize, SubsamplingLayer.PoolingType poolingType, int pNorm, int poolSize, int poolStride, Activation transferFunction, String inputLayer) {
        int i;
        for (i = 0; i < kernelSize.length; ++i) {
            graph.addLayer(FaceNetHelper.getModuleName(moduleLayerName) + "-cnn1-" + i, (Layer)FaceNetHelper.conv1x1(inputSize, reduceSize[i], 0.2), new String[]{inputLayer});
            graph.addLayer(FaceNetHelper.getModuleName(moduleLayerName) + "-batch1-" + i, (Layer)FaceNetHelper.batchNorm(reduceSize[i], reduceSize[i]), new String[]{FaceNetHelper.getModuleName(moduleLayerName) + "-cnn1-" + i});
            graph.addLayer(FaceNetHelper.getModuleName(moduleLayerName) + "-transfer1-" + i, (Layer)new ActivationLayer.Builder().activation(transferFunction).build(), new String[]{FaceNetHelper.getModuleName(moduleLayerName) + "-batch1-" + i});
            graph.addLayer(FaceNetHelper.getModuleName(moduleLayerName) + "-reduce1-" + i, (Layer)FaceNetHelper.convNxN(reduceSize[i], outputSize[i], kernelSize[i], kernelStride[i], true), new String[]{FaceNetHelper.getModuleName(moduleLayerName) + "-transfer1-" + i});
            graph.addLayer(FaceNetHelper.getModuleName(moduleLayerName) + "-batch2-" + i, (Layer)FaceNetHelper.batchNorm(outputSize[i], outputSize[i]), new String[]{FaceNetHelper.getModuleName(moduleLayerName) + "-reduce1-" + i});
            graph.addLayer(FaceNetHelper.getModuleName(moduleLayerName) + "-transfer2-" + i, (Layer)new ActivationLayer.Builder().activation(transferFunction).build(), new String[]{FaceNetHelper.getModuleName(moduleLayerName) + "-batch2-" + i});
        }
        i = kernelSize.length;
        try {
            int checkIndex = reduceSize[i];
            switch (poolingType) {
                case AVG: {
                    graph.addLayer(FaceNetHelper.getModuleName(moduleLayerName) + "-pool1", (Layer)FaceNetHelper.avgPoolNxN(poolSize, poolStride), new String[]{inputLayer});
                    break;
                }
                case MAX: {
                    graph.addLayer(FaceNetHelper.getModuleName(moduleLayerName) + "-pool1", (Layer)FaceNetHelper.maxPoolNxN(poolSize, poolStride), new String[]{inputLayer});
                    break;
                }
                case PNORM: {
                    if (pNorm <= 0) {
                        throw new IllegalArgumentException("p-norm must be greater than zero.");
                    }
                    graph.addLayer(FaceNetHelper.getModuleName(moduleLayerName) + "-pool1", (Layer)FaceNetHelper.pNormNxN(pNorm, poolSize, poolStride), new String[]{inputLayer});
                    break;
                }
                default: {
                    throw new IllegalStateException("You must specify a valid pooling type of avg or max for Inception module.");
                }
            }
            graph.addLayer(FaceNetHelper.getModuleName(moduleLayerName) + "-cnn2", (Layer)FaceNetHelper.convNxNreduce(inputSize, reduceSize[i], 1), new String[]{FaceNetHelper.getModuleName(moduleLayerName) + "-pool1"});
            graph.addLayer(FaceNetHelper.getModuleName(moduleLayerName) + "-batch3", (Layer)FaceNetHelper.batchNorm(reduceSize[i], reduceSize[i]), new String[]{FaceNetHelper.getModuleName(moduleLayerName) + "-cnn2"});
            graph.addLayer(FaceNetHelper.getModuleName(moduleLayerName) + "-transfer3", (Layer)new ActivationLayer.Builder().activation(transferFunction).build(), new String[]{FaceNetHelper.getModuleName(moduleLayerName) + "-batch3"});
        }
        catch (IndexOutOfBoundsException indexOutOfBoundsException) {
            // empty catch block
        }
        ++i;
        try {
            graph.addLayer(FaceNetHelper.getModuleName(moduleLayerName) + "-reduce2", (Layer)FaceNetHelper.convNxNreduce(inputSize, reduceSize[i], 1), new String[]{inputLayer});
            graph.addLayer(FaceNetHelper.getModuleName(moduleLayerName) + "-batch4", (Layer)FaceNetHelper.batchNorm(reduceSize[i], reduceSize[i]), new String[]{FaceNetHelper.getModuleName(moduleLayerName) + "-reduce2"});
            graph.addLayer(FaceNetHelper.getModuleName(moduleLayerName) + "-transfer4", (Layer)new ActivationLayer.Builder().activation(transferFunction).build(), new String[]{FaceNetHelper.getModuleName(moduleLayerName) + "-batch4"});
        }
        catch (IndexOutOfBoundsException indexOutOfBoundsException) {
            // empty catch block
        }
        if (kernelSize.length == 1 && reduceSize.length == 3) {
            graph.addVertex(FaceNetHelper.getModuleName(moduleLayerName), (GraphVertex)new MergeVertex(), new String[]{FaceNetHelper.getModuleName(moduleLayerName) + "-transfer2-0", FaceNetHelper.getModuleName(moduleLayerName) + "-transfer3", FaceNetHelper.getModuleName(moduleLayerName) + "-transfer4"});
        } else if (kernelSize.length == 2 && reduceSize.length == 2) {
            graph.addVertex(FaceNetHelper.getModuleName(moduleLayerName), (GraphVertex)new MergeVertex(), new String[]{FaceNetHelper.getModuleName(moduleLayerName) + "-transfer2-0", FaceNetHelper.getModuleName(moduleLayerName) + "-transfer2-1"});
        } else if (kernelSize.length == 2 && reduceSize.length == 3) {
            graph.addVertex(FaceNetHelper.getModuleName(moduleLayerName), (GraphVertex)new MergeVertex(), new String[]{FaceNetHelper.getModuleName(moduleLayerName) + "-transfer2-0", FaceNetHelper.getModuleName(moduleLayerName) + "-transfer2-1", FaceNetHelper.getModuleName(moduleLayerName) + "-transfer3"});
        } else if (kernelSize.length == 2 && reduceSize.length == 4) {
            graph.addVertex(FaceNetHelper.getModuleName(moduleLayerName), (GraphVertex)new MergeVertex(), new String[]{FaceNetHelper.getModuleName(moduleLayerName) + "-transfer2-0", FaceNetHelper.getModuleName(moduleLayerName) + "-transfer2-1", FaceNetHelper.getModuleName(moduleLayerName) + "-transfer3", FaceNetHelper.getModuleName(moduleLayerName) + "-transfer4"});
        } else {
            throw new IllegalStateException("Only kernel of shape 1 or 2 and a reduce shape between 2 and 4 is supported.");
        }
        return graph;
    }
}

