/*
 * Copyright 2008-2014 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.codehaus.groovy.transform;

import groovy.lang.Newify;
import org.codehaus.groovy.GroovyBugError;
import org.codehaus.groovy.ast.ASTNode;
import org.codehaus.groovy.ast.AnnotatedNode;
import org.codehaus.groovy.ast.AnnotationNode;
import org.codehaus.groovy.ast.ClassCodeExpressionTransformer;
import org.codehaus.groovy.ast.ClassNode;
import org.codehaus.groovy.ast.FieldNode;
import org.codehaus.groovy.ast.MethodNode;
import org.codehaus.groovy.ast.expr.ClassExpression;
import org.codehaus.groovy.ast.expr.ClosureExpression;
import org.codehaus.groovy.ast.expr.ConstantExpression;
import org.codehaus.groovy.ast.expr.ConstructorCallExpression;
import org.codehaus.groovy.ast.expr.DeclarationExpression;
import org.codehaus.groovy.ast.expr.Expression;
import org.codehaus.groovy.ast.expr.ListExpression;
import org.codehaus.groovy.ast.expr.MethodCallExpression;
import org.codehaus.groovy.ast.expr.VariableExpression;
import org.codehaus.groovy.control.CompilePhase;
import org.codehaus.groovy.control.SourceUnit;

import java.util.HashSet;
import java.util.List;
import java.util.Arrays;
import java.util.Set;

import static org.codehaus.groovy.ast.ClassHelper.make;
import static org.codehaus.groovy.ast.tools.GeneralUtils.callX;
import static org.codehaus.groovy.ast.tools.GeneralUtils.classX;

/**
 * Handles generation of code for the @Newify annotation.
 *
 * @author Paul King
 */
@GroovyASTTransformation(phase = CompilePhase.CANONICALIZATION)
public class NewifyASTTransformation extends ClassCodeExpressionTransformer implements ASTTransformation {
    private static final ClassNode MY_TYPE = make(Newify.class);
    private static final String MY_NAME = MY_TYPE.getNameWithoutPackage();
    private static final String BASE_BAD_PARAM_ERROR = "Error during @" + MY_NAME +
            " processing. Annotation parameter must be a class or list of classes but found ";
    private SourceUnit source;
    private ListExpression classesToNewify;
    private DeclarationExpression candidate;
    private boolean auto;

    public void visit(ASTNode[] nodes, SourceUnit source) {
        this.source = source;
        if (nodes.length != 2 || !(nodes[0] instanceof AnnotationNode) || !(nodes[1] instanceof AnnotatedNode)) {
            internalError("Expecting [AnnotationNode, AnnotatedClass] but got: " + Arrays.asList(nodes));
        }

        AnnotatedNode parent = (AnnotatedNode) nodes[1];
        AnnotationNode node = (AnnotationNode) nodes[0];
        if (!MY_TYPE.equals(node.getClassNode())) {
            internalError("Transformation called from wrong annotation: " + node.getClassNode().getName());
        }

        boolean autoFlag = determineAutoFlag(node.getMember("auto"));
        Expression value = node.getMember("value");

        if (parent instanceof ClassNode) {
            newifyClass((ClassNode) parent, autoFlag, determineClasses(value, false));
        } else if (parent instanceof MethodNode || parent instanceof FieldNode) {
            newifyMethodOrField(parent, autoFlag, determineClasses(value, false));
        } else if (parent instanceof DeclarationExpression) {
            newifyDeclaration((DeclarationExpression) parent, autoFlag, determineClasses(value, true));
        }
    }

    private void newifyDeclaration(DeclarationExpression de, boolean autoFlag, ListExpression list) {
        ClassNode cNode = de.getDeclaringClass();
        candidate = de;
        final ListExpression oldClassesToNewify = classesToNewify;
        final boolean oldAuto = auto;
        classesToNewify = list;
        auto = autoFlag;
        super.visitClass(cNode);
        classesToNewify = oldClassesToNewify;
        auto = oldAuto;
    }

