/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.beam.fn.harness;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState;

import com.google.auto.service.AutoService;
import com.google.common.collect.ImmutableMap;
import java.util.Collection;
import java.util.Iterator;
import java.util.Map;
import javax.annotation.Nullable;
import org.apache.beam.fn.harness.DoFnPTransformRunnerFactory.Context;
import org.apache.beam.fn.harness.state.FnApiStateAccessor;
import org.apache.beam.runners.core.DoFnRunner;
import org.apache.beam.runners.core.construction.PTransformTranslation;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.fn.data.FnDataReceiver;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.schemas.FieldAccessDescriptor;
import org.apache.beam.sdk.schemas.SchemaCoder;
import org.apache.beam.sdk.state.State;
import org.apache.beam.sdk.state.StateSpec;
import org.apache.beam.sdk.state.TimeDomain;
import org.apache.beam.sdk.state.Timer;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.DoFn.MultiOutputReceiver;
import org.apache.beam.sdk.transforms.DoFn.OutputReceiver;
import org.apache.beam.sdk.transforms.DoFnOutputReceivers;
import org.apache.beam.sdk.transforms.reflect.DoFnInvoker;
import org.apache.beam.sdk.transforms.reflect.DoFnInvokers;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.FieldAccessDeclaration;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.RowParameter;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.StateDeclaration;
import org.apache.beam.sdk.transforms.reflect.DoFnSignatures;
import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.transforms.windowing.PaneInfo;
import org.apache.beam.sdk.util.UserCodeException;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.values.TupleTag;
import org.joda.time.Instant;

/**
 * A {@link DoFnRunner} specific to integrating with the Fn Api. This is to remove the layers of
 * abstraction caused by StateInternals/TimerInternals since they model state and timer concepts
 * differently.
 */
