/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.flink.api.scala.functions

import org.apache.flink.api.common.functions.RichJoinFunction
import org.apache.flink.api.common.functions.RichMapFunction
import org.apache.flink.api.common.operators.{GenericDataSinkBase, SingleInputSemanticProperties}
import org.apache.flink.api.common.operators.base.{InnerJoinOperatorBase, MapOperatorBase}
import org.apache.flink.api.common.operators.util.FieldSet
import org.apache.flink.api.java.functions.FunctionAnnotation.ForwardedFields
import org.apache.flink.api.java.functions.FunctionAnnotation.ForwardedFieldsFirst
import org.apache.flink.api.java.functions.FunctionAnnotation.ForwardedFieldsSecond
import org.apache.flink.api.java.io.DiscardingOutputFormat
import org.apache.flink.api.scala._

import org.assertj.core.api.Assertions.{assertThat, fail}
import org.junit.jupiter.api.Test

/**
 * This is a minimal test to verify that semantic annotations are evaluated against the type
 * information properly translated correctly to the common data flow API.
 *
 * This covers only the constant fields annotations currently !!!
 */
class SemanticPropertiesTranslationTest {

  /** A mapper that preserves all fields over a tuple data set. */
  @Test
  def translateUnaryFunctionAnnotationTuplesWildCard(): Unit = {
    try {
      val env = ExecutionEnvironment.getExecutionEnvironment

      val input = env.fromElements((3L, "test", 42))
      input
        .map(new WildcardForwardMapper[(Long, String, Int)])
        .output(new DiscardingOutputFormat[(Long, String, Int)])

      val plan = env.createProgramPlan()

      val sink: GenericDataSinkBase[_] = plan.getDataSinks.iterator.next

      val mapper: MapOperatorBase[_, _, _] = sink.getInput.asInstanceOf[MapOperatorBase[_, _, _]]

      val semantics: SingleInputSemanticProperties = mapper.getSemanticProperties
      val fw1: FieldSet = semantics.getForwardingTargetFields(0, 0)
      val fw2: FieldSet = semantics.getForwardingTargetFields(0, 1)
      val fw3: FieldSet = semantics.getForwardingTargetFields(0, 2)

      assertThat(fw1).isNotNull
      assertThat(fw2).isNotNull
      assertThat(fw3).isNotNull
      assertThat(fw1).contains(0)
      assertThat(fw2).contains(1)
      assertThat(fw3).contains(2)
    } catch {
      case e: Exception =>
        System.err.println(e.getMessage)
        e.printStackTrace()
        fail("Exception in test: " + e.getMessage)
    }
  }

  /** A mapper that preserves fields 0, 1, 2 of a tuple data set. */
  @Test
  def translateUnaryFunctionAnnotationTuples1(): Unit = {
    try {
      val env = ExecutionEnvironment.getExecutionEnvironment

      val input = env.fromElements((3L, "test", 42))
      input
        .map(new IndividualForwardMapper[Long, String, Int])
        .output(new DiscardingOutputFormat[(Long, String, Int)])

      val plan = env.createProgramPlan()

      val sink: GenericDataSinkBase[_] = plan.getDataSinks.iterator.next

      val mapper: MapOperatorBase[_, _, _] = sink.getInput.asInstanceOf[MapOperatorBase[_, _, _]]

      val semantics: SingleInputSemanticProperties = mapper.getSemanticProperties
      val fw1: FieldSet = semantics.getForwardingTargetFields(0, 0)
      val fw2: FieldSet = semantics.getForwardingTargetFields(0, 1)
      val fw3: FieldSet = semantics.getForwardingTargetFields(0, 2)

      assertThat(fw1).isNotNull
      assertThat(fw2).isNotNull
      assertThat(fw3).isNotNull
      assertThat(fw1).contains(0)
      assertThat(fw2).contains(1)
      assertThat(fw3).contains(2)
    } catch {
      case e: Exception =>
        System.err.println(e.getMessage)
        e.printStackTrace()
        fail("Exception in test: " + e.getMessage)
    }
  }

  /** A mapper that preserves field 1 of a tuple data set. */
  @Test
  def translateUnaryFunctionAnnotationTuples2(): Unit = {
    try {
      val env = ExecutionEnvironment.getExecutionEnvironment

      val input = env.fromElements((3L, "test", 42))
      input
        .map(new FieldTwoForwardMapper[Long, String, Int])
        .output(new DiscardingOutputFormat[(Long, String, Int)])

      val plan = env.createProgramPlan()

      val sink: GenericDataSinkBase[_] = plan.getDataSinks.iterator.next

      val mapper: MapOperatorBase[_, _, _] = sink.getInput.asInstanceOf[MapOperatorBase[_, _, _]]

      val semantics: SingleInputSemanticProperties = mapper.getSemanticProperties
      val fw1: FieldSet = semantics.getForwardingTargetFields(0, 0)
      val fw2: FieldSet = semantics.getForwardingTargetFields(0, 1)
      val fw3: FieldSet = semantics.getForwardingTargetFields(0, 2)

      assertThat(fw1).isNotNull
      assertThat(fw2).isNotNull
      assertThat(fw3).isNotNull
      assertThat(fw1).isEmpty()
      assertThat(fw3).isEmpty()
      assertThat(fw2).contains(1)
    } catch {
      case e: Exception =>
        System.err.println(e.getMessage)
        e.printStackTrace()
        fail("Exception in test: " + e.getMessage)
    }
  }

