处理jsqlparser版本(不兼容改动,需要自行引入模块).

https://github.com/baomidou/mybatis-plus/issues/6497
This commit is contained in:
聂秋荣 2024-10-19 14:10:14 +08:00
parent 9ae6730d0e
commit 37251cf9db
78 changed files with 5817 additions and 11 deletions

View File

@ -1,6 +1,5 @@
dependencies {
api project(":mybatis-plus-annotation")
api "${lib.'jsqlparser'}"
api "${lib.mybatis}"
implementation "${lib.cglib}"

View File

@ -47,9 +47,6 @@ class GeneratePomTest {
Dependency mybatis = dependenciesMap.get("mybatis");
Assertions.assertEquals("compile", mybatis.getScope());
Assertions.assertFalse(mybatis.isOptional());
Dependency jsqlParser = dependenciesMap.get("jsqlparser");
Assertions.assertEquals("compile", jsqlParser.getScope());
Assertions.assertFalse(jsqlParser.isOptional());
Dependency cglib = dependenciesMap.get("cglib");
Assertions.assertEquals("compile", cglib.getScope());
Assertions.assertTrue(cglib.isOptional());

View File

@ -17,9 +17,6 @@ dependencies {
implementation "${lib['mybatis-thymeleaf']}"
implementation "${lib.'mybatis-velocity'}"
implementation "${lib.'mybatis-freemarker'}"
implementation "de.ruedigermoeller:fst:3.0.4-jdk17"
implementation "com.github.ben-manes.caffeine:caffeine:2.9.3"
testImplementation "io.github.classgraph:classgraph:4.8.176"
testImplementation "${lib."spring-context-support"}"
testImplementation "${lib.h2}"
testImplementation "${lib.mysql}"

View File

@ -2,11 +2,11 @@ package com.baomidou.mybatisplus.test.extension.plugins.pagination.dialects;
import com.baomidou.mybatisplus.extension.plugins.pagination.DialectModel;
import com.baomidou.mybatisplus.extension.plugins.pagination.dialects.IDialect;
import com.google.common.collect.Lists;
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Test;
import org.junit.platform.commons.util.ReflectionUtils;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
@ -36,7 +36,9 @@ class IDialectTest {
DialectModel model = o.buildPaginationSql("select * from table", 1, 10);
String sql = model.getDialectSql();
if (!map.containsKey(sql)) {
map.put(sql, Lists.newArrayList(i));
ArrayList<Class<?>> list = new ArrayList<>();
list.add(i);
map.put(sql,list);
} else {
map.get(sql).add(i);
}

View File

@ -0,0 +1 @@
tasks.matching {it.group == 'publishing' || it.group == 'central publish' }.each { it.enabled = false }

View File

@ -0,0 +1,14 @@
dependencies {
api "com.github.jsqlparser:jsqlparser:4.9"
api project(":mybatis-plus-jsqlparser:mybatis-plus-jsqlparser-common")
implementation "${lib."slf4j-api"}"
implementation "de.ruedigermoeller:fst:3.0.3"
implementation "com.github.ben-manes.caffeine:caffeine:2.9.3"
testImplementation "io.github.classgraph:classgraph:4.8.176"
testImplementation "${lib."spring-context-support"}"
testImplementation "${lib.h2}"
testImplementation group: 'com.google.guava', name: 'guava', version: '33.3.1-jre'
}
compileJava.dependsOn(processResources)

View File

@ -0,0 +1,266 @@
/*
* Copyright (c) 2011-2024, baomidou (jobob@qq.com).
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.baomidou.mybatisplus.extension.parser.cache;
import org.nustaq.serialization.FSTConfiguration;
/**
* Fst Factory
*
* @author miemie
* @since 2023-08-06
*/
public class FstFactory {
private static final FstFactory FACTORY = new FstFactory();
private final FSTConfiguration conf = FSTConfiguration.createDefaultConfiguration();
public static FstFactory getDefaultFactory() {
return FACTORY;
}
public FstFactory() {
conf.registerClass(net.sf.jsqlparser.expression.Alias.class);
conf.registerClass(net.sf.jsqlparser.expression.Alias.AliasColumn.class);
conf.registerClass(net.sf.jsqlparser.expression.AllValue.class);
conf.registerClass(net.sf.jsqlparser.expression.AnalyticExpression.class);
conf.registerClass(net.sf.jsqlparser.expression.AnyComparisonExpression.class);
conf.registerClass(net.sf.jsqlparser.expression.ArrayConstructor.class);
conf.registerClass(net.sf.jsqlparser.expression.ArrayExpression.class);
conf.registerClass(net.sf.jsqlparser.expression.CaseExpression.class);
conf.registerClass(net.sf.jsqlparser.expression.CastExpression.class);
conf.registerClass(net.sf.jsqlparser.expression.CollateExpression.class);
conf.registerClass(net.sf.jsqlparser.expression.ConnectByRootOperator.class);
conf.registerClass(net.sf.jsqlparser.expression.DateTimeLiteralExpression.class);
conf.registerClass(net.sf.jsqlparser.expression.DateValue.class);
conf.registerClass(net.sf.jsqlparser.expression.DoubleValue.class);
conf.registerClass(net.sf.jsqlparser.expression.ExtractExpression.class);
conf.registerClass(net.sf.jsqlparser.expression.FilterOverImpl.class);
conf.registerClass(net.sf.jsqlparser.expression.Function.class);
conf.registerClass(net.sf.jsqlparser.expression.HexValue.class);
conf.registerClass(net.sf.jsqlparser.expression.IntervalExpression.class);
conf.registerClass(net.sf.jsqlparser.expression.JdbcNamedParameter.class);
conf.registerClass(net.sf.jsqlparser.expression.JdbcParameter.class);
conf.registerClass(net.sf.jsqlparser.expression.JsonAggregateFunction.class);
conf.registerClass(net.sf.jsqlparser.expression.JsonExpression.class);
conf.registerClass(net.sf.jsqlparser.expression.JsonFunction.class);
conf.registerClass(net.sf.jsqlparser.expression.JsonFunctionExpression.class);
conf.registerClass(net.sf.jsqlparser.expression.JsonKeyValuePair.class);
conf.registerClass(net.sf.jsqlparser.expression.KeepExpression.class);
conf.registerClass(net.sf.jsqlparser.expression.LongValue.class);
conf.registerClass(net.sf.jsqlparser.expression.MySQLGroupConcat.class);
conf.registerClass(net.sf.jsqlparser.expression.MySQLIndexHint.class);
conf.registerClass(net.sf.jsqlparser.expression.NextValExpression.class);
conf.registerClass(net.sf.jsqlparser.expression.NotExpression.class);
conf.registerClass(net.sf.jsqlparser.expression.NullValue.class);
conf.registerClass(net.sf.jsqlparser.expression.NumericBind.class);
conf.registerClass(net.sf.jsqlparser.expression.OracleHierarchicalExpression.class);
conf.registerClass(net.sf.jsqlparser.expression.OracleHint.class);
conf.registerClass(net.sf.jsqlparser.expression.OracleNamedFunctionParameter.class);
conf.registerClass(net.sf.jsqlparser.expression.OrderByClause.class);
conf.registerClass(net.sf.jsqlparser.expression.OverlapsCondition.class);
conf.registerClass(net.sf.jsqlparser.expression.Parenthesis.class);
conf.registerClass(net.sf.jsqlparser.expression.PartitionByClause.class);
conf.registerClass(net.sf.jsqlparser.expression.RangeExpression.class);
conf.registerClass(net.sf.jsqlparser.expression.RowConstructor.class);
conf.registerClass(net.sf.jsqlparser.expression.RowGetExpression.class);
conf.registerClass(net.sf.jsqlparser.expression.SQLServerHints.class);
conf.registerClass(net.sf.jsqlparser.expression.SignedExpression.class);
conf.registerClass(net.sf.jsqlparser.expression.StringValue.class);
conf.registerClass(net.sf.jsqlparser.expression.TimeKeyExpression.class);
conf.registerClass(net.sf.jsqlparser.expression.TimeValue.class);
conf.registerClass(net.sf.jsqlparser.expression.TimestampValue.class);
conf.registerClass(net.sf.jsqlparser.expression.TimezoneExpression.class);
conf.registerClass(net.sf.jsqlparser.expression.TranscodingFunction.class);
conf.registerClass(net.sf.jsqlparser.expression.TrimFunction.class);
conf.registerClass(net.sf.jsqlparser.expression.UserVariable.class);
conf.registerClass(net.sf.jsqlparser.expression.VariableAssignment.class);
conf.registerClass(net.sf.jsqlparser.expression.WhenClause.class);
conf.registerClass(net.sf.jsqlparser.expression.WindowDefinition.class);
conf.registerClass(net.sf.jsqlparser.expression.WindowElement.class);
conf.registerClass(net.sf.jsqlparser.expression.WindowOffset.class);
conf.registerClass(net.sf.jsqlparser.expression.WindowRange.class);
conf.registerClass(net.sf.jsqlparser.expression.XMLSerializeExpr.class);
conf.registerClass(net.sf.jsqlparser.expression.operators.arithmetic.Addition.class);
conf.registerClass(net.sf.jsqlparser.expression.operators.arithmetic.BitwiseAnd.class);
conf.registerClass(net.sf.jsqlparser.expression.operators.arithmetic.BitwiseLeftShift.class);
conf.registerClass(net.sf.jsqlparser.expression.operators.arithmetic.BitwiseOr.class);
conf.registerClass(net.sf.jsqlparser.expression.operators.arithmetic.BitwiseRightShift.class);
conf.registerClass(net.sf.jsqlparser.expression.operators.arithmetic.BitwiseXor.class);
conf.registerClass(net.sf.jsqlparser.expression.operators.arithmetic.Concat.class);
conf.registerClass(net.sf.jsqlparser.expression.operators.arithmetic.Division.class);
conf.registerClass(net.sf.jsqlparser.expression.operators.arithmetic.IntegerDivision.class);
conf.registerClass(net.sf.jsqlparser.expression.operators.arithmetic.Modulo.class);
conf.registerClass(net.sf.jsqlparser.expression.operators.arithmetic.Multiplication.class);
conf.registerClass(net.sf.jsqlparser.expression.operators.arithmetic.Subtraction.class);
conf.registerClass(net.sf.jsqlparser.expression.operators.conditional.AndExpression.class);
conf.registerClass(net.sf.jsqlparser.expression.operators.conditional.OrExpression.class);
conf.registerClass(net.sf.jsqlparser.expression.operators.conditional.XorExpression.class);
conf.registerClass(net.sf.jsqlparser.expression.operators.relational.Between.class);
conf.registerClass(net.sf.jsqlparser.expression.operators.relational.ContainedBy.class);
conf.registerClass(net.sf.jsqlparser.expression.operators.relational.Contains.class);
conf.registerClass(net.sf.jsqlparser.expression.operators.relational.DoubleAnd.class);
conf.registerClass(net.sf.jsqlparser.expression.operators.relational.EqualsTo.class);
conf.registerClass(net.sf.jsqlparser.expression.operators.relational.ExistsExpression.class);
conf.registerClass(net.sf.jsqlparser.expression.operators.relational.ExpressionList.class);
conf.registerClass(net.sf.jsqlparser.expression.operators.relational.FullTextSearch.class);
conf.registerClass(net.sf.jsqlparser.expression.operators.relational.GeometryDistance.class);
conf.registerClass(net.sf.jsqlparser.expression.operators.relational.GreaterThan.class);
conf.registerClass(net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals.class);
conf.registerClass(net.sf.jsqlparser.expression.operators.relational.InExpression.class);
conf.registerClass(net.sf.jsqlparser.expression.operators.relational.IsBooleanExpression.class);
conf.registerClass(net.sf.jsqlparser.expression.operators.relational.IsDistinctExpression.class);
conf.registerClass(net.sf.jsqlparser.expression.operators.relational.IsNullExpression.class);
conf.registerClass(net.sf.jsqlparser.expression.operators.relational.JsonOperator.class);
conf.registerClass(net.sf.jsqlparser.expression.operators.relational.LikeExpression.class);
conf.registerClass(net.sf.jsqlparser.expression.operators.relational.Matches.class);
conf.registerClass(net.sf.jsqlparser.expression.operators.relational.MemberOfExpression.class);
conf.registerClass(net.sf.jsqlparser.expression.operators.relational.MinorThan.class);
conf.registerClass(net.sf.jsqlparser.expression.operators.relational.MinorThanEquals.class);
conf.registerClass(net.sf.jsqlparser.expression.operators.relational.NamedExpressionList.class);
conf.registerClass(net.sf.jsqlparser.expression.operators.relational.NotEqualsTo.class);
conf.registerClass(net.sf.jsqlparser.expression.operators.relational.ParenthesedExpressionList.class);
conf.registerClass(net.sf.jsqlparser.expression.operators.relational.RegExpMatchOperator.class);
conf.registerClass(net.sf.jsqlparser.expression.operators.relational.SimilarToExpression.class);
conf.registerClass(net.sf.jsqlparser.expression.operators.relational.TSQLLeftJoin.class);
conf.registerClass(net.sf.jsqlparser.expression.operators.relational.TSQLRightJoin.class);
conf.registerClass(net.sf.jsqlparser.parser.ASTNodeAccessImpl.class);
conf.registerClass(net.sf.jsqlparser.parser.Token.class);
conf.registerClass(net.sf.jsqlparser.schema.Column.class);
conf.registerClass(net.sf.jsqlparser.schema.Sequence.class);
conf.registerClass(net.sf.jsqlparser.schema.Synonym.class);
conf.registerClass(net.sf.jsqlparser.schema.Table.class);
conf.registerClass(net.sf.jsqlparser.statement.Block.class);
conf.registerClass(net.sf.jsqlparser.statement.Commit.class);
conf.registerClass(net.sf.jsqlparser.statement.DeclareStatement.class);
conf.registerClass(net.sf.jsqlparser.statement.DeclareStatement.TypeDefExpr.class);
conf.registerClass(net.sf.jsqlparser.statement.DescribeStatement.class);
conf.registerClass(net.sf.jsqlparser.statement.ExplainStatement.class);
conf.registerClass(net.sf.jsqlparser.statement.ExplainStatement.Option.class);
conf.registerClass(net.sf.jsqlparser.statement.IfElseStatement.class);
conf.registerClass(net.sf.jsqlparser.statement.OutputClause.class);
conf.registerClass(net.sf.jsqlparser.statement.PurgeStatement.class);
conf.registerClass(net.sf.jsqlparser.statement.ReferentialAction.class);
conf.registerClass(net.sf.jsqlparser.statement.ResetStatement.class);
conf.registerClass(net.sf.jsqlparser.statement.RollbackStatement.class);
conf.registerClass(net.sf.jsqlparser.statement.SavepointStatement.class);
conf.registerClass(net.sf.jsqlparser.statement.SetStatement.class);
conf.registerClass(net.sf.jsqlparser.statement.ShowColumnsStatement.class);
conf.registerClass(net.sf.jsqlparser.statement.ShowStatement.class);
conf.registerClass(net.sf.jsqlparser.statement.Statements.class);
conf.registerClass(net.sf.jsqlparser.statement.UnsupportedStatement.class);
conf.registerClass(net.sf.jsqlparser.statement.UseStatement.class);
conf.registerClass(net.sf.jsqlparser.statement.alter.Alter.class);
conf.registerClass(net.sf.jsqlparser.statement.alter.AlterExpression.class);
conf.registerClass(net.sf.jsqlparser.statement.alter.AlterExpression.ColumnDataType.class);
conf.registerClass(net.sf.jsqlparser.statement.alter.AlterExpression.ColumnDropDefault.class);
conf.registerClass(net.sf.jsqlparser.statement.alter.AlterExpression.ColumnDropNotNull.class);
conf.registerClass(net.sf.jsqlparser.statement.alter.AlterSession.class);
conf.registerClass(net.sf.jsqlparser.statement.alter.AlterSystemStatement.class);
conf.registerClass(net.sf.jsqlparser.statement.alter.RenameTableStatement.class);
conf.registerClass(net.sf.jsqlparser.statement.alter.sequence.AlterSequence.class);
conf.registerClass(net.sf.jsqlparser.statement.analyze.Analyze.class);
conf.registerClass(net.sf.jsqlparser.statement.comment.Comment.class);
conf.registerClass(net.sf.jsqlparser.statement.create.function.CreateFunction.class);
conf.registerClass(net.sf.jsqlparser.statement.create.index.CreateIndex.class);
conf.registerClass(net.sf.jsqlparser.statement.create.procedure.CreateProcedure.class);
conf.registerClass(net.sf.jsqlparser.statement.create.schema.CreateSchema.class);
conf.registerClass(net.sf.jsqlparser.statement.create.sequence.CreateSequence.class);
conf.registerClass(net.sf.jsqlparser.statement.create.synonym.CreateSynonym.class);
conf.registerClass(net.sf.jsqlparser.statement.create.table.CheckConstraint.class);
conf.registerClass(net.sf.jsqlparser.statement.create.table.ColDataType.class);
conf.registerClass(net.sf.jsqlparser.statement.create.table.ColumnDefinition.class);
conf.registerClass(net.sf.jsqlparser.statement.create.table.CreateTable.class);
conf.registerClass(net.sf.jsqlparser.statement.create.table.ExcludeConstraint.class);
conf.registerClass(net.sf.jsqlparser.statement.create.table.ForeignKeyIndex.class);
conf.registerClass(net.sf.jsqlparser.statement.create.table.Index.class);
conf.registerClass(net.sf.jsqlparser.statement.create.table.Index.ColumnParams.class);
conf.registerClass(net.sf.jsqlparser.statement.create.table.NamedConstraint.class);
conf.registerClass(net.sf.jsqlparser.statement.create.table.RowMovement.class);
conf.registerClass(net.sf.jsqlparser.statement.create.view.AlterView.class);
conf.registerClass(net.sf.jsqlparser.statement.create.view.CreateView.class);
conf.registerClass(net.sf.jsqlparser.statement.delete.Delete.class);
conf.registerClass(net.sf.jsqlparser.statement.drop.Drop.class);
conf.registerClass(net.sf.jsqlparser.statement.execute.Execute.class);
conf.registerClass(net.sf.jsqlparser.statement.grant.Grant.class);
conf.registerClass(net.sf.jsqlparser.statement.insert.Insert.class);
conf.registerClass(net.sf.jsqlparser.statement.insert.InsertConflictAction.class);
conf.registerClass(net.sf.jsqlparser.statement.insert.InsertConflictTarget.class);
conf.registerClass(net.sf.jsqlparser.statement.merge.Merge.class);
conf.registerClass(net.sf.jsqlparser.statement.merge.MergeDelete.class);
conf.registerClass(net.sf.jsqlparser.statement.merge.MergeInsert.class);
conf.registerClass(net.sf.jsqlparser.statement.merge.MergeUpdate.class);
conf.registerClass(net.sf.jsqlparser.statement.refresh.RefreshMaterializedViewStatement.class);
conf.registerClass(net.sf.jsqlparser.statement.select.AllColumns.class);
conf.registerClass(net.sf.jsqlparser.statement.select.AllTableColumns.class);
conf.registerClass(net.sf.jsqlparser.statement.select.Distinct.class);
conf.registerClass(net.sf.jsqlparser.statement.select.ExceptOp.class);
conf.registerClass(net.sf.jsqlparser.statement.select.Fetch.class);
conf.registerClass(net.sf.jsqlparser.statement.select.First.class);
conf.registerClass(net.sf.jsqlparser.statement.select.ForClause.class);
conf.registerClass(net.sf.jsqlparser.statement.select.GroupByElement.class);
conf.registerClass(net.sf.jsqlparser.statement.select.IntersectOp.class);
conf.registerClass(net.sf.jsqlparser.statement.select.Join.class);
conf.registerClass(net.sf.jsqlparser.statement.select.KSQLJoinWindow.class);
conf.registerClass(net.sf.jsqlparser.statement.select.KSQLWindow.class);
conf.registerClass(net.sf.jsqlparser.statement.select.LateralSubSelect.class);
conf.registerClass(net.sf.jsqlparser.statement.select.LateralView.class);
conf.registerClass(net.sf.jsqlparser.statement.select.Limit.class);
conf.registerClass(net.sf.jsqlparser.statement.select.MinusOp.class);
conf.registerClass(net.sf.jsqlparser.statement.select.Offset.class);
conf.registerClass(net.sf.jsqlparser.statement.select.OptimizeFor.class);
conf.registerClass(net.sf.jsqlparser.statement.select.OrderByElement.class);
conf.registerClass(net.sf.jsqlparser.statement.select.ParenthesedFromItem.class);
conf.registerClass(net.sf.jsqlparser.statement.select.ParenthesedSelect.class);
conf.registerClass(net.sf.jsqlparser.statement.select.Pivot.class);
conf.registerClass(net.sf.jsqlparser.statement.select.PivotXml.class);
conf.registerClass(net.sf.jsqlparser.statement.select.PlainSelect.class);
conf.registerClass(net.sf.jsqlparser.statement.select.SelectItem.class);
conf.registerClass(net.sf.jsqlparser.statement.select.SetOperationList.class);
conf.registerClass(net.sf.jsqlparser.statement.select.Skip.class);
conf.registerClass(net.sf.jsqlparser.statement.select.TableFunction.class);
conf.registerClass(net.sf.jsqlparser.statement.select.TableStatement.class);
conf.registerClass(net.sf.jsqlparser.statement.select.Top.class);
conf.registerClass(net.sf.jsqlparser.statement.select.UnPivot.class);
conf.registerClass(net.sf.jsqlparser.statement.select.UnionOp.class);
conf.registerClass(net.sf.jsqlparser.statement.select.Values.class);
conf.registerClass(net.sf.jsqlparser.statement.select.Wait.class);
conf.registerClass(net.sf.jsqlparser.statement.select.WithIsolation.class);
conf.registerClass(net.sf.jsqlparser.statement.select.WithItem.class);
conf.registerClass(net.sf.jsqlparser.statement.show.ShowIndexStatement.class);
conf.registerClass(net.sf.jsqlparser.statement.show.ShowTablesStatement.class);
conf.registerClass(net.sf.jsqlparser.statement.truncate.Truncate.class);
conf.registerClass(net.sf.jsqlparser.statement.update.Update.class);
conf.registerClass(net.sf.jsqlparser.statement.update.UpdateSet.class);
conf.registerClass(net.sf.jsqlparser.statement.upsert.Upsert.class);
conf.registerClass(net.sf.jsqlparser.util.cnfexpression.MultiAndExpression.class);
conf.registerClass(net.sf.jsqlparser.util.cnfexpression.MultiOrExpression.class);
conf.registerClass(net.sf.jsqlparser.expression.BinaryExpression.class);
conf.registerClass(net.sf.jsqlparser.expression.operators.relational.ComparisonOperator.class);
conf.registerClass(net.sf.jsqlparser.expression.operators.relational.OldOracleJoinBinaryExpression.class);
conf.registerClass(net.sf.jsqlparser.statement.CreateFunctionalStatement.class);
conf.registerClass(net.sf.jsqlparser.statement.select.Select.class);
conf.registerClass(net.sf.jsqlparser.statement.select.SetOperation.class);
conf.registerClass(net.sf.jsqlparser.util.cnfexpression.MultipleExpression.class);
}
public byte[] asByteArray(Object obj) {
return conf.asByteArray(obj);
}
public Object asObject(byte[] bytes) {
return conf.asObject(bytes);
}
}

View File

@ -0,0 +1,424 @@
/*
* Copyright (c) 2011-2024, baomidou (jobob@qq.com).
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.baomidou.mybatisplus.extension.plugins.inner;
import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
import com.baomidou.mybatisplus.extension.parser.JsqlParserSupport;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import lombok.ToString;
import net.sf.jsqlparser.expression.*;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.expression.operators.relational.ExistsExpression;
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
import net.sf.jsqlparser.expression.operators.relational.InExpression;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.select.*;
import java.util.*;
import java.util.stream.Collectors;
/**
* 多表条件处理基对象从原有的 {@link TenantLineInnerInterceptor} 拦截器中提取出来
*
* @author houkunlin
* @since 3.5.2
*/
@Data
@NoArgsConstructor
@ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true)
@SuppressWarnings({"rawtypes"})
public abstract class BaseMultiTableInnerInterceptor extends JsqlParserSupport implements InnerInterceptor {
protected void processSelectBody(Select selectBody, final String whereSegment) {
if (selectBody == null) {
return;
}
if (selectBody instanceof PlainSelect) {
processPlainSelect((PlainSelect) selectBody, whereSegment);
} else if (selectBody instanceof ParenthesedSelect) {
ParenthesedSelect parenthesedSelect = (ParenthesedSelect) selectBody;
processSelectBody(parenthesedSelect.getSelect(), whereSegment);
} else if (selectBody instanceof SetOperationList) {
SetOperationList operationList = (SetOperationList) selectBody;
List<Select> selectBodyList = operationList.getSelects();
if (CollectionUtils.isNotEmpty(selectBodyList)) {
selectBodyList.forEach(body -> processSelectBody(body, whereSegment));
}
}
}
/**
* delete update 语句 where 处理
*/
protected Expression andExpression(Table table, Expression where, final String whereSegment) {
//获得where条件表达式
final Expression expression = buildTableExpression(table, where, whereSegment);
if (expression == null) {
return where;
}
if (where != null) {
if (where instanceof OrExpression) {
return new AndExpression(new Parenthesis(where), expression);
} else {
return new AndExpression(where, expression);
}
}
return expression;
}
/**
* 处理 PlainSelect
*/
protected void processPlainSelect(final PlainSelect plainSelect, final String whereSegment) {
//#3087 github
List<SelectItem<?>> selectItems = plainSelect.getSelectItems();
if (CollectionUtils.isNotEmpty(selectItems)) {
selectItems.forEach(selectItem -> processSelectItem(selectItem, whereSegment));
}
// 处理 where 中的子查询
Expression where = plainSelect.getWhere();
processWhereSubSelect(where, whereSegment);
// 处理 fromItem
FromItem fromItem = plainSelect.getFromItem();
List<Table> list = processFromItem(fromItem, whereSegment);
List<Table> mainTables = new ArrayList<>(list);
// 处理 join
List<Join> joins = plainSelect.getJoins();
if (CollectionUtils.isNotEmpty(joins)) {
processJoins(mainTables, joins, whereSegment);
}
// 当有 mainTable 进行 where 条件追加
if (CollectionUtils.isNotEmpty(mainTables)) {
plainSelect.setWhere(builderExpression(where, mainTables, whereSegment));
}
}
private List<Table> processFromItem(FromItem fromItem, final String whereSegment) {
// 处理括号括起来的表达式
// while (fromItem instanceof ParenthesedFromItem) {
// fromItem = ((ParenthesedFromItem) fromItem).getFromItem();
// }
List<Table> mainTables = new ArrayList<>();
// join 时的处理逻辑
if (fromItem instanceof Table) {
Table fromTable = (Table) fromItem;
mainTables.add(fromTable);
} else if (fromItem instanceof ParenthesedFromItem) {
// SubJoin 类型则还需要添加上 where 条件
List<Table> tables = processSubJoin((ParenthesedFromItem) fromItem, whereSegment);
mainTables.addAll(tables);
} else {
// 处理下 fromItem
processOtherFromItem(fromItem, whereSegment);
}
return mainTables;
}
/**
* 处理where条件内的子查询
* <p>
* 支持如下:
* <ol>
* <li>in</li>
* <li>=</li>
* <li>&gt;</li>
* <li>&lt;</li>
* <li>&gt;=</li>
* <li>&lt;=</li>
* <li>&lt;&gt;</li>
* <li>EXISTS</li>
* <li>NOT EXISTS</li>
* </ol>
* <p>
* 前提条件:
* 1. 子查询必须放在小括号中
* 2. 子查询一般放在比较操作符的右边
*
* @param where where 条件
*/
protected void processWhereSubSelect(Expression where, final String whereSegment) {
if (where == null) {
return;
}
if (where instanceof FromItem) {
processOtherFromItem((FromItem) where, whereSegment);
return;
}
if (where.toString().indexOf("SELECT") > 0) {
// 有子查询
if (where instanceof BinaryExpression) {
// 比较符号 , and , or , 等等
BinaryExpression expression = (BinaryExpression) where;
processWhereSubSelect(expression.getLeftExpression(), whereSegment);
processWhereSubSelect(expression.getRightExpression(), whereSegment);
} else if (where instanceof InExpression) {
// in
InExpression expression = (InExpression) where;
Expression inExpression = expression.getRightExpression();
if (inExpression instanceof Select) {
processSelectBody(((Select) inExpression), whereSegment);
}
} else if (where instanceof ExistsExpression) {
// exists
ExistsExpression expression = (ExistsExpression) where;
processWhereSubSelect(expression.getRightExpression(), whereSegment);
} else if (where instanceof NotExpression) {
// not exists
NotExpression expression = (NotExpression) where;
processWhereSubSelect(expression.getExpression(), whereSegment);
} else if (where instanceof Parenthesis) {
Parenthesis expression = (Parenthesis) where;
processWhereSubSelect(expression.getExpression(), whereSegment);
}
}
}
protected void processSelectItem(SelectItem selectItem, final String whereSegment) {
Expression expression = selectItem.getExpression();
if (expression instanceof Select) {
processSelectBody(((Select) expression), whereSegment);
} else if (expression instanceof Function) {
processFunction((Function) expression, whereSegment);
} else if (expression instanceof ExistsExpression) {
ExistsExpression existsExpression = (ExistsExpression) expression;
processSelectBody((Select) existsExpression.getRightExpression(), whereSegment);
}
}
/**
* 处理函数
* <p>支持: 1. select fun(args..) 2. select fun1(fun2(args..),args..)<p>
* <p> fixed gitee pulls/141</p>
*
* @param function
*/
protected void processFunction(Function function, final String whereSegment) {
ExpressionList<?> parameters = function.getParameters();
if (parameters != null) {
parameters.forEach(expression -> {
if (expression instanceof Select) {
processSelectBody(((Select) expression), whereSegment);
} else if (expression instanceof Function) {
processFunction((Function) expression, whereSegment);
} else if (expression instanceof EqualsTo) {
if (((EqualsTo) expression).getLeftExpression() instanceof Select) {
processSelectBody(((Select) ((EqualsTo) expression).getLeftExpression()), whereSegment);
}
if (((EqualsTo) expression).getRightExpression() instanceof Select) {
processSelectBody(((Select) ((EqualsTo) expression).getRightExpression()), whereSegment);
}
}
});
}
}
/**
* 处理子查询等
*/
protected void processOtherFromItem(FromItem fromItem, final String whereSegment) {
// 去除括号
// while (fromItem instanceof ParenthesisFromItem) {
// fromItem = ((ParenthesisFromItem) fromItem).getFromItem();
// }
if (fromItem instanceof ParenthesedSelect) {
Select subSelect = (Select) fromItem;
processSelectBody(subSelect, whereSegment);
} else if (fromItem instanceof ParenthesedFromItem) {
logger.debug("Perform a subQuery, if you do not give us feedback");
}
}
/**
* 处理 sub join
*
* @param subJoin subJoin
* @return Table subJoin 中的主表
*/
private List<Table> processSubJoin(ParenthesedFromItem subJoin, final String whereSegment) {
List<Table> mainTables = new ArrayList<>();
while (subJoin.getJoins() == null && subJoin.getFromItem() instanceof ParenthesedFromItem) {
subJoin = (ParenthesedFromItem) subJoin.getFromItem();
}
if (subJoin.getJoins() != null) {
List<Table> list = processFromItem(subJoin.getFromItem(), whereSegment);
mainTables.addAll(list);
processJoins(mainTables, subJoin.getJoins(), whereSegment);
}
return mainTables;
}
/**
* 处理 joins
*
* @param mainTables 可以为 null
* @param joins join 集合
* @return List<Table> 右连接查询的 Table 列表
*/
private List<Table> processJoins(List<Table> mainTables, List<Join> joins, final String whereSegment) {
// join 表达式中最终的主表
Table mainTable = null;
// 当前 join 的左表
Table leftTable = null;
if (mainTables.size() == 1) {
mainTable = mainTables.get(0);
leftTable = mainTable;
}
//对于 on 表达式写在最后的 join需要记录下前面多个 on 的表名
Deque<List<Table>> onTableDeque = new LinkedList<>();
for (Join join : joins) {
// 处理 on 表达式
FromItem joinItem = join.getRightItem();
// 获取当前 join 的表subJoint 可以看作是一张表
List<Table> joinTables = null;
if (joinItem instanceof Table) {
joinTables = new ArrayList<>();
joinTables.add((Table) joinItem);
} else if (joinItem instanceof ParenthesedFromItem) {
joinTables = processSubJoin((ParenthesedFromItem) joinItem, whereSegment);
}
if (joinTables != null) {
// 如果是隐式内连接
if (join.isSimple()) {
mainTables.addAll(joinTables);
continue;
}
// 当前表是否忽略
Table joinTable = joinTables.get(0);
List<Table> onTables = null;
// 如果不要忽略且是右连接则记录下当前表
if (join.isRight()) {
mainTable = joinTable;
mainTables.clear();
if (leftTable != null) {
onTables = Collections.singletonList(leftTable);
}
} else if (join.isInner()) {
if (mainTable == null) {
onTables = Collections.singletonList(joinTable);
} else {
onTables = Arrays.asList(mainTable, joinTable);
}
mainTable = null;
mainTables.clear();
} else {
onTables = Collections.singletonList(joinTable);
}
if (mainTable != null && !mainTables.contains(mainTable)) {
mainTables.add(mainTable);
}
// 获取 join 尾缀的 on 表达式列表
Collection<Expression> originOnExpressions = join.getOnExpressions();
// 正常 join on 表达式只有一个立刻处理
if (originOnExpressions.size() == 1 && onTables != null) {
List<Expression> onExpressions = new LinkedList<>();
onExpressions.add(builderExpression(originOnExpressions.iterator().next(), onTables, whereSegment));
join.setOnExpressions(onExpressions);
leftTable = mainTable == null ? joinTable : mainTable;
continue;
}
// 表名压栈忽略的表压入 null以便后续不处理
onTableDeque.push(onTables);
// 尾缀多个 on 表达式的时候统一处理
if (originOnExpressions.size() > 1) {
Collection<Expression> onExpressions = new LinkedList<>();
for (Expression originOnExpression : originOnExpressions) {
List<Table> currentTableList = onTableDeque.poll();
if (CollectionUtils.isEmpty(currentTableList)) {
onExpressions.add(originOnExpression);
} else {
onExpressions.add(builderExpression(originOnExpression, currentTableList, whereSegment));
}
}
join.setOnExpressions(onExpressions);
}
leftTable = joinTable;
} else {
processOtherFromItem(joinItem, whereSegment);
leftTable = null;
}
}
return mainTables;
}
/**
* 处理条件
*/
protected Expression builderExpression(Expression currentExpression, List<Table> tables, final String whereSegment) {
// 没有表需要处理直接返回
if (CollectionUtils.isEmpty(tables)) {
return currentExpression;
}
// 构造每张表的条件
List<Expression> expressions = tables.stream()
.map(item -> buildTableExpression(item, currentExpression, whereSegment))
.filter(Objects::nonNull)
.collect(Collectors.toList());
// 没有表需要处理直接返回
if (CollectionUtils.isEmpty(expressions)) {
return currentExpression;
}
// 注入的表达式
Expression injectExpression = expressions.get(0);
// 如果有多表则用 and 连接
if (expressions.size() > 1) {
for (int i = 1; i < expressions.size(); i++) {
injectExpression = new AndExpression(injectExpression, expressions.get(i));
}
}
if (currentExpression == null) {
return injectExpression;
}
if (currentExpression instanceof OrExpression) {
return new AndExpression(new Parenthesis(currentExpression), injectExpression);
} else {
return new AndExpression(currentExpression, injectExpression);
}
}
/**
* 构建数据库表的查询条件
*
* @param table 表对象
* @param where 当前where条件
* @param whereSegment 所属Mapper对象全路径
* @return 需要拼接的新条件不会覆盖原有的where条件只会在原有条件上再加条件 null 则不加入新的条件
*/
public abstract Expression buildTableExpression(final Table table, final Expression where, final String whereSegment);
}

View File

@ -0,0 +1,142 @@
/*
* Copyright (c) 2011-2024, baomidou (jobob@qq.com).
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.baomidou.mybatisplus.extension.plugins.inner;
import com.baomidou.mybatisplus.core.metadata.TableInfo;
import com.baomidou.mybatisplus.core.metadata.TableInfoHelper;
import com.baomidou.mybatisplus.core.plugins.InterceptorIgnoreHelper;
import com.baomidou.mybatisplus.core.toolkit.Assert;
import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
import com.baomidou.mybatisplus.core.toolkit.StringPool;
import com.baomidou.mybatisplus.core.toolkit.StringUtils;
import com.baomidou.mybatisplus.extension.parser.JsqlParserSupport;
import net.sf.jsqlparser.expression.BinaryExpression;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.Parenthesis;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.expression.operators.relational.IsNullExpression;
import net.sf.jsqlparser.expression.operators.relational.NotEqualsTo;
import net.sf.jsqlparser.statement.delete.Delete;
import net.sf.jsqlparser.statement.update.Update;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlCommandType;
import java.sql.Connection;
/**
* 攻击 SQL 阻断解析器,防止全表更新与删除
*
* @author hubin
* @since 3.4.0
*/
public class BlockAttackInnerInterceptor extends JsqlParserSupport implements InnerInterceptor {
@Override
public void beforePrepare(StatementHandler sh, Connection connection, Integer transactionTimeout) {
PluginUtils.MPStatementHandler handler = PluginUtils.mpStatementHandler(sh);
MappedStatement ms = handler.mappedStatement();
SqlCommandType sct = ms.getSqlCommandType();
if (sct == SqlCommandType.UPDATE || sct == SqlCommandType.DELETE) {
if (InterceptorIgnoreHelper.willIgnoreBlockAttack(ms.getId())) {
return;
}
BoundSql boundSql = handler.boundSql();
parserMulti(boundSql.getSql(), null);
}
}
@Override
protected void processDelete(Delete delete, int index, String sql, Object obj) {
this.checkWhere(delete.getTable().getName(), delete.getWhere(), "Prohibition of full table deletion");
}
@Override
protected void processUpdate(Update update, int index, String sql, Object obj) {
this.checkWhere(update.getTable().getName(), update.getWhere(), "Prohibition of table update operation");
}
protected void checkWhere(String tableName, Expression where, String ex) {
Assert.isFalse(this.fullMatch(where, this.getTableLogicField(tableName)), ex);
}
private boolean fullMatch(Expression where, String logicField) {
if (where == null) {
return true;
}
if (StringUtils.isNotBlank(logicField)) {
if (where instanceof BinaryExpression) {
BinaryExpression binaryExpression = (BinaryExpression) where;
if (StringUtils.equals(binaryExpression.getLeftExpression().toString(), logicField) || StringUtils.equals(binaryExpression.getRightExpression().toString(), logicField)) {
return true;
}
}
if (where instanceof IsNullExpression) {
IsNullExpression binaryExpression = (IsNullExpression) where;
if (StringUtils.equals(binaryExpression.getLeftExpression().toString(), logicField)) {
return true;
}
}
}
if (where instanceof EqualsTo) {
// example: 1=1
EqualsTo equalsTo = (EqualsTo) where;
return StringUtils.equals(equalsTo.getLeftExpression().toString(), equalsTo.getRightExpression().toString());
} else if (where instanceof NotEqualsTo) {
// example: 1 != 2
NotEqualsTo notEqualsTo = (NotEqualsTo) where;
return !StringUtils.equals(notEqualsTo.getLeftExpression().toString(), notEqualsTo.getRightExpression().toString());
} else if (where instanceof OrExpression) {
OrExpression orExpression = (OrExpression) where;
return fullMatch(orExpression.getLeftExpression(), logicField) || fullMatch(orExpression.getRightExpression(), logicField);
} else if (where instanceof AndExpression) {
AndExpression andExpression = (AndExpression) where;
return fullMatch(andExpression.getLeftExpression(), logicField) && fullMatch(andExpression.getRightExpression(), logicField);
} else if (where instanceof Parenthesis) {
// example: (1 = 1)
Parenthesis parenthesis = (Parenthesis) where;
return fullMatch(parenthesis.getExpression(), logicField);
}
return false;
}
/**
* 获取表名中的逻辑删除字段
*
* @param tableName 表名
* @return 逻辑删除字段
*/
private String getTableLogicField(String tableName) {
if (StringUtils.isBlank(tableName)) {
return StringPool.EMPTY;
}
TableInfo tableInfo = TableInfoHelper.getTableInfo(tableName);
if (tableInfo == null || !tableInfo.isWithLogicDelete() || tableInfo.getLogicDeleteFieldInfo() == null) {
return StringPool.EMPTY;
}
return tableInfo.getLogicDeleteFieldInfo().getColumn();
}
}

View File

@ -0,0 +1,379 @@
/*
* Copyright (c) 2011-2024, baomidou (jobob@qq.com).
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.baomidou.mybatisplus.extension.plugins.inner;
import com.baomidou.mybatisplus.core.exceptions.MybatisPlusException;
import com.baomidou.mybatisplus.core.plugins.InterceptorIgnoreHelper;
import com.baomidou.mybatisplus.core.toolkit.Assert;
import com.baomidou.mybatisplus.core.toolkit.EncryptUtils;
import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
import com.baomidou.mybatisplus.core.toolkit.StringUtils;
import com.baomidou.mybatisplus.extension.parser.JsqlParserSupport;
import com.baomidou.mybatisplus.extension.toolkit.SqlParserUtils;
import lombok.Data;
import net.sf.jsqlparser.expression.BinaryExpression;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.Function;
import net.sf.jsqlparser.expression.Parenthesis;
import net.sf.jsqlparser.expression.operators.arithmetic.Subtraction;
import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
import net.sf.jsqlparser.expression.operators.relational.InExpression;
import net.sf.jsqlparser.expression.operators.relational.NotEqualsTo;
import net.sf.jsqlparser.expression.operators.relational.ParenthesedExpressionList;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.delete.Delete;
import net.sf.jsqlparser.statement.select.FromItem;
import net.sf.jsqlparser.statement.select.Join;
import net.sf.jsqlparser.statement.select.ParenthesedSelect;
import net.sf.jsqlparser.statement.select.PlainSelect;
import net.sf.jsqlparser.statement.select.Select;
import net.sf.jsqlparser.statement.update.Update;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlCommandType;
import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
/**
* 由于开发人员水平参差不齐即使订了开发规范很多人也不遵守
* <p>SQL是影响系统性能最重要的因素所以拦截掉垃圾SQL语句</p>
* <br>
* <p>拦截SQL类型的场景</p>
* <p>1.必须使用到索引包含left join连接字段符合索引最左原则</p>
* <p>必须使用索引好处</p>
* <p>1.1 如果因为动态SQLbug导致update的where条件没有带上全表更新上万条数据</p>
* <p>1.2 如果检查到使用了索引SQL性能基本不会太差</p>
* <br>
* <p>2.SQL尽量单表执行有查询left join的语句必须在注释里面允许该SQL运行否则会被拦截有left join的语句如果不能拆成单表执行的SQL请leader商量在做</p>
* <p>https://gaoxianglong.github.io/shark</p>
* <p>SQL尽量单表执行的好处</p>
* <p>2.1 查询条件简单易于开理解和维护</p>
* <p>2.2 扩展性极强可为分库分表做准备</p>
* <p>2.3 缓存利用率高</p>
* <p>2.在字段上使用函数</p>
* <br>
* <p>3.where条件为空</p>
* <p>4.where条件使用了 !=</p>
* <p>5.where条件使用了 not 关键字</p>
* <p>6.where条件使用了 or 关键字</p>
* <p>7.where条件使用了 使用子查询</p>
*
* @author willenfoo
* @since 3.4.0
*/
public class IllegalSQLInnerInterceptor extends JsqlParserSupport implements InnerInterceptor {
/**
* 缓存验证结果提高性能
*/
private static final Set<String> cacheValidResult = new HashSet<>();
/**
* 缓存表的索引信息
*/
private static final Map<String, List<IndexInfo>> indexInfoMap = new ConcurrentHashMap<>();
@Override
public void beforePrepare(StatementHandler sh, Connection connection, Integer transactionTimeout) {
PluginUtils.MPStatementHandler mpStatementHandler = PluginUtils.mpStatementHandler(sh);
MappedStatement ms = mpStatementHandler.mappedStatement();
SqlCommandType sct = ms.getSqlCommandType();
if (sct == SqlCommandType.INSERT || InterceptorIgnoreHelper.willIgnoreIllegalSql(ms.getId())) {
return;
}
BoundSql boundSql = mpStatementHandler.boundSql();
String originalSql = boundSql.getSql();
logger.debug("检查SQL是否合规SQL:" + originalSql);
String md5Base64 = EncryptUtils.md5Base64(originalSql);
if (cacheValidResult.contains(md5Base64)) {
logger.debug("该SQL已验证无需再次验证SQL:" + originalSql);
return;
}
parserSingle(originalSql, connection);
//缓存验证结果
cacheValidResult.add(md5Base64);
}
@Override
protected void processSelect(Select select, int index, String sql, Object obj) {
if (select instanceof PlainSelect) {
PlainSelect plainSelect = (PlainSelect) select;
FromItem fromItem = ((PlainSelect) select).getFromItem();
while (fromItem instanceof ParenthesedSelect) {
ParenthesedSelect parenthesedSelect = (ParenthesedSelect) fromItem;
plainSelect = (PlainSelect) parenthesedSelect.getSelect();
fromItem = plainSelect.getFromItem();
}
Expression where = plainSelect.getWhere();
Assert.notNull(where, "非法SQL必须要有where条件");
Table table = (Table) plainSelect.getFromItem();
List<Join> joins = plainSelect.getJoins();
validWhere(where, table, (Connection) obj);
validJoins(joins, table, (Connection) obj);
}
}
@Override
protected void processUpdate(Update update, int index, String sql, Object obj) {
Expression where = update.getWhere();
Assert.notNull(where, "非法SQL必须要有where条件");
Table table = update.getTable();
List<Join> joins = update.getJoins();
validWhere(where, table, (Connection) obj);
validJoins(joins, table, (Connection) obj);
}
@Override
protected void processDelete(Delete delete, int index, String sql, Object obj) {
Expression where = delete.getWhere();
Assert.notNull(where, "非法SQL必须要有where条件");
Table table = delete.getTable();
List<Join> joins = delete.getJoins();
validWhere(where, table, (Connection) obj);
validJoins(joins, table, (Connection) obj);
}
/**
* 验证expression对象是不是 ornot等等
*
* @param expression ignore
*/
private void validExpression(Expression expression) {
while (expression instanceof Parenthesis) {
Parenthesis parenthesis = (Parenthesis) expression;
expression = parenthesis.getExpression();
}
//where条件使用了 or 关键字
if (expression instanceof OrExpression) {
OrExpression orExpression = (OrExpression) expression;
throw new MybatisPlusException("非法SQLwhere条件中不能使用【or】关键字错误or信息" + orExpression.toString());
} else if (expression instanceof NotEqualsTo) {
NotEqualsTo notEqualsTo = (NotEqualsTo) expression;
throw new MybatisPlusException("非法SQLwhere条件中不能使用【!=】关键字,错误!=信息:" + notEqualsTo.toString());
} else if (expression instanceof BinaryExpression) {
BinaryExpression binaryExpression = (BinaryExpression) expression;
// TODO 升级 jsqlparser 后待实现
// if (binaryExpression.isNot()) {
// throw new MybatisPlusException("非法SQLwhere条件中不能使用【not】关键字错误not信息" + binaryExpression.toString());
// }
if (binaryExpression.getLeftExpression() instanceof Function) {
Function function = (Function) binaryExpression.getLeftExpression();
throw new MybatisPlusException("非法SQLwhere条件中不能使用数据库函数错误函数信息" + function.toString());
}
if (binaryExpression.getRightExpression() instanceof Subtraction) {
Subtraction subSelect = (Subtraction) binaryExpression.getRightExpression();
throw new MybatisPlusException("非法SQLwhere条件中不能使用子查询错误子查询SQL信息" + subSelect.toString());
}
} else if (expression instanceof InExpression) {
InExpression inExpression = (InExpression) expression;
if (inExpression.getRightExpression() instanceof Subtraction) {
Subtraction subSelect = (Subtraction) inExpression.getRightExpression();
throw new MybatisPlusException("非法SQLwhere条件中不能使用子查询错误子查询SQL信息" + subSelect.toString());
}
}
}
/**
* 如果SQL用了 left Join验证是否有ornot等等并且验证是否使用了索引
*
* @param joins ignore
* @param table ignore
* @param connection ignore
*/
private void validJoins(List<Join> joins, Table table, Connection connection) {
//允许执行join验证jion是否使用索引等等
if (joins != null) {
for (Join join : joins) {
Table rightTable = (Table) join.getRightItem();
Collection<Expression> onExpressions = join.getOnExpressions();
for (Expression expression : onExpressions) {
validWhere(expression, table, rightTable, connection);
}
}
}
}
/**
* 检查是否使用索引
*
* @param table ignore
* @param columnName ignore
* @param connection ignore
*/
private void validUseIndex(Table table, String columnName, Connection connection) {
//是否使用索引
boolean useIndexFlag = false;
if (StringUtils.isNotBlank(columnName)) {
String tableInfo = table.getName();
//表存在的索引
String dbName = null;
String tableName;
String[] tableArray = tableInfo.split("\\.");
if (tableArray.length == 1) {
tableName = tableArray[0];
} else {
dbName = tableArray[0];
tableName = tableArray[1];
}
columnName = SqlParserUtils.removeWrapperSymbol(columnName);
List<IndexInfo> indexInfos = getIndexInfos(dbName, tableName, connection);
for (IndexInfo indexInfo : indexInfos) {
if (indexInfo.getColumnName().equalsIgnoreCase(columnName)) {
useIndexFlag = true;
break;
}
}
}
if (!useIndexFlag) {
throw new MybatisPlusException("非法SQLSQL未使用到索引, table:" + table + ", columnName:" + columnName);
}
}
/**
* 验证where条件的字段是否有notor等等并且where的第一个字段必须使用索引
*
* @param expression ignore
* @param table ignore
* @param connection ignore
*/
private void validWhere(Expression expression, Table table, Connection connection) {
validWhere(expression, table, null, connection);
}
/**
* 验证where条件的字段是否有notor等等并且where的第一个字段必须使用索引
*
* @param expression ignore
* @param table ignore
* @param joinTable ignore
* @param connection ignore
*/
private void validWhere(Expression expression, Table table, Table joinTable, Connection connection) {
validExpression(expression);
if (expression instanceof BinaryExpression) {
//获得左边表达式
Expression leftExpression = ((BinaryExpression) expression).getLeftExpression();
validExpression(leftExpression);
//如果左边表达式为Column对象则直接获得列名
if (leftExpression instanceof Column) {
Expression rightExpression = ((BinaryExpression) expression).getRightExpression();
if (joinTable != null && rightExpression instanceof Column) {
if (Objects.equals(((Column) rightExpression).getTable().getName(), table.getAlias().getName())) {
validUseIndex(table, ((Column) rightExpression).getColumnName(), connection);
validUseIndex(joinTable, ((Column) leftExpression).getColumnName(), connection);
} else {
validUseIndex(joinTable, ((Column) rightExpression).getColumnName(), connection);
validUseIndex(table, ((Column) leftExpression).getColumnName(), connection);
}
} else {
//获得列名
validUseIndex(table, ((Column) leftExpression).getColumnName(), connection);
}
}
//如果BinaryExpression进行迭代
else if (leftExpression instanceof BinaryExpression) {
validWhere(leftExpression, table, joinTable, connection);
}
//获得右边表达式并分解
if (joinTable != null) {
Expression rightExpression = ((BinaryExpression) expression).getRightExpression();
validExpression(rightExpression);
}
}
}
/**
* 得到表的索引信息
*
* @param dbName ignore
* @param tableName ignore
* @param conn ignore
* @return ignore
*/
public List<IndexInfo> getIndexInfos(String dbName, String tableName, Connection conn) {
return getIndexInfos(null, dbName, tableName, conn);
}
/**
* 得到表的索引信息
*
* @param key ignore
* @param dbName ignore
* @param tableName ignore
* @param conn ignore
* @return ignore
*/
public List<IndexInfo> getIndexInfos(String key, String dbName, String tableName, Connection conn) {
List<IndexInfo> indexInfos = null;
if (StringUtils.isNotBlank(key)) {
indexInfos = indexInfoMap.get(key);
}
if (indexInfos == null || indexInfos.isEmpty()) {
ResultSet rs;
try {
DatabaseMetaData metadata = conn.getMetaData();
String catalog = StringUtils.isBlank(dbName) ? conn.getCatalog() : dbName;
String schema = StringUtils.isBlank(dbName) ? conn.getSchema() : dbName;
rs = metadata.getIndexInfo(catalog, schema, SqlParserUtils.removeWrapperSymbol(tableName), false, true);
indexInfos = new ArrayList<>();
while (rs.next()) {
//索引中的列序列号等于1才有效
if (Objects.equals(rs.getString(8), "1")) {
IndexInfo indexInfo = new IndexInfo();
indexInfo.setDbName(rs.getString(1));
indexInfo.setTableName(rs.getString(3));
indexInfo.setColumnName(rs.getString(9));
indexInfos.add(indexInfo);
}
}
if (StringUtils.isNotBlank(key)) {
indexInfoMap.put(key, indexInfos);
}
} catch (SQLException e) {
logger.error(String.format("getIndexInfo fault, with key:%s, dbName:%s, tableName:%s", key, dbName, tableName), e);
}
}
return indexInfos;
}
/**
* 索引对象
*/
@Data
private static class IndexInfo {
private String dbName;
private String tableName;
private String columnName;
}
}

View File

@ -0,0 +1,275 @@
/*
* Copyright (c) 2011-2024, baomidou (jobob@qq.com).
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.baomidou.mybatisplus.extension.plugins.inner;
import com.baomidou.mybatisplus.core.plugins.InterceptorIgnoreHelper;
import com.baomidou.mybatisplus.core.toolkit.*;
import com.baomidou.mybatisplus.extension.plugins.handler.TenantLineHandler;
import com.baomidou.mybatisplus.extension.toolkit.PropertyMapper;
import lombok.*;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.Parenthesis;
import net.sf.jsqlparser.expression.RowConstructor;
import net.sf.jsqlparser.expression.StringValue;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
import net.sf.jsqlparser.expression.operators.relational.ParenthesedExpressionList;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.delete.Delete;
import net.sf.jsqlparser.statement.insert.Insert;
import net.sf.jsqlparser.statement.select.*;
import net.sf.jsqlparser.statement.update.Update;
import net.sf.jsqlparser.statement.update.UpdateSet;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.List;
import java.util.Properties;
/**
* @author hubin
* @since 3.4.0
*/
@Data
@NoArgsConstructor
@AllArgsConstructor
@ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true)
@SuppressWarnings({"rawtypes"})
public class TenantLineInnerInterceptor extends BaseMultiTableInnerInterceptor implements InnerInterceptor {
private TenantLineHandler tenantLineHandler;
@Override
public void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) throws SQLException {
if (InterceptorIgnoreHelper.willIgnoreTenantLine(ms.getId())) {
return;
}
PluginUtils.MPBoundSql mpBs = PluginUtils.mpBoundSql(boundSql);
mpBs.sql(parserSingle(mpBs.sql(), null));
}
@Override
public void beforePrepare(StatementHandler sh, Connection connection, Integer transactionTimeout) {
PluginUtils.MPStatementHandler mpSh = PluginUtils.mpStatementHandler(sh);
MappedStatement ms = mpSh.mappedStatement();
SqlCommandType sct = ms.getSqlCommandType();
if (sct == SqlCommandType.INSERT || sct == SqlCommandType.UPDATE || sct == SqlCommandType.DELETE) {
if (InterceptorIgnoreHelper.willIgnoreTenantLine(ms.getId())) {
return;
}
PluginUtils.MPBoundSql mpBs = mpSh.mPBoundSql();
mpBs.sql(parserMulti(mpBs.sql(), null));
}
}
@Override
protected void processSelect(Select select, int index, String sql, Object obj) {
final String whereSegment = (String) obj;
processSelectBody(select, whereSegment);
List<WithItem> withItemsList = select.getWithItemsList();
if (!CollectionUtils.isEmpty(withItemsList)) {
withItemsList.forEach(withItem -> processSelectBody(withItem, whereSegment));
}
}
@Override
protected void processInsert(Insert insert, int index, String sql, Object obj) {
if (tenantLineHandler.ignoreTable(insert.getTable().getName())) {
// 过滤退出执行
return;
}
List<Column> columns = insert.getColumns();
if (CollectionUtils.isEmpty(columns)) {
// 针对不给列名的insert 不处理
return;
}
String tenantIdColumn = tenantLineHandler.getTenantIdColumn();
if (tenantLineHandler.ignoreInsert(columns, tenantIdColumn)) {
// 针对已给出租户列的insert 不处理
return;
}
columns.add(new Column(tenantIdColumn));
Expression tenantId = tenantLineHandler.getTenantId();
// fixed gitee pulls/141 duplicate update
List<UpdateSet> duplicateUpdateColumns = insert.getDuplicateUpdateSets();
if (CollectionUtils.isNotEmpty(duplicateUpdateColumns)) {
EqualsTo equalsTo = new EqualsTo();
equalsTo.setLeftExpression(new StringValue(tenantIdColumn));
equalsTo.setRightExpression(tenantId);
duplicateUpdateColumns.add(new UpdateSet(new Column(tenantIdColumn), tenantId));
}
Select select = insert.getSelect();
if (select instanceof PlainSelect) { //fix github issue 4998 修复升级到4.5版本的问题
this.processInsertSelect(select, (String) obj);
} else if (insert.getValues() != null) {
// fixed github pull/295
Values values = insert.getValues();
ExpressionList<Expression> expressions = (ExpressionList<Expression>) values.getExpressions();
if (expressions instanceof ParenthesedExpressionList) {
expressions.addExpression(tenantId);
} else {
if (CollectionUtils.isNotEmpty(expressions)) {//fix github issue 4998 jsqlparse 4.5 批量insert ItemsList不是MultiExpressionList 需要特殊处理
int len = expressions.size();
for (int i = 0; i < len; i++) {
Expression expression = expressions.get(i);
if (expression instanceof Parenthesis) {
ExpressionList rowConstructor = new RowConstructor<>()
.withExpressions(new ExpressionList<>(((Parenthesis) expression).getExpression(), tenantId));
expressions.set(i, rowConstructor);
} else if (expression instanceof ParenthesedExpressionList) {
((ParenthesedExpressionList) expression).addExpression(tenantId);
} else {
expressions.add(tenantId);
}
}
} else {
expressions.add(tenantId);
}
}
} else {
throw ExceptionUtils.mpe("Failed to process multiple-table update, please exclude the tableName or statementId");
}
}
/**
* update 语句处理
*/
@Override
protected void processUpdate(Update update, int index, String sql, Object obj) {
final Table table = update.getTable();
if (tenantLineHandler.ignoreTable(table.getName())) {
// 过滤退出执行
return;
}
List<UpdateSet> sets = update.getUpdateSets();
if (!CollectionUtils.isEmpty(sets)) {
sets.forEach(us -> us.getValues().forEach(ex -> {
if (ex instanceof Select) {
processSelectBody(((Select) ex), (String) obj);
}
}));
}
update.setWhere(this.andExpression(table, update.getWhere(), (String) obj));
}
/**
* delete 语句处理
*/
@Override
protected void processDelete(Delete delete, int index, String sql, Object obj) {
if (tenantLineHandler.ignoreTable(delete.getTable().getName())) {
// 过滤退出执行
return;
}
delete.setWhere(this.andExpression(delete.getTable(), delete.getWhere(), (String) obj));
}
/**
* 处理 insert into select
* <p>
* 进入这里表示需要 insert 的表启用了多租户, select 的表都启动了
*
* @param selectBody SelectBody
*/
protected void processInsertSelect(Select selectBody, final String whereSegment) {
if(selectBody instanceof PlainSelect){
PlainSelect plainSelect = (PlainSelect) selectBody;
FromItem fromItem = plainSelect.getFromItem();
if (fromItem instanceof Table) {
// fixed gitee pulls/141 duplicate update
processPlainSelect(plainSelect, whereSegment);
appendSelectItem(plainSelect.getSelectItems());
} else if (fromItem instanceof Select) {
Select subSelect = (Select) fromItem;
appendSelectItem(plainSelect.getSelectItems());
processInsertSelect(subSelect, whereSegment);
}
} else if(selectBody instanceof ParenthesedSelect){
ParenthesedSelect parenthesedSelect = (ParenthesedSelect) selectBody;
processInsertSelect(parenthesedSelect.getSelect(), whereSegment);
}
}
/**
* 追加 SelectItem
*
* @param selectItems SelectItem
*/
protected void appendSelectItem(List<SelectItem<?>> selectItems) {
if (CollectionUtils.isEmpty(selectItems)) {
return;
}
if (selectItems.size() == 1) {
SelectItem item = selectItems.get(0);
Expression expression = item.getExpression();
if (expression instanceof AllColumns) {
return;
}
}
selectItems.add(new SelectItem<>(new Column(tenantLineHandler.getTenantIdColumn())));
}
/**
* 租户字段别名设置
* <p>tenantId tableAlias.tenantId</p>
*
* @param table 表对象
* @return 字段
*/
protected Column getAliasColumn(Table table) {
StringBuilder column = new StringBuilder();
// todo 该起别名就要起别名,禁止修改此处逻辑
if (table.getAlias() != null) {
column.append(table.getAlias().getName()).append(StringPool.DOT);
}
column.append(tenantLineHandler.getTenantIdColumn());
return new Column(column.toString());
}
@Override
public void setProperties(Properties properties) {
PropertyMapper.newInstance(properties).whenNotBlank("tenantLineHandler",
ClassUtils::newInstance, this::setTenantLineHandler);
}
/**
* 构建租户条件表达式
*
* @param table 表对象
* @param where 当前where条件
* @param whereSegment 所属Mapper对象全路径在原租户拦截器功能中这个参数并不需要参与相关判断
* @return 租户条件表达式
* @see BaseMultiTableInnerInterceptor#buildTableExpression(Table, Expression, String)
*/
@Override
public Expression buildTableExpression(final Table table, final Expression where, final String whereSegment) {
if (tenantLineHandler.ignoreTable(table.getName())) {
return null;
}
return new EqualsTo(getAliasColumn(table), tenantLineHandler.getTenantId());
}
}

View File

@ -0,0 +1,77 @@
package com.baomidou.mybatisplus.test;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.statement.delete.Delete;
import net.sf.jsqlparser.statement.select.PlainSelect;
import net.sf.jsqlparser.statement.select.Select;
import net.sf.jsqlparser.statement.update.Update;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import static org.assertj.core.api.Assertions.assertThat;
/**
* SQL 解析测试
*/
class JSqlParserTest {
@Test
void parser() throws Exception {
Select select = (Select) CCJSqlParserUtil.parse("SELECT a,b,c FROM tableName t WHERE t.col = 9 and b=c LIMIT 3, ?");
PlainSelect ps = (PlainSelect) select;
System.out.println(ps.getWhere().toString());
System.out.println(ps.getSelectItems().get(1).toString());
AndExpression e = (AndExpression) ps.getWhere();
System.out.println(e.getLeftExpression());
}
@Test
void testDecr() throws JSQLParserException {
// 如果连一起 SqlParser 将无法解析 , 还有种处理方式就自减为负数的时候 转为 自增.
var parse1 = CCJSqlParserUtil.parse("UPDATE test SET a = a --110");
Assertions.assertEquals("UPDATE test SET a = a", parse1.toString());
var parse2 = CCJSqlParserUtil.parse("UPDATE test SET a = a - -110");
Assertions.assertEquals("UPDATE test SET a = a - -110", parse2.toString());
}
@Test
void testIncr() throws JSQLParserException {
var parse1 = CCJSqlParserUtil.parse("UPDATE test SET a = a +-110");
Assertions.assertEquals("UPDATE test SET a = a + -110", parse1.toString());
var parse2 = CCJSqlParserUtil.parse("UPDATE test SET a = a + -110");
Assertions.assertEquals("UPDATE test SET a = a + -110", parse2.toString());
}
@Test
void notLikeParser() throws Exception {
final String targetSql = "SELECT * FROM tableName WHERE id NOT LIKE ?";
Select select = (Select) CCJSqlParserUtil.parse(targetSql);
assertThat(select.toString()).isEqualTo(targetSql);
}
@Test
void updateWhereParser() throws Exception {
Update update = (Update) CCJSqlParserUtil.parse("Update tableName t SET t.a=(select c from tn where tn.id=t.id),b=2,c=3 ");
Assertions.assertNull(update.getWhere());
}
@Test
void deleteWhereParser() throws Exception {
Delete delete = (Delete) CCJSqlParserUtil.parse("delete from tableName t");
Assertions.assertNull(delete.getWhere());
}
// @Test
// void testSelectForUpdate() throws Exception {
// Assertions.assertEquals("SELECT * FROM t_demo WHERE a = 1 FOR UPDATE",
// CCJSqlParserUtil.parse("select * from t_demo where a = 1 for update").toString());
// Assertions.assertEquals("SELECT * FROM sys_sms_send_record WHERE check_status = 0 ORDER BY submit_time ASC LIMIT 10 FOR UPDATE",
// CCJSqlParserUtil.parse("select * from sys_sms_send_record where check_status = 0 for update order by submit_time asc limit 10").toString());
// }
}

View File

@ -1,7 +1,6 @@
package com.baomidou.mybatisplus.test.extension.plugins.inner;
import com.baomidou.mybatisplus.extension.plugins.inner.DynamicTableNameInnerInterceptor;
import org.intellij.lang.annotations.Language;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.assertEquals;
@ -24,7 +23,6 @@ class DynamicTableNameInnerInterceptorTest {
interceptor.setTableNameHandler((sql, tableName) -> tableName + "_r");
// 表名相互包含
@Language("SQL")
String origin = "SELECT * FROM t_user, t_user_role";
assertEquals("SELECT * FROM t_user_r, t_user_role_r", interceptor.changeTable(origin));

View File

@ -0,0 +1,13 @@
dependencies {
api "${lib.'jsqlparser'}"
api project(":mybatis-plus-jsqlparser:mybatis-plus-jsqlparser-common")
implementation "${lib."slf4j-api"}"
implementation "de.ruedigermoeller:fst:3.0.4-jdk17"
implementation "com.github.ben-manes.caffeine:caffeine:2.9.3"
testImplementation "io.github.classgraph:classgraph:4.8.176"
testImplementation "${lib."spring-context-support"}"
testImplementation "${lib.h2}"
testImplementation group: 'com.google.guava', name: 'guava', version: '33.3.1-jre'
}
compileJava.dependsOn(processResources)

View File

@ -0,0 +1,28 @@
/*
* Copyright (c) 2011-2024, baomidou (jobob@qq.com).
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.baomidou.mybatisplus.extension.parser;
import net.sf.jsqlparser.JSQLParserException;
/**
* @author miemie
* @since 2023-08-05
*/
@FunctionalInterface
public interface JsqlParserFunction<T, R> {
R apply(T t) throws JSQLParserException;
}

View File

@ -0,0 +1,89 @@
/*
* Copyright (c) 2011-2024, baomidou (jobob@qq.com).
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.baomidou.mybatisplus.extension.parser;
import com.baomidou.mybatisplus.extension.parser.cache.JsqlParseCache;
import lombok.Setter;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.statement.Statement;
import net.sf.jsqlparser.statement.Statements;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
/**
* @author miemie
* @since 2023-08-05
*/
public class JsqlParserGlobal {
/**
* 默认线程数大小
*
* @since 3.5.6
*/
public static final int DEFAULT_THREAD_SIZE = (Runtime.getRuntime().availableProcessors() + 1) / 2;
/**
* 默认解析处理线程池
* <p>注意: 由于项目情况,机器配置等不一样因素,请自行根据情况创建指定线程池.</p>
*
* @see java.util.concurrent.ThreadPoolExecutor
* @since 3.5.6
*/
public static ExecutorService executorService = new ThreadPoolExecutor(DEFAULT_THREAD_SIZE, DEFAULT_THREAD_SIZE, 0L, TimeUnit.MILLISECONDS, new LinkedBlockingQueue<>(), r -> {
Thread thread = new Thread(r);
thread.setName("mybatis-plus-jsqlParser-" + thread.getId());
thread.setDaemon(true);
return thread;
});
@Setter
private static JsqlParserFunction<String, Statement> parserSingleFunc = sql -> CCJSqlParserUtil.parse(sql, executorService, null);
@Setter
private static JsqlParserFunction<String, Statements> parserMultiFunc = sql -> CCJSqlParserUtil.parseStatements(sql, executorService, null);
@Setter
private static JsqlParseCache jsqlParseCache;
public static Statement parse(String sql) throws JSQLParserException {
if (jsqlParseCache == null) {
return parserSingleFunc.apply(sql);
}
Statement statement = jsqlParseCache.getStatement(sql);
if (statement == null) {
statement = parserSingleFunc.apply(sql);
jsqlParseCache.putStatement(sql, statement);
}
return statement;
}
public static Statements parseStatements(String sql) throws JSQLParserException {
if (jsqlParseCache == null) {
return parserMultiFunc.apply(sql);
}
Statements statements = jsqlParseCache.getStatements(sql);
if (statements == null) {
statements = parserMultiFunc.apply(sql);
jsqlParseCache.putStatements(sql, statements);
}
return statements;
}
}

View File

@ -0,0 +1,130 @@
/*
* Copyright (c) 2011-2024, baomidou (jobob@qq.com).
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.baomidou.mybatisplus.extension.parser;
import com.baomidou.mybatisplus.core.toolkit.ExceptionUtils;
import com.baomidou.mybatisplus.core.toolkit.StringPool;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.statement.Statement;
import net.sf.jsqlparser.statement.Statements;
import net.sf.jsqlparser.statement.delete.Delete;
import net.sf.jsqlparser.statement.insert.Insert;
import net.sf.jsqlparser.statement.select.Select;
import net.sf.jsqlparser.statement.update.Update;
import org.apache.ibatis.logging.Log;
import org.apache.ibatis.logging.LogFactory;
/**
* https://github.com/JSQLParser/JSqlParser
*
* @author miemie
* @since 2020-06-22
*/
public abstract class JsqlParserSupport {
/**
* 日志
*/
protected final Log logger = LogFactory.getLog(this.getClass());
public String parserSingle(String sql, Object obj) {
if (logger.isDebugEnabled()) {
logger.debug("original SQL: " + sql);
}
try {
Statement statement = JsqlParserGlobal.parse(sql);
return processParser(statement, 0, sql, obj);
} catch (JSQLParserException e) {
throw ExceptionUtils.mpe("Failed to process, Error SQL: %s", e.getCause(), sql);
}
}
public String parserMulti(String sql, Object obj) {
if (logger.isDebugEnabled()) {
logger.debug("original SQL: " + sql);
}
try {
// fixed github pull/295
StringBuilder sb = new StringBuilder();
Statements statements = JsqlParserGlobal.parseStatements(sql);
int i = 0;
for (Statement statement : statements) {
if (i > 0) {
sb.append(StringPool.SEMICOLON);
}
sb.append(processParser(statement, i, sql, obj));
i++;
}
return sb.toString();
} catch (JSQLParserException e) {
throw ExceptionUtils.mpe("Failed to process, Error SQL: %s", e.getCause(), sql);
}
}
/**
* 执行 SQL 解析
*
* @param statement JsqlParser Statement
* @return sql
*/
protected String processParser(Statement statement, int index, String sql, Object obj) {
if (logger.isDebugEnabled()) {
logger.debug("SQL to parse, SQL: " + sql);
}
if (statement instanceof Insert) {
this.processInsert((Insert) statement, index, sql, obj);
} else if (statement instanceof Select) {
this.processSelect((Select) statement, index, sql, obj);
} else if (statement instanceof Update) {
this.processUpdate((Update) statement, index, sql, obj);
} else if (statement instanceof Delete) {
this.processDelete((Delete) statement, index, sql, obj);
}
sql = statement.toString();
if (logger.isDebugEnabled()) {
logger.debug("parse the finished SQL: " + sql);
}
return sql;
}
/**
* 新增
*/
protected void processInsert(Insert insert, int index, String sql, Object obj) {
throw new UnsupportedOperationException();
}
/**
* 删除
*/
protected void processDelete(Delete delete, int index, String sql, Object obj) {
throw new UnsupportedOperationException();
}
/**
* 更新
*/
protected void processUpdate(Update update, int index, String sql, Object obj) {
throw new UnsupportedOperationException();
}
/**
* 查询
*/
protected void processSelect(Select select, int index, String sql, Object obj) {
throw new UnsupportedOperationException();
}
}

View File

@ -0,0 +1,121 @@
/*
* Copyright (c) 2011-2024, baomidou (jobob@qq.com).
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.baomidou.mybatisplus.extension.parser.cache;
import com.github.benmanes.caffeine.cache.Cache;
import com.github.benmanes.caffeine.cache.Caffeine;
import lombok.Setter;
import net.sf.jsqlparser.statement.Statement;
import net.sf.jsqlparser.statement.Statements;
import org.apache.ibatis.logging.Log;
import org.apache.ibatis.logging.LogFactory;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.function.Consumer;
/**
* jsqlparser 缓存 Caffeine 缓存实现抽象类
*
* @author miemie hubin
* @since 2023-08-08
*/
public abstract class AbstractCaffeineJsqlParseCache implements JsqlParseCache {
protected final Log logger = LogFactory.getLog(this.getClass());
protected final Cache<String, byte[]> cache;
@Setter
protected boolean async = false;
@Setter
protected Executor executor;
public AbstractCaffeineJsqlParseCache(Cache<String, byte[]> cache) {
this.cache = cache;
}
public AbstractCaffeineJsqlParseCache(Consumer<Caffeine<Object, Object>> consumer) {
Caffeine<Object, Object> caffeine = Caffeine.newBuilder();
consumer.accept(caffeine);
this.cache = caffeine.build();
}
@Override
public void putStatement(String sql, Statement value) {
this.put(sql, value);
}
@Override
public void putStatements(String sql, Statements value) {
this.put(sql, value);
}
@Override
public Statement getStatement(String sql) {
return this.get(sql);
}
@Override
public Statements getStatements(String sql) {
return this.get(sql);
}
/**
* 获取解析对象异常清空缓存逻辑
*
* @param sql 执行 SQL
* @return 返回泛型对象
*/
protected <T> T get(String sql) {
byte[] bytes = cache.getIfPresent(sql);
if (null != bytes) {
try {
return (T) deserialize(sql, bytes);
} catch (Exception e) {
cache.invalidate(sql);
logger.error("deserialize error", e);
}
}
return null;
}
/**
* 存储解析对象
*
* @param sql 执行 SQL
* @param value 解析对象
*/
protected void put(String sql, Object value) {
if (async) {
if (executor != null) {
CompletableFuture.runAsync(() -> cache.put(sql, serialize(value)), executor);
} else {
CompletableFuture.runAsync(() -> cache.put(sql, serialize(value)));
}
} else {
cache.put(sql, serialize(value));
}
}
/**
* 序列化
*/
public abstract byte[] serialize(Object obj);
/**
* 反序列化
*/
public abstract Object deserialize(String sql, byte[] bytes);
}

View File

@ -0,0 +1,49 @@
/*
* Copyright (c) 2011-2024, baomidou (jobob@qq.com).
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.baomidou.mybatisplus.extension.parser.cache;
import com.github.benmanes.caffeine.cache.Cache;
import com.github.benmanes.caffeine.cache.Caffeine;
import java.util.function.Consumer;
/**
* jsqlparser 缓存 Fst 序列化 Caffeine 缓存实现
*
* @author miemie
* @since 2023-08-05
*/
public class FstSerialCaffeineJsqlParseCache extends AbstractCaffeineJsqlParseCache {
public FstSerialCaffeineJsqlParseCache(Cache<String, byte[]> cache) {
super(cache);
}
public FstSerialCaffeineJsqlParseCache(Consumer<Caffeine<Object, Object>> consumer) {
super(consumer);
}
@Override
public byte[] serialize(Object obj) {
return FstFactory.getDefaultFactory().asByteArray(obj);
}
@Override
public Object deserialize(String sql, byte[] bytes) {
return FstFactory.getDefaultFactory().asObject(bytes);
}
}

View File

@ -0,0 +1,50 @@
/*
* Copyright (c) 2011-2024, baomidou (jobob@qq.com).
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.baomidou.mybatisplus.extension.parser.cache;
import com.baomidou.mybatisplus.core.toolkit.SerializationUtils;
import com.github.benmanes.caffeine.cache.Cache;
import com.github.benmanes.caffeine.cache.Caffeine;
import java.util.function.Consumer;
/**
* jsqlparser 缓存 jdk 序列化 Caffeine 缓存实现
*
* @author miemie
* @since 2023-08-05
*/
public class JdkSerialCaffeineJsqlParseCache extends AbstractCaffeineJsqlParseCache {
public JdkSerialCaffeineJsqlParseCache(Cache<String, byte[]> cache) {
super(cache);
}
public JdkSerialCaffeineJsqlParseCache(Consumer<Caffeine<Object, Object>> consumer) {
super(consumer);
}
@Override
public byte[] serialize(Object obj) {
return SerializationUtils.serialize(obj);
}
@Override
public Object deserialize(String sql, byte[] bytes) {
return SerializationUtils.deserialize(bytes);
}
}

View File

@ -0,0 +1,36 @@
/*
* Copyright (c) 2011-2024, baomidou (jobob@qq.com).
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.baomidou.mybatisplus.extension.parser.cache;
import net.sf.jsqlparser.statement.Statement;
import net.sf.jsqlparser.statement.Statements;
/**
* jsqlparser 缓存接口
*
* @author miemie
* @since 2023-08-05
*/
public interface JsqlParseCache {
void putStatement(String sql, Statement value);
void putStatements(String sql, Statements value);
Statement getStatement(String sql);
Statements getStatements(String sql);
}

View File

@ -0,0 +1,36 @@
/*
* Copyright (c) 2011-2024, baomidou (jobob@qq.com).
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.baomidou.mybatisplus.extension.plugins.handler;
import net.sf.jsqlparser.expression.Expression;
/**
* 数据权限处理器
*
* @author hubin
* @since 3.4.1 +
*/
public interface DataPermissionHandler {
/**
* 获取数据权限 SQL 片段
*
* @param where 待执行 SQL Where 条件表达式
* @param mappedStatementId Mybatis MappedStatement Id 根据该参数可以判断具体执行方法
* @return JSqlParser 条件表达式返回的条件表达式会覆盖原有的条件表达式
*/
Expression getSqlSegment(Expression where, String mappedStatementId);
}

View File

@ -0,0 +1,53 @@
/*
* Copyright (c) 2011-2024, baomidou (jobob@qq.com).
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.baomidou.mybatisplus.extension.plugins.handler;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.schema.Table;
/**
* 支持多表的数据权限处理器
*
* @author houkunlin
* @since 3.5.2 +
*/
public interface MultiDataPermissionHandler extends DataPermissionHandler {
/**
* 为兼容旧版数据权限处理器继承了 {@link DataPermissionHandler} 但是新的多表数据权限处理又不会调用此方法因此标记过时
*
* @param where 待执行 SQL Where 条件表达式
* @param mappedStatementId Mybatis MappedStatement Id 根据该参数可以判断具体执行方法
* @return JSqlParser 条件表达式
* @deprecated 新的多表数据权限处理不会调用此方法因此标记过时
*/
@Deprecated
@Override
default Expression getSqlSegment(Expression where, String mappedStatementId) {
return where;
}
/**
* 获取数据权限 SQL 片段
* <p>旧的 {@link MultiDataPermissionHandler#getSqlSegment(Expression, String)} 方法第一个参数包含所有的 where 条件信息如果 return null 会覆盖原有的 where 数据</p>
* <p>新版的 {@link MultiDataPermissionHandler#getSqlSegment(Table, Expression, String)} 方法不能覆盖原有的 where 数据如果 return null 则表示不追加任何 where 条件</p>
*
* @param table 所执行的数据库表信息可以通过此参数获取表名和表别名
* @param where 原有的 where 条件信息
* @param mappedStatementId Mybatis MappedStatement Id 根据该参数可以判断具体执行方法
* @return JSqlParser 条件表达式返回的条件表达式会拼接在原有的表达式后面不会覆盖原有的表达式
*/
Expression getSqlSegment(final Table table, final Expression where, final String mappedStatementId);
}

View File

@ -0,0 +1,72 @@
/*
* Copyright (c) 2011-2024, baomidou (jobob@qq.com).
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.baomidou.mybatisplus.extension.plugins.handler;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.schema.Column;
import java.util.List;
/**
* 租户处理器 TenantId 行级
*
* @author hubin
* @since 3.4.0
*/
public interface TenantLineHandler {
/**
* 获取租户 ID 值表达式只支持单个 ID
* <p>
*
* @return 租户 ID 值表达式
*/
Expression getTenantId();
/**
* 获取租户字段名
* <p>
* 默认字段名叫: tenant_id
*
* @return 租户字段名
*/
default String getTenantIdColumn() {
return "tenant_id";
}
/**
* 根据表名判断是否忽略拼接多租户条件
* <p>
* 默认都要进行解析并拼接多租户条件
*
* @param tableName 表名
* @return 是否忽略, true:表示忽略false:需要解析并拼接多租户条件
*/
default boolean ignoreTable(String tableName) {
return false;
}
/**
* 忽略插入租户字段逻辑
*
* @param columns 插入字段
* @param tenantIdColumn 租户 ID 字段
* @return
*/
default boolean ignoreInsert(List<Column> columns, String tenantIdColumn) {
return columns.stream().map(Column::getColumnName).anyMatch(i -> i.equalsIgnoreCase(tenantIdColumn));
}
}

View File

@ -0,0 +1,990 @@
/*
* Copyright (c) 2011-2024, baomidou (jobob@qq.com).
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.baomidou.mybatisplus.extension.plugins.inner;
import java.io.Reader;
import java.sql.Clob;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Properties;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import com.baomidou.mybatisplus.core.toolkit.Constants;
import net.sf.jsqlparser.statement.select.Values;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;
import org.apache.ibatis.scripting.defaults.DefaultParameterHandler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.baomidou.mybatisplus.annotation.IEnum;
import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper;
import com.baomidou.mybatisplus.core.conditions.update.UpdateWrapper;
import com.baomidou.mybatisplus.core.exceptions.MybatisPlusException;
import com.baomidou.mybatisplus.core.handlers.MybatisEnumTypeHandler;
import com.baomidou.mybatisplus.core.metadata.TableFieldInfo;
import com.baomidou.mybatisplus.core.metadata.TableInfo;
import com.baomidou.mybatisplus.core.metadata.TableInfoHelper;
import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
import com.baomidou.mybatisplus.extension.parser.JsqlParserGlobal;
import lombok.Data;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.JdbcParameter;
import net.sf.jsqlparser.expression.RowConstructor;
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.Statement;
import net.sf.jsqlparser.statement.delete.Delete;
import net.sf.jsqlparser.statement.insert.Insert;
import net.sf.jsqlparser.statement.select.AllColumns;
import net.sf.jsqlparser.statement.select.FromItem;
import net.sf.jsqlparser.statement.select.PlainSelect;
import net.sf.jsqlparser.statement.select.Select;
import net.sf.jsqlparser.statement.select.SelectItem;
import net.sf.jsqlparser.statement.select.SetOperationList;
import net.sf.jsqlparser.statement.update.Update;
import net.sf.jsqlparser.statement.update.UpdateSet;
/**
* <p>
* 数据变动记录插件
* 默认会生成一条log格式
* ----------------------INSERT LOG------------------------------
* </p>
* <p>
* {
* "tableName": "h2user",
* "operation": "insert",
* "recordStatus": "true",
* "changedData": [
* {
* "LAST_UPDATED_DT": "null->2022-08-22 18:49:16.512",
* "TEST_ID": "null->1561666810058739714",
* "AGE": "null->THREE"
* }
* ],
* "cost(ms)": 0
* }
* </p>
* <p>
* * ----------------------UPDATE LOG------------------------------
* <p>
* {
* "tableName": "h2user",
* "operation": "update",
* "recordStatus": "true",
* "changedData": [
* {
* "TEST_ID": "102",
* "AGE": "2->THREE",
* "FIRSTNAME": "DOU.HAO->{\"json\":\"abc\"}",
* "LAST_UPDATED_DT": "null->2022-08-22 18:49:16.512"
* }
* ],
* "cost(ms)": 0
* }
* </p>
*
* @author yuxiaobin
* @date 2022-8-21
*/
public class DataChangeRecorderInnerInterceptor implements InnerInterceptor {
protected final Logger logger = LoggerFactory.getLogger(this.getClass());
@SuppressWarnings("unused")
public static final String IGNORED_TABLE_COLUMN_PROPERTIES = "ignoredTableColumns";
private final Map<String, Set<String>> ignoredTableColumns = new ConcurrentHashMap<>();
private final Set<String> ignoreAllColumns = new HashSet<>();//全部表的这些字段名INSERT/UPDATE都忽略delete暂时保留
//批量更新上限, 默认一次最多1000条
private int BATCH_UPDATE_LIMIT = 1000;
private boolean batchUpdateLimitationOpened = false;
private final Map<String, Integer> BATCH_UPDATE_LIMIT_MAP = new ConcurrentHashMap<>();//表名->批量更新上限
@Override
public void beforePrepare(StatementHandler sh, Connection connection, Integer transactionTimeout) {
PluginUtils.MPStatementHandler mpSh = PluginUtils.mpStatementHandler(sh);
MappedStatement ms = mpSh.mappedStatement();
final BoundSql boundSql = mpSh.boundSql();
SqlCommandType sct = ms.getSqlCommandType();
if (sct == SqlCommandType.INSERT || sct == SqlCommandType.UPDATE || sct == SqlCommandType.DELETE) {
PluginUtils.MPBoundSql mpBs = mpSh.mPBoundSql();
OperationResult operationResult;
long startTs = System.currentTimeMillis();
try {
Statement statement = JsqlParserGlobal.parse(mpBs.sql());
if (statement instanceof Insert) {
operationResult = processInsert((Insert) statement, mpSh.boundSql());
} else if (statement instanceof Update) {
operationResult = processUpdate((Update) statement, ms, boundSql, connection);
} else if (statement instanceof Delete) {
operationResult = processDelete((Delete) statement, ms, boundSql, connection);
} else {
logger.info("other operation sql={}", mpBs.sql());
return;
}
} catch (Exception e) {
if (e instanceof DataUpdateLimitationException) {
throw (DataUpdateLimitationException) e;
}
logger.error("Unexpected error for mappedStatement={}, sql={}", ms.getId(), mpBs.sql(), e);
return;
}
long costThis = System.currentTimeMillis() - startTs;
if (operationResult != null) {
operationResult.setCost(costThis);
dealOperationResult(operationResult);
}
}
}
/**
* 判断哪些SQL需要处理
* 默认INSERT/UPDATE/DELETE语句
*
* @param sql
* @return
*/
protected boolean allowProcess(String sql) {
String sqlTrim = sql.trim().toUpperCase();
return sqlTrim.startsWith("INSERT") || sqlTrim.startsWith("UPDATE") || sqlTrim.startsWith("DELETE");
}
/**
* 处理数据更新结果默认打印
*
* @param operationResult
*/
protected void dealOperationResult(OperationResult operationResult) {
logger.info("{}", operationResult);
}
public OperationResult processInsert(Insert insertStmt, BoundSql boundSql) {
String operation = SqlCommandType.INSERT.name().toLowerCase();
Table table = insertStmt.getTable();
String tableName = table.getName();
Optional<OperationResult> optionalOperationResult = ignoredTableColumns(tableName, operation);
if (optionalOperationResult.isPresent()) {
return optionalOperationResult.get();
}
OperationResult result = new OperationResult();
result.setOperation(operation);
result.setTableName(tableName);
result.setRecordStatus(true);
Map<String, Object> updatedColumnDatas = getUpdatedColumnDatas(tableName, boundSql, insertStmt);
result.buildDataStr(compareAndGetUpdatedColumnDatas(result.getTableName(), null, updatedColumnDatas));
return result;
}
public OperationResult processUpdate(Update updateStmt, MappedStatement mappedStatement, BoundSql boundSql, Connection connection) {
Expression where = updateStmt.getWhere();
PlainSelect selectBody = new PlainSelect();
Table table = updateStmt.getTable();
String tableName = table.getName();
String operation = SqlCommandType.UPDATE.name().toLowerCase();
Optional<OperationResult> optionalOperationResult = ignoredTableColumns(tableName, operation);
if (optionalOperationResult.isPresent()) {
return optionalOperationResult.get();
}
selectBody.setFromItem(table);
List<Column> updateColumns = new ArrayList<>();
for (UpdateSet updateSet : updateStmt.getUpdateSets()) {
updateColumns.addAll(updateSet.getColumns());
}
Columns2SelectItemsResult buildColumns2SelectItems = buildColumns2SelectItems(tableName, updateColumns);
selectBody.setSelectItems(buildColumns2SelectItems.getSelectItems());
selectBody.setWhere(where);
SelectItem<PlainSelect> plainSelectSelectItem = new SelectItem<>(selectBody);
BoundSql boundSql4Select = new BoundSql(mappedStatement.getConfiguration(), plainSelectSelectItem.toString(),
prepareParameterMapping4Select(boundSql.getParameterMappings(), updateStmt),
boundSql.getParameterObject());
PluginUtils.MPBoundSql mpBoundSql = PluginUtils.mpBoundSql(boundSql);
Map<String, Object> additionalParameters = mpBoundSql.additionalParameters();
if (additionalParameters != null && !additionalParameters.isEmpty()) {
for (Map.Entry<String, Object> ety : additionalParameters.entrySet()) {
boundSql4Select.setAdditionalParameter(ety.getKey(), ety.getValue());
}
}
Map<String, Object> updatedColumnDatas = getUpdatedColumnDatas(tableName, boundSql, updateStmt);
OriginalDataObj originalData = buildOriginalObjectData(updatedColumnDatas, selectBody, buildColumns2SelectItems.getPk(), mappedStatement, boundSql4Select, connection);
OperationResult result = new OperationResult();
result.setOperation(operation);
result.setTableName(tableName);
result.setRecordStatus(true);
result.buildDataStr(compareAndGetUpdatedColumnDatas(result.getTableName(), originalData, updatedColumnDatas));
return result;
}
private Optional<OperationResult> ignoredTableColumns(String table, String operation) {
final Set<String> ignoredColumns = ignoredTableColumns.get(table.toUpperCase());
if (ignoredColumns != null) {
if (ignoredColumns.stream().anyMatch("*"::equals)) {
OperationResult result = new OperationResult();
result.setOperation(operation);
result.setTableName(table + ":*");
result.setRecordStatus(false);
return Optional.of(result);
}
}
return Optional.empty();
}
private TableInfo getTableInfoByTableName(String tableName) {
for (TableInfo tableInfo : TableInfoHelper.getTableInfos()) {
if (tableName.equalsIgnoreCase(tableInfo.getTableName())) {
return tableInfo;
}
}
return null;
}
/**
* 将update SET部分的jdbc参数去除
*
* @param originalMappingList 这里只会包含JdbcParameter参数
* @param updateStmt
* @return
*/
private List<ParameterMapping> prepareParameterMapping4Select(List<ParameterMapping> originalMappingList, Update updateStmt) {
List<Expression> updateValueExpressions = new ArrayList<>();
for (UpdateSet updateSet : updateStmt.getUpdateSets()) {
updateValueExpressions.addAll(updateSet.getValues());
}
int removeParamCount = 0;
for (Expression expression : updateValueExpressions) {
if (expression instanceof JdbcParameter) {
++removeParamCount;
}
}
return originalMappingList.subList(removeParamCount, originalMappingList.size());
}
protected Map<String, Object> getUpdatedColumnDatas(String tableName, BoundSql updateSql, Statement statement) {
Map<String, Object> columnNameValMap = new HashMap<>(updateSql.getParameterMappings().size());
Map<Integer, String> columnSetIndexMap = new HashMap<>(updateSql.getParameterMappings().size());
List<Column> selectItemsFromUpdateSql = new ArrayList<>();
if (statement instanceof Update) {
Update updateStmt = (Update) statement;
int index = 0;
for (UpdateSet updateSet : updateStmt.getUpdateSets()) {
selectItemsFromUpdateSql.addAll(updateSet.getColumns());
final ExpressionList<Expression> updateList = (ExpressionList<Expression>) updateSet.getValues();
for (int i = 0; i < updateList.size(); ++i) {
Expression updateExps = updateList.get(i);
if (!(updateExps instanceof JdbcParameter)) {
columnNameValMap.put(updateSet.getColumns().get(i).getColumnName().toUpperCase(), updateExps.toString());
}
columnSetIndexMap.put(index++, updateSet.getColumns().get(i).getColumnName().toUpperCase());
}
}
} else if (statement instanceof Insert) {
Insert insert = (Insert) statement;
selectItemsFromUpdateSql.addAll(insert.getColumns());
columnNameValMap.putAll(detectInsertColumnValuesNonJdbcParameters(insert));
}
Map<String, String> relatedColumnsUpperCaseWithoutUnderline = new HashMap<>(selectItemsFromUpdateSql.size(), 1);
for (Column item : selectItemsFromUpdateSql) {
//FIRSTNAME: FIRST_NAME/FIRST-NAME/FIRST$NAME/FIRST.NAME
relatedColumnsUpperCaseWithoutUnderline.put(item.getColumnName().replaceAll("[._\\-$]", "").toUpperCase(), item.getColumnName().toUpperCase());
}
MetaObject metaObject = SystemMetaObject.forObject(updateSql.getParameterObject());
int index = 0;
for (ParameterMapping parameterMapping : updateSql.getParameterMappings()) {
String propertyName = parameterMapping.getProperty();
if (propertyName.startsWith("ew.paramNameValuePairs")) {
++index;
continue;
}
String[] arr = propertyName.split("\\.");
String propertyNameTrim = arr[arr.length - 1].replace("_", "").toUpperCase();
try {
final String columnName = columnSetIndexMap.getOrDefault(index++, getColumnNameByProperty(propertyNameTrim, tableName));
if (relatedColumnsUpperCaseWithoutUnderline.containsKey(propertyNameTrim)) {
final String colkey = relatedColumnsUpperCaseWithoutUnderline.get(propertyNameTrim);
Object valObj = metaObject.getValue(propertyName);
if (valObj instanceof IEnum) {
valObj = ((IEnum<?>) valObj).getValue();
} else if (valObj instanceof Enum) {
valObj = getEnumValue((Enum) valObj);
}
if (columnNameValMap.containsKey(colkey)) {
columnNameValMap.put(relatedColumnsUpperCaseWithoutUnderline.get(propertyNameTrim), String.valueOf(columnNameValMap.get(colkey)).replace("?", valObj == null ? "" : valObj.toString()));
}
if (columnName != null && !columnNameValMap.containsKey(columnName)) {
columnNameValMap.put(columnName, valObj);
}
} else {
if (columnName != null) {
columnNameValMap.put(columnName, metaObject.getValue(propertyName));
}
}
} catch (Exception e) {
logger.warn("get value error,propertyName:{},parameterMapping:{}", propertyName, parameterMapping);
}
}
dealWithUpdateWrapper(columnSetIndexMap, columnNameValMap, updateSql);
return columnNameValMap;
}
/**
* @param originalDataObj
* @return
*/
private List<DataChangedRecord> compareAndGetUpdatedColumnDatas(String tableName, OriginalDataObj originalDataObj, Map<String, Object> columnNameValMap) {
final Set<String> ignoredColumns = ignoredTableColumns.get(tableName.toUpperCase());
if (originalDataObj == null || originalDataObj.isEmpty()) {
DataChangedRecord oneRecord = new DataChangedRecord();
List<DataColumnChangeResult> updateColumns = new ArrayList<>(columnNameValMap.size());
for (Map.Entry<String, Object> ety : columnNameValMap.entrySet()) {
String columnName = ety.getKey();
if ((ignoredColumns == null || !ignoredColumns.contains(columnName)) && !ignoreAllColumns.contains(columnName)) {
updateColumns.add(DataColumnChangeResult.constrcutByUpdateVal(columnName, ety.getValue()));
}
}
oneRecord.setUpdatedColumns(updateColumns);
// oneRecord.setUpdatedColumns(Collections.EMPTY_LIST);
return Collections.singletonList(oneRecord);
}
List<DataChangedRecord> originalDataList = originalDataObj.getOriginalDataObj();
List<DataChangedRecord> updateDataList = new ArrayList<>(originalDataList.size());
for (DataChangedRecord originalData : originalDataList) {
if (originalData.hasUpdate(columnNameValMap, ignoredColumns, ignoreAllColumns)) {
updateDataList.add(originalData);
}
}
return updateDataList;
}
private Object getEnumValue(Enum enumVal) {
Optional<String> enumValueFieldName = MybatisEnumTypeHandler.findEnumValueFieldName(enumVal.getClass());
if (enumValueFieldName.isPresent()) {
return SystemMetaObject.forObject(enumVal).getValue(enumValueFieldName.get());
}
return enumVal;
}
@SuppressWarnings("rawtypes")
private void dealWithUpdateWrapper(Map<Integer, String> columnSetIndexMap, Map<String, Object> columnNameValMap, BoundSql updateSql) {
if (columnSetIndexMap.size() <= columnNameValMap.size()) {
return;
}
MetaObject mpgenVal = SystemMetaObject.forObject(updateSql.getParameterObject());
if(!mpgenVal.hasGetter(Constants.WRAPPER)){
return;
}
Object ew = mpgenVal.getValue(Constants.WRAPPER);
if (ew instanceof UpdateWrapper || ew instanceof LambdaUpdateWrapper) {
final String sqlSet = ew instanceof UpdateWrapper ? ((UpdateWrapper) ew).getSqlSet() : ((LambdaUpdateWrapper) ew).getSqlSet();// columnName=#{val}
if (sqlSet == null) {
return;
}
MetaObject ewMeta = SystemMetaObject.forObject(ew);
final Map paramNameValuePairs = (Map) ewMeta.getValue("paramNameValuePairs");
String[] setItems = sqlSet.split(",");
for (String setItem : setItems) {
//age=#{ew.paramNameValuePairs.MPGENVAL1}
String[] nameAndValuePair = setItem.split("=", 2);
if (nameAndValuePair.length == 2) {
String setColName = nameAndValuePair[0].trim().toUpperCase();
String setColVal = nameAndValuePair[1].trim();//#{.mp}
if (columnSetIndexMap.containsValue(setColName)) {
String[] mpGenKeyArray = setColVal.split("\\.");
String mpGenKey = mpGenKeyArray[mpGenKeyArray.length - 1].replace("}", "");
final Object setVal = paramNameValuePairs.get(mpGenKey);
if (setVal instanceof IEnum) {
columnNameValMap.put(setColName, String.valueOf(((IEnum<?>) setVal).getValue()));
} else {
columnNameValMap.put(setColName, setVal);
}
}
}
}
}
}
private Map<String, String> detectInsertColumnValuesNonJdbcParameters(Insert insert) {
Map<String, String> columnNameValMap = new HashMap<>(4);
final Select select = insert.getSelect();
List<Column> columns = insert.getColumns();
if (select instanceof SetOperationList) {
SetOperationList setOperationList = (SetOperationList) select;
final List<Select> selects = setOperationList.getSelects();
if (CollectionUtils.isEmpty(selects)) {
return columnNameValMap;
}
final Select selectBody = selects.get(0);
if (!(selectBody instanceof Values)) {
return columnNameValMap;
}
Values valuesStatement = (Values) selectBody;
if (valuesStatement.getExpressions() instanceof ExpressionList) {
ExpressionList expressionList = valuesStatement.getExpressions();
List<Expression> expressions = expressionList;
for (Expression expression : expressions) {
if (expression instanceof RowConstructor) {
final ExpressionList exprList = ((RowConstructor) expression);
final List<Expression> insertExpList = exprList;
for (int i = 0; i < insertExpList.size(); ++i) {
Expression e = insertExpList.get(i);
if (!(e instanceof JdbcParameter)) {
final String columnName = columns.get(i).getColumnName();
final String val = e.toString();
columnNameValMap.put(columnName, val);
}
}
}
}
}
}
return columnNameValMap;
}
private String getColumnNameByProperty(String propertyName, String tableName) {
for (TableInfo tableInfo : TableInfoHelper.getTableInfos()) {
if (tableName.equalsIgnoreCase(tableInfo.getTableName())) {
final List<TableFieldInfo> fieldList = tableInfo.getFieldList();
if (CollectionUtils.isEmpty(fieldList)) {
return propertyName;
}
for (TableFieldInfo tableFieldInfo : fieldList) {
if (propertyName.equalsIgnoreCase(tableFieldInfo.getProperty())) {
return tableFieldInfo.getColumn().toUpperCase();
}
}
return propertyName;
}
}
return propertyName;
}
private Map<String, Object> buildParameterObjectMap(BoundSql boundSql) {
MetaObject metaObject = PluginUtils.getMetaObject(boundSql.getParameterObject());
Map<String, Object> propertyValMap = new HashMap<>(boundSql.getParameterMappings().size());
for (ParameterMapping parameterMapping : boundSql.getParameterMappings()) {
String propertyName = parameterMapping.getProperty();
if (propertyName.startsWith("ew.paramNameValuePairs")) {
continue;
}
Object propertyValue = metaObject.getValue(propertyName);
propertyValMap.put(propertyName, propertyValue);
}
return propertyValMap;
}
private String buildOriginalData(Select selectStmt, MappedStatement mappedStatement, BoundSql boundSql, Connection connection) {
try (PreparedStatement statement = connection.prepareStatement(selectStmt.toString())) {
DefaultParameterHandler parameterHandler = new DefaultParameterHandler(mappedStatement, boundSql.getParameterObject(), boundSql);
parameterHandler.setParameters(statement);
ResultSet resultSet = statement.executeQuery();
final ResultSetMetaData metaData = resultSet.getMetaData();
int columnCount = metaData.getColumnCount();
StringBuilder sb = new StringBuilder("[");
int count = 0;
while (resultSet.next()) {
++count;
if (checkTableBatchLimitExceeded(selectStmt, count)) {
logger.error("batch delete limit exceed: count={}, BATCH_UPDATE_LIMIT={}", count, BATCH_UPDATE_LIMIT);
throw DataUpdateLimitationException.DEFAULT;
}
sb.append("{");
for (int i = 1; i <= columnCount; ++i) {
sb.append("\"").append(metaData.getColumnName(i)).append("\":\"");
Object res = resultSet.getObject(i);
if (res instanceof Clob) {
sb.append(DataColumnChangeResult.convertClob((Clob) res));
} else {
sb.append(res);
}
sb.append("\",");
}
sb.replace(sb.length() - 1, sb.length(), "}");
}
sb.append("]");
resultSet.close();
return sb.toString();
} catch (Exception e) {
if (e instanceof DataUpdateLimitationException) {
throw (DataUpdateLimitationException) e;
}
logger.error("try to get record tobe deleted for selectStmt={}", selectStmt, e);
return "failed to get original data";
}
}
private OriginalDataObj buildOriginalObjectData(Map<String, Object> updatedColumnDatas, Select selectStmt, Column pk, MappedStatement mappedStatement, BoundSql boundSql, Connection connection) {
try (PreparedStatement statement = connection.prepareStatement(selectStmt.toString())) {
DefaultParameterHandler parameterHandler = new DefaultParameterHandler(mappedStatement, boundSql.getParameterObject(), boundSql);
parameterHandler.setParameters(statement);
ResultSet resultSet = statement.executeQuery();
List<DataChangedRecord> originalObjectDatas = new LinkedList<>();
int count = 0;
while (resultSet.next()) {
++count;
if (checkTableBatchLimitExceeded(selectStmt, count)) {
logger.error("batch update limit exceed: count={}, BATCH_UPDATE_LIMIT={}", count, BATCH_UPDATE_LIMIT);
throw DataUpdateLimitationException.DEFAULT;
}
originalObjectDatas.add(prepareOriginalDataObj(updatedColumnDatas, resultSet, pk));
}
OriginalDataObj result = new OriginalDataObj();
result.setOriginalDataObj(originalObjectDatas);
resultSet.close();
return result;
} catch (Exception e) {
if (e instanceof DataUpdateLimitationException) {
throw (DataUpdateLimitationException) e;
}
logger.error("try to get record tobe updated for selectStmt={}", selectStmt, e);
return new OriginalDataObj();
}
}
/**
* 防止出现全表批量更新
* 默认一次更新不超过1000条
*
* @param selectStmt
* @param count
* @return
*/
private boolean checkTableBatchLimitExceeded(Select selectStmt, int count) {
if (!batchUpdateLimitationOpened) {
return false;
}
final PlainSelect selectBody = (PlainSelect) selectStmt;
final FromItem fromItem = selectBody.getFromItem();
if (fromItem instanceof Table) {
Table fromTable = (Table) fromItem;
final String tableName = fromTable.getName().toUpperCase();
if (!BATCH_UPDATE_LIMIT_MAP.containsKey(tableName)) {
if (count > BATCH_UPDATE_LIMIT) {
logger.error("batch update limit exceed for tableName={}, BATCH_UPDATE_LIMIT={}, count={}",
tableName, BATCH_UPDATE_LIMIT, count);
return true;
}
return false;
}
final Integer limit = BATCH_UPDATE_LIMIT_MAP.get(tableName);
if (count > limit) {
logger.error("batch update limit exceed for configured tableName={}, BATCH_UPDATE_LIMIT={}, count={}",
tableName, limit, count);
return true;
}
return false;
}
return count > BATCH_UPDATE_LIMIT;
}
/**
* get records : include related column with original data in DB
*
* @param resultSet
* @param pk
* @return
* @throws SQLException
*/
private DataChangedRecord prepareOriginalDataObj(Map<String, Object> updatedColumnDatas, ResultSet resultSet, Column pk) throws SQLException {
final ResultSetMetaData metaData = resultSet.getMetaData();
int columnCount = metaData.getColumnCount();
List<DataColumnChangeResult> originalColumnDatas = new LinkedList<>();
DataColumnChangeResult pkval = null;
for (int i = 1; i <= columnCount; ++i) {
String columnName = metaData.getColumnName(i).toUpperCase();
DataColumnChangeResult col;
Object updateVal = updatedColumnDatas.get(columnName);
if (updateVal != null && updateVal.getClass().getCanonicalName().startsWith("java.")) {
col = DataColumnChangeResult.constrcutByOriginalVal(columnName, resultSet.getObject(i, updateVal.getClass()));
} else {
col = DataColumnChangeResult.constrcutByOriginalVal(columnName, resultSet.getObject(i));
}
if (pk != null && columnName.equalsIgnoreCase(pk.getColumnName())) {
pkval = col;
} else {
originalColumnDatas.add(col);
}
}
DataChangedRecord changedRecord = new DataChangedRecord();
changedRecord.setOriginalColumnDatas(originalColumnDatas);
if (pkval != null) {
changedRecord.setPkColumnName(pkval.getColumnName());
changedRecord.setPkColumnVal(pkval.getOriginalValue());
}
return changedRecord;
}
private Columns2SelectItemsResult buildColumns2SelectItems(String tableName, List<Column> columns) {
if (columns == null || columns.isEmpty()) {
return Columns2SelectItemsResult.build(Collections.singletonList(new SelectItem<>(new AllColumns())), 0);
}
List<SelectItem<?>> selectItems = new ArrayList<>(columns.size());
for (Column column : columns) {
selectItems.add(new SelectItem<>(column));
}
TableInfo tableInfo = getTableInfoByTableName(tableName);
if (tableInfo == null) {
return Columns2SelectItemsResult.build(selectItems, 0);
}
Column pk = new Column(tableInfo.getKeyColumn());
selectItems.add(new SelectItem<>(pk));
Columns2SelectItemsResult result = Columns2SelectItemsResult.build(selectItems, 1);
result.setPk(pk);
return result;
}
private String buildParameterObject(BoundSql boundSql) {
Object paramObj = boundSql.getParameterObject();
StringBuilder sb = new StringBuilder();
sb.append("{");
if (paramObj instanceof Map) {
Map<String, Object> paramMap = (Map<String, Object>) paramObj;
int index = 1;
boolean hasParamIndex = false;
String key;
while (paramMap.containsKey((key = "param" + index))) {
Object paramIndex = paramMap.get(key);
sb.append("\"").append(key).append("\"").append(":").append("\"").append(paramIndex).append("\"").append(",");
hasParamIndex = true;
++index;
}
if (hasParamIndex) {
sb.delete(sb.length() - 1, sb.length());
sb.append("}");
return sb.toString();
}
for (Map.Entry<String, Object> ety : paramMap.entrySet()) {
sb.append("\"").append(ety.getKey()).append("\"").append(":").append("\"").append(ety.getValue()).append("\"").append(",");
}
sb.delete(sb.length() - 1, sb.length());
sb.append("}");
return sb.toString();
}
sb.append("param:").append(paramObj);
sb.append("}");
return sb.toString();
}
public OperationResult processDelete(Delete deleteStmt, MappedStatement mappedStatement, BoundSql boundSql, Connection connection) {
Table table = deleteStmt.getTable();
Expression where = deleteStmt.getWhere();
PlainSelect selectBody = new PlainSelect();
selectBody.setFromItem(table);
selectBody.setSelectItems(Collections.singletonList(new SelectItem<>((new AllColumns()))));
selectBody.setWhere(where);
String originalData = buildOriginalData(selectBody, mappedStatement, boundSql, connection);
OperationResult result = new OperationResult();
result.setOperation("delete");
result.setTableName(table.getName());
result.setRecordStatus(originalData.startsWith("["));
result.setChangedData(originalData);
return result;
}
/**
* 设置批量更新记录条数上限
*
* @param limit
* @return
*/
public DataChangeRecorderInnerInterceptor setBatchUpdateLimit(int limit) {
this.BATCH_UPDATE_LIMIT = limit;
return this;
}
public DataChangeRecorderInnerInterceptor openBatchUpdateLimitation() {
this.batchUpdateLimitationOpened = true;
return this;
}
public DataChangeRecorderInnerInterceptor configTableLimitation(String tableName, int limit) {
this.BATCH_UPDATE_LIMIT_MAP.put(tableName.toUpperCase(), limit);
return this;
}
/**
* ignoredColumns = TABLE_NAME1.COLUMN1,COLUMN2; TABLE2.COLUMN1,COLUMN2; TABLE3.*; *.COLUMN1,COLUMN2
* 多个表用分号分隔
* TABLE_NAME1.COLUMN1,COLUMN2 : 表示忽略这个表的这2个字段
* TABLE3.*: 表示忽略这张表的INSERT/UPDATEdelete暂时还保留
* *.COLUMN1,COLUMN2:表示所有表的这个2个字段名都忽略
*
* @param properties
*/
@Override
public void setProperties(Properties properties) {
String ignoredTableColumns = properties.getProperty("ignoredTableColumns");
if (ignoredTableColumns == null || ignoredTableColumns.trim().isEmpty()) {
return;
}
String[] array = ignoredTableColumns.split(";");
for (String table : array) {
int index = table.indexOf(".");
if (index == -1) {
logger.warn("invalid data={} for ignoredColumns, format should be TABLE_NAME1.COLUMN1,COLUMN2; TABLE2.COLUMN1,COLUMN2;", table);
continue;
}
String tableName = table.substring(0, index).trim().toUpperCase();
String[] columnArray = table.substring(index + 1).split(",");
Set<String> columnSet = new HashSet<>(columnArray.length);
for (String column : columnArray) {
column = column.trim().toUpperCase();
if (column.isEmpty()) {
continue;
}
columnSet.add(column);
}
if ("*".equals(tableName)) {
ignoreAllColumns.addAll(columnSet);
} else {
this.ignoredTableColumns.put(tableName, columnSet);
}
}
}
@Data
public static class OperationResult {
private String operation;
private boolean recordStatus;
private String tableName;
private String changedData;
/**
* cost for this plugin, ms
*/
private long cost;
public void buildDataStr(List<DataChangedRecord> records) {
StringBuilder sb = new StringBuilder();
sb.append("[");
for (DataChangedRecord r : records) {
sb.append(r.generateUpdatedDataStr()).append(",");
}
if (sb.length() == 1) {
sb.append("]");
changedData = sb.toString();
return;
}
sb.replace(sb.length() - 1, sb.length(), "]");
changedData = sb.toString();
}
@Override
public String toString() {
return "{" +
"\"tableName\":\"" + tableName + "\"," +
"\"operation\":\"" + operation + "\"," +
"\"recordStatus\":\"" + recordStatus + "\"," +
"\"changedData\":" + changedData + "," +
"\"cost(ms)\":" + cost + "}";
}
}
@Data
public static class Columns2SelectItemsResult {
private Column pk;
/**
* all column with additional columns: ID, etc.
*/
private List<SelectItem<?>> selectItems;
/**
* newly added column count from meta data.
*/
private int additionalItemCount;
public static Columns2SelectItemsResult build(List<SelectItem<?>> selectItems, int additionalItemCount) {
Columns2SelectItemsResult result = new Columns2SelectItemsResult();
result.setSelectItems(selectItems);
result.setAdditionalItemCount(additionalItemCount);
return result;
}
}
@Data
public static class OriginalDataObj {
private List<DataChangedRecord> originalDataObj;
public boolean isEmpty() {
return originalDataObj == null || originalDataObj.isEmpty();
}
}
@Data
public static class DataColumnChangeResult {
private String columnName;
private Object originalValue;
private Object updateValue;
@SuppressWarnings("rawtypes")
public boolean isDataChanged(Object updateValue) {
if (!Objects.equals(originalValue, updateValue)) {
if (originalValue instanceof Clob) {
String originalStr = convertClob((Clob) originalValue);
setOriginalValue(originalStr);
return !originalStr.equals(updateValue);
}
if (originalValue instanceof Comparable) {
Comparable original = (Comparable) originalValue;
Comparable update = (Comparable) updateValue;
try {
return update == null || original.compareTo(update) != 0;
} catch (Exception e) {
return true;
}
}
return true;
}
return false;
}
public static String convertClob(Clob clobObj) {
try {
return clobObj.getSubString(0, (int) clobObj.length());
} catch (Exception e) {
try (Reader is = clobObj.getCharacterStream()) {
char[] chars = new char[64];
int readChars;
StringBuilder sb = new StringBuilder();
while ((readChars = is.read(chars)) != -1) {
sb.append(chars, 0, readChars);
}
return sb.toString();
} catch (Exception e2) {
//ignored
return "unknown clobObj";
}
}
}
public static DataColumnChangeResult constrcutByUpdateVal(String columnName, Object updateValue) {
DataColumnChangeResult res = new DataColumnChangeResult();
res.setColumnName(columnName);
res.setUpdateValue(updateValue);
return res;
}
public static DataColumnChangeResult constrcutByOriginalVal(String columnName, Object originalValue) {
DataColumnChangeResult res = new DataColumnChangeResult();
res.setColumnName(columnName);
res.setOriginalValue(originalValue);
return res;
}
public String generateDataStr() {
StringBuilder sb = new StringBuilder();
sb.append("\"").append(columnName).append("\"").append(":").append("\"").append(convertDoubleQuotes(originalValue)).append("->").append(convertDoubleQuotes(updateValue)).append("\"").append(",");
return sb.toString();
}
public String convertDoubleQuotes(Object obj) {
if (obj == null) {
return null;
}
return obj.toString().replace("\"", "\\\"");
}
}
@Data
public static class DataChangedRecord {
private String pkColumnName;
private Object pkColumnVal;
private List<DataColumnChangeResult> originalColumnDatas;
private List<DataColumnChangeResult> updatedColumns;
public boolean hasUpdate(Map<String, Object> columnNameValMap, Set<String> ignoredColumns, Set<String> ignoreAllColumns) {
if (originalColumnDatas == null) {
return true;
}
boolean hasUpdate = false;
updatedColumns = new ArrayList<>(originalColumnDatas.size());
for (DataColumnChangeResult originalColumn : originalColumnDatas) {
final String columnName = originalColumn.getColumnName().toUpperCase();
if (ignoredColumns != null && ignoredColumns.contains(columnName) || ignoreAllColumns.contains(columnName)) {
continue;
}
Object updatedValue = columnNameValMap.get(columnName);
if (originalColumn.isDataChanged(updatedValue)) {
hasUpdate = true;
originalColumn.setUpdateValue(updatedValue);
updatedColumns.add(originalColumn);
}
}
return hasUpdate;
}
public String generateUpdatedDataStr() {
StringBuilder sb = new StringBuilder();
sb.append("{");
if (pkColumnName != null) {
sb.append("\"").append(pkColumnName).append("\"").append(":").append("\"").append(convertDoubleQuotes(pkColumnVal)).append("\"").append(",");
}
for (DataColumnChangeResult update : updatedColumns) {
sb.append(update.generateDataStr());
}
sb.replace(sb.length() - 1, sb.length(), "}");
return sb.toString();
}
public String convertDoubleQuotes(Object obj) {
if (obj == null) {
return null;
}
return obj.toString().replace("\"", "\\\"");
}
}
public static class DataUpdateLimitationException extends MybatisPlusException {
public DataUpdateLimitationException(String message) {
super(message);
}
public static DataUpdateLimitationException DEFAULT = new DataUpdateLimitationException("本次操作 因超过系统安全阈值 被拦截,如需继续,请联系管理员!");
}
}

View File

@ -0,0 +1,175 @@
/*
* Copyright (c) 2011-2024, baomidou (jobob@qq.com).
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.baomidou.mybatisplus.extension.plugins.inner;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.List;
import com.baomidou.mybatisplus.extension.plugins.handler.DataPermissionHandler;
import com.baomidou.mybatisplus.extension.plugins.handler.MultiDataPermissionHandler;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import com.baomidou.mybatisplus.core.plugins.InterceptorIgnoreHelper;
import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import lombok.ToString;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.delete.Delete;
import net.sf.jsqlparser.statement.select.PlainSelect;
import net.sf.jsqlparser.statement.select.Select;
import net.sf.jsqlparser.statement.select.SetOperationList;
import net.sf.jsqlparser.statement.select.WithItem;
import net.sf.jsqlparser.statement.update.Update;
/**
* 数据权限处理器
*
* @author hubin
* @since 3.5.2
*/
@Data
@NoArgsConstructor
@AllArgsConstructor
@ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true)
@SuppressWarnings({"rawtypes"})
public class DataPermissionInterceptor extends BaseMultiTableInnerInterceptor implements InnerInterceptor {
private DataPermissionHandler dataPermissionHandler;
@SuppressWarnings("RedundantThrows")
@Override
public void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) throws SQLException {
if (InterceptorIgnoreHelper.willIgnoreDataPermission(ms.getId())) {
return;
}
PluginUtils.MPBoundSql mpBs = PluginUtils.mpBoundSql(boundSql);
mpBs.sql(parserSingle(mpBs.sql(), ms.getId()));
}
@Override
public void beforePrepare(StatementHandler sh, Connection connection, Integer transactionTimeout) {
PluginUtils.MPStatementHandler mpSh = PluginUtils.mpStatementHandler(sh);
MappedStatement ms = mpSh.mappedStatement();
SqlCommandType sct = ms.getSqlCommandType();
if (sct == SqlCommandType.UPDATE || sct == SqlCommandType.DELETE) {
if (InterceptorIgnoreHelper.willIgnoreDataPermission(ms.getId())) {
return;
}
PluginUtils.MPBoundSql mpBs = mpSh.mPBoundSql();
mpBs.sql(parserMulti(mpBs.sql(), ms.getId()));
}
}
@Override
protected void processSelect(Select select, int index, String sql, Object obj) {
if (dataPermissionHandler == null) {
return;
}
if (dataPermissionHandler instanceof MultiDataPermissionHandler) {
// 参照 com.baomidou.mybatisplus.extension.plugins.inner.TenantLineInnerInterceptor.processSelect 做的修改
final String whereSegment = (String) obj;
processSelectBody(select, whereSegment);
List<WithItem> withItemsList = select.getWithItemsList();
if (!CollectionUtils.isEmpty(withItemsList)) {
withItemsList.forEach(withItem -> processSelectBody(withItem, whereSegment));
}
} else {
// 兼容原来的旧版 DataPermissionHandler 场景
if (select instanceof PlainSelect) {
this.setWhere((PlainSelect) select, (String) obj);
} else if (select instanceof SetOperationList) {
SetOperationList setOperationList = (SetOperationList) select;
List<Select> selectBodyList = setOperationList.getSelects();
selectBodyList.forEach(s -> this.setWhere((PlainSelect) s, (String) obj));
}
}
}
/**
* 设置 where 条件
*
* @param plainSelect 查询对象
* @param whereSegment 查询条件片段
*/
protected void setWhere(PlainSelect plainSelect, String whereSegment) {
if (dataPermissionHandler == null) {
return;
}
// 兼容旧版的数据权限处理
final Expression sqlSegment = dataPermissionHandler.getSqlSegment(plainSelect.getWhere(), whereSegment);
if (null != sqlSegment) {
plainSelect.setWhere(sqlSegment);
}
}
/**
* update 语句处理
*/
@Override
protected void processUpdate(Update update, int index, String sql, Object obj) {
final Expression sqlSegment = getUpdateOrDeleteExpression(update.getTable(), update.getWhere(), (String) obj);
if (null != sqlSegment) {
update.setWhere(sqlSegment);
}
}
/**
* delete 语句处理
*/
@Override
protected void processDelete(Delete delete, int index, String sql, Object obj) {
final Expression sqlSegment = getUpdateOrDeleteExpression(delete.getTable(), delete.getWhere(), (String) obj);
if (null != sqlSegment) {
delete.setWhere(sqlSegment);
}
}
protected Expression getUpdateOrDeleteExpression(final Table table, final Expression where, final String whereSegment) {
if (dataPermissionHandler == null) {
return null;
}
if (dataPermissionHandler instanceof MultiDataPermissionHandler) {
return andExpression(table, where, whereSegment);
} else {
// 兼容旧版的数据权限处理
return dataPermissionHandler.getSqlSegment(where, whereSegment);
}
}
@Override
public Expression buildTableExpression(final Table table, final Expression where, final String whereSegment) {
if (dataPermissionHandler == null) {
return null;
}
// 只有新版数据权限处理器才会执行到这里
final MultiDataPermissionHandler handler = (MultiDataPermissionHandler) dataPermissionHandler;
return handler.getSqlSegment(table, where, whereSegment);
}
}

