/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.datavec.api.transform.sequence.window;

import lombok.Data;
import lombok.EqualsAndHashCode;
import org.datavec.api.transform.ColumnType;
import org.datavec.api.transform.metadata.ColumnMetaData;
import org.datavec.api.transform.metadata.TimeMetaData;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.schema.SequenceSchema;
import org.datavec.api.writable.LongWritable;
import org.datavec.api.writable.Writable;
import org.joda.time.DateTimeZone;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
import org.nd4j.shade.jackson.annotation.JsonProperty;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.TimeUnit;

@JsonIgnoreProperties({"inputSchema", "offsetAmountMilliseconds", "windowSizeMilliseconds",
                "windowSeparationMilliseconds", "timeZone"})
@EqualsAndHashCode(exclude = {"inputSchema", "offsetAmountMilliseconds", "windowSizeMilliseconds",
                "windowSeparationMilliseconds", "timeZone"})
@Data
public class OverlappingTimeWindowFunction implements WindowFunction {

    private final String timeColumn;
    private final long windowSize;
    private final TimeUnit windowSizeUnit;
    private final long windowSeparation;
    private final TimeUnit windowSeparationUnit;
    private final long offsetAmount;
    private final TimeUnit offsetUnit;
    private final boolean addWindowStartTimeColumn;
    private final boolean addWindowEndTimeColumn;
    private final boolean excludeEmptyWindows;
    private Schema inputSchema;

    private final long offsetAmountMilliseconds;
    private final long windowSizeMilliseconds;
    private final long windowSeparationMilliseconds;

    private DateTimeZone timeZone;

    /**
     * Constructor with zero offset
     *
     * @param timeColumn           Name of the column that contains the time values (must be a time column)
     * @param windowSize           Numerical quantity for the size of the time window (used in conjunction with windowSizeUnit)
     * @param windowSizeUnit       Unit of the time window
     * @param windowSeparation     The separation between consecutive window start times (used in conjunction with WindowSeparationUnit)
     * @param windowSeparationUnit Unit for the separation between windows
     */
    public OverlappingTimeWindowFunction(String timeColumn, long windowSize, TimeUnit windowSizeUnit,
                    long windowSeparation, TimeUnit windowSeparationUnit) {
        this(timeColumn, windowSize, windowSizeUnit, windowSeparation, windowSeparationUnit, 0, null);
    }

    /**
     * Constructor with zero offset, ability to add window start/end time columns
     *
     * @param timeColumn           Name of the column that contains the time values (must be a time column)
     * @param windowSize           Numerical quantity for the size of the time window (used in conjunction with windowSizeUnit)
     * @param windowSizeUnit       Unit of the time window
     * @param windowSeparation     The separation between consecutive window start times (used in conjunction with WindowSeparationUnit)
     * @param windowSeparationUnit Unit for the separation between windows
     */
    public OverlappingTimeWindowFunction(String timeColumn, long windowSize, TimeUnit windowSizeUnit,
                    long windowSeparation, TimeUnit windowSeparationUnit, boolean addWindowStartTimeColumn,
                    boolean addWindowEndTimeColumn) {
        this(timeColumn, windowSize, windowSizeUnit, windowSeparation, windowSeparationUnit, 0, null,
                        addWindowStartTimeColumn, addWindowEndTimeColumn, false);
    }

    /**
     * Constructor with optional offset
     *
     * @param timeColumn           Name of the column that contains the time values (must be a time column)
     * @param windowSize           Numerical quantity for the size of the time window (used in conjunction with windowSizeUnit)
     * @param windowSizeUnit       Unit of the time window
     * @param windowSeparation     The separation between consecutive window start times (used in conjunction with WindowSeparationUnit)
     * @param windowSeparationUnit Unit for the separation between windows
     * @param offset               Optional offset amount, to shift start/end of the time window forward or back
     * @param offsetUnit           Optional offset unit for the offset amount.
     */
    public OverlappingTimeWindowFunction(String timeColumn, long windowSize, TimeUnit windowSizeUnit,
                    long windowSeparation, TimeUnit windowSeparationUnit, long offset, TimeUnit offsetUnit) {
        this(timeColumn, windowSize, windowSizeUnit, windowSeparation, windowSeparationUnit, offset, offsetUnit, false,
                        false, false);
    }

