/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.apache.iotdb.db.mpp.execution;

import org.apache.iotdb.commons.concurrent.IoTDBThreadPoolFactory;
import org.apache.iotdb.commons.exception.IllegalPathException;
import org.apache.iotdb.commons.exception.MetadataException;
import org.apache.iotdb.commons.path.MeasurementPath;
import org.apache.iotdb.commons.path.PartialPath;
import org.apache.iotdb.db.engine.querycontext.QueryDataSource;
import org.apache.iotdb.db.engine.storagegroup.DataRegion;
import org.apache.iotdb.db.engine.storagegroup.TsFileResource;
import org.apache.iotdb.db.exception.query.QueryProcessException;
import org.apache.iotdb.db.mpp.common.FragmentInstanceId;
import org.apache.iotdb.db.mpp.common.PlanFragmentId;
import org.apache.iotdb.db.mpp.common.QueryId;
import org.apache.iotdb.db.mpp.execution.driver.DataDriver;
import org.apache.iotdb.db.mpp.execution.driver.DataDriverContext;
import org.apache.iotdb.db.mpp.execution.driver.IDriver;
import org.apache.iotdb.db.mpp.execution.exchange.StubSinkHandle;
import org.apache.iotdb.db.mpp.execution.fragment.FragmentInstanceContext;
import org.apache.iotdb.db.mpp.execution.fragment.FragmentInstanceState;
import org.apache.iotdb.db.mpp.execution.fragment.FragmentInstanceStateMachine;
import org.apache.iotdb.db.mpp.execution.operator.process.LimitOperator;
import org.apache.iotdb.db.mpp.execution.operator.process.join.TimeJoinOperator;
import org.apache.iotdb.db.mpp.execution.operator.process.join.merge.AscTimeComparator;
import org.apache.iotdb.db.mpp.execution.operator.process.join.merge.SingleColumnMerger;
import org.apache.iotdb.db.mpp.execution.operator.source.SeriesScanOperator;
import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanNodeId;
import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.InputLocation;
import org.apache.iotdb.db.mpp.plan.statement.component.Ordering;
import org.apache.iotdb.db.query.reader.series.SeriesReaderTestUtil;
import org.apache.iotdb.tsfile.exception.write.WriteProcessException;
import org.apache.iotdb.tsfile.file.metadata.enums.TSDataType;
import org.apache.iotdb.tsfile.read.common.block.TsBlock;
import org.apache.iotdb.tsfile.read.common.block.column.IntColumn;
import org.apache.iotdb.tsfile.write.schema.MeasurementSchema;

import com.google.common.collect.ImmutableList;
import com.google.common.util.concurrent.ListenableFuture;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mockito;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ExecutorService;

import static org.apache.iotdb.db.mpp.execution.fragment.FragmentInstanceContext.createFragmentInstanceContext;
import static org.apache.iotdb.db.mpp.execution.schedule.DriverTaskThread.EXECUTION_TIME_SLICE;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;

public class DataDriverTest {

  private static final String DATA_DRIVER_TEST_SG = "root.DataDriverTest";
  private final List<String> deviceIds = new ArrayList<>();
  private final List<MeasurementSchema> measurementSchemas = new ArrayList<>();

  private final List<TsFileResource> seqResources = new ArrayList<>();
  private final List<TsFileResource> unSeqResources = new ArrayList<>();

  @Before
  public void setUp() throws MetadataException, IOException, WriteProcessException {
    SeriesReaderTestUtil.setUp(
        measurementSchemas, deviceIds, seqResources, unSeqResources, DATA_DRIVER_TEST_SG);
  }

  @After
  public void tearDown() throws IOException {
    SeriesReaderTestUtil.tearDown(seqResources, unSeqResources);
  }