  /** A join that preserves tuple fields from both sides. */
  @Test
  def translateBinaryFunctionAnnotationTuples1(): Unit = {
    try {
      val env = ExecutionEnvironment.getExecutionEnvironment

      val input1 = env.fromElements((3L, "test"))
      val input2 = env.fromElements((3L, 3.1415))

      input1
        .join(input2)
        .where(0)
        .equalTo(0)(new ForwardingTupleJoin[Long, String, Long, Double])
        .output(new DiscardingOutputFormat[(String, Long)])

      val plan = env.createProgramPlan()
      val sink: GenericDataSinkBase[_] = plan.getDataSinks.iterator.next

      val join: InnerJoinOperatorBase[_, _, _, _] =
        sink.getInput.asInstanceOf[InnerJoinOperatorBase[_, _, _, _]]

      val semantics = join.getSemanticProperties
      val fw11: FieldSet = semantics.getForwardingTargetFields(0, 0)
      val fw12: FieldSet = semantics.getForwardingTargetFields(0, 1)
      val fw21: FieldSet = semantics.getForwardingTargetFields(1, 0)
      val fw22: FieldSet = semantics.getForwardingTargetFields(1, 1)

      assertThat(fw11).isNotNull
      assertThat(fw21).isNotNull
      assertThat(fw12).isNotNull
      assertThat(fw22).isNotNull
      assertThat(fw11).isEmpty()
      assertThat(fw22).isEmpty()
      assertThat(fw12).contains(0)
      assertThat(fw21).contains(1)
    } catch {
      case e: Exception =>
        System.err.println(e.getMessage)
        e.printStackTrace()
        fail("Exception in test: " + e.getMessage)
    }
  }

  /** A join that preserves tuple fields from both sides. */
  @Test
  def translateBinaryFunctionAnnotationTuples2(): Unit = {
    try {
      val env = ExecutionEnvironment.getExecutionEnvironment

      val input1 = env.fromElements((3L, "test"))
      val input2 = env.fromElements((3L, 42))

      input1
        .join(input2)
        .where(0)
        .equalTo(0)(new ForwardingBasicJoin[(Long, String), (Long, Int)])
        .output(new DiscardingOutputFormat[((Long, String), (Long, Int))])

      val plan = env.createProgramPlan()
      val sink: GenericDataSinkBase[_] = plan.getDataSinks.iterator.next

      val join: InnerJoinOperatorBase[_, _, _, _] =
        sink.getInput.asInstanceOf[InnerJoinOperatorBase[_, _, _, _]]

      val semantics = join.getSemanticProperties
      val fw11: FieldSet = semantics.getForwardingTargetFields(0, 0)
      val fw12: FieldSet = semantics.getForwardingTargetFields(0, 1)
      val fw21: FieldSet = semantics.getForwardingTargetFields(1, 0)
      val fw22: FieldSet = semantics.getForwardingTargetFields(1, 1)

      assertThat(fw11).isNotNull
      assertThat(fw12).isNotNull
      assertThat(fw21).isNotNull
      assertThat(fw22).isNotNull
      assertThat(fw11).contains(0)
      assertThat(fw12).contains(1)
      assertThat(fw21).contains(2)
      assertThat(fw22).contains(3)
    } catch {
      case e: Exception =>
        System.err.println(e.getMessage)
        e.printStackTrace()
        fail("Exception in test: " + e.getMessage)
    }
  }
}

@ForwardedFields(Array("*"))
class WildcardForwardMapper[T] extends RichMapFunction[T, T] {
  def map(value: T): T = {
    value
  }
}

@ForwardedFields(Array("0;1;2"))
class IndividualForwardMapper[X, Y, Z] extends RichMapFunction[(X, Y, Z), (X, Y, Z)] {
  def map(value: (X, Y, Z)): (X, Y, Z) = {
    value
  }
}

@ForwardedFields(Array("_2"))
class FieldTwoForwardMapper[X, Y, Z] extends RichMapFunction[(X, Y, Z), (X, Y, Z)] {
  def map(value: (X, Y, Z)): (X, Y, Z) = {
    value
  }
}

@ForwardedFieldsFirst(Array("_2 -> _1"))
@ForwardedFieldsSecond(Array("_1 -> _2"))
class ForwardingTupleJoin[A, B, C, D] extends RichJoinFunction[(A, B), (C, D), (B, C)] {
  def join(first: (A, B), second: (C, D)): (B, C) = {
    (first._2, second._1)
  }
}

@ForwardedFieldsFirst(Array("* -> 0.*"))
@ForwardedFieldsSecond(Array("* -> 1.*"))
class ForwardingBasicJoin[A, B] extends RichJoinFunction[A, B, (A, B)] {
  def join(first: A, second: B): (A, B) = {
    (first, second)
  }
}
