/*
 * ====================================================================
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 * ====================================================================
 *
 * This software consists of voluntary contributions made by many
 * individuals on behalf of the Apache Software Foundation.  For more
 * information on the Apache Software Foundation, please see
 * <http://www.apache.org/>.
 *
 */

package org.apache.hc.core5.http.impl.nio;

import java.io.IOException;
import java.net.SocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.ReadableByteChannel;
import java.nio.channels.SelectionKey;
import java.nio.channels.WritableByteChannel;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;

import javax.net.ssl.SSLSession;

import org.apache.hc.core5.http.ConnectionClosedException;
import org.apache.hc.core5.http.ContentLengthStrategy;
import org.apache.hc.core5.http.EndpointDetails;
import org.apache.hc.core5.http.EntityDetails;
import org.apache.hc.core5.http.Header;
import org.apache.hc.core5.http.HttpConnection;
import org.apache.hc.core5.http.HttpException;
import org.apache.hc.core5.http.HttpMessage;
import org.apache.hc.core5.http.Message;
import org.apache.hc.core5.http.ProtocolVersion;
import org.apache.hc.core5.http.config.CharCodingConfig;
import org.apache.hc.core5.http.config.H1Config;
import org.apache.hc.core5.http.impl.BasicEndpointDetails;
import org.apache.hc.core5.http.impl.BasicHttpConnectionMetrics;
import org.apache.hc.core5.http.impl.BasicHttpTransportMetrics;
import org.apache.hc.core5.http.impl.CharCodingSupport;
import org.apache.hc.core5.http.impl.DefaultContentLengthStrategy;
import org.apache.hc.core5.http.impl.IncomingEntityDetails;
import org.apache.hc.core5.http.nio.AsyncClientExchangeHandler;
import org.apache.hc.core5.http.nio.CapacityChannel;
import org.apache.hc.core5.http.nio.ContentDecoder;
import org.apache.hc.core5.http.nio.ContentEncoder;
import org.apache.hc.core5.http.nio.NHttpMessageParser;
import org.apache.hc.core5.http.nio.NHttpMessageWriter;
import org.apache.hc.core5.http.nio.SessionInputBuffer;
import org.apache.hc.core5.http.nio.SessionOutputBuffer;
import org.apache.hc.core5.http.nio.command.RequestExecutionCommand;
import org.apache.hc.core5.http.nio.command.ShutdownCommand;
import org.apache.hc.core5.io.CloseMode;
import org.apache.hc.core5.io.SocketTimeoutExceptionFactory;
import org.apache.hc.core5.reactor.Command;
import org.apache.hc.core5.reactor.EventMask;
import org.apache.hc.core5.reactor.IOSession;
import org.apache.hc.core5.reactor.ProtocolIOSession;
import org.apache.hc.core5.reactor.ssl.TlsDetails;
import org.apache.hc.core5.util.Args;
import org.apache.hc.core5.util.Identifiable;
import org.apache.hc.core5.util.Timeout;