    /**
     * Constructor with optional offset, ability to add window start/end time columns
     *
     * @param timeColumn               Name of the column that contains the time values (must be a time column)
     * @param windowSize               Numerical quantity for the size of the time window (used in conjunction with windowSizeUnit)
     * @param windowSizeUnit           Unit of the time window
     * @param windowSeparation         The separation between consecutive window start times (used in conjunction with WindowSeparationUnit)
     * @param windowSeparationUnit     Unit for the separation between windows
     * @param offset                   Optional offset amount, to shift start/end of the time window forward or back
     * @param offsetUnit               Optional offset unit for the offset amount.
     * @param addWindowStartTimeColumn If true: add a time column (name: "windowStartTime") that contains the start time
     *                                 of the window
     * @param addWindowEndTimeColumn   If true: add a time column (name: "windowEndTime") that contains the end time
     *                                 of the window
     * @param excludeEmptyWindows      If true: exclude any windows that don't have any values in them
     */
    public OverlappingTimeWindowFunction(@JsonProperty("timeColumn") String timeColumn,
                    @JsonProperty("windowSize") long windowSize,
                    @JsonProperty("windowSizeUnit") TimeUnit windowSizeUnit,
                    @JsonProperty("windowSeparation") long windowSeparation,
                    @JsonProperty("windowSeparationUnit") TimeUnit windowSeparationUnit,
                    @JsonProperty("offset") long offset, @JsonProperty("offsetUnit") TimeUnit offsetUnit,
                    @JsonProperty("addWindowStartTimeColumn") boolean addWindowStartTimeColumn,
                    @JsonProperty("addWindowEndTimeColumn") boolean addWindowEndTimeColumn,
                    @JsonProperty("excludeEmptyWindows") boolean excludeEmptyWindows) {
        this.timeColumn = timeColumn;
        this.windowSize = windowSize;
        this.windowSizeUnit = windowSizeUnit;
        this.windowSeparation = windowSeparation;
        this.windowSeparationUnit = windowSeparationUnit;
        this.offsetAmount = offset;
        this.offsetUnit = offsetUnit;
        this.addWindowStartTimeColumn = addWindowStartTimeColumn;
        this.addWindowEndTimeColumn = addWindowEndTimeColumn;
        this.excludeEmptyWindows = excludeEmptyWindows;

        if (offsetAmount == 0 || offsetUnit == null)
            this.offsetAmountMilliseconds = 0;
        else {
            this.offsetAmountMilliseconds = TimeUnit.MILLISECONDS.convert(offset, offsetUnit);
        }

        this.windowSizeMilliseconds = TimeUnit.MILLISECONDS.convert(windowSize, windowSizeUnit);
        this.windowSeparationMilliseconds = TimeUnit.MILLISECONDS.convert(windowSeparation, windowSeparationUnit);
    }

    private OverlappingTimeWindowFunction(Builder builder) {
        this(builder.timeColumn, builder.windowSize, builder.windowSizeUnit, builder.windowSeparation,
                        builder.windowSeparationUnit, builder.offsetAmount, builder.offsetUnit,
                        builder.addWindowStartTimeColumn, builder.addWindowEndTimeColumn, builder.excludeEmptyWindows);
    }

