/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.beam.runners.dataflow.worker;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Function;
import com.google.common.base.Objects;
import com.google.common.base.Preconditions;
import com.google.common.collect.AbstractIterator;
import com.google.common.collect.ForwardingList;
import com.google.common.util.concurrent.ForwardingFuture;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.SettableFuture;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import javax.annotation.Nullable;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.TagBag;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.TagValue;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.util.Weighted;
import org.apache.beam.vendor.grpc.v1_13_1.com.google.protobuf.ByteString;
import org.joda.time.Instant;

/**
 * Reads persistent state from {@link Windmill}. Returns {@code Future}s containing the data that
 * has been read. Will not initiate a read until {@link Future#get} is called, at which point all
 * the pending futures will be read.
 *
 * <p>CAUTION Watch out for escaping references to the reader ending up inside {@link
 * WindmillStateCache}.
 */
class WindmillStateReader {
  /**
   * Ideal maximum bytes in a TagBag response. However, Windmill will always return at least one
   * value if possible irrespective of this limit.
   */
  public static final long MAX_BAG_BYTES = 8L << 20; // 8MB

  /**
   * Ideal maximum bytes in a KeyedGetDataResponse. However, Windmill will always return at least
   * one value if possible irrespective of this limit.
   */
  public static final long MAX_KEY_BYTES = 16L << 20; // 16MB

  /**
   * When combined with a key and computationId, represents the unique address for state managed by
   * Windmill.
   */
  private static class StateTag {
    private enum Kind {
      VALUE,
      BAG,
      WATERMARK;
    }

    private final Kind kind;
    private final ByteString tag;
    private final String stateFamily;

    /**
     * For {@link Kind#BAG} kinds: A previous 'continuation_position' returned by Windmill to signal
     * the resulting bag was incomplete. Sending that position will request the next page of values.
     * Null for first request.
     *
     * <p>Null for other kinds.
     */
    @Nullable private final Long requestPosition;

    private StateTag(
        Kind kind, ByteString tag, String stateFamily, @Nullable Long requestPosition) {
      this.kind = kind;
      this.tag = tag;
      this.stateFamily = Preconditions.checkNotNull(stateFamily);
      this.requestPosition = requestPosition;
    }

    private StateTag(Kind kind, ByteString tag, String stateFamily) {
      this(kind, tag, stateFamily, null);
    }

    @Override
    public boolean equals(Object obj) {
      if (this == obj) {
        return true;
      }

      if (!(obj instanceof StateTag)) {
        return false;
      }

      StateTag that = (StateTag) obj;
      return Objects.equal(this.kind, that.kind)
          && Objects.equal(this.tag, that.tag)
          && Objects.equal(this.stateFamily, that.stateFamily)
          && Objects.equal(this.requestPosition, that.requestPosition);
    }

    @Override
    public int hashCode() {
      return Objects.hashCode(kind, tag, stateFamily, requestPosition);
    }

    @Override
    public String toString() {
      return "Tag("
          + kind
          + ","
          + tag.toStringUtf8()
          + ","
          + stateFamily
          + (requestPosition == null ? "" : ("," + requestPosition.toString()))
          + ")";
    }
  }

  /**
   * An in-memory collection of deserialized values and an optional continuation position to pass to
   * Windmill when fetching the next page of values.
   */
  private static class ValuesAndContPosition<T> {
    private final List<T> values;

    /** Position to pass to next request for next page of values. Null if done. */
    @Nullable private final Long continuationPosition;

    public ValuesAndContPosition(List<T> values, @Nullable Long continuationPosition) {
      this.values = values;
      this.continuationPosition = continuationPosition;
    }
  }

  private final String computation;
  private final ByteString key;
  private final long shardingKey;
  private final long workToken;

  private final MetricTrackingWindmillServerStub server;

  private long bytesRead = 0L;