View File

@ -0,0 +1,104 @@
/*
* Copyright (c) 2011-2024, baomidou (jobob@qq.com).
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.baomidou.mybatisplus.extension.plugins.inner;
import com.baomidou.mybatisplus.core.plugins.InterceptorIgnoreHelper;
import com.baomidou.mybatisplus.core.toolkit.ExceptionUtils;
import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
import com.baomidou.mybatisplus.core.toolkit.TableNameParser;
import com.baomidou.mybatisplus.extension.plugins.handler.TableNameHandler;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;
/**
* 动态表名
*
* @author jobob
* @since 3.4.0
*/
@Getter
@Setter
@NoArgsConstructor
@SuppressWarnings({"rawtypes"})
public class DynamicTableNameInnerInterceptor implements InnerInterceptor {
private Runnable hook;
/**
* 表名处理器是否处理表名的情况都在该处理器中自行判断
*/
private TableNameHandler tableNameHandler;
public DynamicTableNameInnerInterceptor(TableNameHandler tableNameHandler) {
this.tableNameHandler = tableNameHandler;
}
@Override
public void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) throws SQLException {
if (InterceptorIgnoreHelper.willIgnoreDynamicTableName(ms.getId())) return;
PluginUtils.MPBoundSql mpBs = PluginUtils.mpBoundSql(boundSql);
mpBs.sql(this.changeTable(mpBs.sql()));
}
@Override
public void beforePrepare(StatementHandler sh, Connection connection, Integer transactionTimeout) {
PluginUtils.MPStatementHandler mpSh = PluginUtils.mpStatementHandler(sh);
MappedStatement ms = mpSh.mappedStatement();
SqlCommandType sct = ms.getSqlCommandType();
if (sct == SqlCommandType.INSERT || sct == SqlCommandType.UPDATE || sct == SqlCommandType.DELETE) {
if (InterceptorIgnoreHelper.willIgnoreDynamicTableName(ms.getId())) {
return;
}
PluginUtils.MPBoundSql mpBs = mpSh.mPBoundSql();
mpBs.sql(this.changeTable(mpBs.sql()));
}
}
public String changeTable(String sql) {
ExceptionUtils.throwMpe(null == tableNameHandler, "Please implement TableNameHandler processing logic");
TableNameParser parser = new TableNameParser(sql);
List<TableNameParser.SqlToken> names = new ArrayList<>();
parser.accept(names::add);
StringBuilder builder = new StringBuilder();
int last = 0;
for (TableNameParser.SqlToken name : names) {
int start = name.getStart();
if (start != last) {
builder.append(sql, last, start);
builder.append(tableNameHandler.dynamicTableName(sql, name.getValue()));
}
last = name.getEnd();
}
if (last != sql.length()) {
builder.append(sql.substring(last));
}
if (hook != null) {
hook.run();
}
return builder.toString();
}
}

