/*
 * Decompiled with CFR 0.152.
 */
package org.apache.shardingsphere.infra.rewrite.engine;

import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import lombok.Generated;
import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext;
import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
import org.apache.shardingsphere.infra.datanode.DataNode;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.metadata.database.resource.unit.StorageUnit;
import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData;
import org.apache.shardingsphere.infra.rewrite.context.SQLRewriteContext;
import org.apache.shardingsphere.infra.rewrite.engine.result.RouteSQLRewriteResult;
import org.apache.shardingsphere.infra.rewrite.engine.result.SQLRewriteUnit;
import org.apache.shardingsphere.infra.rewrite.parameter.builder.ParameterBuilder;
import org.apache.shardingsphere.infra.rewrite.parameter.builder.impl.GroupedParameterBuilder;
import org.apache.shardingsphere.infra.rewrite.parameter.builder.impl.StandardParameterBuilder;
import org.apache.shardingsphere.infra.rewrite.sql.impl.RouteSQLBuilder;
import org.apache.shardingsphere.infra.route.context.RouteContext;
import org.apache.shardingsphere.infra.route.context.RouteUnit;
import org.apache.shardingsphere.infra.session.query.QueryContext;
import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.SelectStatement;
import org.apache.shardingsphere.sql.parser.sql.common.util.SQLUtils;
import org.apache.shardingsphere.sql.parser.sql.dialect.handler.dml.SelectStatementHandler;
import org.apache.shardingsphere.sqltranslator.context.SQLTranslatorContext;
import org.apache.shardingsphere.sqltranslator.rule.SQLTranslatorRule;

public final class RouteSQLRewriteEngine {
    private final SQLTranslatorRule translatorRule;
    private final ShardingSphereDatabase database;
    private final RuleMetaData globalRuleMetaData;

    public RouteSQLRewriteResult rewrite(SQLRewriteContext sqlRewriteContext, RouteContext routeContext, QueryContext queryContext) {
        LinkedHashMap<RouteUnit, SQLRewriteUnit> sqlRewriteUnits = new LinkedHashMap<RouteUnit, SQLRewriteUnit>(routeContext.getRouteUnits().size(), 1.0f);
        for (Map.Entry<String, Collection<RouteUnit>> entry : this.aggregateRouteUnitGroups(routeContext.getRouteUnits()).entrySet()) {
            Collection<RouteUnit> routeUnits = entry.getValue();
            if (this.isNeedAggregateRewrite(sqlRewriteContext.getSqlStatementContext(), routeUnits)) {
                sqlRewriteUnits.put(routeUnits.iterator().next(), this.createSQLRewriteUnit(sqlRewriteContext, routeContext, routeUnits));
                continue;
            }
            this.addSQLRewriteUnits(sqlRewriteUnits, sqlRewriteContext, routeContext, routeUnits);
        }
        return new RouteSQLRewriteResult(this.translate(queryContext, sqlRewriteUnits));
    }

    private SQLRewriteUnit createSQLRewriteUnit(SQLRewriteContext sqlRewriteContext, RouteContext routeContext, Collection<RouteUnit> routeUnits) {
        LinkedList<String> sql = new LinkedList<String>();
        LinkedList<Object> params = new LinkedList<Object>();
        boolean containsDollarMarker = sqlRewriteContext.getSqlStatementContext() instanceof SelectStatementContext && ((SelectStatementContext)sqlRewriteContext.getSqlStatementContext()).isContainsDollarParameterMarker();
        for (RouteUnit each : routeUnits) {
            sql.add(SQLUtils.trimSemicolon((String)new RouteSQLBuilder(sqlRewriteContext, each).toSQL()));
            if (containsDollarMarker && !params.isEmpty()) continue;
            params.addAll(this.getParameters(sqlRewriteContext.getParameterBuilder(), routeContext, each));
        }
        return new SQLRewriteUnit(String.join((CharSequence)" UNION ALL ", sql), params);
    }

    private void addSQLRewriteUnits(Map<RouteUnit, SQLRewriteUnit> sqlRewriteUnits, SQLRewriteContext sqlRewriteContext, RouteContext routeContext, Collection<RouteUnit> routeUnits) {
        for (RouteUnit each : routeUnits) {
            sqlRewriteUnits.put(each, new SQLRewriteUnit(new RouteSQLBuilder(sqlRewriteContext, each).toSQL(), this.getParameters(sqlRewriteContext.getParameterBuilder(), routeContext, each)));
        }
    }

