package org.nuiton.spgeed;

/*-
 * #%L
 * spgeed
 * %%
 * Copyright (C) 2017 CodeLutin
 * %%
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Lesser General Public License as
 * published by the Free Software Foundation, either version 3 of the
 * License, or (at your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Lesser Public License for more details.
 * 
 * You should have received a copy of the GNU General Lesser Public
 * License along with this program.  If not, see
 * <http://www.gnu.org/licenses/lgpl-3.0.html>.
 * #L%
 */

import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.DeserializationContext;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.JsonDeserializer;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.module.SimpleModule;
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule;
import java.io.IOException;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.xml.bind.DatatypeConverter;

import jodd.bean.BeanException;
import jodd.bean.BeanUtil;
import org.apache.commons.lang3.ClassUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.json.simple.JSONArray;

/**
 * Represents the statement which can be execute on the current transaction.
 * 
 * @author julien
 */
public class Query {
    
    final static private Log log = LogFactory.getLog(Query.class);

    protected SqlSession session;
    
    protected String sql;
    protected String[] roles;
    protected Map<String, Object> parameters;
    protected Class returnType;
    
    protected String parsedSql;
    protected List<Object> sqlParameters;
    
    public Query(SqlSession session, String sql, String[] roles, Map<String, Object> parameters, Class<?> returnType) throws Exception {
        this.session = session;
        
        this.sql = sql;
        this.roles = roles;
        this.parameters = parameters;
        this.returnType = returnType;

        this.sqlParameters = new ArrayList<>();
    }
        
    public Query(SqlSession session, String sql, Map<String, Object> parameters, Class<?> returnType) throws Exception {
        this(session, sql, null, parameters, returnType);
    }

    public String getParsedSql() throws SQLException {
        if (this.parsedSql == null) {
            try {
                this.parsedSql = this.parseParameters(this.sql);
            } catch (Exception eee) {
                throw new SQLException("Can't parse sql: " + this.sql, eee);
            }
        }
        return parsedSql;
    }

    protected String parseParameters(String sql) throws Exception {
        String result = QueryParser.parse(sql, this);
        return result;
    }

    public Object evalField(String path) {
        Object result = this.evalField(this.parameters, path);
        return result;
    }

    public Object evalField(Object o, String path) {
        Object result = o;
        if (StringUtils.isNotBlank(path)) {
            try {
                result = BeanUtil.pojo.getProperty(result, path.trim());
            } catch (BeanException eee) {
                throw new RuntimeException(String.format("Can't find properties '%s' in '%s'", path, result), eee);
            }
        }
        return result;
    }

    public Object evalFunction(Object value, String name, List args) {
        try {
            Map<String, PipeFunction> pipeFunctions = this.session.getPipeFunctions();
            PipeFunction pipeFunction = pipeFunctions.get(name);
            if (pipeFunction == null) {
                throw new RuntimeException(String.format("Can find function for evalFunction(%s, '%s', %s)", value, name, args));
            }
            
            Object result = pipeFunction.function(this, value, args == null ? null : args.toArray());
            return result;
        } catch (Exception eee) {
            throw new RuntimeException(String.format("Error during evalFunction(%s, %s, %s)", value, name, args), eee);
        }
    }

    /**
     *
     * @param value object to used as parameter for sql query
     * @return string will must be used to replace argument in sql string
     */
    public String addSqlParameter(Object value) {
        this.sqlParameters.add(value);
        return "?";
    }

    public SqlSession getSession() {
        return session;
    }
    
    protected PreparedStatement getStatement(String sql) throws SQLException {
        Connection connection = this.session.getConnection();
        
        PreparedStatement statement = connection.prepareStatement(sql);
        
        int position = 1;
        for (Object arg : this.sqlParameters) {
            if (arg != null && arg.getClass().isEnum()) {
                statement.setObject(position++, arg.toString());
            } else {
                statement.setObject(position++, arg);
            }
        }
        
        return statement;
    }

    protected String getFormattedRoles() {
        String formattedRoles = Stream.of(this.roles)
                .collect(Collectors.joining("', 'member') OR pg_has_role('", "pg_has_role('", "', 'member')"));
        return formattedRoles;
    }

