/*
 * Decompiled with CFR 0.152.
 */
package com.knuddels.jtokkit;

import com.knuddels.jtokkit.ByteArrayList;
import com.knuddels.jtokkit.SpecialEncoder;
import com.knuddels.jtokkit.TokenEncoder;
import com.knuddels.jtokkit.api.Encoding;
import com.knuddels.jtokkit.api.EncodingResult;
import com.knuddels.jtokkit.api.GptBytePairEncodingParams;
import com.knuddels.jtokkit.api.IntArrayList;
import java.nio.charset.StandardCharsets;
import java.util.Objects;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

class GptBytePairEncoding
implements Encoding {
    final TokenEncoder encoder;
    private final String name;
    private final Pattern pattern;
    private final SpecialEncoder specialEncoder;

    GptBytePairEncoding(GptBytePairEncodingParams params) {
        this.name = params.getName();
        this.pattern = params.getPattern();
        this.encoder = new TokenEncoder(params.getEncoder());
        this.specialEncoder = new SpecialEncoder(params.getSpecialTokensEncoder());
    }

    @Override
    public IntArrayList encode(String text) {
        return this.encode(text, Integer.MAX_VALUE).getTokens();
    }

    @Override
    public EncodingResult encode(String text, int maxTokenCount) {
        return this.encodeInternal(text, maxTokenCount, true).toEncodingResult();
    }

    private InternalResult encodeInternal(String text, int maxTokenCount, boolean keepEncodings) {
        if (text == null) {
            return new InternalResult(new IntArrayList(0), false);
        }
        this.specialEncoder.checkForSpecialTokens(text);
        return this.encodeOrdinaryInternal(text, maxTokenCount, keepEncodings);
    }

    @Override
    public IntArrayList encodeOrdinary(String text) {
        return this.encodeOrdinary(text, Integer.MAX_VALUE).getTokens();
    }

    @Override
    public EncodingResult encodeOrdinary(String text, int maxTokenCount) {
        return this.encodeOrdinaryInternal(text, maxTokenCount, true).toEncodingResult();
    }

    private InternalResult encodeOrdinaryInternal(String text, int maxTokenCount, boolean keepEncodings) {
        if (text == null) {
            return new InternalResult(new IntArrayList(0), false);
        }
        IntArrayList out = new IntArrayList();
        int tokenCount = this.encodeOrdinaryInternal(text, maxTokenCount, keepEncodings, out);
        if (keepEncodings && maxTokenCount != Integer.MAX_VALUE) {
            for (int tokensToRemove = 0; tokensToRemove <= out.size(); ++tokensToRemove) {
                int size = out.size() - tokensToRemove;
                IntArrayList tokens = new IntArrayList(size);
                for (int i = 0; i < size; ++i) {
                    tokens.add(out.get(i));
                }
                String decoded = this.decode(tokens);
                if (!text.startsWith(decoded)) continue;
                return new InternalResult(tokens, text.length() > decoded.length());
            }
        }
        return new InternalResult(out, tokenCount, false);
    }

    int encodeOrdinaryInternal(String text, int maxTokenCount, boolean keepEncodings, IntArrayList out) {
        int tokenCount;
        byte[] bytes;
        IntArrayList ranks = new IntArrayList();
        Matcher matcher = this.pattern.matcher(text);
        for (tokenCount = 0; tokenCount < maxTokenCount && matcher.find(); tokenCount += this.encoder.addTokensAndGetCount(maxTokenCount, keepEncodings, bytes, out, ranks)) {
            bytes = matcher.group().getBytes(StandardCharsets.UTF_8);
        }
        return tokenCount;
    }

    @Override
    public int countTokens(String text) {
        return this.encodeInternal(text, Integer.MAX_VALUE, false).toTokenCount();
    }

    @Override
    public int countTokensOrdinary(String text) {
        return this.encodeOrdinaryInternal(text, Integer.MAX_VALUE, false).toTokenCount();
    }

    @Override
    public String decode(IntArrayList tokens) {
        return new String(this.decodeBytes(tokens), StandardCharsets.UTF_8);
    }

    @Override
    public byte[] decodeBytes(IntArrayList tokens) {
        ByteArrayList out = new ByteArrayList(10 * tokens.size());
        for (int i = 0; i < tokens.size(); ++i) {
            byte[] decodedToken;
            for (byte b : decodedToken = this.decodeToken(tokens.get(i))) {
                out.add(b);
            }
        }
        return out.toArray();
    }

    @Override
    public String getName() {
        return this.name;
    }

    private byte[] decodeToken(int token) {
        byte[] decodedToken = this.encoder.decodeToken(token, this.specialEncoder);
        return Objects.requireNonNull(decodedToken, "Unknown token for decoding: " + token);
    }

    private static final class InternalResult {
        private final IntArrayList tokens;
        private final boolean truncated;
        private final int tokenCount;

        private InternalResult(IntArrayList tokens, boolean truncated) {
            this(tokens, -1, truncated);
        }

        private InternalResult(IntArrayList tokens, int tokenCount, boolean truncated) {
            this.tokens = tokens;
            this.truncated = truncated;
            this.tokenCount = tokenCount < 0 ? tokens.size() : tokenCount;
        }

        private EncodingResult toEncodingResult() {
            if (this.tokens.size() != this.tokenCount) {
                throw new IllegalStateException("Token count does not match token list size (tokenCount=" + this.tokenCount + ", tokens size=" + this.tokens.size() + ")");
            }
            return new EncodingResult(this.tokens, this.truncated);
        }

        private int toTokenCount() {
            return this.tokenCount;
        }
    }
}