abstract class AbstractHttp1StreamDuplexer<IncomingMessage extends HttpMessage, OutgoingMessage extends HttpMessage>
        implements Identifiable, HttpConnection {

    private enum ConnectionState { READY, ACTIVE, GRACEFUL_SHUTDOWN, SHUTDOWN}

    private final ProtocolIOSession ioSession;
    private final H1Config h1Config;
    private final SessionInputBufferImpl inbuf;
    private final SessionOutputBufferImpl outbuf;
    private final BasicHttpTransportMetrics inTransportMetrics;
    private final BasicHttpTransportMetrics outTransportMetrics;
    private final BasicHttpConnectionMetrics connMetrics;
    private final NHttpMessageParser<IncomingMessage> incomingMessageParser;
    private final NHttpMessageWriter<OutgoingMessage> outgoingMessageWriter;
    private final ContentLengthStrategy incomingContentStrategy;
    private final ContentLengthStrategy outgoingContentStrategy;
    private final ByteBuffer contentBuffer;
    private final AtomicInteger outputRequests;

    private volatile Message<IncomingMessage, ContentDecoder> incomingMessage;
    private volatile Message<OutgoingMessage, ContentEncoder> outgoingMessage;
    private volatile ConnectionState connState;
    private volatile CapacityWindow capacityWindow;

    private volatile ProtocolVersion version;
    private volatile EndpointDetails endpointDetails;

    AbstractHttp1StreamDuplexer(
            final ProtocolIOSession ioSession,
            final H1Config h1Config,
            final CharCodingConfig charCodingConfig,
            final NHttpMessageParser<IncomingMessage> incomingMessageParser,
            final NHttpMessageWriter<OutgoingMessage> outgoingMessageWriter,
            final ContentLengthStrategy incomingContentStrategy,
            final ContentLengthStrategy outgoingContentStrategy) {
        this.ioSession = Args.notNull(ioSession, "I/O session");
        this.h1Config = h1Config != null ? h1Config : H1Config.DEFAULT;
        final int bufferSize = this.h1Config.getBufferSize();
        this.inbuf = new SessionInputBufferImpl(bufferSize, bufferSize < 512 ? bufferSize : 512,
                this.h1Config.getMaxLineLength(),
                CharCodingSupport.createDecoder(charCodingConfig));
        this.outbuf = new SessionOutputBufferImpl(bufferSize, bufferSize < 512 ? bufferSize : 512,
                CharCodingSupport.createEncoder(charCodingConfig));
        this.inTransportMetrics = new BasicHttpTransportMetrics();
        this.outTransportMetrics = new BasicHttpTransportMetrics();
        this.connMetrics = new BasicHttpConnectionMetrics(inTransportMetrics, outTransportMetrics);
        this.incomingMessageParser = incomingMessageParser;
        this.outgoingMessageWriter = outgoingMessageWriter;
        this.incomingContentStrategy = incomingContentStrategy != null ? incomingContentStrategy :
                DefaultContentLengthStrategy.INSTANCE;
        this.outgoingContentStrategy = outgoingContentStrategy != null ? outgoingContentStrategy :
                DefaultContentLengthStrategy.INSTANCE;
        this.contentBuffer = ByteBuffer.allocate(this.h1Config.getBufferSize());
        this.outputRequests = new AtomicInteger(0);
        this.connState = ConnectionState.READY;
    }

    @Override
    public String getId() {
        return ioSession.getId();
    }

    void shutdownSession(final CloseMode closeMode) {
        if (closeMode == CloseMode.GRACEFUL) {
            connState = ConnectionState.GRACEFUL_SHUTDOWN;
            ioSession.enqueue(ShutdownCommand.GRACEFUL, Command.Priority.NORMAL);
        } else {
            connState = ConnectionState.SHUTDOWN;
            ioSession.close();
        }
    }

    void shutdownSession(final Exception exception) {
        connState = ConnectionState.SHUTDOWN;
        try {
            terminate(exception);
        } finally {
            ioSession.close();
        }
    }

    abstract void disconnected();

    abstract void terminate(final Exception exception);

    abstract void updateInputMetrics(IncomingMessage incomingMessage, BasicHttpConnectionMetrics connMetrics);

    abstract void updateOutputMetrics(OutgoingMessage outgoingMessage, BasicHttpConnectionMetrics connMetrics);

    abstract void consumeHeader(IncomingMessage messageHead, EntityDetails entityDetails) throws HttpException, IOException;

    abstract boolean handleIncomingMessage(IncomingMessage incomingMessage) throws HttpException;

    abstract boolean handleOutgoingMessage(OutgoingMessage outgoingMessage) throws HttpException;

    abstract ContentDecoder createContentDecoder(
            long contentLength,
            ReadableByteChannel channel,
            SessionInputBuffer buffer,
            BasicHttpTransportMetrics metrics) throws HttpException;

    abstract ContentEncoder createContentEncoder(
            long contentLength,
            WritableByteChannel channel,
            SessionOutputBuffer buffer,
            BasicHttpTransportMetrics metrics) throws HttpException;

    abstract void consumeData(ByteBuffer src) throws HttpException, IOException;

    abstract void updateCapacity(CapacityChannel capacityChannel) throws HttpException, IOException;

    abstract void dataEnd(List<? extends Header> trailers) throws HttpException, IOException;

    abstract boolean isOutputReady();

    abstract void produceOutput() throws HttpException, IOException;

    abstract void execute(RequestExecutionCommand executionCommand) throws HttpException, IOException;

    abstract void inputEnd() throws HttpException, IOException;

    abstract void outputEnd() throws HttpException, IOException;

    abstract boolean inputIdle();

    abstract boolean outputIdle();

    abstract boolean handleTimeout();

    private void processCommands() throws HttpException, IOException {
        for (;;) {
            final Command command = ioSession.poll();
            if (command == null) {
                return;
            }
            if (command instanceof ShutdownCommand) {
                final ShutdownCommand shutdownCommand = (ShutdownCommand) command;
                requestShutdown(shutdownCommand.getType());
            } else if (command instanceof RequestExecutionCommand) {
                if (connState.compareTo(ConnectionState.GRACEFUL_SHUTDOWN) >= 0) {
                    command.cancel();
                } else {
                    execute((RequestExecutionCommand) command);
                    return;
                }
            } else {
                throw new HttpException("Unexpected command: " + command.getClass());
            }
        }
    }

    public final void onConnect(final ByteBuffer prefeed) throws HttpException, IOException {
        if (prefeed != null) {
            inbuf.put(prefeed);
        }
        connState = ConnectionState.ACTIVE;
        processCommands();
    }

    public final void onInput() throws HttpException, IOException {
        while (connState.compareTo(ConnectionState.SHUTDOWN) < 0) {
            int totalBytesRead = 0;
            int messagesReceived = 0;
            if (incomingMessage == null) {

                if (connState.compareTo(ConnectionState.GRACEFUL_SHUTDOWN) >= 0 && inputIdle()) {
                    ioSession.clearEvent(SelectionKey.OP_READ);
                    return;
                }

                int bytesRead;
                do {
                    bytesRead = inbuf.fill(ioSession.channel());
                    if (bytesRead > 0) {
                        totalBytesRead += bytesRead;
                        inTransportMetrics.incrementBytesTransferred(bytesRead);
                    }
                    final IncomingMessage messageHead = incomingMessageParser.parse(inbuf, bytesRead == -1);
                    if (messageHead != null) {
                        messagesReceived++;
                        incomingMessageParser.reset();

                        this.version = messageHead.getVersion();

                        updateInputMetrics(messageHead, connMetrics);
                        final ContentDecoder contentDecoder;
                        if (handleIncomingMessage(messageHead)) {
                            final long len = incomingContentStrategy.determineLength(messageHead);
                            contentDecoder = createContentDecoder(len, ioSession.channel(), inbuf, inTransportMetrics);
                            consumeHeader(messageHead, contentDecoder != null ? new IncomingEntityDetails(messageHead, len) : null);
                        } else {
                            consumeHeader(messageHead, null);
                            contentDecoder = null;
                        }
                        capacityWindow = new CapacityWindow(h1Config.getInitialWindowSize(), ioSession);
                        if (contentDecoder != null) {
                            incomingMessage = new Message<>(messageHead, contentDecoder);
                            break;
                        }
                        inputEnd();
                        if (connState.compareTo(ConnectionState.ACTIVE) == 0) {
                            ioSession.setEvent(SelectionKey.OP_READ);
                        } else {
                            break;
                        }
                    }
                } while (bytesRead > 0);

                if (bytesRead == -1 && !inbuf.hasData()) {
                    if (outputIdle() && inputIdle()) {
                        requestShutdown(CloseMode.GRACEFUL);
                    } else {
                        shutdownSession(new ConnectionClosedException("Connection closed by peer"));
                    }
                    return;
                }
            }

            if (incomingMessage != null) {
                final ContentDecoder contentDecoder = incomingMessage.getBody();

                // At present the consumer can be forced to consume data
                // over its declared capacity in order to avoid having
                // unprocessed message body content stuck in the session
                // input buffer
                int bytesRead;
                while ((bytesRead = contentDecoder.read(contentBuffer)) > 0) {
                    totalBytesRead += bytesRead;
                    contentBuffer.flip();
                    consumeData(contentBuffer);
                    contentBuffer.clear();
                    final int capacity = capacityWindow.removeCapacity(bytesRead);
                    if (capacity <= 0) {
                        if (!contentDecoder.isCompleted()) {
                            updateCapacity(capacityWindow);
                        }
                        break;
                    }
                }
                if (contentDecoder.isCompleted()) {
                    dataEnd(contentDecoder.getTrailers());
                    capacityWindow.close();
                    incomingMessage = null;
                    ioSession.setEvent(SelectionKey.OP_READ);
                    inputEnd();
                }
            }
            if (totalBytesRead == 0 && messagesReceived == 0) {
                break;
            }
        }
    }

    public final void onOutput() throws IOException, HttpException {
        ioSession.getLock().lock();
        try {
            if (outbuf.hasData()) {
                final int bytesWritten = outbuf.flush(ioSession.channel());
                if (bytesWritten > 0) {
                    outTransportMetrics.incrementBytesTransferred(bytesWritten);
                }
            }
        } finally {
            ioSession.getLock().unlock();
        }
        if (connState.compareTo(ConnectionState.SHUTDOWN) < 0) {
            produceOutput();
            final int pendingOutputRequests = outputRequests.get();
            final boolean outputPending = isOutputReady();
            final boolean outputEnd;
            ioSession.getLock().lock();
            try {
                if (!outputPending && !outbuf.hasData() && outputRequests.compareAndSet(pendingOutputRequests, 0)) {
                    ioSession.clearEvent(SelectionKey.OP_WRITE);
                } else {
                    outputRequests.addAndGet(-pendingOutputRequests);
                }
                outputEnd = outgoingMessage == null && !outbuf.hasData();
            } finally {
                ioSession.getLock().unlock();
            }
            if (outputEnd) {
                outputEnd();
                if (connState.compareTo(ConnectionState.ACTIVE) == 0) {
                    processCommands();
                } else if (connState.compareTo(ConnectionState.GRACEFUL_SHUTDOWN) >= 0 && inputIdle() && outputIdle()) {
                    connState = ConnectionState.SHUTDOWN;
                }
            }
        }
        if (connState.compareTo(ConnectionState.SHUTDOWN) >= 0) {
            ioSession.close();
        }
    }

    public final void onTimeout(final Timeout timeout) throws IOException, HttpException {
        if (!handleTimeout()) {
            onException(SocketTimeoutExceptionFactory.create(timeout));
        }
    }

    public final void onException(final Exception ex) {
        shutdownSession(ex);
        for (;;) {
            final Command command = ioSession.poll();
            if (command != null) {
                if (command instanceof RequestExecutionCommand) {
                    final AsyncClientExchangeHandler exchangeHandler = ((RequestExecutionCommand) command).getExchangeHandler();
                    exchangeHandler.failed(ex);
                    exchangeHandler.releaseResources();
                } else {
                    command.cancel();
                }
            } else {
                break;
            }
        }
    }

    public final void onDisconnect() {
        disconnected();
        for (;;) {
            final Command command = ioSession.poll();
            if (command != null) {
                if (command instanceof RequestExecutionCommand) {
                    final AsyncClientExchangeHandler exchangeHandler = ((RequestExecutionCommand) command).getExchangeHandler();
                    exchangeHandler.failed(new ConnectionClosedException());
                    exchangeHandler.releaseResources();
                } else {
                    command.cancel();
                }
            } else {
                break;
            }
        }
    }

    void requestShutdown(final CloseMode closeMode) {
        switch (closeMode) {
            case GRACEFUL:
                if (connState == ConnectionState.ACTIVE) {
                    connState = ConnectionState.GRACEFUL_SHUTDOWN;
                }
                break;
            case IMMEDIATE:
                connState = ConnectionState.SHUTDOWN;
                break;
        }
        ioSession.setEvent(SelectionKey.OP_WRITE);
    }

    void commitMessageHead(
            final OutgoingMessage messageHead,
            final boolean endStream,
            final FlushMode flushMode) throws HttpException, IOException {
        ioSession.getLock().lock();
        try {
            outgoingMessageWriter.write(messageHead, outbuf);
            updateOutputMetrics(messageHead, connMetrics);
            if (!endStream) {
                final ContentEncoder contentEncoder;
                if (handleOutgoingMessage(messageHead)) {
                    final long len = outgoingContentStrategy.determineLength(messageHead);
                    contentEncoder = createContentEncoder(len, ioSession.channel(), outbuf, outTransportMetrics);
                } else {
                    contentEncoder = null;
                }
                if (contentEncoder != null) {
                    outgoingMessage = new Message<>(messageHead, contentEncoder);
                }
            }
            outgoingMessageWriter.reset();
            if (flushMode == FlushMode.IMMEDIATE) {
                outbuf.flush(ioSession.channel());
            }
            ioSession.setEvent(EventMask.WRITE);
        } finally {
            ioSession.getLock().unlock();
        }
    }

    void requestSessionInput() {
        ioSession.setEvent(SelectionKey.OP_READ);
    }

    void requestSessionOutput() {
        outputRequests.incrementAndGet();
        ioSession.setEvent(SelectionKey.OP_WRITE);
    }

    Timeout getSessionTimeout() {
        return ioSession.getSocketTimeout();
    }

    void setSessionTimeout(final Timeout timeout) {
        ioSession.setSocketTimeout(timeout);
    }

    void suspendSessionOutput() throws IOException {
        ioSession.getLock().lock();
        try {
            if (outbuf.hasData()) {
                outbuf.flush(ioSession.channel());
            } else {
                ioSession.clearEvent(SelectionKey.OP_WRITE);
            }
        } finally {
            ioSession.getLock().unlock();
        }
    }

    int streamOutput(final ByteBuffer src) throws IOException {
        ioSession.getLock().lock();
        try {
            if (outgoingMessage == null) {
                throw new ClosedChannelException();
            }
            final ContentEncoder contentEncoder = outgoingMessage.getBody();
            final int bytesWritten = contentEncoder.write(src);
            if (bytesWritten > 0) {
                ioSession.setEvent(SelectionKey.OP_WRITE);
            }
            return bytesWritten;
        } finally {
            ioSession.getLock().unlock();
        }
    }

    enum MessageDelineation { NONE, CHUNK_CODED, MESSAGE_HEAD}

    MessageDelineation endOutputStream(final List<? extends Header> trailers) throws IOException {
        ioSession.getLock().lock();
        try {
            if (outgoingMessage == null) {
                return MessageDelineation.NONE;
            }
            final ContentEncoder contentEncoder = outgoingMessage.getBody();
            contentEncoder.complete(trailers);
            ioSession.setEvent(SelectionKey.OP_WRITE);
            outgoingMessage = null;
            return contentEncoder instanceof ChunkEncoder
                            ? MessageDelineation.CHUNK_CODED
                            : MessageDelineation.MESSAGE_HEAD;
        } finally {
            ioSession.getLock().unlock();
        }
    }

    boolean isOutputCompleted() {
        ioSession.getLock().lock();
        try {
            if (outgoingMessage == null) {
                return true;
            }
            final ContentEncoder contentEncoder = outgoingMessage.getBody();
            return contentEncoder.isCompleted();
        } finally {
            ioSession.getLock().unlock();
        }
    }

    @Override
    public void close() throws IOException {
        ioSession.enqueue(ShutdownCommand.GRACEFUL, Command.Priority.NORMAL);
    }

    @Override
    public void close(final CloseMode closeMode) {
        ioSession.enqueue(new ShutdownCommand(closeMode), Command.Priority.IMMEDIATE);
    }

    @Override
    public boolean isOpen() {
        return connState == ConnectionState.ACTIVE;
    }

    @Override
    public Timeout getSocketTimeout() {
        return ioSession.getSocketTimeout();
    }

    @Override
    public void setSocketTimeout(final Timeout timeout) {
        ioSession.setSocketTimeout(timeout);
    }

    @Override
    public EndpointDetails getEndpointDetails() {
        if (endpointDetails == null) {
            endpointDetails = new BasicEndpointDetails(
                    ioSession.getRemoteAddress(),
                    ioSession.getLocalAddress(),
                    connMetrics,
                    ioSession.getSocketTimeout());
        }
        return endpointDetails;
    }

    @Override
    public ProtocolVersion getProtocolVersion() {
        return version;
    }

    @Override
    public SocketAddress getRemoteAddress() {
        return ioSession.getRemoteAddress();
    }

    @Override
    public SocketAddress getLocalAddress() {
        return ioSession.getLocalAddress();
    }

    @Override
    public SSLSession getSSLSession() {
        final TlsDetails tlsDetails = ioSession.getTlsDetails();
        return tlsDetails != null ? tlsDetails.getSSLSession() : null;
    }

    void appendState(final StringBuilder buf) {
        buf.append("connState=").append(connState)
                .append(", inbuf=").append(inbuf)
                .append(", outbuf=").append(outbuf)
                .append(", inputWindow=").append(capacityWindow != null ? capacityWindow.getWindow() : 0);
    }

    static class CapacityWindow implements CapacityChannel {
        private final IOSession ioSession;
        private int window;
        private boolean closed;

        CapacityWindow(final int window, final IOSession ioSession) {
            this.window = window;
            this.ioSession = ioSession;
        }

        @Override
        public synchronized void update(final int increment) throws IOException {
            if (closed) {
                return;
            }
            if (increment > 0) {
                updateWindow(increment);
                ioSession.setEvent(SelectionKey.OP_READ);
            }
        }

        /**
         * Internal method for removing capacity. We don't need to check
         * if this channel is closed in it.
         */
        synchronized int removeCapacity(final int delta) {
            updateWindow(-delta);
            if (window <= 0) {
                ioSession.clearEvent(SelectionKey.OP_READ);
            }
            return window;
        }

        private void updateWindow(final int delta) {
            int newValue = window + delta;
            // Math.addExact
            if (((window ^ newValue) & (delta ^ newValue)) < 0) {
                newValue = delta < 0 ? Integer.MIN_VALUE : Integer.MAX_VALUE;
            }
            window = newValue;
        }

        /**
         * Closes the capacity channel, preventing user code from accidentally requesting
         * read events outside of the context of the request the channel was created for
         */
        synchronized void close() {
            closed = true;
        }

        // visible for testing
        int getWindow() {
            return window;
        }
    }
}