  public WindmillStateReader(
      MetricTrackingWindmillServerStub server,
      String computation,
      ByteString key,
      long shardingKey,
      long workToken) {
    this.server = server;
    this.computation = computation;
    this.key = key;
    this.shardingKey = shardingKey;
    this.workToken = workToken;
  }

  private static final class CoderAndFuture<ElemT, FutureT> {
    private Coder<ElemT> coder;
    private final SettableFuture<FutureT> future;

    private CoderAndFuture(Coder<ElemT> coder, SettableFuture<FutureT> future) {
      this.coder = coder;
      this.future = future;
    }

    private SettableFuture<FutureT> getFuture() {
      return future;
    }

    private SettableFuture<FutureT> getNonDoneFuture(StateTag stateTag) {
      if (future.isDone()) {
        throw new IllegalStateException("Future for " + stateTag + " is already done");
      }
      return future;
    }

    private Coder<ElemT> getAndClearCoder() {
      if (coder == null) {
        throw new IllegalStateException("Coder has already been cleared from cache");
      }
      Coder<ElemT> result = coder;
      coder = null;
      return result;
    }

    private void checkNoCoder() {
      if (coder != null) {
        throw new IllegalStateException("Unexpected coder");
      }
    }
  }

  @VisibleForTesting ConcurrentLinkedQueue<StateTag> pendingLookups = new ConcurrentLinkedQueue<>();
  private ConcurrentHashMap<StateTag, CoderAndFuture<?, ?>> waiting = new ConcurrentHashMap<>();

  private <ElemT, FutureT> Future<FutureT> stateFuture(
      StateTag stateTag, @Nullable Coder<ElemT> coder) {
    CoderAndFuture<ElemT, FutureT> coderAndFuture =
        new CoderAndFuture<>(coder, SettableFuture.<FutureT>create());
    CoderAndFuture<?, ?> existingCoderAndFutureWildcard =
        waiting.putIfAbsent(stateTag, coderAndFuture);
    if (existingCoderAndFutureWildcard == null) {
      // Schedule a new request. It's response is guaranteed to find the future and coder.
      pendingLookups.add(stateTag);
    } else {
      // Piggy-back on the pending or already answered request.
      @SuppressWarnings("unchecked")
      CoderAndFuture<ElemT, FutureT> existingCoderAndFuture =
          (CoderAndFuture<ElemT, FutureT>) existingCoderAndFutureWildcard;
      coderAndFuture = existingCoderAndFuture;
    }

    return wrappedFuture(coderAndFuture.getFuture());
  }

  private <ElemT, FutureT> CoderAndFuture<ElemT, FutureT> getWaiting(
      StateTag stateTag, boolean shouldRemove) {
    CoderAndFuture<?, ?> coderAndFutureWildcard;
    if (shouldRemove) {
      coderAndFutureWildcard = waiting.remove(stateTag);
    } else {
      coderAndFutureWildcard = waiting.get(stateTag);
    }
    if (coderAndFutureWildcard == null) {
      throw new IllegalStateException("Missing future for " + stateTag);
    }
    @SuppressWarnings("unchecked")
    CoderAndFuture<ElemT, FutureT> coderAndFuture =
        (CoderAndFuture<ElemT, FutureT>) coderAndFutureWildcard;
    return coderAndFuture;
  }

  public Future<Instant> watermarkFuture(ByteString encodedTag, String stateFamily) {
    return stateFuture(new StateTag(StateTag.Kind.WATERMARK, encodedTag, stateFamily), null);
  }

  public <T> Future<T> valueFuture(ByteString encodedTag, String stateFamily, Coder<T> coder) {
    return stateFuture(new StateTag(StateTag.Kind.VALUE, encodedTag, stateFamily), coder);
  }

  public <T> Future<Iterable<T>> bagFuture(
      ByteString encodedTag, String stateFamily, Coder<T> elemCoder) {
    // First request has no continuation position.
    StateTag stateTag = new StateTag(StateTag.Kind.BAG, encodedTag, stateFamily);
    // Convert the ValuesAndContPosition<T> to Iterable<T>.
    return valuesToPagingIterableFuture(
        stateTag, elemCoder, this.<T, ValuesAndContPosition<T>>stateFuture(stateTag, elemCoder));
  }

