/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.beam.runners.samza.translation;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.beam.runners.samza.runtime.Op;
import org.apache.beam.runners.samza.runtime.OpAdapter;
import org.apache.beam.runners.samza.runtime.OpMessage;
import org.apache.beam.sdk.runners.TransformHierarchy;
import org.apache.beam.sdk.transforms.Flatten;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PValue;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.samza.operators.MessageStream;

/**
 * Translates {@link org.apache.beam.sdk.transforms.Flatten.PCollections} to Samza merge operator.
 */
class FlattenPCollectionsTranslator<T> implements TransformTranslator<Flatten.PCollections<T>> {
  @Override
  public void translate(
      Flatten.PCollections<T> transform, TransformHierarchy.Node node, TranslationContext ctx) {
    final PCollection<T> output = ctx.getOutput(transform);

    final List<MessageStream<OpMessage<T>>> inputStreams = new ArrayList<>();
    for (Map.Entry<TupleTag<?>, PValue> taggedPValue : node.getInputs().entrySet()) {
      if (!(taggedPValue.getValue() instanceof PCollection)) {
        throw new IllegalArgumentException(
            String.format(
                "Got non-PCollection input for flatten. Tag: %s. Input: %s. Type: %s",
                taggedPValue.getKey(),
                taggedPValue.getValue(),
                taggedPValue.getValue().getClass()));
      }

      @SuppressWarnings("unchecked")
      final PCollection<T> input = (PCollection<T>) taggedPValue.getValue();
      inputStreams.add(ctx.getMessageStream(input));
    }

    if (inputStreams.size() == 0) {
      final MessageStream<OpMessage<T>> noOpStream =
          ctx.getDummyStream()
              .flatMap(OpAdapter.adapt((Op<String, T, Void>) (inputElement, emitter) -> {}));
      ctx.registerMessageStream(output, noOpStream);
      return;
    }

    if (inputStreams.size() == 1) {
      ctx.registerMessageStream(output, inputStreams.get(0));
      return;
    }

    final Set<MessageStream<OpMessage<T>>> streamsToMerge = new HashSet<>();
    inputStreams.forEach(
        stream -> {
          boolean inserted = streamsToMerge.add(stream);
          if (!inserted) {
            // Merge same streams. Make a copy of the current stream.
            streamsToMerge.add(stream.map(m -> m));
          }
        });

    final MessageStream<OpMessage<T>> outputStream = MessageStream.mergeAll(streamsToMerge);
    ctx.registerMessageStream(output, outputStream);
  }
}