    private boolean isNeedAggregateRewrite(SQLStatementContext sqlStatementContext, Collection<RouteUnit> routeUnits) {
        if (!(sqlStatementContext instanceof SelectStatementContext) || routeUnits.size() == 1) {
            return false;
        }
        SelectStatementContext statementContext = (SelectStatementContext)sqlStatementContext;
        boolean containsSubqueryJoinQuery = statementContext.isContainsSubquery() || statementContext.isContainsJoinQuery();
        boolean containsOrderByLimitClause = !statementContext.getOrderByContext().getItems().isEmpty() || statementContext.getPaginationContext().isHasPagination();
        boolean containsLockClause = SelectStatementHandler.getLockSegment((SelectStatement)statementContext.getSqlStatement()).isPresent();
        boolean needAggregateRewrite = !containsSubqueryJoinQuery && !containsOrderByLimitClause && !containsLockClause;
        statementContext.setNeedAggregateRewrite(needAggregateRewrite);
        return needAggregateRewrite;
    }

    private Map<String, Collection<RouteUnit>> aggregateRouteUnitGroups(Collection<RouteUnit> routeUnits) {
        LinkedHashMap<String, Collection<RouteUnit>> result = new LinkedHashMap<String, Collection<RouteUnit>>(routeUnits.size(), 1.0f);
        for (RouteUnit each : routeUnits) {
            String dataSourceName = each.getDataSourceMapper().getActualName();
            result.computeIfAbsent(dataSourceName, unused -> new LinkedList()).add(each);
        }
        return result;
    }

    private List<Object> getParameters(ParameterBuilder paramBuilder, RouteContext routeContext, RouteUnit routeUnit) {
        if (paramBuilder instanceof StandardParameterBuilder) {
            return paramBuilder.getParameters();
        }
        return routeContext.getOriginalDataNodes().isEmpty() ? ((GroupedParameterBuilder)paramBuilder).getParameters() : this.buildRouteParameters((GroupedParameterBuilder)paramBuilder, routeContext, routeUnit);
    }

    private List<Object> buildRouteParameters(GroupedParameterBuilder paramBuilder, RouteContext routeContext, RouteUnit routeUnit) {
        LinkedList<Object> result = new LinkedList<Object>();
        int count = 0;
        for (Collection each : routeContext.getOriginalDataNodes()) {
            if (this.isInSameDataNode(each, routeUnit)) {
                result.addAll(paramBuilder.getParameters(count));
            }
            ++count;
        }
        result.addAll(paramBuilder.getGenericParameterBuilder().getParameters());
        return result;
    }

    private boolean isInSameDataNode(Collection<DataNode> dataNodes, RouteUnit routeUnit) {
        if (dataNodes.isEmpty()) {
            return true;
        }
        for (DataNode each : dataNodes) {
            if (!routeUnit.findTableMapper(each.getDataSourceName(), each.getTableName()).isPresent()) continue;
            return true;
        }
        return false;
    }

    private Map<RouteUnit, SQLRewriteUnit> translate(QueryContext queryContext, Map<RouteUnit, SQLRewriteUnit> sqlRewriteUnits) {
        LinkedHashMap<RouteUnit, SQLRewriteUnit> result = new LinkedHashMap<RouteUnit, SQLRewriteUnit>(sqlRewriteUnits.size(), 1.0f);
        Map storageUnits = this.database.getResourceMetaData().getStorageUnits();
        for (Map.Entry<RouteUnit, SQLRewriteUnit> entry : sqlRewriteUnits.entrySet()) {
            DatabaseType storageType = ((StorageUnit)storageUnits.get(entry.getKey().getDataSourceMapper().getActualName())).getStorageType();
            SQLTranslatorContext sqlTranslatorContext = this.translatorRule.translate(entry.getValue().getSql(), entry.getValue().getParameters(), queryContext, storageType, this.database, this.globalRuleMetaData);
            SQLRewriteUnit sqlRewriteUnit = new SQLRewriteUnit(sqlTranslatorContext.getSql(), sqlTranslatorContext.getParameters());
            result.put(entry.getKey(), sqlRewriteUnit);
        }
        return result;
    }

    @Generated
    public RouteSQLRewriteEngine(SQLTranslatorRule translatorRule, ShardingSphereDatabase database, RuleMetaData globalRuleMetaData) {
        this.translatorRule = translatorRule;
        this.database = database;
        this.globalRuleMetaData = globalRuleMetaData;
    }
}

