package edu.berkeley.nlp.classify;

import edu.berkeley.nlp.math.SloppyMath;
import edu.berkeley.nlp.util.Counter;
import edu.berkeley.nlp.util.CounterMap;
import java.util.ArrayList;
import java.util.List;

/* loaded from: input_file:edu/berkeley/nlp/classify/NaiveBayesClassifier.class */
public class NaiveBayesClassifier<I, F, L> implements ProbabilisticClassifier<I, L> {
    private CounterMap<L, F> featureProbs;
    private Counter<F> backoffProbs;
    private Counter<L> labelProbs;
    private FeatureExtractor<I, F> featureExtractor;
    private double alpha = 0.1d;

    /* loaded from: input_file:edu/berkeley/nlp/classify/NaiveBayesClassifier$Factory.class */
    public static class Factory<I, F, L> implements ProbabilisticClassifierFactory<I, L> {
        private FeatureExtractor<I, F> featureExtractor;

        public Factory(FeatureExtractor<I, F> featureExtractor) {
            this.featureExtractor = featureExtractor;
        }

        @Override // edu.berkeley.nlp.classify.ProbabilisticClassifierFactory
        public ProbabilisticClassifier<I, L> trainClassifier(List<LabeledInstance<I, L>> list) {
            CounterMap counterMap = new CounterMap();
            Counter counter = new Counter();
            Counter counter2 = new Counter();
            for (LabeledInstance<I, L> labeledInstance : list) {
                L label = labeledInstance.getLabel();
                counter2.incrementCount(label, 1.0d);
                Counter<F> extractFeatures = this.featureExtractor.extractFeatures(labeledInstance.getInput());
                for (F f : extractFeatures.keySet()) {
                    double count = extractFeatures.getCount(f);
                    counter.incrementCount(f, count);
                    counterMap.incrementCount(label, f, count);
                }
            }
            counterMap.normalize();
            counter2.normalize();
            counter.normalize();
            return new NaiveBayesClassifier(counterMap, counter, counter2, this.featureExtractor);
        }
    }

    @Override // edu.berkeley.nlp.classify.ProbabilisticClassifier
    public Counter<L> getProbabilities(I i) {
        Counter<L> counter = new Counter<>();
        ArrayList arrayList = new ArrayList();
        for (L l : this.labelProbs.keySet()) {
            double log = Math.log(this.labelProbs.getCount(l));
            Counter<F> extractFeatures = this.featureExtractor.extractFeatures(i);
            for (F f : extractFeatures.keySet()) {
                log += extractFeatures.getCount(f) * Math.log(getFeatureProb(f, l));
            }
            arrayList.add(Double.valueOf(log));
            counter.setCount(l, log);
        }
        double logAdd = SloppyMath.logAdd(arrayList);
        for (L l2 : this.labelProbs.keySet()) {
            counter.setCount(l2, Math.exp(counter.getCount(l2) - logAdd));
        }
        return counter;
    }

    private double getFeatureProb(F f, L l) {
        return ((1.0d - this.alpha) * this.featureProbs.getCount(l, f)) + (this.alpha * this.backoffProbs.getCount(f));
    }

    @Override // edu.berkeley.nlp.classify.Classifier
    public L getLabel(I i) {
        return getProbabilities(i).argMax();
    }

    public NaiveBayesClassifier(CounterMap<L, F> counterMap, Counter<F> counter, Counter<L> counter2, FeatureExtractor<I, F> featureExtractor) {
        this.featureProbs = counterMap;
        this.backoffProbs = counter;
        this.labelProbs = counter2;
        this.featureExtractor = featureExtractor;
    }
}