    private boolean determineAutoFlag(Expression autoExpr) {
        return !(autoExpr instanceof ConstantExpression && ((ConstantExpression) autoExpr).getValue().equals(false));
    }

    /** allow non-strict mode in scripts because parsing not complete at that point */
    private ListExpression determineClasses(Expression expr, boolean searchSourceUnit) {
        ListExpression list = new ListExpression();
        if (expr instanceof ClassExpression) {
            list.addExpression(expr);
        } else if (expr instanceof VariableExpression && searchSourceUnit) {
            VariableExpression ve = (VariableExpression) expr;
            ClassNode fromSourceUnit = getSourceUnitClass(ve);
            if (fromSourceUnit != null) {
                ClassExpression found = classX(fromSourceUnit);
                found.setSourcePosition(ve);
                list.addExpression(found);
            } else {
                addError(BASE_BAD_PARAM_ERROR + "an unresolvable reference to '" + ve.getName() + "'.", expr);
            }
        } else if (expr instanceof ListExpression) {
            list = (ListExpression) expr;
            final List<Expression> expressions = list.getExpressions();
            for (int i = 0; i < expressions.size(); i++) {
                Expression next = expressions.get(i);
                if (next instanceof VariableExpression && searchSourceUnit) {
                    VariableExpression ve = (VariableExpression) next;
                    ClassNode fromSourceUnit = getSourceUnitClass(ve);
                    if (fromSourceUnit != null) {
                        ClassExpression found = classX(fromSourceUnit);
                        found.setSourcePosition(ve);
                        expressions.set(i, found);
                    } else {
                        addError(BASE_BAD_PARAM_ERROR + "a list containing an unresolvable reference to '" + ve.getName() + "'.", next);
                    }
                } else if (!(next instanceof ClassExpression)) {
                    addError(BASE_BAD_PARAM_ERROR + "a list containing type: " + next.getType().getName() + ".", next);
                }
            }
            checkDuplicateNameClashes(list);
        } else if (expr != null) {
            addError(BASE_BAD_PARAM_ERROR + "a type: " + expr.getType().getName() + ".", expr);
        }
        return list;
    }

    private ClassNode getSourceUnitClass(VariableExpression ve) {
        List<ClassNode> classes = source.getAST().getClasses();
        for (ClassNode classNode : classes) {
            if (classNode.getNameWithoutPackage().equals(ve.getName())) return classNode;
        }
        return null;
    }

    public Expression transform(Expression expr) {
        if (expr == null) return null;
        if (expr instanceof MethodCallExpression && candidate == null) {
            MethodCallExpression mce = (MethodCallExpression) expr;
            Expression args = transform(mce.getArguments());
            if (isNewifyCandidate(mce)) {
                Expression transformed = transformMethodCall(mce, args);
                transformed.setSourcePosition(mce);
                return transformed;
            }
            Expression method = transform(mce.getMethod());
            Expression object = transform(mce.getObjectExpression());
            MethodCallExpression transformed = callX(object, method, args);
            transformed.setSourcePosition(mce);
            return transformed;
        } else if (expr instanceof ClosureExpression) {
            ClosureExpression ce = (ClosureExpression) expr;
            ce.getCode().visit(this);
        } else if (expr instanceof DeclarationExpression) {
            DeclarationExpression de = (DeclarationExpression) expr;
            if (de == candidate || auto) {
                candidate = null;
                Expression left = de.getLeftExpression();
                Expression right = transform(de.getRightExpression());
                DeclarationExpression newDecl = new DeclarationExpression(left, de.getOperation(), right);
                newDecl.addAnnotations(de.getAnnotations());
                return newDecl;
            }
            return de;
        }
        return expr.transformExpression(this);
    }

    private void newifyClass(ClassNode cNode, boolean autoFlag, ListExpression list) {
        String cName = cNode.getName();
        if (cNode.isInterface()) {
            addError("Error processing interface '" + cName + "'. @"
                    + MY_NAME + " not allowed for interfaces.", cNode);
        }
        final ListExpression oldClassesToNewify = classesToNewify;
        final boolean oldAuto = auto;
        classesToNewify = list;
        auto = autoFlag;
        super.visitClass(cNode);
        classesToNewify = oldClassesToNewify;
        auto = oldAuto;
    }