View File

@ -0,0 +1,319 @@
/*
* Copyright (c) 2011-2024, baomidou (jobob@qq.com).
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.baomidou.mybatisplus.extension.plugins.inner;
import com.baomidou.mybatisplus.annotation.Version;
import com.baomidou.mybatisplus.core.conditions.AbstractWrapper;
import com.baomidou.mybatisplus.core.conditions.ISqlSegment;
import com.baomidou.mybatisplus.core.conditions.Wrapper;
import com.baomidou.mybatisplus.core.conditions.segments.NormalSegmentList;
import com.baomidou.mybatisplus.core.conditions.update.Update;
import com.baomidou.mybatisplus.core.conditions.update.UpdateWrapper;
import com.baomidou.mybatisplus.core.enums.SqlKeyword;
import com.baomidou.mybatisplus.core.mapper.Mapper;
import com.baomidou.mybatisplus.core.metadata.TableFieldInfo;
import com.baomidou.mybatisplus.core.metadata.TableInfo;
import com.baomidou.mybatisplus.core.metadata.TableInfoHelper;
import com.baomidou.mybatisplus.core.toolkit.Constants;
import com.baomidou.mybatisplus.core.toolkit.ExceptionUtils;
import com.baomidou.mybatisplus.core.toolkit.ReflectionKit;
import com.baomidou.mybatisplus.core.toolkit.StringPool;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlCommandType;
import java.lang.reflect.Field;
import java.sql.SQLException;
import java.sql.Timestamp;
import java.time.Instant;
import java.time.LocalDateTime;
import java.util.Date;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
/**
* Optimistic Lock Light version
* <p>Intercept on {@link Executor}.update;</p>
* <p>Support version types: int/Integer, long/Long, java.util.Date, java.sql.Timestamp</p>
* <p>For extra types, please define a subclass and override {@code getUpdatedVersionVal}() method.</p>
* <br>
* <p>How to use?</p>
* <p>(1) Define an Entity and add {@link Version} annotation on one entity field.</p>
* <p>(2) Add {@link OptimisticLockerInnerInterceptor} into mybatis plugin.</p>
* <br>
* <p>How to work?</p>
* <p>if update entity with version column=1:</p>
* <p>(1) no {@link OptimisticLockerInnerInterceptor}:</p>
* <p>SQL: update tbl_test set name='abc' where id=100001;</p>
* <p>(2) add {@link OptimisticLockerInnerInterceptor}:</p>
* <p>SQL: update tbl_test set name='abc',version=2 where id=100001 and version=1;</p>
*
* @author yuxiaobin
* @since 3.4.0
*/
@SuppressWarnings({"unchecked"})
public class OptimisticLockerInnerInterceptor implements InnerInterceptor {
private RuntimeException exception;
public void setException(RuntimeException exception) {
this.exception = exception;
}
/**
* entity类缓存
*/
private static final Map<String, Class<?>> ENTITY_CLASS_CACHE = new ConcurrentHashMap<>();
/**
* 变量占位符正则
*/
private static final Pattern PARAM_PAIRS_RE = Pattern.compile("#\\{ew\\.paramNameValuePairs\\.(" + Constants.WRAPPER_PARAM + "\\d+)\\}");
/**
* paramNameValuePairs存放的version值的key
*/
private static final String UPDATED_VERSION_VAL_KEY = "#updatedVersionVal#";
/**
* Support wrapper mode
*/
private final boolean wrapperMode;
public OptimisticLockerInnerInterceptor() {
this(false);
}
public OptimisticLockerInnerInterceptor(boolean wrapperMode) {
this.wrapperMode = wrapperMode;
}
@Override
public void beforeUpdate(Executor executor, MappedStatement ms, Object parameter) throws SQLException {
if (SqlCommandType.UPDATE != ms.getSqlCommandType()) {
return;
}
if (parameter instanceof Map) {
Map<String, Object> map = (Map<String, Object>) parameter;
doOptimisticLocker(map, ms.getId());
}
}
protected void doOptimisticLocker(Map<String, Object> map, String msId) {
// updateById(et), update(et, wrapper);
Object et = map.getOrDefault(Constants.ENTITY, null);
if (Objects.nonNull(et)) {
// version field
TableFieldInfo fieldInfo = this.getVersionFieldInfo(et.getClass());
if (null == fieldInfo) {
return;
}
try {
Field versionField = fieldInfo.getField();
// 旧的 version
Object originalVersionVal = versionField.get(et);
if (originalVersionVal == null) {
if (null != exception) {
/**
* 自定义异常处理
*/
throw exception;
}
return;
}
String versionColumn = fieldInfo.getColumn();
// 新的 version
Object updatedVersionVal = this.getUpdatedVersionVal(fieldInfo.getPropertyType(), originalVersionVal);
String methodName = msId.substring(msId.lastIndexOf(StringPool.DOT) + 1);
if ("update".equals(methodName)) {
AbstractWrapper<?, ?, ?> aw = (AbstractWrapper<?, ?, ?>) map.getOrDefault(Constants.WRAPPER, null);
if (aw == null) {
UpdateWrapper<?> uw = new UpdateWrapper<>();
uw.eq(versionColumn, originalVersionVal);
map.put(Constants.WRAPPER, uw);
} else {
aw.apply(versionColumn + " = {0}", originalVersionVal);
}
} else {
map.put(Constants.MP_OPTLOCK_VERSION_ORIGINAL, originalVersionVal);
}
versionField.set(et, updatedVersionVal);
} catch (IllegalAccessException e) {
throw ExceptionUtils.mpe(e);
}
}
// update(LambdaUpdateWrapper) or update(UpdateWrapper)
else if (wrapperMode && map.entrySet().stream().anyMatch(t -> Objects.equals(t.getKey(), Constants.WRAPPER))) {
setVersionByWrapper(map, msId);
}
}
protected TableFieldInfo getVersionFieldInfo(Class<?> entityClazz) {
TableInfo tableInfo = TableInfoHelper.getTableInfo(entityClazz);
return (null != tableInfo && tableInfo.isWithVersion()) ? tableInfo.getVersionFieldInfo() : null;
}
private void setVersionByWrapper(Map<String, Object> map, String msId) {
Object ew = map.get(Constants.WRAPPER);
if (ew instanceof AbstractWrapper && ew instanceof Update) {
Class<?> entityClass = ENTITY_CLASS_CACHE.get(msId);
if (null == entityClass) {
try {
final String className = msId.substring(0, msId.lastIndexOf('.'));
entityClass = ReflectionKit.getSuperClassGenericType(Class.forName(className), Mapper.class, 0);
ENTITY_CLASS_CACHE.put(msId, entityClass);
} catch (ClassNotFoundException e) {
throw ExceptionUtils.mpe(e);
}
}
final TableFieldInfo versionField = getVersionFieldInfo(entityClass);
if (null == versionField) {
return;
}
final String versionColumn = versionField.getColumn();
final FieldEqFinder fieldEqFinder = new FieldEqFinder(versionColumn, (Wrapper<?>) ew);
if (!fieldEqFinder.isPresent()) {
return;
}
final Map<String, Object> paramNameValuePairs = ((AbstractWrapper<?, ?, ?>) ew).getParamNameValuePairs();
final Object originalVersionValue = paramNameValuePairs.get(fieldEqFinder.valueKey);
if (originalVersionValue == null) {
return;
}
final Object updatedVersionVal = getUpdatedVersionVal(originalVersionValue.getClass(), originalVersionValue);
if (originalVersionValue == updatedVersionVal) {
return;
}
// 拼接新的version值
paramNameValuePairs.put(UPDATED_VERSION_VAL_KEY, updatedVersionVal);
((Update<?, ?>) ew).setSql(String.format("%s = #{%s.%s}", versionColumn, "ew.paramNameValuePairs", UPDATED_VERSION_VAL_KEY));
}
}
/**
* EQ字段查找器
*/
private static class FieldEqFinder {
/**
* 状态机
*/
enum State {
INIT,
FIELD_FOUND,
EQ_FOUND,
VERSION_VALUE_PRESENT;
}
/**
* 字段值的key
*/
private String valueKey;
/**
* 当前状态
*/
private State state;
/**
* 字段名
*/
private final String fieldName;
public FieldEqFinder(String fieldName, Wrapper<?> wrapper) {
this.fieldName = fieldName;
state = State.INIT;
find(wrapper);
}
/**
* 是否已存在
*/
public boolean isPresent() {
return state == State.VERSION_VALUE_PRESENT;
}
private boolean find(Wrapper<?> wrapper) {
Matcher matcher;
final NormalSegmentList segments = wrapper.getExpression().getNormal();
for (ISqlSegment segment : segments) {
// 如果字段已找到并且当前segment为EQ
if (state == State.FIELD_FOUND && segment == SqlKeyword.EQ) {
this.state = State.EQ_FOUND;
// 如果EQ找到并且value已找到
} else if (state == State.EQ_FOUND
&& (matcher = PARAM_PAIRS_RE.matcher(segment.getSqlSegment())).matches()) {
this.valueKey = matcher.group(1);
this.state = State.VERSION_VALUE_PRESENT;
return true;
// 处理嵌套
} else if (segment instanceof Wrapper) {
if (find((Wrapper<?>) segment)) {
return true;
}
// 判断字段是否是要查找字段
} else if (segment.getSqlSegment().equals(this.fieldName)) {
this.state = State.FIELD_FOUND;
}
}
return false;
}
}
private static class VersionFactory {
/**
* 存放版本号类型与获取更新后版本号的map
*/
private static final Map<Class<?>, Function<Object, Object>> VERSION_FUNCTION_MAP = new HashMap<>();
static {
VERSION_FUNCTION_MAP.put(long.class, version -> (long) version + 1);
VERSION_FUNCTION_MAP.put(Long.class, version -> (long) version + 1);
VERSION_FUNCTION_MAP.put(int.class, version -> (int) version + 1);
VERSION_FUNCTION_MAP.put(Integer.class, version -> (int) version + 1);
VERSION_FUNCTION_MAP.put(Date.class, version -> new Date());
VERSION_FUNCTION_MAP.put(Timestamp.class, version -> new Timestamp(System.currentTimeMillis()));
VERSION_FUNCTION_MAP.put(LocalDateTime.class, version -> LocalDateTime.now());
VERSION_FUNCTION_MAP.put(Instant.class, version -> Instant.now());
}
public static Object getUpdatedVersionVal(Class<?> clazz, Object originalVersionVal) {
Function<Object, Object> versionFunction = VERSION_FUNCTION_MAP.get(clazz);
if (versionFunction == null) {
// not supported type, return original val.
return originalVersionVal;
}
return versionFunction.apply(originalVersionVal);
}
}
/**
* This method provides the control for version value.<BR>
* Returned value type must be the same as original one.
*
* @param originalVersionVal ignore
* @return updated version val
*/
protected Object getUpdatedVersionVal(Class<?> clazz, Object originalVersionVal) {
return VersionFactory.getUpdatedVersionVal(clazz, originalVersionVal);
}
}

