/*
 * Decompiled with CFR 0.152.
 */
package com.aliasi.classify;

import com.aliasi.classify.BaseClassifierEvaluator;
import com.aliasi.classify.Classification;
import com.aliasi.classify.Classified;
import com.aliasi.classify.RankedClassification;
import com.aliasi.classify.RankedClassifier;

public class RankedClassifierEvaluator<E>
extends BaseClassifierEvaluator<E> {
    boolean mDefectiveRanking = false;
    private final int[][] mRankCounts;

    public RankedClassifierEvaluator(RankedClassifier<E> classifier, String[] categories, boolean storeInputs) {
        super(classifier, categories, storeInputs);
        int len = categories.length;
        this.mRankCounts = new int[len][len];
    }

    @Override
    public void setClassifier(RankedClassifier<E> classifier) {
        if (!this.getClass().equals(RankedClassifierEvaluator.class)) {
            String msg = "Require appropriate classifier type. Evaluator class=" + this.getClass() + " Found classifier.class=" + classifier.getClass();
            throw new IllegalArgumentException(msg);
        }
        this.mClassifier = classifier;
    }

    @Override
    public RankedClassifier<E> classifier() {
        RankedClassifier result = (RankedClassifier)super.classifier();
        return result;
    }

    @Override
    public void handle(Classified<E> classified) {
        E input = classified.getObject();
        Classification refClassification = classified.getClassification();
        String refCategory = refClassification.bestCategory();
        this.validateCategory(refCategory);
        RankedClassification classification = this.classifier().classify(input);
        this.addClassification(refCategory, classification, input);
        this.addRanking(refCategory, classification);
    }

    void addRanking(String refCategory, RankedClassification ranking) {
        int refCategoryIndex = this.categoryToIndex(refCategory);
        if (ranking.size() < this.numCategories()) {
            this.mDefectiveRanking = true;
        }
        for (int rank = 0; rank < this.numCategories() && rank < ranking.size(); ++rank) {
            String category = ranking.category(rank);
            if (!category.equals(refCategory)) continue;
            int[] nArray = this.mRankCounts[refCategoryIndex];
            int n = rank;
            nArray[n] = nArray[n] + 1;
            return;
        }
        int[] nArray = this.mRankCounts[refCategoryIndex];
        int n = this.mCategories.length - 1;
        nArray[n] = nArray[n] + 1;
    }

    public int rankCount(String referenceCategory, int rank) {
        this.validateCategory(referenceCategory);
        int i = this.categoryToIndex(referenceCategory);
        return this.mRankCounts[i][rank];
    }

    public double averageRankReference() {
        double sum = 0.0;
        int count = 0;
        for (int i = 0; i < this.numCategories(); ++i) {
            for (int rank = 0; rank < this.numCategories(); ++rank) {
                int rankCount = this.mRankCounts[i][rank];
                if (rankCount == 0) continue;
                count += rankCount;
                sum += (double)(rank * rankCount);
            }
        }
        return sum / (double)count;
    }

    public double meanReciprocalRank() {
        double sum = 0.0;
        int numCases = 0;
        for (int i = 0; i < this.numCategories(); ++i) {
            for (int rank = 0; rank < this.numCategories(); ++rank) {
                int rankCount = this.mRankCounts[i][rank];
                if (rankCount == 0) continue;
                numCases += rankCount;
                sum += (double)rankCount / (1.0 + (double)rank);
            }
        }
        return sum / (double)numCases;
    }

    public double averageRank(String refCategory, String responseCategory) {
        this.validateCategory(refCategory);
        this.validateCategory(responseCategory);
        double sum = 0.0;
        int count = 0;
        for (int i = 0; i < this.mReferenceCategories.size(); ++i) {
            if (!((String)this.mReferenceCategories.get(i)).equals(refCategory)) continue;
            RankedClassification rankedClassification = (RankedClassification)this.mClassifications.get(i);
            int rank = this.getRank(rankedClassification, responseCategory);
            sum += (double)rank;
            ++count;
        }
        return sum / (double)count;
    }

    int categoryToIndex(String category) {
        int result = this.confusionMatrix().getIndex(category);
        if (result < 0) {
            String msg = "Unknown category=" + category;
            throw new IllegalArgumentException(msg);
        }
        return result;
    }

    int getRank(RankedClassification classification, String responseCategory) {
        for (int rank = 0; rank < classification.size(); ++rank) {
            if (!classification.category(rank).equals(responseCategory)) continue;
            return rank;
        }
        return this.mCategories.length - 1;
    }

    @Override
    void baseToString(StringBuilder sb) {
        super.baseToString(sb);
        sb.append("Average Reference Rank=" + this.averageRankReference() + "\n");
    }

    @Override
    void oneVsAllToString(StringBuilder sb, String category, int i) {
        super.oneVsAllToString(sb, category, i);
        sb.append("Rank Histogram=\n");
        this.appendCategoryLine(sb);
        for (int rank = 0; rank < this.numCategories(); ++rank) {
            if (rank > 0) {
                sb.append(',');
            }
            sb.append(this.mRankCounts[i][rank]);
        }
        sb.append("\n");
        sb.append("Average Rank Histogram=\n");
        this.appendCategoryLine(sb);
        for (int j = 0; j < this.numCategories(); ++j) {
            if (j > 0) {
                sb.append(',');
            }
            sb.append(this.averageRank(category, this.categories()[j]));
        }
        sb.append("\n");
    }

    void appendCategoryLine(StringBuilder sb) {
        sb.append("  ");
        for (int i = 0; i < this.numCategories(); ++i) {
            if (i > 0) {
                sb.append(',');
            }
            sb.append(this.categories()[i]);
        }
        sb.append("\n  ");
    }
}