  @Test
  public void batchTest() {
    ExecutorService instanceNotificationExecutor =
        IoTDBThreadPoolFactory.newFixedThreadPool(1, "test-instance-notification");
    try {
      MeasurementPath measurementPath1 =
          new MeasurementPath(DATA_DRIVER_TEST_SG + ".device0.sensor0", TSDataType.INT32);
      Set<String> allSensors = new HashSet<>();
      allSensors.add("sensor0");
      allSensors.add("sensor1");
      QueryId queryId = new QueryId("stub_query");
      FragmentInstanceId instanceId =
          new FragmentInstanceId(new PlanFragmentId(queryId, 0), "stub-instance");
      FragmentInstanceStateMachine stateMachine =
          new FragmentInstanceStateMachine(instanceId, instanceNotificationExecutor);
      FragmentInstanceContext fragmentInstanceContext =
          createFragmentInstanceContext(instanceId, stateMachine);
      PlanNodeId planNodeId1 = new PlanNodeId("1");
      fragmentInstanceContext.addOperatorContext(
          1, planNodeId1, SeriesScanOperator.class.getSimpleName());
      PlanNodeId planNodeId2 = new PlanNodeId("2");
      fragmentInstanceContext.addOperatorContext(
          2, planNodeId2, SeriesScanOperator.class.getSimpleName());
      fragmentInstanceContext.addOperatorContext(
          3, new PlanNodeId("3"), TimeJoinOperator.class.getSimpleName());
      fragmentInstanceContext.addOperatorContext(
          4, new PlanNodeId("4"), LimitOperator.class.getSimpleName());
      SeriesScanOperator seriesScanOperator1 =
          new SeriesScanOperator(
              planNodeId1,
              measurementPath1,
              allSensors,
              TSDataType.INT32,
              fragmentInstanceContext.getOperatorContexts().get(0),
              null,
              null,
              true);

      MeasurementPath measurementPath2 =
          new MeasurementPath(DATA_DRIVER_TEST_SG + ".device0.sensor1", TSDataType.INT32);
      SeriesScanOperator seriesScanOperator2 =
          new SeriesScanOperator(
              planNodeId2,
              measurementPath2,
              allSensors,
              TSDataType.INT32,
              fragmentInstanceContext.getOperatorContexts().get(1),
              null,
              null,
              true);

      TimeJoinOperator timeJoinOperator =
          new TimeJoinOperator(
              fragmentInstanceContext.getOperatorContexts().get(2),
              Arrays.asList(seriesScanOperator1, seriesScanOperator2),
              Ordering.ASC,
              Arrays.asList(TSDataType.INT32, TSDataType.INT32),
              Arrays.asList(
                  new SingleColumnMerger(new InputLocation(0, 0), new AscTimeComparator()),
                  new SingleColumnMerger(new InputLocation(1, 0), new AscTimeComparator())),
              new AscTimeComparator());

      LimitOperator limitOperator =
          new LimitOperator(
              fragmentInstanceContext.getOperatorContexts().get(3), 250, timeJoinOperator);

      DataRegion dataRegion = Mockito.mock(DataRegion.class);

      List<PartialPath> pathList = ImmutableList.of(measurementPath1, measurementPath2);
      String deviceId = DATA_DRIVER_TEST_SG + ".device0";

      Mockito.when(dataRegion.query(pathList, deviceId, fragmentInstanceContext, null))
          .thenReturn(new QueryDataSource(seqResources, unSeqResources));

      DataDriverContext driverContext =
          new DataDriverContext(
              fragmentInstanceContext,
              pathList,
              null,
              dataRegion,
              ImmutableList.of(seriesScanOperator1, seriesScanOperator2));

      StubSinkHandle sinkHandle = new StubSinkHandle(fragmentInstanceContext);
      IDriver dataDriver = null;
      try {
        dataDriver = new DataDriver(limitOperator, sinkHandle, driverContext);
        assertEquals(fragmentInstanceContext.getId(), dataDriver.getInfo());

        assertFalse(dataDriver.isFinished());

        while (!dataDriver.isFinished()) {
          assertEquals(FragmentInstanceState.RUNNING, stateMachine.getState());
          ListenableFuture<?> blocked = dataDriver.processFor(EXECUTION_TIME_SLICE);
          assertTrue(blocked.isDone());
        }

        assertEquals(FragmentInstanceState.FLUSHING, stateMachine.getState());

        List<TsBlock> result = sinkHandle.getTsBlocks();
        assertEquals(13, result.size());

        for (int i = 0; i < 13; i++) {
          TsBlock tsBlock = result.get(i);
          assertEquals(2, tsBlock.getValueColumnCount());
          assertTrue(tsBlock.getColumn(0) instanceof IntColumn);
          assertTrue(tsBlock.getColumn(1) instanceof IntColumn);

          if (i < 12) {
            assertEquals(20, tsBlock.getPositionCount());
          } else {
            assertEquals(10, tsBlock.getPositionCount());
          }
          for (int j = 0; j < tsBlock.getPositionCount(); j++) {
            long expectedTime = j + 20L * i;
            assertEquals(expectedTime, tsBlock.getTimeByIndex(j));
            if (expectedTime < 200) {
              assertEquals(20000 + expectedTime, tsBlock.getColumn(0).getInt(j));
              assertEquals(20000 + expectedTime, tsBlock.getColumn(1).getInt(j));
            } else if (expectedTime < 260
                || (expectedTime >= 300 && expectedTime < 380)
                || expectedTime >= 400) {
              assertEquals(10000 + expectedTime, tsBlock.getColumn(0).getInt(j));
              assertEquals(10000 + expectedTime, tsBlock.getColumn(1).getInt(j));
            } else {
              assertEquals(expectedTime, tsBlock.getColumn(0).getInt(j));
              assertEquals(expectedTime, tsBlock.getColumn(1).getInt(j));
            }
          }
        }
      } finally {
        if (dataDriver != null) {
          dataDriver.close();
        }
      }
    } catch (IllegalPathException | QueryProcessException e) {
      e.printStackTrace();
      fail();
    } finally {
      instanceNotificationExecutor.shutdown();
    }
  }
}
