/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.classification;

import dev.langchain4j.classification.TextClassifier;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.store.embedding.CosineSimilarity;
import dev.langchain4j.store.embedding.RelevanceScore;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

public class EmbeddingModelTextClassifier<E extends Enum<E>>
implements TextClassifier<E> {
    private final EmbeddingModel embeddingModel;
    private final Map<E, List<Embedding>> exampleEmbeddingsByLabel;
    private final int maxResults;
    private final double minScore;
    private final double meanToMaxScoreRatio;

    public EmbeddingModelTextClassifier(EmbeddingModel embeddingModel, Map<E, ? extends Collection<String>> examplesByLabel) {
        this(embeddingModel, examplesByLabel, 1, 0.0, 0.5);
    }

    public EmbeddingModelTextClassifier(EmbeddingModel embeddingModel, Map<E, ? extends Collection<String>> examplesByLabel, int maxResults, double minScore, double meanToMaxScoreRatio) {
        this.embeddingModel = (EmbeddingModel)ValidationUtils.ensureNotNull((Object)embeddingModel, (String)"embeddingModel");
        ValidationUtils.ensureNotNull(examplesByLabel, (String)"examplesByLabel");
        this.exampleEmbeddingsByLabel = new HashMap<E, List<Embedding>>();
        examplesByLabel.forEach((label, examples) -> this.exampleEmbeddingsByLabel.put(label, examples.stream().map(example -> (Embedding)embeddingModel.embed(example).content()).collect(Collectors.toList())));
        this.maxResults = ValidationUtils.ensureGreaterThanZero((Integer)maxResults, (String)"maxResults");
        this.minScore = ValidationUtils.ensureBetween((Double)minScore, (double)0.0, (double)1.0, (String)"minScore");
        this.meanToMaxScoreRatio = ValidationUtils.ensureBetween((Double)meanToMaxScoreRatio, (double)0.0, (double)1.0, (String)"meanToMaxScoreRatio");
    }

    public List<E> classify(String text) {
        Embedding textEmbedding = (Embedding)this.embeddingModel.embed(text).content();
        ArrayList labelsWithScores = new ArrayList();
        this.exampleEmbeddingsByLabel.forEach((label, exampleEmbeddings) -> {
            double meanScore = 0.0;
            double maxScore = 0.0;
            for (Embedding exampleEmbedding : exampleEmbeddings) {
                double cosineSimilarity = CosineSimilarity.between((Embedding)textEmbedding, (Embedding)exampleEmbedding);
                double score = RelevanceScore.fromCosineSimilarity((double)cosineSimilarity);
                meanScore += score;
                maxScore = Math.max(score, maxScore);
            }
            labelsWithScores.add(new LabelWithScore(this, (Enum)label, this.aggregatedScore(meanScore /= (double)exampleEmbeddings.size(), maxScore)));
        });
        return labelsWithScores.stream().filter(it -> ((LabelWithScore)it).score >= this.minScore).sorted(Comparator.comparingDouble(labelWithScore -> 1.0 - ((LabelWithScore)labelWithScore).score)).limit(this.maxResults).map(it -> ((LabelWithScore)it).label).collect(Collectors.toList());
    }

    private double aggregatedScore(double meanScore, double maxScore) {
        return this.meanToMaxScoreRatio * meanScore + (1.0 - this.meanToMaxScoreRatio) * maxScore;
    }

    private class LabelWithScore {
        private final E label;
        private final double score;
        final /* synthetic */ EmbeddingModelTextClassifier this$0;

        /*
         * WARNING - Possible parameter corruption
         * WARNING - void declaration
         */
        private LabelWithScore(E score, double d2) {
            void label;
            this.this$0 = (EmbeddingModelTextClassifier)d;
            this.label = label;
            this.score = (double)score;
        }
    }
}

