/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.beam.runners.fnexecution.control;

import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;

import java.io.Serializable;
import java.util.Map;
import org.apache.beam.model.fnexecution.v1.BeamFnApi;
import org.apache.beam.model.pipeline.v1.Endpoints;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.runners.core.construction.CoderTranslation;
import org.apache.beam.runners.core.construction.ModelCoderRegistrar;
import org.apache.beam.runners.core.construction.ModelCoders;
import org.apache.beam.runners.core.construction.PipelineTranslation;
import org.apache.beam.runners.core.construction.graph.ExecutableStage;
import org.apache.beam.runners.core.construction.graph.FusedPipeline;
import org.apache.beam.runners.core.construction.graph.GreedyPipelineFuser;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.coders.VoidCoder;
import org.apache.beam.sdk.state.BagState;
import org.apache.beam.sdk.state.StateSpec;
import org.apache.beam.sdk.state.StateSpecs;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.GroupByKey;
import org.apache.beam.sdk.transforms.Impulse;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Optional;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
import org.junit.Test;

/** Tests for {@link ProcessBundleDescriptors}. */
public class ProcessBundleDescriptorsTest implements Serializable {

  /**
   * Tests that a stateful Executable stage will wrap a key coder of a stateful transform in a
   * LengthPrefixCoder.
   */
  @Test
  public void testWrapKeyCoderOfStatefulExecutableStageInLengthPrefixCoder() throws Exception {
    // Add another stateful stage with a non-standard key coder
    Pipeline p = Pipeline.create();
    Coder<Void> keycoder = VoidCoder.of();
    assertThat(ModelCoderRegistrar.isKnownCoder(keycoder), is(false));
    p.apply("impulse", Impulse.create())
        .apply(
            "create",
            ParDo.of(
                new DoFn<byte[], KV<Void, String>>() {
                  @ProcessElement
                  public void process(ProcessContext ctxt) {}
                }))
        .setCoder(KvCoder.of(keycoder, StringUtf8Coder.of()))
        .apply(
            "userState",
            ParDo.of(
                new DoFn<KV<Void, String>, KV<Void, String>>() {

                  @StateId("stateId")
                  private final StateSpec<BagState<String>> bufferState =
                      StateSpecs.bag(StringUtf8Coder.of());

                  @ProcessElement
                  public void processElement(
                      @Element KV<Void, String> element,
                      @StateId("stateId") BagState<String> state,
                      OutputReceiver<KV<Void, String>> r) {
                    for (String value : state.read()) {
                      r.output(KV.of(element.getKey(), value));
                    }
                    state.add(element.getValue());
                  }
                }))
        // Force the output to be materialized
        .apply("gbk", GroupByKey.create());

    RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p);
    FusedPipeline fused = GreedyPipelineFuser.fuse(pipelineProto);
    Optional<ExecutableStage> optionalStage =
        Iterables.tryFind(
            fused.getFusedStages(),
            (ExecutableStage stage) ->
                stage.getUserStates().stream()
                    .anyMatch(spec -> spec.localName().equals("stateId")));
    checkState(optionalStage.isPresent(), "Expected a stage with user state.");
    ExecutableStage stage = optionalStage.get();

    ProcessBundleDescriptors.ExecutableProcessBundleDescriptor descriptor =
        ProcessBundleDescriptors.fromExecutableStage(
            "test_stage", stage, Endpoints.ApiServiceDescriptor.getDefaultInstance());

    BeamFnApi.ProcessBundleDescriptor pbDescriptor = descriptor.getProcessBundleDescriptor();
    String inputColId = stage.getInputPCollection().getId();
    String inputCoderId = pbDescriptor.getPcollectionsMap().get(inputColId).getCoderId();

    Map<String, RunnerApi.Coder> codersMap = pbDescriptor.getCodersMap();
    RunnerApi.Coder coder = codersMap.get(inputCoderId);
    String keyCoderId = ModelCoders.getKvCoderComponents(coder).keyCoderId();

    assertThat(
        codersMap.get(keyCoderId).getSpec().getUrn(), is(ModelCoders.LENGTH_PREFIX_CODER_URN));

    RunnerApi.Coder orignalCoder = stage.getComponents().getCodersMap().get(inputCoderId);
    String originalKeyCoderId = ModelCoders.getKvCoderComponents(orignalCoder).keyCoderId();
    assertThat(
        stage.getComponents().getCodersMap().get(originalKeyCoderId).getSpec().getUrn(),
        is(CoderTranslation.JAVA_SERIALIZED_CODER_URN));
  }
}
