/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.preprocessing.encoding;

import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.UpstreamEntry;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.preprocessing.PreprocessingTrainer;
import org.apache.ignite.ml.preprocessing.encoding.EncoderPartitionData;
import org.apache.ignite.ml.preprocessing.encoding.EncoderPreprocessor;
import org.apache.ignite.ml.preprocessing.encoding.EncoderType;
import org.apache.ignite.ml.preprocessing.encoding.onehotencoder.OneHotEncoderPreprocessor;
import org.apache.ignite.ml.preprocessing.encoding.stringencoder.StringEncoderPreprocessor;
import org.jetbrains.annotations.NotNull;

public class EncoderTrainer<K, V>
implements PreprocessingTrainer<K, V, Object[], Vector> {
    private Set<Integer> handledIndices = new HashSet<Integer>();
    private EncoderType encoderType = EncoderType.ONE_HOT_ENCODER;

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    public EncoderPreprocessor<K, V> fit(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Object[]> basePreprocessor) {
        if (this.handledIndices.isEmpty()) {
            throw new RuntimeException("Add indices of handled features");
        }
        try (Dataset<EmptyContext, EncoderPartitionData> dataset = datasetBuilder.build((upstream, upstreamSize) -> new EmptyContext(), (upstream, upstreamSize, ctx) -> {
            Map<String, Integer>[] categoryFrequencies = null;
            while (upstream.hasNext()) {
                UpstreamEntry entity = (UpstreamEntry)upstream.next();
                Object[] row = (Object[])basePreprocessor.apply(entity.getKey(), entity.getValue());
                categoryFrequencies = this.calculateFrequencies(row, categoryFrequencies);
            }
            return new EncoderPartitionData().withCategoryFrequencies(categoryFrequencies);
        });){
            Map<String, Integer>[] encodingValues = this.calculateEncodingValuesByFrequencies(dataset);
            switch (this.encoderType) {
                case ONE_HOT_ENCODER: {
                    OneHotEncoderPreprocessor<K, V> oneHotEncoderPreprocessor = new OneHotEncoderPreprocessor<K, V>(encodingValues, basePreprocessor, this.handledIndices);
                    return oneHotEncoderPreprocessor;
                }
                case STRING_ENCODER: {
                    StringEncoderPreprocessor<K, V> stringEncoderPreprocessor = new StringEncoderPreprocessor<K, V>(encodingValues, basePreprocessor, this.handledIndices);
                    return stringEncoderPreprocessor;
                }
            }
            throw new IllegalStateException("Define the type of the resulting prerocessor.");
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private Map<String, Integer>[] calculateEncodingValuesByFrequencies(Dataset<EmptyContext, EncoderPartitionData> dataset) {
        Map[] frequencies = (Map[])dataset.compute(EncoderPartitionData::categoryFrequencies, (a, b) -> {
            if (a == null) {
                return b;
            }
            if (b == null) {
                return a;
            }
            assert (((Map[])a).length == ((Map[])b).length);
            for (int i = 0; i < ((Map[])a).length; ++i) {
                if (!this.handledIndices.contains(i)) continue;
                int finalI = i;
                a[i].forEach((k, v) -> b[finalI].merge(k, v, (f1, f2) -> f1 + f2));
            }
            return b;
        });
        HashMap[] res = new HashMap[frequencies.length];
        for (int i = 0; i < frequencies.length; ++i) {
            if (!this.handledIndices.contains(i)) continue;
            res[i] = this.transformFrequenciesToEncodingValues(frequencies[i]);
        }
        return res;
    }

    private Map<String, Integer> transformFrequenciesToEncodingValues(Map<String, Integer> frequencies) {
        HashMap resMap = frequencies.entrySet().stream().sorted(Map.Entry.comparingByValue()).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (oldValue, newValue) -> oldValue, LinkedHashMap::new));
        int amountOfLabels = frequencies.size();
        for (Map.Entry entry : resMap.entrySet()) {
            entry.setValue(--amountOfLabels);
        }
        return resMap;
    }

    private Map<String, Integer>[] calculateFrequencies(Object[] row, Map<String, Integer>[] categoryFrequencies) {
        if (categoryFrequencies == null) {
            categoryFrequencies = this.initializeCategoryFrequencies(row);
        } else assert (categoryFrequencies.length == row.length) : "Base preprocessor must return exactly " + categoryFrequencies.length + " features";
        for (int i = 0; i < categoryFrequencies.length; ++i) {
            String strVal;
            if (!this.handledIndices.contains(i)) continue;
            Object featureVal = row[i];
            if (featureVal.equals(Double.NaN)) {
                strVal = "";
                row[i] = strVal;
            } else if (featureVal instanceof String) {
                strVal = (String)featureVal;
            } else if (featureVal instanceof Double) {
                strVal = String.valueOf(featureVal);
            } else {
                throw new RuntimeException("The type " + featureVal.getClass() + " is not supported for the feature values.");
            }
            Map<String, Integer> map = categoryFrequencies[i];
            if (map.containsKey(strVal)) {
                map.put(strVal, map.get(strVal) + 1);
                continue;
            }
            map.put(strVal, 1);
        }
        return categoryFrequencies;
    }

    @NotNull
    private Map<String, Integer>[] initializeCategoryFrequencies(Object[] row) {
        HashMap[] categoryFrequencies = new HashMap[row.length];
        for (int i = 0; i < categoryFrequencies.length; ++i) {
            if (!this.handledIndices.contains(i)) continue;
            categoryFrequencies[i] = new HashMap();
        }
        return categoryFrequencies;
    }

    public EncoderTrainer<K, V> withEncodedFeature(int idx) {
        this.handledIndices.add(idx);
        return this;
    }

    public EncoderTrainer<K, V> withEncoderType(EncoderType type) {
        this.encoderType = type;
        return this;
    }

    public EncoderTrainer<K, V> withEncodedFeatures(Set<Integer> handledIndices) {
        this.handledIndices = handledIndices;
        return this;
    }
}