View File

@ -0,0 +1,478 @@
/*
* Copyright (c) 2011-2024, baomidou (jobob@qq.com).
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.baomidou.mybatisplus.extension.plugins.inner;
import com.baomidou.mybatisplus.annotation.DbType;
import com.baomidou.mybatisplus.core.metadata.IPage;
import com.baomidou.mybatisplus.core.metadata.OrderItem;
import com.baomidou.mybatisplus.core.toolkit.*;
import com.baomidou.mybatisplus.extension.parser.JsqlParserGlobal;
import com.baomidou.mybatisplus.extension.plugins.pagination.DialectFactory;
import com.baomidou.mybatisplus.extension.plugins.pagination.DialectModel;
import com.baomidou.mybatisplus.extension.plugins.pagination.dialects.IDialect;
import com.baomidou.mybatisplus.extension.toolkit.JdbcUtils;
import com.baomidou.mybatisplus.extension.toolkit.PropertyMapper;
import com.baomidou.mybatisplus.extension.toolkit.SqlParserUtils;
import lombok.Data;
import lombok.NoArgsConstructor;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Alias;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.select.*;
import org.apache.ibatis.cache.CacheKey;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.logging.Log;
import org.apache.ibatis.logging.LogFactory;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.mapping.ResultMap;
import org.apache.ibatis.session.Configuration;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import java.sql.SQLException;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
/**
* 分页拦截器
* <p>
* 默认对 left join 进行优化,虽然能优化count,但是加上分页的话如果1对多本身结果条数就是不正确的
*
* @author hubin
* @since 3.4.0
*/
@Data
@NoArgsConstructor
@SuppressWarnings({"rawtypes"})
public class PaginationInnerInterceptor implements InnerInterceptor {
/**
* 获取jsqlparser中count的SelectItem
*/
protected static final List<SelectItem<?>> COUNT_SELECT_ITEM = Collections.singletonList(
new SelectItem<>(new Column().withColumnName("COUNT(*)")).withAlias(new Alias("total"))
);
protected static final Map<String, MappedStatement> countMsCache = new ConcurrentHashMap<>();
protected final Log logger = LogFactory.getLog(this.getClass());
/**
* 溢出总页数后是否进行处理
*/
protected boolean overflow;
/**
* 单页分页条数限制
*/
protected Long maxLimit;
/**
* 数据库类型
* <p>
* 查看 {@link #findIDialect(Executor)} 逻辑
*/
private DbType dbType;
/**
* 方言实现类
* <p>
* 查看 {@link #findIDialect(Executor)} 逻辑
*/
private IDialect dialect;
/**
* 生成 countSql 优化掉 join
* 现在只支持 left join
*
* @since 3.4.2
*/
protected boolean optimizeJoin = true;
public PaginationInnerInterceptor(DbType dbType) {
this.dbType = dbType;
}
public PaginationInnerInterceptor(IDialect dialect) {
this.dialect = dialect;
}
/**
* 这里进行count,如果count为0这返回false(就是不再执行sql了)
*/
@Override
public boolean willDoQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) throws SQLException {
IPage<?> page = ParameterUtils.findPage(parameter).orElse(null);
if (page == null || page.getSize() < 0 || !page.searchCount() || resultHandler != Executor.NO_RESULT_HANDLER) {
return true;
}
BoundSql countSql;
MappedStatement countMs = buildCountMappedStatement(ms, page.countId());
if (countMs != null) {
countSql = countMs.getBoundSql(parameter);
} else {
countMs = buildAutoCountMappedStatement(ms);
String countSqlStr = autoCountSql(page, boundSql.getSql());
PluginUtils.MPBoundSql mpBoundSql = PluginUtils.mpBoundSql(boundSql);
countSql = new BoundSql(countMs.getConfiguration(), countSqlStr, mpBoundSql.parameterMappings(), parameter);
PluginUtils.setAdditionalParameter(countSql, mpBoundSql.additionalParameters());
}
CacheKey cacheKey = executor.createCacheKey(countMs, parameter, rowBounds, countSql);
List<Object> result = executor.query(countMs, parameter, rowBounds, resultHandler, cacheKey, countSql);
long total = 0;
if (CollectionUtils.isNotEmpty(result)) {
// 个别数据库 count 没数据不会返回 0
Object o = result.get(0);
if (o != null) {
total = Long.parseLong(o.toString());
}
}
page.setTotal(total);
return continuePage(page);
}
@Override
public void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) {
IPage<?> page = ParameterUtils.findPage(parameter).orElse(null);
if (null == page) {
return;
}
// 处理 orderBy 拼接
boolean addOrdered = false;
String buildSql = boundSql.getSql();
List<OrderItem> orders = page.orders();
if (CollectionUtils.isNotEmpty(orders)) {
addOrdered = true;
buildSql = this.concatOrderBy(buildSql, orders);
}
// size 小于 0 且不限制返回值则不构造分页sql
Long _limit = page.maxLimit() != null ? page.maxLimit() : maxLimit;
if (page.getSize() < 0 && null == _limit) {
if (addOrdered) {
PluginUtils.mpBoundSql(boundSql).sql(buildSql);
}
return;
}
handlerLimit(page, _limit);
IDialect dialect = findIDialect(executor);
final Configuration configuration = ms.getConfiguration();
DialectModel model = dialect.buildPaginationSql(buildSql, page.offset(), page.getSize());
PluginUtils.MPBoundSql mpBoundSql = PluginUtils.mpBoundSql(boundSql);
List<ParameterMapping> mappings = mpBoundSql.parameterMappings();
Map<String, Object> additionalParameter = mpBoundSql.additionalParameters();
model.consumers(mappings, configuration, additionalParameter);
mpBoundSql.sql(model.getDialectSql());
mpBoundSql.parameterMappings(mappings);
}
/**
* 获取分页方言类的逻辑
*
* @param executor Executor
* @return 分页方言类
*/
protected IDialect findIDialect(Executor executor) {
if (dialect != null) {
return dialect;
}
if (dbType != null) {
dialect = DialectFactory.getDialect(dbType);
return dialect;
}
return DialectFactory.getDialect(JdbcUtils.getDbType(executor));
}
/**
* 获取指定的 id MappedStatement
*
* @param ms MappedStatement
* @param countId id
* @return MappedStatement
*/
protected MappedStatement buildCountMappedStatement(MappedStatement ms, String countId) {
if (StringUtils.isNotBlank(countId)) {
final String id = ms.getId();
if (!countId.contains(StringPool.DOT)) {
countId = id.substring(0, id.lastIndexOf(StringPool.DOT) + 1) + countId;
}
final Configuration configuration = ms.getConfiguration();
try {
return CollectionUtils.computeIfAbsent(countMsCache, countId, key -> configuration.getMappedStatement(key, false));
} catch (Exception e) {
logger.warn(String.format("can not find this countId: [\"%s\"]", countId));
}
}
return null;
}
/**
* 构建 mp 自用自动的 MappedStatement
*
* @param ms MappedStatement
* @return MappedStatement
*/
protected MappedStatement buildAutoCountMappedStatement(MappedStatement ms) {
final String countId = ms.getId() + "_mpCount";
final Configuration configuration = ms.getConfiguration();
return CollectionUtils.computeIfAbsent(countMsCache, countId, key -> {
MappedStatement.Builder builder = new MappedStatement.Builder(configuration, key, ms.getSqlSource(), ms.getSqlCommandType());
builder.resource(ms.getResource());
builder.fetchSize(ms.getFetchSize());
builder.statementType(ms.getStatementType());
builder.timeout(ms.getTimeout());
builder.parameterMap(ms.getParameterMap());
builder.resultMaps(Collections.singletonList(new ResultMap.Builder(configuration, Constants.MYBATIS_PLUS, Long.class, Collections.emptyList()).build()));
builder.resultSetType(ms.getResultSetType());
builder.cache(ms.getCache());
builder.flushCacheRequired(ms.isFlushCacheRequired());
builder.useCache(ms.isUseCache());
return builder.build();
});
}
/**
* 获取自动优化的 countSql
*
* @param page 参数
* @param sql sql
* @return countSql
*/
public String autoCountSql(IPage<?> page, String sql) {
if (!page.optimizeCountSql()) {
return lowLevelCountSql(sql);
}
try {
Select select = (Select) JsqlParserGlobal.parse(sql);
// https://github.com/baomidou/mybatis-plus/issues/3920 分页增加union语法支持
if (select instanceof SetOperationList) {
return lowLevelCountSql(sql);
}
PlainSelect plainSelect = (PlainSelect) select;
// 优化 order by 在非分组情况下
List<OrderByElement> orderBy = plainSelect.getOrderByElements();
if (CollectionUtils.isNotEmpty(orderBy)) {
boolean canClean = true;
for (OrderByElement order : orderBy) {
// order by 里带参数,不去除order by
Expression expression = order.getExpression();
if (!(expression instanceof Column) && expression.toString().contains(StringPool.QUESTION_MARK)) {
canClean = false;
break;
}
}
if (canClean) {
plainSelect.setOrderByElements(null);
}
}
Distinct distinct = plainSelect.getDistinct();
GroupByElement groupBy = plainSelect.getGroupBy();
// 包含 distinctgroupBy 不优化
if (null != distinct || null != groupBy) {
return lowLevelCountSql(select.toString());
}
//#95 Github, selectItems contains #{} ${}, which will be translated to ?, and it may be in a function: power(#{myInt},2)
for (SelectItem item : plainSelect.getSelectItems()) {
if (item.toString().contains(StringPool.QUESTION_MARK)) {
return lowLevelCountSql(select.toString());
}
}
// 包含 join 连表,进行判断是否移除 join 连表
if (optimizeJoin && page.optimizeJoinOfCountSql()) {
List<Join> joins = plainSelect.getJoins();
if (CollectionUtils.isNotEmpty(joins)) {
boolean canRemoveJoin = true;
String whereS = Optional.ofNullable(plainSelect.getWhere()).map(Expression::toString).orElse(StringPool.EMPTY);
// 不区分大小写
whereS = whereS.toLowerCase();
for (Join join : joins) {
if (!join.isLeft()) {
canRemoveJoin = false;
break;
}
FromItem rightItem = join.getRightItem();
String str = "";
if (rightItem instanceof Table) {
Table table = (Table) rightItem;
str = Optional.ofNullable(table.getAlias()).map(Alias::getName).orElse(table.getName()) + StringPool.DOT;
} else if (rightItem instanceof ParenthesedSelect) {
ParenthesedSelect subSelect = (ParenthesedSelect) rightItem;
/* 如果 left join 是子查询,并且子查询里包含 ?(代表有入参) 或者 where 条件里包含使用 join 的表的字段作条件,就不移除 join */
if (subSelect.toString().contains(StringPool.QUESTION_MARK)) {
canRemoveJoin = false;
break;
}
str = subSelect.getAlias().getName() + StringPool.DOT;
}
// 不区分大小写
str = str.toLowerCase();
if (whereS.contains(str)) {
/* 如果 where 条件里包含使用 join 的表的字段作条件,就不移除 join */
canRemoveJoin = false;
break;
}
for (Expression expression : join.getOnExpressions()) {
if (expression.toString().contains(StringPool.QUESTION_MARK)) {
/* 如果 join 里包含 ?(代表有入参) 就不移除 join */
canRemoveJoin = false;
break;
}
}
}
if (canRemoveJoin) {
plainSelect.setJoins(null);
}
}
}
// 优化 SQL
plainSelect.setSelectItems(COUNT_SELECT_ITEM);
return select.toString();
} catch (JSQLParserException e) {
// 无法优化使用原 SQL
logger.warn("optimize this sql to a count sql has exception, sql:\"" + sql + "\", exception:\n" + e.getCause());
} catch (Exception e) {
logger.warn("optimize this sql to a count sql has error, sql:\"" + sql + "\", exception:\n" + e);
}
return lowLevelCountSql(sql);
}
/**
* 无法进行count优化时,降级使用此方法
*
* @param originalSql 原始sql
* @return countSql
*/
protected String lowLevelCountSql(String originalSql) {
return SqlParserUtils.getOriginalCountSql(originalSql);
}
/**
* 查询SQL拼接Order By
*
* @param originalSql 需要拼接的SQL
* @return ignore
*/
public String concatOrderBy(String originalSql, List<OrderItem> orderList) {
try {
Select selectBody = (Select) JsqlParserGlobal.parse(originalSql);
if (selectBody instanceof PlainSelect) {
PlainSelect plainSelect = (PlainSelect) selectBody;
List<OrderByElement> orderByElements = plainSelect.getOrderByElements();
List<OrderByElement> orderByElementsReturn = addOrderByElements(orderList, orderByElements);
plainSelect.setOrderByElements(orderByElementsReturn);
return plainSelect.toString();
} else if (selectBody instanceof SetOperationList) {
SetOperationList setOperationList = (SetOperationList) selectBody;
List<OrderByElement> orderByElements = setOperationList.getOrderByElements();
List<OrderByElement> orderByElementsReturn = addOrderByElements(orderList, orderByElements);
setOperationList.setOrderByElements(orderByElementsReturn);
return setOperationList.toString();
} else if (selectBody instanceof WithItem) {
// todo: don't known how to resole
return originalSql;
} else {
return originalSql;
}
} catch (JSQLParserException e) {
logger.warn("failed to concat orderBy from IPage, exception:\n" + e.getCause());
} catch (Exception e) {
logger.warn("failed to concat orderBy from IPage, exception:\n" + e);
}
return originalSql;
}
protected List<OrderByElement> addOrderByElements(List<OrderItem> orderList, List<OrderByElement> orderByElements) {
List<OrderByElement> additionalOrderBy = orderList.stream()
.filter(item -> StringUtils.isNotBlank(item.getColumn()))
.map(item -> {
OrderByElement element = new OrderByElement();
element.setExpression(new Column(item.getColumn()));
element.setAsc(item.isAsc());
element.setAscDescPresent(true);
return element;
}).collect(Collectors.toList());
if (CollectionUtils.isEmpty(orderByElements)) {
return additionalOrderBy;
}
// github pull/3550 优化排序比如默认 order by id 前端传了name排序设置为 order by name,id
additionalOrderBy.addAll(orderByElements);
return additionalOrderBy;
}
/**
* count 查询之后,是否继续执行分页
*
* @param page 分页对象
* @return 是否
*/
protected boolean continuePage(IPage<?> page) {
if (page.getTotal() <= 0) {
return false;
}
if (page.getCurrent() > page.getPages()) {
if (overflow) {
//溢出总页数处理
handlerOverflow(page);
} else {
// 超过最大范围未设置溢出逻辑中断 list 执行
return false;
}
}
return true;
}
/**
* 处理超出分页条数限制,默认归为限制数
*
* @param page IPage
*/
protected void handlerLimit(IPage<?> page, Long limit) {
final long size = page.getSize();
if (limit != null && limit > 0 && (size > limit || size < 0)) {
page.setSize(limit);
}
}
/**
* 处理页数溢出,默认设置为第一页
*
* @param page IPage
*/
protected void handlerOverflow(IPage<?> page) {
page.setCurrent(1);
}
@Override
public void setProperties(Properties properties) {
PropertyMapper.newInstance(properties)
.whenNotBlank("overflow", Boolean::parseBoolean, this::setOverflow)
.whenNotBlank("dbType", DbType::getDbType, this::setDbType)
.whenNotBlank("dialect", ClassUtils::newInstance, this::setDialect)
.whenNotBlank("maxLimit", Long::parseLong, this::setMaxLimit)
.whenNotBlank("optimizeJoin", Boolean::parseBoolean, this::setOptimizeJoin);
}
}