    @Override
    public void setInputSchema(Schema schema) {
        if (!(schema instanceof SequenceSchema))
            throw new IllegalArgumentException(
                            "Invalid schema: OverlappingTimeWindowFunction can only operate on SequenceSchema");
        if (!schema.hasColumn(timeColumn))
            throw new IllegalStateException("Input schema does not have a column with name \"" + timeColumn + "\"");

        if (schema.getMetaData(timeColumn).getColumnType() != ColumnType.Time)
            throw new IllegalStateException("Invalid column: column \"" + timeColumn + "\" is not of type "
                            + ColumnType.Time + "; is " + schema.getMetaData(timeColumn).getColumnType());

        this.inputSchema = schema;

        timeZone = ((TimeMetaData) schema.getMetaData(timeColumn)).getTimeZone();
    }

    @Override
    public Schema getInputSchema() {
        return inputSchema;
    }

    @Override
    public Schema transform(Schema inputSchema) {
        if (!addWindowStartTimeColumn && !addWindowEndTimeColumn)
            return inputSchema;

        List<ColumnMetaData> newMeta = new ArrayList<>(inputSchema.getColumnMetaData());

        if (addWindowStartTimeColumn) {
            newMeta.add(new TimeMetaData("windowStartTime"));
        }

        if (addWindowEndTimeColumn) {
            newMeta.add(new TimeMetaData("windowEndTime"));
        }

        return inputSchema.newSchema(newMeta);
    }

    @Override
    public String toString() {
        return "OverlappingTimeWindowFunction(columnName=\"" + timeColumn + "\",windowSize=" + windowSize
                        + windowSizeUnit + ",windowSeparation=" + windowSeparation + windowSeparationUnit + ",offset="
                        + offsetAmount + (offsetAmount != 0 && offsetUnit != null ? offsetUnit : "")
                        + (addWindowStartTimeColumn ? ",addWindowStartTimeColumn=true" : "")
                        + (addWindowEndTimeColumn ? ",addWindowEndTimeColumn=true" : "")
                        + (excludeEmptyWindows ? ",excludeEmptyWindows=true" : "") + ")";
    }


    @Override
    public List<List<List<Writable>>> applyToSequence(List<List<Writable>> sequence) {

        int timeColumnIdx = inputSchema.getIndexOfColumn(this.timeColumn);

        List<List<List<Writable>>> out = new ArrayList<>();

        //We are assuming here that the sequence is already ordered (as is usually the case)

        //First: work out the window to start on. The window to start on is the first window that includes the first time step values
        long firstTimeStepTimePlusOffset = sequence.get(0).get(timeColumnIdx).toLong() + offsetAmountMilliseconds;
        long windowBorder = firstTimeStepTimePlusOffset - (firstTimeStepTimePlusOffset % windowSeparationMilliseconds); //Round down to time where a window starts/ends
        //At this windowBorder time: the window that _ends_ at windowBorder does NOT include the first time step
        // Therefore the window that ends at windowBorder+1*windowSeparation is first window that includes the first data point

        //Second: work out the window to end on. The window to end on is the last window that includes the last time step values
        long lastTimeStepTimePlusOffset =
                        sequence.get(sequence.size() - 1).get(timeColumnIdx).toLong() + offsetAmountMilliseconds;
        long windowBorderLastTimeStep =
                        lastTimeStepTimePlusOffset - (lastTimeStepTimePlusOffset % windowSeparationMilliseconds);
        //At this windowBorderLastTimeStep time: the window that _starts_ this time is the last window to include the last time step

        long lastWindowStartTime = windowBorderLastTimeStep;


        long currentWindowStartTime = windowBorder + windowSeparationMilliseconds - windowSizeMilliseconds;
        long nextWindowStartTime = currentWindowStartTime + windowSeparationMilliseconds;
        long currentWindowEndTime = currentWindowStartTime + windowSizeMilliseconds;
        List<List<Writable>> currentWindow = new ArrayList<>();

        int currentWindowStartIdx = 0;
        int sequenceLength = sequence.size();
        boolean foundIndexForNextWindowStart = false;
        while (currentWindowStartTime <= lastWindowStartTime) {

            for (int i = currentWindowStartIdx; i < sequenceLength; i++) {
                List<Writable> timeStep = sequence.get(i);
                long currentTime = timeStep.get(timeColumnIdx).toLong();

                //As we go through: let's keep track of the index of the first element in the next window
                if (!foundIndexForNextWindowStart && currentTime >= nextWindowStartTime) {
                    foundIndexForNextWindowStart = true;
                    currentWindowStartIdx = i;
                }
                boolean nextWindow = false;
                if (currentTime < currentWindowEndTime) {
                    //This time step is included in the current window
                    if (addWindowStartTimeColumn || addWindowEndTimeColumn) {
                        List<Writable> timeStep2 = new ArrayList<>(timeStep);
                        if (addWindowStartTimeColumn)
                            timeStep2.add(new LongWritable(currentWindowStartTime));
                        if (addWindowEndTimeColumn)
                            timeStep2.add(new LongWritable(currentWindowStartTime + windowSizeMilliseconds));
                        currentWindow.add(timeStep2);
                    } else {
                        currentWindow.add(timeStep);
                    }
                } else {
                    //This time step is NOT included in the current window -> done with the current window -> start the next window
                    nextWindow = true;
                }

                //Once we reach the end of the input sequence: we might have added it to the current time step, but still
                // need to create the next window
                if (i == sequenceLength - 1)
                    nextWindow = true;

                if (nextWindow) {
                    if (!(excludeEmptyWindows && currentWindow.size() == 0))
                        out.add(currentWindow);
                    currentWindow = new ArrayList<>();
                    currentWindowStartTime = currentWindowStartTime + windowSeparationMilliseconds;
                    currentWindowEndTime = currentWindowStartTime + windowSizeMilliseconds;
                    foundIndexForNextWindowStart = false;
                    nextWindowStartTime = currentWindowStartTime + windowSeparationMilliseconds;
                    break;
                }
            }
        }

        return out;
    }

