06、ShardingJDBC实战:SQL改写

一 序

上一篇在《SQL路由实现》 提了SQL改写。这里跟路由紧密结合,是在路由之后的步骤。路由解决了分库分表去查那些,改写负责一些如查询结果需要聚合,对SQL进行调整,生成分库分表后的执行SQL。

二 sqltoken

SQL改写的源码在core模块的io.shardingjdbc.core.rewrite目录下,分为SQLBuilder和SQLRewriteEngine。

SQLToken,SQL标记对象接口,属于SQL解析部分,在io.shardingjdbc.core.parsing.parser.token目录下。SQLRewriteEngine 基于 SQLToken 实现 SQL改写。SQL解析器在 SQL解析过程中,很重要的一个目的是标记需要SQL改写的部分,也就是 SQLToken。

*

GeneratedKeyToken 自增主键标记对象
插入SQL自增列不存在:INSERT INTO t_order(nickname) VALUES ... 中没有自增列 order_id
TableToken 表标记对象
查询列的表别名:SELECT o.order_id 的 o
ItemsToken 选择项标记对象
AVG查询列:SELECT AVG(price) FROM t_order 的 AVG(price)
ORDER BY 字段不在查询列:SELECT order_id FROM t_order ORDER BY create_time 的 create_time
GROUP BY 字段不在查询列:SELECT COUNT(order_id) FROM t_order GROUP BY user_id 的 user_id
自增主键未在插入列中:INSERT INTO t_order(nickname) VALUES ... 中没有自增列 order_id
OffsetToken 分页偏移量标记对象
分页有偏移量,但不是占位符 ?
RowCountToken 分页长度标记对象
分页有长度,但不是占位符 ?
OrderByToken 排序标记对象
有GROUP BY 条件,无 ORDER BY 条件:SELECT COUNT(*) FROM t_order GROUP BY order_id 的 order_id

三 SQL 改写

SQLRewriteEngine#rewrite() 实现了 SQL改写 功能。

/**
     * rewrite SQL.
     *
     * @param isRewriteLimit is rewrite limit
     * @return SQL builder
     */
    public SQLBuilder rewrite(final boolean isRewriteLimit) {
        SQLBuilder result = new SQLBuilder();
        if (sqlTokens.isEmpty()) {
            result.appendLiterals(originalSQL);
            return result;
        }
        int count = 0;
        // 排序SQLToken,按照 beginPosition 递增
        sortByBeginPosition();
        for (SQLToken each : sqlTokens) {
            if (0 == count) {
                //第一次处理:截取从原生SQL的开始位置到第一个token起始位置之间的内容,
                //例如"SELECT x.id FROM table_x x LIMIT 2, 2"这条SQL的第一个token是TableToken,即table_x所在位置,所以截取内容为"SELECT x.id FROM "
                result.appendLiterals(originalSQL.substring(0, each.getBeginPosition()));
            }
             // 拼接每个SQLToken,后面有解释
            if (each instanceof TableToken) {
                appendTableToken(result, (TableToken) each, count, sqlTokens);
            } else if (each instanceof IndexToken) {
                appendIndexToken(result, (IndexToken) each, count, sqlTokens);
            } else if (each instanceof ItemsToken) {
                appendItemsToken(result, (ItemsToken) each, count, sqlTokens);
            } else if (each instanceof RowCountToken) {
                appendLimitRowCount(result, (RowCountToken) each, count, sqlTokens, isRewriteLimit);
            } else if (each instanceof OffsetToken) {
                appendLimitOffsetToken(result, (OffsetToken) each, count, sqlTokens, isRewriteLimit);
            } else if (each instanceof OrderByToken) {
                appendOrderByToken(result, count, sqlTokens);
            }
            count++;
        }
        return result;
    }
 private void sortByBeginPosition() {
        Collections.sort(sqlTokens, new Comparator<SQLToken>() {
            
            @Override
            public int compare(final SQLToken o1, final SQLToken o2) {
                return o1.getBeginPosition() - o2.getBeginPosition();
            }
        });
    }

SQLBuilder,SQL构建器。主要属性有:

public final class SQLBuilder {
    
    private final List<Object> segments;
    
    private StringBuilder currentSegment;

下面是每种token的实现方式:

3.1 tabletoken

private void appendTableToken(final SQLBuilder sqlBuilder, final TableToken tableToken, final int count, final List<SQLToken> sqlTokens) {
        sqlBuilder.appendTable(tableToken.getTableName().toLowerCase());
        int beginPosition = tableToken.getBeginPosition() + tableToken.getOriginalLiterals().length();
        appendRest(sqlBuilder, count, sqlTokens, beginPosition);
    }

1gettablename处理了特殊字符

