MyBatis拦截器修改SQL语句
背景
最近公司项目想要做成一个云SaaS平台,需要不同用户能看到不同数据,需要做到数据的物理隔离。目前的方案就是在每张业务表中增加一个platform_id 字段,来区分不同的租户,这就意味着在原来系统层面需要再增删改查都需要带上platform_id 字段作为标识。如果在每个脚本上都手动加上这个字段的话那就太麻烦,太复杂了。所以就想使用mybatis 的拦截器Interceptor 来实现。
准备
自定义注解@PlatformTag @PlatformTagIngore
为了让代码更灵活,只在mapper 类上标注有@PlatformTag 的标记才会自动拦截并且添加条件。
@Target({ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface PlatformTag {
}
这个注解只作用于类上,但是我们又考虑到自定义的复杂sql是没办法自动添加条件的,所以就再增加一个@PlatformTagIngore 注解来忽略类中的方法进行手动添加条件。
@Target({ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface PlatformTagIgnore {
}
过滤器PlatformInterceptor
首先考虑一下我们的mysql的几种基本类型,增删改查,删除不用考虑(删除一般都是直接根据主键id删除),改和新增属于update 和insert ,查是select ,目前就只需要考虑这三种情况了
import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
import net.sf.jsqlparser.parser.CCJSqlParserManager;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.Statement;
import net.sf.jsqlparser.statement.insert.Insert;
import net.sf.jsqlparser.statement.select.FromItem;
import net.sf.jsqlparser.statement.select.PlainSelect;
import net.sf.jsqlparser.statement.select.Select;
import net.sf.jsqlparser.statement.update.Update;
import org.apache.ibatis.cache.CacheKey;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Plugin;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import java.io.StringReader;
import java.util.ArrayList;
import java.util.List;
import java.util.Properties;
@Slf4j
@Intercepts({
@Signature( type = Executor.class, method = "update",args = {MappedStatement.class, Object.class}),
@Signature(type = Executor.class, method = "query",args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}),
@Signature(type = Executor.class, method = "query",args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class,BoundSql.class})
})
public class AreaInterceptor implements Interceptor {
private static final String COLUMN_NAME = "platform_id";
@Override
public Object intercept(Invocation invocation) throws Throwable {
String processSql = ExecutorPluginUtils.getSqlByInvocation(invocation);
log.debug("schema替换前:{}", processSql);
String sql2Reset = processSql;
Statement statement = CCJSqlParserUtil.parse(processSql);
MappedStatement mappedStatement = (MappedStatement) invocation.getArgs()[0];
if (ExecutorPluginUtils.isAreaTag(mappedStatement)) {
try {
if (statement instanceof Update) {
Update updateStatement = (Update) statement;
Table table = updateStatement.getTables().get(0);
if (table != null) {
List<Column> columns = updateStatement.getColumns();
List<Expression> expressions = updateStatement.getExpressions();
columns.add(new Column(COLUMN_NAME));
expressions.add(CCJSqlParserUtil.parseExpression(CurrentPlatformIdCache.getCurrentPlatformId()));
updateStatement.setColumns(columns);
updateStatement.setExpressions(expressions);
sql2Reset = updateStatement.toString();
}
}
if (statement instanceof Insert) {
Insert insertStatement = (Insert) statement;
List<Column> columns = insertStatement.getColumns();
ExpressionList itemsList = (ExpressionList) insertStatement.getItemsList();
columns.add(new Column(COLUMN_NAME));
List<Expression> list = new ArrayList<>();
list.addAll(itemsList.getExpressions());
list.add(CCJSqlParserUtil.parseExpression(CurrentPlatformIdCache.getCurrentPlatformId()));
itemsList.setExpressions(list);
insertStatement.setItemsList(itemsList);
insertStatement.setColumns(columns);
sql2Reset = insertStatement.toString();
}
if (statement instanceof Select) {
Select selectStatement = (Select) statement;
PlainSelect plain = (PlainSelect) selectStatement.getSelectBody();
FromItem fromItem = plain.getFromItem();
String sql = processSql;
StringBuffer whereSql = new StringBuffer();
if (fromItem.getAlias() != null) {
whereSql.append(fromItem.getAlias().getName()).append(".platform_id = ").append(CurrentPlatformIdCache.getCurrentPlatformId());
} else {
whereSql.append("platform_id = ").append(CurrentPlatformIdCache.getCurrentPlatformId());
}
Expression where = plain.getWhere();
if (where == null) {
if (whereSql.length() > 0) {
Expression expression = CCJSqlParserUtil
.parseCondExpression(whereSql.toString());
Expression whereExpression = (Expression) expression;
plain.setWhere(whereExpression);
}
} else {
if (whereSql.length() > 0) {
whereSql.append(" and ( " + where.toString() + " )");
} else {
whereSql.append(where.toString());
}
Expression expression = CCJSqlParserUtil
.parseCondExpression(whereSql.toString());
plain.setWhere(expression);
}
sql2Reset = selectStatement.toString();
}
} catch (Exception e) {
e.printStackTrace();
}
}
log.info("schema替换后:{}", sql2Reset);
ExecutorPluginUtils.resetSql2Invocation(invocation, sql2Reset);
return invocation.proceed();
}
@Override
public Object plugin(Object target) {
return Plugin.wrap(target, this);
}
@Override
public void setProperties(Properties properties) {
}
}
|