/*
 * Decompiled with CFR 0.152.
 */
package smile.base.mlp;

import smile.base.mlp.Cost;
import smile.base.mlp.Layer;
import smile.base.mlp.OutputFunction;
import smile.tensor.Vector;

public class OutputLayer
extends Layer {
    private static final long serialVersionUID = 2L;
    private final Cost cost;
    private final OutputFunction activation;

    public OutputLayer(int n, int p, OutputFunction activation, Cost cost) {
        super(n, p);
        switch (cost) {
            case MEAN_SQUARED_ERROR: {
                if (activation != OutputFunction.SOFTMAX) break;
                throw new IllegalArgumentException("Softmax output function is not allowed with mean squared error cost function");
            }
            case LIKELIHOOD: {
                if (activation != OutputFunction.LINEAR) break;
                throw new IllegalArgumentException("Linear output function is not allowed with likelihood cost function");
            }
        }
        this.activation = activation;
        this.cost = cost;
    }

    public String toString() {
        return String.format("%s(%d) | %s", new Object[]{this.activation.name(), this.n, this.cost});
    }

    public Cost cost() {
        return this.cost;
    }

    @Override
    public void transform(Vector x) {
        this.activation.f(x);
    }

    @Override
    public void backpropagate(Vector lowerLayerGradient) {
        this.weight.tv((Vector)this.outputGradient.get(), lowerLayerGradient);
    }

    public void computeOutputGradient(Vector target, double weight) {
        Vector output = (Vector)this.output.get();
        Vector outputGradient = (Vector)this.outputGradient.get();
        int n = output.size();
        if (target.size() != n) {
            throw new IllegalArgumentException(String.format("Invalid target vector size: %d, expected: %d", target.size(), n));
        }
        for (int i = 0; i < n; ++i) {
            outputGradient.set(i, target.get(i) - output.get(i));
        }
        this.activation.g(this.cost, outputGradient, output);
        if (weight > 0.0 && weight != 1.0) {
            outputGradient.scale(weight);
        }
    }
}