  /**
   * Internal request to fetch the next 'page' of values in a TagBag. Return null if no continuation
   * position is in {@code contStateTag}, which signals there are no more pages.
   */
  @Nullable
  private <T> Future<ValuesAndContPosition<T>> continuationBagFuture(
      StateTag contStateTag, Coder<T> elemCoder) {
    if (contStateTag.requestPosition == null) {
      // We're done.
      return null;
    }
    return stateFuture(contStateTag, elemCoder);
  }

  /**
   * A future which will trigger a GetData request to Windmill for all outstanding futures on the
   * first {@link #get}.
   */
  private static class WrappedFuture<T> extends ForwardingFuture.SimpleForwardingFuture<T> {
    /**
     * The reader we'll use to service the eventual read. Null if read has been fulfilled.
     *
     * <p>NOTE: We must clear this after the read is fulfilled to prevent space leaks.
     */
    @Nullable private WindmillStateReader reader;

    public WrappedFuture(WindmillStateReader reader, Future<T> delegate) {
      super(delegate);
      this.reader = reader;
    }

    @Override
    public T get() throws InterruptedException, ExecutionException {
      if (!delegate().isDone() && reader != null) {
        // Only one thread per reader, so no race here.
        reader.startBatchAndBlock();
      }
      reader = null;
      return super.get();
    }

    @Override
    public T get(long timeout, TimeUnit unit)
        throws InterruptedException, ExecutionException, TimeoutException {
      if (!delegate().isDone() && reader != null) {
        // Only one thread per reader, so no race here.
        reader.startBatchAndBlock();
      }
      reader = null;
      return super.get(timeout, unit);
    }
  }

  private <T> Future<T> wrappedFuture(final Future<T> future) {
    if (future.isDone()) {
      // If the underlying lookup is already complete, we don't need to create the wrapper.
      return future;
    } else {
      // Otherwise, wrap the true future so we know when to trigger a GetData.
      return new WrappedFuture<>(this, future);
    }
  }

  /** Function to extract an {@link Iterable} from the continuation-supporting page read future. */
  private static class ToIterableFunction<T>
      implements Function<ValuesAndContPosition<T>, Iterable<T>> {
    /**
     * Reader to request continuation pages from, or {@literal null} if no continuation pages
     * required.
     */
    @Nullable private WindmillStateReader reader;

    private final StateTag stateTag;
    private final Coder<T> elemCoder;

    public ToIterableFunction(WindmillStateReader reader, StateTag stateTag, Coder<T> elemCoder) {
      this.reader = reader;
      this.stateTag = stateTag;
      this.elemCoder = elemCoder;
    }

    @Override
    public Iterable<T> apply(ValuesAndContPosition<T> valuesAndContPosition) {
      if (valuesAndContPosition.continuationPosition == null) {
        // Number of values is small enough Windmill sent us the entire bag in one response.
        reader = null;
        return valuesAndContPosition.values;
      } else {
        // Return an iterable which knows how to come back for more.
        StateTag contStateTag =
            new StateTag(
                stateTag.kind,
                stateTag.tag,
                stateTag.stateFamily,
                valuesAndContPosition.continuationPosition);
        return new BagPagingIterable<>(
            reader, valuesAndContPosition.values, contStateTag, elemCoder);
      }
    }
  }

  /**
   * Return future which transforms a {@code ValuesAndContPosition<T>} result into the initial
   * Iterable<T> result expected from the external caller.
   */
  private <T> Future<Iterable<T>> valuesToPagingIterableFuture(
      final StateTag stateTag,
      final Coder<T> elemCoder,
      final Future<ValuesAndContPosition<T>> future) {
    return Futures.lazyTransform(future, new ToIterableFunction<T>(this, stateTag, elemCoder));
  }

