/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.composition.boosting;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.ignite.ml.composition.boosting.GDBTrainer;
import org.apache.ignite.ml.composition.boosting.loss.LogLoss;
import org.apache.ignite.ml.composition.boosting.loss.Loss;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.structures.partition.LabeledDatasetPartitionDataBuilderOnHeap;

public abstract class GDBBinaryClassifierTrainer
extends GDBTrainer {
    private double externalFirstCls;
    private double externalSecondCls;

    public GDBBinaryClassifierTrainer(double gradStepSize, Integer cntOfIterations) {
        super(gradStepSize, cntOfIterations, new LogLoss());
    }

    public GDBBinaryClassifierTrainer(double gradStepSize, Integer cntOfIterations, Loss loss) {
        super(gradStepSize, cntOfIterations, loss);
    }

    @Override
    protected <V, K> boolean learnLabels(DatasetBuilder<K, V> builder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lExtractor) {
        Set uniqLabels = (Set)builder.build(new EmptyContextBuilder(), new LabeledDatasetPartitionDataBuilderOnHeap(featureExtractor, lExtractor)).compute(x -> Arrays.stream(x.labels()).boxed().collect(Collectors.toSet()), (a, b) -> {
            if (a == null) {
                return b;
            }
            if (b == null) {
                return a;
            }
            a.addAll(b);
            return a;
        });
        if (uniqLabels != null && uniqLabels.size() == 2) {
            ArrayList lblsArr = new ArrayList(uniqLabels);
            this.externalFirstCls = (Double)lblsArr.get(0);
            this.externalSecondCls = (Double)lblsArr.get(1);
            return true;
        }
        return false;
    }

    @Override
    protected double externalLabelToInternal(double x) {
        return x == this.externalFirstCls ? 0.0 : 1.0;
    }

    @Override
    protected double internalLabelToExternal(double indent) {
        double sigma = 1.0 / (1.0 + Math.exp(-indent));
        double internalCls = sigma < 0.5 ? 0.0 : 1.0;
        return internalCls == 0.0 ? this.externalFirstCls : this.externalSecondCls;
    }
}

