/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.params;

import java.lang.invoke.CallSite;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

public class BidirectionalParamInitializer
implements ParamInitializer {
    public static final String FORWARD_PREFIX = "f";
    public static final String BACKWARD_PREFIX = "b";
    private final Bidirectional layer;
    private final Layer underlying;
    private List<String> paramKeys;
    private List<String> weightKeys;
    private List<String> biasKeys;

    public BidirectionalParamInitializer(Bidirectional layer) {
        this.layer = layer;
        this.underlying = this.underlying(layer);
    }

    @Override
    public long numParams(NeuralNetConfiguration conf) {
        return this.numParams(conf.getLayer());
    }

    @Override
    public long numParams(Layer layer) {
        return 2L * this.underlying(layer).initializer().numParams(this.underlying(layer));
    }

    @Override
    public List<String> paramKeys(Layer layer) {
        if (this.paramKeys == null) {
            Layer u = this.underlying(layer);
            List<String> orig = u.initializer().paramKeys(u);
            this.paramKeys = this.withPrefixes(orig);
        }
        return this.paramKeys;
    }

    @Override
    public List<String> weightKeys(Layer layer) {
        if (this.weightKeys == null) {
            Layer u = this.underlying(layer);
            List<String> orig = u.initializer().weightKeys(u);
            this.weightKeys = this.withPrefixes(orig);
        }
        return this.weightKeys;
    }

    @Override
    public List<String> biasKeys(Layer layer) {
        if (this.biasKeys == null) {
            Layer u = this.underlying(layer);
            List<String> orig = u.initializer().weightKeys(u);
            this.biasKeys = this.withPrefixes(orig);
        }
        return this.biasKeys;
    }

    @Override
    public boolean isWeightParam(Layer layer, String key) {
        return this.weightKeys(this.layer).contains(key);
    }

    @Override
    public boolean isBiasParam(Layer layer, String key) {
        return this.biasKeys(this.layer).contains(key);
    }

    @Override
    public Map<String, INDArray> init(NeuralNetConfiguration conf, INDArray paramsView, boolean initializeParams) {
        long n = paramsView.length() / 2L;
        INDArray paramsReshape = paramsView.reshape(new long[]{paramsView.length()});
        INDArray forwardView = paramsReshape.get(new INDArrayIndex[]{NDArrayIndex.interval((long)0L, (long)n)});
        INDArray backwardView = paramsReshape.get(new INDArrayIndex[]{NDArrayIndex.interval((long)n, (long)(2L * n))});
        conf.clearVariables();
        NeuralNetConfiguration c1 = conf.clone();
        NeuralNetConfiguration c2 = conf.clone();
        c1.setLayer(this.underlying);
        c2.setLayer(this.underlying);
        Map<String, INDArray> origFwd = this.underlying.initializer().init(c1, forwardView, initializeParams);
        Map<String, INDArray> origBwd = this.underlying.initializer().init(c2, backwardView, initializeParams);
        List<String> variables = this.addPrefixes(c1.getVariables(), c2.getVariables());
        conf.setVariables(variables);
        LinkedHashMap<String, INDArray> out = new LinkedHashMap<String, INDArray>();
        for (Map.Entry<String, INDArray> e : origFwd.entrySet()) {
            out.put(FORWARD_PREFIX + e.getKey(), e.getValue());
        }
        for (Map.Entry<String, INDArray> e : origBwd.entrySet()) {
            out.put(BACKWARD_PREFIX + e.getKey(), e.getValue());
        }
        return out;
    }

    private <T> Map<String, T> addPrefixes(Map<String, T> fwd, Map<String, T> bwd) {
        LinkedHashMap<CallSite, T> out = new LinkedHashMap<CallSite, T>();
        for (Map.Entry<String, T> e : fwd.entrySet()) {
            out.put((CallSite)((Object)(FORWARD_PREFIX + e.getKey())), e.getValue());
        }
        for (Map.Entry<String, T> e : bwd.entrySet()) {
            out.put((CallSite)((Object)(BACKWARD_PREFIX + e.getKey())), e.getValue());
        }
        return out;
    }

    private List<String> addPrefixes(List<String> fwd, List<String> bwd) {
        ArrayList<String> out = new ArrayList<String>();
        for (String s : fwd) {
            out.add(FORWARD_PREFIX + s);
        }
        for (String s : bwd) {
            out.add(BACKWARD_PREFIX + s);
        }
        return out;
    }

    @Override
    public Map<String, INDArray> getGradientsFromFlattened(NeuralNetConfiguration conf, INDArray gradientView) {
        long n = gradientView.length() / 2L;
        INDArray gradientsViewReshape = gradientView.reshape(new long[]{gradientView.length()});
        INDArray forwardView = gradientsViewReshape.get(new INDArrayIndex[]{NDArrayIndex.interval((long)0L, (long)n)});
        INDArray backwardView = gradientsViewReshape.get(new INDArrayIndex[]{NDArrayIndex.interval((long)n, (long)(2L * n))});
        Map<String, INDArray> origFwd = this.underlying.initializer().getGradientsFromFlattened(conf, forwardView);
        Map<String, INDArray> origBwd = this.underlying.initializer().getGradientsFromFlattened(conf, backwardView);
        LinkedHashMap<String, INDArray> out = new LinkedHashMap<String, INDArray>();
        for (Map.Entry<String, INDArray> e : origFwd.entrySet()) {
            out.put(FORWARD_PREFIX + e.getKey(), e.getValue());
        }
        for (Map.Entry<String, INDArray> e : origBwd.entrySet()) {
            out.put(BACKWARD_PREFIX + e.getKey(), e.getValue());
        }
        return out;
    }

    private Layer underlying(Layer layer) {
        Bidirectional b = (Bidirectional)layer;
        return b.getFwd();
    }

    private List<String> withPrefixes(List<String> orig) {
        ArrayList<String> out = new ArrayList<String>();
        for (String s : orig) {
            out.add(FORWARD_PREFIX + s);
        }
        for (String s : orig) {
            out.add(BACKWARD_PREFIX + s);
        }
        return out;
    }
}

