package smile.clustering;

import java.util.Arrays;
import java.util.stream.IntStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.math.MathEx;
import smile.math.distance.HammingDistance;
import smile.util.IntSet;

/* loaded from: input_file:smile/clustering/KModes.class */
public class KModes extends CentroidClustering<int[], int[]> {
    private static final long serialVersionUID = 2;
    private static final Logger logger = LoggerFactory.getLogger(KModes.class);

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:smile/clustering/KModes$Codec.class */
    public static class Codec {
        public final int k;
        public final int[] x;
        public final IntSet encoder;

        public Codec(int[] iArr) {
            int[] unique = MathEx.unique(iArr);
            Arrays.sort(unique);
            this.x = iArr;
            this.k = unique.length;
            this.encoder = new IntSet(unique);
            if (unique[0] == 0 && unique[this.k - 1] == this.k - 1) {
                return;
            }
            int length = iArr.length;
            for (int i = 0; i < length; i++) {
                iArr[i] = this.encoder.indexOf(iArr[i]);
            }
        }

        public int valueOf(int i) {
            return this.encoder.valueOf(i);
        }
    }

    public KModes(double d, int[][] iArr, int[] iArr2) {
        super(d, iArr, iArr2);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // smile.clustering.CentroidClustering
    public double distance(int[] iArr, int[] iArr2) {
        return HammingDistance.d(iArr, iArr2);
    }

    public static KModes fit(int[][] iArr, int i) {
        return fit(iArr, i, 100);
    }

    /* JADX WARN: Type inference failed for: r0v16, types: [java.lang.Object[], int[]] */
    public static KModes fit(int[][] iArr, int i, int i2) {
        if (i < 2) {
            throw new IllegalArgumentException("Invalid number of clusters: " + i);
        }
        if (i2 <= 0) {
            throw new IllegalArgumentException("Invalid maximum number of iterations: " + i2);
        }
        int length = iArr.length;
        int length2 = iArr[0].length;
        Codec[] codecArr = (Codec[]) IntStream.range(0, length2).parallel().mapToObj(i3 -> {
            int[] iArr2 = new int[length];
            for (int i3 = 0; i3 < length; i3++) {
                iArr2[i3] = iArr[i3][i3];
            }
            return new Codec(iArr2);
        }).toArray(i4 -> {
            return new Codec[i4];
        });
        int[] iArr2 = new int[length];
        int[][] iArr3 = new int[i][length2];
        double sum = MathEx.sum(seed(iArr, new int[i], iArr2, HammingDistance::d));
        logger.info(String.format("Distortion after initialization: %d", Integer.valueOf((int) sum)));
        double d = 2.147483647E9d;
        for (int i5 = 1; i5 <= i2 && d > 0.0d; i5++) {
            updateCentroids(iArr3, iArr, iArr2, codecArr);
            double assign = assign(iArr2, iArr, iArr3, HammingDistance::d);
            logger.info(String.format("Distortion after %3d iterations: %d", Integer.valueOf(i5), Integer.valueOf((int) assign)));
            d = sum - assign;
            sum = assign;
        }
        if (d > 0.0d) {
            updateCentroids(iArr3, iArr, iArr2, codecArr);
        }
        return new KModes(sum, iArr3, iArr2);
    }

    private static void updateCentroids(int[][] iArr, int[][] iArr2, int[] iArr3, Codec[] codecArr) {
        int length = iArr2.length;
        int length2 = iArr.length;
        int length3 = iArr[0].length;
        IntStream.range(0, length2).parallel().forEach(i -> {
            int[] iArr4 = iArr[i];
            for (int i = 0; i < length3; i++) {
                if (codecArr[i].k > 1) {
                    int[] iArr5 = new int[codecArr[i].k];
                    int[] iArr6 = codecArr[i].x;
                    for (int i2 = 0; i2 < length; i2++) {
                        if (iArr3[i2] == i) {
                            int i3 = iArr6[i2];
                            iArr5[i3] = iArr5[i3] + 1;
                        }
                    }
                    iArr4[i] = codecArr[i].valueOf(MathEx.whichMax(iArr5));
                }
            }
        });
    }
}
