/*
 * Decompiled with CFR 0.152.
 */
package org.apache.celeborn.plugin.flink.readclient;

import java.io.IOException;
import java.lang.invoke.LambdaMetafactory;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashSet;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.function.Function;
import org.apache.celeborn.client.ShuffleClientImpl;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.exception.CelebornIOException;
import org.apache.celeborn.common.exception.DriverChangedException;
import org.apache.celeborn.common.exception.PartitionUnRetryAbleException;
import org.apache.celeborn.common.identity.UserIdentifier;
import org.apache.celeborn.common.network.TransportContext;
import org.apache.celeborn.common.network.buffer.NettyManagedBuffer;
import org.apache.celeborn.common.network.client.RpcResponseCallback;
import org.apache.celeborn.common.network.client.TransportClient;
import org.apache.celeborn.common.network.client.TransportClientFactory;
import org.apache.celeborn.common.network.protocol.PushData;
import org.apache.celeborn.common.network.protocol.TransportMessage;
import org.apache.celeborn.common.network.util.TransportConf;
import org.apache.celeborn.common.protocol.MessageType;
import org.apache.celeborn.common.protocol.PartitionLocation;
import org.apache.celeborn.common.protocol.PbChangeLocationPartitionInfo;
import org.apache.celeborn.common.protocol.PbChangeLocationResponse;
import org.apache.celeborn.common.protocol.PbPartitionLocation;
import org.apache.celeborn.common.protocol.PbPushDataHandShake;
import org.apache.celeborn.common.protocol.PbRegionFinish;
import org.apache.celeborn.common.protocol.PbRegionStart;
import org.apache.celeborn.common.protocol.ReviveRequest;
import org.apache.celeborn.common.protocol.message.ControlMessages$Revive$;
import org.apache.celeborn.common.protocol.message.StatusCode;
import org.apache.celeborn.common.util.JavaUtils;
import org.apache.celeborn.common.util.PbSerDeUtils;
import org.apache.celeborn.common.util.Utils;
import org.apache.celeborn.common.write.PushState;
import org.apache.celeborn.plugin.flink.network.FlinkTransportClientFactory;
import org.apache.celeborn.plugin.flink.network.ReadClientHandler;
import org.apache.celeborn.plugin.flink.readclient.CelebornBufferStream;
import org.apache.celeborn.shaded.com.google.common.annotations.VisibleForTesting;
import org.apache.celeborn.shaded.com.google.common.util.concurrent.Uninterruptibles;
import org.apache.celeborn.shaded.io.netty.buffer.ByteBuf;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Tuple2;
import scala.reflect.ClassTag$;