  public void startBatchAndBlock() {
    // First, drain work out of the pending lookups into a set. These will be the items we fetch.
    HashSet<StateTag> toFetch = new HashSet<>();
    while (!pendingLookups.isEmpty()) {
      StateTag stateTag = pendingLookups.poll();
      if (stateTag == null) {
        break;
      }

      if (!toFetch.add(stateTag)) {
        throw new IllegalStateException("Duplicate tags being fetched.");
      }
    }

    // If we failed to drain anything, some other thread pulled it off the queue. We have no work
    // to do.
    if (toFetch.isEmpty()) {
      return;
    }

    Windmill.KeyedGetDataRequest request = createRequest(toFetch);
    Windmill.KeyedGetDataResponse response = server.getStateData(computation, request);

    if (response == null) {
      throw new RuntimeException("Windmill unexpectedly returned null for request " + request);
    }

    consumeResponse(request, response, toFetch);
  }

  public long getBytesRead() {
    return bytesRead;
  }

  private Windmill.KeyedGetDataRequest createRequest(Iterable<StateTag> toFetch) {
    Windmill.KeyedGetDataRequest.Builder keyedDataBuilder =
        Windmill.KeyedGetDataRequest.newBuilder()
            .setKey(key)
            .setShardingKey(shardingKey)
            .setWorkToken(workToken);

    for (StateTag stateTag : toFetch) {
      switch (stateTag.kind) {
        case BAG:
          TagBag.Builder bag =
              keyedDataBuilder
                  .addBagsToFetchBuilder()
                  .setTag(stateTag.tag)
                  .setStateFamily(stateTag.stateFamily)
                  .setFetchMaxBytes(MAX_BAG_BYTES);
          if (stateTag.requestPosition != null) {
            // We're asking for the next page.
            bag.setRequestPosition(stateTag.requestPosition);
          }
          break;

        case WATERMARK:
          keyedDataBuilder
              .addWatermarkHoldsToFetchBuilder()
              .setTag(stateTag.tag)
              .setStateFamily(stateTag.stateFamily);
          break;

        case VALUE:
          keyedDataBuilder
              .addValuesToFetchBuilder()
              .setTag(stateTag.tag)
              .setStateFamily(stateTag.stateFamily);
          break;

        default:
          throw new RuntimeException("Unknown kind of tag requested: " + stateTag.kind);
      }
    }

    keyedDataBuilder.setMaxBytes(MAX_KEY_BYTES);

    return keyedDataBuilder.build();
  }

  private void consumeResponse(
      Windmill.KeyedGetDataRequest request,
      Windmill.KeyedGetDataResponse response,
      Set<StateTag> toFetch) {
    bytesRead += response.getSerializedSize();

    if (response.getFailed()) {
      // Set up all the futures for this key to throw an exception:
      KeyTokenInvalidException keyTokenInvalidException =
          new KeyTokenInvalidException(key.toStringUtf8());
      for (StateTag stateTag : toFetch) {
        waiting.get(stateTag).future.setException(keyTokenInvalidException);
      }
      return;
    }

    if (!key.equals(response.getKey())) {
      throw new RuntimeException("Expected data for key " + key + " but was " + response.getKey());
    }

    for (Windmill.TagBag bag : response.getBagsList()) {
      StateTag stateTag =
          new StateTag(
              StateTag.Kind.BAG,
              bag.getTag(),
              bag.getStateFamily(),
              bag.hasRequestPosition() ? bag.getRequestPosition() : null);
      if (!toFetch.remove(stateTag)) {
        throw new IllegalStateException(
            "Received response for unrequested tag " + stateTag + ". Pending tags: " + toFetch);
      }
      consumeBag(bag, stateTag);
    }

    for (Windmill.WatermarkHold hold : response.getWatermarkHoldsList()) {
      StateTag stateTag =
          new StateTag(StateTag.Kind.WATERMARK, hold.getTag(), hold.getStateFamily());
      if (!toFetch.remove(stateTag)) {
        throw new IllegalStateException(
            "Received response for unrequested tag " + stateTag + ". Pending tags: " + toFetch);
      }
      consumeWatermark(hold, stateTag);
    }

    for (Windmill.TagValue value : response.getValuesList()) {
      StateTag stateTag = new StateTag(StateTag.Kind.VALUE, value.getTag(), value.getStateFamily());
      if (!toFetch.remove(stateTag)) {
        throw new IllegalStateException(
            "Received response for unrequested tag " + stateTag + ". Pending tags: " + toFetch);
      }
      consumeTagValue(value, stateTag);
    }

    if (!toFetch.isEmpty()) {
      throw new IllegalStateException(
          "Didn't receive responses for all pending fetches. Missing: " + toFetch);
    }
  }

