/*
 * Decompiled with CFR 0.152.
 */
package org.languagetool.languagemodel.bert;

import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import io.grpc.ManagedChannel;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import io.grpc.netty.shaded.io.grpc.netty.GrpcSslContexts;
import io.grpc.netty.shaded.io.grpc.netty.NegotiationType;
import io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder;
import io.grpc.netty.shaded.io.netty.handler.ssl.SslContextBuilder;
import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.stream.Collectors;
import javax.net.ssl.SSLException;
import org.jetbrains.annotations.Nullable;
import org.languagetool.languagemodel.bert.grpc.BertLmGrpc;
import org.languagetool.languagemodel.bert.grpc.BertLmProto;

public class RemoteLanguageModel {
    private final BertLmGrpc.BertLmBlockingStub model;
    private final ManagedChannel channel;
    private final Cache<Request, List<Double>> cache = CacheBuilder.newBuilder().maximumSize(1000L).build();

    public RemoteLanguageModel(String host, int port, boolean useSSL, @Nullable String clientPrivateKey, @Nullable String clientCertificate, @Nullable String rootCertificate) throws SSLException {
        this.channel = this.getChannel(host, port, useSSL, clientPrivateKey, clientCertificate, rootCertificate);
        this.model = BertLmGrpc.newBlockingStub(this.channel);
    }

    private ManagedChannel getChannel(String host, int port, boolean useSSL, @Nullable String clientPrivateKey, @Nullable String clientCertificate, @Nullable String rootCertificate) throws SSLException {
        NettyChannelBuilder channelBuilder = NettyChannelBuilder.forAddress(host, port);
        if (useSSL) {
            SslContextBuilder sslContextBuilder = GrpcSslContexts.forClient();
            if (rootCertificate != null) {
                sslContextBuilder.trustManager(new File(rootCertificate));
            }
            if (clientCertificate != null && clientPrivateKey != null) {
                sslContextBuilder.keyManager(new File(clientCertificate), new File(clientPrivateKey));
            }
            channelBuilder = channelBuilder.negotiationType(NegotiationType.TLS).sslContext(sslContextBuilder.build());
        } else {
            channelBuilder = channelBuilder.usePlaintext();
        }
        return channelBuilder.build();
    }

    public void shutdown() {
        if (this.channel != null) {
            this.channel.shutdownNow();
        }
    }

    public List<List<Double>> batchScore(List<Request> requests, long timeoutMilliseconds) throws TimeoutException {
        List nonCacheResult;
        HashMap<Request, List<Double>> cachedRequests = new HashMap<Request, List<Double>>();
        ArrayList<Request> uncachedRequests = new ArrayList<Request>();
        for (Request request : requests) {
            List<Double> result = this.cache.getIfPresent(request);
            if (result == null) {
                uncachedRequests.add(request);
                continue;
            }
            cachedRequests.put(request, result);
        }
        BertLmProto.BatchScoreRequest batch = BertLmProto.BatchScoreRequest.newBuilder().addAllRequests(uncachedRequests.stream().map(Request::convert).collect(Collectors.toList())).build();
        try {
            BertLmGrpc.BertLmBlockingStub stub = timeoutMilliseconds > 0L ? (BertLmGrpc.BertLmBlockingStub)this.model.withDeadlineAfter(timeoutMilliseconds, TimeUnit.MILLISECONDS) : this.model;
            nonCacheResult = stub.batchScore(batch).getResponsesList().stream().map(r -> r.getScoresList().get(0).getScoreList()).collect(Collectors.toList());
        }
        catch (StatusRuntimeException e) {
            if (e.getStatus().getCode() == Status.DEADLINE_EXCEEDED.getCode()) {
                throw new TimeoutException(e.getMessage());
            }
            throw e;
        }
        ArrayList<List<Double>> allResults = new ArrayList<List<Double>>();
        int i = 0;
        for (Request request : requests) {
            List result = (List)cachedRequests.get(request);
            if (result != null) {
                allResults.add(result);
                continue;
            }
            allResults.add((List)nonCacheResult.get(i++));
        }
        int j = 0;
        for (List re : nonCacheResult) {
            this.cache.put((Request)uncachedRequests.get(j), re);
            ++j;
        }
        return allResults;
    }

    public List<Double> score(Request req) {
        return this.model.score(req.convert()).getScoresList().get(0).getScoreList();
    }

    public static class Request {
        public String text;
        public int start;
        public int end;
        public List<String> candidates;

        public Request(String text, int start, int end, List<String> candidates) {
            this.text = text;
            this.start = start;
            this.end = end;
            this.candidates = candidates;
        }

        public BertLmProto.ScoreRequest convert() {
            List<BertLmProto.Mask> masks = Arrays.asList(BertLmProto.Mask.newBuilder().setStart(this.start).setEnd(this.end).addAllCandidates(this.candidates).build());
            return BertLmProto.ScoreRequest.newBuilder().setText(this.text).addAllMask(masks).build();
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            Request request = (Request)o;
            return this.start == request.start && this.end == request.end && this.text.equals(request.text) && this.candidates.equals(request.candidates);
        }

        public int hashCode() {
            return Objects.hash(this.text, this.start, this.end, this.candidates);
        }
    }
}