    public String getTableName() {
        return SQLUtil.getExactlyValue(originalLiterals);
    }
  public static String getExactlyValue(final String value) {
        return null == value ? null : CharMatcher.anyOf("[]'\"").removeFrom(value);
    }
   //把TableToken也要添加到SQLBuilder中(List<Object> segments)
    public void appendTable(final String tableName) {
        segments.add(new TableToken(tableName));
        currentSegment = new StringBuilder();
        segments.add(currentSegment);
    }

3.2 IndexToken

 private void appendIndexToken(final SQLBuilder sqlBuilder, final IndexToken indexToken, final int count, final List<SQLToken> sqlTokens) {
        String indexName = indexToken.getIndexName().toLowerCase();
        String logicTableName = indexToken.getTableName().toLowerCase();
        if (Strings.isNullOrEmpty(logicTableName)) {
            logicTableName = shardingRule.getLogicTableName(indexName);
        }
        sqlBuilder.appendIndex(indexName, logicTableName);
        int beginPosition = indexToken.getBeginPosition() + indexToken.getOriginalLiterals().length();
        appendRest(sqlBuilder, count, sqlTokens, beginPosition);
    }

主要是把索引名与逻辑表拼在一起。把indextoken加入到SQLBuilder

    public void appendIndex(final String indexName, final String tableName) {
        segments.add(new IndexToken(indexName, tableName));
        currentSegment = new StringBuilder();
        segments.add(currentSegment);
    }

3.3 ItemsToken

 private void appendItemsToken(final SQLBuilder sqlBuilder, final ItemsToken itemsToken, final int count, final List<SQLToken> sqlTokens) {
        for (String item : itemsToken.getItems()) {
            sqlBuilder.appendLiterals(", ");
            sqlBuilder.appendLiterals(SQLUtil.getOriginalValue(item, databaseType));
        }
        int beginPosition = itemsToken.getBeginPosition();
        appendRest(sqlBuilder, count, sqlTokens, beginPosition);
    }

举例:avg: SELECT AVG(order_id) FROM t_order WHERE user_id =#{userId,jdbcType=INTEGER}

执行完appendItemsToken 后:

item [SELECT AVG(order_id) , COUNT(order_id) AS AVG_DERIVED_COUNT_0 , SUM(order_id) AS AVG_DERIVED_SUM_0 FROM ]

其他的当逻辑SQL有order by,group by这样的特殊条件时,需要在select的结果列中增加一些结果列。

3.4 OffsetToken

 private void appendLimitOffsetToken(final SQLBuilder sqlBuilder, final OffsetToken offsetToken, final int count, final List<SQLToken> sqlTokens, final boolean isRewrite) {
        sqlBuilder.appendLiterals(isRewrite ? "0" : String.valueOf(offsetToken.getOffset()));
        int beginPosition = offsetToken.getBeginPosition() + String.valueOf(offsetToken.getOffset()).length();
        appendRest(sqlBuilder, count, sqlTokens, beginPosition);
    }

当分页跨分片时,需要每个分片都查询后在内存中进行聚合。此时 isRewrite = true。为什么是 "0" 开始呢?每个分片在 [0, offset) 的记录可能属于实际分页结果,因而查询每个分片需要从 0 开始。

当分页单分片时,则无需重写,该分片执行的结果即是最终结果。

3.5 RowCountToken

private void appendLimitRowCount(final SQLBuilder sqlBuilder, final RowCountToken rowCountToken, final int count, final List<SQLToken> sqlTokens, final boolean isRewrite) {
        SelectStatement selectStatement = (SelectStatement) sqlStatement;
        Limit limit = selectStatement.getLimit();
        if (!isRewrite) {
            //不需要重写,路有结果为单分片,直接append rowCount的值即可;
            sqlBuilder.appendLiterals(String.valueOf(rowCountToken.getRowCount()));
        } else if ((!selectStatement.getGroupByItems().isEmpty() // [1.1] 跨分片分组需要在内存计算,可能需要全部加载
        || !selectStatement.getAggregationSelectItems().isEmpty()) && !selectStatement.isSameGroupByAndOrderByItems()) {
           //跨分片聚合列需要在内存计算,可能需要全部加载,如果排序不一致(即各分片没有排序好结果),可能需要全部加载
            sqlBuilder.appendLiterals(String.valueOf(Integer.MAX_VALUE));
        } else {
             // 路由结果为多分片,重写为  offset+rowCount;
            sqlBuilder.appendLiterals(String.valueOf(limit.isNeedRewriteRowCount() ? rowCountToken.getRowCount() + limit.getOffsetValue() : rowCountToken.getRowCount()));
        }
        int beginPosition = rowCountToken.getBeginPosition() + String.valueOf(rowCountToken.getRowCount()).length();
        appendRest(sqlBuilder, count, sqlTokens, beginPosition);
    }

条件2可能变成必须的前提是 GROUP BY 和 ORDER BY 排序不一致。如果一致,各分片已经排序完成,无需内存中排序。

分页补充

OffsetToken、RowCountToken 只有在分页对应位置非占位符 ? 才存在。当对应位置是占位符时,会对分页条件对应的预编译 SQL 占位符参数进行重写,整体逻辑和 OffsetToken、RowCountToken 是一致的。以下为ParsingSQLRouter.route里面调用