  @VisibleForTesting
  static class WeightedList<T> extends ForwardingList<T> implements Weighted {
    private List<T> delegate;
    long weight;

    WeightedList(List<T> delegate) {
      this.delegate = delegate;
      this.weight = 0;
    }

    @Override
    protected List<T> delegate() {
      return delegate;
    }

    @Override
    public boolean add(T elem) {
      throw new UnsupportedOperationException("Must use AddWeighted()");
    }

    @Override
    public long getWeight() {
      return weight;
    }

    public void addWeighted(T elem, long weight) {
      delegate.add(elem);
      this.weight += weight;
    }
  }

  /** The deserialized values in {@code bag} as a read-only array list. */
  private <T> List<T> bagPageValues(TagBag bag, Coder<T> elemCoder) {
    if (bag.getValuesCount() == 0) {
      return new WeightedList<T>(Collections.<T>emptyList());
    }

    WeightedList<T> valueList = new WeightedList<>(new ArrayList<T>(bag.getValuesCount()));
    for (ByteString value : bag.getValuesList()) {
      try {
        valueList.addWeighted(
            elemCoder.decode(value.newInput(), Coder.Context.OUTER), value.size());
      } catch (IOException e) {
        throw new IllegalStateException("Unable to decode tag list using " + elemCoder, e);
      }
    }
    return valueList;
  }

  private <T> void consumeBag(TagBag bag, StateTag stateTag) {
    boolean shouldRemove;
    if (stateTag.requestPosition == null) {
      // This is the response for the first page.
      // Leave the future in the cache so subsequent requests for the first page
      // can return immediately.
      shouldRemove = false;
    } else {
      // This is a response for a subsequent page.
      // Don't cache the future since we may need to make multiple requests with different
      // continuation positions.
      shouldRemove = true;
    }
    CoderAndFuture<T, ValuesAndContPosition<T>> coderAndFuture = getWaiting(stateTag, shouldRemove);
    SettableFuture<ValuesAndContPosition<T>> future = coderAndFuture.getNonDoneFuture(stateTag);
    Coder<T> coder = coderAndFuture.getAndClearCoder();
    List<T> values = this.<T>bagPageValues(bag, coder);
    future.set(
        new ValuesAndContPosition<T>(
            values, bag.hasContinuationPosition() ? bag.getContinuationPosition() : null));
  }

  private void consumeWatermark(Windmill.WatermarkHold watermarkHold, StateTag stateTag) {
    CoderAndFuture<Void, Instant> coderAndFuture = getWaiting(stateTag, false);
    SettableFuture<Instant> future = coderAndFuture.getNonDoneFuture(stateTag);
    // No coders for watermarks
    coderAndFuture.checkNoCoder();

    Instant hold = null;
    for (long timestamp : watermarkHold.getTimestampsList()) {
      Instant instant = new Instant(TimeUnit.MICROSECONDS.toMillis(timestamp));
      // TIMESTAMP_MAX_VALUE represents infinity, and windmill will return it if no hold is set, so
      // don't treat it as a hold here.
      if (instant.isBefore(BoundedWindow.TIMESTAMP_MAX_VALUE)
          && (hold == null || instant.isBefore(hold))) {
        hold = instant;
      }
    }

    future.set(hold);
  }