View File

@ -0,0 +1,62 @@
/*
* Copyright (c) 2011-2024, baomidou (jobob@qq.com).
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.baomidou.mybatisplus.extension.plugins.inner;
import com.baomidou.mybatisplus.core.config.GlobalConfig;
import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
import com.baomidou.mybatisplus.core.toolkit.sql.SqlUtils;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.logging.Log;
import org.apache.ibatis.logging.LogFactory;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import java.sql.SQLException;
import java.util.List;
/**
* 功能类似于 {@link GlobalConfig.DbConfig#isReplacePlaceholder()},
* 只是这个是在运行时实时替换,适用范围更广
*
* @author miemie
* @since 2020-11-19
*/
public class ReplacePlaceholderInnerInterceptor implements InnerInterceptor {
protected final Log logger = LogFactory.getLog(this.getClass());
private String escapeSymbol;
public ReplacePlaceholderInnerInterceptor() {
}
public ReplacePlaceholderInnerInterceptor(String escapeSymbol) {
this.escapeSymbol = escapeSymbol;
}
@Override
public void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) throws SQLException {
String sql = boundSql.getSql();
List<String> find = SqlUtils.findPlaceholder(sql);
if (CollectionUtils.isNotEmpty(find)) {
sql = SqlUtils.replaceSqlPlaceholder(sql, find, escapeSymbol);
PluginUtils.mpBoundSql(boundSql).sql(sql);
}
}
}