 private void processLimit(final List<Object> parameters, final SelectStatement selectStatement, final boolean isSingleRouting) {
        if (isSingleRouting) {
            selectStatement.setLimit(null);
            return;
        }
        boolean isNeedFetchAll = (!selectStatement.getGroupByItems().isEmpty() || !selectStatement.getAggregationSelectItems().isEmpty()) && !selectStatement.isSameGroupByAndOrderByItems();
        selectStatement.getLimit().processParameters(parameters, isNeedFetchAll);
    }

3.6 OrderByToken

数据库里,当无 ORDER BY条件 而有 GROUP BY 条件时候,会使用 GROUP BY条件将结果升序排序:
SELECT order_id FROM t_order GROUP BY order_id 等价于 SELECT order_id FROM t_order GROUP BY order_id ORDER BY order_id ASC

SELECT order_id FROM t_order GROUP BY order_id DESC 等价于 SELECT order_id FROM t_order GROUP BY order_id ORDER BY order_id DESC

 private void appendOrderByToken(final SQLBuilder sqlBuilder, final int count, final List<SQLToken> sqlTokens) {
        SelectStatement selectStatement = (SelectStatement) sqlStatement;
        StringBuilder orderByLiterals = new StringBuilder();
        orderByLiterals.append(" ").append(DefaultKeyword.ORDER).append(" ").append(DefaultKeyword.BY).append(" ");
        int i = 0;
        for (OrderItem each : selectStatement.getOrderByItems()) {
            String columnLabel = SQLUtil.getOriginalValue(each.getColumnLabel(), databaseType);
            if (0 == i) {
                orderByLiterals.append(columnLabel).append(" ").append(each.getType().name());
            } else {
                orderByLiterals.append(",").append(columnLabel).append(" ").append(each.getType().name());
            }
            i++;
        }
        orderByLiterals.append(" ");
        sqlBuilder.appendLiterals(orderByLiterals.toString());
        int beginPosition = ((SelectStatement) sqlStatement).getGroupByLastPosition();
        appendRest(sqlBuilder, count, sqlTokens, beginPosition);
    }

3.7 appendrest

private void appendRest(final SQLBuilder sqlBuilder, final int count, final List<SQLToken> sqlTokens, final int beginPosition) {
    // 如果SQL解析后只有一个token,那么结束位置(endPosition)就是sql末尾;否则结束位置就是到下一个token的起始位置
    int endPosition = sqlTokens.size() - 1 == count ? originalSQL.length() : sqlTokens.get(count + 1).getBeginPosition();
    sqlBuilder.appendLiterals(originalSQL.substring(beginPosition, endPosition));
}

所有重写最后都会内部调用appendRest(),即附加上余下部分内容,这个余下部分内容是指从当前处理的token到下一个token之间的内容。

四 生成SQL

重写完后,调用SQLBuilder的toString()方法生成重写后最终的SQL语句;

/**
     * Convert to SQL string.
     *
     * @param tableTokens table tokens
     * @return SQL string
     */
    public String toSQL(final Map<String, String> tableTokens) {
        StringBuilder result = new StringBuilder();
        for (Object each : segments) {
            if (each instanceof TableToken && tableTokens.containsKey(((TableToken) each).tableName)) {
                result.append(tableTokens.get(((TableToken) each).tableName));
            } else if (each instanceof IndexToken) {
                IndexToken indexToken = (IndexToken) each;
                result.append(indexToken.indexName);
                String tableName = tableTokens.get(indexToken.tableName);
                if (!Strings.isNullOrEmpty(tableName)) {
                    result.append("_");
                    result.append(tableName);
                }
            } else {
                result.append(each);
            }
        }
        return result.toString();
    }

举例:上面的处理完之后结果为

sqlbuilder:[[SELECT AVG(order_id) , COUNT(order_id) AS AVG_DERIVED_COUNT_0 , SUM(order_id) AS AVG_DERIVED_SUM_0 FROM , t_order, WHERE user_id =?]

最后,再来看下SQL改写在路由的大流程ParsingSQLRouter 主#route()

@Override  
   public SQLRouteResult route(final String logicSQL, final List<Object> parameters, final SQLStatement sqlStatement) {  
       SQLRouteResult result = new SQLRouteResult(sqlStatement);  
       //处理 插入SQL 主键字段  
       if (sqlStatement instanceof InsertStatement && null != ((InsertStatement) sqlStatement).getGeneratedKey()) {  
           processGeneratedKey(parameters, (InsertStatement) sqlStatement, result);  
       }  
       //路由  
       RoutingResult routingResult = route(parameters, sqlStatement);  
       //SQL重写引擎  
       SQLRewriteEngine rewriteEngine = new SQLRewriteEngine(shardingRule, logicSQL, databaseType, sqlStatement);  
       boolean isSingleRouting = routingResult.isSingleRouting();  
       // 处理分页  
       if (sqlStatement instanceof SelectStatement && null != ((SelectStatement) sqlStatement).getLimit()) {  
           processLimit(parameters, (SelectStatement) sqlStatement, isSingleRouting);  
       }  
       // SQL 重写  
       SQLBuilder sqlBuilder = rewriteEngine.rewrite(!isSingleRouting);  
       // 生成 ExecutionUnit  
       if (routingResult instanceof CartesianRoutingResult) {  
           for (CartesianDataSource cartesianDataSource : ((CartesianRoutingResult) routingResult).getRoutingDataSources()) {  
               for (CartesianTableReference cartesianTableReference : cartesianDataSource.getRoutingTableReferences()) {  
                   result.getExecutionUnits().add(new SQLExecutionUnit(cartesianDataSource.getDataSource(), rewriteEngine.generateSQL(cartesianTableReference, sqlBuilder)));  
               }  
           }  
       } else {  
           for (TableUnit each : routingResult.getTableUnits().getTableUnits()) {  
               result.getExecutionUnits().add(new SQLExecutionUnit(each.getDataSourceName(), rewriteEngine.generateSQL(each, sqlBuilder)));  
           }  
       }  
       //打印sql  
       if (showSQL) {  
           SQLLogger.logSQL(logicSQL, sqlStatement, result.getExecutionUnits(), parameters);  
       }  
       return result;  
   }  

路有结果: routeResult [SQLExecutionUnit(dataSource=demo_ds_0, sql=SELECT AVG(order_id) , COUNT(order_id) AS AVG_DERIVED_COUNT_0 , SUM(order_id) AS AVG_DERIVED_SUM_0 FROM t_order_0 WHERE user_id =?), SQLExecutionUnit(dataSource=demo_ds_0, sql=SELECT AVG(order_id) , COUNT(order_id) AS AVG_DERIVED_COUNT_0 , SUM(order_id) AS AVG_DERIVED_SUM_0 FROM t_order_1 WHERE user_id =?), SQLExecutionUnit(dataSource=demo_ds_0, sql=SELECT AVG(order_id) , COUNT(order_id) AS AVG_DERIVED_COUNT_0 , SUM(order_id) AS AVG_DERIVED_SUM_0 FROM t_order_2 WHERE user_id =?), SQLExecutionUnit(dataSource=demo_ds_0, sql=SELECT AVG(order_id) , COUNT(order_id) AS AVG_DERIVED_COUNT_0 , SUM(order_id) AS AVG_DERIVED_SUM_0 FROM t_order_3 WHERE user_id =?), SQLExecutionUnit(dataSource=demo_ds_0, sql=SELECT AVG(order_id) , COUNT(order_id) AS AVG_DERIVED_COUNT_0 , SUM(order_id) AS AVG_DERIVED_SUM_0 FROM t_order_4 WHERE user_id =?), SQLExecutionUnit(dataSource=demo_ds_0, sql=SELECT AVG(order_id) , COUNT(order_id) AS AVG_DERIVED_COUNT_0 , SUM(order_id) AS AVG_DERIVED_SUM_0 FROM t_order_5 WHERE user_id =?), SQLExecutionUnit(dataSource=demo_ds_0, sql=SELECT AVG(order_id) , COUNT(order_id) AS AVG_DERIVED_COUNT_0 , SUM(order_id) AS AVG_DERIVED_SUM_0 FROM t_order_6 WHERE user_id =?), SQLExecutionUnit(dataSource=demo_ds_0, sql=SELECT AVG(order_id) , COUNT(order_id) AS AVG_DERIVED_COUNT_0 , SUM(order_id) AS AVG_DERIVED_SUM_0 FROM t_order_7 WHERE user_id =?), SQLExecutionUnit(dataSource=demo_ds_0, sql=SELECT AVG(order_id) , COUNT(order_id) AS AVG_DERIVED_COUNT_0 , SUM(order_id) AS AVG_DERIVED_SUM_0 FROM t_order_8 WHERE user_id =?), SQLExecutionUnit(dataSource=demo_ds_0, sql=SELECT AVG(order_id) , COUNT(order_id) AS AVG_DERIVED_COUNT_0 , SUM(order_id) AS AVG_DERIVED_SUM_0 FROM t_order_9 WHERE user_id =?), SQLExecutionUnit(dataSource=demo_ds_0, sql=SELECT AVG(order_id) , COUNT(order_id) AS AVG_DERIVED_COUNT_0 , SUM(order_id) AS AVG_DERIVED_SUM_0 FROM t_order_10 WHERE user_id =?), SQLExecutionUnit(dataSource=demo_ds_0, sql=SELECT AVG(order_id) , COUNT(order_id) AS AVG_DERIVED_COUNT_0 , SUM(order_id) AS AVG_DERIVED_SUM_0 FROM t_order_11 WHERE user_id =?), SQLExecutionUnit(dataSource=demo_ds_0, sql=SELECT AVG(order_id) , COUNT(order_id) AS AVG_DERIVED_COUNT_0 , SUM(order_id) AS AVG_DERIVED_SUM_0 FROM t_order_12 WHERE user_id =?), SQLExecutionUnit(dataSource=demo_ds_0, sql=SELECT AVG(order_id) , COUNT(order_id) AS AVG_DERIVED_COUNT_0 , SUM(order_id) AS AVG_DERIVED_SUM_0 FROM t_order_13 WHERE user_id =?), SQLExecutionUnit(dataSource=demo_ds_0, sql=SELECT AVG(order_id) , COUNT(order_id) AS AVG_DERIVED_COUNT_0 , SUM(order_id) AS AVG_DERIVED_SUM_0 FROM t_order_14 WHERE user_id =?), SQLExecutionUnit(dataSource=demo_ds_0, sql=SELECT AVG(order_id) , COUNT(order_id) AS AVG_DERIVED_COUNT_0 , SUM(order_id) AS AVG_DERIVED_SUM_0 FROM t_order_15 WHERE user_id =?)]

放在更外层的大流程来看。ShardingPreparedStatement.execute()

先route().再执行PreparedStatementExecutor.execute()

private Collection<PreparedStatementUnit> route() throws SQLException {
        Collection<PreparedStatementUnit> result = new LinkedList<>();
        routeResult = routingEngine.route(getParameters());
        for (SQLExecutionUnit each : routeResult.getExecutionUnits()) {
            SQLType sqlType = routeResult.getSqlStatement().getType();
            Collection<PreparedStatement> preparedStatements;
            if (SQLType.DDL == sqlType) {
                preparedStatements = generatePreparedStatementForDDL(each);
            } else {
                preparedStatements = Collections.singletonList(generatePreparedStatement(each));
            }
            routedStatements.addAll(preparedStatements);
            for (PreparedStatement preparedStatement : preparedStatements) {
                replaySetParameter(preparedStatement);
                result.add(new PreparedStatementUnit(each, preparedStatement));
            }
        }
        return result;
    }

结果举例:

preparedStatements:[com.mysql.jdbc.JDBC4PreparedStatement@64b70919: SELECT AVG(order_id) , COUNT(order_id) AS AVG_DERIVED_COUNT_0 , SUM(order_id) AS AVG_DERIVED_SUM_0 FROM t_order_0 WHERE user_id =** NOT SPECIFIED **]
replaySetParameter:[com.mysql.jdbc.JDBC4PreparedStatement@64b70919: SELECT AVG(order_id) , COUNT(order_id) AS AVG_DERIVED_COUNT_0 , SUM(order_id) AS AVG_DERIVED_SUM_0 FROM t_order_0 WHERE user_id =101]

参考:

http://www.iocoder.cn/Sharding-JDBC/sql-rewrite/

https://blog.csdn.net/feelwing1314/article/details/80404010

版权声明:本文不是「本站」原创文章,版权归原作者所有 | 原文地址: