/*
 * Decompiled with CFR 0.152.
 */
package com.knuddels.jtokkit;

import com.knuddels.jtokkit.ByteArrayWrapper;
import com.knuddels.jtokkit.SpecialEncoder;
import com.knuddels.jtokkit.TokenEncoderLarge;
import com.knuddels.jtokkit.api.IntArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.TreeMap;

final class TokenEncoder {
    static final int MAX_RANK = 0x7FFFFFFE;
    private static final int DUMMY_RANK = Integer.MAX_VALUE;
    private final Map<ByteArrayWrapper, Integer>[] encoders;
    private final Map<Integer, byte[]> decoder;
    private int VERY_LARGE_TOKENIZER_BYTE_THRESHOLD;

    TokenEncoder(Map<byte[], Integer> encoder) {
        if (!encoder.isEmpty()) {
            this.VERY_LARGE_TOKENIZER_BYTE_THRESHOLD = Integer.parseInt(System.getProperty("VERY_LARGE_TOKENIZER_BYTE_THRESHOLD", "500"));
            TreeMap<Integer, Map> tempEncoders = new TreeMap<Integer, Map>();
            encoder.forEach((k, v) -> {
                ByteArrayWrapper key = new ByteArrayWrapper((byte[])k);
                tempEncoders.computeIfAbsent(((byte[])k).length, integer -> new HashMap()).put(key, v);
            });
            this.encoders = new Map[(Integer)tempEncoders.lastKey() + 1];
            tempEncoders.forEach((k, v) -> {
                this.encoders[k.intValue()] = v;
            });
            this.decoder = new HashMap<Integer, byte[]>(encoder.size());
            encoder.forEach((k, v) -> this.decoder.put((Integer)v, (byte[])k));
        } else {
            this.encoders = new Map[0];
            this.decoder = Collections.emptyMap();
        }
    }

    private static int getMinRankIndex(IntArrayList ranks) {
        int r;
        int i;
        int minRankIndex = -1;
        int minRank = 0x7FFFFFFE;
        int length = ranks.size() - 3;
        for (i = 0; i < length - 2; i += 4) {
            r = ranks.get(i);
            if (r < minRank) {
                minRankIndex = i;
                minRank = r;
            }
            if ((r = ranks.get(i + 1)) < minRank) {
                minRankIndex = i + 1;
                minRank = r;
            }
            if ((r = ranks.get(i + 2)) < minRank) {
                minRankIndex = i + 2;
                minRank = r;
            }
            if ((r = ranks.get(i + 3)) >= minRank) continue;
            minRankIndex = i + 3;
            minRank = r;
        }
        while (i <= length) {
            r = ranks.get(i);
            if (r < minRank) {
                minRankIndex = i;
                minRank = r;
            }
            ++i;
        }
        return minRankIndex;
    }

    private static int getNextIndex(IntArrayList ranks, int nextIndex) {
        while (nextIndex < ranks.size() && ranks.get(nextIndex) == Integer.MAX_VALUE) {
            ++nextIndex;
        }
        return nextIndex;
    }

    private static int getPreviousIndex(IntArrayList ranks, int previousIndex) {
        while (previousIndex >= 0 && ranks.get(previousIndex) == Integer.MAX_VALUE) {
            --previousIndex;
        }
        return previousIndex;
    }

    int addTokensAndGetCount(int maxTokenCount, boolean keepEncodings, byte[] byteArray, IntArrayList out, IntArrayList ranks) {
        ByteArrayWrapper match = new ByteArrayWrapper(byteArray);
        int encoded = this.encode(match);
        if (encoded != 0x7FFFFFFE) {
            if (keepEncodings) {
                out.add(encoded);
            }
            return 1;
        }
        if (match.length() < this.VERY_LARGE_TOKENIZER_BYTE_THRESHOLD) {
            return this.calculateTokensSmall(maxTokenCount, keepEncodings, out, ranks, match);
        }
        return TokenEncoderLarge.calculateTokensLarge(this, maxTokenCount, keepEncodings, out, match);
    }

    private int calculateTokensSmall(int maxTokenCount, boolean keepEncodings, IntArrayList out, IntArrayList ranks, ByteArrayWrapper match) {
        int length = match.length();
        assert (length > 1) : "Already filtered out";
        ranks.clear();
        ranks.ensureCapacity(length + 1);
        int minRankIndex = -1;
        int minRank = 0x7FFFFFFE;
        for (int i = 0; i < length + 1; ++i) {
            int encoded = this.encode(match, i, i + 2);
            if (encoded != 0x7FFFFFFE && encoded < minRank) {
                minRankIndex = i;
                minRank = encoded;
            }
            ranks.add(encoded);
        }
        int tokenCount = this.mergeBytesAndGetTokenCount(match, length, ranks, minRankIndex);
        if (keepEncodings) {
            int start = 0;
            for (int end = 1; end < ranks.size() && out.size() < maxTokenCount; ++end) {
                if (ranks.get(end) == Integer.MAX_VALUE) continue;
                int token = this.encode(match, start, end);
                assert (token != 0x7FFFFFFE) : "Token should not be MAX_RANK";
                out.add(token);
                start = end;
            }
        }
        return tokenCount;
    }

    int mergeBytesAndGetTokenCount(ByteArrayWrapper piece, int length, IntArrayList ranks, int minRankIndex) {
        assert (TokenEncoder.getMinRankIndex(ranks) == minRankIndex);
        while (minRankIndex >= 0) {
            int newRank;
            int previousIndex = TokenEncoder.getPreviousIndex(ranks, minRankIndex - 1);
            int nextIndex = TokenEncoder.getNextIndex(ranks, minRankIndex + 1);
            int nextNextIndex = TokenEncoder.getNextIndex(ranks, nextIndex + 1);
            int nextNextNextIndex = TokenEncoder.getNextIndex(ranks, nextNextIndex + 1);
            if (previousIndex >= 0) {
                assert (ranks.get(previousIndex) != Integer.MAX_VALUE);
                newRank = this.encode(piece, previousIndex, nextNextIndex);
                ranks.set(previousIndex, newRank);
            }
            assert (ranks.get(minRankIndex) != Integer.MAX_VALUE);
            newRank = this.encode(piece, minRankIndex, nextNextNextIndex);
            ranks.set(minRankIndex, newRank);
            ranks.set(nextIndex, Integer.MAX_VALUE);
            if (--length < 3) break;
            minRankIndex = TokenEncoder.getMinRankIndex(ranks);
        }
        assert (TokenEncoder.getMinRankIndex(ranks) < 0);
        return length;
    }

    private int encode(ByteArrayWrapper payload) {
        Integer result;
        Map<ByteArrayWrapper, Integer> encoder;
        if (payload.length() < this.encoders.length && (encoder = this.encoders[payload.length()]) != null && (result = encoder.get(payload)) != null) {
            return result;
        }
        return 0x7FFFFFFE;
    }

    int encode(ByteArrayWrapper piece, int start, int end) {
        if (end > piece.length() || end - start == piece.length()) {
            return 0x7FFFFFFE;
        }
        return this.encode(piece.getBytesBetween(start, end));
    }

    byte[] decodeToken(int token, SpecialEncoder specialEncoder) {
        return this.decoder.computeIfAbsent(token, specialEncoder::decodeIfPresent);
    }
}