View File

@ -0,0 +1,93 @@
package com.baomidou.mybatisplus.test.extension.parser;
import com.baomidou.mybatisplus.extension.parser.cache.FstFactory;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.statement.Statement;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledOnJre;
import org.junit.jupiter.api.condition.JRE;
import org.springframework.util.SerializationUtils;
import static org.assertj.core.api.Assertions.assertThat;
/**
* @author miemie
* @since 2023-08-03
*/
class JsqlParserSimpleSerialTest {
private final static int len = 1000;
private final static String sql = "SELECT * FROM entity e " +
"LEFT JOIN entity1 e1 " +
"LEFT JOIN entity2 e2 ON e2.id = e1.id " +
"ON e1.id = e.id " +
"WHERE (e.id = ? OR e.NAME = ?)";
@Test
@EnabledOnJre(JRE.JAVA_8)
void test() throws JSQLParserException {
System.out.println("循环次数: " + len);
noSerial();
jdkSerial();
fstSerial();
}
void noSerial() throws JSQLParserException {
long startTime = System.currentTimeMillis();
for (int i = 0; i < len; i++) {
CCJSqlParserUtil.parse(sql);
}
long endTime = System.currentTimeMillis();
long e1 = endTime - startTime;
System.out.printf("普通解析执行耗时: %s 毫秒, 均耗时: %s%n", e1, (double) e1 / len);
}
void jdkSerial() throws JSQLParserException {
Statement statement = CCJSqlParserUtil.parse(sql);
String target = statement.toString();
byte[] serial = null;
long startTime = System.currentTimeMillis();
for (int i = 0; i < len; i++) {
serial = SerializationUtils.serialize(statement);
}
long endTime = System.currentTimeMillis();
long et = endTime - startTime;
System.out.printf("jdk serialize 执行耗时: %s 毫秒,byte大小: %s, 均耗时: %s%n", et, serial.length, (double) et / len);
startTime = System.currentTimeMillis();
for (int i = 0; i < len; i++) {
statement = (Statement) SerializationUtils.deserialize(serial);
}
endTime = System.currentTimeMillis();
et = endTime - startTime;
System.out.printf("jdk deserialize 执行耗时: %s 毫秒, 均耗时: %s%n", et, (double) et / len);
assertThat(statement).isNotNull();
assertThat(statement.toString()).isEqualTo(target);
}
void fstSerial() throws JSQLParserException {
Statement statement = CCJSqlParserUtil.parse(sql);
String target = statement.toString();
FstFactory factory = FstFactory.getDefaultFactory();
byte[] serial = null;
long startTime = System.currentTimeMillis();
for (int i = 0; i < len; i++) {
serial = factory.asByteArray(statement);
}
long endTime = System.currentTimeMillis();
long et = endTime - startTime;
System.out.printf("fst serialize 执行耗时: %s 毫秒,byte大小: %s, 均耗时: %s%n", et, serial.length, (double) et / len);
startTime = System.currentTimeMillis();
for (int i = 0; i < len; i++) {
statement = (Statement) factory.asObject(serial);
}
endTime = System.currentTimeMillis();
et = endTime - startTime;
System.out.printf("fst deserialize 执行耗时: %s 毫秒, 均耗时: %s%n", et, (double) et / len);
assertThat(statement).isNotNull();
assertThat(statement.toString()).isEqualTo(target);
}
}

View File

@ -0,0 +1,37 @@
package com.baomidou.mybatisplus.test.extension.parser.cache;
import io.github.classgraph.ClassGraph;
import io.github.classgraph.ClassInfo;
import io.github.classgraph.ScanResult;
import org.junit.jupiter.api.Test;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
/**
* @author miemie
* @since 2023-08-06
*/
class FstFactoryTest {
@Test
void clazz() {
List<ClassInfo> list = new ArrayList<>();
List<ClassInfo> absList = new ArrayList<>();
try (ScanResult scanResult = new ClassGraph().enableClassInfo().acceptPackages("net.sf.jsqlparser").scan()) {
for (ClassInfo classInfo : scanResult.getAllClasses()) {
if (!classInfo.isInterface() && classInfo.implementsInterface(Serializable.class)) {
if (classInfo.isAbstract()) {
absList.add(classInfo);
continue;
}
list.add(classInfo);
}
}
}
list.forEach(i -> System.out.printf("conf.registerClass(%s.class);%n", i.getName().replace("$", ".")));
System.out.println("↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓");
absList.forEach(i -> System.out.printf("conf.registerClass(%s.class);%n", i.getName().replace("$", ".")));
}
}

View File

@ -0,0 +1,40 @@
package com.baomidou.mybatisplus.test.extension.plugins;
import com.baomidou.mybatisplus.annotation.DbType;
import com.baomidou.mybatisplus.extension.plugins.MybatisPlusInterceptor;
import com.baomidou.mybatisplus.extension.plugins.inner.InnerInterceptor;
import com.baomidou.mybatisplus.extension.plugins.inner.PaginationInnerInterceptor;
import org.junit.jupiter.api.Test;
import java.util.List;
import java.util.Properties;
import static org.assertj.core.api.Assertions.assertThat;
/**
* @author miemie
* @since 2020-06-30
*/
class MybatisPlusInterceptorTest {
@Test
void setProperties() {
Properties properties = new Properties();
properties.setProperty("@page", "com.baomidou.mybatisplus.extension.plugins.inner.PaginationInnerInterceptor");
properties.setProperty("page:maxLimit", "10");
properties.setProperty("page:dbType", "h2");
MybatisPlusInterceptor interceptor = new MybatisPlusInterceptor();
interceptor.setProperties(properties);
List<InnerInterceptor> interceptors = interceptor.getInterceptors();
assertThat(interceptors).isNotEmpty();
assertThat(interceptors.size()).isEqualTo(1);
InnerInterceptor page = interceptors.get(0);
assertThat(page).isInstanceOf(PaginationInnerInterceptor.class);
PaginationInnerInterceptor pii = (PaginationInnerInterceptor) page;
assertThat(pii.getMaxLimit()).isEqualTo(10);
assertThat(pii.getDbType()).isEqualTo(DbType.H2);
}
}

View File

@ -0,0 +1,62 @@
package com.baomidou.mybatisplus.test.extension.plugins.inner;
import com.baomidou.mybatisplus.core.exceptions.MybatisPlusException;
import com.baomidou.mybatisplus.extension.plugins.inner.BlockAttackInnerInterceptor;
import org.junit.jupiter.api.Test;
import static org.assertj.core.api.Assertions.assertThat;
/**
* @author miemie
* @since 2020-08-18
*/
class BlockAttackInnerInterceptorTest {
private final BlockAttackInnerInterceptor interceptor = new BlockAttackInnerInterceptor();
@Test
void update() {
checkEx("update user set name = null", "null where");
checkEx("update user set name = null where 1=1", "1=1");
checkEx("update user set name = null where 1<>2", "1<>2");
checkEx("update user set name = null where 1!=2", "1!=2");
checkEx("update user set name = null where 1=1 and 2=2", "1=1 and 2=2");
checkEx("update user set name = null where 1=1 and 2=3 or 1=1", "1=1 and 2=3 or 1=1");
checkEx("update user set name = null where 1=1 and (2=2)", "1=1 and (2=2)");
checkEx("update user set name = null where (1=1 and 2=2)", "(1=1 and 2=2)");
checkEx("update user set name = null where (1=1 and (2=3 or 3=3))", "(1=1 and (2=3 or 3=3))");
checkNotEx("update user set name = null where 1=?", "1=?");
}
@Test
void delete() {
checkEx("delete from user", "null where");
checkEx("delete from user where 1=1", "1=1");
checkEx("delete from user where 1<>2", "1<>2");
checkEx("delete from user where 1!=2", "1!=2");
checkEx("delete from user where 1=1 and 2=2", "1=1 and 2=2");
checkEx("delete from user where 1=1 and 2=3 or 1=1", "1=1 and 2=3 or 1=1");
}
void checkEx(String sql, String as) {
Exception e = null;
try {
interceptor.parserSingle(sql, null);
} catch (Exception x) {
e = x;
}
assertThat(e).as(as).isNotNull();
assertThat(e).as(as).isInstanceOf(MybatisPlusException.class);
}
void checkNotEx(String sql, String as) {
Exception e = null;
try {
interceptor.parserSingle(sql, null);
} catch (Exception x) {
e = x;
}
assertThat(e).as(as).isNull();
}
}

View File

@ -0,0 +1,63 @@
package com.baomidou.mybatisplus.test.extension.plugins.inner;
import com.baomidou.mybatisplus.extension.plugins.inner.DataChangeRecorderInnerInterceptor;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.insert.Insert;
import net.sf.jsqlparser.statement.update.Update;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import java.lang.reflect.Field;
import java.util.Map;
import java.util.Properties;
import java.util.Set;
/**
* @author miemie
* @since 2020-06-28
*/
class DataChangeRecorderInnerInterceptorTest {
private final DataChangeRecorderInnerInterceptor interceptor = new DataChangeRecorderInnerInterceptor();
@BeforeEach
public void initProperties() {
Properties properties = new Properties();
properties.put("ignoredTableColumns", "table_name1.column1,column2; h2user.*; *.column1,COLUMN2");
interceptor.setProperties(properties);
}
@Test
void setProperties() throws Exception {
final Object ignoreAllColumns = getFieldValue(interceptor, "ignoreAllColumns");
Assertions.assertEquals(Set.of("COLUMN1", "COLUMN2"), ignoreAllColumns);
final Object ignoredTableColumns = getFieldValue(interceptor, "ignoredTableColumns");
Assertions.assertEquals(Map.of("H2USER", Set.of("*"), "TABLE_NAME1", Set.of("COLUMN1", "COLUMN2")), ignoredTableColumns);
}
private Object getFieldValue(Object obj, String fieldName) throws NoSuchFieldException, IllegalAccessException {
final Field field = DataChangeRecorderInnerInterceptor.class.getDeclaredField(fieldName);
field.setAccessible(true);
return field.get(obj);
}
@Test
void processInsert() {
final Insert insert = new Insert();
insert.setTable(new Table("H2USER"));
final DataChangeRecorderInnerInterceptor.OperationResult operationResult = interceptor.processInsert(insert, null);
Assertions.assertEquals(operationResult.getTableName(), "H2USER:*");
Assertions.assertFalse(operationResult.isRecordStatus());
Assertions.assertNull(operationResult.getChangedData());
}
@Test
void processUpdate() {
final Update update = new Update();
update.setTable(new Table("H2USER"));
final DataChangeRecorderInnerInterceptor.OperationResult operationResult = interceptor.processUpdate(update, null, null, null);
Assertions.assertEquals(operationResult.getTableName(), "H2USER:*");
Assertions.assertFalse(operationResult.isRecordStatus());
Assertions.assertNull(operationResult.getChangedData());
}
}

View File