    private void newifyMethodOrField(AnnotatedNode parent, boolean autoFlag, ListExpression list) {
        final ListExpression oldClassesToNewify = classesToNewify;
        final boolean oldAuto = auto;
        checkClassLevelClashes(list);
        checkAutoClash(autoFlag, parent);
        classesToNewify = list;
        auto = autoFlag;
        if (parent instanceof FieldNode) {
            super.visitField((FieldNode) parent);
        } else {
            super.visitMethod((MethodNode) parent);
        }
        classesToNewify = oldClassesToNewify;
        auto = oldAuto;
    }

    private void checkDuplicateNameClashes(ListExpression list) {
        final Set<String> seen = new HashSet<String>();
        @SuppressWarnings("unchecked")
        final List<ClassExpression> classes = (List)list.getExpressions();
        for (ClassExpression ce : classes) {
            final String name = ce.getType().getNameWithoutPackage();
            if (seen.contains(name)) {
                addError("Duplicate name '" + name + "' found during @" + MY_NAME + " processing.", ce);
            }
            seen.add(name);
        }
    }

    private void checkAutoClash(boolean autoFlag, AnnotatedNode parent) {
        if (auto && !autoFlag) {
            addError("Error during @" + MY_NAME + " processing. The 'auto' flag can't be false at " +
                    "method/constructor/field level if it is true at the class level.", parent);
        }
    }

    private void checkClassLevelClashes(ListExpression list) {
        @SuppressWarnings("unchecked")
        final List<ClassExpression> classes = (List)list.getExpressions();
        for (ClassExpression ce : classes) {
            final String name = ce.getType().getNameWithoutPackage();
            if (findClassWithMatchingBasename(name)) {
                addError("Error during @" + MY_NAME + " processing. Class '" + name + "' can't appear at " +
                        "method/constructor/field level if it already appears at the class level.", ce);
            }
        }
    }

    private boolean findClassWithMatchingBasename(String nameWithoutPackage) {
        if (classesToNewify == null) return false;
        @SuppressWarnings("unchecked")
        final List<ClassExpression> classes = (List)classesToNewify.getExpressions();
        for (ClassExpression ce : classes) {
            if (ce.getType().getNameWithoutPackage().equals(nameWithoutPackage)) {
                return true;
            }
        }
        return false;
    }

    private boolean isNewifyCandidate(MethodCallExpression mce) {
        return mce.getObjectExpression() == VariableExpression.THIS_EXPRESSION
                || (auto && isNewMethodStyle(mce));
    }

    private boolean isNewMethodStyle(MethodCallExpression mce) {
        final Expression obj = mce.getObjectExpression();
        final Expression meth = mce.getMethod();
        return (obj instanceof ClassExpression && meth instanceof ConstantExpression
                && ((ConstantExpression) meth).getValue().equals("new"));
    }

    private Expression transformMethodCall(MethodCallExpression mce, Expression args) {
        ClassNode classType;
        if (isNewMethodStyle(mce)) {
            classType = mce.getObjectExpression().getType();
        } else {
            classType = findMatchingCandidateClass(mce);
        }
        if (classType != null) {
            return new ConstructorCallExpression(classType, args);
        }
        // set the args as they might have gotten Newify transformed GROOVY-3491
        mce.setArguments(args);
        return mce;
    }

    private ClassNode findMatchingCandidateClass(MethodCallExpression mce) {
        if (classesToNewify == null) return null;
        @SuppressWarnings("unchecked")
        List<ClassExpression> classes = (List)classesToNewify.getExpressions();
        for (ClassExpression ce : classes) {
            final ClassNode type = ce.getType();
            if (type.getNameWithoutPackage().equals(mce.getMethodAsString())) {
                return type;
            }
        }
        return null;
    }

    private void internalError(String message) {
        throw new GroovyBugError("Internal error: " + message);
    }

    protected SourceUnit getSourceUnit() {
        return source;
    }
}
