/*
 * Decompiled with CFR 0.152.
 */
package org.apache.shardingsphere.infra.rewrite.engine;

import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import org.apache.shardingsphere.infra.binder.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.binder.statement.dml.SelectStatementContext;
import org.apache.shardingsphere.infra.datanode.DataNode;
import org.apache.shardingsphere.infra.rewrite.context.SQLRewriteContext;
import org.apache.shardingsphere.infra.rewrite.engine.result.RouteSQLRewriteResult;
import org.apache.shardingsphere.infra.rewrite.engine.result.SQLRewriteUnit;
import org.apache.shardingsphere.infra.rewrite.parameter.builder.ParameterBuilder;
import org.apache.shardingsphere.infra.rewrite.parameter.builder.impl.GroupedParameterBuilder;
import org.apache.shardingsphere.infra.rewrite.parameter.builder.impl.StandardParameterBuilder;
import org.apache.shardingsphere.infra.rewrite.sql.impl.RouteSQLBuilder;
import org.apache.shardingsphere.infra.route.context.RouteContext;
import org.apache.shardingsphere.infra.route.context.RouteUnit;
import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.SelectStatement;
import org.apache.shardingsphere.sql.parser.sql.common.util.SQLUtil;
import org.apache.shardingsphere.sql.parser.sql.dialect.handler.dml.SelectStatementHandler;

public final class RouteSQLRewriteEngine {
    public RouteSQLRewriteResult rewrite(SQLRewriteContext sqlRewriteContext, RouteContext routeContext) {
        LinkedHashMap<RouteUnit, SQLRewriteUnit> result = new LinkedHashMap<RouteUnit, SQLRewriteUnit>(routeContext.getRouteUnits().size(), 1.0f);
        for (Map.Entry<String, Collection<RouteUnit>> entry : this.aggregateRouteUnitGroups(routeContext.getRouteUnits()).entrySet()) {
            Collection<RouteUnit> routeUnits = entry.getValue();
            if (this.isNeedAggregateRewrite(sqlRewriteContext.getSqlStatementContext(), routeUnits)) {
                result.put(routeUnits.iterator().next(), this.createSQLRewriteUnit(sqlRewriteContext, routeContext, routeUnits));
                continue;
            }
            result.putAll(this.createSQLRewriteUnits(sqlRewriteContext, routeContext, routeUnits));
        }
        return new RouteSQLRewriteResult(result);
    }

    private SQLRewriteUnit createSQLRewriteUnit(SQLRewriteContext sqlRewriteContext, RouteContext routeContext, Collection<RouteUnit> routeUnits) {
        LinkedList<String> sql = new LinkedList<String>();
        LinkedList<Object> parameters = new LinkedList<Object>();
        boolean containsDollarMarker = sqlRewriteContext.getSqlStatementContext() instanceof SelectStatementContext && ((SelectStatementContext)sqlRewriteContext.getSqlStatementContext()).isContainsDollarParameterMarker();
        for (RouteUnit each : routeUnits) {
            sql.add(SQLUtil.trimSemicolon((String)new RouteSQLBuilder(sqlRewriteContext, each).toSQL()));
            if (containsDollarMarker && !parameters.isEmpty()) continue;
            parameters.addAll(this.getParameters(sqlRewriteContext.getParameterBuilder(), routeContext, each));
        }
        return new SQLRewriteUnit(String.join((CharSequence)" UNION ALL ", sql), parameters);
    }

    private Map<RouteUnit, SQLRewriteUnit> createSQLRewriteUnits(SQLRewriteContext sqlRewriteContext, RouteContext routeContext, Collection<RouteUnit> routeUnits) {
        LinkedHashMap<RouteUnit, SQLRewriteUnit> result = new LinkedHashMap<RouteUnit, SQLRewriteUnit>(routeUnits.size(), 1.0f);
        for (RouteUnit each : routeUnits) {
            result.put(each, new SQLRewriteUnit(new RouteSQLBuilder(sqlRewriteContext, each).toSQL(), this.getParameters(sqlRewriteContext.getParameterBuilder(), routeContext, each)));
        }
        return result;
    }

    private boolean isNeedAggregateRewrite(SQLStatementContext<?> sqlStatementContext, Collection<RouteUnit> routeUnits) {
        if (!(sqlStatementContext instanceof SelectStatementContext) || routeUnits.size() == 1) {
            return false;
        }
        SelectStatementContext statementContext = (SelectStatementContext)sqlStatementContext;
        boolean containsSubqueryJoinQuery = statementContext.isContainsSubquery() || statementContext.isContainsJoinQuery();
        boolean containsOrderByLimitClause = !statementContext.getOrderByContext().getItems().isEmpty() || statementContext.getPaginationContext().isHasPagination();
        boolean containsLockClause = SelectStatementHandler.getLockSegment((SelectStatement)((SelectStatement)statementContext.getSqlStatement())).isPresent();
        boolean needAggregateRewrite = !containsSubqueryJoinQuery && !containsOrderByLimitClause && !containsLockClause;
        statementContext.setNeedAggregateRewrite(needAggregateRewrite);
        return needAggregateRewrite;
    }

    private Map<String, Collection<RouteUnit>> aggregateRouteUnitGroups(Collection<RouteUnit> routeUnits) {
        LinkedHashMap<String, Collection<RouteUnit>> result = new LinkedHashMap<String, Collection<RouteUnit>>(routeUnits.size(), 1.0f);
        for (RouteUnit each : routeUnits) {
            String dataSourceName = each.getDataSourceMapper().getActualName();
            if (!result.containsKey(dataSourceName)) {
                result.put(dataSourceName, new LinkedList());
            }
            ((Collection)result.get(dataSourceName)).add(each);
        }
        return result;
    }

    private List<Object> getParameters(ParameterBuilder parameterBuilder, RouteContext routeContext, RouteUnit routeUnit) {
        if (parameterBuilder instanceof StandardParameterBuilder) {
            return parameterBuilder.getParameters();
        }
        return routeContext.getOriginalDataNodes().isEmpty() ? ((GroupedParameterBuilder)parameterBuilder).getParameters() : this.buildRouteParameters((GroupedParameterBuilder)parameterBuilder, routeContext, routeUnit);
    }

    private List<Object> buildRouteParameters(GroupedParameterBuilder parameterBuilder, RouteContext routeContext, RouteUnit routeUnit) {
        LinkedList<Object> result = new LinkedList<Object>();
        int count = 0;
        for (Collection each : routeContext.getOriginalDataNodes()) {
            if (this.isInSameDataNode(each, routeUnit)) {
                result.addAll(parameterBuilder.getParameters(count));
            }
            ++count;
        }
        result.addAll(parameterBuilder.getGenericParameterBuilder().getParameters());
        return result;
    }

    private boolean isInSameDataNode(Collection<DataNode> dataNodes, RouteUnit routeUnit) {
        if (dataNodes.isEmpty()) {
            return true;
        }
        for (DataNode each : dataNodes) {
            if (!routeUnit.findTableMapper(each.getDataSourceName(), each.getTableName()).isPresent()) continue;
            return true;
        }
        return false;
    }
}

