# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
#
import threading

import numpy as np

from iotdb.table_session_pool import TableSessionPool, TableSessionPoolConfig
from iotdb.utils.IoTDBConstants import TSDataType
from iotdb.utils.NumpyTablet import NumpyTablet
from iotdb.utils.Tablet import ColumnType, Tablet


def prepare_data():
    print("create database")
    # Get a session from the pool
    session = session_pool.get_session()
    session.execute_non_query_statement("CREATE DATABASE IF NOT EXISTS db1")
    session.execute_non_query_statement('USE "db1"')
    session.execute_non_query_statement(
        "CREATE TABLE table0 (id1 string tag, attr1 string attribute, "
        + "m1 double "
        + "field)"
    )
    session.execute_non_query_statement(
        "CREATE TABLE table1 (id1 string tag, attr1 string attribute, "
        + "m1 double "
        + "field)"
    )

    print("now the tables are:")
    # show result
    res = session.execute_query_statement("SHOW TABLES")
    while res.has_next():
        print(res.next())

    session.close()


def insert_data(num: int):
    print("insert data for table" + str(num))
    # Get a session from the pool
    session = session_pool.get_session()
    column_names = [
        "id1",
        "attr1",
        "m1",
    ]
    data_types = [
        TSDataType.STRING,
        TSDataType.STRING,
        TSDataType.DOUBLE,
    ]
    column_types = [ColumnType.TAG, ColumnType.ATTRIBUTE, ColumnType.FIELD]
    timestamps = []
    values = []
    for row in range(15):
        timestamps.append(row)
        values.append(["id:" + str(row), "attr:" + str(row), row * 1.0])
    tablet = Tablet(
        "table" + str(num), column_names, data_types, values, timestamps, column_types
    )
    session.insert(tablet)
    session.execute_non_query_statement("FLush")

    np_timestamps = np.arange(15, 30, dtype=np.dtype(">i8"))
    np_values = [
        np.array(["id:{}".format(i) for i in range(15, 30)]),
        np.array(["attr:{}".format(i) for i in range(15, 30)]),
        np.linspace(15.0, 29.0, num=15, dtype=TSDataType.DOUBLE.np_dtype()),
    ]

    np_tablet = NumpyTablet(
        "table" + str(num),
        column_names,
        data_types,
        np_values,
        np_timestamps,
        column_types=column_types,
    )
    session.insert(np_tablet)
    session.close()


def query_data():
    # Get a session from the pool
    session = session_pool.get_session()

    print("get data from table0")
    res = session.execute_query_statement("select * from table0")
    while res.has_next():
        print(res.next())

    print("get data from table1")
    res = session.execute_query_statement("select * from table0")
    while res.has_next():
        print(res.next())

    session.close()


def delete_data():
    session = session_pool.get_session()
    session.execute_non_query_statement("drop database db1")
    print("data has been deleted. now the databases are:")
    res = session.execute_query_statement("show databases")
    while res.has_next():
        print(res.next())
    session.close()


# Create a session pool
username = "root"
password = "root"
node_urls = ["127.0.0.1:6667", "127.0.0.1:6668", "127.0.0.1:6669"]
fetch_size = 1024
database = "db1"
max_pool_size = 5
wait_timeout_in_ms = 3000
config = TableSessionPoolConfig(
    node_urls=node_urls,
    username=username,
    password=password,
    database=database,
    max_pool_size=max_pool_size,
    fetch_size=fetch_size,
    wait_timeout_in_ms=wait_timeout_in_ms,
)
session_pool = TableSessionPool(config)

prepare_data()

insert_thread1 = threading.Thread(target=insert_data, args=(0,))
insert_thread2 = threading.Thread(target=insert_data, args=(1,))

insert_thread1.start()
insert_thread2.start()

insert_thread1.join()
insert_thread2.join()

query_data()
delete_data()
session_pool.close()
print("example is finished!")
