package org.apache.mahout.clustering.lda.cvb;

import java.io.IOException;
import java.util.Random;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.DoubleWritable;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.mahout.common.MemoryUtil;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/mahout/clustering/lda/cvb/CachingCVB0PerplexityMapper.class */
public class CachingCVB0PerplexityMapper extends Mapper<IntWritable, VectorWritable, DoubleWritable, DoubleWritable> {
    private static final Logger log = LoggerFactory.getLogger(CachingCVB0PerplexityMapper.class);
    private ModelTrainer modelTrainer;
    private TopicModel readModel;
    private int maxIters;
    private int numTopics;
    private float testFraction;
    private Random random;
    private Vector topicVector;
    private final DoubleWritable outKey = new DoubleWritable();
    private final DoubleWritable outValue = new DoubleWritable();

    /* loaded from: input_file:org/apache/mahout/clustering/lda/cvb/CachingCVB0PerplexityMapper$Counters.class */
    public enum Counters {
        SAMPLED_DOCUMENTS
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.hadoop.mapreduce.Mapper
    public void setup(Mapper<IntWritable, VectorWritable, DoubleWritable, DoubleWritable>.Context context) throws IOException, InterruptedException {
        MemoryUtil.startMemoryLogger(5000L);
        log.info("Retrieving configuration");
        Configuration configuration = context.getConfiguration();
        float f = configuration.getFloat(CVB0Driver.TERM_TOPIC_SMOOTHING, Float.NaN);
        float f2 = configuration.getFloat(CVB0Driver.DOC_TOPIC_SMOOTHING, Float.NaN);
        long j = configuration.getLong(CVB0Driver.RANDOM_SEED, 1234L);
        this.random = RandomUtils.getRandom(j);
        this.numTopics = configuration.getInt(CVB0Driver.NUM_TOPICS, -1);
        int i = configuration.getInt(CVB0Driver.NUM_TERMS, -1);
        int i2 = configuration.getInt(CVB0Driver.NUM_UPDATE_THREADS, 1);
        int i3 = configuration.getInt(CVB0Driver.NUM_TRAIN_THREADS, 4);
        this.maxIters = configuration.getInt(CVB0Driver.MAX_ITERATIONS_PER_DOC, 10);
        float f3 = configuration.getFloat(CVB0Driver.MODEL_WEIGHT, 1.0f);
        this.testFraction = configuration.getFloat(CVB0Driver.TEST_SET_FRACTION, 0.1f);
        log.info("Initializing read model");
        Path[] modelPaths = CVB0Driver.getModelPaths(configuration);
        if (modelPaths == null || modelPaths.length <= 0) {
            log.info("No model files found");
            this.readModel = new TopicModel(this.numTopics, i, f, f2, RandomUtils.getRandom(j), null, i3, f3);
        } else {
            this.readModel = new TopicModel(configuration, f, f2, (String[]) null, i2, f3, modelPaths);
        }
        log.info("Initializing model trainer");
        this.modelTrainer = new ModelTrainer(this.readModel, null, i3, this.numTopics, i);
        log.info("Initializing topic vector");
        this.topicVector = new DenseVector(new double[this.numTopics]);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.hadoop.mapreduce.Mapper
    public void cleanup(Mapper<IntWritable, VectorWritable, DoubleWritable, DoubleWritable>.Context context) throws IOException, InterruptedException {
        this.readModel.stop();
        MemoryUtil.stopMemoryLogger();
    }

    @Override // org.apache.hadoop.mapreduce.Mapper
    public void map(IntWritable intWritable, VectorWritable vectorWritable, Mapper<IntWritable, VectorWritable, DoubleWritable, DoubleWritable>.Context context) throws IOException, InterruptedException {
        if (this.testFraction >= 1.0f || this.random.nextFloat() < this.testFraction) {
            context.getCounter(Counters.SAMPLED_DOCUMENTS).increment(1L);
            this.outKey.set(vectorWritable.get().norm(1.0d));
            this.outValue.set(this.modelTrainer.calculatePerplexity(vectorWritable.get(), this.topicVector.assign(1.0d / this.numTopics), this.maxIters));
            context.write(this.outKey, this.outValue);
        }
    }
}