public class FlinkShuffleClientImpl
extends ShuffleClientImpl {
    public static final Logger logger = LoggerFactory.getLogger(FlinkShuffleClientImpl.class);
    private static volatile FlinkShuffleClientImpl _instance;
    private static volatile boolean initialized;
    private FlinkTransportClientFactory flinkTransportClientFactory;
    private ReadClientHandler readClientHandler = new ReadClientHandler();
    private ConcurrentHashMap<String, TransportClient> currentClient = JavaUtils.newConcurrentHashMap();
    private long driverTimestamp;

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     * Converted monitor instructions to comments
     * Lifted jumps to return sites
     */
    public static FlinkShuffleClientImpl get(String appUniqueId, String driverHost, int port, long driverTimestamp, CelebornConf conf, UserIdentifier userIdentifier) throws DriverChangedException {
        if (null == _instance || !initialized || FlinkShuffleClientImpl._instance.driverTimestamp < driverTimestamp) {
            Class<FlinkShuffleClientImpl> clazz = FlinkShuffleClientImpl.class;
            // MONITORENTER : org.apache.celeborn.plugin.flink.readclient.FlinkShuffleClientImpl.class
            if (null == _instance) {
                _instance = new FlinkShuffleClientImpl(appUniqueId, driverHost, port, driverTimestamp, conf, userIdentifier);
                initialized = true;
            } else if (!initialized || FlinkShuffleClientImpl._instance.driverTimestamp < driverTimestamp) {
                _instance.shutdown();
                _instance = new FlinkShuffleClientImpl(appUniqueId, driverHost, port, driverTimestamp, conf, userIdentifier);
                initialized = true;
            }
            // MONITOREXIT : clazz
        }
        if (driverTimestamp >= FlinkShuffleClientImpl._instance.driverTimestamp) return _instance;
        String format = "Driver reinitialized or changed driverHost-port-driverTimestamp to %s-%s-%s";
        String message = String.format(format, driverHost, port, driverTimestamp);
        logger.warn(message);
        throw new DriverChangedException(message);
    }

    @Override
    public void shutdown() {
        super.shutdown();
        if (this.flinkTransportClientFactory != null) {
            this.flinkTransportClientFactory.close();
        }
        if (this.readClientHandler != null) {
            this.readClientHandler.close();
        }
    }

    public FlinkShuffleClientImpl(String appUniqueId, String driverHost, int port, long driverTimestamp, CelebornConf conf, UserIdentifier userIdentifier) {
        super(appUniqueId, conf, userIdentifier);
        String module = "data";
        TransportConf dataTransportConf = Utils.fromCelebornConf(conf, module, conf.getInt("celeborn." + module + ".io.threads", 8));
        TransportContext context = new TransportContext(dataTransportConf, this.readClientHandler, conf.clientCloseIdleConnections());
        this.flinkTransportClientFactory = new FlinkTransportClientFactory(context, conf.clientFetchMaxRetriesForEachReplica());
        this.setupLifecycleManagerRef(driverHost, port);
        this.driverTimestamp = driverTimestamp;
    }

    public CelebornBufferStream readBufferedPartition(int shuffleId, int partitionId, int subPartitionIndexStart, int subPartitionIndexEnd) throws IOException {
        String shuffleKey = Utils.makeShuffleKey(this.appUniqueId, shuffleId);
        ShuffleClientImpl.ReduceFileGroups fileGroups = this.updateFileGroup(shuffleId, partitionId);
        if (fileGroups.partitionGroups.size() == 0 || !fileGroups.partitionGroups.containsKey(partitionId)) {
            logger.error("Shuffle data is empty for shuffle {} partitionId {}.", (Object)shuffleId, (Object)partitionId);
            throw new PartitionUnRetryAbleException(partitionId + " may be lost.");
        }
        PartitionLocation[] partitionLocations = fileGroups.partitionGroups.get(partitionId).toArray(new PartitionLocation[0]);
        Arrays.sort(partitionLocations, Comparator.comparingInt(PartitionLocation::getEpoch));
        logger.debug("readBufferedPartition shuffleKey:{} partitionid:{} partitionLocation:{}", new Object[]{shuffleKey, partitionId, partitionLocations});
        return CelebornBufferStream.create(this, this.flinkTransportClientFactory, shuffleKey, partitionLocations, subPartitionIndexStart, subPartitionIndexEnd);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    protected ShuffleClientImpl.ReduceFileGroups updateFileGroup(int shuffleId, int partitionId) throws CelebornIOException {
        ShuffleClientImpl.ReduceFileGroups reduceFileGroups = (ShuffleClientImpl.ReduceFileGroups)this.reduceFileGroupsMap.computeIfAbsent(Integer.valueOf((int)shuffleId), (Function<Integer, Tuple2>)LambdaMetafactory.metafactory(null, null, null, (Ljava/lang/Object;)Ljava/lang/Object;, lambda$updateFileGroup$0(java.lang.Integer ), (Ljava/lang/Integer;)Lscala/Tuple2;)())._1;
        if (reduceFileGroups.partitionIds != null && reduceFileGroups.partitionIds.contains(partitionId)) {
            logger.debug("use cached file groups for partition: {}", (Object)Utils.makeReducerKey(shuffleId, partitionId));
        } else {
            ShuffleClientImpl.ReduceFileGroups reduceFileGroups2 = reduceFileGroups;
            synchronized (reduceFileGroups2) {
                if (reduceFileGroups.partitionIds != null && reduceFileGroups.partitionIds.contains(partitionId)) {
                    logger.debug("use cached file groups for partition: {}", (Object)Utils.makeReducerKey(shuffleId, partitionId));
                } else {
                    Tuple2<ShuffleClientImpl.ReduceFileGroups, String> fileGroups = this.loadFileGroupInternal(shuffleId);
                    ShuffleClientImpl.ReduceFileGroups newGroups = (ShuffleClientImpl.ReduceFileGroups)fileGroups._1;
                    if (newGroups == null) {
                        throw new CelebornIOException(this.loadFileGroupException(shuffleId, partitionId, (String)fileGroups._2));
                    }
                    if (!newGroups.partitionIds.contains(partitionId)) {
                        throw new CelebornIOException(String.format("Shuffle data lost for shuffle %d partition %d.", shuffleId, partitionId));
                    }
                    reduceFileGroups.update(newGroups);
                }
            }
        }
        return reduceFileGroups;
    }

    public ReadClientHandler getReadClientHandler() {
        return this.readClientHandler;
    }

    public int pushDataToLocation(final int shuffleId, final int mapId, final int attemptId, int partitionId, ByteBuf data, final PartitionLocation location, Runnable closeCallBack) throws IOException {
        String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId);
        final PushState pushState = this.getPushState(mapKey);
        final int nextBatchId = pushState.nextBatchId();
        int totalLength = data.readableBytes();
        data.markWriterIndex();
        data.writerIndex(0);
        data.writeInt(partitionId);
        data.writeInt(attemptId);
        data.writeInt(nextBatchId);
        data.writeInt(totalLength - 16);
        data.resetWriterIndex();
        logger.debug("Do push data byteBuf size {} for app {} shuffle {} map {} attempt {} reduce {} batch {}.", new Object[]{totalLength, this.appUniqueId, shuffleId, mapId, attemptId, partitionId, nextBatchId});
        this.limitMaxInFlight(mapKey, pushState, location.hostAndPushPort());
        pushState.addBatch(nextBatchId, location.hostAndPushPort());
        NettyManagedBuffer buffer = new NettyManagedBuffer(data);
        String shuffleKey = Utils.makeShuffleKey(this.appUniqueId, shuffleId);
        PushData pushData = new PushData(PRIMARY_MODE, shuffleKey, location.getUniqueId(), buffer);
        RpcResponseCallback callback = new RpcResponseCallback(){

            @Override
            public void onSuccess(ByteBuffer response) {
                pushState.removeBatch(nextBatchId, location.hostAndPushPort());
                logger.debug("Push data byteBuf to {} success for shuffle {} map {} attemptId {} batch {}.", new Object[]{location.hostAndPushPort(), shuffleId, mapId, attemptId, nextBatchId});
            }

            @Override
            public void onFailure(Throwable e) {
                pushState.removeBatch(nextBatchId, location.hostAndPushPort());
                if (pushState.exception.get() != null) {
                    return;
                }
                String errorMsg = String.format("Push data byteBuf to %s failed for shuffle %d map %d attempt %d batch %d.", location.hostAndPushPort(), shuffleId, mapId, attemptId, nextBatchId);
                pushState.exception.compareAndSet(null, new CelebornIOException(errorMsg, e));
            }
        };
        try {
            TransportClient client = this.createClientWaitingInFlightRequest(location, mapKey, pushState);
            client.pushData(pushData, this.pushDataTimeout, callback, closeCallBack);
        }
        catch (Exception e) {
            logger.error("Exception raised while pushing data byteBuf for shuffle {} map {} attempt {} partitionId {} batch {} location {}.", new Object[]{shuffleId, mapId, attemptId, partitionId, nextBatchId, location, e});
            callback.onFailure(new CelebornIOException(StatusCode.PUSH_DATA_CREATE_CONNECTION_FAIL_PRIMARY, (Throwable)e));
        }
        return totalLength;
    }

    private TransportClient createClientWaitingInFlightRequest(PartitionLocation location, String mapKey, PushState pushState) throws IOException, InterruptedException {
        TransportClient client = this.dataClientFactory.createClient(location.getHost(), location.getPushPort(), location.getId());
        if (this.currentClient.get(mapKey) != client) {
            if (this.currentClient.get(mapKey) != null) {
                this.limitZeroInFlight(mapKey, pushState);
            }
            this.currentClient.put(mapKey, client);
        }
        return this.currentClient.get(mapKey);
    }

    public Optional<PartitionLocation> pushDataHandShake(int shuffleId, int mapId, int attemptId, int numPartitions, int bufferSize, PartitionLocation location) throws IOException {
        String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId);
        PushState pushState = this.pushStates.computeIfAbsent(mapKey, s -> new PushState(this.conf));
        return this.retrySendMessage(() -> {
            ByteBuffer pushDataHandShakeResponse;
            String shuffleKey = Utils.makeShuffleKey(this.appUniqueId, shuffleId);
            logger.info("PushDataHandShake shuffleKey {} attemptId {} locationId {}", new Object[]{shuffleKey, attemptId, location.getUniqueId()});
            logger.debug("PushDataHandShake location {}", (Object)location);
            TransportClient client = this.createClientWaitingInFlightRequest(location, mapKey, pushState);
            try {
                pushDataHandShakeResponse = client.sendRpcSync(new TransportMessage(MessageType.PUSH_DATA_HAND_SHAKE, PbPushDataHandShake.newBuilder().setMode(PbPartitionLocation.Mode.forNumber(PRIMARY_MODE)).setShuffleKey(shuffleKey).setPartitionUniqueId(location.getUniqueId()).setAttemptId(attemptId).setNumPartitions(numPartitions).setBufferSize(bufferSize).build().toByteArray()).toByteBuffer(), this.conf.pushDataTimeoutMs());
            }
            catch (IOException e) {
                return this.revive(shuffleId, mapId, attemptId, location);
            }
            if (pushDataHandShakeResponse.hasRemaining() && pushDataHandShakeResponse.get() == StatusCode.HARD_SPLIT.getValue()) {
                return this.revive(shuffleId, mapId, attemptId, location);
            }
            return Optional.empty();
        });
    }

    public Optional<PartitionLocation> regionStart(int shuffleId, int mapId, int attemptId, PartitionLocation location, int currentRegionIdx, boolean isBroadcast) throws IOException {
        String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId);
        PushState pushState = this.pushStates.computeIfAbsent(mapKey, s -> new PushState(this.conf));
        return this.retrySendMessage(() -> {
            ByteBuffer regionStartResponse;
            String shuffleKey = Utils.makeShuffleKey(this.appUniqueId, shuffleId);
            logger.info("RegionStart for shuffle {} regionId {} attemptId {} locationId {}.", new Object[]{shuffleId, currentRegionIdx, attemptId, location.getUniqueId()});
            logger.debug("RegionStart  for location {}.", (Object)location.toString());
            TransportClient client = this.createClientWaitingInFlightRequest(location, mapKey, pushState);
            try {
                regionStartResponse = client.sendRpcSync(new TransportMessage(MessageType.REGION_START, PbRegionStart.newBuilder().setMode(PbPartitionLocation.Mode.forNumber(PRIMARY_MODE)).setShuffleKey(shuffleKey).setPartitionUniqueId(location.getUniqueId()).setAttemptId(attemptId).setCurrentRegionIndex(currentRegionIdx).setIsBroadcast(isBroadcast).build().toByteArray()).toByteBuffer(), this.conf.pushDataTimeoutMs());
            }
            catch (IOException e) {
                return this.revive(shuffleId, mapId, attemptId, location);
            }
            if (regionStartResponse.hasRemaining() && regionStartResponse.get() == StatusCode.HARD_SPLIT.getValue()) {
                return this.revive(shuffleId, mapId, attemptId, location);
            }
            return Optional.empty();
        });
    }

    public Optional<PartitionLocation> revive(int shuffleId, int mapId, int attemptId, PartitionLocation location) throws CelebornIOException {
        HashSet<Integer> mapIds = new HashSet<Integer>();
        mapIds.add(mapId);
        ArrayList<ReviveRequest> requests = new ArrayList<ReviveRequest>();
        ReviveRequest req = new ReviveRequest(shuffleId, mapId, attemptId, location.getId(), location.getEpoch(), location, StatusCode.HARD_SPLIT);
        requests.add(req);
        PbChangeLocationResponse response = (PbChangeLocationResponse)this.lifecycleManagerRef.askSync(ControlMessages$Revive$.MODULE$.apply(shuffleId, mapIds, requests), this.conf.clientRpcRequestPartitionLocationAskTimeout(), ClassTag$.MODULE$.apply(PbChangeLocationResponse.class));
        PbChangeLocationPartitionInfo partitionInfo = response.getPartitionInfo(0);
        StatusCode respStatus = Utils.toStatusCode(partitionInfo.getStatus());
        if (StatusCode.SUCCESS.equals((Object)respStatus)) {
            logger.debug("revive new partition:{}", (Object)partitionInfo.getPartition());
            return Optional.of(PbSerDeUtils.fromPbPartitionLocation(partitionInfo.getPartition()));
        }
        logger.error("Exception raised while reviving for shuffle {} map {} attemptId {} partition {} epoch {}.", new Object[]{shuffleId, mapId, attemptId, location.getId(), location.getEpoch()});
        throw new CelebornIOException("RegionStart revive failed");
    }

    public void regionFinish(int shuffleId, int mapId, int attemptId, PartitionLocation location) throws IOException {
        String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId);
        PushState pushState = this.pushStates.computeIfAbsent(mapKey, s -> new PushState(this.conf));
        this.retrySendMessage(() -> {
            String shuffleKey = Utils.makeShuffleKey(this.appUniqueId, shuffleId);
            logger.info("RegionFinish for shuffle {} map {} attemptId {} locationId {}.", new Object[]{shuffleId, mapId, attemptId, location.getUniqueId()});
            logger.debug("RegionFinish for location {}.", (Object)location);
            TransportClient client = this.createClientWaitingInFlightRequest(location, mapKey, pushState);
            client.sendRpcSync(new TransportMessage(MessageType.REGION_FINISH, PbRegionFinish.newBuilder().setMode(PbPartitionLocation.Mode.forNumber(PRIMARY_MODE)).setShuffleKey(shuffleKey).setPartitionUniqueId(location.getUniqueId()).setAttemptId(attemptId).build().toByteArray()).toByteBuffer(), this.conf.pushDataTimeoutMs());
            return null;
        });
    }

    private <R> R retrySendMessage(ThrowingExceptionSupplier<R, Exception> supplier) throws IOException {
        int retryTimes = 0;
        boolean isSuccess = false;
        Throwable currentException = null;
        R result = null;
        while (!Thread.currentThread().isInterrupted() && !isSuccess && retryTimes < this.conf.networkIoMaxRetries("push")) {
            logger.debug("RetrySendMessage  retry times {}.", (Object)retryTimes);
            try {
                result = supplier.get();
                isSuccess = true;
            }
            catch (Exception e) {
                currentException = e;
                if (e instanceof InterruptedException) {
                    Thread.currentThread().interrupt();
                }
                if (!this.shouldRetry(e)) break;
                ++retryTimes;
                Uninterruptibles.sleepUninterruptibly(this.conf.networkIoRetryWaitMs("push"), TimeUnit.MILLISECONDS);
            }
        }
        if (!isSuccess) {
            if (currentException instanceof IOException) {
                throw (IOException)currentException;
            }
            throw new CelebornIOException(currentException.getMessage(), currentException);
        }
        return result;
    }

    private boolean shouldRetry(Throwable e) {
        boolean isIOException = e instanceof IOException || e instanceof TimeoutException || e.getCause() != null && e.getCause() instanceof TimeoutException || e.getCause() != null && e.getCause() instanceof IOException || e instanceof RuntimeException && e.getMessage() != null && e.getMessage().startsWith(IOException.class.getName());
        return isIOException;
    }

    @Override
    public void cleanup(int shuffleId, int mapId, int attemptId) {
        String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId);
        super.cleanup(shuffleId, mapId, attemptId);
        if (this.currentClient != null) {
            this.currentClient.remove(mapKey);
        }
    }

    public void setDataClientFactory(TransportClientFactory dataClientFactory) {
        this.dataClientFactory = dataClientFactory;
    }

    @Override
    @VisibleForTesting
    public TransportClientFactory getDataClientFactory() {
        return this.flinkTransportClientFactory;
    }

    private static /* synthetic */ Tuple2 lambda$updateFileGroup$0(Integer id) {
        return Tuple2.apply((Object)new ShuffleClientImpl.ReduceFileGroups(), null);
    }

    static {
        initialized = false;
    }

    @FunctionalInterface
    static interface ThrowingExceptionSupplier<R, E extends Exception> {
        public R get() throws E;
    }
}