    public static class Builder {
        private String timeColumn;
        private long windowSize = -1;
        private TimeUnit windowSizeUnit;
        private long windowSeparation = -1;
        private TimeUnit windowSeparationUnit;
        private long offsetAmount;
        private TimeUnit offsetUnit;
        private boolean addWindowStartTimeColumn = false;
        private boolean addWindowEndTimeColumn = false;
        private boolean excludeEmptyWindows = false;

        public Builder timeColumn(String timeColumn) {
            this.timeColumn = timeColumn;
            return this;
        }

        public Builder windowSize(long windowSize, TimeUnit windowSizeUnit) {
            this.windowSize = windowSize;
            this.windowSizeUnit = windowSizeUnit;
            return this;
        }

        public Builder windowSeparation(long windowSeparation, TimeUnit windowSeparationUnit) {
            this.windowSeparation = windowSeparation;
            this.windowSeparationUnit = windowSeparationUnit;
            return this;
        }

        public Builder offset(long offsetAmount, TimeUnit offsetUnit) {
            this.offsetAmount = offsetAmount;
            this.offsetUnit = offsetUnit;
            return this;
        }

        public Builder addWindowStartTimeColumn(boolean addWindowStartTimeColumn) {
            this.addWindowStartTimeColumn = addWindowStartTimeColumn;
            return this;
        }

        public Builder addWindowEndTimeColumn(boolean addWindowEndTimeColumn) {
            this.addWindowEndTimeColumn = addWindowEndTimeColumn;
            return this;
        }

        public Builder excludeEmptyWindows(boolean excludeEmptyWindows) {
            this.excludeEmptyWindows = excludeEmptyWindows;
            return this;
        }

        public OverlappingTimeWindowFunction build() {
            if (timeColumn == null)
                throw new IllegalStateException("Time column is null (not specified)");
            if (windowSize == -1 || windowSizeUnit == null)
                throw new IllegalStateException("Window size/unit not set");
            if (windowSeparation == -1 || windowSeparationUnit == null)
                throw new IllegalStateException("Window separation and/or unit not set");
            return new OverlappingTimeWindowFunction(this);
        }
    }
}
