概述:上篇已经详细介绍Apache Calcite的概念,这块就不在多做介绍了,直接看实现代码
package com.joe.common.util;
import com.google.common.collect.ImmutableList;
import org.apache.calcite.avatica.util.Casing;
import org.apache.calcite.avatica.util.Quoting;
import org.apache.calcite.avatica.util.TimeUnitRange;
import org.apache.calcite.config.Lex;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.sql.*;
import org.apache.calcite.sql.SqlDialect.DatabaseProduct;
import org.apache.calcite.sql.dialect.MysqlSqlDialect;
import org.apache.calcite.sql.dialect.OracleSqlDialect;
import org.apache.calcite.sql.fun.SqlFloorFunction;
import org.apache.calcite.sql.fun.SqlLibraryOperators;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParseException;
import org.apache.calcite.sql.parser.SqlParser;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.validate.SqlConformanceEnum;
import org.apache.commons.lang.StringUtils;
import org.checkerframework.checker.nullness.qual.Nullable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
public class CalciteSqlUtils {
private static SqlParser.Config mysqlConfig = SqlParser.configBuilder()
.setLex(Lex.MYSQL)
.setCaseSensitive(false)
.setQuoting(Quoting.BACK_TICK)
.setQuotedCasing(Casing.TO_LOWER)
.setUnquotedCasing(Casing.TO_LOWER)
.setConformance(SqlConformanceEnum.MYSQL_5)
.build();
private static SqlParser.Config oralceConfig = SqlParser.configBuilder()
.setLex(Lex.ORACLE)
.setCaseSensitive(false)
.setQuoting(Quoting.BACK_TICK)
.setQuotedCasing(Casing.TO_LOWER)
.setUnquotedCasing(Casing.TO_LOWER)
.setConformance(SqlConformanceEnum.ORACLE_12)
.build();
private static SqlParser.Config sqlserverConfig = SqlParser.configBuilder()
.setLex(Lex.SQL_SERVER)
.setCaseSensitive(false)
.setQuoting(Quoting.BACK_TICK)
.setQuotedCasing(Casing.TO_LOWER)
.setUnquotedCasing(Casing.TO_LOWER)
.setConformance(SqlConformanceEnum.SQL_SERVER_2008)
.build();
public static List<Map<String,String>> mapList = new ArrayList<>();
public static void main(String[] args) {
String sql = "select id,name from t_user where id='${id}' and name='zhangsan'";
String sql2 = "select\n" +
" aa.TOTAL_MONEY,\n" +
" aa.DRUG_MONEY,\n" +
" aa.BASE_DRUG_MONEY,\n" +
" bb.CAL_DATE,\n" +
" bb.CAL_MONTH,\n" +
" bb.CAL_YEAR,\n" +
" cc.OFFICE_NAME,\n" +
" cc.CUSTOM_CODE\n" +
"FROM\n" +
" F_DRUG_USE aa,\n" +
" T_DATES bb,\n" +
" T_OFFICE_PROPERTY cc\n" +
"where\n" +
" aa.date_id = bb.id\n" +
" and aa.BILLING_OFFICE_ID = cc.id\n" +
" and (select cal_year from bb) in ('${year_cond}')\n" +
" and bb.cal_year BETWEEN '${yeardes.get(0)}' and '${yeardes.get(1)}'\n" +
" and office_name like '${office_name_cond}'";
try {
List<Map<String, String>> list = handlerSqlTableAlias(DatabaseProduct.ORACLE, sql2);
System.out.println("$$$$$$$$$$$$$打印别名sql$$$$$$$$$$$$$");
list.forEach(System.out::println);
String rt = handlerSqlParameterSubstitution(DatabaseProduct.ORACLE,sql2, "year_cond");
System.out.println("$$$$$$$$$$$$$打印参数sql$$$$$$$$$$$$$");
System.out.println(rt);
} catch (Exception e) {
throw new RuntimeException("", e);
}
}
public static String handlerSqlParameterSubstitution(DatabaseProduct type,String sql,String param) throws SqlParseException {
SqlParser sqlParser = null;
switch (type){
case ORACLE:
sqlParser = SqlParser.create(sql, oralceConfig);
break;
case MYSQL:
sqlParser = SqlParser.create(sql, mysqlConfig);
break;
case MSSQL:
sqlParser = SqlParser.create(sql, sqlserverConfig);
break;
default:
sqlParser = SqlParser.create(sql, SqlParser.Config.DEFAULT);
break;
}
SqlNode sqlNode = sqlParser.parseQuery();
return handlerWhere(type,sqlNode, param);
}
public static List<Map<String, String>> handlerSqlTableAlias(DatabaseProduct type,String sql) throws SqlParseException {
SqlParser sqlParser = null;
switch (type){
case ORACLE:
sqlParser = SqlParser.create(sql, oralceConfig);
break;
case MYSQL:
sqlParser = SqlParser.create(sql, mysqlConfig);
break;
case MSSQL:
sqlParser = SqlParser.create(sql, sqlserverConfig);
break;
default:
sqlParser = SqlParser.create(sql, SqlParser.Config.DEFAULT);
break;
}
SqlNode sqlNode = sqlParser.parseQuery();
List<Map<String, String>> list = handlerSQL(sqlNode);
return list.stream().distinct().collect(Collectors.toList());
}
private static List<Map<String,String>> handlerSQL(SqlNode sqlNode) {
SqlKind kind = sqlNode.getKind();
switch (kind) {
case SELECT:
handlerSelect(sqlNode);
break;
case AS:
SqlBasicCall sqlBasicCall = (SqlBasicCall) sqlNode;
SqlNode selectNode1 = sqlBasicCall.getOperandList().get(0);
SqlNode selectNode2 = sqlBasicCall.getOperandList().get(1);
if (!SqlKind.UNION.equals(selectNode1.getKind())){
if (!SqlKind.SELECT.equals(selectNode1.getKind())){
Map<String,String> aliasMap = new HashMap<>();
aliasMap.put(selectNode2.toString(),selectNode1.toString());
mapList.add(aliasMap);
}
}
handlerSQL(selectNode1);
break;
case JOIN:
SqlJoin sqlJoin = (SqlJoin) sqlNode;
SqlNode left = sqlJoin.getLeft();
handlerSQL(left);
SqlNode right = sqlJoin.getRight();
handlerSQL(right);
SqlNode condition = sqlJoin.getCondition();
if (condition!=null){
handlerField(condition);
}
break;
case UNION:
((SqlBasicCall) sqlNode).getOperandList().forEach(node -> {
handlerSQL(node);
});
break;
case ORDER_BY:
handlerOrderBy(sqlNode);
break;
}
return mapList;
}
private static void handlerOrderBy(SqlNode node) {
SqlOrderBy sqlOrderBy = (SqlOrderBy) node;
SqlNode query = sqlOrderBy.query;
handlerSQL(query);
SqlNodeList orderList = sqlOrderBy.orderList;
handlerField(orderList);
}
private static String handlerWhere(DatabaseProduct type,SqlNode sqlNode,String param) {
AtomicReference<String> sqlStr = new AtomicReference<>();
SqlKind kind = sqlNode.getKind();
switch (kind) {
case SELECT:
sqlStr.set(handlerSqlParameter(type,sqlNode, param));
break;
case JOIN:
SqlJoin sqlJoin = (SqlJoin) sqlNode;
SqlNode left = sqlJoin.getLeft();
handlerLeftAndRight(type,left,param);
SqlNode right = sqlJoin.getRight();
handlerLeftAndRight(type,right,param);
break;
case UNION:
((SqlBasicCall) sqlNode).getOperandList().forEach(node -> {
sqlStr.set(handlerSqlParameter(type,node, param));
});
break;
}
return sqlStr.get();
}
private static void handlerLeftAndRight(DatabaseProduct type,SqlNode sqlNode,String param){
SqlBasicCall leftSelectCall = (SqlBasicCall) sqlNode;
List<SqlNode> leftOperandList = leftSelectCall.getOperandList();
for (SqlNode node : leftOperandList) {
SqlKind kind = node.getKind();
if (SqlKind.IDENTIFIER.equals(kind)){
break;
}
if (SqlKind.SELECT.equals(kind)){
handlerWhere(type,node,param);
}else{
handlerLeftAndRight(type,node,param);
}
}
}
private static String handlerSqlParameter(DatabaseProduct type,SqlNode node,String param) {
SqlSelect sqlSelect = (SqlSelect) node;
SqlBasicCall where = (SqlBasicCall) sqlSelect.getWhere();
if (!sqlSelect.hasWhere()){
handlerWhere(type,sqlSelect.getFrom(),param);
}else{
handlerOperand(where,sqlSelect,param);
}
String sql = "";
switch (type){
case ORACLE:
SqlDialect.Context oracleSqlDialect = SqlDialect.EMPTY_CONTEXT
.withDatabaseProduct(DatabaseProduct.ORACLE)
.withIdentifierQuoteString("")
.withDataTypeSystem(OracleSqlDialect.DEFAULT.getTypeSystem());
sql = sqlReplace(sqlSelect.toSqlString(new MySqlDialect(oracleSqlDialect)).toString());
break;
case MYSQL:
SqlDialect.Context MYSQL_CONTEXT = SqlDialect.EMPTY_CONTEXT
.withDatabaseProduct(DatabaseProduct.MYSQL)
.withIdentifierQuoteString("")
.withDataTypeSystem(MysqlSqlDialect.DEFAULT.getTypeSystem());
sql = sqlReplace(sqlSelect.toSqlString(new MySqlDialect(MYSQL_CONTEXT)).toString());
break;
default:
SqlDialect.Context DEFAULT_CONTEXT = SqlDialect.EMPTY_CONTEXT
.withDatabaseProduct(DatabaseProduct.UNKNOWN)
.withIdentifierQuoteString("");
sql = sqlReplace(sqlSelect.toSqlString(new MySqlDialect(DEFAULT_CONTEXT)).toString());
break;
}
return sql;
}
private static void handlerOperand(SqlBasicCall where,SqlSelect sqlSelect,String param){
List<SqlNode> operandList = where.getOperandList();
for (int i = 0; i < operandList.size(); i++) {
SqlBasicCall operandStr = (SqlBasicCall)operandList.get(i);
SqlNode paramName = operandStr.getOperandList().size()>=2?operandStr:operandStr.getOperandList().get(0);
SqlKind kind = paramName.getKind();
if (SqlKind.AND.equals(kind)){
handlerOperand(operandStr,sqlSelect,param);
}
if (!SqlKind.IDENTIFIER.equals(kind)){
if (!SqlKind.BETWEEN.equals(kind)){
SqlBasicCall sqlBasicCall = (SqlBasicCall) paramName;
paramName = sqlBasicCall.getOperandList().get(1);
}
}
if (checkKind(kind)&¶mName.toString().contains(param)){
SqlOperator operator = new SqlBinaryOperator("=",
SqlKind.EQUALS,
0,
false,
operandStr.getOperator().getReturnTypeInference(),
operandStr.getOperator().getOperandTypeInference(),
operandStr.getOperator().getOperandTypeChecker());
SqlNode[] operands = new SqlNode[2];
SqlIdentifier sqlIdentifier = new SqlIdentifier("'jh'",paramName.getParserPosition());
SqlCharStringLiteral literal = SqlCharStringLiteral.createCharString("jh", paramName.getParserPosition());
operands[0]=sqlIdentifier;
operands[1]=literal;
SqlBasicCall operandCall = new SqlBasicCall(operator,operands,paramName.getParserPosition());
where.setOperand(i,operandCall);
sqlSelect.setWhere(where);
}else {
where.setOperand(i,operandStr);
sqlSelect.setWhere(where);
}
}
}
private static void handlerSelect(SqlNode select) {
SqlSelect sqlSelect = (SqlSelect) select;
SqlNodeList selectList = sqlSelect.getSelectList();
selectList.getList().forEach(list -> {
handlerField(list);
});
handlerFrom(sqlSelect.getFrom());
if (sqlSelect.hasWhere()) {
handlerField(sqlSelect.getWhere());
}
if (sqlSelect.hasOrderBy()) {
handlerField(sqlSelect.getOrderList());
}
SqlNodeList group = sqlSelect.getGroup();
if (group != null) {
group.forEach(groupField -> {
handlerField(groupField);
});
}
}
private static List<Map<String,String>> handlerFrom(SqlNode from) {
SqlKind kind = from.getKind();
switch (kind) {
case IDENTIFIER:
SqlIdentifier sqlIdentifier = (SqlIdentifier) from;
break;
case AS:
SqlBasicCall sqlBasicCall = (SqlBasicCall) from;
SqlNode selectNode1 = sqlBasicCall.getOperandList().get(0);
SqlNode selectNode2 = sqlBasicCall.getOperandList().get(1);
if (!SqlKind.UNION.equals(selectNode1.getKind())){
if (!SqlKind.SELECT.equals(selectNode1.getKind())){
Map<String,String> aliasMap = new HashMap<>();
aliasMap.put(selectNode2.toString(),selectNode1.toString());
mapList.add(aliasMap);
}
}
handlerSQL(selectNode1);
break;
case JOIN:
SqlJoin sqlJoin = (SqlJoin) from;
SqlNode left = sqlJoin.getLeft();
handlerSQL(left);
SqlNode right = sqlJoin.getRight();
handlerSQL(right);
SqlNode condition = sqlJoin.getCondition();
if (condition!=null){
handlerField(condition);
}
break;
case SELECT:
handlerSQL(from);
break;
}
return mapList;
}
private static void handlerField(SqlNode field) {
SqlKind kind = field.getKind();
switch (kind) {
case AS:
List<SqlNode> operandList1 = ((SqlBasicCall) field).getOperandList();
SqlNode left_as = operandList1.get(0);
handlerField(left_as);
break;
case IDENTIFIER:
SqlIdentifier sqlIdentifier = (SqlIdentifier) field;
break;
default:
if (field instanceof SqlBasicCall) {
List<SqlNode> operandList2 = ((SqlBasicCall) field).getOperandList();
for (int i = 0; i < operandList2.size(); i++) {
handlerField(operandList2.get(i));
}
}
if (field instanceof SqlNodeList) {
((SqlNodeList) field).getList().forEach(node -> {
handlerField(node);
});
}
break;
}
}
private static boolean checkKind(SqlKind kind){
if (SqlKind.EQUALS.equals(kind)
||SqlKind.BETWEEN.equals(kind)
||SqlKind.LIKE.equals(kind)
||SqlKind.NOT_IN.equals(kind)
||SqlKind.IN.equals(kind)
||SqlKind.LESS_THAN_OR_EQUAL.equals(kind)
||SqlKind.GREATER_THAN_OR_EQUAL.equals(kind)
||SqlKind.LESS_THAN.equals(kind)
||SqlKind.GREATER_THAN.equals(kind)
||SqlKind.NOT_EQUALS.equals(kind)
||SqlKind.IS_NOT_NULL.equals(kind)){
return Boolean.TRUE;
}else{
return Boolean.FALSE;
}
}
private static int countStr(String str1, String str2, int counter) {
if (str1.contains(str2)) {
counter++;
counter = countStr(str1.substring(str1.indexOf(str2) + str2.length()), str2, counter);
}
return counter;
}
private static String sqlReplace(String str){
List<String> list = new ArrayList<>();
list.add("ASYMMETRIC");
String all = "";
if (StringUtils.isNotBlank(str)){
for (String s : list) {
all = str.replaceAll(s, "");
}
}
return all;
}
}
class MySqlDialect extends SqlDialect {
public MySqlDialect(Context context) {
super(context);
}
@Override
public void quoteStringLiteral(StringBuilder buf, @Nullable String charsetName, String val) {
buf.append(literalQuoteString);
buf.append(val.replace(literalEndQuoteString,literalEscapedQuote));
buf.append(literalEndQuoteString);
}
@Override public boolean supportsApproxCountDistinct() {
return true;
}
@Override public boolean supportsCharSet() {
return false;
}
@Override public boolean supportsDataType(RelDataType type) {
switch (type.getSqlTypeName()) {
case BOOLEAN:
return false;
default:
return super.supportsDataType(type);
}
}
@Override public @Nullable SqlNode getCastSpec(RelDataType type) {
String castSpec;
switch (type.getSqlTypeName()) {
case SMALLINT:
castSpec = "NUMBER(5)";
break;
case INTEGER:
castSpec = "NUMBER(10)";
break;
case BIGINT:
castSpec = "NUMBER(19)";
break;
case DOUBLE:
castSpec = "DOUBLE PRECISION";
break;
default:
return super.getCastSpec(type);
}
return new SqlDataTypeSpec(
new SqlAlienSystemTypeNameSpec(castSpec, type.getSqlTypeName(), SqlParserPos.ZERO),
SqlParserPos.ZERO);
}
@Override protected boolean allowsAs() {
return false;
}
@Override public boolean supportsAliasedValues() {
return false;
}
@Override public void unparseDateTimeLiteral(SqlWriter writer,
SqlAbstractDateTimeLiteral literal, int leftPrec, int rightPrec) {
if (literal instanceof SqlTimestampLiteral) {
writer.literal("TO_TIMESTAMP('"
+ literal.toFormattedString() + "', 'YYYY-MM-DD HH24:MI:SS.FF')");
} else if (literal instanceof SqlDateLiteral) {
writer.literal("TO_DATE('"
+ literal.toFormattedString() + "', 'YYYY-MM-DD')");
} else if (literal instanceof SqlTimeLiteral) {
writer.literal("TO_TIME('"
+ literal.toFormattedString() + "', 'HH24:MI:SS.FF')");
} else {
super.unparseDateTimeLiteral(writer, literal, leftPrec, rightPrec);
}
}
@Override public List<String> getSingleRowTableName() {
return ImmutableList.of("DUAL");
}
@Override public void unparseCall(SqlWriter writer, SqlCall call,
int leftPrec, int rightPrec) {
if (call.getOperator() == SqlStdOperatorTable.SUBSTRING) {
SqlUtil.unparseFunctionSyntax(SqlLibraryOperators.SUBSTR_ORACLE, writer,
call, false);
} else {
switch (call.getKind()) {
case FLOOR:
if (call.operandCount() != 2) {
super.unparseCall(writer, call, leftPrec, rightPrec);
return;
}
final SqlLiteral timeUnitNode = call.operand(1);
final TimeUnitRange timeUnit = timeUnitNode.getValueAs(TimeUnitRange.class);
SqlCall call2 = SqlFloorFunction.replaceTimeUnitOperand(call, timeUnit.name(),
timeUnitNode.getParserPosition());
SqlFloorFunction.unparseDatetimeFunction(writer, call2, "TRUNC", true);
break;
default:
super.unparseCall(writer, call, leftPrec, rightPrec);
}
}
}
}
里面有测试的main方法sql列子,直接运行即可。主要还是要理解原理,实现不是很难。
|