public class FnApiDoFnRunner<InputT, OutputT>
    implements DoFnPTransformRunnerFactory.DoFnPTransformRunner<InputT> {
  /** A registrar which provides a factory to handle Java {@link DoFn}s. */
  @AutoService(PTransformRunnerFactory.Registrar.class)
  public static class Registrar implements PTransformRunnerFactory.Registrar {
    @Override
    public Map<String, PTransformRunnerFactory> getPTransformRunnerFactories() {
      return ImmutableMap.of(PTransformTranslation.PAR_DO_TRANSFORM_URN, new Factory());
    }
  }

  static class Factory<InputT, OutputT>
      extends DoFnPTransformRunnerFactory<
          InputT, InputT, OutputT, FnApiDoFnRunner<InputT, OutputT>> {
    @Override
    public FnApiDoFnRunner<InputT, OutputT> createRunner(Context<InputT, OutputT> context) {
      return new FnApiDoFnRunner<>(context);
    }
  }

  //////////////////////////////////////////////////////////////////////////////////////////////////

  private final Context<InputT, OutputT> context;
  private final Collection<FnDataReceiver<WindowedValue<OutputT>>> mainOutputConsumers;
  private FnApiStateAccessor stateAccessor;
  private final DoFnSignature doFnSignature;
  private final DoFnInvoker<InputT, OutputT> doFnInvoker;
  private final DoFn<InputT, OutputT>.StartBundleContext startBundleContext;
  private final ProcessBundleContext processContext;
  private final DoFn<InputT, OutputT>.FinishBundleContext finishBundleContext;

  /** Only valid during {@link #processElement}, null otherwise. */
  private WindowedValue<InputT> currentElement;

  /** Only valid during {@link #processElement}, null otherwise. */
  private BoundedWindow currentWindow;

  /** Following fields are only valid if a Schema is set, null otherwise. */
  @Nullable private final SchemaCoder<InputT> schemaCoder;

  @Nullable private final SchemaCoder<OutputT> mainOutputSchemaCoder;
  @Nullable private final FieldAccessDescriptor fieldAccessDescriptor;

  FnApiDoFnRunner(Context<InputT, OutputT> context) {
    this.context = context;

    this.mainOutputConsumers =
        (Collection<FnDataReceiver<WindowedValue<OutputT>>>)
            (Collection) context.tagToConsumer.get(context.mainOutputTag);
    this.doFnSignature = DoFnSignatures.signatureForDoFn(context.doFn);
    this.doFnInvoker = DoFnInvokers.invokerFor(context.doFn);
    this.doFnInvoker.invokeSetup();

    this.startBundleContext =
        this.context.doFn.new StartBundleContext() {
          @Override
          public PipelineOptions getPipelineOptions() {
            return context.pipelineOptions;
          }
        };
    this.processContext = new ProcessBundleContext();
    finishBundleContext =
        this.context.doFn.new FinishBundleContext() {
          @Override
          public PipelineOptions getPipelineOptions() {
            return context.pipelineOptions;
          }

          @Override
          public void output(OutputT output, Instant timestamp, BoundedWindow window) {
            outputTo(
                mainOutputConsumers,
                WindowedValue.of(output, timestamp, window, PaneInfo.NO_FIRING));
          }

          @Override
          public <T> void output(
              TupleTag<T> tag, T output, Instant timestamp, BoundedWindow window) {
            Collection<FnDataReceiver<WindowedValue<T>>> consumers =
                (Collection) context.tagToConsumer.get(tag);
            if (consumers == null) {
              throw new IllegalArgumentException(String.format("Unknown output tag %s", tag));
            }
            outputTo(consumers, WindowedValue.of(output, timestamp, window, PaneInfo.NO_FIRING));
          }
        };

    this.schemaCoder =
        (context.inputCoder instanceof SchemaCoder)
            ? (SchemaCoder<InputT>) context.inputCoder
            : null;
    if (context.outputCoders != null) {
      Coder<OutputT> outputCoder = (Coder<OutputT>) context.outputCoders.get(context.mainOutputTag);
      mainOutputSchemaCoder =
          (outputCoder instanceof SchemaCoder) ? (SchemaCoder<OutputT>) outputCoder : null;
    } else {
      mainOutputSchemaCoder = null;
    }
    DoFnSignature doFnSignature = DoFnSignatures.getSignature(context.doFn.getClass());
    DoFnSignature.ProcessElementMethod processElementMethod =
        DoFnSignatures.getSignature(context.doFn.getClass()).processElement();
    RowParameter rowParameter = processElementMethod.getRowParameter();
    FieldAccessDescriptor fieldAccessDescriptor = null;
    if (rowParameter != null) {
      checkArgument(
          schemaCoder != null,
          "Cannot access object as a row if the input PCollection does not have a schema ."
              + "DoFn "
              + context.doFn.getClass()
              + " Coder "
              + context.inputCoder.getClass());
      String id = rowParameter.fieldAccessId();
      if (id == null) {
        // This is the case where no FieldId is defined, just an @Element Row row. Default to all
        // fields accessed.
        fieldAccessDescriptor = FieldAccessDescriptor.withAllFields();
      } else {
        // In this case, we expect to have a FieldAccessDescriptor defined in the class.
        FieldAccessDeclaration fieldAccessDeclaration =
            doFnSignature.fieldAccessDeclarations().get(id);
        checkArgument(
            fieldAccessDeclaration != null, "No FieldAccessDescriptor defined with id", id);
        checkArgument(fieldAccessDeclaration.field().getType().equals(FieldAccessDescriptor.class));
        try {
          fieldAccessDescriptor =
              (FieldAccessDescriptor) fieldAccessDeclaration.field().get(context.doFn);
        } catch (IllegalAccessException e) {
          throw new RuntimeException(e);
        }
      }
      // Resolve the FieldAccessDescriptor. This converts all field names into field ids.
      fieldAccessDescriptor = fieldAccessDescriptor.resolve(schemaCoder.getSchema());
    }
    this.fieldAccessDescriptor = fieldAccessDescriptor;
  }

  @Override
  public void startBundle() {
    this.stateAccessor =
        new FnApiStateAccessor(
            context.pipelineOptions,
            context.ptransformId,
            context.processBundleInstructionId,
            context.tagToSideInputSpecMap,
            context.beamFnStateClient,
            context.keyCoder,
            (Coder<BoundedWindow>) context.windowCoder,
            () -> currentElement,
            () -> currentWindow);

    doFnInvoker.invokeStartBundle(startBundleContext);
  }

  @Override
  public void processElement(WindowedValue<InputT> elem) {
    currentElement = elem;
    try {
      Iterator<BoundedWindow> windowIterator =
          (Iterator<BoundedWindow>) elem.getWindows().iterator();
      while (windowIterator.hasNext()) {
        currentWindow = windowIterator.next();
        doFnInvoker.invokeProcessElement(processContext);
      }
    } finally {
      currentElement = null;
      currentWindow = null;
    }
  }

  @Override
  public void finishBundle() {
    doFnInvoker.invokeFinishBundle(finishBundleContext);

    // TODO: Support caching state data across bundle boundaries.
    this.stateAccessor.finalizeState();
    this.stateAccessor = null;
  }

  /** Outputs the given element to the specified set of consumers wrapping any exceptions. */
  private <T> void outputTo(
      Collection<FnDataReceiver<WindowedValue<T>>> consumers, WindowedValue<T> output) {
    try {
      for (FnDataReceiver<WindowedValue<T>> consumer : consumers) {
        consumer.accept(output);
      }
    } catch (Throwable t) {
      throw UserCodeException.wrap(t);
    }
  }

  /**
   * Provides arguments for a {@link DoFnInvoker} for {@link DoFn.ProcessElement @ProcessElement}.
   */
  private class ProcessBundleContext extends DoFn<InputT, OutputT>.ProcessContext
      implements DoFnInvoker.ArgumentProvider<InputT, OutputT> {

    private ProcessBundleContext() {
      context.doFn.super();
    }

    @Override
    public BoundedWindow window() {
      return currentWindow;
    }

    @Override
    public PaneInfo paneInfo(DoFn<InputT, OutputT> doFn) {
      return pane();
    }

    @Override
    public DoFn<InputT, OutputT>.StartBundleContext startBundleContext(DoFn<InputT, OutputT> doFn) {
      throw new UnsupportedOperationException(
          "Cannot access StartBundleContext outside of @StartBundle method.");
    }

    @Override
    public DoFn<InputT, OutputT>.FinishBundleContext finishBundleContext(
        DoFn<InputT, OutputT> doFn) {
      throw new UnsupportedOperationException(
          "Cannot access FinishBundleContext outside of @FinishBundle method.");
    }

    @Override
    public DoFn<InputT, OutputT>.ProcessContext processContext(DoFn<InputT, OutputT> doFn) {
      return this;
    }

    @Override
    public InputT element(DoFn<InputT, OutputT> doFn) {
      return element();
    }

    @Override
    public Row asRow(@Nullable String id) {
      checkState(fieldAccessDescriptor.allFields());
      return schemaCoder.getToRowFunction().apply(element());
    }

    @Override
    public Instant timestamp(DoFn<InputT, OutputT> doFn) {
      return timestamp();
    }

    @Override
    public TimeDomain timeDomain(DoFn<InputT, OutputT> doFn) {
      throw new UnsupportedOperationException(
          "Cannot access time domain outside of @ProcessTimer method.");
    }

    @Override
    public OutputReceiver<OutputT> outputReceiver(DoFn<InputT, OutputT> doFn) {
      return DoFnOutputReceivers.windowedReceiver(this, null);
    }

    @Override
    public OutputReceiver<Row> outputRowReceiver(DoFn<InputT, OutputT> doFn) {
      return DoFnOutputReceivers.rowReceiver(this, null, mainOutputSchemaCoder);
    }

    @Override
    public MultiOutputReceiver taggedOutputReceiver(DoFn<InputT, OutputT> doFn) {
      return DoFnOutputReceivers.windowedMultiReceiver(this, context.outputCoders);
    }

    @Override
    public DoFn<InputT, OutputT>.OnTimerContext onTimerContext(DoFn<InputT, OutputT> doFn) {
      throw new UnsupportedOperationException("TODO: Add support for timers");
    }

    @Override
    public RestrictionTracker<?, ?> restrictionTracker() {
      throw new UnsupportedOperationException("TODO: Add support for SplittableDoFn");
    }

    @Override
    public State state(String stateId) {
      StateDeclaration stateDeclaration = doFnSignature.stateDeclarations().get(stateId);
      checkNotNull(stateDeclaration, "No state declaration found for %s", stateId);
      StateSpec<?> spec;
      try {
        spec = (StateSpec<?>) stateDeclaration.field().get(context.doFn);
      } catch (IllegalAccessException e) {
        throw new RuntimeException(e);
      }
      return spec.bind(stateId, stateAccessor);
    }

    @Override
    public Timer timer(String timerId) {
      throw new UnsupportedOperationException("TODO: Add support for timers");
    }

    @Override
    public PipelineOptions getPipelineOptions() {
      return context.pipelineOptions;
    }

    @Override
    public PipelineOptions pipelineOptions() {
      return context.pipelineOptions;
    }

    @Override
    public void output(OutputT output) {
      outputTo(
          mainOutputConsumers,
          WindowedValue.of(
              output, currentElement.getTimestamp(), currentWindow, currentElement.getPane()));
    }

    @Override
    public void outputWithTimestamp(OutputT output, Instant timestamp) {
      outputTo(
          mainOutputConsumers,
          WindowedValue.of(output, timestamp, currentWindow, currentElement.getPane()));
    }

    @Override
    public <T> void output(TupleTag<T> tag, T output) {
      Collection<FnDataReceiver<WindowedValue<T>>> consumers =
          (Collection) context.tagToConsumer.get(tag);
      if (consumers == null) {
        throw new IllegalArgumentException(String.format("Unknown output tag %s", tag));
      }
      outputTo(
          consumers,
          WindowedValue.of(
              output, currentElement.getTimestamp(), currentWindow, currentElement.getPane()));
    }

    @Override
    public <T> void outputWithTimestamp(TupleTag<T> tag, T output, Instant timestamp) {
      Collection<FnDataReceiver<WindowedValue<T>>> consumers =
          (Collection) context.tagToConsumer.get(tag);
      if (consumers == null) {
        throw new IllegalArgumentException(String.format("Unknown output tag %s", tag));
      }
      outputTo(
          consumers, WindowedValue.of(output, timestamp, currentWindow, currentElement.getPane()));
    }

    @Override
    public InputT element() {
      return currentElement.getValue();
    }

    @Override
    public <T> T sideInput(PCollectionView<T> view) {
      return stateAccessor.get(view, currentWindow);
    }

    @Override
    public Instant timestamp() {
      return currentElement.getTimestamp();
    }

    @Override
    public PaneInfo pane() {
      return currentElement.getPane();
    }

    @Override
    public void updateWatermark(Instant watermark) {
      throw new UnsupportedOperationException("TODO: Add support for SplittableDoFn");
    }
  }
}