    protected <E> E getResult(final PreparedStatement statement) throws SQLException, IOException, ClassNotFoundException {
        ResultSet result = statement.executeQuery();
        
        if (this.returnType.equals(ResultSet.class)) {
            return (E) result;
        }
        
        while (result.next()) {
            String json = result.getString(1);
            if (json == null) {
                return null;
            }
            
            ObjectMapper mapper = new ObjectMapper();
            mapper.registerModule(new JavaTimeModule());
            mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
            
            SimpleModule module = new SimpleModule();
            module.addDeserializer(byte[].class, new JsonDeserializer<byte[]>() {
                @Override
                public byte[] deserialize(JsonParser jsonParser, DeserializationContext dc) throws IOException, JsonProcessingException {
                    String value = jsonParser.getText();
                    byte[] decode = DatatypeConverter.parseHexBinary(value.substring(2)); // Remove \x
                    return decode;
                }
            });
            mapper.registerModule(module);
            
            if (this.returnType.equals(Void.TYPE)) {
                return null;
            }
            
            if (this.returnType.getSimpleName().equals("byte[]")) {
                String firstValue = (String) getFirstValue(mapper, json);
                if (firstValue != null) {
                    byte[] decode = DatatypeConverter.parseHexBinary(firstValue.substring(2)); // Remove \x
                    return (E) decode;
                }
                
                return null;
            }
            
            if (this.returnType.isArray() || this.returnType.equals(JSONArray.class)) {
                return (E) mapper.readValue(json, this.returnType);
            }
            
            if (this.returnType.equals(UUID.class)) {
                Object firstValue = getFirstValue(mapper, json);
                return (E) UUID.fromString((String) firstValue);
            }
            
            if (this.returnType.isEnum()) {
                Object firstValue = getFirstValue(mapper, json);
                return (E) Enum.valueOf(this.returnType, (String) firstValue);
            }
            
            if (ClassUtils.isPrimitiveOrWrapper(returnType) || returnType.equals(String.class)) {
                Object firstValue = getFirstValue(mapper, json);
                return (E) firstValue;
            }
                        
            Class<?> returnTypeArray = Class.forName("[L" + this.returnType.getName() + ";");
            E[] readValue = (E[]) mapper.readValue(json, returnTypeArray);
            return readValue[0];
        }
        
        return null;
    }

    protected Object getFirstValue(ObjectMapper mapper, String json) throws IOException {
        JSONArray readValue = mapper.readValue(json, JSONArray.class);
        Map map = (Map) readValue.get(0);
        Collection values = map.values();
        Object firstValue = values.iterator().next();
        return firstValue;
    }
    
    public<E> E executeQuery() throws Exception {
        String sql = this.getParsedSql();
        
        if (this.roles != null && this.roles.length != 0) {
            String formattedRoles = getFormattedRoles();            
            sql = String.format("WITH __all AS (%s) SELECT json_agg(__all.*) FROM __all WHERE %s", sql, formattedRoles);
        } else {
            sql = String.format("WITH __all AS (%s) SELECT json_agg(__all.*) FROM __all", sql);   
        }
        
        try (PreparedStatement statement = this.getStatement(sql)) {
            E result = getResult(statement);
            return result;
            
        } catch (Exception eee) {
            log.error(String.format("Can't execute query '%s' with args: %s", this.getParsedSql(), this.sqlParameters));
            throw eee;
        }
    }
    
    public <E> E executeUpdate() throws SQLException, Exception {
        String sql = this.getParsedSql();
        
        if (this.roles != null && this.roles.length != 0) {
            String formattedRoles = getFormattedRoles();
            
            if (sql.contains("RETURNING")) {
                sql = String.format("WITH __all AS (%s) SELECT json_agg(__all.*) FROM __all WHERE %s", sql, formattedRoles);
                
            } else if (sql.startsWith("UPDATE") || sql.startsWith("INSERT")) {
                sql = String.format("WITH __all AS (%s RETURNING 1) SELECT '[{\"result\":' || COUNT(__all.*) || '}]' FROM __all WHERE %s", sql, formattedRoles);
                
            } else {
                throw new SQLException("You cannot use roles in this type of request");
            }
        } else if (sql.contains("RETURNING")) {
            sql = String.format("WITH __all AS (%s) SELECT json_agg(__all.*) FROM __all", sql);
            
        } else if (sql.startsWith("UPDATE") || sql.startsWith("INSERT") || sql.startsWith("DELETE")) {
            sql = String.format("WITH __all AS (%s RETURNING 1) SELECT '[{\"result\":' || COUNT(__all.*) || '}]' FROM __all", sql);
            
        } else {
            try (PreparedStatement statement = this.getStatement(sql)) {
                Integer result = statement.executeUpdate();
                return (E) result;

            } catch (Exception eee) {
                log.error(String.format("Can't execute query '%s' with args: %s", this.getParsedSql(), this.sqlParameters));
                throw eee;
            }
        }
        
        try (PreparedStatement statement = this.getStatement(sql)) {
            E result = getResult(statement);
            return result;

        } catch (Exception eee) {
            log.error(String.format("Can't execute query '%s' with args: %s", this.getParsedSql(), this.sqlParameters));
            throw eee;
        }
    }

    public boolean execute() throws SQLException {
        String sql = this.getParsedSql();
        
        try (PreparedStatement statement = this.getStatement(sql)) {
            boolean result = statement.execute();
            return result;
            
        } catch (Exception eee) {
            log.error(String.format("Can't execute query '%s' with args: %s", this.getParsedSql(), this.sqlParameters));
            throw eee;
        }
    }
}