@ -0,0 +1,105 @@
package com.baomidou.mybatisplus.test.extension.plugins.inner;
import com.baomidou.mybatisplus.extension.plugins.handler.DataPermissionHandler;
import com.baomidou.mybatisplus.extension.plugins.inner.DataPermissionInterceptor;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import org.junit.jupiter.api.Test;
import java.util.HashMap;
import java.util.Map;
import static org.assertj.core.api.Assertions.assertThat;
/**
* 数据权限拦截器测试
*
* @author hubin
* @since 3.4.1 +
*/
public class DataPermissionInterceptorTest {
private static String TEST_1 = "com.baomidou.userMapper.selectByUsername";
private static String TEST_2 = "com.baomidou.userMapper.selectById";
private static String TEST_3 = "com.baomidou.roleMapper.selectByCompanyId";
private static String TEST_4 = "com.baomidou.roleMapper.selectById";
private static String TEST_5 = "com.baomidou.roleMapper.selectByRoleId";
/**
* 这里可以理解为数据库配置的数据权限规则 SQL
*/
private static Map<String, String> sqlSegmentMap = new HashMap<String, String>() {
{
put(TEST_1, "username='123' or userId IN (1,2,3)");
put(TEST_2, "u.state=1 and u.amount > 1000");
put(TEST_3, "companyId in (1,2,3)");
put(TEST_4, "username like 'abc%'");
put(TEST_5, "id=1 and role_id in (select id from sys_role)");
}
};
private static DataPermissionInterceptor interceptor = new DataPermissionInterceptor(new DataPermissionHandler() {
@Override
public Expression getSqlSegment(Expression where, String mappedStatementId) {
try {
String sqlSegment = sqlSegmentMap.get(mappedStatementId);
Expression sqlSegmentExpression = CCJSqlParserUtil.parseCondExpression(sqlSegment);
if (null != where) {
System.out.println("原 where : " + where.toString());
if (mappedStatementId.equals(TEST_4)) {
// 这里测试返回 OR 条件
return new OrExpression(where, sqlSegmentExpression);
}
return new AndExpression(where, sqlSegmentExpression);
}
return sqlSegmentExpression;
} catch (JSQLParserException e) {
e.printStackTrace();
}
return null;
}
});
@Test
void test1() {
assertSql(TEST_1, "select * from sys_user",
"SELECT * FROM sys_user WHERE username = '123' OR userId IN (1, 2, 3)");
}
@Test
void test2() {
assertSql(TEST_2, "select u.username from sys_user u join sys_user_role r on u.id=r.user_id where r.role_id=3",
"SELECT u.username FROM sys_user u JOIN sys_user_role r ON u.id = r.user_id WHERE r.role_id = 3 AND u.state = 1 AND u.amount > 1000");
}
@Test
void test3() {
assertSql(TEST_3, "select * from sys_role where company_id=6",
"SELECT * FROM sys_role WHERE company_id = 6 AND companyId IN (1, 2, 3)");
}
@Test
void test3unionAll() {
assertSql(TEST_3, "select * from sys_role where company_id=6 union all select * from sys_role where company_id=7",
"SELECT * FROM sys_role WHERE company_id = 6 AND companyId IN (1, 2, 3) UNION ALL SELECT * FROM sys_role WHERE company_id = 7 AND companyId IN (1, 2, 3)");
}
@Test
void test4() {
assertSql(TEST_4, "select * from sys_role where id=3",
"SELECT * FROM sys_role WHERE id = 3 OR username LIKE 'abc%'");
}
@Test
void test5() {
assertSql(TEST_5, "select * from sys_role where id=3",
"SELECT * FROM sys_role WHERE id = 3 AND id = 1 AND role_id IN (SELECT id FROM sys_role)");
}
void assertSql(String mappedStatementId, String sql, String targetSql) {
assertThat(interceptor.parserSingle(sql, mappedStatementId)).isEqualTo(targetSql);
}
}

View File

@ -0,0 +1,49 @@
package com.baomidou.mybatisplus.test.extension.plugins.inner;
import com.baomidou.mybatisplus.extension.plugins.inner.DynamicTableNameInnerInterceptor;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.assertEquals;
/**
* 动态表名内部拦截器测试
*
* @author miemie, hcl
* @since 2020-07-16
*/
class DynamicTableNameInnerInterceptorTest {
/**
* 测试 SQL 中的动态表名替换
*/
@Test
@SuppressWarnings({"SqlDialectInspection", "SqlNoDataSourceInspection"})
void doIt() {
DynamicTableNameInnerInterceptor interceptor = new DynamicTableNameInnerInterceptor();
interceptor.setTableNameHandler((sql, tableName) -> tableName + "_r");
// 表名相互包含
String origin = "SELECT * FROM t_user, t_user_role";
assertEquals("SELECT * FROM t_user_r, t_user_role_r", interceptor.changeTable(origin));
// 表名在末尾
origin = "SELECT * FROM t_user";
assertEquals("SELECT * FROM t_user_r", interceptor.changeTable(origin));
// 表名前后有注释
origin = "SELECT * FROM /**/t_user/* t_user */";
assertEquals("SELECT * FROM /**/t_user_r/* t_user */", interceptor.changeTable(origin));
// 值中带有表名
origin = "SELECT * FROM t_user WHERE name = 't_user'";
assertEquals("SELECT * FROM t_user_r WHERE name = 't_user'", interceptor.changeTable(origin));
// 别名被声明要替换
origin = "SELECT t_user.* FROM t_user_real t_user";
assertEquals("SELECT t_user.* FROM t_user_real_r t_user", interceptor.changeTable(origin));
// 别名被声明要替换
origin = "SELECT t.* FROM t_user_real t left join entity e on e.id = t.id";
assertEquals("SELECT t.* FROM t_user_real_r t left join entity_r e on e.id = t.id", interceptor.changeTable(origin));
}
}

View File

@ -0,0 +1,123 @@
package com.baomidou.mybatisplus.test.extension.plugins.inner;
import com.baomidou.mybatisplus.core.exceptions.MybatisPlusException;
import com.baomidou.mybatisplus.extension.plugins.inner.IllegalSQLInnerInterceptor;
import org.apache.ibatis.jdbc.SqlRunner;
import org.h2.jdbcx.JdbcDataSource;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import javax.sql.DataSource;
import java.sql.Connection;
import java.sql.SQLException;
/**
* @author miemie
* @since 2022-04-11
*/
class IllegalSQLInnerInterceptorTest {
private final IllegalSQLInnerInterceptor interceptor = new IllegalSQLInnerInterceptor();
private static DataSource dataSource;
@BeforeAll
public static void beforeAll() throws SQLException {
var jdbcDataSource = new JdbcDataSource();
jdbcDataSource.setURL("jdbc:h2:mem:test;MODE=mysql;DB_CLOSE_DELAY=-1;DB_CLOSE_ON_EXIT=FALSE");
jdbcDataSource.setPassword("");
jdbcDataSource.setUser("sa");
dataSource = jdbcDataSource;
Connection connection = jdbcDataSource.getConnection();
var sql = """
CREATE TABLE T_DEMO (
`a` int DEFAULT NULL,
`b` int DEFAULT NULL,
`c` int DEFAULT NULL,
KEY `ab_index1` (`a`,`b`)
);
CREATE TABLE T_TEST (
`a` int DEFAULT NULL,
`b` int DEFAULT NULL,
`c` int DEFAULT NULL,
KEY `ab_index2` (`a`,`b`)
);
""";
SqlRunner sqlRunner = new SqlRunner(connection);
sqlRunner.run(sql);
}
@Test
void test() {
Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("SELECT COUNT(*) AS total FROM t_user WHERE (client_id = ?)", null));
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select * from t_user", null));
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("update t_user set age = 18", null));
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("delete from t_user set age = 18", null));
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select * from t_user where age != 1", null));
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select * from t_user where age = 1 or name = 'test'", null));
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select * from t_user where (age = 1 or name = 'test')", null));
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("update t_user set age = 1 where age = 1 or name = 'test'", null));
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("update t_user set age = 1 where (age = 1 or name = 'test')", null));
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("delete t_user where age = 1 or name = 'test'", null));
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("delete t_user where (age = 1 or name = 'test')", null));
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select * from `T_DEMO` where b = 2", dataSource.getConnection()));
Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("select * from T_DEMO where `a` = 1 and `b` = 2", dataSource.getConnection()));
Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("select * from T_DEMO where a = 1 and `b` = 2", dataSource.getConnection()));
Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("select * from `T_DEMO` where `a` = 1 and `b` = 2", dataSource.getConnection()));
Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("select * from `T_DEMO` where a = 1 and `b` = 2", dataSource.getConnection()));
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select * from `T_DEMO` where c = 3", dataSource.getConnection()));
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select * from T_DEMO where c = 3", dataSource.getConnection()));
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select * from test.`T_DEMO` where c = 3", dataSource.getConnection()));
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select * from test.T_DEMO where c = 3", dataSource.getConnection()));
Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("SELECT * FROM `T_DEMO` a INNER JOIN `T_TEST` b ON a.a = b.a WHERE a.a = 1", dataSource.getConnection()));
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("SELECT * FROM `T_DEMO` a INNER JOIN `test` b ON a.a = b.a WHERE a.b = 1", dataSource.getConnection()));
Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("SELECT * FROM test.`T_DEMO` a INNER JOIN test.`T_TEST` b ON a.a = b.a WHERE a.a = 1", dataSource.getConnection()));
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("SELECT * FROM test.`T_DEMO` a INNER JOIN test.`T_TEST` b ON a.a = b.a WHERE a.b = 1", dataSource.getConnection()));
Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("SELECT * FROM T_DEMO a INNER JOIN `T_TEST` b ON a.a = b.a WHERE a.a = 1", dataSource.getConnection()));
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("SELECT * FROM T_DEMO a INNER JOIN `T_TEST` b ON a.a = b.a WHERE a.b = 1", dataSource.getConnection()));
Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("SELECT * FROM `T_DEMO` a LEFT JOIN `T_TEST` b ON a.a = b.a WHERE a.a = 1", dataSource.getConnection()));
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("SELECT * FROM `T_DEMO` a LEFT JOIN `T_TEST` b ON a.a = b.a WHERE a.b = 1", dataSource.getConnection()));
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select * from `T_DEMO` where (c = 3 OR b = 2)", dataSource.getConnection()));
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select * from `T_DEMO` where c = 3 OR b = 2", dataSource.getConnection()));
Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("select * from `T_DEMO` where a = 3 AND (c = 3 OR b = 2)", dataSource.getConnection()));
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select * from `T_DEMO` where (a = 3 AND c = 3 OR b = 2)", dataSource.getConnection()));
Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("select * from `T_DEMO` where a in (1,3,2)", dataSource.getConnection()));
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select * from `T_DEMO` where a in (1,3,2) or b = 2", dataSource.getConnection()));
Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("select * from `T_DEMO` where a in (1,3,2) AND b = 2", dataSource.getConnection()));
Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("select * from `T_DEMO` where (a = 3 AND c = 3 AND b = 2)", dataSource.getConnection()));
Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("select * from `T_DEMO` a INNER JOIN T_TEST b ON a.a = b.a where a.a = 3 AND (b.c = 3 OR b.b = 2)", dataSource.getConnection()));
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select * from `T_DEMO` where a != (SELECT b FROM T_TEST limit 1) ", dataSource.getConnection()));
//TODO 低版本这里的抛异常了.看着应该不用抛出
Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("select * from `T_DEMO` where a = (SELECT b FROM T_TEST limit 1) ", dataSource.getConnection()));
Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("select * from `T_DEMO` where a >= (SELECT b FROM T_TEST limit 1) ", dataSource.getConnection()));
Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("select * from `T_DEMO` where a <= (SELECT b FROM T_TEST limit 1) ", dataSource.getConnection()));
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select * from `T_DEMO` where b = (SELECT b FROM T_TEST limit 1) ", dataSource.getConnection()));
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select * from `T_DEMO` where b >= (SELECT b FROM T_TEST limit 1) ", dataSource.getConnection()));
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select * from `T_DEMO` where b <= (SELECT b FROM T_TEST limit 1) ", dataSource.getConnection()));
}
@Test
void testCount() {
Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("select count(*) from T_DEMO where a = 1 and `b` = 2", dataSource.getConnection()));
Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("select count(*) from (select * from T_DEMO where a = 1 and `b` = 2) a", dataSource.getConnection()));
Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("select count(*) from (select count(*) from (select * from T_DEMO where a = 1 and `b` = 2) a) c", dataSource.getConnection()));
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select count(*) from (select * from `T_DEMO`) a ", dataSource.getConnection()));
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select count(*) from (select * from `T_DEMO` where b = (SELECT b FROM T_TEST limit 1)) a ", dataSource.getConnection()));
}
}

View File

@ -0,0 +1,155 @@
package com.baomidou.mybatisplus.test.extension.plugins.inner;
import com.baomidou.mybatisplus.core.toolkit.StringPool;
import com.baomidou.mybatisplus.extension.plugins.handler.MultiDataPermissionHandler;
import com.baomidou.mybatisplus.extension.plugins.inner.DataPermissionInterceptor;
import com.google.common.collect.HashBasedTable;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Table;
import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import static org.assertj.core.api.Assertions.assertThat;
/**
* SQL多表场景的数据权限拦截器测试
*
* @author houkunlin
* @since 3.5.2 +
*/
public class MultiDataPermissionInterceptorTest {
private static final Logger logger = LoggerFactory.getLogger(MultiDataPermissionInterceptorTest.class);
/**
* 这里可以理解为数据库配置的数据权限规则 SQL
*/
private static final com.google.common.collect.Table<String, String, String> sqlSegmentMap;
private static final DataPermissionInterceptor interceptor;
private static String TEST_1 = "com.baomidou.userMapper.selectByUsername";
private static String TEST_2 = "com.baomidou.userMapper.selectById";
private static String TEST_3 = "com.baomidou.roleMapper.selectByCompanyId";
private static String TEST_4 = "com.baomidou.roleMapper.selectById";
private static String TEST_5 = "com.baomidou.roleMapper.selectByRoleId";
private static String TEST_6 = "com.baomidou.roleMapper.selectUserInfo";
private static String TEST_7 = "com.baomidou.roleMapper.summarySum";
private static String TEST_8_1 = "com.baomidou.CustomMapper.selectByOnlyMyData";
private static String TEST_8_2 = "com.baomidou.CustomMapper.selectByOnlyOrgData";
private static String TEST_8_3 = "com.baomidou.CustomMapper.selectByOnlyDeptData";
private static String TEST_8_4 = "com.baomidou.CustomMapper.selectByMyDataOrDeptData";
private static String TEST_8_5 = "com.baomidou.CustomMapper.selectByMyData";
static {
sqlSegmentMap = HashBasedTable.create();
sqlSegmentMap.put(TEST_1, "sys_user", "username='123' or userId IN (1,2,3)");
sqlSegmentMap.put(TEST_2, "sys_user", "u.state=1 and u.amount > 1000");
sqlSegmentMap.put(TEST_3, "sys_role", "companyId in (1,2,3)");
sqlSegmentMap.put(TEST_4, "sys_role", "username like 'abc%'");
sqlSegmentMap.put(TEST_5, "sys_role", "id=1 and role_id in (select id from sys_role)");
sqlSegmentMap.put(TEST_6, "sys_user", "u.state=1 and u.amount > 1000");
sqlSegmentMap.put(TEST_6, "sys_user_role", "r.role_id=3 AND r.role_id IN (7,9,11)");
sqlSegmentMap.put(TEST_7, "`fund`", "a.id = 1 AND a.year = 2022 AND a.create_user_id = 1111");
sqlSegmentMap.put(TEST_7, "`fund_month`", "b.fund_id = 2 AND b.month <= '2022-05'");
sqlSegmentMap.put(TEST_8_1, "fund", "user_id=1");
sqlSegmentMap.put(TEST_8_2, "fund", "org_id=1");
sqlSegmentMap.put(TEST_8_3, "fund", "dept_id=1");
sqlSegmentMap.put(TEST_8_4, "fund", "user_id=1 or dept_id=1");
sqlSegmentMap.put(TEST_8_5, "table1", "u.user_id=1");
sqlSegmentMap.put(TEST_8_5, "table2", "u.dept_id=1");
interceptor = new DataPermissionInterceptor(new MultiDataPermissionHandler() {
@Override
public Expression getSqlSegment(final Table table, final Expression where, final String mappedStatementId) {
try {
String sqlSegment = sqlSegmentMap.get(mappedStatementId, table.getName());
if (sqlSegment == null) {
logger.info("{} {} AS {} : NOT FOUND", mappedStatementId, table.getName(), table.getAlias());
return null;
}
if (table.getAlias() != null) {
// 替换表别名
sqlSegment = sqlSegment.replaceAll("u\\.", table.getAlias().getName() + StringPool.DOT);
}
Expression sqlSegmentExpression = CCJSqlParserUtil.parseCondExpression(sqlSegment);
logger.info("{} {} AS {} : {}", mappedStatementId, table.getName(), table.getAlias(), sqlSegmentExpression.toString());
return sqlSegmentExpression;
} catch (JSQLParserException e) {
e.printStackTrace();
}
return null;
}
});
}
@Test
void test1() {
assertSql(TEST_1, "select * from sys_user",
"SELECT * FROM sys_user WHERE username = '123' OR userId IN (1, 2, 3)");
}
@Test
void test2() {
assertSql(TEST_2, "select u.username from sys_user u join sys_user_role r on u.id=r.user_id where r.role_id=3",
"SELECT u.username FROM sys_user u JOIN sys_user_role r ON u.id = r.user_id WHERE r.role_id = 3 AND u.state = 1 AND u.amount > 1000");
}
@Test
void test3() {
assertSql(TEST_3, "select * from sys_role where company_id=6",
"SELECT * FROM sys_role WHERE company_id = 6 AND companyId IN (1, 2, 3)");
}
@Test
void test3unionAll() {
assertSql(TEST_3, "select * from sys_role where company_id=6 union all select * from sys_role where company_id=7",
"SELECT * FROM sys_role WHERE company_id = 6 AND companyId IN (1, 2, 3) UNION ALL SELECT * FROM sys_role WHERE company_id = 7 AND companyId IN (1, 2, 3)");
}
@Test
void test4() {
assertSql(TEST_4, "select * from sys_role where id=3",
"SELECT * FROM sys_role WHERE id = 3 AND username LIKE 'abc%'");
}
@Test
void test5() {
assertSql(TEST_5, "select * from sys_role where id=3",
"SELECT * FROM sys_role WHERE id = 3 AND id = 1 AND role_id IN (SELECT id FROM sys_role)");
}
@Test
void test6() {
// 显式指定 JOIN 类型时 JOIN 右侧表才能进行拼接条件
assertSql(TEST_6, "select u.username from sys_user u LEFT join sys_user_role r on u.id=r.user_id",
"SELECT u.username FROM sys_user u LEFT JOIN sys_user_role r ON u.id = r.user_id AND r.role_id = 3 AND r.role_id IN (7, 9, 11) WHERE u.state = 1 AND u.amount > 1000");
}
@Test
void test7() {
assertSql(TEST_7, "SELECT c.doc AS title, sum(c.total_paid_amount) AS total_paid_amount, sum(c.balance_amount) AS balance_amount FROM (SELECT `a`.`id`, `a`.`doc`, `b`.`month`, `b`.`total_paid_amount`, `b`.`balance_amount`, row_number() OVER (PARTITION BY `a`.`id` ORDER BY `b`.`month` DESC) AS `row_index` FROM `fund` `a` LEFT JOIN `fund_month` `b` ON `a`.`id` = `b`.`fund_id` AND `b`.`submit` = TRUE) c WHERE c.row_index = 1 GROUP BY title LIMIT 20",
"SELECT c.doc AS title, sum(c.total_paid_amount) AS total_paid_amount, sum(c.balance_amount) AS balance_amount FROM (SELECT `a`.`id`, `a`.`doc`, `b`.`month`, `b`.`total_paid_amount`, `b`.`balance_amount`, row_number() OVER (PARTITION BY `a`.`id` ORDER BY `b`.`month` DESC) AS `row_index` FROM `fund` `a` LEFT JOIN `fund_month` `b` ON `a`.`id` = `b`.`fund_id` AND `b`.`submit` = TRUE AND b.fund_id = 2 AND b.month <= '2022-05' WHERE a.id = 1 AND a.year = 2022 AND a.create_user_id = 1111) c WHERE c.row_index = 1 GROUP BY title LIMIT 20");
}
@Test
void test8() {
assertSql(TEST_8_1, "select * from fund where id=3",
"SELECT * FROM fund WHERE id = 3 AND user_id = 1");
assertSql(TEST_8_2, "select * from fund where id=3",
"SELECT * FROM fund WHERE id = 3 AND org_id = 1");
assertSql(TEST_8_3, "select * from fund where id=3",
"SELECT * FROM fund WHERE id = 3 AND dept_id = 1");
assertSql(TEST_8_4, "select * from fund where id=3",
"SELECT * FROM fund WHERE id = 3 AND user_id = 1 OR dept_id = 1");
// 修改之前旧版的多表数据权限对这个SQL的表现形式
// 输入 "WITH temp AS (SELECT t1.field1, t2.field2 FROM table1 t1 LEFT JOIN table2 t2 on t1.uid = t2.uid) SELECT * FROM temp"
// 输出 "WITH temp AS (SELECT t1.field1, t2.field2 FROM table1 t1 LEFT JOIN table2 t2 ON t1.uid = t2.uid) SELECT * FROM temp"
// 修改之后的多表数据权限对这个SQL的表现形式
assertSql(TEST_8_5, "WITH temp AS (SELECT t1.field1, t2.field2 FROM table1 t1 LEFT JOIN table2 t2 on t1.uid = t2.uid) SELECT * FROM temp",
"WITH temp AS (SELECT t1.field1, t2.field2 FROM table1 t1 LEFT JOIN table2 t2 ON t1.uid = t2.uid AND t2.dept_id = 1 WHERE t1.user_id = 1) SELECT * FROM temp");
}
void assertSql(String mappedStatementId, String sql, String targetSql) {
assertThat(interceptor.parserSingle(sql, mappedStatementId)).isEqualTo(targetSql);
}
}

View File

@ -0,0 +1,109 @@
package com.baomidou.mybatisplus.test.extension.plugins.inner;
import com.baomidou.mybatisplus.core.metadata.OrderItem;
import com.baomidou.mybatisplus.extension.plugins.inner.PaginationInnerInterceptor;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import org.junit.jupiter.api.Test;
import java.util.Arrays;
import static org.assertj.core.api.Assertions.assertThat;
/**
* @author miemie
* @since 2020-06-28
*/
class PaginationInnerInterceptorTest {
private final PaginationInnerInterceptor interceptor = new PaginationInnerInterceptor();
@Test
void optimizeCount() {
/* 能进行优化的 SQL */
assertsCountSql("select * from user u LEFT JOIN role r ON r.id = u.role_id",
"SELECT COUNT(*) AS total FROM user u");
assertsCountSql("select * from user u LEFT JOIN role r ON r.id = u.role_id WHERE u.xx = ?",
"SELECT COUNT(*) AS total FROM user u WHERE u.xx = ?");
assertsCountSql("select * from user u LEFT JOIN role r ON r.id = u.role_id LEFT JOIN permission p on p.id = u.per_id",
"SELECT COUNT(*) AS total FROM user u");
assertsCountSql("select * from user u LEFT JOIN role r ON r.id = u.role_id LEFT JOIN permission p on p.id = u.per_id WHERE u.xx = ?",
"SELECT COUNT(*) AS total FROM user u WHERE u.xx = ?");
assertsCountSql("select distinct id from table order by id", "SELECT COUNT(*) FROM (SELECT DISTINCT id FROM table) TOTAL");
assertsCountSql("select distinct id from table", "SELECT COUNT(*) FROM (SELECT DISTINCT id FROM table) TOTAL");
}
@Test
void notOptimizeCount() {
/* 不能进行优化的 SQL */
assertsCountSql("select * from user u LEFT JOIN role r ON r.id = u.role_id AND r.name = ? where u.xx = ?",
"SELECT COUNT(*) AS total FROM user u LEFT JOIN role r ON r.id = u.role_id AND r.name = ? WHERE u.xx = ?");
/* join 表与 where 条件大小写不同的情况 */
assertsCountSql("select * from user u LEFT JOIN role r ON r.id = u.role_id where R.NAME = ?",
"SELECT COUNT(*) AS total FROM user u LEFT JOIN role r ON r.id = u.role_id WHERE R.NAME = ?");
assertsCountSql("select * from user u LEFT JOIN role r ON r.id = u.role_id WHERE u.xax = ? AND r.cc = ? AND r.qq = ?",
"SELECT COUNT(*) AS total FROM user u LEFT JOIN role r ON r.id = u.role_id WHERE u.xax = ? AND r.cc = ? AND r.qq = ?");
}
@Test
void optimizeCountOrderBy() {
/* order by 里不带参数,去除order by */
assertsCountSql("SELECT * FROM comment ORDER BY name",
"SELECT COUNT(*) AS total FROM comment");
/* order by 里带参数,不去除order by */
assertsCountSql("SELECT * FROM comment ORDER BY (CASE WHEN creator = ? THEN 0 ELSE 1 END)",
"SELECT COUNT(*) AS total FROM comment ORDER BY (CASE WHEN creator = ? THEN 0 ELSE 1 END)");
}
@Test
void withAsCount() {
assertsCountSql("with A as (select * from class) select * from A",
"WITH A AS (SELECT * FROM class) SELECT COUNT(*) AS total FROM A");
}
@Test
void withAsOrderBy() {
assertsConcatOrderBy("with A as (select * from class) select * from A",
"WITH A AS (SELECT * FROM class) SELECT * FROM A ORDER BY column ASC",
OrderItem.asc("column"));
}
@Test
void groupByCount() {
assertsCountSql("SELECT * FROM record_1 WHERE id = ? GROUP BY date(date_time)",
"SELECT COUNT(*) FROM (SELECT * FROM record_1 WHERE id = ? GROUP BY date(date_time)) TOTAL");
}
@Test
void leftJoinSelectCount() {
assertsCountSql("select r.id, r.name, r.phone,rlr.total_top_up from reseller r " +
"left join (select ral.reseller_id, sum(ral.top_up_money) as total_top_up, sum(ral.acquire_money) as total_acquire " +
"from reseller_acquire_log ral " +
"group by ral.reseller_id) rlr on r.id = rlr.reseller_id " +
"order by r.created_at desc",
"SELECT COUNT(*) AS total FROM reseller r");
// 不优化
assertsCountSql("SELECT f.ca, f.cb FROM table_a f LEFT JOIN " +
"(SELECT ca FROM table_b WHERE cc = ?) rf on rf.ca = f.ca",
"SELECT COUNT(*) AS total FROM table_a f LEFT JOIN (SELECT ca FROM table_b WHERE cc = ?) rf ON rf.ca = f.ca");
assertsCountSql("select * from order_info left join (select count(1) from order_info where create_time between ? and ?) tt on 1=1 WHERE equipment_id=?",
"SELECT COUNT(*) AS total FROM order_info LEFT JOIN (SELECT count(1) FROM order_info WHERE create_time BETWEEN ? AND ?) tt ON 1 = 1 WHERE equipment_id = ?");
}
void assertsCountSql(String sql, String targetSql) {
assertThat(interceptor.autoCountSql(new Page<>(), sql)).isEqualTo(targetSql);
}
void assertsConcatOrderBy(String sql, String targetSql, OrderItem... orderItems) {
assertThat(interceptor.concatOrderBy(sql, Arrays.asList(orderItems))).isEqualTo(targetSql);
}
}

View File

