/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.preprocessing.encoding.target;

import java.io.Serializable;
import java.util.Set;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.preprocessing.Preprocessor;
import org.apache.ignite.ml.preprocessing.encoding.EncoderPreprocessor;
import org.apache.ignite.ml.preprocessing.encoding.target.TargetEncodingMeta;
import org.apache.ignite.ml.structures.LabeledVector;

public class TargetEncoderPreprocessor<K, V>
extends EncoderPreprocessor<K, V> {
    protected static final long serialVersionUID = 6237711236382623481L;
    protected final TargetEncodingMeta[] targetCounters;

    public TargetEncoderPreprocessor(TargetEncodingMeta[] targetCounters, Preprocessor<K, V> basePreprocessor, Set<Integer> handledIndices) {
        super(null, basePreprocessor, handledIndices);
        this.targetCounters = targetCounters;
    }

    @Override
    public LabeledVector apply(K k, V v) {
        LabeledVector tmp = (LabeledVector)this.basePreprocessor.apply(k, v);
        double[] res = new double[tmp.size()];
        for (int i = 0; i < res.length; ++i) {
            Serializable tmpObj = tmp.getRaw(i);
            if (this.handledIndices.contains(i)) {
                if (this.targetCounters[i].getCategoryMean().containsKey(tmpObj.toString())) {
                    res[i] = this.targetCounters[i].getCategoryMean().get(tmpObj.toString());
                    continue;
                }
                res[i] = this.targetCounters[i].getGlobalMean();
                continue;
            }
            res[i] = (Double)tmpObj;
        }
        return new LabeledVector(VectorUtils.of(res), tmp.label());
    }
}