  private <T> void consumeTagValue(TagValue tagValue, StateTag stateTag) {
    CoderAndFuture<T, T> coderAndFuture = getWaiting(stateTag, false);
    SettableFuture<T> future = coderAndFuture.getNonDoneFuture(stateTag);
    Coder<T> coder = coderAndFuture.getAndClearCoder();

    if (tagValue.hasValue()
        && tagValue.getValue().hasData()
        && !tagValue.getValue().getData().isEmpty()) {
      InputStream inputStream = tagValue.getValue().getData().newInput();
      try {
        T value = coder.decode(inputStream, Coder.Context.OUTER);
        future.set(value);
      } catch (IOException e) {
        throw new IllegalStateException("Unable to decode value using " + coder, e);
      }
    } else {
      future.set(null);
    }
  }

  /**
   * An iterable over elements backed by paginated GetData requests to Windmill. The iterable may be
   * iterated over an arbitrary number of times and multiple iterators may be active simultaneously.
   *
   * <p>There are two pattern we wish to support with low -memory and -latency:
   *
   * <ol>
   *   <li>Re-iterate over the initial elements multiple times (eg Iterables.first). We'll cache the
   *       initial 'page' of values returned by Windmill from our first request for the lifetime of
   *       the iterable.
   *   <li>Iterate through all elements of a very large collection. We'll send the GetData request
   *       for the next page when the current page is begun. We'll discard intermediate pages and
   *       only retain the first. Thus the maximum memory pressure is one page plus one page per
   *       call to iterator.
   * </ol>
   */
  private static class BagPagingIterable<T> implements Iterable<T> {
    /**
     * The reader we will use for scheduling continuation pages.
     *
     * <p>NOTE We've made this explicit to remind us to be careful not to cache the iterable.
     */
    private final WindmillStateReader reader;

    /** Initial values returned for the first page. Never reclaimed. */
    private final List<T> firstPage;

    /** State tag with continuation position set for second page. */
    private final StateTag secondPagePos;

    /** Coder for elements. */
    private final Coder<T> elemCoder;

    private BagPagingIterable(
        WindmillStateReader reader, List<T> firstPage, StateTag secondPagePos, Coder<T> elemCoder) {
      this.reader = reader;
      this.firstPage = firstPage;
      this.secondPagePos = secondPagePos;
      this.elemCoder = elemCoder;
    }

    @Override
    public Iterator<T> iterator() {
      return new AbstractIterator<T>() {
        private Iterator<T> currentPage = firstPage.iterator();
        private StateTag nextPagePos = secondPagePos;
        private Future<ValuesAndContPosition<T>> pendingNextPage =
            // NOTE: The results of continuation page reads are never cached.
            reader.continuationBagFuture(nextPagePos, elemCoder);

        @Override
        protected T computeNext() {
          while (true) {
            if (currentPage.hasNext()) {
              return currentPage.next();
            }
            if (pendingNextPage == null) {
              return endOfData();
            }

            ValuesAndContPosition<T> valuesAndContPosition;
            try {
              valuesAndContPosition = pendingNextPage.get();
            } catch (InterruptedException | ExecutionException e) {
              if (e instanceof InterruptedException) {
                Thread.currentThread().interrupt();
              }
              throw new RuntimeException("Unable to read value from state", e);
            }
            currentPage = valuesAndContPosition.values.iterator();
            nextPagePos =
                new StateTag(
                    nextPagePos.kind,
                    nextPagePos.tag,
                    nextPagePos.stateFamily,
                    valuesAndContPosition.continuationPosition);
            pendingNextPage =
                // NOTE: The results of continuation page reads are never cached.
                reader.continuationBagFuture(nextPagePos, elemCoder);
          }
        }
      };
    }
  }
}