@ -0,0 +1,470 @@
package com.baomidou.mybatisplus.test.extension.plugins.inner;
import com.baomidou.mybatisplus.extension.plugins.handler.TenantLineHandler;
import com.baomidou.mybatisplus.extension.plugins.inner.TenantLineInnerInterceptor;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.LongValue;
import org.junit.jupiter.api.Test;
import static org.assertj.core.api.Assertions.assertThat;
/**
* @author miemie
* @since 2020-07-30
*/
class TenantLineInnerInterceptorTest {
private final TenantLineInnerInterceptor interceptor = new TenantLineInnerInterceptor(new TenantLineHandler() {
private boolean ignoreFirst;// 需要执行 getTenantId 前必须先执行 ignoreTable
@Override
public Expression getTenantId() {
assertThat(ignoreFirst).isEqualTo(true);
ignoreFirst = false;
return new LongValue(1);
}
@Override
public boolean ignoreTable(String tableName) {
ignoreFirst = true;
return tableName.startsWith("with_as");
}
});
@Test
void insert() {
// plain
assertSql("insert into entity (id) values (?)",
"INSERT INTO entity (id, tenant_id) VALUES (?, 1)");
assertSql("insert into entity (id,name) values (?,?)",
"INSERT INTO entity (id, name, tenant_id) VALUES (?, ?, 1)");
// batch
assertSql("insert into entity (id) values (?),(?)",
"INSERT INTO entity (id, tenant_id) VALUES (?, 1), (?, 1)");
assertSql("insert into entity (id,name) values (?,?),(?,?)",
"INSERT INTO entity (id, name, tenant_id) VALUES (?, ?, 1), (?, ?, 1)");
// insert的列
assertSql("insert into entity value (?,?)",
"INSERT INTO entity VALUES (?, ?)");
// 自己加了insert的列
assertSql("insert into entity (id,name,tenant_id) value (?,?,?)",
"INSERT INTO entity (id, name, tenant_id) VALUES (?, ?, ?)");
// insert into select
assertSql("insert into entity (id,name) select id,name from entity2",
"INSERT INTO entity (id, name, tenant_id) SELECT id, name, tenant_id FROM entity2 WHERE tenant_id = 1");
assertSql("insert into entity (id,name) select * from entity2 e2",
"INSERT INTO entity (id, name, tenant_id) SELECT * FROM entity2 e2 WHERE e2.tenant_id = 1");
assertSql("insert into entity (id,name) select id,name from (select id,name from entity3 e3) t",
"INSERT INTO entity (id, name, tenant_id) SELECT id, name, tenant_id FROM (SELECT id, name, tenant_id FROM entity3 e3 WHERE e3.tenant_id = 1) t");
assertSql("insert into entity (id,name) select * from (select id,name from entity3 e3) t",
"INSERT INTO entity (id, name, tenant_id) SELECT * FROM (SELECT id, name, tenant_id FROM entity3 e3 WHERE e3.tenant_id = 1) t");
assertSql("insert into entity (id,name) select t.* from (select id,name from entity3 e3) t",
"INSERT INTO entity (id, name, tenant_id) SELECT t.* FROM (SELECT id, name, tenant_id FROM entity3 e3 WHERE e3.tenant_id = 1) t");
}
@Test
void delete() {
assertSql("delete from entity where id = ?",
"DELETE FROM entity WHERE id = ? AND tenant_id = 1");
}
@Test
void update() {
assertSql("update entity set name = ? where id = ?",
"UPDATE entity SET name = ? WHERE id = ? AND tenant_id = 1");
// set subSelect
assertSql("UPDATE entity e SET e.cq = (SELECT e1.total FROM entity e1 WHERE e1.id = ?) WHERE e.id = ?",
"UPDATE entity e SET e.cq = (SELECT e1.total FROM entity e1 WHERE e1.id = ? AND e1.tenant_id = 1) " +
"WHERE e.id = ? AND e.tenant_id = 1");
assertSql("UPDATE sys_user SET (name, age) = ('秋秋', 18), address = 'test'",
"UPDATE sys_user SET (name, age) = ('秋秋', 18), address = 'test' WHERE tenant_id = 1");
assertSql("UPDATE entity t1 INNER JOIN entity t2 ON t1.a= t2.a SET t1.b = t2.b, t1.c = t2.c",
"UPDATE entity t1 INNER JOIN entity t2 ON t1.a = t2.a SET t1.b = t2.b, t1.c = t2.c WHERE t1.tenant_id = 1");
}
@Test
void selectSingle() {
// 单表
assertSql("select * from entity where id = ?",
"SELECT * FROM entity WHERE id = ? AND tenant_id = 1");
assertSql("select * from entity where id = ? or name = ?",
"SELECT * FROM entity WHERE (id = ? OR name = ?) AND tenant_id = 1");
assertSql("SELECT * FROM entity WHERE (id = ? OR name = ?)",
"SELECT * FROM entity WHERE (id = ? OR name = ?) AND tenant_id = 1");
/* not */
assertSql("SELECT * FROM entity WHERE not (id = ? OR name = ?)",
"SELECT * FROM entity WHERE NOT (id = ? OR name = ?) AND tenant_id = 1");
assertSql("SELECT * FROM entity u WHERE not (u.id = ? OR u.name = ?)",
"SELECT * FROM entity u WHERE NOT (u.id = ? OR u.name = ?) AND u.tenant_id = 1");
}
@Test
void selectSubSelectIn() {
/* in */
assertSql("SELECT * FROM entity e WHERE e.id IN (select e1.id from entity1 e1 where e1.id = ?)",
"SELECT * FROM entity e WHERE e.id IN (SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.tenant_id = 1");
// 在最前
assertSql("SELECT * FROM entity e WHERE e.id IN " +
"(select e1.id from entity1 e1 where e1.id = ?) and e.id = ?",
"SELECT * FROM entity e WHERE e.id IN " +
"(SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.id = ? AND e.tenant_id = 1");
// 在最后
assertSql("SELECT * FROM entity e WHERE e.id = ? and e.id IN " +
"(select e1.id from entity1 e1 where e1.id = ?)",
"SELECT * FROM entity e WHERE e.id = ? AND e.id IN " +
"(SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.tenant_id = 1");
// 在中间
assertSql("SELECT * FROM entity e WHERE e.id = ? and e.id IN " +
"(select e1.id from entity1 e1 where e1.id = ?) and e.id = ?",
"SELECT * FROM entity e WHERE e.id = ? AND e.id IN " +
"(SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.id = ? AND e.tenant_id = 1");
}
@Test
void selectSubSelectEq() {
/* = */
assertSql("SELECT * FROM entity e WHERE e.id = (select e1.id from entity1 e1 where e1.id = ?)",
"SELECT * FROM entity e WHERE e.id = (SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.tenant_id = 1");
}
@Test
void selectSubSelectInnerNotEq() {
/* inner not = */
assertSql("SELECT * FROM entity e WHERE not (e.id = (select e1.id from entity1 e1 where e1.id = ?))",
"SELECT * FROM entity e WHERE NOT (e.id = (SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1)) AND e.tenant_id = 1");
assertSql("SELECT * FROM entity e WHERE not (e.id = (select e1.id from entity1 e1 where e1.id = ?) and e.id = ?)",
"SELECT * FROM entity e WHERE NOT (e.id = (SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.id = ?) AND e.tenant_id = 1");
}
@Test
void selectSubSelectExists() {
/* EXISTS */
assertSql("SELECT * FROM entity e WHERE EXISTS (select e1.id from entity1 e1 where e1.id = ?)",
"SELECT * FROM entity e WHERE EXISTS (SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.tenant_id = 1");
assertSql("SELECT EXISTS (SELECT 1 FROM entity1 e WHERE e.id = ? LIMIT 1)","SELECT EXISTS (SELECT 1 FROM entity1 e WHERE e.id = ? AND e.tenant_id = 1 LIMIT 1)");
/* NOT EXISTS */
assertSql("SELECT * FROM entity e WHERE NOT EXISTS (select e1.id from entity1 e1 where e1.id = ?)",
"SELECT * FROM entity e WHERE NOT EXISTS (SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.tenant_id = 1");
}
@Test
void selectWhereSubSelect() {
/* >= */
assertSql("SELECT * FROM entity e WHERE e.id >= (select e1.id from entity1 e1 where e1.id = ?)",
"SELECT * FROM entity e WHERE e.id >= (SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.tenant_id = 1");
/* <= */
assertSql("SELECT * FROM entity e WHERE e.id <= (select e1.id from entity1 e1 where e1.id = ?)",
"SELECT * FROM entity e WHERE e.id <= (SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.tenant_id = 1");
/* <> */
assertSql("SELECT * FROM entity e WHERE e.id <> (select e1.id from entity1 e1 where e1.id = ?)",
"SELECT * FROM entity e WHERE e.id <> (SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.tenant_id = 1");
}
@Test
void selectFromSelect() {
assertSql("SELECT * FROM (select e.id from entity e WHERE e.id = (select e1.id from entity1 e1 where e1.id = ?))",
"SELECT * FROM (SELECT e.id FROM entity e WHERE e.id = (SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.tenant_id = 1)");
}
@Test
void selectBodySubSelect() {
assertSql("select t1.col1,(select t2.col2 from t2 t2 where t1.col1=t2.col1) from t1 t1",
"SELECT t1.col1, (SELECT t2.col2 FROM t2 t2 WHERE t1.col1 = t2.col1 AND t2.tenant_id = 1) FROM t1 t1 WHERE t1.tenant_id = 1");
}
@Test
void selectBodyFuncSubSelect() {
assertSql("SELECT e1.*, IF((SELECT e2.id FROM entity2 e2 WHERE e2.id = 1) = 1, e2.type, e1.type) AS type " +
"FROM entity e1 WHERE e1.id = ?",
"SELECT e1.*, IF((SELECT e2.id FROM entity2 e2 WHERE e2.id = 1 AND e2.tenant_id = 1) = 1, e2.type, e1.type) AS type " +
"FROM entity e1 WHERE e1.id = ? AND e1.tenant_id = 1");
}
@Test
void selectLeftJoin() {
// left join
assertSql("SELECT * FROM entity e " +
"left join entity1 e1 on e1.id = e.id " +
"WHERE e.id = ? OR e.name = ?",
"SELECT * FROM entity e " +
"LEFT JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " +
"WHERE (e.id = ? OR e.name = ?) AND e.tenant_id = 1");
assertSql("SELECT * FROM entity e " +
"left join entity1 e1 on e1.id = e.id " +
"WHERE (e.id = ? OR e.name = ?)",
"SELECT * FROM entity e " +
"LEFT JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " +
"WHERE (e.id = ? OR e.name = ?) AND e.tenant_id = 1");
assertSql("SELECT * FROM entity e " +
"left join entity1 e1 on e1.id = e.id " +
"left join entity2 e2 on e1.id = e2.id",
"SELECT * FROM entity e " +
"LEFT JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " +
"LEFT JOIN entity2 e2 ON e1.id = e2.id AND e2.tenant_id = 1 " +
"WHERE e.tenant_id = 1");
}
@Test
void selectRightJoin() {
// right join
assertSql("SELECT * FROM entity e " +
"right join entity1 e1 on e1.id = e.id",
"SELECT * FROM entity e " +
"RIGHT JOIN entity1 e1 ON e1.id = e.id AND e.tenant_id = 1 " +
"WHERE e1.tenant_id = 1");
assertSql("SELECT * FROM with_as_1 e " +
"right join entity1 e1 on e1.id = e.id",
"SELECT * FROM with_as_1 e " +
"RIGHT JOIN entity1 e1 ON e1.id = e.id " +
"WHERE e1.tenant_id = 1");
assertSql("SELECT * FROM entity e " +
"right join entity1 e1 on e1.id = e.id " +
"WHERE e.id = ? OR e.name = ?",
"SELECT * FROM entity e " +
"RIGHT JOIN entity1 e1 ON e1.id = e.id AND e.tenant_id = 1 " +
"WHERE (e.id = ? OR e.name = ?) AND e1.tenant_id = 1");
assertSql("SELECT * FROM entity e " +
"right join entity1 e1 on e1.id = e.id " +
"right join entity2 e2 on e1.id = e2.id ",
"SELECT * FROM entity e " +
"RIGHT JOIN entity1 e1 ON e1.id = e.id AND e.tenant_id = 1 " +
"RIGHT JOIN entity2 e2 ON e1.id = e2.id AND e1.tenant_id = 1 " +
"WHERE e2.tenant_id = 1");
}
@Test
void selectMixJoin() {
assertSql("SELECT * FROM entity e " +
"right join entity1 e1 on e1.id = e.id " +
"left join entity2 e2 on e1.id = e2.id",
"SELECT * FROM entity e " +
"RIGHT JOIN entity1 e1 ON e1.id = e.id AND e.tenant_id = 1 " +
"LEFT JOIN entity2 e2 ON e1.id = e2.id AND e2.tenant_id = 1 " +
"WHERE e1.tenant_id = 1");
assertSql("SELECT * FROM entity e " +
"left join entity1 e1 on e1.id = e.id " +
"right join entity2 e2 on e1.id = e2.id",
"SELECT * FROM entity e " +
"LEFT JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " +
"RIGHT JOIN entity2 e2 ON e1.id = e2.id AND e.tenant_id = 1 " +
"WHERE e2.tenant_id = 1");
assertSql("SELECT * FROM entity e " +
"left join entity1 e1 on e1.id = e.id " +
"inner join entity2 e2 on e1.id = e2.id",
"SELECT * FROM entity e " +
"LEFT JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " +
"INNER JOIN entity2 e2 ON e1.id = e2.id AND e.tenant_id = 1 AND e2.tenant_id = 1");
}
@Test
void selectJoinSubSelect() {
assertSql("select * from (select * from entity e) e1 " +
"left join entity2 e2 on e1.id = e2.id",
"SELECT * FROM (SELECT * FROM entity e WHERE e.tenant_id = 1) e1 " +
"LEFT JOIN entity2 e2 ON e1.id = e2.id AND e2.tenant_id = 1");
assertSql("select * from entity1 e1 " +
"left join (select * from entity2 e2) e22 " +
"on e1.id = e22.id",
"SELECT * FROM entity1 e1 " +
"LEFT JOIN (SELECT * FROM entity2 e2 WHERE e2.tenant_id = 1) e22 " +
"ON e1.id = e22.id " +
"WHERE e1.tenant_id = 1");
}
@Test
void selectSubJoin() {
assertSql("select * FROM " +
"(entity1 e1 right JOIN entity2 e2 ON e1.id = e2.id)",
"SELECT * FROM " +
"(entity1 e1 RIGHT JOIN entity2 e2 ON e1.id = e2.id AND e1.tenant_id = 1) " +
"WHERE e2.tenant_id = 1");
assertSql("select * FROM " +
"(entity1 e1 LEFT JOIN entity2 e2 ON e1.id = e2.id)",
"SELECT * FROM " +
"(entity1 e1 LEFT JOIN entity2 e2 ON e1.id = e2.id AND e2.tenant_id = 1) " +
"WHERE e1.tenant_id = 1");
assertSql("select * FROM " +
"(entity1 e1 LEFT JOIN entity2 e2 ON e1.id = e2.id) " +
"right join entity3 e3 on e1.id = e3.id",
"SELECT * FROM " +
"(entity1 e1 LEFT JOIN entity2 e2 ON e1.id = e2.id AND e2.tenant_id = 1) " +
"RIGHT JOIN entity3 e3 ON e1.id = e3.id AND e1.tenant_id = 1 " +
"WHERE e3.tenant_id = 1");
assertSql("select * FROM entity e " +
"LEFT JOIN (entity1 e1 right join entity2 e2 ON e1.id = e2.id) " +
"on e.id = e2.id",
"SELECT * FROM entity e " +
"LEFT JOIN (entity1 e1 RIGHT JOIN entity2 e2 ON e1.id = e2.id AND e1.tenant_id = 1) " +
"ON e.id = e2.id AND e2.tenant_id = 1 " +
"WHERE e.tenant_id = 1");
assertSql("select * FROM entity e " +
"LEFT JOIN (entity1 e1 left join entity2 e2 ON e1.id = e2.id) " +
"on e.id = e2.id",
"SELECT * FROM entity e " +
"LEFT JOIN (entity1 e1 LEFT JOIN entity2 e2 ON e1.id = e2.id AND e2.tenant_id = 1) " +
"ON e.id = e2.id AND e1.tenant_id = 1 " +
"WHERE e.tenant_id = 1");
assertSql("select * FROM entity e " +
"RIGHT JOIN (entity1 e1 left join entity2 e2 ON e1.id = e2.id) " +
"on e.id = e2.id",
"SELECT * FROM entity e " +
"RIGHT JOIN (entity1 e1 LEFT JOIN entity2 e2 ON e1.id = e2.id AND e2.tenant_id = 1) " +
"ON e.id = e2.id AND e.tenant_id = 1 " +
"WHERE e1.tenant_id = 1");
}
@Test
void selectLeftJoinMultipleTrailingOn() {
// 多个 on 尾缀的
assertSql("SELECT * FROM entity e " +
"LEFT JOIN entity1 e1 " +
"LEFT JOIN entity2 e2 ON e2.id = e1.id " +
"ON e1.id = e.id " +
"WHERE (e.id = ? OR e.NAME = ?)",
"SELECT * FROM entity e " +
"LEFT JOIN entity1 e1 " +
"LEFT JOIN entity2 e2 ON e2.id = e1.id AND e2.tenant_id = 1 " +
"ON e1.id = e.id AND e1.tenant_id = 1 " +
"WHERE (e.id = ? OR e.NAME = ?) AND e.tenant_id = 1");
assertSql("SELECT * FROM entity e " +
"LEFT JOIN entity1 e1 " +
"LEFT JOIN with_as_A e2 ON e2.id = e1.id " +
"ON e1.id = e.id " +
"WHERE (e.id = ? OR e.NAME = ?)",
"SELECT * FROM entity e " +
"LEFT JOIN entity1 e1 " +
"LEFT JOIN with_as_A e2 ON e2.id = e1.id " +
"ON e1.id = e.id AND e1.tenant_id = 1 " +
"WHERE (e.id = ? OR e.NAME = ?) AND e.tenant_id = 1");
}
@Test
void selectInnerJoin() {
// inner join
assertSql("SELECT * FROM entity e " +
"inner join entity1 e1 on e1.id = e.id " +
"WHERE e.id = ? OR e.name = ?",
"SELECT * FROM entity e " +
"INNER JOIN entity1 e1 ON e1.id = e.id AND e.tenant_id = 1 AND e1.tenant_id = 1 " +
"WHERE e.id = ? OR e.name = ?");
assertSql("SELECT * FROM entity e " +
"inner join entity1 e1 on e1.id = e.id " +
"WHERE (e.id = ? OR e.name = ?)",
"SELECT * FROM entity e " +
"INNER JOIN entity1 e1 ON e1.id = e.id AND e.tenant_id = 1 AND e1.tenant_id = 1 " +
"WHERE (e.id = ? OR e.name = ?)");
// ignore table
assertSql("SELECT * FROM entity e " +
"inner join with_as_1 w1 on w1.id = e.id " +
"WHERE (e.id = ? OR e.name = ?)",
"SELECT * FROM entity e " +
"INNER JOIN with_as_1 w1 ON w1.id = e.id AND e.tenant_id = 1 " +
"WHERE (e.id = ? OR e.name = ?)");
// 隐式内连接
assertSql("SELECT * FROM entity e,entity1 e1 " +
"WHERE e.id = e1.id",
"SELECT * FROM entity e, entity1 e1 " +
"WHERE e.id = e1.id AND e.tenant_id = 1 AND e1.tenant_id = 1");
// 隐式内连接
assertSql("SELECT * FROM entity a, with_as_entity1 b " +
"WHERE a.id = b.id",
"SELECT * FROM entity a, with_as_entity1 b " +
"WHERE a.id = b.id AND a.tenant_id = 1");
assertSql("SELECT * FROM with_as_entity a, with_as_entity1 b " +
"WHERE a.id = b.id",
"SELECT * FROM with_as_entity a, with_as_entity1 b " +
"WHERE a.id = b.id");
// SubJoin with 隐式内连接
assertSql("SELECT * FROM (entity e,entity1 e1) " +
"WHERE e.id = e1.id",
"SELECT * FROM (entity e, entity1 e1) " +
"WHERE e.id = e1.id " +
"AND e.tenant_id = 1 AND e1.tenant_id = 1");
assertSql("SELECT * FROM ((entity e,entity1 e1),entity2 e2) " +
"WHERE e.id = e1.id and e.id = e2.id",
"SELECT * FROM ((entity e, entity1 e1), entity2 e2) " +
"WHERE e.id = e1.id AND e.id = e2.id " +
"AND e.tenant_id = 1 AND e1.tenant_id = 1 AND e2.tenant_id = 1");
assertSql("SELECT * FROM (entity e,(entity1 e1,entity2 e2)) " +
"WHERE e.id = e1.id and e.id = e2.id",
"SELECT * FROM (entity e, (entity1 e1, entity2 e2)) " +
"WHERE e.id = e1.id AND e.id = e2.id " +
"AND e.tenant_id = 1 AND e1.tenant_id = 1 AND e2.tenant_id = 1");
// 沙雕的括号写法
assertSql("SELECT * FROM (((entity e,entity1 e1))) " +
"WHERE e.id = e1.id",
"SELECT * FROM (((entity e, entity1 e1))) " +
"WHERE e.id = e1.id " +
"AND e.tenant_id = 1 AND e1.tenant_id = 1");
}
@Test
void selectSingleJoin() {
// join
assertSql("SELECT * FROM entity e join entity1 e1 on e1.id = e.id WHERE e.id = ? OR e.name = ?",
"SELECT * FROM entity e JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 WHERE (e.id = ? OR e.name = ?) AND e.tenant_id = 1");
assertSql("SELECT * FROM entity e join entity1 e1 on e1.id = e.id WHERE (e.id = ? OR e.name = ?)",
"SELECT * FROM entity e JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 WHERE (e.id = ? OR e.name = ?) AND e.tenant_id = 1");
}
@Test
void selectWithAs() {
assertSql("with with_as_A as (select * from entity) select * from with_as_A",
"WITH with_as_A AS (SELECT * FROM entity WHERE tenant_id = 1) SELECT * FROM with_as_A");
}
@Test
void testDuplicateKeyUpdate() {
assertSql("INSERT INTO entity (name,age) VALUES ('秋秋',18),('秋秋','22') ON DUPLICATE KEY UPDATE age=18",
"INSERT INTO entity (name, age, tenant_id) VALUES ('秋秋', 18, 1), ('秋秋', '22', 1) ON DUPLICATE KEY UPDATE age = 18, tenant_id = 1");
}
void assertSql(String sql, String targetSql) {
assertThat(interceptor.parserSingle(sql, null)).isEqualTo(targetSql);
}
}

View File

@ -0,0 +1,113 @@
package com.baomidou.mybatisplus.test.pagination;
import com.baomidou.mybatisplus.core.metadata.OrderItem;
import com.baomidou.mybatisplus.extension.plugins.inner.PaginationInnerInterceptor;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.statement.select.PlainSelect;
import net.sf.jsqlparser.statement.select.Select;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import java.util.ArrayList;
import java.util.List;
import static org.assertj.core.api.Assertions.assertThat;
/**
* SelectBody强转PlainSelect不支持sql里面最外层带union
* 用SetOperationList处理sql带union的语句
*/
class SelectBodyToPlainSelectTest {
private static final List<OrderItem> ITEMS = new ArrayList<>();
static {
ITEMS.add(OrderItem.asc("column"));
}
/**
* 报错的测试
*/
@Test
void testSelectBodyToPlainSelectThrowException() {
Select selectStatement = null;
try {
String originalUnionSql = "select * from test union select * from test";
selectStatement = (Select) CCJSqlParserUtil.parse(originalUnionSql);
} catch (JSQLParserException e) {
e.printStackTrace();
}
assert selectStatement != null;
Select finalSelectStatement = selectStatement;
Assertions.assertThrows(ClassCastException.class, () -> {
PlainSelect plainSelect = (PlainSelect) finalSelectStatement.getSelectBody();
});
}
@BeforeEach
void setup() {
List<OrderItem> orderItems = new ArrayList<>();
OrderItem order = new OrderItem();
order.setAsc(true);
order.setColumn("column");
orderItems.add(order);
OrderItem orderEmptyColumn = new OrderItem();
orderEmptyColumn.setAsc(false);
orderEmptyColumn.setColumn("");
orderItems.add(orderEmptyColumn);
}
@Test
void testPaginationInterceptorConcatOrderByBefore() {
String actualSql = new PaginationInnerInterceptor()
.concatOrderBy("select * from test", ITEMS);
assertThat(actualSql).isEqualTo("SELECT * FROM test ORDER BY column ASC");
String actualSqlWhere = new PaginationInnerInterceptor()
.concatOrderBy("select * from test where 1 = 1", ITEMS);
assertThat(actualSqlWhere).isEqualTo("SELECT * FROM test WHERE 1 = 1 ORDER BY column ASC");
}
@Test
void testPaginationInterceptorConcatOrderByFix() {
List<OrderItem> orderList = new ArrayList<>();
// 测试可能的 sql 注入 https://github.com/baomidou/mybatis-plus/issues/5745
orderList.add(OrderItem.asc("col umn"));
String actualSql = new PaginationInnerInterceptor()
.concatOrderBy("select * from test union select * from test2", orderList);
assertThat(actualSql).isEqualTo("SELECT * FROM test UNION SELECT * FROM test2 ORDER BY column ASC");
String actualSqlUnionAll = new PaginationInnerInterceptor()
.concatOrderBy("select * from test union all select * from test2", orderList);
assertThat(actualSqlUnionAll).isEqualTo("SELECT * FROM test UNION ALL SELECT * FROM test2 ORDER BY column ASC");
}
@Test
void testPaginationInterceptorConcatOrderByFixWithWhere() {
String actualSqlWhere = new PaginationInnerInterceptor()
.concatOrderBy("select * from test where 1 = 1 union select * from test2 where 1 = 1", ITEMS);
assertThat(actualSqlWhere).isEqualTo("SELECT * FROM test WHERE 1 = 1 UNION SELECT * FROM test2 WHERE 1 = 1 ORDER BY column ASC");
String actualSqlUnionAll = new PaginationInnerInterceptor()
.concatOrderBy("select * from test where 1 = 1 union all select * from test2 where 1 = 1 ", ITEMS);
assertThat(actualSqlUnionAll).isEqualTo("SELECT * FROM test WHERE 1 = 1 UNION ALL SELECT * FROM test2 WHERE 1 = 1 ORDER BY column ASC");
}
@Test
void testPaginationInterceptorOrderByEmptyColumnFix() {
String actualSql = new PaginationInnerInterceptor()
.concatOrderBy("select * from test", ITEMS);
assertThat(actualSql).isEqualTo("SELECT * FROM test ORDER BY column ASC");
String actualSqlWhere = new PaginationInnerInterceptor()
.concatOrderBy("select * from test where 1 = 1", ITEMS);
assertThat(actualSqlWhere).isEqualTo("SELECT * FROM test WHERE 1 = 1 ORDER BY column ASC");
}
}

View File

@ -0,0 +1,6 @@
dependencies {
api project(":mybatis-plus-extension")
implementation "${lib."slf4j-api"}"
}
compileJava.dependsOn(processResources)

View File

@ -28,5 +28,6 @@ dependencies {
testImplementation "${lib.'logback-classic'}"
testImplementation "${lib.cglib}"
testImplementation "${lib.postgresql}"
testImplementation project(":mybatis-plus-jsqlparser:mybatis-plus-jsqlparser-5.0")
// testCompile ('org.apache.phoenix:phoenix-core:5.0.0-HBase-2.0')
}

View File

@ -29,4 +29,8 @@ include ':spring-boot-starter:mybatis-plus-spring-boot-autoconfigure'
include ':spring-boot-starter:mybatis-plus-spring-boot-test-autoconfigure'
include ':spring-boot-starter:mybatis-plus-spring-boot3-starter'
include ':spring-boot-starter:mybatis-plus-spring-boot3-starter-test'
include 'mybatis-plus-jsqlparser'
include ':mybatis-plus-jsqlparser:mybatis-plus-jsqlparser-common'
include ':mybatis-plus-jsqlparser:mybatis-plus-jsqlparser-4.9'
include ':mybatis-plus-jsqlparser:mybatis-plus-jsqlparser-5.0'