/*
 *  Licensed to the Apache Software Foundation (ASF) under one
 *  or more contributor license agreements.  See the NOTICE file
 *  distributed with this work for additional information
 *  regarding copyright ownership.  The ASF licenses this file
 *  to you under the Apache License, Version 2.0 (the
 *  "License"); you may not use this file except in compliance
 *  with the License.  You may obtain a copy of the License at
 *
 *    https://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing,
 *  software distributed under the License is distributed on an
 *  "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 *  KIND, either express or implied.  See the License for the
 *  specific language governing permissions and limitations
 *  under the License.
 */

package org.grails.testing.spock

import java.lang.annotation.Annotation
import java.lang.reflect.Method
import java.lang.reflect.Modifier

import groovy.transform.CompileStatic

import org.junit.jupiter.api.AfterAll
import org.junit.jupiter.api.AfterEach
import org.junit.jupiter.api.BeforeAll
import org.junit.jupiter.api.BeforeEach
import org.spockframework.runtime.extension.IGlobalExtension
import org.spockframework.runtime.model.MethodInfo
import org.spockframework.runtime.model.MethodKind
import org.spockframework.runtime.model.SpecInfo

import grails.testing.spring.AutowiredTest
import org.grails.testing.GrailsUnitTest

@CompileStatic
class TestingSupportExtension implements IGlobalExtension {

    AutowiredInterceptor autowiredInterceptor = new AutowiredInterceptor()
    CleanupContextInterceptor cleanupContextInterceptor = new CleanupContextInterceptor()

    @Override
    void visitSpec(SpecInfo spec) {
        if (AutowiredTest.isAssignableFrom(spec.reflection)) {
            spec.addSetupInterceptor(autowiredInterceptor)
        }
        if (GrailsUnitTest.isAssignableFrom(spec.reflection)) {
            spec.addCleanupSpecInterceptor(cleanupContextInterceptor)
        }
        for (Method method : (spec.getReflection().declaredMethods)) {
            if (method.isAnnotationPresent(BeforeEach)) {
                spec.setupMethods.add(0, createJUnitFixtureMethod(spec, method, MethodKind.SETUP, BeforeEach))
            }
            if (method.isAnnotationPresent(AfterEach)) {
                spec.addCleanupMethod(createJUnitFixtureMethod(spec, method, MethodKind.CLEANUP, AfterEach))
            }
            if (method.isAnnotationPresent(BeforeAll)) {
                spec.setupSpecMethods.add(0, createJUnitFixtureMethod(spec, method, MethodKind.SETUP_SPEC, BeforeAll))
            }
            if (method.isAnnotationPresent(AfterAll)) {
                spec.addCleanupSpecMethod(createJUnitFixtureMethod(spec, method, MethodKind.CLEANUP_SPEC, AfterAll))
            }
        }
    }

    private static MethodInfo createMethod(SpecInfo specInfo, Method method, MethodKind kind, String name) {
        MethodInfo methodInfo = new MethodInfo()
        methodInfo.parent = specInfo
        methodInfo.name = name
        methodInfo.reflection = method
        methodInfo.kind = kind
        return methodInfo
    }

    private static MethodInfo createJUnitFixtureMethod(SpecInfo specInfo, Method method, MethodKind kind, Class<? extends Annotation> annotation) {
        MethodInfo methodInfo = createMethod(specInfo, method, kind, method.name)
        methodInfo.excluded = isOverriddenJUnitFixtureMethod(specInfo, method, annotation)
        return methodInfo
    }

    private static boolean isOverriddenJUnitFixtureMethod(SpecInfo specInfo, Method method, Class<? extends Annotation> annotation) {
        if (Modifier.isPrivate(method.modifiers)) return false

        for (Class<?> currClass = specInfo.class; currClass != specInfo.class.superclass; currClass = currClass.superclass) {
            for (Method currMethod : currClass.declaredMethods) {
                if (!currMethod.isAnnotationPresent(annotation)) continue
                if (currMethod.name != method.name) continue
                if (!Arrays.deepEquals(currMethod.parameterTypes, method.parameterTypes)) continue
                return true
            }
        }

        return false
    }
}
