/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.modelimport.keras.layers;

import com.google.gson.Gson;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.ServiceLoader;
import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.AbstractLayer;
import org.deeplearning4j.nn.modelimport.keras.layers.TFOpLayer;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.TFGraphRunnerService;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.shade.protobuf.ByteString;
import org.nd4j.shade.protobuf.Message;
import org.nd4j.shade.protobuf.TextFormat;
import org.nd4j.shade.protobuf.util.JsonFormat;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;

public class TFOpLayerImpl
extends AbstractLayer<TFOpLayer> {
    private static final Logger log = LoggerFactory.getLogger(TFOpLayerImpl.class);
    private Map nodeDef;
    private Map constants;
    private List<String> inputNames;
    TFGraphRunnerService graphRunnerService;

    public TFOpLayerImpl(Map nodeDef, Map constants, NeuralNetConfiguration conf, DataType dtype) {
        super(conf, dtype);
        this.nodeDef = nodeDef;
        this.constants = constants;
        this.setGraphRunner();
    }

    public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) {
        throw new RuntimeException("Backprop through TFOpLayerImpl is not supported yet. TFOpLayerImpl is created when importing TensorFlow 2.0 Keras models (tf.keras) into DL4J, that contains TensorFlow operations not just Keras layers.");
    }

    private void setGraphRunner() {
        try {
            String json = new Gson().toJson((Object)this.nodeDef);
            NodeDef.Builder builder = NodeDef.newBuilder();
            JsonFormat.parser().merge(json, (Message.Builder)builder);
            NodeDef nodeDef = builder.build();
            ArrayList<String> allInputNames = new ArrayList<String>();
            HashMap<String, String> inputDataTypes = new HashMap<String, String>();
            HashMap<String, INDArray> constArrays = new HashMap<String, INDArray>();
            this.inputNames = new ArrayList<String>();
            List<String> outputNames = Arrays.asList(nodeDef.getName());
            Map attrMap = nodeDef.getAttrMap();
            for (int i = 0; i < nodeDef.getInputCount(); ++i) {
                String inputName = nodeDef.getInput(i);
                String[] split = inputName.split("/");
                String attrKey = split.length == 1 ? "T" : "T" + split[split.length - 1];
                allInputNames.add(nodeDef.getInput(i));
                inputDataTypes.put(nodeDef.getInput(i), ((AttrValue)attrMap.get(attrKey)).getType().toString());
                if (this.constants.containsKey(String.valueOf(i))) {
                    constArrays.put(nodeDef.getInput(i), Nd4j.create((List)((List)this.constants.get(String.valueOf(i)))));
                    continue;
                }
                this.inputNames.add(nodeDef.getInput(i));
            }
            String graph = "node{\n" + nodeDef.toString() + "\n}\nversions {\n producer: 22\n}";
            for (int i = 0; i < allInputNames.size(); ++i) {
                String inpName = (String)allInputNames.get(i);
                String dtype = (String)inputDataTypes.get(inpName);
                graph = "node{\nname: \"" + inpName + "\"\nop: \"Placeholder\"\nattr{\nkey: \"dtype\"\n value {\n type: " + dtype + "}\n}\n}\n" + graph;
            }
            GraphDef.Builder graphDefBuilder = GraphDef.newBuilder();
            TextFormat.getParser().merge((CharSequence)graph, (Message.Builder)graphDefBuilder);
            GraphDef graphDef = graphDefBuilder.build();
            ByteString serialized = graphDef.toByteString();
            byte[] graphBytes = serialized.toByteArray();
            ServiceLoader<TFGraphRunnerService> sl = ServiceLoader.load(TFGraphRunnerService.class);
            Iterator<TFGraphRunnerService> iter = sl.iterator();
            if (!iter.hasNext()) {
                throw new RuntimeException("The model contains a Tensorflow Op, which requires the nd4j-tensorflow dependency to execute.");
            }
            this.graphRunnerService = iter.next().init(allInputNames, outputNames, graphBytes, constArrays, inputDataTypes);
        }
        catch (Exception e) {
            throw new RuntimeException("Error parsing protobuf", e);
        }
    }

    private INDArray runGraph(INDArray input) {
        HashMap<String, INDArray> inputMap = new HashMap<String, INDArray>();
        inputMap.put(this.inputNames.get(0), input);
        INDArray out = this.graphRunnerService.run(inputMap).values().toArray(new INDArray[0])[0];
        return out;
    }

    public long[] getOutputShape(long[] inputShape) {
        long[] shape = ArrayUtils.clone((long[])inputShape);
        for (int i = 0; i < shape.length; ++i) {
            if (shape[i] >= 0L) continue;
            shape[i] = 1L;
        }
        INDArray dummyArr = Nd4j.zeros((long[])shape);
        return this.runGraph(dummyArr).shape();
    }

    public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) {
        return this.runGraph(this.input);
    }

    public boolean isPretrainLayer() {
        return false;
    }

    public void clearNoiseWeightParams() {
    }

    public Map getNodeDef() {
        return this.nodeDef;
    }

    public Map getConstants() {
        return this.constants;
    }

    public List<String> getInputNames() {
        return this.inputNames;
    }

    public TFGraphRunnerService getGraphRunnerService() {
        return this.graphRunnerService;
    }

    public void setNodeDef(Map nodeDef) {
        this.nodeDef = nodeDef;
    }

    public void setConstants(Map constants) {
        this.constants = constants;
    }

    public void setInputNames(List<String> inputNames) {
        this.inputNames = inputNames;
    }

    public void setGraphRunnerService(TFGraphRunnerService graphRunnerService) {
        this.graphRunnerService = graphRunnerService;
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof TFOpLayerImpl)) {
            return false;
        }
        TFOpLayerImpl other = (TFOpLayerImpl)((Object)o);
        if (!other.canEqual((Object)this)) {
            return false;
        }
        Map this$nodeDef = this.getNodeDef();
        Map other$nodeDef = other.getNodeDef();
        if (this$nodeDef == null ? other$nodeDef != null : !((Object)this$nodeDef).equals(other$nodeDef)) {
            return false;
        }
        Map this$constants = this.getConstants();
        Map other$constants = other.getConstants();
        if (this$constants == null ? other$constants != null : !((Object)this$constants).equals(other$constants)) {
            return false;
        }
        List<String> this$inputNames = this.getInputNames();
        List<String> other$inputNames = other.getInputNames();
        if (this$inputNames == null ? other$inputNames != null : !((Object)this$inputNames).equals(other$inputNames)) {
            return false;
        }
        TFGraphRunnerService this$graphRunnerService = this.getGraphRunnerService();
        TFGraphRunnerService other$graphRunnerService = other.getGraphRunnerService();
        return !(this$graphRunnerService == null ? other$graphRunnerService != null : !this$graphRunnerService.equals(other$graphRunnerService));
    }

    protected boolean canEqual(Object other) {
        return other instanceof TFOpLayerImpl;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        Map $nodeDef = this.getNodeDef();
        result = result * 59 + ($nodeDef == null ? 43 : ((Object)$nodeDef).hashCode());
        Map $constants = this.getConstants();
        result = result * 59 + ($constants == null ? 43 : ((Object)$constants).hashCode());
        List<String> $inputNames = this.getInputNames();
        result = result * 59 + ($inputNames == null ? 43 : ((Object)$inputNames).hashCode());
        TFGraphRunnerService $graphRunnerService = this.getGraphRunnerService();
        result = result * 59 + ($graphRunnerService == null ? 43 : $graphRunnerService.hashCode());
        return result;
    }

    public String toString() {
        return "TFOpLayerImpl(nodeDef=" + this.getNodeDef() + ", constants=" + this.getConstants() + ", inputNames=" + this.getInputNames() + ", graphRunnerService=" + this.getGraphRunnerService() + ")";
    }
}

