/*
 * Decompiled with CFR 0.152.
 */
package smile.feature.extraction;

import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Function;
import smile.data.DataFrame;
import smile.data.Tuple;
import smile.data.transform.Transform;
import smile.data.type.DataType;
import smile.data.type.DataTypes;
import smile.data.type.StructField;
import smile.data.type.StructType;
import smile.sort.QuickSort;

public class BagOfWords
implements Transform {
    private final Function<String, String[]> tokenizer;
    private final String[] words;
    private final Map<String, Integer> featureIndex;
    private final boolean binary;
    private final StructType schema;
    private final String[] columns;

    public BagOfWords(Function<String, String[]> tokenizer, String[] words) {
        this(null, tokenizer, words, false);
    }

    public BagOfWords(String[] columns, Function<String, String[]> tokenizer, String[] words, boolean binary) {
        this.columns = columns;
        this.tokenizer = tokenizer;
        this.binary = binary;
        this.words = words;
        this.featureIndex = new HashMap<String, Integer>();
        for (int i = 0; i < words.length; ++i) {
            if (this.featureIndex.containsKey(words[i])) {
                throw new IllegalArgumentException("Duplicated word:" + words[i]);
            }
            this.featureIndex.put(words[i], i);
        }
        StructField[] fields = (StructField[])Arrays.stream(words).map(word -> new StructField("BoW_" + word, (DataType)DataTypes.IntType)).toArray(StructField[]::new);
        this.schema = new StructType(fields);
    }

    public String[] features() {
        return this.words;
    }

    public static BagOfWords fit(DataFrame data, Function<String, String[]> tokenizer, int k, String ... columns) {
        HashMap<String, Integer> words = new HashMap<String, Integer>();
        for (String column : columns) {
            for (String text : data.column(column).toStringArray()) {
                for (String word : tokenizer.apply(text)) {
                    words.merge(word, 1, Integer::sum);
                }
            }
        }
        Object[] features = new String[words.size()];
        int[] count = new int[words.size()];
        int i = 0;
        for (String word : words.keySet()) {
            features[i] = word;
            count[i++] = -((Integer)words.get(word)).intValue();
        }
        QuickSort.sort((int[])count, (Object[])features);
        return new BagOfWords(columns, tokenizer, (String[])Arrays.copyOf(features, Math.min(k, features.length)), false);
    }

    public Tuple apply(Tuple x) {
        int[] bag = new int[this.featureIndex.size()];
        for (String column : this.columns) {
            for (String word : this.tokenizer.apply(x.getString(column))) {
                Integer index = this.featureIndex.get(word);
                if (index == null) continue;
                if (this.binary) {
                    bag[index.intValue()] = 1;
                    continue;
                }
                int n = index;
                bag[n] = bag[n] + 1;
            }
        }
        return Tuple.of((StructType)this.schema, (int[])bag);
    }

    public int[] apply(String text) {
        int[] bag = new int[this.featureIndex.size()];
        for (String word : this.tokenizer.apply(text)) {
            Integer index = this.featureIndex.get(word);
            if (index == null) continue;
            if (this.binary) {
                bag[index.intValue()] = 1;
                continue;
            }
            int n = index;
            bag[n] = bag[n] + 1;
        }
        return bag;
    }
}

