/*
 * Decompiled with CFR 0.152.
 */
package org.apache.iotdb.db.protocol.client;

import java.io.IOException;
import java.util.Map;
import java.util.Optional;
import org.apache.iotdb.common.rpc.thrift.TEndPoint;
import org.apache.iotdb.common.rpc.thrift.TSStatus;
import org.apache.iotdb.commons.client.property.ClientPoolProperty;
import org.apache.iotdb.commons.conf.CommonDescriptor;
import org.apache.iotdb.commons.model.ModelInformation;
import org.apache.iotdb.mlnode.rpc.thrift.IMLNodeRPCService;
import org.apache.iotdb.mlnode.rpc.thrift.TCreateTrainingTaskReq;
import org.apache.iotdb.mlnode.rpc.thrift.TDeleteModelReq;
import org.apache.iotdb.mlnode.rpc.thrift.TForecastReq;
import org.apache.iotdb.mlnode.rpc.thrift.TForecastResp;
import org.apache.iotdb.rpc.TConfigurationConst;
import org.apache.iotdb.rpc.TSStatusCode;
import org.apache.iotdb.tsfile.read.common.block.TsBlock;
import org.apache.iotdb.tsfile.read.common.block.column.TsBlockSerde;
import org.apache.thrift.TException;
import org.apache.thrift.protocol.TCompactProtocol;
import org.apache.thrift.transport.TSocket;
import org.apache.thrift.transport.TTransport;
import org.apache.thrift.transport.TTransportException;
import org.apache.thrift.transport.layered.TFramedTransport;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MLNodeClient
implements AutoCloseable {
    private static final Logger logger = LoggerFactory.getLogger(MLNodeClient.class);
    private final TTransport transport;
    private final IMLNodeRPCService.Client client;
    public static final String MSG_CONNECTION_FAIL = "Fail to connect to MLNode. Please check status of MLNode";
    private final TsBlockSerde tsBlockSerde = new TsBlockSerde();

    public MLNodeClient() throws TException {
        TEndPoint endpoint = CommonDescriptor.getInstance().getConfig().getTargetMLNodeEndPoint();
        try {
            long connectionTimeout = ClientPoolProperty.DefaultProperty.WAIT_CLIENT_TIMEOUT_MS;
            this.transport = new TFramedTransport.Factory().getTransport((TTransport)new TSocket(TConfigurationConst.defaultTConfiguration, endpoint.getIp(), endpoint.getPort(), (int)connectionTimeout));
            if (!this.transport.isOpen()) {
                this.transport.open();
            }
        }
        catch (TTransportException e) {
            throw new TException(MSG_CONNECTION_FAIL);
        }
        TCompactProtocol.Factory protocolFactory = new TCompactProtocol.Factory();
        this.client = new IMLNodeRPCService.Client(protocolFactory.getProtocol(this.transport));
    }

    public TSStatus createTrainingTask(ModelInformation modelInformation, Map<String, String> modelConfigs) throws TException {
        try {
            TCreateTrainingTaskReq req = new TCreateTrainingTaskReq(modelInformation.getModelId(), modelInformation.isAuto(), modelConfigs, modelInformation.getQueryExpressions());
            if (modelInformation.getQueryFilter() != null) {
                req.setQueryFilter(modelInformation.getQueryFilter());
            }
            return this.client.createTrainingTask(req);
        }
        catch (TException e) {
            logger.warn("Failed to connect to MLNode from ConfigNode when executing {}", (Object)Thread.currentThread().getStackTrace()[1].getMethodName());
            throw new TException(MSG_CONNECTION_FAIL);
        }
    }

    public TSStatus deleteModel(String modelId) throws TException {
        try {
            return this.client.deleteModel(new TDeleteModelReq(modelId));
        }
        catch (TException e) {
            logger.warn("Failed to connect to MLNode from ConfigNode when executing {}", (Object)Thread.currentThread().getStackTrace()[1].getMethodName());
            throw new TException(MSG_CONNECTION_FAIL);
        }
    }

    public TsBlock forecast(String modelPath, TsBlock inputTsBlock) throws TException {
        try {
            TForecastReq forecastReq = new TForecastReq(modelPath, this.tsBlockSerde.serialize(inputTsBlock));
            TForecastResp resp = this.client.forecast(forecastReq);
            if (resp.status.code != TSStatusCode.SUCCESS_STATUS.getStatusCode()) {
                throw new TException("Failed to execute forecast task, because: " + resp.status.message);
            }
            return this.tsBlockSerde.deserialize(resp.forecastResult);
        }
        catch (IOException e) {
            throw new TException("An exception occurred while serializing input tsblock", (Throwable)e);
        }
        catch (TException e) {
            logger.warn("Failed to connect to MLNode from DataNode when executing {}", (Object)Thread.currentThread().getStackTrace()[1].getMethodName());
            throw new TException(MSG_CONNECTION_FAIL);
        }
    }

    @Override
    public void close() throws Exception {
        Optional.ofNullable(this.transport).ifPresent(TTransport::close);
    }
}

