处理jsqlparser版本(不兼容改动,需要自行引入模块).
https://github.com/baomidou/mybatis-plus/issues/6497
This commit is contained in:
parent
9ae6730d0e
commit
37251cf9db
@ -1,6 +1,5 @@
|
||||
dependencies {
|
||||
api project(":mybatis-plus-annotation")
|
||||
api "${lib.'jsqlparser'}"
|
||||
api "${lib.mybatis}"
|
||||
|
||||
implementation "${lib.cglib}"
|
||||
|
@ -47,9 +47,6 @@ class GeneratePomTest {
|
||||
Dependency mybatis = dependenciesMap.get("mybatis");
|
||||
Assertions.assertEquals("compile", mybatis.getScope());
|
||||
Assertions.assertFalse(mybatis.isOptional());
|
||||
Dependency jsqlParser = dependenciesMap.get("jsqlparser");
|
||||
Assertions.assertEquals("compile", jsqlParser.getScope());
|
||||
Assertions.assertFalse(jsqlParser.isOptional());
|
||||
Dependency cglib = dependenciesMap.get("cglib");
|
||||
Assertions.assertEquals("compile", cglib.getScope());
|
||||
Assertions.assertTrue(cglib.isOptional());
|
||||
|
@ -17,9 +17,6 @@ dependencies {
|
||||
implementation "${lib['mybatis-thymeleaf']}"
|
||||
implementation "${lib.'mybatis-velocity'}"
|
||||
implementation "${lib.'mybatis-freemarker'}"
|
||||
implementation "de.ruedigermoeller:fst:3.0.4-jdk17"
|
||||
implementation "com.github.ben-manes.caffeine:caffeine:2.9.3"
|
||||
testImplementation "io.github.classgraph:classgraph:4.8.176"
|
||||
testImplementation "${lib."spring-context-support"}"
|
||||
testImplementation "${lib.h2}"
|
||||
testImplementation "${lib.mysql}"
|
||||
|
@ -2,11 +2,11 @@ package com.baomidou.mybatisplus.test.extension.plugins.pagination.dialects;
|
||||
|
||||
import com.baomidou.mybatisplus.extension.plugins.pagination.DialectModel;
|
||||
import com.baomidou.mybatisplus.extension.plugins.pagination.dialects.IDialect;
|
||||
import com.google.common.collect.Lists;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.platform.commons.util.ReflectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
@ -36,7 +36,9 @@ class IDialectTest {
|
||||
DialectModel model = o.buildPaginationSql("select * from table", 1, 10);
|
||||
String sql = model.getDialectSql();
|
||||
if (!map.containsKey(sql)) {
|
||||
map.put(sql, Lists.newArrayList(i));
|
||||
ArrayList<Class<?>> list = new ArrayList<>();
|
||||
list.add(i);
|
||||
map.put(sql,list);
|
||||
} else {
|
||||
map.get(sql).add(i);
|
||||
}
|
||||
|
1
mybatis-plus-jsqlparser/build.gradle
Normal file
1
mybatis-plus-jsqlparser/build.gradle
Normal file
@ -0,0 +1 @@
|
||||
tasks.matching {it.group == 'publishing' || it.group == 'central publish' }.each { it.enabled = false }
|
@ -0,0 +1,14 @@
|
||||
dependencies {
|
||||
api "com.github.jsqlparser:jsqlparser:4.9"
|
||||
api project(":mybatis-plus-jsqlparser:mybatis-plus-jsqlparser-common")
|
||||
implementation "${lib."slf4j-api"}"
|
||||
implementation "de.ruedigermoeller:fst:3.0.3"
|
||||
implementation "com.github.ben-manes.caffeine:caffeine:2.9.3"
|
||||
testImplementation "io.github.classgraph:classgraph:4.8.176"
|
||||
testImplementation "${lib."spring-context-support"}"
|
||||
testImplementation "${lib.h2}"
|
||||
testImplementation group: 'com.google.guava', name: 'guava', version: '33.3.1-jre'
|
||||
|
||||
}
|
||||
|
||||
compileJava.dependsOn(processResources)
|
@ -0,0 +1,266 @@
|
||||
/*
|
||||
* Copyright (c) 2011-2024, baomidou (jobob@qq.com).
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package com.baomidou.mybatisplus.extension.parser.cache;
|
||||
|
||||
import org.nustaq.serialization.FSTConfiguration;
|
||||
|
||||
/**
|
||||
* Fst Factory
|
||||
*
|
||||
* @author miemie
|
||||
* @since 2023-08-06
|
||||
*/
|
||||
public class FstFactory {
|
||||
private static final FstFactory FACTORY = new FstFactory();
|
||||
private final FSTConfiguration conf = FSTConfiguration.createDefaultConfiguration();
|
||||
|
||||
public static FstFactory getDefaultFactory() {
|
||||
return FACTORY;
|
||||
}
|
||||
|
||||
public FstFactory() {
|
||||
conf.registerClass(net.sf.jsqlparser.expression.Alias.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.Alias.AliasColumn.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.AllValue.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.AnalyticExpression.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.AnyComparisonExpression.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.ArrayConstructor.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.ArrayExpression.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.CaseExpression.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.CastExpression.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.CollateExpression.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.ConnectByRootOperator.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.DateTimeLiteralExpression.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.DateValue.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.DoubleValue.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.ExtractExpression.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.FilterOverImpl.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.Function.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.HexValue.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.IntervalExpression.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.JdbcNamedParameter.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.JdbcParameter.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.JsonAggregateFunction.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.JsonExpression.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.JsonFunction.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.JsonFunctionExpression.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.JsonKeyValuePair.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.KeepExpression.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.LongValue.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.MySQLGroupConcat.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.MySQLIndexHint.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.NextValExpression.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.NotExpression.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.NullValue.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.NumericBind.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.OracleHierarchicalExpression.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.OracleHint.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.OracleNamedFunctionParameter.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.OrderByClause.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.OverlapsCondition.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.Parenthesis.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.PartitionByClause.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.RangeExpression.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.RowConstructor.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.RowGetExpression.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.SQLServerHints.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.SignedExpression.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.StringValue.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.TimeKeyExpression.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.TimeValue.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.TimestampValue.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.TimezoneExpression.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.TranscodingFunction.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.TrimFunction.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.UserVariable.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.VariableAssignment.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.WhenClause.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.WindowDefinition.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.WindowElement.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.WindowOffset.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.WindowRange.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.XMLSerializeExpr.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.operators.arithmetic.Addition.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.operators.arithmetic.BitwiseAnd.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.operators.arithmetic.BitwiseLeftShift.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.operators.arithmetic.BitwiseOr.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.operators.arithmetic.BitwiseRightShift.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.operators.arithmetic.BitwiseXor.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.operators.arithmetic.Concat.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.operators.arithmetic.Division.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.operators.arithmetic.IntegerDivision.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.operators.arithmetic.Modulo.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.operators.arithmetic.Multiplication.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.operators.arithmetic.Subtraction.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.operators.conditional.AndExpression.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.operators.conditional.OrExpression.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.operators.conditional.XorExpression.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.operators.relational.Between.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.operators.relational.ContainedBy.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.operators.relational.Contains.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.operators.relational.DoubleAnd.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.operators.relational.EqualsTo.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.operators.relational.ExistsExpression.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.operators.relational.ExpressionList.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.operators.relational.FullTextSearch.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.operators.relational.GeometryDistance.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.operators.relational.GreaterThan.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.operators.relational.InExpression.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.operators.relational.IsBooleanExpression.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.operators.relational.IsDistinctExpression.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.operators.relational.IsNullExpression.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.operators.relational.JsonOperator.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.operators.relational.LikeExpression.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.operators.relational.Matches.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.operators.relational.MemberOfExpression.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.operators.relational.MinorThan.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.operators.relational.MinorThanEquals.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.operators.relational.NamedExpressionList.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.operators.relational.NotEqualsTo.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.operators.relational.ParenthesedExpressionList.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.operators.relational.RegExpMatchOperator.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.operators.relational.SimilarToExpression.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.operators.relational.TSQLLeftJoin.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.operators.relational.TSQLRightJoin.class);
|
||||
conf.registerClass(net.sf.jsqlparser.parser.ASTNodeAccessImpl.class);
|
||||
conf.registerClass(net.sf.jsqlparser.parser.Token.class);
|
||||
conf.registerClass(net.sf.jsqlparser.schema.Column.class);
|
||||
conf.registerClass(net.sf.jsqlparser.schema.Sequence.class);
|
||||
conf.registerClass(net.sf.jsqlparser.schema.Synonym.class);
|
||||
conf.registerClass(net.sf.jsqlparser.schema.Table.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.Block.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.Commit.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.DeclareStatement.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.DeclareStatement.TypeDefExpr.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.DescribeStatement.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.ExplainStatement.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.ExplainStatement.Option.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.IfElseStatement.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.OutputClause.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.PurgeStatement.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.ReferentialAction.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.ResetStatement.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.RollbackStatement.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.SavepointStatement.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.SetStatement.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.ShowColumnsStatement.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.ShowStatement.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.Statements.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.UnsupportedStatement.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.UseStatement.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.alter.Alter.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.alter.AlterExpression.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.alter.AlterExpression.ColumnDataType.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.alter.AlterExpression.ColumnDropDefault.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.alter.AlterExpression.ColumnDropNotNull.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.alter.AlterSession.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.alter.AlterSystemStatement.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.alter.RenameTableStatement.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.alter.sequence.AlterSequence.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.analyze.Analyze.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.comment.Comment.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.create.function.CreateFunction.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.create.index.CreateIndex.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.create.procedure.CreateProcedure.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.create.schema.CreateSchema.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.create.sequence.CreateSequence.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.create.synonym.CreateSynonym.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.create.table.CheckConstraint.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.create.table.ColDataType.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.create.table.ColumnDefinition.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.create.table.CreateTable.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.create.table.ExcludeConstraint.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.create.table.ForeignKeyIndex.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.create.table.Index.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.create.table.Index.ColumnParams.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.create.table.NamedConstraint.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.create.table.RowMovement.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.create.view.AlterView.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.create.view.CreateView.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.delete.Delete.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.drop.Drop.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.execute.Execute.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.grant.Grant.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.insert.Insert.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.insert.InsertConflictAction.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.insert.InsertConflictTarget.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.merge.Merge.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.merge.MergeDelete.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.merge.MergeInsert.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.merge.MergeUpdate.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.refresh.RefreshMaterializedViewStatement.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.select.AllColumns.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.select.AllTableColumns.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.select.Distinct.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.select.ExceptOp.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.select.Fetch.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.select.First.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.select.ForClause.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.select.GroupByElement.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.select.IntersectOp.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.select.Join.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.select.KSQLJoinWindow.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.select.KSQLWindow.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.select.LateralSubSelect.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.select.LateralView.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.select.Limit.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.select.MinusOp.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.select.Offset.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.select.OptimizeFor.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.select.OrderByElement.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.select.ParenthesedFromItem.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.select.ParenthesedSelect.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.select.Pivot.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.select.PivotXml.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.select.PlainSelect.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.select.SelectItem.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.select.SetOperationList.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.select.Skip.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.select.TableFunction.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.select.TableStatement.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.select.Top.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.select.UnPivot.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.select.UnionOp.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.select.Values.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.select.Wait.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.select.WithIsolation.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.select.WithItem.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.show.ShowIndexStatement.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.show.ShowTablesStatement.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.truncate.Truncate.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.update.Update.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.update.UpdateSet.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.upsert.Upsert.class);
|
||||
conf.registerClass(net.sf.jsqlparser.util.cnfexpression.MultiAndExpression.class);
|
||||
conf.registerClass(net.sf.jsqlparser.util.cnfexpression.MultiOrExpression.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.BinaryExpression.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.operators.relational.ComparisonOperator.class);
|
||||
conf.registerClass(net.sf.jsqlparser.expression.operators.relational.OldOracleJoinBinaryExpression.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.CreateFunctionalStatement.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.select.Select.class);
|
||||
conf.registerClass(net.sf.jsqlparser.statement.select.SetOperation.class);
|
||||
conf.registerClass(net.sf.jsqlparser.util.cnfexpression.MultipleExpression.class);
|
||||
}
|
||||
|
||||
public byte[] asByteArray(Object obj) {
|
||||
return conf.asByteArray(obj);
|
||||
}
|
||||
|
||||
public Object asObject(byte[] bytes) {
|
||||
return conf.asObject(bytes);
|
||||
}
|
||||
}
|
@ -0,0 +1,424 @@
|
||||
/*
|
||||
* Copyright (c) 2011-2024, baomidou (jobob@qq.com).
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package com.baomidou.mybatisplus.extension.plugins.inner;
|
||||
|
||||
import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
|
||||
import com.baomidou.mybatisplus.extension.parser.JsqlParserSupport;
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.ToString;
|
||||
import net.sf.jsqlparser.expression.*;
|
||||
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
|
||||
import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
|
||||
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
|
||||
import net.sf.jsqlparser.expression.operators.relational.ExistsExpression;
|
||||
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
|
||||
import net.sf.jsqlparser.expression.operators.relational.InExpression;
|
||||
import net.sf.jsqlparser.schema.Table;
|
||||
import net.sf.jsqlparser.statement.select.*;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* 多表条件处理基对象,从原有的 {@link TenantLineInnerInterceptor} 拦截器中提取出来
|
||||
*
|
||||
* @author houkunlin
|
||||
* @since 3.5.2
|
||||
*/
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
@ToString(callSuper = true)
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@SuppressWarnings({"rawtypes"})
|
||||
public abstract class BaseMultiTableInnerInterceptor extends JsqlParserSupport implements InnerInterceptor {
|
||||
|
||||
protected void processSelectBody(Select selectBody, final String whereSegment) {
|
||||
if (selectBody == null) {
|
||||
return;
|
||||
}
|
||||
if (selectBody instanceof PlainSelect) {
|
||||
processPlainSelect((PlainSelect) selectBody, whereSegment);
|
||||
} else if (selectBody instanceof ParenthesedSelect) {
|
||||
ParenthesedSelect parenthesedSelect = (ParenthesedSelect) selectBody;
|
||||
processSelectBody(parenthesedSelect.getSelect(), whereSegment);
|
||||
} else if (selectBody instanceof SetOperationList) {
|
||||
SetOperationList operationList = (SetOperationList) selectBody;
|
||||
List<Select> selectBodyList = operationList.getSelects();
|
||||
if (CollectionUtils.isNotEmpty(selectBodyList)) {
|
||||
selectBodyList.forEach(body -> processSelectBody(body, whereSegment));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* delete update 语句 where 处理
|
||||
*/
|
||||
protected Expression andExpression(Table table, Expression where, final String whereSegment) {
|
||||
//获得where条件表达式
|
||||
final Expression expression = buildTableExpression(table, where, whereSegment);
|
||||
if (expression == null) {
|
||||
return where;
|
||||
}
|
||||
if (where != null) {
|
||||
if (where instanceof OrExpression) {
|
||||
return new AndExpression(new Parenthesis(where), expression);
|
||||
} else {
|
||||
return new AndExpression(where, expression);
|
||||
}
|
||||
}
|
||||
return expression;
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理 PlainSelect
|
||||
*/
|
||||
protected void processPlainSelect(final PlainSelect plainSelect, final String whereSegment) {
|
||||
//#3087 github
|
||||
List<SelectItem<?>> selectItems = plainSelect.getSelectItems();
|
||||
if (CollectionUtils.isNotEmpty(selectItems)) {
|
||||
selectItems.forEach(selectItem -> processSelectItem(selectItem, whereSegment));
|
||||
}
|
||||
|
||||
// 处理 where 中的子查询
|
||||
Expression where = plainSelect.getWhere();
|
||||
processWhereSubSelect(where, whereSegment);
|
||||
|
||||
// 处理 fromItem
|
||||
FromItem fromItem = plainSelect.getFromItem();
|
||||
List<Table> list = processFromItem(fromItem, whereSegment);
|
||||
List<Table> mainTables = new ArrayList<>(list);
|
||||
|
||||
// 处理 join
|
||||
List<Join> joins = plainSelect.getJoins();
|
||||
if (CollectionUtils.isNotEmpty(joins)) {
|
||||
processJoins(mainTables, joins, whereSegment);
|
||||
}
|
||||
|
||||
// 当有 mainTable 时,进行 where 条件追加
|
||||
if (CollectionUtils.isNotEmpty(mainTables)) {
|
||||
plainSelect.setWhere(builderExpression(where, mainTables, whereSegment));
|
||||
}
|
||||
}
|
||||
|
||||
private List<Table> processFromItem(FromItem fromItem, final String whereSegment) {
|
||||
// 处理括号括起来的表达式
|
||||
// while (fromItem instanceof ParenthesedFromItem) {
|
||||
// fromItem = ((ParenthesedFromItem) fromItem).getFromItem();
|
||||
// }
|
||||
|
||||
List<Table> mainTables = new ArrayList<>();
|
||||
// 无 join 时的处理逻辑
|
||||
if (fromItem instanceof Table) {
|
||||
Table fromTable = (Table) fromItem;
|
||||
mainTables.add(fromTable);
|
||||
} else if (fromItem instanceof ParenthesedFromItem) {
|
||||
// SubJoin 类型则还需要添加上 where 条件
|
||||
List<Table> tables = processSubJoin((ParenthesedFromItem) fromItem, whereSegment);
|
||||
mainTables.addAll(tables);
|
||||
} else {
|
||||
// 处理下 fromItem
|
||||
processOtherFromItem(fromItem, whereSegment);
|
||||
}
|
||||
return mainTables;
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理where条件内的子查询
|
||||
* <p>
|
||||
* 支持如下:
|
||||
* <ol>
|
||||
* <li>in</li>
|
||||
* <li>=</li>
|
||||
* <li>></li>
|
||||
* <li><</li>
|
||||
* <li>>=</li>
|
||||
* <li><=</li>
|
||||
* <li><></li>
|
||||
* <li>EXISTS</li>
|
||||
* <li>NOT EXISTS</li>
|
||||
* </ol>
|
||||
* <p>
|
||||
* 前提条件:
|
||||
* 1. 子查询必须放在小括号中
|
||||
* 2. 子查询一般放在比较操作符的右边
|
||||
*
|
||||
* @param where where 条件
|
||||
*/
|
||||
protected void processWhereSubSelect(Expression where, final String whereSegment) {
|
||||
if (where == null) {
|
||||
return;
|
||||
}
|
||||
if (where instanceof FromItem) {
|
||||
processOtherFromItem((FromItem) where, whereSegment);
|
||||
return;
|
||||
}
|
||||
if (where.toString().indexOf("SELECT") > 0) {
|
||||
// 有子查询
|
||||
if (where instanceof BinaryExpression) {
|
||||
// 比较符号 , and , or , 等等
|
||||
BinaryExpression expression = (BinaryExpression) where;
|
||||
processWhereSubSelect(expression.getLeftExpression(), whereSegment);
|
||||
processWhereSubSelect(expression.getRightExpression(), whereSegment);
|
||||
} else if (where instanceof InExpression) {
|
||||
// in
|
||||
InExpression expression = (InExpression) where;
|
||||
Expression inExpression = expression.getRightExpression();
|
||||
if (inExpression instanceof Select) {
|
||||
processSelectBody(((Select) inExpression), whereSegment);
|
||||
}
|
||||
} else if (where instanceof ExistsExpression) {
|
||||
// exists
|
||||
ExistsExpression expression = (ExistsExpression) where;
|
||||
processWhereSubSelect(expression.getRightExpression(), whereSegment);
|
||||
} else if (where instanceof NotExpression) {
|
||||
// not exists
|
||||
NotExpression expression = (NotExpression) where;
|
||||
processWhereSubSelect(expression.getExpression(), whereSegment);
|
||||
} else if (where instanceof Parenthesis) {
|
||||
Parenthesis expression = (Parenthesis) where;
|
||||
processWhereSubSelect(expression.getExpression(), whereSegment);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
protected void processSelectItem(SelectItem selectItem, final String whereSegment) {
|
||||
Expression expression = selectItem.getExpression();
|
||||
if (expression instanceof Select) {
|
||||
processSelectBody(((Select) expression), whereSegment);
|
||||
} else if (expression instanceof Function) {
|
||||
processFunction((Function) expression, whereSegment);
|
||||
} else if (expression instanceof ExistsExpression) {
|
||||
ExistsExpression existsExpression = (ExistsExpression) expression;
|
||||
processSelectBody((Select) existsExpression.getRightExpression(), whereSegment);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理函数
|
||||
* <p>支持: 1. select fun(args..) 2. select fun1(fun2(args..),args..)<p>
|
||||
* <p> fixed gitee pulls/141</p>
|
||||
*
|
||||
* @param function
|
||||
*/
|
||||
protected void processFunction(Function function, final String whereSegment) {
|
||||
ExpressionList<?> parameters = function.getParameters();
|
||||
if (parameters != null) {
|
||||
parameters.forEach(expression -> {
|
||||
if (expression instanceof Select) {
|
||||
processSelectBody(((Select) expression), whereSegment);
|
||||
} else if (expression instanceof Function) {
|
||||
processFunction((Function) expression, whereSegment);
|
||||
} else if (expression instanceof EqualsTo) {
|
||||
if (((EqualsTo) expression).getLeftExpression() instanceof Select) {
|
||||
processSelectBody(((Select) ((EqualsTo) expression).getLeftExpression()), whereSegment);
|
||||
}
|
||||
if (((EqualsTo) expression).getRightExpression() instanceof Select) {
|
||||
processSelectBody(((Select) ((EqualsTo) expression).getRightExpression()), whereSegment);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理子查询等
|
||||
*/
|
||||
protected void processOtherFromItem(FromItem fromItem, final String whereSegment) {
|
||||
// 去除括号
|
||||
// while (fromItem instanceof ParenthesisFromItem) {
|
||||
// fromItem = ((ParenthesisFromItem) fromItem).getFromItem();
|
||||
// }
|
||||
|
||||
if (fromItem instanceof ParenthesedSelect) {
|
||||
Select subSelect = (Select) fromItem;
|
||||
processSelectBody(subSelect, whereSegment);
|
||||
} else if (fromItem instanceof ParenthesedFromItem) {
|
||||
logger.debug("Perform a subQuery, if you do not give us feedback");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理 sub join
|
||||
*
|
||||
* @param subJoin subJoin
|
||||
* @return Table subJoin 中的主表
|
||||
*/
|
||||
private List<Table> processSubJoin(ParenthesedFromItem subJoin, final String whereSegment) {
|
||||
List<Table> mainTables = new ArrayList<>();
|
||||
while (subJoin.getJoins() == null && subJoin.getFromItem() instanceof ParenthesedFromItem) {
|
||||
subJoin = (ParenthesedFromItem) subJoin.getFromItem();
|
||||
}
|
||||
if (subJoin.getJoins() != null) {
|
||||
List<Table> list = processFromItem(subJoin.getFromItem(), whereSegment);
|
||||
mainTables.addAll(list);
|
||||
processJoins(mainTables, subJoin.getJoins(), whereSegment);
|
||||
}
|
||||
return mainTables;
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理 joins
|
||||
*
|
||||
* @param mainTables 可以为 null
|
||||
* @param joins join 集合
|
||||
* @return List<Table> 右连接查询的 Table 列表
|
||||
*/
|
||||
private List<Table> processJoins(List<Table> mainTables, List<Join> joins, final String whereSegment) {
|
||||
// join 表达式中最终的主表
|
||||
Table mainTable = null;
|
||||
// 当前 join 的左表
|
||||
Table leftTable = null;
|
||||
|
||||
if (mainTables.size() == 1) {
|
||||
mainTable = mainTables.get(0);
|
||||
leftTable = mainTable;
|
||||
}
|
||||
|
||||
//对于 on 表达式写在最后的 join,需要记录下前面多个 on 的表名
|
||||
Deque<List<Table>> onTableDeque = new LinkedList<>();
|
||||
for (Join join : joins) {
|
||||
// 处理 on 表达式
|
||||
FromItem joinItem = join.getRightItem();
|
||||
|
||||
// 获取当前 join 的表,subJoint 可以看作是一张表
|
||||
List<Table> joinTables = null;
|
||||
if (joinItem instanceof Table) {
|
||||
joinTables = new ArrayList<>();
|
||||
joinTables.add((Table) joinItem);
|
||||
} else if (joinItem instanceof ParenthesedFromItem) {
|
||||
joinTables = processSubJoin((ParenthesedFromItem) joinItem, whereSegment);
|
||||
}
|
||||
|
||||
if (joinTables != null) {
|
||||
|
||||
// 如果是隐式内连接
|
||||
if (join.isSimple()) {
|
||||
mainTables.addAll(joinTables);
|
||||
continue;
|
||||
}
|
||||
|
||||
// 当前表是否忽略
|
||||
Table joinTable = joinTables.get(0);
|
||||
|
||||
List<Table> onTables = null;
|
||||
// 如果不要忽略,且是右连接,则记录下当前表
|
||||
if (join.isRight()) {
|
||||
mainTable = joinTable;
|
||||
mainTables.clear();
|
||||
if (leftTable != null) {
|
||||
onTables = Collections.singletonList(leftTable);
|
||||
}
|
||||
} else if (join.isInner()) {
|
||||
if (mainTable == null) {
|
||||
onTables = Collections.singletonList(joinTable);
|
||||
} else {
|
||||
onTables = Arrays.asList(mainTable, joinTable);
|
||||
}
|
||||
mainTable = null;
|
||||
mainTables.clear();
|
||||
} else {
|
||||
onTables = Collections.singletonList(joinTable);
|
||||
}
|
||||
|
||||
if (mainTable != null && !mainTables.contains(mainTable)) {
|
||||
mainTables.add(mainTable);
|
||||
}
|
||||
|
||||
// 获取 join 尾缀的 on 表达式列表
|
||||
Collection<Expression> originOnExpressions = join.getOnExpressions();
|
||||
// 正常 join on 表达式只有一个,立刻处理
|
||||
if (originOnExpressions.size() == 1 && onTables != null) {
|
||||
List<Expression> onExpressions = new LinkedList<>();
|
||||
onExpressions.add(builderExpression(originOnExpressions.iterator().next(), onTables, whereSegment));
|
||||
join.setOnExpressions(onExpressions);
|
||||
leftTable = mainTable == null ? joinTable : mainTable;
|
||||
continue;
|
||||
}
|
||||
// 表名压栈,忽略的表压入 null,以便后续不处理
|
||||
onTableDeque.push(onTables);
|
||||
// 尾缀多个 on 表达式的时候统一处理
|
||||
if (originOnExpressions.size() > 1) {
|
||||
Collection<Expression> onExpressions = new LinkedList<>();
|
||||
for (Expression originOnExpression : originOnExpressions) {
|
||||
List<Table> currentTableList = onTableDeque.poll();
|
||||
if (CollectionUtils.isEmpty(currentTableList)) {
|
||||
onExpressions.add(originOnExpression);
|
||||
} else {
|
||||
onExpressions.add(builderExpression(originOnExpression, currentTableList, whereSegment));
|
||||
}
|
||||
}
|
||||
join.setOnExpressions(onExpressions);
|
||||
}
|
||||
leftTable = joinTable;
|
||||
} else {
|
||||
processOtherFromItem(joinItem, whereSegment);
|
||||
leftTable = null;
|
||||
}
|
||||
}
|
||||
|
||||
return mainTables;
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理条件
|
||||
*/
|
||||
protected Expression builderExpression(Expression currentExpression, List<Table> tables, final String whereSegment) {
|
||||
// 没有表需要处理直接返回
|
||||
if (CollectionUtils.isEmpty(tables)) {
|
||||
return currentExpression;
|
||||
}
|
||||
// 构造每张表的条件
|
||||
List<Expression> expressions = tables.stream()
|
||||
.map(item -> buildTableExpression(item, currentExpression, whereSegment))
|
||||
.filter(Objects::nonNull)
|
||||
.collect(Collectors.toList());
|
||||
|
||||
// 没有表需要处理直接返回
|
||||
if (CollectionUtils.isEmpty(expressions)) {
|
||||
return currentExpression;
|
||||
}
|
||||
|
||||
// 注入的表达式
|
||||
Expression injectExpression = expressions.get(0);
|
||||
// 如果有多表,则用 and 连接
|
||||
if (expressions.size() > 1) {
|
||||
for (int i = 1; i < expressions.size(); i++) {
|
||||
injectExpression = new AndExpression(injectExpression, expressions.get(i));
|
||||
}
|
||||
}
|
||||
|
||||
if (currentExpression == null) {
|
||||
return injectExpression;
|
||||
}
|
||||
if (currentExpression instanceof OrExpression) {
|
||||
return new AndExpression(new Parenthesis(currentExpression), injectExpression);
|
||||
} else {
|
||||
return new AndExpression(currentExpression, injectExpression);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建数据库表的查询条件
|
||||
*
|
||||
* @param table 表对象
|
||||
* @param where 当前where条件
|
||||
* @param whereSegment 所属Mapper对象全路径
|
||||
* @return 需要拼接的新条件(不会覆盖原有的where条件,只会在原有条件上再加条件),为 null 则不加入新的条件
|
||||
*/
|
||||
public abstract Expression buildTableExpression(final Table table, final Expression where, final String whereSegment);
|
||||
}
|
@ -0,0 +1,142 @@
|
||||
/*
|
||||
* Copyright (c) 2011-2024, baomidou (jobob@qq.com).
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package com.baomidou.mybatisplus.extension.plugins.inner;
|
||||
|
||||
import com.baomidou.mybatisplus.core.metadata.TableInfo;
|
||||
import com.baomidou.mybatisplus.core.metadata.TableInfoHelper;
|
||||
import com.baomidou.mybatisplus.core.plugins.InterceptorIgnoreHelper;
|
||||
import com.baomidou.mybatisplus.core.toolkit.Assert;
|
||||
import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
|
||||
import com.baomidou.mybatisplus.core.toolkit.StringPool;
|
||||
import com.baomidou.mybatisplus.core.toolkit.StringUtils;
|
||||
import com.baomidou.mybatisplus.extension.parser.JsqlParserSupport;
|
||||
import net.sf.jsqlparser.expression.BinaryExpression;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
import net.sf.jsqlparser.expression.Parenthesis;
|
||||
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
|
||||
import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
|
||||
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
|
||||
import net.sf.jsqlparser.expression.operators.relational.IsNullExpression;
|
||||
import net.sf.jsqlparser.expression.operators.relational.NotEqualsTo;
|
||||
import net.sf.jsqlparser.statement.delete.Delete;
|
||||
import net.sf.jsqlparser.statement.update.Update;
|
||||
import org.apache.ibatis.executor.statement.StatementHandler;
|
||||
import org.apache.ibatis.mapping.BoundSql;
|
||||
import org.apache.ibatis.mapping.MappedStatement;
|
||||
import org.apache.ibatis.mapping.SqlCommandType;
|
||||
|
||||
import java.sql.Connection;
|
||||
|
||||
/**
|
||||
* 攻击 SQL 阻断解析器,防止全表更新与删除
|
||||
*
|
||||
* @author hubin
|
||||
* @since 3.4.0
|
||||
*/
|
||||
public class BlockAttackInnerInterceptor extends JsqlParserSupport implements InnerInterceptor {
|
||||
|
||||
@Override
|
||||
public void beforePrepare(StatementHandler sh, Connection connection, Integer transactionTimeout) {
|
||||
PluginUtils.MPStatementHandler handler = PluginUtils.mpStatementHandler(sh);
|
||||
MappedStatement ms = handler.mappedStatement();
|
||||
SqlCommandType sct = ms.getSqlCommandType();
|
||||
if (sct == SqlCommandType.UPDATE || sct == SqlCommandType.DELETE) {
|
||||
if (InterceptorIgnoreHelper.willIgnoreBlockAttack(ms.getId())) {
|
||||
return;
|
||||
}
|
||||
BoundSql boundSql = handler.boundSql();
|
||||
parserMulti(boundSql.getSql(), null);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void processDelete(Delete delete, int index, String sql, Object obj) {
|
||||
this.checkWhere(delete.getTable().getName(), delete.getWhere(), "Prohibition of full table deletion");
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void processUpdate(Update update, int index, String sql, Object obj) {
|
||||
this.checkWhere(update.getTable().getName(), update.getWhere(), "Prohibition of table update operation");
|
||||
}
|
||||
|
||||
protected void checkWhere(String tableName, Expression where, String ex) {
|
||||
Assert.isFalse(this.fullMatch(where, this.getTableLogicField(tableName)), ex);
|
||||
}
|
||||
|
||||
private boolean fullMatch(Expression where, String logicField) {
|
||||
if (where == null) {
|
||||
return true;
|
||||
}
|
||||
if (StringUtils.isNotBlank(logicField)) {
|
||||
|
||||
if (where instanceof BinaryExpression) {
|
||||
BinaryExpression binaryExpression = (BinaryExpression) where;
|
||||
if (StringUtils.equals(binaryExpression.getLeftExpression().toString(), logicField) || StringUtils.equals(binaryExpression.getRightExpression().toString(), logicField)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if (where instanceof IsNullExpression) {
|
||||
IsNullExpression binaryExpression = (IsNullExpression) where;
|
||||
if (StringUtils.equals(binaryExpression.getLeftExpression().toString(), logicField)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (where instanceof EqualsTo) {
|
||||
// example: 1=1
|
||||
EqualsTo equalsTo = (EqualsTo) where;
|
||||
return StringUtils.equals(equalsTo.getLeftExpression().toString(), equalsTo.getRightExpression().toString());
|
||||
} else if (where instanceof NotEqualsTo) {
|
||||
// example: 1 != 2
|
||||
NotEqualsTo notEqualsTo = (NotEqualsTo) where;
|
||||
return !StringUtils.equals(notEqualsTo.getLeftExpression().toString(), notEqualsTo.getRightExpression().toString());
|
||||
} else if (where instanceof OrExpression) {
|
||||
|
||||
OrExpression orExpression = (OrExpression) where;
|
||||
return fullMatch(orExpression.getLeftExpression(), logicField) || fullMatch(orExpression.getRightExpression(), logicField);
|
||||
} else if (where instanceof AndExpression) {
|
||||
|
||||
AndExpression andExpression = (AndExpression) where;
|
||||
return fullMatch(andExpression.getLeftExpression(), logicField) && fullMatch(andExpression.getRightExpression(), logicField);
|
||||
} else if (where instanceof Parenthesis) {
|
||||
// example: (1 = 1)
|
||||
Parenthesis parenthesis = (Parenthesis) where;
|
||||
return fullMatch(parenthesis.getExpression(), logicField);
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取表名中的逻辑删除字段
|
||||
*
|
||||
* @param tableName 表名
|
||||
* @return 逻辑删除字段
|
||||
*/
|
||||
private String getTableLogicField(String tableName) {
|
||||
if (StringUtils.isBlank(tableName)) {
|
||||
return StringPool.EMPTY;
|
||||
}
|
||||
|
||||
TableInfo tableInfo = TableInfoHelper.getTableInfo(tableName);
|
||||
if (tableInfo == null || !tableInfo.isWithLogicDelete() || tableInfo.getLogicDeleteFieldInfo() == null) {
|
||||
return StringPool.EMPTY;
|
||||
}
|
||||
return tableInfo.getLogicDeleteFieldInfo().getColumn();
|
||||
}
|
||||
}
|
@ -0,0 +1,379 @@
|
||||
/*
|
||||
* Copyright (c) 2011-2024, baomidou (jobob@qq.com).
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package com.baomidou.mybatisplus.extension.plugins.inner;
|
||||
|
||||
import com.baomidou.mybatisplus.core.exceptions.MybatisPlusException;
|
||||
import com.baomidou.mybatisplus.core.plugins.InterceptorIgnoreHelper;
|
||||
import com.baomidou.mybatisplus.core.toolkit.Assert;
|
||||
import com.baomidou.mybatisplus.core.toolkit.EncryptUtils;
|
||||
import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
|
||||
import com.baomidou.mybatisplus.core.toolkit.StringUtils;
|
||||
import com.baomidou.mybatisplus.extension.parser.JsqlParserSupport;
|
||||
import com.baomidou.mybatisplus.extension.toolkit.SqlParserUtils;
|
||||
import lombok.Data;
|
||||
import net.sf.jsqlparser.expression.BinaryExpression;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
import net.sf.jsqlparser.expression.Function;
|
||||
import net.sf.jsqlparser.expression.Parenthesis;
|
||||
import net.sf.jsqlparser.expression.operators.arithmetic.Subtraction;
|
||||
import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
|
||||
import net.sf.jsqlparser.expression.operators.relational.InExpression;
|
||||
import net.sf.jsqlparser.expression.operators.relational.NotEqualsTo;
|
||||
import net.sf.jsqlparser.expression.operators.relational.ParenthesedExpressionList;
|
||||
import net.sf.jsqlparser.schema.Column;
|
||||
import net.sf.jsqlparser.schema.Table;
|
||||
import net.sf.jsqlparser.statement.delete.Delete;
|
||||
import net.sf.jsqlparser.statement.select.FromItem;
|
||||
import net.sf.jsqlparser.statement.select.Join;
|
||||
import net.sf.jsqlparser.statement.select.ParenthesedSelect;
|
||||
import net.sf.jsqlparser.statement.select.PlainSelect;
|
||||
import net.sf.jsqlparser.statement.select.Select;
|
||||
import net.sf.jsqlparser.statement.update.Update;
|
||||
import org.apache.ibatis.executor.statement.StatementHandler;
|
||||
import org.apache.ibatis.mapping.BoundSql;
|
||||
import org.apache.ibatis.mapping.MappedStatement;
|
||||
import org.apache.ibatis.mapping.SqlCommandType;
|
||||
|
||||
import java.sql.Connection;
|
||||
import java.sql.DatabaseMetaData;
|
||||
import java.sql.ResultSet;
|
||||
import java.sql.SQLException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
/**
|
||||
* 由于开发人员水平参差不齐,即使订了开发规范很多人也不遵守
|
||||
* <p>SQL是影响系统性能最重要的因素,所以拦截掉垃圾SQL语句</p>
|
||||
* <br>
|
||||
* <p>拦截SQL类型的场景</p>
|
||||
* <p>1.必须使用到索引,包含left join连接字段,符合索引最左原则</p>
|
||||
* <p>必须使用索引好处,</p>
|
||||
* <p>1.1 如果因为动态SQL,bug导致update的where条件没有带上,全表更新上万条数据</p>
|
||||
* <p>1.2 如果检查到使用了索引,SQL性能基本不会太差</p>
|
||||
* <br>
|
||||
* <p>2.SQL尽量单表执行,有查询left join的语句,必须在注释里面允许该SQL运行,否则会被拦截,有left join的语句,如果不能拆成单表执行的SQL,请leader商量在做</p>
|
||||
* <p>https://gaoxianglong.github.io/shark</p>
|
||||
* <p>SQL尽量单表执行的好处</p>
|
||||
* <p>2.1 查询条件简单、易于开理解和维护;</p>
|
||||
* <p>2.2 扩展性极强;(可为分库分表做准备)</p>
|
||||
* <p>2.3 缓存利用率高;</p>
|
||||
* <p>2.在字段上使用函数</p>
|
||||
* <br>
|
||||
* <p>3.where条件为空</p>
|
||||
* <p>4.where条件使用了 !=</p>
|
||||
* <p>5.where条件使用了 not 关键字</p>
|
||||
* <p>6.where条件使用了 or 关键字</p>
|
||||
* <p>7.where条件使用了 使用子查询</p>
|
||||
*
|
||||
* @author willenfoo
|
||||
* @since 3.4.0
|
||||
*/
|
||||
public class IllegalSQLInnerInterceptor extends JsqlParserSupport implements InnerInterceptor {
|
||||
|
||||
/**
|
||||
* 缓存验证结果,提高性能
|
||||
*/
|
||||
private static final Set<String> cacheValidResult = new HashSet<>();
|
||||
/**
|
||||
* 缓存表的索引信息
|
||||
*/
|
||||
private static final Map<String, List<IndexInfo>> indexInfoMap = new ConcurrentHashMap<>();
|
||||
|
||||
@Override
|
||||
public void beforePrepare(StatementHandler sh, Connection connection, Integer transactionTimeout) {
|
||||
PluginUtils.MPStatementHandler mpStatementHandler = PluginUtils.mpStatementHandler(sh);
|
||||
MappedStatement ms = mpStatementHandler.mappedStatement();
|
||||
SqlCommandType sct = ms.getSqlCommandType();
|
||||
if (sct == SqlCommandType.INSERT || InterceptorIgnoreHelper.willIgnoreIllegalSql(ms.getId())) {
|
||||
return;
|
||||
}
|
||||
BoundSql boundSql = mpStatementHandler.boundSql();
|
||||
String originalSql = boundSql.getSql();
|
||||
logger.debug("检查SQL是否合规,SQL:" + originalSql);
|
||||
String md5Base64 = EncryptUtils.md5Base64(originalSql);
|
||||
if (cacheValidResult.contains(md5Base64)) {
|
||||
logger.debug("该SQL已验证,无需再次验证,,SQL:" + originalSql);
|
||||
return;
|
||||
}
|
||||
parserSingle(originalSql, connection);
|
||||
//缓存验证结果
|
||||
cacheValidResult.add(md5Base64);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void processSelect(Select select, int index, String sql, Object obj) {
|
||||
if (select instanceof PlainSelect) {
|
||||
PlainSelect plainSelect = (PlainSelect) select;
|
||||
FromItem fromItem = ((PlainSelect) select).getFromItem();
|
||||
while (fromItem instanceof ParenthesedSelect) {
|
||||
ParenthesedSelect parenthesedSelect = (ParenthesedSelect) fromItem;
|
||||
plainSelect = (PlainSelect) parenthesedSelect.getSelect();
|
||||
fromItem = plainSelect.getFromItem();
|
||||
}
|
||||
Expression where = plainSelect.getWhere();
|
||||
Assert.notNull(where, "非法SQL,必须要有where条件");
|
||||
Table table = (Table) plainSelect.getFromItem();
|
||||
List<Join> joins = plainSelect.getJoins();
|
||||
validWhere(where, table, (Connection) obj);
|
||||
validJoins(joins, table, (Connection) obj);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void processUpdate(Update update, int index, String sql, Object obj) {
|
||||
Expression where = update.getWhere();
|
||||
Assert.notNull(where, "非法SQL,必须要有where条件");
|
||||
Table table = update.getTable();
|
||||
List<Join> joins = update.getJoins();
|
||||
validWhere(where, table, (Connection) obj);
|
||||
validJoins(joins, table, (Connection) obj);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void processDelete(Delete delete, int index, String sql, Object obj) {
|
||||
Expression where = delete.getWhere();
|
||||
Assert.notNull(where, "非法SQL,必须要有where条件");
|
||||
Table table = delete.getTable();
|
||||
List<Join> joins = delete.getJoins();
|
||||
validWhere(where, table, (Connection) obj);
|
||||
validJoins(joins, table, (Connection) obj);
|
||||
}
|
||||
|
||||
/**
|
||||
* 验证expression对象是不是 or、not等等
|
||||
*
|
||||
* @param expression ignore
|
||||
*/
|
||||
private void validExpression(Expression expression) {
|
||||
while (expression instanceof Parenthesis) {
|
||||
Parenthesis parenthesis = (Parenthesis) expression;
|
||||
expression = parenthesis.getExpression();
|
||||
}
|
||||
//where条件使用了 or 关键字
|
||||
if (expression instanceof OrExpression) {
|
||||
OrExpression orExpression = (OrExpression) expression;
|
||||
throw new MybatisPlusException("非法SQL,where条件中不能使用【or】关键字,错误or信息:" + orExpression.toString());
|
||||
} else if (expression instanceof NotEqualsTo) {
|
||||
NotEqualsTo notEqualsTo = (NotEqualsTo) expression;
|
||||
throw new MybatisPlusException("非法SQL,where条件中不能使用【!=】关键字,错误!=信息:" + notEqualsTo.toString());
|
||||
} else if (expression instanceof BinaryExpression) {
|
||||
BinaryExpression binaryExpression = (BinaryExpression) expression;
|
||||
// TODO 升级 jsqlparser 后待实现
|
||||
// if (binaryExpression.isNot()) {
|
||||
// throw new MybatisPlusException("非法SQL,where条件中不能使用【not】关键字,错误not信息:" + binaryExpression.toString());
|
||||
// }
|
||||
if (binaryExpression.getLeftExpression() instanceof Function) {
|
||||
Function function = (Function) binaryExpression.getLeftExpression();
|
||||
throw new MybatisPlusException("非法SQL,where条件中不能使用数据库函数,错误函数信息:" + function.toString());
|
||||
}
|
||||
if (binaryExpression.getRightExpression() instanceof Subtraction) {
|
||||
Subtraction subSelect = (Subtraction) binaryExpression.getRightExpression();
|
||||
throw new MybatisPlusException("非法SQL,where条件中不能使用子查询,错误子查询SQL信息:" + subSelect.toString());
|
||||
}
|
||||
} else if (expression instanceof InExpression) {
|
||||
InExpression inExpression = (InExpression) expression;
|
||||
if (inExpression.getRightExpression() instanceof Subtraction) {
|
||||
Subtraction subSelect = (Subtraction) inExpression.getRightExpression();
|
||||
throw new MybatisPlusException("非法SQL,where条件中不能使用子查询,错误子查询SQL信息:" + subSelect.toString());
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* 如果SQL用了 left Join,验证是否有or、not等等,并且验证是否使用了索引
|
||||
*
|
||||
* @param joins ignore
|
||||
* @param table ignore
|
||||
* @param connection ignore
|
||||
*/
|
||||
private void validJoins(List<Join> joins, Table table, Connection connection) {
|
||||
//允许执行join,验证jion是否使用索引等等
|
||||
if (joins != null) {
|
||||
for (Join join : joins) {
|
||||
Table rightTable = (Table) join.getRightItem();
|
||||
Collection<Expression> onExpressions = join.getOnExpressions();
|
||||
for (Expression expression : onExpressions) {
|
||||
validWhere(expression, table, rightTable, connection);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查是否使用索引
|
||||
*
|
||||
* @param table ignore
|
||||
* @param columnName ignore
|
||||
* @param connection ignore
|
||||
*/
|
||||
private void validUseIndex(Table table, String columnName, Connection connection) {
|
||||
//是否使用索引
|
||||
boolean useIndexFlag = false;
|
||||
if (StringUtils.isNotBlank(columnName)) {
|
||||
String tableInfo = table.getName();
|
||||
//表存在的索引
|
||||
String dbName = null;
|
||||
String tableName;
|
||||
String[] tableArray = tableInfo.split("\\.");
|
||||
if (tableArray.length == 1) {
|
||||
tableName = tableArray[0];
|
||||
} else {
|
||||
dbName = tableArray[0];
|
||||
tableName = tableArray[1];
|
||||
}
|
||||
columnName = SqlParserUtils.removeWrapperSymbol(columnName);
|
||||
List<IndexInfo> indexInfos = getIndexInfos(dbName, tableName, connection);
|
||||
for (IndexInfo indexInfo : indexInfos) {
|
||||
if (indexInfo.getColumnName().equalsIgnoreCase(columnName)) {
|
||||
useIndexFlag = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!useIndexFlag) {
|
||||
throw new MybatisPlusException("非法SQL,SQL未使用到索引, table:" + table + ", columnName:" + columnName);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 验证where条件的字段,是否有not、or等等,并且where的第一个字段,必须使用索引
|
||||
*
|
||||
* @param expression ignore
|
||||
* @param table ignore
|
||||
* @param connection ignore
|
||||
*/
|
||||
private void validWhere(Expression expression, Table table, Connection connection) {
|
||||
validWhere(expression, table, null, connection);
|
||||
}
|
||||
|
||||
/**
|
||||
* 验证where条件的字段,是否有not、or等等,并且where的第一个字段,必须使用索引
|
||||
*
|
||||
* @param expression ignore
|
||||
* @param table ignore
|
||||
* @param joinTable ignore
|
||||
* @param connection ignore
|
||||
*/
|
||||
private void validWhere(Expression expression, Table table, Table joinTable, Connection connection) {
|
||||
validExpression(expression);
|
||||
if (expression instanceof BinaryExpression) {
|
||||
//获得左边表达式
|
||||
Expression leftExpression = ((BinaryExpression) expression).getLeftExpression();
|
||||
validExpression(leftExpression);
|
||||
|
||||
//如果左边表达式为Column对象,则直接获得列名
|
||||
if (leftExpression instanceof Column) {
|
||||
Expression rightExpression = ((BinaryExpression) expression).getRightExpression();
|
||||
if (joinTable != null && rightExpression instanceof Column) {
|
||||
if (Objects.equals(((Column) rightExpression).getTable().getName(), table.getAlias().getName())) {
|
||||
validUseIndex(table, ((Column) rightExpression).getColumnName(), connection);
|
||||
validUseIndex(joinTable, ((Column) leftExpression).getColumnName(), connection);
|
||||
} else {
|
||||
validUseIndex(joinTable, ((Column) rightExpression).getColumnName(), connection);
|
||||
validUseIndex(table, ((Column) leftExpression).getColumnName(), connection);
|
||||
}
|
||||
} else {
|
||||
//获得列名
|
||||
validUseIndex(table, ((Column) leftExpression).getColumnName(), connection);
|
||||
}
|
||||
}
|
||||
//如果BinaryExpression,进行迭代
|
||||
else if (leftExpression instanceof BinaryExpression) {
|
||||
validWhere(leftExpression, table, joinTable, connection);
|
||||
}
|
||||
|
||||
//获得右边表达式,并分解
|
||||
if (joinTable != null) {
|
||||
Expression rightExpression = ((BinaryExpression) expression).getRightExpression();
|
||||
validExpression(rightExpression);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 得到表的索引信息
|
||||
*
|
||||
* @param dbName ignore
|
||||
* @param tableName ignore
|
||||
* @param conn ignore
|
||||
* @return ignore
|
||||
*/
|
||||
public List<IndexInfo> getIndexInfos(String dbName, String tableName, Connection conn) {
|
||||
return getIndexInfos(null, dbName, tableName, conn);
|
||||
}
|
||||
|
||||
/**
|
||||
* 得到表的索引信息
|
||||
*
|
||||
* @param key ignore
|
||||
* @param dbName ignore
|
||||
* @param tableName ignore
|
||||
* @param conn ignore
|
||||
* @return ignore
|
||||
*/
|
||||
public List<IndexInfo> getIndexInfos(String key, String dbName, String tableName, Connection conn) {
|
||||
List<IndexInfo> indexInfos = null;
|
||||
if (StringUtils.isNotBlank(key)) {
|
||||
indexInfos = indexInfoMap.get(key);
|
||||
}
|
||||
if (indexInfos == null || indexInfos.isEmpty()) {
|
||||
ResultSet rs;
|
||||
try {
|
||||
DatabaseMetaData metadata = conn.getMetaData();
|
||||
String catalog = StringUtils.isBlank(dbName) ? conn.getCatalog() : dbName;
|
||||
String schema = StringUtils.isBlank(dbName) ? conn.getSchema() : dbName;
|
||||
rs = metadata.getIndexInfo(catalog, schema, SqlParserUtils.removeWrapperSymbol(tableName), false, true);
|
||||
indexInfos = new ArrayList<>();
|
||||
while (rs.next()) {
|
||||
//索引中的列序列号等于1,才有效
|
||||
if (Objects.equals(rs.getString(8), "1")) {
|
||||
IndexInfo indexInfo = new IndexInfo();
|
||||
indexInfo.setDbName(rs.getString(1));
|
||||
indexInfo.setTableName(rs.getString(3));
|
||||
indexInfo.setColumnName(rs.getString(9));
|
||||
indexInfos.add(indexInfo);
|
||||
}
|
||||
}
|
||||
if (StringUtils.isNotBlank(key)) {
|
||||
indexInfoMap.put(key, indexInfos);
|
||||
}
|
||||
} catch (SQLException e) {
|
||||
logger.error(String.format("getIndexInfo fault, with key:%s, dbName:%s, tableName:%s", key, dbName, tableName), e);
|
||||
}
|
||||
}
|
||||
return indexInfos;
|
||||
}
|
||||
|
||||
/**
|
||||
* 索引对象
|
||||
*/
|
||||
@Data
|
||||
private static class IndexInfo {
|
||||
|
||||
private String dbName;
|
||||
|
||||
private String tableName;
|
||||
|
||||
private String columnName;
|
||||
}
|
||||
}
|
@ -0,0 +1,275 @@
|
||||
/*
|
||||
* Copyright (c) 2011-2024, baomidou (jobob@qq.com).
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package com.baomidou.mybatisplus.extension.plugins.inner;
|
||||
|
||||
import com.baomidou.mybatisplus.core.plugins.InterceptorIgnoreHelper;
|
||||
import com.baomidou.mybatisplus.core.toolkit.*;
|
||||
import com.baomidou.mybatisplus.extension.plugins.handler.TenantLineHandler;
|
||||
import com.baomidou.mybatisplus.extension.toolkit.PropertyMapper;
|
||||
import lombok.*;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
import net.sf.jsqlparser.expression.Parenthesis;
|
||||
import net.sf.jsqlparser.expression.RowConstructor;
|
||||
import net.sf.jsqlparser.expression.StringValue;
|
||||
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
|
||||
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
|
||||
import net.sf.jsqlparser.expression.operators.relational.ParenthesedExpressionList;
|
||||
import net.sf.jsqlparser.schema.Column;
|
||||
import net.sf.jsqlparser.schema.Table;
|
||||
import net.sf.jsqlparser.statement.delete.Delete;
|
||||
import net.sf.jsqlparser.statement.insert.Insert;
|
||||
import net.sf.jsqlparser.statement.select.*;
|
||||
import net.sf.jsqlparser.statement.update.Update;
|
||||
import net.sf.jsqlparser.statement.update.UpdateSet;
|
||||
import org.apache.ibatis.executor.Executor;
|
||||
import org.apache.ibatis.executor.statement.StatementHandler;
|
||||
import org.apache.ibatis.mapping.BoundSql;
|
||||
import org.apache.ibatis.mapping.MappedStatement;
|
||||
import org.apache.ibatis.mapping.SqlCommandType;
|
||||
import org.apache.ibatis.session.ResultHandler;
|
||||
import org.apache.ibatis.session.RowBounds;
|
||||
|
||||
import java.sql.Connection;
|
||||
import java.sql.SQLException;
|
||||
import java.util.List;
|
||||
import java.util.Properties;
|
||||
|
||||
/**
|
||||
* @author hubin
|
||||
* @since 3.4.0
|
||||
*/
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
@ToString(callSuper = true)
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@SuppressWarnings({"rawtypes"})
|
||||
public class TenantLineInnerInterceptor extends BaseMultiTableInnerInterceptor implements InnerInterceptor {
|
||||
|
||||
private TenantLineHandler tenantLineHandler;
|
||||
|
||||
@Override
|
||||
public void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) throws SQLException {
|
||||
if (InterceptorIgnoreHelper.willIgnoreTenantLine(ms.getId())) {
|
||||
return;
|
||||
}
|
||||
PluginUtils.MPBoundSql mpBs = PluginUtils.mpBoundSql(boundSql);
|
||||
mpBs.sql(parserSingle(mpBs.sql(), null));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void beforePrepare(StatementHandler sh, Connection connection, Integer transactionTimeout) {
|
||||
PluginUtils.MPStatementHandler mpSh = PluginUtils.mpStatementHandler(sh);
|
||||
MappedStatement ms = mpSh.mappedStatement();
|
||||
SqlCommandType sct = ms.getSqlCommandType();
|
||||
if (sct == SqlCommandType.INSERT || sct == SqlCommandType.UPDATE || sct == SqlCommandType.DELETE) {
|
||||
if (InterceptorIgnoreHelper.willIgnoreTenantLine(ms.getId())) {
|
||||
return;
|
||||
}
|
||||
PluginUtils.MPBoundSql mpBs = mpSh.mPBoundSql();
|
||||
mpBs.sql(parserMulti(mpBs.sql(), null));
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void processSelect(Select select, int index, String sql, Object obj) {
|
||||
final String whereSegment = (String) obj;
|
||||
processSelectBody(select, whereSegment);
|
||||
List<WithItem> withItemsList = select.getWithItemsList();
|
||||
if (!CollectionUtils.isEmpty(withItemsList)) {
|
||||
withItemsList.forEach(withItem -> processSelectBody(withItem, whereSegment));
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void processInsert(Insert insert, int index, String sql, Object obj) {
|
||||
if (tenantLineHandler.ignoreTable(insert.getTable().getName())) {
|
||||
// 过滤退出执行
|
||||
return;
|
||||
}
|
||||
List<Column> columns = insert.getColumns();
|
||||
if (CollectionUtils.isEmpty(columns)) {
|
||||
// 针对不给列名的insert 不处理
|
||||
return;
|
||||
}
|
||||
String tenantIdColumn = tenantLineHandler.getTenantIdColumn();
|
||||
if (tenantLineHandler.ignoreInsert(columns, tenantIdColumn)) {
|
||||
// 针对已给出租户列的insert 不处理
|
||||
return;
|
||||
}
|
||||
columns.add(new Column(tenantIdColumn));
|
||||
Expression tenantId = tenantLineHandler.getTenantId();
|
||||
// fixed gitee pulls/141 duplicate update
|
||||
List<UpdateSet> duplicateUpdateColumns = insert.getDuplicateUpdateSets();
|
||||
if (CollectionUtils.isNotEmpty(duplicateUpdateColumns)) {
|
||||
EqualsTo equalsTo = new EqualsTo();
|
||||
equalsTo.setLeftExpression(new StringValue(tenantIdColumn));
|
||||
equalsTo.setRightExpression(tenantId);
|
||||
duplicateUpdateColumns.add(new UpdateSet(new Column(tenantIdColumn), tenantId));
|
||||
}
|
||||
|
||||
Select select = insert.getSelect();
|
||||
if (select instanceof PlainSelect) { //fix github issue 4998 修复升级到4.5版本的问题
|
||||
this.processInsertSelect(select, (String) obj);
|
||||
} else if (insert.getValues() != null) {
|
||||
// fixed github pull/295
|
||||
Values values = insert.getValues();
|
||||
ExpressionList<Expression> expressions = (ExpressionList<Expression>) values.getExpressions();
|
||||
if (expressions instanceof ParenthesedExpressionList) {
|
||||
expressions.addExpression(tenantId);
|
||||
} else {
|
||||
if (CollectionUtils.isNotEmpty(expressions)) {//fix github issue 4998 jsqlparse 4.5 批量insert ItemsList不是MultiExpressionList 了,需要特殊处理
|
||||
int len = expressions.size();
|
||||
for (int i = 0; i < len; i++) {
|
||||
Expression expression = expressions.get(i);
|
||||
if (expression instanceof Parenthesis) {
|
||||
ExpressionList rowConstructor = new RowConstructor<>()
|
||||
.withExpressions(new ExpressionList<>(((Parenthesis) expression).getExpression(), tenantId));
|
||||
expressions.set(i, rowConstructor);
|
||||
} else if (expression instanceof ParenthesedExpressionList) {
|
||||
((ParenthesedExpressionList) expression).addExpression(tenantId);
|
||||
} else {
|
||||
expressions.add(tenantId);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
expressions.add(tenantId);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
throw ExceptionUtils.mpe("Failed to process multiple-table update, please exclude the tableName or statementId");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* update 语句处理
|
||||
*/
|
||||
@Override
|
||||
protected void processUpdate(Update update, int index, String sql, Object obj) {
|
||||
final Table table = update.getTable();
|
||||
if (tenantLineHandler.ignoreTable(table.getName())) {
|
||||
// 过滤退出执行
|
||||
return;
|
||||
}
|
||||
List<UpdateSet> sets = update.getUpdateSets();
|
||||
if (!CollectionUtils.isEmpty(sets)) {
|
||||
sets.forEach(us -> us.getValues().forEach(ex -> {
|
||||
if (ex instanceof Select) {
|
||||
processSelectBody(((Select) ex), (String) obj);
|
||||
}
|
||||
}));
|
||||
}
|
||||
update.setWhere(this.andExpression(table, update.getWhere(), (String) obj));
|
||||
}
|
||||
|
||||
/**
|
||||
* delete 语句处理
|
||||
*/
|
||||
@Override
|
||||
protected void processDelete(Delete delete, int index, String sql, Object obj) {
|
||||
if (tenantLineHandler.ignoreTable(delete.getTable().getName())) {
|
||||
// 过滤退出执行
|
||||
return;
|
||||
}
|
||||
delete.setWhere(this.andExpression(delete.getTable(), delete.getWhere(), (String) obj));
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理 insert into select
|
||||
* <p>
|
||||
* 进入这里表示需要 insert 的表启用了多租户,则 select 的表都启动了
|
||||
*
|
||||
* @param selectBody SelectBody
|
||||
*/
|
||||
protected void processInsertSelect(Select selectBody, final String whereSegment) {
|
||||
if(selectBody instanceof PlainSelect){
|
||||
PlainSelect plainSelect = (PlainSelect) selectBody;
|
||||
FromItem fromItem = plainSelect.getFromItem();
|
||||
if (fromItem instanceof Table) {
|
||||
// fixed gitee pulls/141 duplicate update
|
||||
processPlainSelect(plainSelect, whereSegment);
|
||||
appendSelectItem(plainSelect.getSelectItems());
|
||||
} else if (fromItem instanceof Select) {
|
||||
Select subSelect = (Select) fromItem;
|
||||
appendSelectItem(plainSelect.getSelectItems());
|
||||
processInsertSelect(subSelect, whereSegment);
|
||||
}
|
||||
} else if(selectBody instanceof ParenthesedSelect){
|
||||
ParenthesedSelect parenthesedSelect = (ParenthesedSelect) selectBody;
|
||||
processInsertSelect(parenthesedSelect.getSelect(), whereSegment);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 追加 SelectItem
|
||||
*
|
||||
* @param selectItems SelectItem
|
||||
*/
|
||||
protected void appendSelectItem(List<SelectItem<?>> selectItems) {
|
||||
if (CollectionUtils.isEmpty(selectItems)) {
|
||||
return;
|
||||
}
|
||||
if (selectItems.size() == 1) {
|
||||
SelectItem item = selectItems.get(0);
|
||||
Expression expression = item.getExpression();
|
||||
if (expression instanceof AllColumns) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
selectItems.add(new SelectItem<>(new Column(tenantLineHandler.getTenantIdColumn())));
|
||||
}
|
||||
|
||||
/**
|
||||
* 租户字段别名设置
|
||||
* <p>tenantId 或 tableAlias.tenantId</p>
|
||||
*
|
||||
* @param table 表对象
|
||||
* @return 字段
|
||||
*/
|
||||
protected Column getAliasColumn(Table table) {
|
||||
StringBuilder column = new StringBuilder();
|
||||
// todo 该起别名就要起别名,禁止修改此处逻辑
|
||||
if (table.getAlias() != null) {
|
||||
column.append(table.getAlias().getName()).append(StringPool.DOT);
|
||||
}
|
||||
column.append(tenantLineHandler.getTenantIdColumn());
|
||||
return new Column(column.toString());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setProperties(Properties properties) {
|
||||
PropertyMapper.newInstance(properties).whenNotBlank("tenantLineHandler",
|
||||
ClassUtils::newInstance, this::setTenantLineHandler);
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建租户条件表达式
|
||||
*
|
||||
* @param table 表对象
|
||||
* @param where 当前where条件
|
||||
* @param whereSegment 所属Mapper对象全路径(在原租户拦截器功能中,这个参数并不需要参与相关判断)
|
||||
* @return 租户条件表达式
|
||||
* @see BaseMultiTableInnerInterceptor#buildTableExpression(Table, Expression, String)
|
||||
*/
|
||||
@Override
|
||||
public Expression buildTableExpression(final Table table, final Expression where, final String whereSegment) {
|
||||
if (tenantLineHandler.ignoreTable(table.getName())) {
|
||||
return null;
|
||||
}
|
||||
return new EqualsTo(getAliasColumn(table), tenantLineHandler.getTenantId());
|
||||
}
|
||||
}
|
@ -0,0 +1,77 @@
|
||||
package com.baomidou.mybatisplus.test;
|
||||
|
||||
import net.sf.jsqlparser.JSQLParserException;
|
||||
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
|
||||
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
||||
import net.sf.jsqlparser.statement.delete.Delete;
|
||||
import net.sf.jsqlparser.statement.select.PlainSelect;
|
||||
import net.sf.jsqlparser.statement.select.Select;
|
||||
import net.sf.jsqlparser.statement.update.Update;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
/**
|
||||
* SQL 解析测试
|
||||
*/
|
||||
class JSqlParserTest {
|
||||
|
||||
@Test
|
||||
void parser() throws Exception {
|
||||
Select select = (Select) CCJSqlParserUtil.parse("SELECT a,b,c FROM tableName t WHERE t.col = 9 and b=c LIMIT 3, ?");
|
||||
|
||||
PlainSelect ps = (PlainSelect) select;
|
||||
|
||||
System.out.println(ps.getWhere().toString());
|
||||
System.out.println(ps.getSelectItems().get(1).toString());
|
||||
|
||||
AndExpression e = (AndExpression) ps.getWhere();
|
||||
System.out.println(e.getLeftExpression());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testDecr() throws JSQLParserException {
|
||||
// 如果连一起 SqlParser 将无法解析 , 还有种处理方式就自减为负数的时候 转为 自增.
|
||||
var parse1 = CCJSqlParserUtil.parse("UPDATE test SET a = a --110");
|
||||
Assertions.assertEquals("UPDATE test SET a = a", parse1.toString());
|
||||
var parse2 = CCJSqlParserUtil.parse("UPDATE test SET a = a - -110");
|
||||
Assertions.assertEquals("UPDATE test SET a = a - -110", parse2.toString());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testIncr() throws JSQLParserException {
|
||||
var parse1 = CCJSqlParserUtil.parse("UPDATE test SET a = a +-110");
|
||||
Assertions.assertEquals("UPDATE test SET a = a + -110", parse1.toString());
|
||||
var parse2 = CCJSqlParserUtil.parse("UPDATE test SET a = a + -110");
|
||||
Assertions.assertEquals("UPDATE test SET a = a + -110", parse2.toString());
|
||||
}
|
||||
|
||||
@Test
|
||||
void notLikeParser() throws Exception {
|
||||
final String targetSql = "SELECT * FROM tableName WHERE id NOT LIKE ?";
|
||||
Select select = (Select) CCJSqlParserUtil.parse(targetSql);
|
||||
assertThat(select.toString()).isEqualTo(targetSql);
|
||||
}
|
||||
|
||||
@Test
|
||||
void updateWhereParser() throws Exception {
|
||||
Update update = (Update) CCJSqlParserUtil.parse("Update tableName t SET t.a=(select c from tn where tn.id=t.id),b=2,c=3 ");
|
||||
Assertions.assertNull(update.getWhere());
|
||||
}
|
||||
|
||||
@Test
|
||||
void deleteWhereParser() throws Exception {
|
||||
Delete delete = (Delete) CCJSqlParserUtil.parse("delete from tableName t");
|
||||
Assertions.assertNull(delete.getWhere());
|
||||
}
|
||||
|
||||
// @Test
|
||||
// void testSelectForUpdate() throws Exception {
|
||||
// Assertions.assertEquals("SELECT * FROM t_demo WHERE a = 1 FOR UPDATE",
|
||||
// CCJSqlParserUtil.parse("select * from t_demo where a = 1 for update").toString());
|
||||
// Assertions.assertEquals("SELECT * FROM sys_sms_send_record WHERE check_status = 0 ORDER BY submit_time ASC LIMIT 10 FOR UPDATE",
|
||||
// CCJSqlParserUtil.parse("select * from sys_sms_send_record where check_status = 0 for update order by submit_time asc limit 10").toString());
|
||||
// }
|
||||
|
||||
}
|
@ -1,7 +1,6 @@
|
||||
package com.baomidou.mybatisplus.test.extension.plugins.inner;
|
||||
|
||||
import com.baomidou.mybatisplus.extension.plugins.inner.DynamicTableNameInnerInterceptor;
|
||||
import org.intellij.lang.annotations.Language;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
@ -24,7 +23,6 @@ class DynamicTableNameInnerInterceptorTest {
|
||||
interceptor.setTableNameHandler((sql, tableName) -> tableName + "_r");
|
||||
|
||||
// 表名相互包含
|
||||
@Language("SQL")
|
||||
String origin = "SELECT * FROM t_user, t_user_role";
|
||||
assertEquals("SELECT * FROM t_user_r, t_user_role_r", interceptor.changeTable(origin));
|
||||
|
@ -0,0 +1,13 @@
|
||||
dependencies {
|
||||
api "${lib.'jsqlparser'}"
|
||||
api project(":mybatis-plus-jsqlparser:mybatis-plus-jsqlparser-common")
|
||||
implementation "${lib."slf4j-api"}"
|
||||
implementation "de.ruedigermoeller:fst:3.0.4-jdk17"
|
||||
implementation "com.github.ben-manes.caffeine:caffeine:2.9.3"
|
||||
testImplementation "io.github.classgraph:classgraph:4.8.176"
|
||||
testImplementation "${lib."spring-context-support"}"
|
||||
testImplementation "${lib.h2}"
|
||||
testImplementation group: 'com.google.guava', name: 'guava', version: '33.3.1-jre'
|
||||
}
|
||||
|
||||
compileJava.dependsOn(processResources)
|
@ -0,0 +1,28 @@
|
||||
/*
|
||||
* Copyright (c) 2011-2024, baomidou (jobob@qq.com).
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package com.baomidou.mybatisplus.extension.parser;
|
||||
|
||||
import net.sf.jsqlparser.JSQLParserException;
|
||||
|
||||
/**
|
||||
* @author miemie
|
||||
* @since 2023-08-05
|
||||
*/
|
||||
@FunctionalInterface
|
||||
public interface JsqlParserFunction<T, R> {
|
||||
|
||||
R apply(T t) throws JSQLParserException;
|
||||
}
|
@ -0,0 +1,89 @@
|
||||
/*
|
||||
* Copyright (c) 2011-2024, baomidou (jobob@qq.com).
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package com.baomidou.mybatisplus.extension.parser;
|
||||
|
||||
import com.baomidou.mybatisplus.extension.parser.cache.JsqlParseCache;
|
||||
import lombok.Setter;
|
||||
import net.sf.jsqlparser.JSQLParserException;
|
||||
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
||||
import net.sf.jsqlparser.statement.Statement;
|
||||
import net.sf.jsqlparser.statement.Statements;
|
||||
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.LinkedBlockingQueue;
|
||||
import java.util.concurrent.ThreadPoolExecutor;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
/**
|
||||
* @author miemie
|
||||
* @since 2023-08-05
|
||||
*/
|
||||
public class JsqlParserGlobal {
|
||||
|
||||
/**
|
||||
* 默认线程数大小
|
||||
*
|
||||
* @since 3.5.6
|
||||
*/
|
||||
public static final int DEFAULT_THREAD_SIZE = (Runtime.getRuntime().availableProcessors() + 1) / 2;
|
||||
|
||||
/**
|
||||
* 默认解析处理线程池
|
||||
* <p>注意: 由于项目情况,机器配置等不一样因素,请自行根据情况创建指定线程池.</p>
|
||||
*
|
||||
* @see java.util.concurrent.ThreadPoolExecutor
|
||||
* @since 3.5.6
|
||||
*/
|
||||
public static ExecutorService executorService = new ThreadPoolExecutor(DEFAULT_THREAD_SIZE, DEFAULT_THREAD_SIZE, 0L, TimeUnit.MILLISECONDS, new LinkedBlockingQueue<>(), r -> {
|
||||
Thread thread = new Thread(r);
|
||||
thread.setName("mybatis-plus-jsqlParser-" + thread.getId());
|
||||
thread.setDaemon(true);
|
||||
return thread;
|
||||
});
|
||||
|
||||
@Setter
|
||||
private static JsqlParserFunction<String, Statement> parserSingleFunc = sql -> CCJSqlParserUtil.parse(sql, executorService, null);
|
||||
|
||||
@Setter
|
||||
private static JsqlParserFunction<String, Statements> parserMultiFunc = sql -> CCJSqlParserUtil.parseStatements(sql, executorService, null);
|
||||
|
||||
@Setter
|
||||
private static JsqlParseCache jsqlParseCache;
|
||||
|
||||
public static Statement parse(String sql) throws JSQLParserException {
|
||||
if (jsqlParseCache == null) {
|
||||
return parserSingleFunc.apply(sql);
|
||||
}
|
||||
Statement statement = jsqlParseCache.getStatement(sql);
|
||||
if (statement == null) {
|
||||
statement = parserSingleFunc.apply(sql);
|
||||
jsqlParseCache.putStatement(sql, statement);
|
||||
}
|
||||
return statement;
|
||||
}
|
||||
|
||||
public static Statements parseStatements(String sql) throws JSQLParserException {
|
||||
if (jsqlParseCache == null) {
|
||||
return parserMultiFunc.apply(sql);
|
||||
}
|
||||
Statements statements = jsqlParseCache.getStatements(sql);
|
||||
if (statements == null) {
|
||||
statements = parserMultiFunc.apply(sql);
|
||||
jsqlParseCache.putStatements(sql, statements);
|
||||
}
|
||||
return statements;
|
||||
}
|
||||
}
|
@ -0,0 +1,130 @@
|
||||
/*
|
||||
* Copyright (c) 2011-2024, baomidou (jobob@qq.com).
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package com.baomidou.mybatisplus.extension.parser;
|
||||
|
||||
import com.baomidou.mybatisplus.core.toolkit.ExceptionUtils;
|
||||
import com.baomidou.mybatisplus.core.toolkit.StringPool;
|
||||
import net.sf.jsqlparser.JSQLParserException;
|
||||
import net.sf.jsqlparser.statement.Statement;
|
||||
import net.sf.jsqlparser.statement.Statements;
|
||||
import net.sf.jsqlparser.statement.delete.Delete;
|
||||
import net.sf.jsqlparser.statement.insert.Insert;
|
||||
import net.sf.jsqlparser.statement.select.Select;
|
||||
import net.sf.jsqlparser.statement.update.Update;
|
||||
import org.apache.ibatis.logging.Log;
|
||||
import org.apache.ibatis.logging.LogFactory;
|
||||
|
||||
/**
|
||||
* https://github.com/JSQLParser/JSqlParser
|
||||
*
|
||||
* @author miemie
|
||||
* @since 2020-06-22
|
||||
*/
|
||||
public abstract class JsqlParserSupport {
|
||||
|
||||
/**
|
||||
* 日志
|
||||
*/
|
||||
protected final Log logger = LogFactory.getLog(this.getClass());
|
||||
|
||||
public String parserSingle(String sql, Object obj) {
|
||||
if (logger.isDebugEnabled()) {
|
||||
logger.debug("original SQL: " + sql);
|
||||
}
|
||||
try {
|
||||
Statement statement = JsqlParserGlobal.parse(sql);
|
||||
return processParser(statement, 0, sql, obj);
|
||||
} catch (JSQLParserException e) {
|
||||
throw ExceptionUtils.mpe("Failed to process, Error SQL: %s", e.getCause(), sql);
|
||||
}
|
||||
}
|
||||
|
||||
public String parserMulti(String sql, Object obj) {
|
||||
if (logger.isDebugEnabled()) {
|
||||
logger.debug("original SQL: " + sql);
|
||||
}
|
||||
try {
|
||||
// fixed github pull/295
|
||||
StringBuilder sb = new StringBuilder();
|
||||
Statements statements = JsqlParserGlobal.parseStatements(sql);
|
||||
int i = 0;
|
||||
for (Statement statement : statements) {
|
||||
if (i > 0) {
|
||||
sb.append(StringPool.SEMICOLON);
|
||||
}
|
||||
sb.append(processParser(statement, i, sql, obj));
|
||||
i++;
|
||||
}
|
||||
return sb.toString();
|
||||
} catch (JSQLParserException e) {
|
||||
throw ExceptionUtils.mpe("Failed to process, Error SQL: %s", e.getCause(), sql);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 执行 SQL 解析
|
||||
*
|
||||
* @param statement JsqlParser Statement
|
||||
* @return sql
|
||||
*/
|
||||
protected String processParser(Statement statement, int index, String sql, Object obj) {
|
||||
if (logger.isDebugEnabled()) {
|
||||
logger.debug("SQL to parse, SQL: " + sql);
|
||||
}
|
||||
if (statement instanceof Insert) {
|
||||
this.processInsert((Insert) statement, index, sql, obj);
|
||||
} else if (statement instanceof Select) {
|
||||
this.processSelect((Select) statement, index, sql, obj);
|
||||
} else if (statement instanceof Update) {
|
||||
this.processUpdate((Update) statement, index, sql, obj);
|
||||
} else if (statement instanceof Delete) {
|
||||
this.processDelete((Delete) statement, index, sql, obj);
|
||||
}
|
||||
sql = statement.toString();
|
||||
if (logger.isDebugEnabled()) {
|
||||
logger.debug("parse the finished SQL: " + sql);
|
||||
}
|
||||
return sql;
|
||||
}
|
||||
|
||||
/**
|
||||
* 新增
|
||||
*/
|
||||
protected void processInsert(Insert insert, int index, String sql, Object obj) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
/**
|
||||
* 删除
|
||||
*/
|
||||
protected void processDelete(Delete delete, int index, String sql, Object obj) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
/**
|
||||
* 更新
|
||||
*/
|
||||
protected void processUpdate(Update update, int index, String sql, Object obj) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
/**
|
||||
* 查询
|
||||
*/
|
||||
protected void processSelect(Select select, int index, String sql, Object obj) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
}
|
@ -0,0 +1,121 @@
|
||||
/*
|
||||
* Copyright (c) 2011-2024, baomidou (jobob@qq.com).
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package com.baomidou.mybatisplus.extension.parser.cache;
|
||||
|
||||
import com.github.benmanes.caffeine.cache.Cache;
|
||||
import com.github.benmanes.caffeine.cache.Caffeine;
|
||||
import lombok.Setter;
|
||||
import net.sf.jsqlparser.statement.Statement;
|
||||
import net.sf.jsqlparser.statement.Statements;
|
||||
import org.apache.ibatis.logging.Log;
|
||||
import org.apache.ibatis.logging.LogFactory;
|
||||
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.Executor;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
/**
|
||||
* jsqlparser 缓存 Caffeine 缓存实现抽象类
|
||||
*
|
||||
* @author miemie hubin
|
||||
* @since 2023-08-08
|
||||
*/
|
||||
public abstract class AbstractCaffeineJsqlParseCache implements JsqlParseCache {
|
||||
protected final Log logger = LogFactory.getLog(this.getClass());
|
||||
protected final Cache<String, byte[]> cache;
|
||||
@Setter
|
||||
protected boolean async = false;
|
||||
@Setter
|
||||
protected Executor executor;
|
||||
|
||||
public AbstractCaffeineJsqlParseCache(Cache<String, byte[]> cache) {
|
||||
this.cache = cache;
|
||||
}
|
||||
|
||||
public AbstractCaffeineJsqlParseCache(Consumer<Caffeine<Object, Object>> consumer) {
|
||||
Caffeine<Object, Object> caffeine = Caffeine.newBuilder();
|
||||
consumer.accept(caffeine);
|
||||
this.cache = caffeine.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void putStatement(String sql, Statement value) {
|
||||
this.put(sql, value);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void putStatements(String sql, Statements value) {
|
||||
this.put(sql, value);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Statement getStatement(String sql) {
|
||||
return this.get(sql);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Statements getStatements(String sql) {
|
||||
return this.get(sql);
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取解析对象,异常清空缓存逻辑
|
||||
*
|
||||
* @param sql 执行 SQL
|
||||
* @return 返回泛型对象
|
||||
*/
|
||||
protected <T> T get(String sql) {
|
||||
byte[] bytes = cache.getIfPresent(sql);
|
||||
if (null != bytes) {
|
||||
try {
|
||||
return (T) deserialize(sql, bytes);
|
||||
} catch (Exception e) {
|
||||
cache.invalidate(sql);
|
||||
logger.error("deserialize error", e);
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* 存储解析对象
|
||||
*
|
||||
* @param sql 执行 SQL
|
||||
* @param value 解析对象
|
||||
*/
|
||||
protected void put(String sql, Object value) {
|
||||
if (async) {
|
||||
if (executor != null) {
|
||||
CompletableFuture.runAsync(() -> cache.put(sql, serialize(value)), executor);
|
||||
} else {
|
||||
CompletableFuture.runAsync(() -> cache.put(sql, serialize(value)));
|
||||
}
|
||||
} else {
|
||||
cache.put(sql, serialize(value));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 序列化
|
||||
*/
|
||||
public abstract byte[] serialize(Object obj);
|
||||
|
||||
/**
|
||||
* 反序列化
|
||||
*/
|
||||
public abstract Object deserialize(String sql, byte[] bytes);
|
||||
|
||||
}
|
@ -0,0 +1,49 @@
|
||||
/*
|
||||
* Copyright (c) 2011-2024, baomidou (jobob@qq.com).
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package com.baomidou.mybatisplus.extension.parser.cache;
|
||||
|
||||
import com.github.benmanes.caffeine.cache.Cache;
|
||||
import com.github.benmanes.caffeine.cache.Caffeine;
|
||||
|
||||
import java.util.function.Consumer;
|
||||
|
||||
/**
|
||||
* jsqlparser 缓存 Fst 序列化 Caffeine 缓存实现
|
||||
*
|
||||
* @author miemie
|
||||
* @since 2023-08-05
|
||||
*/
|
||||
public class FstSerialCaffeineJsqlParseCache extends AbstractCaffeineJsqlParseCache {
|
||||
|
||||
|
||||
public FstSerialCaffeineJsqlParseCache(Cache<String, byte[]> cache) {
|
||||
super(cache);
|
||||
}
|
||||
|
||||
public FstSerialCaffeineJsqlParseCache(Consumer<Caffeine<Object, Object>> consumer) {
|
||||
super(consumer);
|
||||
}
|
||||
|
||||
@Override
|
||||
public byte[] serialize(Object obj) {
|
||||
return FstFactory.getDefaultFactory().asByteArray(obj);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object deserialize(String sql, byte[] bytes) {
|
||||
return FstFactory.getDefaultFactory().asObject(bytes);
|
||||
}
|
||||
}
|
@ -0,0 +1,50 @@
|
||||
/*
|
||||
* Copyright (c) 2011-2024, baomidou (jobob@qq.com).
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package com.baomidou.mybatisplus.extension.parser.cache;
|
||||
|
||||
import com.baomidou.mybatisplus.core.toolkit.SerializationUtils;
|
||||
import com.github.benmanes.caffeine.cache.Cache;
|
||||
import com.github.benmanes.caffeine.cache.Caffeine;
|
||||
|
||||
import java.util.function.Consumer;
|
||||
|
||||
/**
|
||||
* jsqlparser 缓存 jdk 序列化 Caffeine 缓存实现
|
||||
*
|
||||
* @author miemie
|
||||
* @since 2023-08-05
|
||||
*/
|
||||
public class JdkSerialCaffeineJsqlParseCache extends AbstractCaffeineJsqlParseCache {
|
||||
|
||||
|
||||
public JdkSerialCaffeineJsqlParseCache(Cache<String, byte[]> cache) {
|
||||
super(cache);
|
||||
}
|
||||
|
||||
public JdkSerialCaffeineJsqlParseCache(Consumer<Caffeine<Object, Object>> consumer) {
|
||||
super(consumer);
|
||||
}
|
||||
|
||||
@Override
|
||||
public byte[] serialize(Object obj) {
|
||||
return SerializationUtils.serialize(obj);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object deserialize(String sql, byte[] bytes) {
|
||||
return SerializationUtils.deserialize(bytes);
|
||||
}
|
||||
}
|
@ -0,0 +1,36 @@
|
||||
/*
|
||||
* Copyright (c) 2011-2024, baomidou (jobob@qq.com).
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package com.baomidou.mybatisplus.extension.parser.cache;
|
||||
|
||||
import net.sf.jsqlparser.statement.Statement;
|
||||
import net.sf.jsqlparser.statement.Statements;
|
||||
|
||||
/**
|
||||
* jsqlparser 缓存接口
|
||||
*
|
||||
* @author miemie
|
||||
* @since 2023-08-05
|
||||
*/
|
||||
public interface JsqlParseCache {
|
||||
|
||||
void putStatement(String sql, Statement value);
|
||||
|
||||
void putStatements(String sql, Statements value);
|
||||
|
||||
Statement getStatement(String sql);
|
||||
|
||||
Statements getStatements(String sql);
|
||||
}
|
@ -0,0 +1,36 @@
|
||||
/*
|
||||
* Copyright (c) 2011-2024, baomidou (jobob@qq.com).
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package com.baomidou.mybatisplus.extension.plugins.handler;
|
||||
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
|
||||
/**
|
||||
* 数据权限处理器
|
||||
*
|
||||
* @author hubin
|
||||
* @since 3.4.1 +
|
||||
*/
|
||||
public interface DataPermissionHandler {
|
||||
|
||||
/**
|
||||
* 获取数据权限 SQL 片段
|
||||
*
|
||||
* @param where 待执行 SQL Where 条件表达式
|
||||
* @param mappedStatementId Mybatis MappedStatement Id 根据该参数可以判断具体执行方法
|
||||
* @return JSqlParser 条件表达式,返回的条件表达式会覆盖原有的条件表达式
|
||||
*/
|
||||
Expression getSqlSegment(Expression where, String mappedStatementId);
|
||||
}
|
@ -0,0 +1,53 @@
|
||||
/*
|
||||
* Copyright (c) 2011-2024, baomidou (jobob@qq.com).
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package com.baomidou.mybatisplus.extension.plugins.handler;
|
||||
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
import net.sf.jsqlparser.schema.Table;
|
||||
|
||||
/**
|
||||
* 支持多表的数据权限处理器
|
||||
*
|
||||
* @author houkunlin
|
||||
* @since 3.5.2 +
|
||||
*/
|
||||
public interface MultiDataPermissionHandler extends DataPermissionHandler {
|
||||
/**
|
||||
* 为兼容旧版数据权限处理器,继承了 {@link DataPermissionHandler} 但是新的多表数据权限处理又不会调用此方法,因此标记过时
|
||||
*
|
||||
* @param where 待执行 SQL Where 条件表达式
|
||||
* @param mappedStatementId Mybatis MappedStatement Id 根据该参数可以判断具体执行方法
|
||||
* @return JSqlParser 条件表达式
|
||||
* @deprecated 新的多表数据权限处理不会调用此方法,因此标记过时
|
||||
*/
|
||||
@Deprecated
|
||||
@Override
|
||||
default Expression getSqlSegment(Expression where, String mappedStatementId) {
|
||||
return where;
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取数据权限 SQL 片段。
|
||||
* <p>旧的 {@link MultiDataPermissionHandler#getSqlSegment(Expression, String)} 方法第一个参数包含所有的 where 条件信息,如果 return 了 null 会覆盖原有的 where 数据,</p>
|
||||
* <p>新版的 {@link MultiDataPermissionHandler#getSqlSegment(Table, Expression, String)} 方法不能覆盖原有的 where 数据,如果 return 了 null 则表示不追加任何 where 条件</p>
|
||||
*
|
||||
* @param table 所执行的数据库表信息,可以通过此参数获取表名和表别名
|
||||
* @param where 原有的 where 条件信息
|
||||
* @param mappedStatementId Mybatis MappedStatement Id 根据该参数可以判断具体执行方法
|
||||
* @return JSqlParser 条件表达式,返回的条件表达式会拼接在原有的表达式后面(不会覆盖原有的表达式)
|
||||
*/
|
||||
Expression getSqlSegment(final Table table, final Expression where, final String mappedStatementId);
|
||||
}
|
@ -0,0 +1,72 @@
|
||||
/*
|
||||
* Copyright (c) 2011-2024, baomidou (jobob@qq.com).
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package com.baomidou.mybatisplus.extension.plugins.handler;
|
||||
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
import net.sf.jsqlparser.schema.Column;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* 租户处理器( TenantId 行级 )
|
||||
*
|
||||
* @author hubin
|
||||
* @since 3.4.0
|
||||
*/
|
||||
public interface TenantLineHandler {
|
||||
|
||||
/**
|
||||
* 获取租户 ID 值表达式,只支持单个 ID 值
|
||||
* <p>
|
||||
*
|
||||
* @return 租户 ID 值表达式
|
||||
*/
|
||||
Expression getTenantId();
|
||||
|
||||
/**
|
||||
* 获取租户字段名
|
||||
* <p>
|
||||
* 默认字段名叫: tenant_id
|
||||
*
|
||||
* @return 租户字段名
|
||||
*/
|
||||
default String getTenantIdColumn() {
|
||||
return "tenant_id";
|
||||
}
|
||||
|
||||
/**
|
||||
* 根据表名判断是否忽略拼接多租户条件
|
||||
* <p>
|
||||
* 默认都要进行解析并拼接多租户条件
|
||||
*
|
||||
* @param tableName 表名
|
||||
* @return 是否忽略, true:表示忽略,false:需要解析并拼接多租户条件
|
||||
*/
|
||||
default boolean ignoreTable(String tableName) {
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* 忽略插入租户字段逻辑
|
||||
*
|
||||
* @param columns 插入字段
|
||||
* @param tenantIdColumn 租户 ID 字段
|
||||
* @return
|
||||
*/
|
||||
default boolean ignoreInsert(List<Column> columns, String tenantIdColumn) {
|
||||
return columns.stream().map(Column::getColumnName).anyMatch(i -> i.equalsIgnoreCase(tenantIdColumn));
|
||||
}
|
||||
}
|
@ -0,0 +1,990 @@
|
||||
/*
|
||||
* Copyright (c) 2011-2024, baomidou (jobob@qq.com).
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package com.baomidou.mybatisplus.extension.plugins.inner;
|
||||
|
||||
import java.io.Reader;
|
||||
import java.sql.Clob;
|
||||
import java.sql.Connection;
|
||||
import java.sql.PreparedStatement;
|
||||
import java.sql.ResultSet;
|
||||
import java.sql.ResultSetMetaData;
|
||||
import java.sql.SQLException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.Properties;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
import com.baomidou.mybatisplus.core.toolkit.Constants;
|
||||
import net.sf.jsqlparser.statement.select.Values;
|
||||
import org.apache.ibatis.executor.statement.StatementHandler;
|
||||
import org.apache.ibatis.mapping.BoundSql;
|
||||
import org.apache.ibatis.mapping.MappedStatement;
|
||||
import org.apache.ibatis.mapping.ParameterMapping;
|
||||
import org.apache.ibatis.mapping.SqlCommandType;
|
||||
import org.apache.ibatis.reflection.MetaObject;
|
||||
import org.apache.ibatis.reflection.SystemMetaObject;
|
||||
import org.apache.ibatis.scripting.defaults.DefaultParameterHandler;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import com.baomidou.mybatisplus.annotation.IEnum;
|
||||
import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper;
|
||||
import com.baomidou.mybatisplus.core.conditions.update.UpdateWrapper;
|
||||
import com.baomidou.mybatisplus.core.exceptions.MybatisPlusException;
|
||||
import com.baomidou.mybatisplus.core.handlers.MybatisEnumTypeHandler;
|
||||
import com.baomidou.mybatisplus.core.metadata.TableFieldInfo;
|
||||
import com.baomidou.mybatisplus.core.metadata.TableInfo;
|
||||
import com.baomidou.mybatisplus.core.metadata.TableInfoHelper;
|
||||
import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
|
||||
import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
|
||||
import com.baomidou.mybatisplus.extension.parser.JsqlParserGlobal;
|
||||
|
||||
import lombok.Data;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
import net.sf.jsqlparser.expression.JdbcParameter;
|
||||
import net.sf.jsqlparser.expression.RowConstructor;
|
||||
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
|
||||
import net.sf.jsqlparser.schema.Column;
|
||||
import net.sf.jsqlparser.schema.Table;
|
||||
import net.sf.jsqlparser.statement.Statement;
|
||||
import net.sf.jsqlparser.statement.delete.Delete;
|
||||
import net.sf.jsqlparser.statement.insert.Insert;
|
||||
import net.sf.jsqlparser.statement.select.AllColumns;
|
||||
import net.sf.jsqlparser.statement.select.FromItem;
|
||||
import net.sf.jsqlparser.statement.select.PlainSelect;
|
||||
import net.sf.jsqlparser.statement.select.Select;
|
||||
import net.sf.jsqlparser.statement.select.SelectItem;
|
||||
import net.sf.jsqlparser.statement.select.SetOperationList;
|
||||
import net.sf.jsqlparser.statement.update.Update;
|
||||
import net.sf.jsqlparser.statement.update.UpdateSet;
|
||||
|
||||
/**
|
||||
* <p>
|
||||
* 数据变动记录插件
|
||||
* 默认会生成一条log,格式:
|
||||
* ----------------------INSERT LOG------------------------------
|
||||
* </p>
|
||||
* <p>
|
||||
* {
|
||||
* "tableName": "h2user",
|
||||
* "operation": "insert",
|
||||
* "recordStatus": "true",
|
||||
* "changedData": [
|
||||
* {
|
||||
* "LAST_UPDATED_DT": "null->2022-08-22 18:49:16.512",
|
||||
* "TEST_ID": "null->1561666810058739714",
|
||||
* "AGE": "null->THREE"
|
||||
* }
|
||||
* ],
|
||||
* "cost(ms)": 0
|
||||
* }
|
||||
* </p>
|
||||
* <p>
|
||||
* * ----------------------UPDATE LOG------------------------------
|
||||
* <p>
|
||||
* {
|
||||
* "tableName": "h2user",
|
||||
* "operation": "update",
|
||||
* "recordStatus": "true",
|
||||
* "changedData": [
|
||||
* {
|
||||
* "TEST_ID": "102",
|
||||
* "AGE": "2->THREE",
|
||||
* "FIRSTNAME": "DOU.HAO->{\"json\":\"abc\"}",
|
||||
* "LAST_UPDATED_DT": "null->2022-08-22 18:49:16.512"
|
||||
* }
|
||||
* ],
|
||||
* "cost(ms)": 0
|
||||
* }
|
||||
* </p>
|
||||
*
|
||||
* @author yuxiaobin
|
||||
* @date 2022-8-21
|
||||
*/
|
||||
public class DataChangeRecorderInnerInterceptor implements InnerInterceptor {
|
||||
|
||||
protected final Logger logger = LoggerFactory.getLogger(this.getClass());
|
||||
@SuppressWarnings("unused")
|
||||
public static final String IGNORED_TABLE_COLUMN_PROPERTIES = "ignoredTableColumns";
|
||||
|
||||
private final Map<String, Set<String>> ignoredTableColumns = new ConcurrentHashMap<>();
|
||||
private final Set<String> ignoreAllColumns = new HashSet<>();//全部表的这些字段名,INSERT/UPDATE都忽略,delete暂时保留
|
||||
//批量更新上限, 默认一次最多1000条
|
||||
private int BATCH_UPDATE_LIMIT = 1000;
|
||||
private boolean batchUpdateLimitationOpened = false;
|
||||
private final Map<String, Integer> BATCH_UPDATE_LIMIT_MAP = new ConcurrentHashMap<>();//表名->批量更新上限
|
||||
|
||||
@Override
|
||||
public void beforePrepare(StatementHandler sh, Connection connection, Integer transactionTimeout) {
|
||||
PluginUtils.MPStatementHandler mpSh = PluginUtils.mpStatementHandler(sh);
|
||||
MappedStatement ms = mpSh.mappedStatement();
|
||||
final BoundSql boundSql = mpSh.boundSql();
|
||||
SqlCommandType sct = ms.getSqlCommandType();
|
||||
if (sct == SqlCommandType.INSERT || sct == SqlCommandType.UPDATE || sct == SqlCommandType.DELETE) {
|
||||
PluginUtils.MPBoundSql mpBs = mpSh.mPBoundSql();
|
||||
OperationResult operationResult;
|
||||
long startTs = System.currentTimeMillis();
|
||||
try {
|
||||
Statement statement = JsqlParserGlobal.parse(mpBs.sql());
|
||||
if (statement instanceof Insert) {
|
||||
operationResult = processInsert((Insert) statement, mpSh.boundSql());
|
||||
} else if (statement instanceof Update) {
|
||||
operationResult = processUpdate((Update) statement, ms, boundSql, connection);
|
||||
} else if (statement instanceof Delete) {
|
||||
operationResult = processDelete((Delete) statement, ms, boundSql, connection);
|
||||
} else {
|
||||
logger.info("other operation sql={}", mpBs.sql());
|
||||
return;
|
||||
}
|
||||
} catch (Exception e) {
|
||||
if (e instanceof DataUpdateLimitationException) {
|
||||
throw (DataUpdateLimitationException) e;
|
||||
}
|
||||
logger.error("Unexpected error for mappedStatement={}, sql={}", ms.getId(), mpBs.sql(), e);
|
||||
return;
|
||||
}
|
||||
long costThis = System.currentTimeMillis() - startTs;
|
||||
if (operationResult != null) {
|
||||
operationResult.setCost(costThis);
|
||||
dealOperationResult(operationResult);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 判断哪些SQL需要处理
|
||||
* 默认INSERT/UPDATE/DELETE语句
|
||||
*
|
||||
* @param sql
|
||||
* @return
|
||||
*/
|
||||
protected boolean allowProcess(String sql) {
|
||||
String sqlTrim = sql.trim().toUpperCase();
|
||||
return sqlTrim.startsWith("INSERT") || sqlTrim.startsWith("UPDATE") || sqlTrim.startsWith("DELETE");
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理数据更新结果,默认打印
|
||||
*
|
||||
* @param operationResult
|
||||
*/
|
||||
protected void dealOperationResult(OperationResult operationResult) {
|
||||
logger.info("{}", operationResult);
|
||||
}
|
||||
|
||||
public OperationResult processInsert(Insert insertStmt, BoundSql boundSql) {
|
||||
String operation = SqlCommandType.INSERT.name().toLowerCase();
|
||||
Table table = insertStmt.getTable();
|
||||
String tableName = table.getName();
|
||||
Optional<OperationResult> optionalOperationResult = ignoredTableColumns(tableName, operation);
|
||||
if (optionalOperationResult.isPresent()) {
|
||||
return optionalOperationResult.get();
|
||||
}
|
||||
OperationResult result = new OperationResult();
|
||||
result.setOperation(operation);
|
||||
result.setTableName(tableName);
|
||||
result.setRecordStatus(true);
|
||||
Map<String, Object> updatedColumnDatas = getUpdatedColumnDatas(tableName, boundSql, insertStmt);
|
||||
result.buildDataStr(compareAndGetUpdatedColumnDatas(result.getTableName(), null, updatedColumnDatas));
|
||||
return result;
|
||||
}
|
||||
|
||||
public OperationResult processUpdate(Update updateStmt, MappedStatement mappedStatement, BoundSql boundSql, Connection connection) {
|
||||
Expression where = updateStmt.getWhere();
|
||||
PlainSelect selectBody = new PlainSelect();
|
||||
Table table = updateStmt.getTable();
|
||||
String tableName = table.getName();
|
||||
String operation = SqlCommandType.UPDATE.name().toLowerCase();
|
||||
Optional<OperationResult> optionalOperationResult = ignoredTableColumns(tableName, operation);
|
||||
if (optionalOperationResult.isPresent()) {
|
||||
return optionalOperationResult.get();
|
||||
}
|
||||
selectBody.setFromItem(table);
|
||||
List<Column> updateColumns = new ArrayList<>();
|
||||
for (UpdateSet updateSet : updateStmt.getUpdateSets()) {
|
||||
updateColumns.addAll(updateSet.getColumns());
|
||||
}
|
||||
Columns2SelectItemsResult buildColumns2SelectItems = buildColumns2SelectItems(tableName, updateColumns);
|
||||
selectBody.setSelectItems(buildColumns2SelectItems.getSelectItems());
|
||||
selectBody.setWhere(where);
|
||||
SelectItem<PlainSelect> plainSelectSelectItem = new SelectItem<>(selectBody);
|
||||
|
||||
BoundSql boundSql4Select = new BoundSql(mappedStatement.getConfiguration(), plainSelectSelectItem.toString(),
|
||||
prepareParameterMapping4Select(boundSql.getParameterMappings(), updateStmt),
|
||||
boundSql.getParameterObject());
|
||||
PluginUtils.MPBoundSql mpBoundSql = PluginUtils.mpBoundSql(boundSql);
|
||||
Map<String, Object> additionalParameters = mpBoundSql.additionalParameters();
|
||||
if (additionalParameters != null && !additionalParameters.isEmpty()) {
|
||||
for (Map.Entry<String, Object> ety : additionalParameters.entrySet()) {
|
||||
boundSql4Select.setAdditionalParameter(ety.getKey(), ety.getValue());
|
||||
}
|
||||
}
|
||||
Map<String, Object> updatedColumnDatas = getUpdatedColumnDatas(tableName, boundSql, updateStmt);
|
||||
OriginalDataObj originalData = buildOriginalObjectData(updatedColumnDatas, selectBody, buildColumns2SelectItems.getPk(), mappedStatement, boundSql4Select, connection);
|
||||
OperationResult result = new OperationResult();
|
||||
result.setOperation(operation);
|
||||
result.setTableName(tableName);
|
||||
result.setRecordStatus(true);
|
||||
result.buildDataStr(compareAndGetUpdatedColumnDatas(result.getTableName(), originalData, updatedColumnDatas));
|
||||
return result;
|
||||
}
|
||||
|
||||
private Optional<OperationResult> ignoredTableColumns(String table, String operation) {
|
||||
final Set<String> ignoredColumns = ignoredTableColumns.get(table.toUpperCase());
|
||||
if (ignoredColumns != null) {
|
||||
if (ignoredColumns.stream().anyMatch("*"::equals)) {
|
||||
OperationResult result = new OperationResult();
|
||||
result.setOperation(operation);
|
||||
result.setTableName(table + ":*");
|
||||
result.setRecordStatus(false);
|
||||
return Optional.of(result);
|
||||
}
|
||||
}
|
||||
return Optional.empty();
|
||||
}
|
||||
|
||||
private TableInfo getTableInfoByTableName(String tableName) {
|
||||
for (TableInfo tableInfo : TableInfoHelper.getTableInfos()) {
|
||||
if (tableName.equalsIgnoreCase(tableInfo.getTableName())) {
|
||||
return tableInfo;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* 将update SET部分的jdbc参数去除
|
||||
*
|
||||
* @param originalMappingList 这里只会包含JdbcParameter参数
|
||||
* @param updateStmt
|
||||
* @return
|
||||
*/
|
||||
private List<ParameterMapping> prepareParameterMapping4Select(List<ParameterMapping> originalMappingList, Update updateStmt) {
|
||||
List<Expression> updateValueExpressions = new ArrayList<>();
|
||||
for (UpdateSet updateSet : updateStmt.getUpdateSets()) {
|
||||
updateValueExpressions.addAll(updateSet.getValues());
|
||||
}
|
||||
int removeParamCount = 0;
|
||||
for (Expression expression : updateValueExpressions) {
|
||||
if (expression instanceof JdbcParameter) {
|
||||
++removeParamCount;
|
||||
}
|
||||
}
|
||||
return originalMappingList.subList(removeParamCount, originalMappingList.size());
|
||||
}
|
||||
|
||||
protected Map<String, Object> getUpdatedColumnDatas(String tableName, BoundSql updateSql, Statement statement) {
|
||||
Map<String, Object> columnNameValMap = new HashMap<>(updateSql.getParameterMappings().size());
|
||||
Map<Integer, String> columnSetIndexMap = new HashMap<>(updateSql.getParameterMappings().size());
|
||||
List<Column> selectItemsFromUpdateSql = new ArrayList<>();
|
||||
if (statement instanceof Update) {
|
||||
Update updateStmt = (Update) statement;
|
||||
int index = 0;
|
||||
for (UpdateSet updateSet : updateStmt.getUpdateSets()) {
|
||||
selectItemsFromUpdateSql.addAll(updateSet.getColumns());
|
||||
final ExpressionList<Expression> updateList = (ExpressionList<Expression>) updateSet.getValues();
|
||||
for (int i = 0; i < updateList.size(); ++i) {
|
||||
Expression updateExps = updateList.get(i);
|
||||
if (!(updateExps instanceof JdbcParameter)) {
|
||||
columnNameValMap.put(updateSet.getColumns().get(i).getColumnName().toUpperCase(), updateExps.toString());
|
||||
}
|
||||
columnSetIndexMap.put(index++, updateSet.getColumns().get(i).getColumnName().toUpperCase());
|
||||
}
|
||||
}
|
||||
} else if (statement instanceof Insert) {
|
||||
Insert insert = (Insert) statement;
|
||||
selectItemsFromUpdateSql.addAll(insert.getColumns());
|
||||
columnNameValMap.putAll(detectInsertColumnValuesNonJdbcParameters(insert));
|
||||
}
|
||||
Map<String, String> relatedColumnsUpperCaseWithoutUnderline = new HashMap<>(selectItemsFromUpdateSql.size(), 1);
|
||||
for (Column item : selectItemsFromUpdateSql) {
|
||||
//FIRSTNAME: FIRST_NAME/FIRST-NAME/FIRST$NAME/FIRST.NAME
|
||||
relatedColumnsUpperCaseWithoutUnderline.put(item.getColumnName().replaceAll("[._\\-$]", "").toUpperCase(), item.getColumnName().toUpperCase());
|
||||
}
|
||||
MetaObject metaObject = SystemMetaObject.forObject(updateSql.getParameterObject());
|
||||
int index = 0;
|
||||
for (ParameterMapping parameterMapping : updateSql.getParameterMappings()) {
|
||||
String propertyName = parameterMapping.getProperty();
|
||||
if (propertyName.startsWith("ew.paramNameValuePairs")) {
|
||||
++index;
|
||||
continue;
|
||||
}
|
||||
String[] arr = propertyName.split("\\.");
|
||||
String propertyNameTrim = arr[arr.length - 1].replace("_", "").toUpperCase();
|
||||
try {
|
||||
final String columnName = columnSetIndexMap.getOrDefault(index++, getColumnNameByProperty(propertyNameTrim, tableName));
|
||||
if (relatedColumnsUpperCaseWithoutUnderline.containsKey(propertyNameTrim)) {
|
||||
final String colkey = relatedColumnsUpperCaseWithoutUnderline.get(propertyNameTrim);
|
||||
Object valObj = metaObject.getValue(propertyName);
|
||||
if (valObj instanceof IEnum) {
|
||||
valObj = ((IEnum<?>) valObj).getValue();
|
||||
} else if (valObj instanceof Enum) {
|
||||
valObj = getEnumValue((Enum) valObj);
|
||||
}
|
||||
if (columnNameValMap.containsKey(colkey)) {
|
||||
columnNameValMap.put(relatedColumnsUpperCaseWithoutUnderline.get(propertyNameTrim), String.valueOf(columnNameValMap.get(colkey)).replace("?", valObj == null ? "" : valObj.toString()));
|
||||
}
|
||||
if (columnName != null && !columnNameValMap.containsKey(columnName)) {
|
||||
columnNameValMap.put(columnName, valObj);
|
||||
}
|
||||
} else {
|
||||
if (columnName != null) {
|
||||
columnNameValMap.put(columnName, metaObject.getValue(propertyName));
|
||||
}
|
||||
}
|
||||
} catch (Exception e) {
|
||||
logger.warn("get value error,propertyName:{},parameterMapping:{}", propertyName, parameterMapping);
|
||||
}
|
||||
}
|
||||
dealWithUpdateWrapper(columnSetIndexMap, columnNameValMap, updateSql);
|
||||
return columnNameValMap;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param originalDataObj
|
||||
* @return
|
||||
*/
|
||||
private List<DataChangedRecord> compareAndGetUpdatedColumnDatas(String tableName, OriginalDataObj originalDataObj, Map<String, Object> columnNameValMap) {
|
||||
final Set<String> ignoredColumns = ignoredTableColumns.get(tableName.toUpperCase());
|
||||
if (originalDataObj == null || originalDataObj.isEmpty()) {
|
||||
DataChangedRecord oneRecord = new DataChangedRecord();
|
||||
List<DataColumnChangeResult> updateColumns = new ArrayList<>(columnNameValMap.size());
|
||||
for (Map.Entry<String, Object> ety : columnNameValMap.entrySet()) {
|
||||
String columnName = ety.getKey();
|
||||
if ((ignoredColumns == null || !ignoredColumns.contains(columnName)) && !ignoreAllColumns.contains(columnName)) {
|
||||
updateColumns.add(DataColumnChangeResult.constrcutByUpdateVal(columnName, ety.getValue()));
|
||||
}
|
||||
}
|
||||
oneRecord.setUpdatedColumns(updateColumns);
|
||||
// oneRecord.setUpdatedColumns(Collections.EMPTY_LIST);
|
||||
return Collections.singletonList(oneRecord);
|
||||
}
|
||||
List<DataChangedRecord> originalDataList = originalDataObj.getOriginalDataObj();
|
||||
List<DataChangedRecord> updateDataList = new ArrayList<>(originalDataList.size());
|
||||
for (DataChangedRecord originalData : originalDataList) {
|
||||
if (originalData.hasUpdate(columnNameValMap, ignoredColumns, ignoreAllColumns)) {
|
||||
updateDataList.add(originalData);
|
||||
}
|
||||
}
|
||||
return updateDataList;
|
||||
}
|
||||
|
||||
private Object getEnumValue(Enum enumVal) {
|
||||
Optional<String> enumValueFieldName = MybatisEnumTypeHandler.findEnumValueFieldName(enumVal.getClass());
|
||||
if (enumValueFieldName.isPresent()) {
|
||||
return SystemMetaObject.forObject(enumVal).getValue(enumValueFieldName.get());
|
||||
}
|
||||
return enumVal;
|
||||
|
||||
}
|
||||
|
||||
@SuppressWarnings("rawtypes")
|
||||
private void dealWithUpdateWrapper(Map<Integer, String> columnSetIndexMap, Map<String, Object> columnNameValMap, BoundSql updateSql) {
|
||||
if (columnSetIndexMap.size() <= columnNameValMap.size()) {
|
||||
return;
|
||||
}
|
||||
MetaObject mpgenVal = SystemMetaObject.forObject(updateSql.getParameterObject());
|
||||
if(!mpgenVal.hasGetter(Constants.WRAPPER)){
|
||||
return;
|
||||
}
|
||||
Object ew = mpgenVal.getValue(Constants.WRAPPER);
|
||||
if (ew instanceof UpdateWrapper || ew instanceof LambdaUpdateWrapper) {
|
||||
final String sqlSet = ew instanceof UpdateWrapper ? ((UpdateWrapper) ew).getSqlSet() : ((LambdaUpdateWrapper) ew).getSqlSet();// columnName=#{val}
|
||||
if (sqlSet == null) {
|
||||
return;
|
||||
}
|
||||
MetaObject ewMeta = SystemMetaObject.forObject(ew);
|
||||
final Map paramNameValuePairs = (Map) ewMeta.getValue("paramNameValuePairs");
|
||||
String[] setItems = sqlSet.split(",");
|
||||
for (String setItem : setItems) {
|
||||
//age=#{ew.paramNameValuePairs.MPGENVAL1}
|
||||
String[] nameAndValuePair = setItem.split("=", 2);
|
||||
if (nameAndValuePair.length == 2) {
|
||||
String setColName = nameAndValuePair[0].trim().toUpperCase();
|
||||
String setColVal = nameAndValuePair[1].trim();//#{.mp}
|
||||
if (columnSetIndexMap.containsValue(setColName)) {
|
||||
String[] mpGenKeyArray = setColVal.split("\\.");
|
||||
String mpGenKey = mpGenKeyArray[mpGenKeyArray.length - 1].replace("}", "");
|
||||
final Object setVal = paramNameValuePairs.get(mpGenKey);
|
||||
if (setVal instanceof IEnum) {
|
||||
columnNameValMap.put(setColName, String.valueOf(((IEnum<?>) setVal).getValue()));
|
||||
} else {
|
||||
columnNameValMap.put(setColName, setVal);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private Map<String, String> detectInsertColumnValuesNonJdbcParameters(Insert insert) {
|
||||
Map<String, String> columnNameValMap = new HashMap<>(4);
|
||||
final Select select = insert.getSelect();
|
||||
List<Column> columns = insert.getColumns();
|
||||
if (select instanceof SetOperationList) {
|
||||
SetOperationList setOperationList = (SetOperationList) select;
|
||||
final List<Select> selects = setOperationList.getSelects();
|
||||
if (CollectionUtils.isEmpty(selects)) {
|
||||
return columnNameValMap;
|
||||
}
|
||||
final Select selectBody = selects.get(0);
|
||||
if (!(selectBody instanceof Values)) {
|
||||
return columnNameValMap;
|
||||
}
|
||||
Values valuesStatement = (Values) selectBody;
|
||||
if (valuesStatement.getExpressions() instanceof ExpressionList) {
|
||||
ExpressionList expressionList = valuesStatement.getExpressions();
|
||||
List<Expression> expressions = expressionList;
|
||||
for (Expression expression : expressions) {
|
||||
if (expression instanceof RowConstructor) {
|
||||
final ExpressionList exprList = ((RowConstructor) expression);
|
||||
final List<Expression> insertExpList = exprList;
|
||||
for (int i = 0; i < insertExpList.size(); ++i) {
|
||||
Expression e = insertExpList.get(i);
|
||||
if (!(e instanceof JdbcParameter)) {
|
||||
final String columnName = columns.get(i).getColumnName();
|
||||
final String val = e.toString();
|
||||
columnNameValMap.put(columnName, val);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return columnNameValMap;
|
||||
}
|
||||
|
||||
private String getColumnNameByProperty(String propertyName, String tableName) {
|
||||
for (TableInfo tableInfo : TableInfoHelper.getTableInfos()) {
|
||||
if (tableName.equalsIgnoreCase(tableInfo.getTableName())) {
|
||||
final List<TableFieldInfo> fieldList = tableInfo.getFieldList();
|
||||
if (CollectionUtils.isEmpty(fieldList)) {
|
||||
return propertyName;
|
||||
}
|
||||
for (TableFieldInfo tableFieldInfo : fieldList) {
|
||||
if (propertyName.equalsIgnoreCase(tableFieldInfo.getProperty())) {
|
||||
return tableFieldInfo.getColumn().toUpperCase();
|
||||
}
|
||||
}
|
||||
return propertyName;
|
||||
}
|
||||
}
|
||||
return propertyName;
|
||||
}
|
||||
|
||||
|
||||
private Map<String, Object> buildParameterObjectMap(BoundSql boundSql) {
|
||||
MetaObject metaObject = PluginUtils.getMetaObject(boundSql.getParameterObject());
|
||||
Map<String, Object> propertyValMap = new HashMap<>(boundSql.getParameterMappings().size());
|
||||
for (ParameterMapping parameterMapping : boundSql.getParameterMappings()) {
|
||||
String propertyName = parameterMapping.getProperty();
|
||||
if (propertyName.startsWith("ew.paramNameValuePairs")) {
|
||||
continue;
|
||||
}
|
||||
Object propertyValue = metaObject.getValue(propertyName);
|
||||
propertyValMap.put(propertyName, propertyValue);
|
||||
}
|
||||
return propertyValMap;
|
||||
|
||||
}
|
||||
|
||||
|
||||
private String buildOriginalData(Select selectStmt, MappedStatement mappedStatement, BoundSql boundSql, Connection connection) {
|
||||
try (PreparedStatement statement = connection.prepareStatement(selectStmt.toString())) {
|
||||
DefaultParameterHandler parameterHandler = new DefaultParameterHandler(mappedStatement, boundSql.getParameterObject(), boundSql);
|
||||
parameterHandler.setParameters(statement);
|
||||
ResultSet resultSet = statement.executeQuery();
|
||||
final ResultSetMetaData metaData = resultSet.getMetaData();
|
||||
int columnCount = metaData.getColumnCount();
|
||||
StringBuilder sb = new StringBuilder("[");
|
||||
int count = 0;
|
||||
while (resultSet.next()) {
|
||||
++count;
|
||||
if (checkTableBatchLimitExceeded(selectStmt, count)) {
|
||||
logger.error("batch delete limit exceed: count={}, BATCH_UPDATE_LIMIT={}", count, BATCH_UPDATE_LIMIT);
|
||||
throw DataUpdateLimitationException.DEFAULT;
|
||||
}
|
||||
sb.append("{");
|
||||
for (int i = 1; i <= columnCount; ++i) {
|
||||
sb.append("\"").append(metaData.getColumnName(i)).append("\":\"");
|
||||
Object res = resultSet.getObject(i);
|
||||
if (res instanceof Clob) {
|
||||
sb.append(DataColumnChangeResult.convertClob((Clob) res));
|
||||
} else {
|
||||
sb.append(res);
|
||||
}
|
||||
sb.append("\",");
|
||||
}
|
||||
sb.replace(sb.length() - 1, sb.length(), "}");
|
||||
}
|
||||
sb.append("]");
|
||||
resultSet.close();
|
||||
return sb.toString();
|
||||
} catch (Exception e) {
|
||||
if (e instanceof DataUpdateLimitationException) {
|
||||
throw (DataUpdateLimitationException) e;
|
||||
}
|
||||
logger.error("try to get record tobe deleted for selectStmt={}", selectStmt, e);
|
||||
return "failed to get original data";
|
||||
}
|
||||
}
|
||||
|
||||
private OriginalDataObj buildOriginalObjectData(Map<String, Object> updatedColumnDatas, Select selectStmt, Column pk, MappedStatement mappedStatement, BoundSql boundSql, Connection connection) {
|
||||
try (PreparedStatement statement = connection.prepareStatement(selectStmt.toString())) {
|
||||
DefaultParameterHandler parameterHandler = new DefaultParameterHandler(mappedStatement, boundSql.getParameterObject(), boundSql);
|
||||
parameterHandler.setParameters(statement);
|
||||
ResultSet resultSet = statement.executeQuery();
|
||||
List<DataChangedRecord> originalObjectDatas = new LinkedList<>();
|
||||
int count = 0;
|
||||
|
||||
while (resultSet.next()) {
|
||||
++count;
|
||||
if (checkTableBatchLimitExceeded(selectStmt, count)) {
|
||||
logger.error("batch update limit exceed: count={}, BATCH_UPDATE_LIMIT={}", count, BATCH_UPDATE_LIMIT);
|
||||
throw DataUpdateLimitationException.DEFAULT;
|
||||
}
|
||||
originalObjectDatas.add(prepareOriginalDataObj(updatedColumnDatas, resultSet, pk));
|
||||
}
|
||||
OriginalDataObj result = new OriginalDataObj();
|
||||
result.setOriginalDataObj(originalObjectDatas);
|
||||
resultSet.close();
|
||||
return result;
|
||||
} catch (Exception e) {
|
||||
if (e instanceof DataUpdateLimitationException) {
|
||||
throw (DataUpdateLimitationException) e;
|
||||
}
|
||||
logger.error("try to get record tobe updated for selectStmt={}", selectStmt, e);
|
||||
return new OriginalDataObj();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 防止出现全表批量更新
|
||||
* 默认一次更新不超过1000条
|
||||
*
|
||||
* @param selectStmt
|
||||
* @param count
|
||||
* @return
|
||||
*/
|
||||
private boolean checkTableBatchLimitExceeded(Select selectStmt, int count) {
|
||||
if (!batchUpdateLimitationOpened) {
|
||||
return false;
|
||||
}
|
||||
final PlainSelect selectBody = (PlainSelect) selectStmt;
|
||||
final FromItem fromItem = selectBody.getFromItem();
|
||||
if (fromItem instanceof Table) {
|
||||
Table fromTable = (Table) fromItem;
|
||||
final String tableName = fromTable.getName().toUpperCase();
|
||||
if (!BATCH_UPDATE_LIMIT_MAP.containsKey(tableName)) {
|
||||
if (count > BATCH_UPDATE_LIMIT) {
|
||||
logger.error("batch update limit exceed for tableName={}, BATCH_UPDATE_LIMIT={}, count={}",
|
||||
tableName, BATCH_UPDATE_LIMIT, count);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
final Integer limit = BATCH_UPDATE_LIMIT_MAP.get(tableName);
|
||||
if (count > limit) {
|
||||
logger.error("batch update limit exceed for configured tableName={}, BATCH_UPDATE_LIMIT={}, count={}",
|
||||
tableName, limit, count);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return count > BATCH_UPDATE_LIMIT;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* get records : include related column with original data in DB
|
||||
*
|
||||
* @param resultSet
|
||||
* @param pk
|
||||
* @return
|
||||
* @throws SQLException
|
||||
*/
|
||||
private DataChangedRecord prepareOriginalDataObj(Map<String, Object> updatedColumnDatas, ResultSet resultSet, Column pk) throws SQLException {
|
||||
final ResultSetMetaData metaData = resultSet.getMetaData();
|
||||
int columnCount = metaData.getColumnCount();
|
||||
List<DataColumnChangeResult> originalColumnDatas = new LinkedList<>();
|
||||
DataColumnChangeResult pkval = null;
|
||||
for (int i = 1; i <= columnCount; ++i) {
|
||||
String columnName = metaData.getColumnName(i).toUpperCase();
|
||||
DataColumnChangeResult col;
|
||||
Object updateVal = updatedColumnDatas.get(columnName);
|
||||
if (updateVal != null && updateVal.getClass().getCanonicalName().startsWith("java.")) {
|
||||
col = DataColumnChangeResult.constrcutByOriginalVal(columnName, resultSet.getObject(i, updateVal.getClass()));
|
||||
} else {
|
||||
col = DataColumnChangeResult.constrcutByOriginalVal(columnName, resultSet.getObject(i));
|
||||
}
|
||||
if (pk != null && columnName.equalsIgnoreCase(pk.getColumnName())) {
|
||||
pkval = col;
|
||||
} else {
|
||||
originalColumnDatas.add(col);
|
||||
}
|
||||
}
|
||||
DataChangedRecord changedRecord = new DataChangedRecord();
|
||||
changedRecord.setOriginalColumnDatas(originalColumnDatas);
|
||||
if (pkval != null) {
|
||||
changedRecord.setPkColumnName(pkval.getColumnName());
|
||||
changedRecord.setPkColumnVal(pkval.getOriginalValue());
|
||||
}
|
||||
return changedRecord;
|
||||
}
|
||||
|
||||
|
||||
private Columns2SelectItemsResult buildColumns2SelectItems(String tableName, List<Column> columns) {
|
||||
if (columns == null || columns.isEmpty()) {
|
||||
return Columns2SelectItemsResult.build(Collections.singletonList(new SelectItem<>(new AllColumns())), 0);
|
||||
}
|
||||
List<SelectItem<?>> selectItems = new ArrayList<>(columns.size());
|
||||
for (Column column : columns) {
|
||||
selectItems.add(new SelectItem<>(column));
|
||||
}
|
||||
TableInfo tableInfo = getTableInfoByTableName(tableName);
|
||||
if (tableInfo == null) {
|
||||
return Columns2SelectItemsResult.build(selectItems, 0);
|
||||
}
|
||||
Column pk = new Column(tableInfo.getKeyColumn());
|
||||
selectItems.add(new SelectItem<>(pk));
|
||||
Columns2SelectItemsResult result = Columns2SelectItemsResult.build(selectItems, 1);
|
||||
result.setPk(pk);
|
||||
return result;
|
||||
}
|
||||
|
||||
private String buildParameterObject(BoundSql boundSql) {
|
||||
Object paramObj = boundSql.getParameterObject();
|
||||
StringBuilder sb = new StringBuilder();
|
||||
sb.append("{");
|
||||
if (paramObj instanceof Map) {
|
||||
Map<String, Object> paramMap = (Map<String, Object>) paramObj;
|
||||
int index = 1;
|
||||
boolean hasParamIndex = false;
|
||||
String key;
|
||||
while (paramMap.containsKey((key = "param" + index))) {
|
||||
Object paramIndex = paramMap.get(key);
|
||||
sb.append("\"").append(key).append("\"").append(":").append("\"").append(paramIndex).append("\"").append(",");
|
||||
hasParamIndex = true;
|
||||
++index;
|
||||
}
|
||||
if (hasParamIndex) {
|
||||
sb.delete(sb.length() - 1, sb.length());
|
||||
sb.append("}");
|
||||
return sb.toString();
|
||||
}
|
||||
for (Map.Entry<String, Object> ety : paramMap.entrySet()) {
|
||||
sb.append("\"").append(ety.getKey()).append("\"").append(":").append("\"").append(ety.getValue()).append("\"").append(",");
|
||||
}
|
||||
sb.delete(sb.length() - 1, sb.length());
|
||||
sb.append("}");
|
||||
return sb.toString();
|
||||
}
|
||||
sb.append("param:").append(paramObj);
|
||||
sb.append("}");
|
||||
return sb.toString();
|
||||
}
|
||||
|
||||
public OperationResult processDelete(Delete deleteStmt, MappedStatement mappedStatement, BoundSql boundSql, Connection connection) {
|
||||
Table table = deleteStmt.getTable();
|
||||
Expression where = deleteStmt.getWhere();
|
||||
PlainSelect selectBody = new PlainSelect();
|
||||
selectBody.setFromItem(table);
|
||||
selectBody.setSelectItems(Collections.singletonList(new SelectItem<>((new AllColumns()))));
|
||||
selectBody.setWhere(where);
|
||||
String originalData = buildOriginalData(selectBody, mappedStatement, boundSql, connection);
|
||||
OperationResult result = new OperationResult();
|
||||
result.setOperation("delete");
|
||||
result.setTableName(table.getName());
|
||||
result.setRecordStatus(originalData.startsWith("["));
|
||||
result.setChangedData(originalData);
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* 设置批量更新记录条数上限
|
||||
*
|
||||
* @param limit
|
||||
* @return
|
||||
*/
|
||||
public DataChangeRecorderInnerInterceptor setBatchUpdateLimit(int limit) {
|
||||
this.BATCH_UPDATE_LIMIT = limit;
|
||||
return this;
|
||||
}
|
||||
|
||||
public DataChangeRecorderInnerInterceptor openBatchUpdateLimitation() {
|
||||
this.batchUpdateLimitationOpened = true;
|
||||
return this;
|
||||
}
|
||||
|
||||
public DataChangeRecorderInnerInterceptor configTableLimitation(String tableName, int limit) {
|
||||
this.BATCH_UPDATE_LIMIT_MAP.put(tableName.toUpperCase(), limit);
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* ignoredColumns = TABLE_NAME1.COLUMN1,COLUMN2; TABLE2.COLUMN1,COLUMN2; TABLE3.*; *.COLUMN1,COLUMN2
|
||||
* 多个表用分号分隔
|
||||
* TABLE_NAME1.COLUMN1,COLUMN2 : 表示忽略这个表的这2个字段
|
||||
* TABLE3.*: 表示忽略这张表的INSERT/UPDATE,delete暂时还保留
|
||||
* *.COLUMN1,COLUMN2:表示所有表的这个2个字段名都忽略
|
||||
*
|
||||
* @param properties
|
||||
*/
|
||||
@Override
|
||||
public void setProperties(Properties properties) {
|
||||
|
||||
String ignoredTableColumns = properties.getProperty("ignoredTableColumns");
|
||||
if (ignoredTableColumns == null || ignoredTableColumns.trim().isEmpty()) {
|
||||
return;
|
||||
}
|
||||
String[] array = ignoredTableColumns.split(";");
|
||||
for (String table : array) {
|
||||
int index = table.indexOf(".");
|
||||
if (index == -1) {
|
||||
logger.warn("invalid data={} for ignoredColumns, format should be TABLE_NAME1.COLUMN1,COLUMN2; TABLE2.COLUMN1,COLUMN2;", table);
|
||||
continue;
|
||||
}
|
||||
String tableName = table.substring(0, index).trim().toUpperCase();
|
||||
String[] columnArray = table.substring(index + 1).split(",");
|
||||
Set<String> columnSet = new HashSet<>(columnArray.length);
|
||||
for (String column : columnArray) {
|
||||
column = column.trim().toUpperCase();
|
||||
if (column.isEmpty()) {
|
||||
continue;
|
||||
}
|
||||
columnSet.add(column);
|
||||
}
|
||||
if ("*".equals(tableName)) {
|
||||
ignoreAllColumns.addAll(columnSet);
|
||||
} else {
|
||||
this.ignoredTableColumns.put(tableName, columnSet);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Data
|
||||
public static class OperationResult {
|
||||
|
||||
private String operation;
|
||||
private boolean recordStatus;
|
||||
private String tableName;
|
||||
private String changedData;
|
||||
/**
|
||||
* cost for this plugin, ms
|
||||
*/
|
||||
private long cost;
|
||||
|
||||
public void buildDataStr(List<DataChangedRecord> records) {
|
||||
StringBuilder sb = new StringBuilder();
|
||||
sb.append("[");
|
||||
for (DataChangedRecord r : records) {
|
||||
sb.append(r.generateUpdatedDataStr()).append(",");
|
||||
}
|
||||
if (sb.length() == 1) {
|
||||
sb.append("]");
|
||||
changedData = sb.toString();
|
||||
return;
|
||||
}
|
||||
sb.replace(sb.length() - 1, sb.length(), "]");
|
||||
changedData = sb.toString();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "{" +
|
||||
"\"tableName\":\"" + tableName + "\"," +
|
||||
"\"operation\":\"" + operation + "\"," +
|
||||
"\"recordStatus\":\"" + recordStatus + "\"," +
|
||||
"\"changedData\":" + changedData + "," +
|
||||
"\"cost(ms)\":" + cost + "}";
|
||||
}
|
||||
}
|
||||
|
||||
@Data
|
||||
public static class Columns2SelectItemsResult {
|
||||
|
||||
private Column pk;
|
||||
/**
|
||||
* all column with additional columns: ID, etc.
|
||||
*/
|
||||
private List<SelectItem<?>> selectItems;
|
||||
/**
|
||||
* newly added column count from meta data.
|
||||
*/
|
||||
private int additionalItemCount;
|
||||
|
||||
public static Columns2SelectItemsResult build(List<SelectItem<?>> selectItems, int additionalItemCount) {
|
||||
Columns2SelectItemsResult result = new Columns2SelectItemsResult();
|
||||
result.setSelectItems(selectItems);
|
||||
result.setAdditionalItemCount(additionalItemCount);
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
@Data
|
||||
public static class OriginalDataObj {
|
||||
|
||||
private List<DataChangedRecord> originalDataObj;
|
||||
|
||||
public boolean isEmpty() {
|
||||
return originalDataObj == null || originalDataObj.isEmpty();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@Data
|
||||
public static class DataColumnChangeResult {
|
||||
|
||||
private String columnName;
|
||||
private Object originalValue;
|
||||
private Object updateValue;
|
||||
|
||||
@SuppressWarnings("rawtypes")
|
||||
public boolean isDataChanged(Object updateValue) {
|
||||
if (!Objects.equals(originalValue, updateValue)) {
|
||||
if (originalValue instanceof Clob) {
|
||||
String originalStr = convertClob((Clob) originalValue);
|
||||
setOriginalValue(originalStr);
|
||||
return !originalStr.equals(updateValue);
|
||||
}
|
||||
if (originalValue instanceof Comparable) {
|
||||
Comparable original = (Comparable) originalValue;
|
||||
Comparable update = (Comparable) updateValue;
|
||||
try {
|
||||
return update == null || original.compareTo(update) != 0;
|
||||
} catch (Exception e) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
public static String convertClob(Clob clobObj) {
|
||||
try {
|
||||
return clobObj.getSubString(0, (int) clobObj.length());
|
||||
} catch (Exception e) {
|
||||
try (Reader is = clobObj.getCharacterStream()) {
|
||||
char[] chars = new char[64];
|
||||
int readChars;
|
||||
StringBuilder sb = new StringBuilder();
|
||||
while ((readChars = is.read(chars)) != -1) {
|
||||
sb.append(chars, 0, readChars);
|
||||
}
|
||||
return sb.toString();
|
||||
} catch (Exception e2) {
|
||||
//ignored
|
||||
return "unknown clobObj";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public static DataColumnChangeResult constrcutByUpdateVal(String columnName, Object updateValue) {
|
||||
DataColumnChangeResult res = new DataColumnChangeResult();
|
||||
res.setColumnName(columnName);
|
||||
res.setUpdateValue(updateValue);
|
||||
return res;
|
||||
}
|
||||
|
||||
public static DataColumnChangeResult constrcutByOriginalVal(String columnName, Object originalValue) {
|
||||
DataColumnChangeResult res = new DataColumnChangeResult();
|
||||
res.setColumnName(columnName);
|
||||
res.setOriginalValue(originalValue);
|
||||
return res;
|
||||
}
|
||||
|
||||
public String generateDataStr() {
|
||||
StringBuilder sb = new StringBuilder();
|
||||
sb.append("\"").append(columnName).append("\"").append(":").append("\"").append(convertDoubleQuotes(originalValue)).append("->").append(convertDoubleQuotes(updateValue)).append("\"").append(",");
|
||||
return sb.toString();
|
||||
}
|
||||
|
||||
public String convertDoubleQuotes(Object obj) {
|
||||
if (obj == null) {
|
||||
return null;
|
||||
}
|
||||
return obj.toString().replace("\"", "\\\"");
|
||||
}
|
||||
}
|
||||
|
||||
@Data
|
||||
public static class DataChangedRecord {
|
||||
|
||||
private String pkColumnName;
|
||||
private Object pkColumnVal;
|
||||
private List<DataColumnChangeResult> originalColumnDatas;
|
||||
private List<DataColumnChangeResult> updatedColumns;
|
||||
|
||||
public boolean hasUpdate(Map<String, Object> columnNameValMap, Set<String> ignoredColumns, Set<String> ignoreAllColumns) {
|
||||
if (originalColumnDatas == null) {
|
||||
return true;
|
||||
}
|
||||
boolean hasUpdate = false;
|
||||
updatedColumns = new ArrayList<>(originalColumnDatas.size());
|
||||
for (DataColumnChangeResult originalColumn : originalColumnDatas) {
|
||||
final String columnName = originalColumn.getColumnName().toUpperCase();
|
||||
if (ignoredColumns != null && ignoredColumns.contains(columnName) || ignoreAllColumns.contains(columnName)) {
|
||||
continue;
|
||||
}
|
||||
Object updatedValue = columnNameValMap.get(columnName);
|
||||
if (originalColumn.isDataChanged(updatedValue)) {
|
||||
hasUpdate = true;
|
||||
originalColumn.setUpdateValue(updatedValue);
|
||||
updatedColumns.add(originalColumn);
|
||||
}
|
||||
}
|
||||
return hasUpdate;
|
||||
}
|
||||
|
||||
public String generateUpdatedDataStr() {
|
||||
StringBuilder sb = new StringBuilder();
|
||||
sb.append("{");
|
||||
if (pkColumnName != null) {
|
||||
sb.append("\"").append(pkColumnName).append("\"").append(":").append("\"").append(convertDoubleQuotes(pkColumnVal)).append("\"").append(",");
|
||||
}
|
||||
for (DataColumnChangeResult update : updatedColumns) {
|
||||
sb.append(update.generateDataStr());
|
||||
}
|
||||
sb.replace(sb.length() - 1, sb.length(), "}");
|
||||
return sb.toString();
|
||||
}
|
||||
|
||||
public String convertDoubleQuotes(Object obj) {
|
||||
if (obj == null) {
|
||||
return null;
|
||||
}
|
||||
return obj.toString().replace("\"", "\\\"");
|
||||
}
|
||||
}
|
||||
|
||||
public static class DataUpdateLimitationException extends MybatisPlusException {
|
||||
|
||||
public DataUpdateLimitationException(String message) {
|
||||
super(message);
|
||||
}
|
||||
|
||||
public static DataUpdateLimitationException DEFAULT = new DataUpdateLimitationException("本次操作 因超过系统安全阈值 被拦截,如需继续,请联系管理员!");
|
||||
}
|
||||
}
|
@ -0,0 +1,175 @@
|
||||
/*
|
||||
* Copyright (c) 2011-2024, baomidou (jobob@qq.com).
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package com.baomidou.mybatisplus.extension.plugins.inner;
|
||||
|
||||
import java.sql.Connection;
|
||||
import java.sql.SQLException;
|
||||
import java.util.List;
|
||||
|
||||
import com.baomidou.mybatisplus.extension.plugins.handler.DataPermissionHandler;
|
||||
import com.baomidou.mybatisplus.extension.plugins.handler.MultiDataPermissionHandler;
|
||||
import org.apache.ibatis.executor.Executor;
|
||||
import org.apache.ibatis.executor.statement.StatementHandler;
|
||||
import org.apache.ibatis.mapping.BoundSql;
|
||||
import org.apache.ibatis.mapping.MappedStatement;
|
||||
import org.apache.ibatis.mapping.SqlCommandType;
|
||||
import org.apache.ibatis.session.ResultHandler;
|
||||
import org.apache.ibatis.session.RowBounds;
|
||||
|
||||
import com.baomidou.mybatisplus.core.plugins.InterceptorIgnoreHelper;
|
||||
import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
|
||||
import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.ToString;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
import net.sf.jsqlparser.schema.Table;
|
||||
import net.sf.jsqlparser.statement.delete.Delete;
|
||||
import net.sf.jsqlparser.statement.select.PlainSelect;
|
||||
import net.sf.jsqlparser.statement.select.Select;
|
||||
import net.sf.jsqlparser.statement.select.SetOperationList;
|
||||
import net.sf.jsqlparser.statement.select.WithItem;
|
||||
import net.sf.jsqlparser.statement.update.Update;
|
||||
|
||||
/**
|
||||
* 数据权限处理器
|
||||
*
|
||||
* @author hubin
|
||||
* @since 3.5.2
|
||||
*/
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
@ToString(callSuper = true)
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@SuppressWarnings({"rawtypes"})
|
||||
public class DataPermissionInterceptor extends BaseMultiTableInnerInterceptor implements InnerInterceptor {
|
||||
|
||||
private DataPermissionHandler dataPermissionHandler;
|
||||
|
||||
@SuppressWarnings("RedundantThrows")
|
||||
@Override
|
||||
public void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) throws SQLException {
|
||||
if (InterceptorIgnoreHelper.willIgnoreDataPermission(ms.getId())) {
|
||||
return;
|
||||
}
|
||||
PluginUtils.MPBoundSql mpBs = PluginUtils.mpBoundSql(boundSql);
|
||||
mpBs.sql(parserSingle(mpBs.sql(), ms.getId()));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void beforePrepare(StatementHandler sh, Connection connection, Integer transactionTimeout) {
|
||||
PluginUtils.MPStatementHandler mpSh = PluginUtils.mpStatementHandler(sh);
|
||||
MappedStatement ms = mpSh.mappedStatement();
|
||||
SqlCommandType sct = ms.getSqlCommandType();
|
||||
if (sct == SqlCommandType.UPDATE || sct == SqlCommandType.DELETE) {
|
||||
if (InterceptorIgnoreHelper.willIgnoreDataPermission(ms.getId())) {
|
||||
return;
|
||||
}
|
||||
PluginUtils.MPBoundSql mpBs = mpSh.mPBoundSql();
|
||||
mpBs.sql(parserMulti(mpBs.sql(), ms.getId()));
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void processSelect(Select select, int index, String sql, Object obj) {
|
||||
if (dataPermissionHandler == null) {
|
||||
return;
|
||||
}
|
||||
if (dataPermissionHandler instanceof MultiDataPermissionHandler) {
|
||||
// 参照 com.baomidou.mybatisplus.extension.plugins.inner.TenantLineInnerInterceptor.processSelect 做的修改
|
||||
final String whereSegment = (String) obj;
|
||||
processSelectBody(select, whereSegment);
|
||||
List<WithItem> withItemsList = select.getWithItemsList();
|
||||
if (!CollectionUtils.isEmpty(withItemsList)) {
|
||||
withItemsList.forEach(withItem -> processSelectBody(withItem, whereSegment));
|
||||
}
|
||||
} else {
|
||||
// 兼容原来的旧版 DataPermissionHandler 场景
|
||||
if (select instanceof PlainSelect) {
|
||||
this.setWhere((PlainSelect) select, (String) obj);
|
||||
} else if (select instanceof SetOperationList) {
|
||||
SetOperationList setOperationList = (SetOperationList) select;
|
||||
List<Select> selectBodyList = setOperationList.getSelects();
|
||||
selectBodyList.forEach(s -> this.setWhere((PlainSelect) s, (String) obj));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 设置 where 条件
|
||||
*
|
||||
* @param plainSelect 查询对象
|
||||
* @param whereSegment 查询条件片段
|
||||
*/
|
||||
protected void setWhere(PlainSelect plainSelect, String whereSegment) {
|
||||
if (dataPermissionHandler == null) {
|
||||
return;
|
||||
}
|
||||
// 兼容旧版的数据权限处理
|
||||
final Expression sqlSegment = dataPermissionHandler.getSqlSegment(plainSelect.getWhere(), whereSegment);
|
||||
if (null != sqlSegment) {
|
||||
plainSelect.setWhere(sqlSegment);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* update 语句处理
|
||||
*/
|
||||
@Override
|
||||
protected void processUpdate(Update update, int index, String sql, Object obj) {
|
||||
final Expression sqlSegment = getUpdateOrDeleteExpression(update.getTable(), update.getWhere(), (String) obj);
|
||||
if (null != sqlSegment) {
|
||||
update.setWhere(sqlSegment);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* delete 语句处理
|
||||
*/
|
||||
@Override
|
||||
protected void processDelete(Delete delete, int index, String sql, Object obj) {
|
||||
final Expression sqlSegment = getUpdateOrDeleteExpression(delete.getTable(), delete.getWhere(), (String) obj);
|
||||
if (null != sqlSegment) {
|
||||
delete.setWhere(sqlSegment);
|
||||
}
|
||||
}
|
||||
|
||||
protected Expression getUpdateOrDeleteExpression(final Table table, final Expression where, final String whereSegment) {
|
||||
if (dataPermissionHandler == null) {
|
||||
return null;
|
||||
}
|
||||
if (dataPermissionHandler instanceof MultiDataPermissionHandler) {
|
||||
return andExpression(table, where, whereSegment);
|
||||
} else {
|
||||
// 兼容旧版的数据权限处理
|
||||
return dataPermissionHandler.getSqlSegment(where, whereSegment);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression buildTableExpression(final Table table, final Expression where, final String whereSegment) {
|
||||
if (dataPermissionHandler == null) {
|
||||
return null;
|
||||
}
|
||||
// 只有新版数据权限处理器才会执行到这里
|
||||
final MultiDataPermissionHandler handler = (MultiDataPermissionHandler) dataPermissionHandler;
|
||||
return handler.getSqlSegment(table, where, whereSegment);
|
||||
}
|
||||
}
|
@ -0,0 +1,104 @@
|
||||
/*
|
||||
* Copyright (c) 2011-2024, baomidou (jobob@qq.com).
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package com.baomidou.mybatisplus.extension.plugins.inner;
|
||||
|
||||
import com.baomidou.mybatisplus.core.plugins.InterceptorIgnoreHelper;
|
||||
import com.baomidou.mybatisplus.core.toolkit.ExceptionUtils;
|
||||
import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
|
||||
import com.baomidou.mybatisplus.core.toolkit.TableNameParser;
|
||||
import com.baomidou.mybatisplus.extension.plugins.handler.TableNameHandler;
|
||||
import lombok.Getter;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.Setter;
|
||||
import org.apache.ibatis.executor.Executor;
|
||||
import org.apache.ibatis.executor.statement.StatementHandler;
|
||||
import org.apache.ibatis.mapping.BoundSql;
|
||||
import org.apache.ibatis.mapping.MappedStatement;
|
||||
import org.apache.ibatis.mapping.SqlCommandType;
|
||||
import org.apache.ibatis.session.ResultHandler;
|
||||
import org.apache.ibatis.session.RowBounds;
|
||||
|
||||
import java.sql.Connection;
|
||||
import java.sql.SQLException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* 动态表名
|
||||
*
|
||||
* @author jobob
|
||||
* @since 3.4.0
|
||||
*/
|
||||
@Getter
|
||||
@Setter
|
||||
@NoArgsConstructor
|
||||
@SuppressWarnings({"rawtypes"})
|
||||
public class DynamicTableNameInnerInterceptor implements InnerInterceptor {
|
||||
private Runnable hook;
|
||||
/**
|
||||
* 表名处理器,是否处理表名的情况都在该处理器中自行判断
|
||||
*/
|
||||
private TableNameHandler tableNameHandler;
|
||||
|
||||
public DynamicTableNameInnerInterceptor(TableNameHandler tableNameHandler) {
|
||||
this.tableNameHandler = tableNameHandler;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) throws SQLException {
|
||||
if (InterceptorIgnoreHelper.willIgnoreDynamicTableName(ms.getId())) return;
|
||||
PluginUtils.MPBoundSql mpBs = PluginUtils.mpBoundSql(boundSql);
|
||||
mpBs.sql(this.changeTable(mpBs.sql()));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void beforePrepare(StatementHandler sh, Connection connection, Integer transactionTimeout) {
|
||||
PluginUtils.MPStatementHandler mpSh = PluginUtils.mpStatementHandler(sh);
|
||||
MappedStatement ms = mpSh.mappedStatement();
|
||||
SqlCommandType sct = ms.getSqlCommandType();
|
||||
if (sct == SqlCommandType.INSERT || sct == SqlCommandType.UPDATE || sct == SqlCommandType.DELETE) {
|
||||
if (InterceptorIgnoreHelper.willIgnoreDynamicTableName(ms.getId())) {
|
||||
return;
|
||||
}
|
||||
PluginUtils.MPBoundSql mpBs = mpSh.mPBoundSql();
|
||||
mpBs.sql(this.changeTable(mpBs.sql()));
|
||||
}
|
||||
}
|
||||
|
||||
public String changeTable(String sql) {
|
||||
ExceptionUtils.throwMpe(null == tableNameHandler, "Please implement TableNameHandler processing logic");
|
||||
TableNameParser parser = new TableNameParser(sql);
|
||||
List<TableNameParser.SqlToken> names = new ArrayList<>();
|
||||
parser.accept(names::add);
|
||||
StringBuilder builder = new StringBuilder();
|
||||
int last = 0;
|
||||
for (TableNameParser.SqlToken name : names) {
|
||||
int start = name.getStart();
|
||||
if (start != last) {
|
||||
builder.append(sql, last, start);
|
||||
builder.append(tableNameHandler.dynamicTableName(sql, name.getValue()));
|
||||
}
|
||||
last = name.getEnd();
|
||||
}
|
||||
if (last != sql.length()) {
|
||||
builder.append(sql.substring(last));
|
||||
}
|
||||
if (hook != null) {
|
||||
hook.run();
|
||||
}
|
||||
return builder.toString();
|
||||
}
|
||||
}
|
@ -0,0 +1,319 @@
|
||||
/*
|
||||
* Copyright (c) 2011-2024, baomidou (jobob@qq.com).
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package com.baomidou.mybatisplus.extension.plugins.inner;
|
||||
|
||||
import com.baomidou.mybatisplus.annotation.Version;
|
||||
import com.baomidou.mybatisplus.core.conditions.AbstractWrapper;
|
||||
import com.baomidou.mybatisplus.core.conditions.ISqlSegment;
|
||||
import com.baomidou.mybatisplus.core.conditions.Wrapper;
|
||||
import com.baomidou.mybatisplus.core.conditions.segments.NormalSegmentList;
|
||||
import com.baomidou.mybatisplus.core.conditions.update.Update;
|
||||
import com.baomidou.mybatisplus.core.conditions.update.UpdateWrapper;
|
||||
import com.baomidou.mybatisplus.core.enums.SqlKeyword;
|
||||
import com.baomidou.mybatisplus.core.mapper.Mapper;
|
||||
import com.baomidou.mybatisplus.core.metadata.TableFieldInfo;
|
||||
import com.baomidou.mybatisplus.core.metadata.TableInfo;
|
||||
import com.baomidou.mybatisplus.core.metadata.TableInfoHelper;
|
||||
import com.baomidou.mybatisplus.core.toolkit.Constants;
|
||||
import com.baomidou.mybatisplus.core.toolkit.ExceptionUtils;
|
||||
import com.baomidou.mybatisplus.core.toolkit.ReflectionKit;
|
||||
import com.baomidou.mybatisplus.core.toolkit.StringPool;
|
||||
import org.apache.ibatis.executor.Executor;
|
||||
import org.apache.ibatis.mapping.MappedStatement;
|
||||
import org.apache.ibatis.mapping.SqlCommandType;
|
||||
|
||||
import java.lang.reflect.Field;
|
||||
import java.sql.SQLException;
|
||||
import java.sql.Timestamp;
|
||||
import java.time.Instant;
|
||||
import java.time.LocalDateTime;
|
||||
import java.util.Date;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.function.Function;
|
||||
import java.util.regex.Matcher;
|
||||
import java.util.regex.Pattern;
|
||||
|
||||
/**
|
||||
* Optimistic Lock Light version
|
||||
* <p>Intercept on {@link Executor}.update;</p>
|
||||
* <p>Support version types: int/Integer, long/Long, java.util.Date, java.sql.Timestamp</p>
|
||||
* <p>For extra types, please define a subclass and override {@code getUpdatedVersionVal}() method.</p>
|
||||
* <br>
|
||||
* <p>How to use?</p>
|
||||
* <p>(1) Define an Entity and add {@link Version} annotation on one entity field.</p>
|
||||
* <p>(2) Add {@link OptimisticLockerInnerInterceptor} into mybatis plugin.</p>
|
||||
* <br>
|
||||
* <p>How to work?</p>
|
||||
* <p>if update entity with version column=1:</p>
|
||||
* <p>(1) no {@link OptimisticLockerInnerInterceptor}:</p>
|
||||
* <p>SQL: update tbl_test set name='abc' where id=100001;</p>
|
||||
* <p>(2) add {@link OptimisticLockerInnerInterceptor}:</p>
|
||||
* <p>SQL: update tbl_test set name='abc',version=2 where id=100001 and version=1;</p>
|
||||
*
|
||||
* @author yuxiaobin
|
||||
* @since 3.4.0
|
||||
*/
|
||||
@SuppressWarnings({"unchecked"})
|
||||
public class OptimisticLockerInnerInterceptor implements InnerInterceptor {
|
||||
private RuntimeException exception;
|
||||
|
||||
public void setException(RuntimeException exception) {
|
||||
this.exception = exception;
|
||||
}
|
||||
|
||||
/**
|
||||
* entity类缓存
|
||||
*/
|
||||
private static final Map<String, Class<?>> ENTITY_CLASS_CACHE = new ConcurrentHashMap<>();
|
||||
/**
|
||||
* 变量占位符正则
|
||||
*/
|
||||
private static final Pattern PARAM_PAIRS_RE = Pattern.compile("#\\{ew\\.paramNameValuePairs\\.(" + Constants.WRAPPER_PARAM + "\\d+)\\}");
|
||||
/**
|
||||
* paramNameValuePairs存放的version值的key
|
||||
*/
|
||||
private static final String UPDATED_VERSION_VAL_KEY = "#updatedVersionVal#";
|
||||
/**
|
||||
* Support wrapper mode
|
||||
*/
|
||||
private final boolean wrapperMode;
|
||||
|
||||
public OptimisticLockerInnerInterceptor() {
|
||||
this(false);
|
||||
}
|
||||
|
||||
public OptimisticLockerInnerInterceptor(boolean wrapperMode) {
|
||||
this.wrapperMode = wrapperMode;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void beforeUpdate(Executor executor, MappedStatement ms, Object parameter) throws SQLException {
|
||||
if (SqlCommandType.UPDATE != ms.getSqlCommandType()) {
|
||||
return;
|
||||
}
|
||||
if (parameter instanceof Map) {
|
||||
Map<String, Object> map = (Map<String, Object>) parameter;
|
||||
doOptimisticLocker(map, ms.getId());
|
||||
}
|
||||
}
|
||||
|
||||
protected void doOptimisticLocker(Map<String, Object> map, String msId) {
|
||||
// updateById(et), update(et, wrapper);
|
||||
Object et = map.getOrDefault(Constants.ENTITY, null);
|
||||
if (Objects.nonNull(et)) {
|
||||
|
||||
// version field
|
||||
TableFieldInfo fieldInfo = this.getVersionFieldInfo(et.getClass());
|
||||
if (null == fieldInfo) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
Field versionField = fieldInfo.getField();
|
||||
// 旧的 version 值
|
||||
Object originalVersionVal = versionField.get(et);
|
||||
if (originalVersionVal == null) {
|
||||
if (null != exception) {
|
||||
/**
|
||||
* 自定义异常处理
|
||||
*/
|
||||
throw exception;
|
||||
}
|
||||
return;
|
||||
}
|
||||
String versionColumn = fieldInfo.getColumn();
|
||||
// 新的 version 值
|
||||
Object updatedVersionVal = this.getUpdatedVersionVal(fieldInfo.getPropertyType(), originalVersionVal);
|
||||
String methodName = msId.substring(msId.lastIndexOf(StringPool.DOT) + 1);
|
||||
if ("update".equals(methodName)) {
|
||||
AbstractWrapper<?, ?, ?> aw = (AbstractWrapper<?, ?, ?>) map.getOrDefault(Constants.WRAPPER, null);
|
||||
if (aw == null) {
|
||||
UpdateWrapper<?> uw = new UpdateWrapper<>();
|
||||
uw.eq(versionColumn, originalVersionVal);
|
||||
map.put(Constants.WRAPPER, uw);
|
||||
} else {
|
||||
aw.apply(versionColumn + " = {0}", originalVersionVal);
|
||||
}
|
||||
} else {
|
||||
map.put(Constants.MP_OPTLOCK_VERSION_ORIGINAL, originalVersionVal);
|
||||
}
|
||||
versionField.set(et, updatedVersionVal);
|
||||
} catch (IllegalAccessException e) {
|
||||
throw ExceptionUtils.mpe(e);
|
||||
}
|
||||
}
|
||||
|
||||
// update(LambdaUpdateWrapper) or update(UpdateWrapper)
|
||||
else if (wrapperMode && map.entrySet().stream().anyMatch(t -> Objects.equals(t.getKey(), Constants.WRAPPER))) {
|
||||
setVersionByWrapper(map, msId);
|
||||
}
|
||||
}
|
||||
|
||||
protected TableFieldInfo getVersionFieldInfo(Class<?> entityClazz) {
|
||||
TableInfo tableInfo = TableInfoHelper.getTableInfo(entityClazz);
|
||||
return (null != tableInfo && tableInfo.isWithVersion()) ? tableInfo.getVersionFieldInfo() : null;
|
||||
}
|
||||
|
||||
private void setVersionByWrapper(Map<String, Object> map, String msId) {
|
||||
Object ew = map.get(Constants.WRAPPER);
|
||||
if (ew instanceof AbstractWrapper && ew instanceof Update) {
|
||||
Class<?> entityClass = ENTITY_CLASS_CACHE.get(msId);
|
||||
if (null == entityClass) {
|
||||
try {
|
||||
final String className = msId.substring(0, msId.lastIndexOf('.'));
|
||||
entityClass = ReflectionKit.getSuperClassGenericType(Class.forName(className), Mapper.class, 0);
|
||||
ENTITY_CLASS_CACHE.put(msId, entityClass);
|
||||
} catch (ClassNotFoundException e) {
|
||||
throw ExceptionUtils.mpe(e);
|
||||
}
|
||||
}
|
||||
|
||||
final TableFieldInfo versionField = getVersionFieldInfo(entityClass);
|
||||
if (null == versionField) {
|
||||
return;
|
||||
}
|
||||
|
||||
final String versionColumn = versionField.getColumn();
|
||||
final FieldEqFinder fieldEqFinder = new FieldEqFinder(versionColumn, (Wrapper<?>) ew);
|
||||
if (!fieldEqFinder.isPresent()) {
|
||||
return;
|
||||
}
|
||||
final Map<String, Object> paramNameValuePairs = ((AbstractWrapper<?, ?, ?>) ew).getParamNameValuePairs();
|
||||
final Object originalVersionValue = paramNameValuePairs.get(fieldEqFinder.valueKey);
|
||||
if (originalVersionValue == null) {
|
||||
return;
|
||||
}
|
||||
final Object updatedVersionVal = getUpdatedVersionVal(originalVersionValue.getClass(), originalVersionValue);
|
||||
if (originalVersionValue == updatedVersionVal) {
|
||||
return;
|
||||
}
|
||||
// 拼接新的version值
|
||||
paramNameValuePairs.put(UPDATED_VERSION_VAL_KEY, updatedVersionVal);
|
||||
((Update<?, ?>) ew).setSql(String.format("%s = #{%s.%s}", versionColumn, "ew.paramNameValuePairs", UPDATED_VERSION_VAL_KEY));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* EQ字段查找器
|
||||
*/
|
||||
private static class FieldEqFinder {
|
||||
|
||||
/**
|
||||
* 状态机
|
||||
*/
|
||||
enum State {
|
||||
INIT,
|
||||
FIELD_FOUND,
|
||||
EQ_FOUND,
|
||||
VERSION_VALUE_PRESENT;
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* 字段值的key
|
||||
*/
|
||||
private String valueKey;
|
||||
/**
|
||||
* 当前状态
|
||||
*/
|
||||
private State state;
|
||||
/**
|
||||
* 字段名
|
||||
*/
|
||||
private final String fieldName;
|
||||
|
||||
public FieldEqFinder(String fieldName, Wrapper<?> wrapper) {
|
||||
this.fieldName = fieldName;
|
||||
state = State.INIT;
|
||||
find(wrapper);
|
||||
}
|
||||
|
||||
/**
|
||||
* 是否已存在
|
||||
*/
|
||||
public boolean isPresent() {
|
||||
return state == State.VERSION_VALUE_PRESENT;
|
||||
}
|
||||
|
||||
private boolean find(Wrapper<?> wrapper) {
|
||||
Matcher matcher;
|
||||
final NormalSegmentList segments = wrapper.getExpression().getNormal();
|
||||
for (ISqlSegment segment : segments) {
|
||||
// 如果字段已找到并且当前segment为EQ
|
||||
if (state == State.FIELD_FOUND && segment == SqlKeyword.EQ) {
|
||||
this.state = State.EQ_FOUND;
|
||||
// 如果EQ找到并且value已找到
|
||||
} else if (state == State.EQ_FOUND
|
||||
&& (matcher = PARAM_PAIRS_RE.matcher(segment.getSqlSegment())).matches()) {
|
||||
this.valueKey = matcher.group(1);
|
||||
this.state = State.VERSION_VALUE_PRESENT;
|
||||
return true;
|
||||
// 处理嵌套
|
||||
} else if (segment instanceof Wrapper) {
|
||||
if (find((Wrapper<?>) segment)) {
|
||||
return true;
|
||||
}
|
||||
// 判断字段是否是要查找字段
|
||||
} else if (segment.getSqlSegment().equals(this.fieldName)) {
|
||||
this.state = State.FIELD_FOUND;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
private static class VersionFactory {
|
||||
|
||||
/**
|
||||
* 存放版本号类型与获取更新后版本号的map
|
||||
*/
|
||||
private static final Map<Class<?>, Function<Object, Object>> VERSION_FUNCTION_MAP = new HashMap<>();
|
||||
|
||||
static {
|
||||
VERSION_FUNCTION_MAP.put(long.class, version -> (long) version + 1);
|
||||
VERSION_FUNCTION_MAP.put(Long.class, version -> (long) version + 1);
|
||||
VERSION_FUNCTION_MAP.put(int.class, version -> (int) version + 1);
|
||||
VERSION_FUNCTION_MAP.put(Integer.class, version -> (int) version + 1);
|
||||
VERSION_FUNCTION_MAP.put(Date.class, version -> new Date());
|
||||
VERSION_FUNCTION_MAP.put(Timestamp.class, version -> new Timestamp(System.currentTimeMillis()));
|
||||
VERSION_FUNCTION_MAP.put(LocalDateTime.class, version -> LocalDateTime.now());
|
||||
VERSION_FUNCTION_MAP.put(Instant.class, version -> Instant.now());
|
||||
}
|
||||
|
||||
public static Object getUpdatedVersionVal(Class<?> clazz, Object originalVersionVal) {
|
||||
Function<Object, Object> versionFunction = VERSION_FUNCTION_MAP.get(clazz);
|
||||
if (versionFunction == null) {
|
||||
// not supported type, return original val.
|
||||
return originalVersionVal;
|
||||
}
|
||||
return versionFunction.apply(originalVersionVal);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* This method provides the control for version value.<BR>
|
||||
* Returned value type must be the same as original one.
|
||||
*
|
||||
* @param originalVersionVal ignore
|
||||
* @return updated version val
|
||||
*/
|
||||
protected Object getUpdatedVersionVal(Class<?> clazz, Object originalVersionVal) {
|
||||
return VersionFactory.getUpdatedVersionVal(clazz, originalVersionVal);
|
||||
}
|
||||
}
|
@ -0,0 +1,478 @@
|
||||
/*
|
||||
* Copyright (c) 2011-2024, baomidou (jobob@qq.com).
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package com.baomidou.mybatisplus.extension.plugins.inner;
|
||||
|
||||
import com.baomidou.mybatisplus.annotation.DbType;
|
||||
import com.baomidou.mybatisplus.core.metadata.IPage;
|
||||
import com.baomidou.mybatisplus.core.metadata.OrderItem;
|
||||
import com.baomidou.mybatisplus.core.toolkit.*;
|
||||
import com.baomidou.mybatisplus.extension.parser.JsqlParserGlobal;
|
||||
import com.baomidou.mybatisplus.extension.plugins.pagination.DialectFactory;
|
||||
import com.baomidou.mybatisplus.extension.plugins.pagination.DialectModel;
|
||||
import com.baomidou.mybatisplus.extension.plugins.pagination.dialects.IDialect;
|
||||
import com.baomidou.mybatisplus.extension.toolkit.JdbcUtils;
|
||||
import com.baomidou.mybatisplus.extension.toolkit.PropertyMapper;
|
||||
import com.baomidou.mybatisplus.extension.toolkit.SqlParserUtils;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import net.sf.jsqlparser.JSQLParserException;
|
||||
import net.sf.jsqlparser.expression.Alias;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
import net.sf.jsqlparser.schema.Column;
|
||||
import net.sf.jsqlparser.schema.Table;
|
||||
import net.sf.jsqlparser.statement.select.*;
|
||||
import org.apache.ibatis.cache.CacheKey;
|
||||
import org.apache.ibatis.executor.Executor;
|
||||
import org.apache.ibatis.logging.Log;
|
||||
import org.apache.ibatis.logging.LogFactory;
|
||||
import org.apache.ibatis.mapping.BoundSql;
|
||||
import org.apache.ibatis.mapping.MappedStatement;
|
||||
import org.apache.ibatis.mapping.ParameterMapping;
|
||||
import org.apache.ibatis.mapping.ResultMap;
|
||||
import org.apache.ibatis.session.Configuration;
|
||||
import org.apache.ibatis.session.ResultHandler;
|
||||
import org.apache.ibatis.session.RowBounds;
|
||||
|
||||
import java.sql.SQLException;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* 分页拦截器
|
||||
* <p>
|
||||
* 默认对 left join 进行优化,虽然能优化count,但是加上分页的话如果1对多本身结果条数就是不正确的
|
||||
*
|
||||
* @author hubin
|
||||
* @since 3.4.0
|
||||
*/
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
@SuppressWarnings({"rawtypes"})
|
||||
public class PaginationInnerInterceptor implements InnerInterceptor {
|
||||
/**
|
||||
* 获取jsqlparser中count的SelectItem
|
||||
*/
|
||||
protected static final List<SelectItem<?>> COUNT_SELECT_ITEM = Collections.singletonList(
|
||||
new SelectItem<>(new Column().withColumnName("COUNT(*)")).withAlias(new Alias("total"))
|
||||
);
|
||||
protected static final Map<String, MappedStatement> countMsCache = new ConcurrentHashMap<>();
|
||||
protected final Log logger = LogFactory.getLog(this.getClass());
|
||||
|
||||
|
||||
/**
|
||||
* 溢出总页数后是否进行处理
|
||||
*/
|
||||
protected boolean overflow;
|
||||
/**
|
||||
* 单页分页条数限制
|
||||
*/
|
||||
protected Long maxLimit;
|
||||
/**
|
||||
* 数据库类型
|
||||
* <p>
|
||||
* 查看 {@link #findIDialect(Executor)} 逻辑
|
||||
*/
|
||||
private DbType dbType;
|
||||
/**
|
||||
* 方言实现类
|
||||
* <p>
|
||||
* 查看 {@link #findIDialect(Executor)} 逻辑
|
||||
*/
|
||||
private IDialect dialect;
|
||||
/**
|
||||
* 生成 countSql 优化掉 join
|
||||
* 现在只支持 left join
|
||||
*
|
||||
* @since 3.4.2
|
||||
*/
|
||||
protected boolean optimizeJoin = true;
|
||||
|
||||
public PaginationInnerInterceptor(DbType dbType) {
|
||||
this.dbType = dbType;
|
||||
}
|
||||
|
||||
public PaginationInnerInterceptor(IDialect dialect) {
|
||||
this.dialect = dialect;
|
||||
}
|
||||
|
||||
/**
|
||||
* 这里进行count,如果count为0这返回false(就是不再执行sql了)
|
||||
*/
|
||||
@Override
|
||||
public boolean willDoQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) throws SQLException {
|
||||
IPage<?> page = ParameterUtils.findPage(parameter).orElse(null);
|
||||
if (page == null || page.getSize() < 0 || !page.searchCount() || resultHandler != Executor.NO_RESULT_HANDLER) {
|
||||
return true;
|
||||
}
|
||||
|
||||
BoundSql countSql;
|
||||
MappedStatement countMs = buildCountMappedStatement(ms, page.countId());
|
||||
if (countMs != null) {
|
||||
countSql = countMs.getBoundSql(parameter);
|
||||
} else {
|
||||
countMs = buildAutoCountMappedStatement(ms);
|
||||
String countSqlStr = autoCountSql(page, boundSql.getSql());
|
||||
PluginUtils.MPBoundSql mpBoundSql = PluginUtils.mpBoundSql(boundSql);
|
||||
countSql = new BoundSql(countMs.getConfiguration(), countSqlStr, mpBoundSql.parameterMappings(), parameter);
|
||||
PluginUtils.setAdditionalParameter(countSql, mpBoundSql.additionalParameters());
|
||||
}
|
||||
|
||||
CacheKey cacheKey = executor.createCacheKey(countMs, parameter, rowBounds, countSql);
|
||||
List<Object> result = executor.query(countMs, parameter, rowBounds, resultHandler, cacheKey, countSql);
|
||||
long total = 0;
|
||||
if (CollectionUtils.isNotEmpty(result)) {
|
||||
// 个别数据库 count 没数据不会返回 0
|
||||
Object o = result.get(0);
|
||||
if (o != null) {
|
||||
total = Long.parseLong(o.toString());
|
||||
}
|
||||
}
|
||||
page.setTotal(total);
|
||||
return continuePage(page);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) {
|
||||
IPage<?> page = ParameterUtils.findPage(parameter).orElse(null);
|
||||
if (null == page) {
|
||||
return;
|
||||
}
|
||||
|
||||
// 处理 orderBy 拼接
|
||||
boolean addOrdered = false;
|
||||
String buildSql = boundSql.getSql();
|
||||
List<OrderItem> orders = page.orders();
|
||||
if (CollectionUtils.isNotEmpty(orders)) {
|
||||
addOrdered = true;
|
||||
buildSql = this.concatOrderBy(buildSql, orders);
|
||||
}
|
||||
|
||||
// size 小于 0 且不限制返回值则不构造分页sql
|
||||
Long _limit = page.maxLimit() != null ? page.maxLimit() : maxLimit;
|
||||
if (page.getSize() < 0 && null == _limit) {
|
||||
if (addOrdered) {
|
||||
PluginUtils.mpBoundSql(boundSql).sql(buildSql);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
handlerLimit(page, _limit);
|
||||
IDialect dialect = findIDialect(executor);
|
||||
|
||||
final Configuration configuration = ms.getConfiguration();
|
||||
DialectModel model = dialect.buildPaginationSql(buildSql, page.offset(), page.getSize());
|
||||
PluginUtils.MPBoundSql mpBoundSql = PluginUtils.mpBoundSql(boundSql);
|
||||
|
||||
List<ParameterMapping> mappings = mpBoundSql.parameterMappings();
|
||||
Map<String, Object> additionalParameter = mpBoundSql.additionalParameters();
|
||||
model.consumers(mappings, configuration, additionalParameter);
|
||||
mpBoundSql.sql(model.getDialectSql());
|
||||
mpBoundSql.parameterMappings(mappings);
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取分页方言类的逻辑
|
||||
*
|
||||
* @param executor Executor
|
||||
* @return 分页方言类
|
||||
*/
|
||||
protected IDialect findIDialect(Executor executor) {
|
||||
if (dialect != null) {
|
||||
return dialect;
|
||||
}
|
||||
if (dbType != null) {
|
||||
dialect = DialectFactory.getDialect(dbType);
|
||||
return dialect;
|
||||
}
|
||||
return DialectFactory.getDialect(JdbcUtils.getDbType(executor));
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取指定的 id 的 MappedStatement
|
||||
*
|
||||
* @param ms MappedStatement
|
||||
* @param countId id
|
||||
* @return MappedStatement
|
||||
*/
|
||||
protected MappedStatement buildCountMappedStatement(MappedStatement ms, String countId) {
|
||||
if (StringUtils.isNotBlank(countId)) {
|
||||
final String id = ms.getId();
|
||||
if (!countId.contains(StringPool.DOT)) {
|
||||
countId = id.substring(0, id.lastIndexOf(StringPool.DOT) + 1) + countId;
|
||||
}
|
||||
final Configuration configuration = ms.getConfiguration();
|
||||
try {
|
||||
return CollectionUtils.computeIfAbsent(countMsCache, countId, key -> configuration.getMappedStatement(key, false));
|
||||
} catch (Exception e) {
|
||||
logger.warn(String.format("can not find this countId: [\"%s\"]", countId));
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建 mp 自用自动的 MappedStatement
|
||||
*
|
||||
* @param ms MappedStatement
|
||||
* @return MappedStatement
|
||||
*/
|
||||
protected MappedStatement buildAutoCountMappedStatement(MappedStatement ms) {
|
||||
final String countId = ms.getId() + "_mpCount";
|
||||
final Configuration configuration = ms.getConfiguration();
|
||||
return CollectionUtils.computeIfAbsent(countMsCache, countId, key -> {
|
||||
MappedStatement.Builder builder = new MappedStatement.Builder(configuration, key, ms.getSqlSource(), ms.getSqlCommandType());
|
||||
builder.resource(ms.getResource());
|
||||
builder.fetchSize(ms.getFetchSize());
|
||||
builder.statementType(ms.getStatementType());
|
||||
builder.timeout(ms.getTimeout());
|
||||
builder.parameterMap(ms.getParameterMap());
|
||||
builder.resultMaps(Collections.singletonList(new ResultMap.Builder(configuration, Constants.MYBATIS_PLUS, Long.class, Collections.emptyList()).build()));
|
||||
builder.resultSetType(ms.getResultSetType());
|
||||
builder.cache(ms.getCache());
|
||||
builder.flushCacheRequired(ms.isFlushCacheRequired());
|
||||
builder.useCache(ms.isUseCache());
|
||||
return builder.build();
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取自动优化的 countSql
|
||||
*
|
||||
* @param page 参数
|
||||
* @param sql sql
|
||||
* @return countSql
|
||||
*/
|
||||
public String autoCountSql(IPage<?> page, String sql) {
|
||||
if (!page.optimizeCountSql()) {
|
||||
return lowLevelCountSql(sql);
|
||||
}
|
||||
try {
|
||||
Select select = (Select) JsqlParserGlobal.parse(sql);
|
||||
// https://github.com/baomidou/mybatis-plus/issues/3920 分页增加union语法支持
|
||||
if (select instanceof SetOperationList) {
|
||||
return lowLevelCountSql(sql);
|
||||
}
|
||||
PlainSelect plainSelect = (PlainSelect) select;
|
||||
|
||||
// 优化 order by 在非分组情况下
|
||||
List<OrderByElement> orderBy = plainSelect.getOrderByElements();
|
||||
if (CollectionUtils.isNotEmpty(orderBy)) {
|
||||
boolean canClean = true;
|
||||
for (OrderByElement order : orderBy) {
|
||||
// order by 里带参数,不去除order by
|
||||
Expression expression = order.getExpression();
|
||||
if (!(expression instanceof Column) && expression.toString().contains(StringPool.QUESTION_MARK)) {
|
||||
canClean = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (canClean) {
|
||||
plainSelect.setOrderByElements(null);
|
||||
}
|
||||
}
|
||||
Distinct distinct = plainSelect.getDistinct();
|
||||
GroupByElement groupBy = plainSelect.getGroupBy();
|
||||
// 包含 distinct、groupBy 不优化
|
||||
if (null != distinct || null != groupBy) {
|
||||
return lowLevelCountSql(select.toString());
|
||||
}
|
||||
//#95 Github, selectItems contains #{} ${}, which will be translated to ?, and it may be in a function: power(#{myInt},2)
|
||||
for (SelectItem item : plainSelect.getSelectItems()) {
|
||||
if (item.toString().contains(StringPool.QUESTION_MARK)) {
|
||||
return lowLevelCountSql(select.toString());
|
||||
}
|
||||
}
|
||||
|
||||
// 包含 join 连表,进行判断是否移除 join 连表
|
||||
if (optimizeJoin && page.optimizeJoinOfCountSql()) {
|
||||
List<Join> joins = plainSelect.getJoins();
|
||||
if (CollectionUtils.isNotEmpty(joins)) {
|
||||
boolean canRemoveJoin = true;
|
||||
String whereS = Optional.ofNullable(plainSelect.getWhere()).map(Expression::toString).orElse(StringPool.EMPTY);
|
||||
// 不区分大小写
|
||||
whereS = whereS.toLowerCase();
|
||||
for (Join join : joins) {
|
||||
if (!join.isLeft()) {
|
||||
canRemoveJoin = false;
|
||||
break;
|
||||
}
|
||||
FromItem rightItem = join.getRightItem();
|
||||
String str = "";
|
||||
if (rightItem instanceof Table) {
|
||||
Table table = (Table) rightItem;
|
||||
str = Optional.ofNullable(table.getAlias()).map(Alias::getName).orElse(table.getName()) + StringPool.DOT;
|
||||
} else if (rightItem instanceof ParenthesedSelect) {
|
||||
ParenthesedSelect subSelect = (ParenthesedSelect) rightItem;
|
||||
/* 如果 left join 是子查询,并且子查询里包含 ?(代表有入参) 或者 where 条件里包含使用 join 的表的字段作条件,就不移除 join */
|
||||
if (subSelect.toString().contains(StringPool.QUESTION_MARK)) {
|
||||
canRemoveJoin = false;
|
||||
break;
|
||||
}
|
||||
str = subSelect.getAlias().getName() + StringPool.DOT;
|
||||
}
|
||||
// 不区分大小写
|
||||
str = str.toLowerCase();
|
||||
|
||||
if (whereS.contains(str)) {
|
||||
/* 如果 where 条件里包含使用 join 的表的字段作条件,就不移除 join */
|
||||
canRemoveJoin = false;
|
||||
break;
|
||||
}
|
||||
|
||||
for (Expression expression : join.getOnExpressions()) {
|
||||
if (expression.toString().contains(StringPool.QUESTION_MARK)) {
|
||||
/* 如果 join 里包含 ?(代表有入参) 就不移除 join */
|
||||
canRemoveJoin = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (canRemoveJoin) {
|
||||
plainSelect.setJoins(null);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 优化 SQL
|
||||
plainSelect.setSelectItems(COUNT_SELECT_ITEM);
|
||||
return select.toString();
|
||||
} catch (JSQLParserException e) {
|
||||
// 无法优化使用原 SQL
|
||||
logger.warn("optimize this sql to a count sql has exception, sql:\"" + sql + "\", exception:\n" + e.getCause());
|
||||
} catch (Exception e) {
|
||||
logger.warn("optimize this sql to a count sql has error, sql:\"" + sql + "\", exception:\n" + e);
|
||||
}
|
||||
return lowLevelCountSql(sql);
|
||||
}
|
||||
|
||||
/**
|
||||
* 无法进行count优化时,降级使用此方法
|
||||
*
|
||||
* @param originalSql 原始sql
|
||||
* @return countSql
|
||||
*/
|
||||
protected String lowLevelCountSql(String originalSql) {
|
||||
return SqlParserUtils.getOriginalCountSql(originalSql);
|
||||
}
|
||||
|
||||
/**
|
||||
* 查询SQL拼接Order By
|
||||
*
|
||||
* @param originalSql 需要拼接的SQL
|
||||
* @return ignore
|
||||
*/
|
||||
public String concatOrderBy(String originalSql, List<OrderItem> orderList) {
|
||||
try {
|
||||
Select selectBody = (Select) JsqlParserGlobal.parse(originalSql);
|
||||
if (selectBody instanceof PlainSelect) {
|
||||
PlainSelect plainSelect = (PlainSelect) selectBody;
|
||||
List<OrderByElement> orderByElements = plainSelect.getOrderByElements();
|
||||
List<OrderByElement> orderByElementsReturn = addOrderByElements(orderList, orderByElements);
|
||||
plainSelect.setOrderByElements(orderByElementsReturn);
|
||||
return plainSelect.toString();
|
||||
} else if (selectBody instanceof SetOperationList) {
|
||||
SetOperationList setOperationList = (SetOperationList) selectBody;
|
||||
List<OrderByElement> orderByElements = setOperationList.getOrderByElements();
|
||||
List<OrderByElement> orderByElementsReturn = addOrderByElements(orderList, orderByElements);
|
||||
setOperationList.setOrderByElements(orderByElementsReturn);
|
||||
return setOperationList.toString();
|
||||
} else if (selectBody instanceof WithItem) {
|
||||
// todo: don't known how to resole
|
||||
return originalSql;
|
||||
} else {
|
||||
return originalSql;
|
||||
}
|
||||
} catch (JSQLParserException e) {
|
||||
logger.warn("failed to concat orderBy from IPage, exception:\n" + e.getCause());
|
||||
} catch (Exception e) {
|
||||
logger.warn("failed to concat orderBy from IPage, exception:\n" + e);
|
||||
}
|
||||
return originalSql;
|
||||
}
|
||||
|
||||
protected List<OrderByElement> addOrderByElements(List<OrderItem> orderList, List<OrderByElement> orderByElements) {
|
||||
List<OrderByElement> additionalOrderBy = orderList.stream()
|
||||
.filter(item -> StringUtils.isNotBlank(item.getColumn()))
|
||||
.map(item -> {
|
||||
OrderByElement element = new OrderByElement();
|
||||
element.setExpression(new Column(item.getColumn()));
|
||||
element.setAsc(item.isAsc());
|
||||
element.setAscDescPresent(true);
|
||||
return element;
|
||||
}).collect(Collectors.toList());
|
||||
if (CollectionUtils.isEmpty(orderByElements)) {
|
||||
return additionalOrderBy;
|
||||
}
|
||||
// github pull/3550 优化排序,比如:默认 order by id 前端传了name排序,设置为 order by name,id
|
||||
additionalOrderBy.addAll(orderByElements);
|
||||
return additionalOrderBy;
|
||||
}
|
||||
|
||||
/**
|
||||
* count 查询之后,是否继续执行分页
|
||||
*
|
||||
* @param page 分页对象
|
||||
* @return 是否
|
||||
*/
|
||||
protected boolean continuePage(IPage<?> page) {
|
||||
if (page.getTotal() <= 0) {
|
||||
return false;
|
||||
}
|
||||
if (page.getCurrent() > page.getPages()) {
|
||||
if (overflow) {
|
||||
//溢出总页数处理
|
||||
handlerOverflow(page);
|
||||
} else {
|
||||
// 超过最大范围,未设置溢出逻辑中断 list 执行
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理超出分页条数限制,默认归为限制数
|
||||
*
|
||||
* @param page IPage
|
||||
*/
|
||||
protected void handlerLimit(IPage<?> page, Long limit) {
|
||||
final long size = page.getSize();
|
||||
if (limit != null && limit > 0 && (size > limit || size < 0)) {
|
||||
page.setSize(limit);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理页数溢出,默认设置为第一页
|
||||
*
|
||||
* @param page IPage
|
||||
*/
|
||||
protected void handlerOverflow(IPage<?> page) {
|
||||
page.setCurrent(1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setProperties(Properties properties) {
|
||||
PropertyMapper.newInstance(properties)
|
||||
.whenNotBlank("overflow", Boolean::parseBoolean, this::setOverflow)
|
||||
.whenNotBlank("dbType", DbType::getDbType, this::setDbType)
|
||||
.whenNotBlank("dialect", ClassUtils::newInstance, this::setDialect)
|
||||
.whenNotBlank("maxLimit", Long::parseLong, this::setMaxLimit)
|
||||
.whenNotBlank("optimizeJoin", Boolean::parseBoolean, this::setOptimizeJoin);
|
||||
}
|
||||
}
|
@ -0,0 +1,62 @@
|
||||
/*
|
||||
* Copyright (c) 2011-2024, baomidou (jobob@qq.com).
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package com.baomidou.mybatisplus.extension.plugins.inner;
|
||||
|
||||
import com.baomidou.mybatisplus.core.config.GlobalConfig;
|
||||
import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
|
||||
import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
|
||||
import com.baomidou.mybatisplus.core.toolkit.sql.SqlUtils;
|
||||
import org.apache.ibatis.executor.Executor;
|
||||
import org.apache.ibatis.logging.Log;
|
||||
import org.apache.ibatis.logging.LogFactory;
|
||||
import org.apache.ibatis.mapping.BoundSql;
|
||||
import org.apache.ibatis.mapping.MappedStatement;
|
||||
import org.apache.ibatis.session.ResultHandler;
|
||||
import org.apache.ibatis.session.RowBounds;
|
||||
|
||||
import java.sql.SQLException;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* 功能类似于 {@link GlobalConfig.DbConfig#isReplacePlaceholder()},
|
||||
* 只是这个是在运行时实时替换,适用范围更广
|
||||
*
|
||||
* @author miemie
|
||||
* @since 2020-11-19
|
||||
*/
|
||||
public class ReplacePlaceholderInnerInterceptor implements InnerInterceptor {
|
||||
|
||||
protected final Log logger = LogFactory.getLog(this.getClass());
|
||||
|
||||
private String escapeSymbol;
|
||||
|
||||
public ReplacePlaceholderInnerInterceptor() {
|
||||
}
|
||||
|
||||
public ReplacePlaceholderInnerInterceptor(String escapeSymbol) {
|
||||
this.escapeSymbol = escapeSymbol;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) throws SQLException {
|
||||
String sql = boundSql.getSql();
|
||||
List<String> find = SqlUtils.findPlaceholder(sql);
|
||||
if (CollectionUtils.isNotEmpty(find)) {
|
||||
sql = SqlUtils.replaceSqlPlaceholder(sql, find, escapeSymbol);
|
||||
PluginUtils.mpBoundSql(boundSql).sql(sql);
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,93 @@
|
||||
package com.baomidou.mybatisplus.test.extension.parser;
|
||||
|
||||
import com.baomidou.mybatisplus.extension.parser.cache.FstFactory;
|
||||
import net.sf.jsqlparser.JSQLParserException;
|
||||
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
||||
import net.sf.jsqlparser.statement.Statement;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.condition.EnabledOnJre;
|
||||
import org.junit.jupiter.api.condition.JRE;
|
||||
import org.springframework.util.SerializationUtils;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
/**
|
||||
* @author miemie
|
||||
* @since 2023-08-03
|
||||
*/
|
||||
class JsqlParserSimpleSerialTest {
|
||||
private final static int len = 1000;
|
||||
private final static String sql = "SELECT * FROM entity e " +
|
||||
"LEFT JOIN entity1 e1 " +
|
||||
"LEFT JOIN entity2 e2 ON e2.id = e1.id " +
|
||||
"ON e1.id = e.id " +
|
||||
"WHERE (e.id = ? OR e.NAME = ?)";
|
||||
|
||||
@Test
|
||||
@EnabledOnJre(JRE.JAVA_8)
|
||||
void test() throws JSQLParserException {
|
||||
System.out.println("循环次数: " + len);
|
||||
noSerial();
|
||||
jdkSerial();
|
||||
fstSerial();
|
||||
}
|
||||
|
||||
void noSerial() throws JSQLParserException {
|
||||
long startTime = System.currentTimeMillis();
|
||||
for (int i = 0; i < len; i++) {
|
||||
CCJSqlParserUtil.parse(sql);
|
||||
}
|
||||
long endTime = System.currentTimeMillis();
|
||||
long e1 = endTime - startTime;
|
||||
System.out.printf("普通解析执行耗时: %s 毫秒, 均耗时: %s%n", e1, (double) e1 / len);
|
||||
}
|
||||
|
||||
void jdkSerial() throws JSQLParserException {
|
||||
Statement statement = CCJSqlParserUtil.parse(sql);
|
||||
String target = statement.toString();
|
||||
byte[] serial = null;
|
||||
long startTime = System.currentTimeMillis();
|
||||
for (int i = 0; i < len; i++) {
|
||||
serial = SerializationUtils.serialize(statement);
|
||||
}
|
||||
long endTime = System.currentTimeMillis();
|
||||
long et = endTime - startTime;
|
||||
System.out.printf("jdk serialize 执行耗时: %s 毫秒,byte大小: %s, 均耗时: %s%n", et, serial.length, (double) et / len);
|
||||
|
||||
|
||||
startTime = System.currentTimeMillis();
|
||||
for (int i = 0; i < len; i++) {
|
||||
statement = (Statement) SerializationUtils.deserialize(serial);
|
||||
}
|
||||
endTime = System.currentTimeMillis();
|
||||
et = endTime - startTime;
|
||||
System.out.printf("jdk deserialize 执行耗时: %s 毫秒, 均耗时: %s%n", et, (double) et / len);
|
||||
assertThat(statement).isNotNull();
|
||||
assertThat(statement.toString()).isEqualTo(target);
|
||||
}
|
||||
|
||||
void fstSerial() throws JSQLParserException {
|
||||
Statement statement = CCJSqlParserUtil.parse(sql);
|
||||
String target = statement.toString();
|
||||
FstFactory factory = FstFactory.getDefaultFactory();
|
||||
byte[] serial = null;
|
||||
long startTime = System.currentTimeMillis();
|
||||
for (int i = 0; i < len; i++) {
|
||||
serial = factory.asByteArray(statement);
|
||||
}
|
||||
long endTime = System.currentTimeMillis();
|
||||
long et = endTime - startTime;
|
||||
System.out.printf("fst serialize 执行耗时: %s 毫秒,byte大小: %s, 均耗时: %s%n", et, serial.length, (double) et / len);
|
||||
|
||||
|
||||
startTime = System.currentTimeMillis();
|
||||
for (int i = 0; i < len; i++) {
|
||||
statement = (Statement) factory.asObject(serial);
|
||||
}
|
||||
endTime = System.currentTimeMillis();
|
||||
et = endTime - startTime;
|
||||
System.out.printf("fst deserialize 执行耗时: %s 毫秒, 均耗时: %s%n", et, (double) et / len);
|
||||
assertThat(statement).isNotNull();
|
||||
assertThat(statement.toString()).isEqualTo(target);
|
||||
}
|
||||
}
|
@ -0,0 +1,37 @@
|
||||
package com.baomidou.mybatisplus.test.extension.parser.cache;
|
||||
|
||||
import io.github.classgraph.ClassGraph;
|
||||
import io.github.classgraph.ClassInfo;
|
||||
import io.github.classgraph.ScanResult;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* @author miemie
|
||||
* @since 2023-08-06
|
||||
*/
|
||||
class FstFactoryTest {
|
||||
|
||||
@Test
|
||||
void clazz() {
|
||||
List<ClassInfo> list = new ArrayList<>();
|
||||
List<ClassInfo> absList = new ArrayList<>();
|
||||
try (ScanResult scanResult = new ClassGraph().enableClassInfo().acceptPackages("net.sf.jsqlparser").scan()) {
|
||||
for (ClassInfo classInfo : scanResult.getAllClasses()) {
|
||||
if (!classInfo.isInterface() && classInfo.implementsInterface(Serializable.class)) {
|
||||
if (classInfo.isAbstract()) {
|
||||
absList.add(classInfo);
|
||||
continue;
|
||||
}
|
||||
list.add(classInfo);
|
||||
}
|
||||
}
|
||||
}
|
||||
list.forEach(i -> System.out.printf("conf.registerClass(%s.class);%n", i.getName().replace("$", ".")));
|
||||
System.out.println("↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓");
|
||||
absList.forEach(i -> System.out.printf("conf.registerClass(%s.class);%n", i.getName().replace("$", ".")));
|
||||
}
|
||||
}
|
@ -0,0 +1,40 @@
|
||||
package com.baomidou.mybatisplus.test.extension.plugins;
|
||||
|
||||
import com.baomidou.mybatisplus.annotation.DbType;
|
||||
import com.baomidou.mybatisplus.extension.plugins.MybatisPlusInterceptor;
|
||||
import com.baomidou.mybatisplus.extension.plugins.inner.InnerInterceptor;
|
||||
import com.baomidou.mybatisplus.extension.plugins.inner.PaginationInnerInterceptor;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Properties;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
/**
|
||||
* @author miemie
|
||||
* @since 2020-06-30
|
||||
*/
|
||||
class MybatisPlusInterceptorTest {
|
||||
|
||||
@Test
|
||||
void setProperties() {
|
||||
Properties properties = new Properties();
|
||||
properties.setProperty("@page", "com.baomidou.mybatisplus.extension.plugins.inner.PaginationInnerInterceptor");
|
||||
properties.setProperty("page:maxLimit", "10");
|
||||
properties.setProperty("page:dbType", "h2");
|
||||
MybatisPlusInterceptor interceptor = new MybatisPlusInterceptor();
|
||||
interceptor.setProperties(properties);
|
||||
List<InnerInterceptor> interceptors = interceptor.getInterceptors();
|
||||
|
||||
assertThat(interceptors).isNotEmpty();
|
||||
assertThat(interceptors.size()).isEqualTo(1);
|
||||
|
||||
InnerInterceptor page = interceptors.get(0);
|
||||
assertThat(page).isInstanceOf(PaginationInnerInterceptor.class);
|
||||
|
||||
PaginationInnerInterceptor pii = (PaginationInnerInterceptor) page;
|
||||
assertThat(pii.getMaxLimit()).isEqualTo(10);
|
||||
assertThat(pii.getDbType()).isEqualTo(DbType.H2);
|
||||
}
|
||||
}
|
@ -0,0 +1,62 @@
|
||||
package com.baomidou.mybatisplus.test.extension.plugins.inner;
|
||||
|
||||
import com.baomidou.mybatisplus.core.exceptions.MybatisPlusException;
|
||||
import com.baomidou.mybatisplus.extension.plugins.inner.BlockAttackInnerInterceptor;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
/**
|
||||
* @author miemie
|
||||
* @since 2020-08-18
|
||||
*/
|
||||
class BlockAttackInnerInterceptorTest {
|
||||
|
||||
private final BlockAttackInnerInterceptor interceptor = new BlockAttackInnerInterceptor();
|
||||
|
||||
@Test
|
||||
void update() {
|
||||
checkEx("update user set name = null", "null where");
|
||||
checkEx("update user set name = null where 1=1", "1=1");
|
||||
checkEx("update user set name = null where 1<>2", "1<>2");
|
||||
checkEx("update user set name = null where 1!=2", "1!=2");
|
||||
checkEx("update user set name = null where 1=1 and 2=2", "1=1 and 2=2");
|
||||
checkEx("update user set name = null where 1=1 and 2=3 or 1=1", "1=1 and 2=3 or 1=1");
|
||||
checkEx("update user set name = null where 1=1 and (2=2)", "1=1 and (2=2)");
|
||||
checkEx("update user set name = null where (1=1 and 2=2)", "(1=1 and 2=2)");
|
||||
checkEx("update user set name = null where (1=1 and (2=3 or 3=3))", "(1=1 and (2=3 or 3=3))");
|
||||
|
||||
checkNotEx("update user set name = null where 1=?", "1=?");
|
||||
}
|
||||
|
||||
@Test
|
||||
void delete() {
|
||||
checkEx("delete from user", "null where");
|
||||
checkEx("delete from user where 1=1", "1=1");
|
||||
checkEx("delete from user where 1<>2", "1<>2");
|
||||
checkEx("delete from user where 1!=2", "1!=2");
|
||||
checkEx("delete from user where 1=1 and 2=2", "1=1 and 2=2");
|
||||
checkEx("delete from user where 1=1 and 2=3 or 1=1", "1=1 and 2=3 or 1=1");
|
||||
}
|
||||
|
||||
void checkEx(String sql, String as) {
|
||||
Exception e = null;
|
||||
try {
|
||||
interceptor.parserSingle(sql, null);
|
||||
} catch (Exception x) {
|
||||
e = x;
|
||||
}
|
||||
assertThat(e).as(as).isNotNull();
|
||||
assertThat(e).as(as).isInstanceOf(MybatisPlusException.class);
|
||||
}
|
||||
|
||||
void checkNotEx(String sql, String as) {
|
||||
Exception e = null;
|
||||
try {
|
||||
interceptor.parserSingle(sql, null);
|
||||
} catch (Exception x) {
|
||||
e = x;
|
||||
}
|
||||
assertThat(e).as(as).isNull();
|
||||
}
|
||||
}
|
@ -0,0 +1,63 @@
|
||||
package com.baomidou.mybatisplus.test.extension.plugins.inner;
|
||||
|
||||
import com.baomidou.mybatisplus.extension.plugins.inner.DataChangeRecorderInnerInterceptor;
|
||||
import net.sf.jsqlparser.schema.Table;
|
||||
import net.sf.jsqlparser.statement.insert.Insert;
|
||||
import net.sf.jsqlparser.statement.update.Update;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.lang.reflect.Field;
|
||||
import java.util.Map;
|
||||
import java.util.Properties;
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
* @author miemie
|
||||
* @since 2020-06-28
|
||||
*/
|
||||
class DataChangeRecorderInnerInterceptorTest {
|
||||
|
||||
private final DataChangeRecorderInnerInterceptor interceptor = new DataChangeRecorderInnerInterceptor();
|
||||
|
||||
@BeforeEach
|
||||
public void initProperties() {
|
||||
Properties properties = new Properties();
|
||||
properties.put("ignoredTableColumns", "table_name1.column1,column2; h2user.*; *.column1,COLUMN2");
|
||||
interceptor.setProperties(properties);
|
||||
}
|
||||
@Test
|
||||
void setProperties() throws Exception {
|
||||
final Object ignoreAllColumns = getFieldValue(interceptor, "ignoreAllColumns");
|
||||
Assertions.assertEquals(Set.of("COLUMN1", "COLUMN2"), ignoreAllColumns);
|
||||
final Object ignoredTableColumns = getFieldValue(interceptor, "ignoredTableColumns");
|
||||
Assertions.assertEquals(Map.of("H2USER", Set.of("*"), "TABLE_NAME1", Set.of("COLUMN1", "COLUMN2")), ignoredTableColumns);
|
||||
}
|
||||
|
||||
private Object getFieldValue(Object obj, String fieldName) throws NoSuchFieldException, IllegalAccessException {
|
||||
final Field field = DataChangeRecorderInnerInterceptor.class.getDeclaredField(fieldName);
|
||||
field.setAccessible(true);
|
||||
return field.get(obj);
|
||||
}
|
||||
|
||||
@Test
|
||||
void processInsert() {
|
||||
final Insert insert = new Insert();
|
||||
insert.setTable(new Table("H2USER"));
|
||||
final DataChangeRecorderInnerInterceptor.OperationResult operationResult = interceptor.processInsert(insert, null);
|
||||
Assertions.assertEquals(operationResult.getTableName(), "H2USER:*");
|
||||
Assertions.assertFalse(operationResult.isRecordStatus());
|
||||
Assertions.assertNull(operationResult.getChangedData());
|
||||
}
|
||||
|
||||
@Test
|
||||
void processUpdate() {
|
||||
final Update update = new Update();
|
||||
update.setTable(new Table("H2USER"));
|
||||
final DataChangeRecorderInnerInterceptor.OperationResult operationResult = interceptor.processUpdate(update, null, null, null);
|
||||
Assertions.assertEquals(operationResult.getTableName(), "H2USER:*");
|
||||
Assertions.assertFalse(operationResult.isRecordStatus());
|
||||
Assertions.assertNull(operationResult.getChangedData());
|
||||
}
|
||||
}
|
@ -0,0 +1,105 @@
|
||||
package com.baomidou.mybatisplus.test.extension.plugins.inner;
|
||||
|
||||
import com.baomidou.mybatisplus.extension.plugins.handler.DataPermissionHandler;
|
||||
import com.baomidou.mybatisplus.extension.plugins.inner.DataPermissionInterceptor;
|
||||
import net.sf.jsqlparser.JSQLParserException;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
|
||||
import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
|
||||
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
/**
|
||||
* 数据权限拦截器测试
|
||||
*
|
||||
* @author hubin
|
||||
* @since 3.4.1 +
|
||||
*/
|
||||
public class DataPermissionInterceptorTest {
|
||||
private static String TEST_1 = "com.baomidou.userMapper.selectByUsername";
|
||||
private static String TEST_2 = "com.baomidou.userMapper.selectById";
|
||||
private static String TEST_3 = "com.baomidou.roleMapper.selectByCompanyId";
|
||||
private static String TEST_4 = "com.baomidou.roleMapper.selectById";
|
||||
private static String TEST_5 = "com.baomidou.roleMapper.selectByRoleId";
|
||||
|
||||
/**
|
||||
* 这里可以理解为数据库配置的数据权限规则 SQL
|
||||
*/
|
||||
private static Map<String, String> sqlSegmentMap = new HashMap<String, String>() {
|
||||
{
|
||||
put(TEST_1, "username='123' or userId IN (1,2,3)");
|
||||
put(TEST_2, "u.state=1 and u.amount > 1000");
|
||||
put(TEST_3, "companyId in (1,2,3)");
|
||||
put(TEST_4, "username like 'abc%'");
|
||||
put(TEST_5, "id=1 and role_id in (select id from sys_role)");
|
||||
}
|
||||
};
|
||||
|
||||
private static DataPermissionInterceptor interceptor = new DataPermissionInterceptor(new DataPermissionHandler() {
|
||||
|
||||
@Override
|
||||
public Expression getSqlSegment(Expression where, String mappedStatementId) {
|
||||
try {
|
||||
String sqlSegment = sqlSegmentMap.get(mappedStatementId);
|
||||
Expression sqlSegmentExpression = CCJSqlParserUtil.parseCondExpression(sqlSegment);
|
||||
if (null != where) {
|
||||
System.out.println("原 where : " + where.toString());
|
||||
if (mappedStatementId.equals(TEST_4)) {
|
||||
// 这里测试返回 OR 条件
|
||||
return new OrExpression(where, sqlSegmentExpression);
|
||||
}
|
||||
return new AndExpression(where, sqlSegmentExpression);
|
||||
}
|
||||
return sqlSegmentExpression;
|
||||
} catch (JSQLParserException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
return null;
|
||||
}
|
||||
});
|
||||
|
||||
@Test
|
||||
void test1() {
|
||||
assertSql(TEST_1, "select * from sys_user",
|
||||
"SELECT * FROM sys_user WHERE username = '123' OR userId IN (1, 2, 3)");
|
||||
}
|
||||
|
||||
@Test
|
||||
void test2() {
|
||||
assertSql(TEST_2, "select u.username from sys_user u join sys_user_role r on u.id=r.user_id where r.role_id=3",
|
||||
"SELECT u.username FROM sys_user u JOIN sys_user_role r ON u.id = r.user_id WHERE r.role_id = 3 AND u.state = 1 AND u.amount > 1000");
|
||||
}
|
||||
|
||||
@Test
|
||||
void test3() {
|
||||
assertSql(TEST_3, "select * from sys_role where company_id=6",
|
||||
"SELECT * FROM sys_role WHERE company_id = 6 AND companyId IN (1, 2, 3)");
|
||||
}
|
||||
|
||||
@Test
|
||||
void test3unionAll() {
|
||||
assertSql(TEST_3, "select * from sys_role where company_id=6 union all select * from sys_role where company_id=7",
|
||||
"SELECT * FROM sys_role WHERE company_id = 6 AND companyId IN (1, 2, 3) UNION ALL SELECT * FROM sys_role WHERE company_id = 7 AND companyId IN (1, 2, 3)");
|
||||
}
|
||||
|
||||
@Test
|
||||
void test4() {
|
||||
assertSql(TEST_4, "select * from sys_role where id=3",
|
||||
"SELECT * FROM sys_role WHERE id = 3 OR username LIKE 'abc%'");
|
||||
}
|
||||
|
||||
@Test
|
||||
void test5() {
|
||||
assertSql(TEST_5, "select * from sys_role where id=3",
|
||||
"SELECT * FROM sys_role WHERE id = 3 AND id = 1 AND role_id IN (SELECT id FROM sys_role)");
|
||||
}
|
||||
|
||||
void assertSql(String mappedStatementId, String sql, String targetSql) {
|
||||
assertThat(interceptor.parserSingle(sql, mappedStatementId)).isEqualTo(targetSql);
|
||||
}
|
||||
}
|
@ -0,0 +1,49 @@
|
||||
package com.baomidou.mybatisplus.test.extension.plugins.inner;
|
||||
|
||||
import com.baomidou.mybatisplus.extension.plugins.inner.DynamicTableNameInnerInterceptor;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
/**
|
||||
* 动态表名内部拦截器测试
|
||||
*
|
||||
* @author miemie, hcl
|
||||
* @since 2020-07-16
|
||||
*/
|
||||
class DynamicTableNameInnerInterceptorTest {
|
||||
|
||||
/**
|
||||
* 测试 SQL 中的动态表名替换
|
||||
*/
|
||||
@Test
|
||||
@SuppressWarnings({"SqlDialectInspection", "SqlNoDataSourceInspection"})
|
||||
void doIt() {
|
||||
DynamicTableNameInnerInterceptor interceptor = new DynamicTableNameInnerInterceptor();
|
||||
interceptor.setTableNameHandler((sql, tableName) -> tableName + "_r");
|
||||
|
||||
// 表名相互包含
|
||||
String origin = "SELECT * FROM t_user, t_user_role";
|
||||
assertEquals("SELECT * FROM t_user_r, t_user_role_r", interceptor.changeTable(origin));
|
||||
|
||||
// 表名在末尾
|
||||
origin = "SELECT * FROM t_user";
|
||||
assertEquals("SELECT * FROM t_user_r", interceptor.changeTable(origin));
|
||||
|
||||
// 表名前后有注释
|
||||
origin = "SELECT * FROM /**/t_user/* t_user */";
|
||||
assertEquals("SELECT * FROM /**/t_user_r/* t_user */", interceptor.changeTable(origin));
|
||||
|
||||
// 值中带有表名
|
||||
origin = "SELECT * FROM t_user WHERE name = 't_user'";
|
||||
assertEquals("SELECT * FROM t_user_r WHERE name = 't_user'", interceptor.changeTable(origin));
|
||||
|
||||
// 别名被声明要替换
|
||||
origin = "SELECT t_user.* FROM t_user_real t_user";
|
||||
assertEquals("SELECT t_user.* FROM t_user_real_r t_user", interceptor.changeTable(origin));
|
||||
|
||||
// 别名被声明要替换
|
||||
origin = "SELECT t.* FROM t_user_real t left join entity e on e.id = t.id";
|
||||
assertEquals("SELECT t.* FROM t_user_real_r t left join entity_r e on e.id = t.id", interceptor.changeTable(origin));
|
||||
}
|
||||
}
|
@ -0,0 +1,123 @@
|
||||
package com.baomidou.mybatisplus.test.extension.plugins.inner;
|
||||
|
||||
import com.baomidou.mybatisplus.core.exceptions.MybatisPlusException;
|
||||
import com.baomidou.mybatisplus.extension.plugins.inner.IllegalSQLInnerInterceptor;
|
||||
import org.apache.ibatis.jdbc.SqlRunner;
|
||||
import org.h2.jdbcx.JdbcDataSource;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import javax.sql.DataSource;
|
||||
import java.sql.Connection;
|
||||
import java.sql.SQLException;
|
||||
|
||||
/**
|
||||
* @author miemie
|
||||
* @since 2022-04-11
|
||||
*/
|
||||
class IllegalSQLInnerInterceptorTest {
|
||||
|
||||
private final IllegalSQLInnerInterceptor interceptor = new IllegalSQLInnerInterceptor();
|
||||
|
||||
private static DataSource dataSource;
|
||||
|
||||
@BeforeAll
|
||||
public static void beforeAll() throws SQLException {
|
||||
var jdbcDataSource = new JdbcDataSource();
|
||||
jdbcDataSource.setURL("jdbc:h2:mem:test;MODE=mysql;DB_CLOSE_DELAY=-1;DB_CLOSE_ON_EXIT=FALSE");
|
||||
jdbcDataSource.setPassword("");
|
||||
jdbcDataSource.setUser("sa");
|
||||
dataSource = jdbcDataSource;
|
||||
Connection connection = jdbcDataSource.getConnection();
|
||||
var sql = """
|
||||
CREATE TABLE T_DEMO (
|
||||
`a` int DEFAULT NULL,
|
||||
`b` int DEFAULT NULL,
|
||||
`c` int DEFAULT NULL,
|
||||
KEY `ab_index1` (`a`,`b`)
|
||||
);
|
||||
CREATE TABLE T_TEST (
|
||||
`a` int DEFAULT NULL,
|
||||
`b` int DEFAULT NULL,
|
||||
`c` int DEFAULT NULL,
|
||||
KEY `ab_index2` (`a`,`b`)
|
||||
);
|
||||
""";
|
||||
SqlRunner sqlRunner = new SqlRunner(connection);
|
||||
sqlRunner.run(sql);
|
||||
}
|
||||
|
||||
@Test
|
||||
void test() {
|
||||
Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("SELECT COUNT(*) AS total FROM t_user WHERE (client_id = ?)", null));
|
||||
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select * from t_user", null));
|
||||
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("update t_user set age = 18", null));
|
||||
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("delete from t_user set age = 18", null));
|
||||
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select * from t_user where age != 1", null));
|
||||
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select * from t_user where age = 1 or name = 'test'", null));
|
||||
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select * from t_user where (age = 1 or name = 'test')", null));
|
||||
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("update t_user set age = 1 where age = 1 or name = 'test'", null));
|
||||
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("update t_user set age = 1 where (age = 1 or name = 'test')", null));
|
||||
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("delete t_user where age = 1 or name = 'test'", null));
|
||||
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("delete t_user where (age = 1 or name = 'test')", null));
|
||||
|
||||
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select * from `T_DEMO` where b = 2", dataSource.getConnection()));
|
||||
Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("select * from T_DEMO where `a` = 1 and `b` = 2", dataSource.getConnection()));
|
||||
Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("select * from T_DEMO where a = 1 and `b` = 2", dataSource.getConnection()));
|
||||
|
||||
Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("select * from `T_DEMO` where `a` = 1 and `b` = 2", dataSource.getConnection()));
|
||||
Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("select * from `T_DEMO` where a = 1 and `b` = 2", dataSource.getConnection()));
|
||||
|
||||
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select * from `T_DEMO` where c = 3", dataSource.getConnection()));
|
||||
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select * from T_DEMO where c = 3", dataSource.getConnection()));
|
||||
|
||||
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select * from test.`T_DEMO` where c = 3", dataSource.getConnection()));
|
||||
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select * from test.T_DEMO where c = 3", dataSource.getConnection()));
|
||||
|
||||
Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("SELECT * FROM `T_DEMO` a INNER JOIN `T_TEST` b ON a.a = b.a WHERE a.a = 1", dataSource.getConnection()));
|
||||
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("SELECT * FROM `T_DEMO` a INNER JOIN `test` b ON a.a = b.a WHERE a.b = 1", dataSource.getConnection()));
|
||||
|
||||
Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("SELECT * FROM test.`T_DEMO` a INNER JOIN test.`T_TEST` b ON a.a = b.a WHERE a.a = 1", dataSource.getConnection()));
|
||||
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("SELECT * FROM test.`T_DEMO` a INNER JOIN test.`T_TEST` b ON a.a = b.a WHERE a.b = 1", dataSource.getConnection()));
|
||||
|
||||
Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("SELECT * FROM T_DEMO a INNER JOIN `T_TEST` b ON a.a = b.a WHERE a.a = 1", dataSource.getConnection()));
|
||||
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("SELECT * FROM T_DEMO a INNER JOIN `T_TEST` b ON a.a = b.a WHERE a.b = 1", dataSource.getConnection()));
|
||||
|
||||
Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("SELECT * FROM `T_DEMO` a LEFT JOIN `T_TEST` b ON a.a = b.a WHERE a.a = 1", dataSource.getConnection()));
|
||||
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("SELECT * FROM `T_DEMO` a LEFT JOIN `T_TEST` b ON a.a = b.a WHERE a.b = 1", dataSource.getConnection()));
|
||||
|
||||
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select * from `T_DEMO` where (c = 3 OR b = 2)", dataSource.getConnection()));
|
||||
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select * from `T_DEMO` where c = 3 OR b = 2", dataSource.getConnection()));
|
||||
Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("select * from `T_DEMO` where a = 3 AND (c = 3 OR b = 2)", dataSource.getConnection()));
|
||||
|
||||
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select * from `T_DEMO` where (a = 3 AND c = 3 OR b = 2)", dataSource.getConnection()));
|
||||
|
||||
Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("select * from `T_DEMO` where a in (1,3,2)", dataSource.getConnection()));
|
||||
|
||||
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select * from `T_DEMO` where a in (1,3,2) or b = 2", dataSource.getConnection()));
|
||||
Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("select * from `T_DEMO` where a in (1,3,2) AND b = 2", dataSource.getConnection()));
|
||||
|
||||
Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("select * from `T_DEMO` where (a = 3 AND c = 3 AND b = 2)", dataSource.getConnection()));
|
||||
Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("select * from `T_DEMO` a INNER JOIN T_TEST b ON a.a = b.a where a.a = 3 AND (b.c = 3 OR b.b = 2)", dataSource.getConnection()));
|
||||
|
||||
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select * from `T_DEMO` where a != (SELECT b FROM T_TEST limit 1) ", dataSource.getConnection()));
|
||||
//TODO 低版本这里的抛异常了.看着应该不用抛出
|
||||
Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("select * from `T_DEMO` where a = (SELECT b FROM T_TEST limit 1) ", dataSource.getConnection()));
|
||||
Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("select * from `T_DEMO` where a >= (SELECT b FROM T_TEST limit 1) ", dataSource.getConnection()));
|
||||
Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("select * from `T_DEMO` where a <= (SELECT b FROM T_TEST limit 1) ", dataSource.getConnection()));
|
||||
|
||||
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select * from `T_DEMO` where b = (SELECT b FROM T_TEST limit 1) ", dataSource.getConnection()));
|
||||
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select * from `T_DEMO` where b >= (SELECT b FROM T_TEST limit 1) ", dataSource.getConnection()));
|
||||
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select * from `T_DEMO` where b <= (SELECT b FROM T_TEST limit 1) ", dataSource.getConnection()));
|
||||
}
|
||||
@Test
|
||||
void testCount() {
|
||||
Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("select count(*) from T_DEMO where a = 1 and `b` = 2", dataSource.getConnection()));
|
||||
Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("select count(*) from (select * from T_DEMO where a = 1 and `b` = 2) a", dataSource.getConnection()));
|
||||
Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("select count(*) from (select count(*) from (select * from T_DEMO where a = 1 and `b` = 2) a) c", dataSource.getConnection()));
|
||||
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select count(*) from (select * from `T_DEMO`) a ", dataSource.getConnection()));
|
||||
Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select count(*) from (select * from `T_DEMO` where b = (SELECT b FROM T_TEST limit 1)) a ", dataSource.getConnection()));
|
||||
}
|
||||
|
||||
}
|
@ -0,0 +1,155 @@
|
||||
package com.baomidou.mybatisplus.test.extension.plugins.inner;
|
||||
|
||||
import com.baomidou.mybatisplus.core.toolkit.StringPool;
|
||||
import com.baomidou.mybatisplus.extension.plugins.handler.MultiDataPermissionHandler;
|
||||
import com.baomidou.mybatisplus.extension.plugins.inner.DataPermissionInterceptor;
|
||||
import com.google.common.collect.HashBasedTable;
|
||||
import net.sf.jsqlparser.JSQLParserException;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
||||
import net.sf.jsqlparser.schema.Table;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
/**
|
||||
* SQL多表场景的数据权限拦截器测试
|
||||
*
|
||||
* @author houkunlin
|
||||
* @since 3.5.2 +
|
||||
*/
|
||||
public class MultiDataPermissionInterceptorTest {
|
||||
private static final Logger logger = LoggerFactory.getLogger(MultiDataPermissionInterceptorTest.class);
|
||||
/**
|
||||
* 这里可以理解为数据库配置的数据权限规则 SQL
|
||||
*/
|
||||
private static final com.google.common.collect.Table<String, String, String> sqlSegmentMap;
|
||||
private static final DataPermissionInterceptor interceptor;
|
||||
private static String TEST_1 = "com.baomidou.userMapper.selectByUsername";
|
||||
private static String TEST_2 = "com.baomidou.userMapper.selectById";
|
||||
private static String TEST_3 = "com.baomidou.roleMapper.selectByCompanyId";
|
||||
private static String TEST_4 = "com.baomidou.roleMapper.selectById";
|
||||
private static String TEST_5 = "com.baomidou.roleMapper.selectByRoleId";
|
||||
private static String TEST_6 = "com.baomidou.roleMapper.selectUserInfo";
|
||||
private static String TEST_7 = "com.baomidou.roleMapper.summarySum";
|
||||
private static String TEST_8_1 = "com.baomidou.CustomMapper.selectByOnlyMyData";
|
||||
private static String TEST_8_2 = "com.baomidou.CustomMapper.selectByOnlyOrgData";
|
||||
private static String TEST_8_3 = "com.baomidou.CustomMapper.selectByOnlyDeptData";
|
||||
private static String TEST_8_4 = "com.baomidou.CustomMapper.selectByMyDataOrDeptData";
|
||||
private static String TEST_8_5 = "com.baomidou.CustomMapper.selectByMyData";
|
||||
|
||||
static {
|
||||
sqlSegmentMap = HashBasedTable.create();
|
||||
sqlSegmentMap.put(TEST_1, "sys_user", "username='123' or userId IN (1,2,3)");
|
||||
sqlSegmentMap.put(TEST_2, "sys_user", "u.state=1 and u.amount > 1000");
|
||||
sqlSegmentMap.put(TEST_3, "sys_role", "companyId in (1,2,3)");
|
||||
sqlSegmentMap.put(TEST_4, "sys_role", "username like 'abc%'");
|
||||
sqlSegmentMap.put(TEST_5, "sys_role", "id=1 and role_id in (select id from sys_role)");
|
||||
sqlSegmentMap.put(TEST_6, "sys_user", "u.state=1 and u.amount > 1000");
|
||||
sqlSegmentMap.put(TEST_6, "sys_user_role", "r.role_id=3 AND r.role_id IN (7,9,11)");
|
||||
sqlSegmentMap.put(TEST_7, "`fund`", "a.id = 1 AND a.year = 2022 AND a.create_user_id = 1111");
|
||||
sqlSegmentMap.put(TEST_7, "`fund_month`", "b.fund_id = 2 AND b.month <= '2022-05'");
|
||||
sqlSegmentMap.put(TEST_8_1, "fund", "user_id=1");
|
||||
sqlSegmentMap.put(TEST_8_2, "fund", "org_id=1");
|
||||
sqlSegmentMap.put(TEST_8_3, "fund", "dept_id=1");
|
||||
sqlSegmentMap.put(TEST_8_4, "fund", "user_id=1 or dept_id=1");
|
||||
sqlSegmentMap.put(TEST_8_5, "table1", "u.user_id=1");
|
||||
sqlSegmentMap.put(TEST_8_5, "table2", "u.dept_id=1");
|
||||
interceptor = new DataPermissionInterceptor(new MultiDataPermissionHandler() {
|
||||
|
||||
@Override
|
||||
public Expression getSqlSegment(final Table table, final Expression where, final String mappedStatementId) {
|
||||
try {
|
||||
String sqlSegment = sqlSegmentMap.get(mappedStatementId, table.getName());
|
||||
if (sqlSegment == null) {
|
||||
logger.info("{} {} AS {} : NOT FOUND", mappedStatementId, table.getName(), table.getAlias());
|
||||
return null;
|
||||
}
|
||||
if (table.getAlias() != null) {
|
||||
// 替换表别名
|
||||
sqlSegment = sqlSegment.replaceAll("u\\.", table.getAlias().getName() + StringPool.DOT);
|
||||
}
|
||||
Expression sqlSegmentExpression = CCJSqlParserUtil.parseCondExpression(sqlSegment);
|
||||
logger.info("{} {} AS {} : {}", mappedStatementId, table.getName(), table.getAlias(), sqlSegmentExpression.toString());
|
||||
return sqlSegmentExpression;
|
||||
} catch (JSQLParserException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
return null;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@Test
|
||||
void test1() {
|
||||
assertSql(TEST_1, "select * from sys_user",
|
||||
"SELECT * FROM sys_user WHERE username = '123' OR userId IN (1, 2, 3)");
|
||||
}
|
||||
|
||||
@Test
|
||||
void test2() {
|
||||
assertSql(TEST_2, "select u.username from sys_user u join sys_user_role r on u.id=r.user_id where r.role_id=3",
|
||||
"SELECT u.username FROM sys_user u JOIN sys_user_role r ON u.id = r.user_id WHERE r.role_id = 3 AND u.state = 1 AND u.amount > 1000");
|
||||
}
|
||||
|
||||
@Test
|
||||
void test3() {
|
||||
assertSql(TEST_3, "select * from sys_role where company_id=6",
|
||||
"SELECT * FROM sys_role WHERE company_id = 6 AND companyId IN (1, 2, 3)");
|
||||
}
|
||||
|
||||
@Test
|
||||
void test3unionAll() {
|
||||
assertSql(TEST_3, "select * from sys_role where company_id=6 union all select * from sys_role where company_id=7",
|
||||
"SELECT * FROM sys_role WHERE company_id = 6 AND companyId IN (1, 2, 3) UNION ALL SELECT * FROM sys_role WHERE company_id = 7 AND companyId IN (1, 2, 3)");
|
||||
}
|
||||
|
||||
@Test
|
||||
void test4() {
|
||||
assertSql(TEST_4, "select * from sys_role where id=3",
|
||||
"SELECT * FROM sys_role WHERE id = 3 AND username LIKE 'abc%'");
|
||||
}
|
||||
|
||||
@Test
|
||||
void test5() {
|
||||
assertSql(TEST_5, "select * from sys_role where id=3",
|
||||
"SELECT * FROM sys_role WHERE id = 3 AND id = 1 AND role_id IN (SELECT id FROM sys_role)");
|
||||
}
|
||||
|
||||
@Test
|
||||
void test6() {
|
||||
// 显式指定 JOIN 类型时 JOIN 右侧表才能进行拼接条件
|
||||
assertSql(TEST_6, "select u.username from sys_user u LEFT join sys_user_role r on u.id=r.user_id",
|
||||
"SELECT u.username FROM sys_user u LEFT JOIN sys_user_role r ON u.id = r.user_id AND r.role_id = 3 AND r.role_id IN (7, 9, 11) WHERE u.state = 1 AND u.amount > 1000");
|
||||
}
|
||||
|
||||
@Test
|
||||
void test7() {
|
||||
assertSql(TEST_7, "SELECT c.doc AS title, sum(c.total_paid_amount) AS total_paid_amount, sum(c.balance_amount) AS balance_amount FROM (SELECT `a`.`id`, `a`.`doc`, `b`.`month`, `b`.`total_paid_amount`, `b`.`balance_amount`, row_number() OVER (PARTITION BY `a`.`id` ORDER BY `b`.`month` DESC) AS `row_index` FROM `fund` `a` LEFT JOIN `fund_month` `b` ON `a`.`id` = `b`.`fund_id` AND `b`.`submit` = TRUE) c WHERE c.row_index = 1 GROUP BY title LIMIT 20",
|
||||
"SELECT c.doc AS title, sum(c.total_paid_amount) AS total_paid_amount, sum(c.balance_amount) AS balance_amount FROM (SELECT `a`.`id`, `a`.`doc`, `b`.`month`, `b`.`total_paid_amount`, `b`.`balance_amount`, row_number() OVER (PARTITION BY `a`.`id` ORDER BY `b`.`month` DESC) AS `row_index` FROM `fund` `a` LEFT JOIN `fund_month` `b` ON `a`.`id` = `b`.`fund_id` AND `b`.`submit` = TRUE AND b.fund_id = 2 AND b.month <= '2022-05' WHERE a.id = 1 AND a.year = 2022 AND a.create_user_id = 1111) c WHERE c.row_index = 1 GROUP BY title LIMIT 20");
|
||||
}
|
||||
|
||||
@Test
|
||||
void test8() {
|
||||
assertSql(TEST_8_1, "select * from fund where id=3",
|
||||
"SELECT * FROM fund WHERE id = 3 AND user_id = 1");
|
||||
assertSql(TEST_8_2, "select * from fund where id=3",
|
||||
"SELECT * FROM fund WHERE id = 3 AND org_id = 1");
|
||||
assertSql(TEST_8_3, "select * from fund where id=3",
|
||||
"SELECT * FROM fund WHERE id = 3 AND dept_id = 1");
|
||||
assertSql(TEST_8_4, "select * from fund where id=3",
|
||||
"SELECT * FROM fund WHERE id = 3 AND user_id = 1 OR dept_id = 1");
|
||||
// 修改之前旧版的多表数据权限对这个SQL的表现形式:
|
||||
// 输入 "WITH temp AS (SELECT t1.field1, t2.field2 FROM table1 t1 LEFT JOIN table2 t2 on t1.uid = t2.uid) SELECT * FROM temp"
|
||||
// 输出 "WITH temp AS (SELECT t1.field1, t2.field2 FROM table1 t1 LEFT JOIN table2 t2 ON t1.uid = t2.uid) SELECT * FROM temp"
|
||||
// 修改之后的多表数据权限对这个SQL的表现形式
|
||||
assertSql(TEST_8_5, "WITH temp AS (SELECT t1.field1, t2.field2 FROM table1 t1 LEFT JOIN table2 t2 on t1.uid = t2.uid) SELECT * FROM temp",
|
||||
"WITH temp AS (SELECT t1.field1, t2.field2 FROM table1 t1 LEFT JOIN table2 t2 ON t1.uid = t2.uid AND t2.dept_id = 1 WHERE t1.user_id = 1) SELECT * FROM temp");
|
||||
}
|
||||
|
||||
void assertSql(String mappedStatementId, String sql, String targetSql) {
|
||||
assertThat(interceptor.parserSingle(sql, mappedStatementId)).isEqualTo(targetSql);
|
||||
}
|
||||
}
|
@ -0,0 +1,109 @@
|
||||
package com.baomidou.mybatisplus.test.extension.plugins.inner;
|
||||
|
||||
import com.baomidou.mybatisplus.core.metadata.OrderItem;
|
||||
import com.baomidou.mybatisplus.extension.plugins.inner.PaginationInnerInterceptor;
|
||||
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.Arrays;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
/**
|
||||
* @author miemie
|
||||
* @since 2020-06-28
|
||||
*/
|
||||
class PaginationInnerInterceptorTest {
|
||||
|
||||
private final PaginationInnerInterceptor interceptor = new PaginationInnerInterceptor();
|
||||
|
||||
@Test
|
||||
void optimizeCount() {
|
||||
/* 能进行优化的 SQL */
|
||||
assertsCountSql("select * from user u LEFT JOIN role r ON r.id = u.role_id",
|
||||
"SELECT COUNT(*) AS total FROM user u");
|
||||
|
||||
assertsCountSql("select * from user u LEFT JOIN role r ON r.id = u.role_id WHERE u.xx = ?",
|
||||
"SELECT COUNT(*) AS total FROM user u WHERE u.xx = ?");
|
||||
|
||||
assertsCountSql("select * from user u LEFT JOIN role r ON r.id = u.role_id LEFT JOIN permission p on p.id = u.per_id",
|
||||
"SELECT COUNT(*) AS total FROM user u");
|
||||
|
||||
assertsCountSql("select * from user u LEFT JOIN role r ON r.id = u.role_id LEFT JOIN permission p on p.id = u.per_id WHERE u.xx = ?",
|
||||
"SELECT COUNT(*) AS total FROM user u WHERE u.xx = ?");
|
||||
|
||||
assertsCountSql("select distinct id from table order by id", "SELECT COUNT(*) FROM (SELECT DISTINCT id FROM table) TOTAL");
|
||||
|
||||
assertsCountSql("select distinct id from table", "SELECT COUNT(*) FROM (SELECT DISTINCT id FROM table) TOTAL");
|
||||
}
|
||||
|
||||
@Test
|
||||
void notOptimizeCount() {
|
||||
/* 不能进行优化的 SQL */
|
||||
assertsCountSql("select * from user u LEFT JOIN role r ON r.id = u.role_id AND r.name = ? where u.xx = ?",
|
||||
"SELECT COUNT(*) AS total FROM user u LEFT JOIN role r ON r.id = u.role_id AND r.name = ? WHERE u.xx = ?");
|
||||
|
||||
/* join 表与 where 条件大小写不同的情况 */
|
||||
assertsCountSql("select * from user u LEFT JOIN role r ON r.id = u.role_id where R.NAME = ?",
|
||||
"SELECT COUNT(*) AS total FROM user u LEFT JOIN role r ON r.id = u.role_id WHERE R.NAME = ?");
|
||||
|
||||
assertsCountSql("select * from user u LEFT JOIN role r ON r.id = u.role_id WHERE u.xax = ? AND r.cc = ? AND r.qq = ?",
|
||||
"SELECT COUNT(*) AS total FROM user u LEFT JOIN role r ON r.id = u.role_id WHERE u.xax = ? AND r.cc = ? AND r.qq = ?");
|
||||
}
|
||||
|
||||
@Test
|
||||
void optimizeCountOrderBy() {
|
||||
/* order by 里不带参数,去除order by */
|
||||
assertsCountSql("SELECT * FROM comment ORDER BY name",
|
||||
"SELECT COUNT(*) AS total FROM comment");
|
||||
|
||||
/* order by 里带参数,不去除order by */
|
||||
assertsCountSql("SELECT * FROM comment ORDER BY (CASE WHEN creator = ? THEN 0 ELSE 1 END)",
|
||||
"SELECT COUNT(*) AS total FROM comment ORDER BY (CASE WHEN creator = ? THEN 0 ELSE 1 END)");
|
||||
}
|
||||
|
||||
@Test
|
||||
void withAsCount() {
|
||||
assertsCountSql("with A as (select * from class) select * from A",
|
||||
"WITH A AS (SELECT * FROM class) SELECT COUNT(*) AS total FROM A");
|
||||
}
|
||||
|
||||
@Test
|
||||
void withAsOrderBy() {
|
||||
assertsConcatOrderBy("with A as (select * from class) select * from A",
|
||||
"WITH A AS (SELECT * FROM class) SELECT * FROM A ORDER BY column ASC",
|
||||
OrderItem.asc("column"));
|
||||
}
|
||||
|
||||
@Test
|
||||
void groupByCount() {
|
||||
assertsCountSql("SELECT * FROM record_1 WHERE id = ? GROUP BY date(date_time)",
|
||||
"SELECT COUNT(*) FROM (SELECT * FROM record_1 WHERE id = ? GROUP BY date(date_time)) TOTAL");
|
||||
}
|
||||
|
||||
@Test
|
||||
void leftJoinSelectCount() {
|
||||
assertsCountSql("select r.id, r.name, r.phone,rlr.total_top_up from reseller r " +
|
||||
"left join (select ral.reseller_id, sum(ral.top_up_money) as total_top_up, sum(ral.acquire_money) as total_acquire " +
|
||||
"from reseller_acquire_log ral " +
|
||||
"group by ral.reseller_id) rlr on r.id = rlr.reseller_id " +
|
||||
"order by r.created_at desc",
|
||||
"SELECT COUNT(*) AS total FROM reseller r");
|
||||
|
||||
// 不优化
|
||||
assertsCountSql("SELECT f.ca, f.cb FROM table_a f LEFT JOIN " +
|
||||
"(SELECT ca FROM table_b WHERE cc = ?) rf on rf.ca = f.ca",
|
||||
"SELECT COUNT(*) AS total FROM table_a f LEFT JOIN (SELECT ca FROM table_b WHERE cc = ?) rf ON rf.ca = f.ca");
|
||||
|
||||
assertsCountSql("select * from order_info left join (select count(1) from order_info where create_time between ? and ?) tt on 1=1 WHERE equipment_id=?",
|
||||
"SELECT COUNT(*) AS total FROM order_info LEFT JOIN (SELECT count(1) FROM order_info WHERE create_time BETWEEN ? AND ?) tt ON 1 = 1 WHERE equipment_id = ?");
|
||||
}
|
||||
|
||||
void assertsCountSql(String sql, String targetSql) {
|
||||
assertThat(interceptor.autoCountSql(new Page<>(), sql)).isEqualTo(targetSql);
|
||||
}
|
||||
|
||||
void assertsConcatOrderBy(String sql, String targetSql, OrderItem... orderItems) {
|
||||
assertThat(interceptor.concatOrderBy(sql, Arrays.asList(orderItems))).isEqualTo(targetSql);
|
||||
}
|
||||
}
|
@ -0,0 +1,470 @@
|
||||
package com.baomidou.mybatisplus.test.extension.plugins.inner;
|
||||
|
||||
import com.baomidou.mybatisplus.extension.plugins.handler.TenantLineHandler;
|
||||
import com.baomidou.mybatisplus.extension.plugins.inner.TenantLineInnerInterceptor;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
import net.sf.jsqlparser.expression.LongValue;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
/**
|
||||
* @author miemie
|
||||
* @since 2020-07-30
|
||||
*/
|
||||
class TenantLineInnerInterceptorTest {
|
||||
|
||||
private final TenantLineInnerInterceptor interceptor = new TenantLineInnerInterceptor(new TenantLineHandler() {
|
||||
private boolean ignoreFirst;// 需要执行 getTenantId 前必须先执行 ignoreTable
|
||||
|
||||
@Override
|
||||
public Expression getTenantId() {
|
||||
assertThat(ignoreFirst).isEqualTo(true);
|
||||
ignoreFirst = false;
|
||||
return new LongValue(1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean ignoreTable(String tableName) {
|
||||
ignoreFirst = true;
|
||||
return tableName.startsWith("with_as");
|
||||
}
|
||||
});
|
||||
|
||||
@Test
|
||||
void insert() {
|
||||
// plain
|
||||
assertSql("insert into entity (id) values (?)",
|
||||
"INSERT INTO entity (id, tenant_id) VALUES (?, 1)");
|
||||
assertSql("insert into entity (id,name) values (?,?)",
|
||||
"INSERT INTO entity (id, name, tenant_id) VALUES (?, ?, 1)");
|
||||
// batch
|
||||
assertSql("insert into entity (id) values (?),(?)",
|
||||
"INSERT INTO entity (id, tenant_id) VALUES (?, 1), (?, 1)");
|
||||
assertSql("insert into entity (id,name) values (?,?),(?,?)",
|
||||
"INSERT INTO entity (id, name, tenant_id) VALUES (?, ?, 1), (?, ?, 1)");
|
||||
// 无 insert的列
|
||||
assertSql("insert into entity value (?,?)",
|
||||
"INSERT INTO entity VALUES (?, ?)");
|
||||
// 自己加了insert的列
|
||||
assertSql("insert into entity (id,name,tenant_id) value (?,?,?)",
|
||||
"INSERT INTO entity (id, name, tenant_id) VALUES (?, ?, ?)");
|
||||
// insert into select
|
||||
assertSql("insert into entity (id,name) select id,name from entity2",
|
||||
"INSERT INTO entity (id, name, tenant_id) SELECT id, name, tenant_id FROM entity2 WHERE tenant_id = 1");
|
||||
|
||||
assertSql("insert into entity (id,name) select * from entity2 e2",
|
||||
"INSERT INTO entity (id, name, tenant_id) SELECT * FROM entity2 e2 WHERE e2.tenant_id = 1");
|
||||
|
||||
assertSql("insert into entity (id,name) select id,name from (select id,name from entity3 e3) t",
|
||||
"INSERT INTO entity (id, name, tenant_id) SELECT id, name, tenant_id FROM (SELECT id, name, tenant_id FROM entity3 e3 WHERE e3.tenant_id = 1) t");
|
||||
|
||||
assertSql("insert into entity (id,name) select * from (select id,name from entity3 e3) t",
|
||||
"INSERT INTO entity (id, name, tenant_id) SELECT * FROM (SELECT id, name, tenant_id FROM entity3 e3 WHERE e3.tenant_id = 1) t");
|
||||
|
||||
assertSql("insert into entity (id,name) select t.* from (select id,name from entity3 e3) t",
|
||||
"INSERT INTO entity (id, name, tenant_id) SELECT t.* FROM (SELECT id, name, tenant_id FROM entity3 e3 WHERE e3.tenant_id = 1) t");
|
||||
}
|
||||
|
||||
@Test
|
||||
void delete() {
|
||||
assertSql("delete from entity where id = ?",
|
||||
"DELETE FROM entity WHERE id = ? AND tenant_id = 1");
|
||||
}
|
||||
|
||||
@Test
|
||||
void update() {
|
||||
assertSql("update entity set name = ? where id = ?",
|
||||
"UPDATE entity SET name = ? WHERE id = ? AND tenant_id = 1");
|
||||
|
||||
// set subSelect
|
||||
assertSql("UPDATE entity e SET e.cq = (SELECT e1.total FROM entity e1 WHERE e1.id = ?) WHERE e.id = ?",
|
||||
"UPDATE entity e SET e.cq = (SELECT e1.total FROM entity e1 WHERE e1.id = ? AND e1.tenant_id = 1) " +
|
||||
"WHERE e.id = ? AND e.tenant_id = 1");
|
||||
|
||||
assertSql("UPDATE sys_user SET (name, age) = ('秋秋', 18), address = 'test'",
|
||||
"UPDATE sys_user SET (name, age) = ('秋秋', 18), address = 'test' WHERE tenant_id = 1");
|
||||
|
||||
assertSql("UPDATE entity t1 INNER JOIN entity t2 ON t1.a= t2.a SET t1.b = t2.b, t1.c = t2.c",
|
||||
"UPDATE entity t1 INNER JOIN entity t2 ON t1.a = t2.a SET t1.b = t2.b, t1.c = t2.c WHERE t1.tenant_id = 1");
|
||||
}
|
||||
|
||||
@Test
|
||||
void selectSingle() {
|
||||
// 单表
|
||||
assertSql("select * from entity where id = ?",
|
||||
"SELECT * FROM entity WHERE id = ? AND tenant_id = 1");
|
||||
|
||||
assertSql("select * from entity where id = ? or name = ?",
|
||||
"SELECT * FROM entity WHERE (id = ? OR name = ?) AND tenant_id = 1");
|
||||
|
||||
assertSql("SELECT * FROM entity WHERE (id = ? OR name = ?)",
|
||||
"SELECT * FROM entity WHERE (id = ? OR name = ?) AND tenant_id = 1");
|
||||
|
||||
/* not */
|
||||
assertSql("SELECT * FROM entity WHERE not (id = ? OR name = ?)",
|
||||
"SELECT * FROM entity WHERE NOT (id = ? OR name = ?) AND tenant_id = 1");
|
||||
|
||||
assertSql("SELECT * FROM entity u WHERE not (u.id = ? OR u.name = ?)",
|
||||
"SELECT * FROM entity u WHERE NOT (u.id = ? OR u.name = ?) AND u.tenant_id = 1");
|
||||
}
|
||||
|
||||
@Test
|
||||
void selectSubSelectIn() {
|
||||
/* in */
|
||||
assertSql("SELECT * FROM entity e WHERE e.id IN (select e1.id from entity1 e1 where e1.id = ?)",
|
||||
"SELECT * FROM entity e WHERE e.id IN (SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.tenant_id = 1");
|
||||
// 在最前
|
||||
assertSql("SELECT * FROM entity e WHERE e.id IN " +
|
||||
"(select e1.id from entity1 e1 where e1.id = ?) and e.id = ?",
|
||||
"SELECT * FROM entity e WHERE e.id IN " +
|
||||
"(SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.id = ? AND e.tenant_id = 1");
|
||||
// 在最后
|
||||
assertSql("SELECT * FROM entity e WHERE e.id = ? and e.id IN " +
|
||||
"(select e1.id from entity1 e1 where e1.id = ?)",
|
||||
"SELECT * FROM entity e WHERE e.id = ? AND e.id IN " +
|
||||
"(SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.tenant_id = 1");
|
||||
// 在中间
|
||||
assertSql("SELECT * FROM entity e WHERE e.id = ? and e.id IN " +
|
||||
"(select e1.id from entity1 e1 where e1.id = ?) and e.id = ?",
|
||||
"SELECT * FROM entity e WHERE e.id = ? AND e.id IN " +
|
||||
"(SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.id = ? AND e.tenant_id = 1");
|
||||
}
|
||||
|
||||
@Test
|
||||
void selectSubSelectEq() {
|
||||
/* = */
|
||||
assertSql("SELECT * FROM entity e WHERE e.id = (select e1.id from entity1 e1 where e1.id = ?)",
|
||||
"SELECT * FROM entity e WHERE e.id = (SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.tenant_id = 1");
|
||||
}
|
||||
|
||||
@Test
|
||||
void selectSubSelectInnerNotEq() {
|
||||
/* inner not = */
|
||||
assertSql("SELECT * FROM entity e WHERE not (e.id = (select e1.id from entity1 e1 where e1.id = ?))",
|
||||
"SELECT * FROM entity e WHERE NOT (e.id = (SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1)) AND e.tenant_id = 1");
|
||||
|
||||
assertSql("SELECT * FROM entity e WHERE not (e.id = (select e1.id from entity1 e1 where e1.id = ?) and e.id = ?)",
|
||||
"SELECT * FROM entity e WHERE NOT (e.id = (SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.id = ?) AND e.tenant_id = 1");
|
||||
}
|
||||
|
||||
@Test
|
||||
void selectSubSelectExists() {
|
||||
/* EXISTS */
|
||||
assertSql("SELECT * FROM entity e WHERE EXISTS (select e1.id from entity1 e1 where e1.id = ?)",
|
||||
"SELECT * FROM entity e WHERE EXISTS (SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.tenant_id = 1");
|
||||
|
||||
assertSql("SELECT EXISTS (SELECT 1 FROM entity1 e WHERE e.id = ? LIMIT 1)","SELECT EXISTS (SELECT 1 FROM entity1 e WHERE e.id = ? AND e.tenant_id = 1 LIMIT 1)");
|
||||
|
||||
/* NOT EXISTS */
|
||||
assertSql("SELECT * FROM entity e WHERE NOT EXISTS (select e1.id from entity1 e1 where e1.id = ?)",
|
||||
"SELECT * FROM entity e WHERE NOT EXISTS (SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.tenant_id = 1");
|
||||
}
|
||||
|
||||
@Test
|
||||
void selectWhereSubSelect() {
|
||||
/* >= */
|
||||
assertSql("SELECT * FROM entity e WHERE e.id >= (select e1.id from entity1 e1 where e1.id = ?)",
|
||||
"SELECT * FROM entity e WHERE e.id >= (SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.tenant_id = 1");
|
||||
|
||||
|
||||
/* <= */
|
||||
assertSql("SELECT * FROM entity e WHERE e.id <= (select e1.id from entity1 e1 where e1.id = ?)",
|
||||
"SELECT * FROM entity e WHERE e.id <= (SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.tenant_id = 1");
|
||||
|
||||
|
||||
/* <> */
|
||||
assertSql("SELECT * FROM entity e WHERE e.id <> (select e1.id from entity1 e1 where e1.id = ?)",
|
||||
"SELECT * FROM entity e WHERE e.id <> (SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.tenant_id = 1");
|
||||
}
|
||||
|
||||
@Test
|
||||
void selectFromSelect() {
|
||||
assertSql("SELECT * FROM (select e.id from entity e WHERE e.id = (select e1.id from entity1 e1 where e1.id = ?))",
|
||||
"SELECT * FROM (SELECT e.id FROM entity e WHERE e.id = (SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.tenant_id = 1)");
|
||||
}
|
||||
|
||||
@Test
|
||||
void selectBodySubSelect() {
|
||||
assertSql("select t1.col1,(select t2.col2 from t2 t2 where t1.col1=t2.col1) from t1 t1",
|
||||
"SELECT t1.col1, (SELECT t2.col2 FROM t2 t2 WHERE t1.col1 = t2.col1 AND t2.tenant_id = 1) FROM t1 t1 WHERE t1.tenant_id = 1");
|
||||
}
|
||||
|
||||
@Test
|
||||
void selectBodyFuncSubSelect() {
|
||||
assertSql("SELECT e1.*, IF((SELECT e2.id FROM entity2 e2 WHERE e2.id = 1) = 1, e2.type, e1.type) AS type " +
|
||||
"FROM entity e1 WHERE e1.id = ?",
|
||||
"SELECT e1.*, IF((SELECT e2.id FROM entity2 e2 WHERE e2.id = 1 AND e2.tenant_id = 1) = 1, e2.type, e1.type) AS type " +
|
||||
"FROM entity e1 WHERE e1.id = ? AND e1.tenant_id = 1");
|
||||
}
|
||||
|
||||
@Test
|
||||
void selectLeftJoin() {
|
||||
// left join
|
||||
assertSql("SELECT * FROM entity e " +
|
||||
"left join entity1 e1 on e1.id = e.id " +
|
||||
"WHERE e.id = ? OR e.name = ?",
|
||||
"SELECT * FROM entity e " +
|
||||
"LEFT JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " +
|
||||
"WHERE (e.id = ? OR e.name = ?) AND e.tenant_id = 1");
|
||||
|
||||
assertSql("SELECT * FROM entity e " +
|
||||
"left join entity1 e1 on e1.id = e.id " +
|
||||
"WHERE (e.id = ? OR e.name = ?)",
|
||||
"SELECT * FROM entity e " +
|
||||
"LEFT JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " +
|
||||
"WHERE (e.id = ? OR e.name = ?) AND e.tenant_id = 1");
|
||||
|
||||
assertSql("SELECT * FROM entity e " +
|
||||
"left join entity1 e1 on e1.id = e.id " +
|
||||
"left join entity2 e2 on e1.id = e2.id",
|
||||
"SELECT * FROM entity e " +
|
||||
"LEFT JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " +
|
||||
"LEFT JOIN entity2 e2 ON e1.id = e2.id AND e2.tenant_id = 1 " +
|
||||
"WHERE e.tenant_id = 1");
|
||||
}
|
||||
|
||||
@Test
|
||||
void selectRightJoin() {
|
||||
// right join
|
||||
assertSql("SELECT * FROM entity e " +
|
||||
"right join entity1 e1 on e1.id = e.id",
|
||||
"SELECT * FROM entity e " +
|
||||
"RIGHT JOIN entity1 e1 ON e1.id = e.id AND e.tenant_id = 1 " +
|
||||
"WHERE e1.tenant_id = 1");
|
||||
|
||||
assertSql("SELECT * FROM with_as_1 e " +
|
||||
"right join entity1 e1 on e1.id = e.id",
|
||||
"SELECT * FROM with_as_1 e " +
|
||||
"RIGHT JOIN entity1 e1 ON e1.id = e.id " +
|
||||
"WHERE e1.tenant_id = 1");
|
||||
|
||||
assertSql("SELECT * FROM entity e " +
|
||||
"right join entity1 e1 on e1.id = e.id " +
|
||||
"WHERE e.id = ? OR e.name = ?",
|
||||
"SELECT * FROM entity e " +
|
||||
"RIGHT JOIN entity1 e1 ON e1.id = e.id AND e.tenant_id = 1 " +
|
||||
"WHERE (e.id = ? OR e.name = ?) AND e1.tenant_id = 1");
|
||||
|
||||
assertSql("SELECT * FROM entity e " +
|
||||
"right join entity1 e1 on e1.id = e.id " +
|
||||
"right join entity2 e2 on e1.id = e2.id ",
|
||||
"SELECT * FROM entity e " +
|
||||
"RIGHT JOIN entity1 e1 ON e1.id = e.id AND e.tenant_id = 1 " +
|
||||
"RIGHT JOIN entity2 e2 ON e1.id = e2.id AND e1.tenant_id = 1 " +
|
||||
"WHERE e2.tenant_id = 1");
|
||||
}
|
||||
|
||||
@Test
|
||||
void selectMixJoin() {
|
||||
assertSql("SELECT * FROM entity e " +
|
||||
"right join entity1 e1 on e1.id = e.id " +
|
||||
"left join entity2 e2 on e1.id = e2.id",
|
||||
"SELECT * FROM entity e " +
|
||||
"RIGHT JOIN entity1 e1 ON e1.id = e.id AND e.tenant_id = 1 " +
|
||||
"LEFT JOIN entity2 e2 ON e1.id = e2.id AND e2.tenant_id = 1 " +
|
||||
"WHERE e1.tenant_id = 1");
|
||||
|
||||
assertSql("SELECT * FROM entity e " +
|
||||
"left join entity1 e1 on e1.id = e.id " +
|
||||
"right join entity2 e2 on e1.id = e2.id",
|
||||
"SELECT * FROM entity e " +
|
||||
"LEFT JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " +
|
||||
"RIGHT JOIN entity2 e2 ON e1.id = e2.id AND e.tenant_id = 1 " +
|
||||
"WHERE e2.tenant_id = 1");
|
||||
|
||||
assertSql("SELECT * FROM entity e " +
|
||||
"left join entity1 e1 on e1.id = e.id " +
|
||||
"inner join entity2 e2 on e1.id = e2.id",
|
||||
"SELECT * FROM entity e " +
|
||||
"LEFT JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " +
|
||||
"INNER JOIN entity2 e2 ON e1.id = e2.id AND e.tenant_id = 1 AND e2.tenant_id = 1");
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
void selectJoinSubSelect() {
|
||||
assertSql("select * from (select * from entity e) e1 " +
|
||||
"left join entity2 e2 on e1.id = e2.id",
|
||||
"SELECT * FROM (SELECT * FROM entity e WHERE e.tenant_id = 1) e1 " +
|
||||
"LEFT JOIN entity2 e2 ON e1.id = e2.id AND e2.tenant_id = 1");
|
||||
|
||||
assertSql("select * from entity1 e1 " +
|
||||
"left join (select * from entity2 e2) e22 " +
|
||||
"on e1.id = e22.id",
|
||||
"SELECT * FROM entity1 e1 " +
|
||||
"LEFT JOIN (SELECT * FROM entity2 e2 WHERE e2.tenant_id = 1) e22 " +
|
||||
"ON e1.id = e22.id " +
|
||||
"WHERE e1.tenant_id = 1");
|
||||
}
|
||||
|
||||
@Test
|
||||
void selectSubJoin() {
|
||||
assertSql("select * FROM " +
|
||||
"(entity1 e1 right JOIN entity2 e2 ON e1.id = e2.id)",
|
||||
"SELECT * FROM " +
|
||||
"(entity1 e1 RIGHT JOIN entity2 e2 ON e1.id = e2.id AND e1.tenant_id = 1) " +
|
||||
"WHERE e2.tenant_id = 1");
|
||||
|
||||
assertSql("select * FROM " +
|
||||
"(entity1 e1 LEFT JOIN entity2 e2 ON e1.id = e2.id)",
|
||||
"SELECT * FROM " +
|
||||
"(entity1 e1 LEFT JOIN entity2 e2 ON e1.id = e2.id AND e2.tenant_id = 1) " +
|
||||
"WHERE e1.tenant_id = 1");
|
||||
|
||||
|
||||
assertSql("select * FROM " +
|
||||
"(entity1 e1 LEFT JOIN entity2 e2 ON e1.id = e2.id) " +
|
||||
"right join entity3 e3 on e1.id = e3.id",
|
||||
"SELECT * FROM " +
|
||||
"(entity1 e1 LEFT JOIN entity2 e2 ON e1.id = e2.id AND e2.tenant_id = 1) " +
|
||||
"RIGHT JOIN entity3 e3 ON e1.id = e3.id AND e1.tenant_id = 1 " +
|
||||
"WHERE e3.tenant_id = 1");
|
||||
|
||||
|
||||
assertSql("select * FROM entity e " +
|
||||
"LEFT JOIN (entity1 e1 right join entity2 e2 ON e1.id = e2.id) " +
|
||||
"on e.id = e2.id",
|
||||
"SELECT * FROM entity e " +
|
||||
"LEFT JOIN (entity1 e1 RIGHT JOIN entity2 e2 ON e1.id = e2.id AND e1.tenant_id = 1) " +
|
||||
"ON e.id = e2.id AND e2.tenant_id = 1 " +
|
||||
"WHERE e.tenant_id = 1");
|
||||
|
||||
assertSql("select * FROM entity e " +
|
||||
"LEFT JOIN (entity1 e1 left join entity2 e2 ON e1.id = e2.id) " +
|
||||
"on e.id = e2.id",
|
||||
"SELECT * FROM entity e " +
|
||||
"LEFT JOIN (entity1 e1 LEFT JOIN entity2 e2 ON e1.id = e2.id AND e2.tenant_id = 1) " +
|
||||
"ON e.id = e2.id AND e1.tenant_id = 1 " +
|
||||
"WHERE e.tenant_id = 1");
|
||||
|
||||
assertSql("select * FROM entity e " +
|
||||
"RIGHT JOIN (entity1 e1 left join entity2 e2 ON e1.id = e2.id) " +
|
||||
"on e.id = e2.id",
|
||||
"SELECT * FROM entity e " +
|
||||
"RIGHT JOIN (entity1 e1 LEFT JOIN entity2 e2 ON e1.id = e2.id AND e2.tenant_id = 1) " +
|
||||
"ON e.id = e2.id AND e.tenant_id = 1 " +
|
||||
"WHERE e1.tenant_id = 1");
|
||||
}
|
||||
|
||||
@Test
|
||||
void selectLeftJoinMultipleTrailingOn() {
|
||||
// 多个 on 尾缀的
|
||||
assertSql("SELECT * FROM entity e " +
|
||||
"LEFT JOIN entity1 e1 " +
|
||||
"LEFT JOIN entity2 e2 ON e2.id = e1.id " +
|
||||
"ON e1.id = e.id " +
|
||||
"WHERE (e.id = ? OR e.NAME = ?)",
|
||||
"SELECT * FROM entity e " +
|
||||
"LEFT JOIN entity1 e1 " +
|
||||
"LEFT JOIN entity2 e2 ON e2.id = e1.id AND e2.tenant_id = 1 " +
|
||||
"ON e1.id = e.id AND e1.tenant_id = 1 " +
|
||||
"WHERE (e.id = ? OR e.NAME = ?) AND e.tenant_id = 1");
|
||||
|
||||
assertSql("SELECT * FROM entity e " +
|
||||
"LEFT JOIN entity1 e1 " +
|
||||
"LEFT JOIN with_as_A e2 ON e2.id = e1.id " +
|
||||
"ON e1.id = e.id " +
|
||||
"WHERE (e.id = ? OR e.NAME = ?)",
|
||||
"SELECT * FROM entity e " +
|
||||
"LEFT JOIN entity1 e1 " +
|
||||
"LEFT JOIN with_as_A e2 ON e2.id = e1.id " +
|
||||
"ON e1.id = e.id AND e1.tenant_id = 1 " +
|
||||
"WHERE (e.id = ? OR e.NAME = ?) AND e.tenant_id = 1");
|
||||
}
|
||||
|
||||
@Test
|
||||
void selectInnerJoin() {
|
||||
// inner join
|
||||
assertSql("SELECT * FROM entity e " +
|
||||
"inner join entity1 e1 on e1.id = e.id " +
|
||||
"WHERE e.id = ? OR e.name = ?",
|
||||
"SELECT * FROM entity e " +
|
||||
"INNER JOIN entity1 e1 ON e1.id = e.id AND e.tenant_id = 1 AND e1.tenant_id = 1 " +
|
||||
"WHERE e.id = ? OR e.name = ?");
|
||||
|
||||
assertSql("SELECT * FROM entity e " +
|
||||
"inner join entity1 e1 on e1.id = e.id " +
|
||||
"WHERE (e.id = ? OR e.name = ?)",
|
||||
"SELECT * FROM entity e " +
|
||||
"INNER JOIN entity1 e1 ON e1.id = e.id AND e.tenant_id = 1 AND e1.tenant_id = 1 " +
|
||||
"WHERE (e.id = ? OR e.name = ?)");
|
||||
|
||||
// ignore table
|
||||
assertSql("SELECT * FROM entity e " +
|
||||
"inner join with_as_1 w1 on w1.id = e.id " +
|
||||
"WHERE (e.id = ? OR e.name = ?)",
|
||||
"SELECT * FROM entity e " +
|
||||
"INNER JOIN with_as_1 w1 ON w1.id = e.id AND e.tenant_id = 1 " +
|
||||
"WHERE (e.id = ? OR e.name = ?)");
|
||||
|
||||
// 隐式内连接
|
||||
assertSql("SELECT * FROM entity e,entity1 e1 " +
|
||||
"WHERE e.id = e1.id",
|
||||
"SELECT * FROM entity e, entity1 e1 " +
|
||||
"WHERE e.id = e1.id AND e.tenant_id = 1 AND e1.tenant_id = 1");
|
||||
|
||||
// 隐式内连接
|
||||
assertSql("SELECT * FROM entity a, with_as_entity1 b " +
|
||||
"WHERE a.id = b.id",
|
||||
"SELECT * FROM entity a, with_as_entity1 b " +
|
||||
"WHERE a.id = b.id AND a.tenant_id = 1");
|
||||
|
||||
assertSql("SELECT * FROM with_as_entity a, with_as_entity1 b " +
|
||||
"WHERE a.id = b.id",
|
||||
"SELECT * FROM with_as_entity a, with_as_entity1 b " +
|
||||
"WHERE a.id = b.id");
|
||||
|
||||
// SubJoin with 隐式内连接
|
||||
assertSql("SELECT * FROM (entity e,entity1 e1) " +
|
||||
"WHERE e.id = e1.id",
|
||||
"SELECT * FROM (entity e, entity1 e1) " +
|
||||
"WHERE e.id = e1.id " +
|
||||
"AND e.tenant_id = 1 AND e1.tenant_id = 1");
|
||||
|
||||
assertSql("SELECT * FROM ((entity e,entity1 e1),entity2 e2) " +
|
||||
"WHERE e.id = e1.id and e.id = e2.id",
|
||||
"SELECT * FROM ((entity e, entity1 e1), entity2 e2) " +
|
||||
"WHERE e.id = e1.id AND e.id = e2.id " +
|
||||
"AND e.tenant_id = 1 AND e1.tenant_id = 1 AND e2.tenant_id = 1");
|
||||
|
||||
assertSql("SELECT * FROM (entity e,(entity1 e1,entity2 e2)) " +
|
||||
"WHERE e.id = e1.id and e.id = e2.id",
|
||||
"SELECT * FROM (entity e, (entity1 e1, entity2 e2)) " +
|
||||
"WHERE e.id = e1.id AND e.id = e2.id " +
|
||||
"AND e.tenant_id = 1 AND e1.tenant_id = 1 AND e2.tenant_id = 1");
|
||||
|
||||
// 沙雕的括号写法
|
||||
assertSql("SELECT * FROM (((entity e,entity1 e1))) " +
|
||||
"WHERE e.id = e1.id",
|
||||
"SELECT * FROM (((entity e, entity1 e1))) " +
|
||||
"WHERE e.id = e1.id " +
|
||||
"AND e.tenant_id = 1 AND e1.tenant_id = 1");
|
||||
}
|
||||
|
||||
@Test
|
||||
void selectSingleJoin() {
|
||||
// join
|
||||
assertSql("SELECT * FROM entity e join entity1 e1 on e1.id = e.id WHERE e.id = ? OR e.name = ?",
|
||||
"SELECT * FROM entity e JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 WHERE (e.id = ? OR e.name = ?) AND e.tenant_id = 1");
|
||||
|
||||
assertSql("SELECT * FROM entity e join entity1 e1 on e1.id = e.id WHERE (e.id = ? OR e.name = ?)",
|
||||
"SELECT * FROM entity e JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 WHERE (e.id = ? OR e.name = ?) AND e.tenant_id = 1");
|
||||
}
|
||||
|
||||
@Test
|
||||
void selectWithAs() {
|
||||
assertSql("with with_as_A as (select * from entity) select * from with_as_A",
|
||||
"WITH with_as_A AS (SELECT * FROM entity WHERE tenant_id = 1) SELECT * FROM with_as_A");
|
||||
}
|
||||
|
||||
@Test
|
||||
void testDuplicateKeyUpdate() {
|
||||
assertSql("INSERT INTO entity (name,age) VALUES ('秋秋',18),('秋秋','22') ON DUPLICATE KEY UPDATE age=18",
|
||||
"INSERT INTO entity (name, age, tenant_id) VALUES ('秋秋', 18, 1), ('秋秋', '22', 1) ON DUPLICATE KEY UPDATE age = 18, tenant_id = 1");
|
||||
}
|
||||
|
||||
void assertSql(String sql, String targetSql) {
|
||||
assertThat(interceptor.parserSingle(sql, null)).isEqualTo(targetSql);
|
||||
}
|
||||
}
|
@ -0,0 +1,113 @@
|
||||
package com.baomidou.mybatisplus.test.pagination;
|
||||
|
||||
import com.baomidou.mybatisplus.core.metadata.OrderItem;
|
||||
import com.baomidou.mybatisplus.extension.plugins.inner.PaginationInnerInterceptor;
|
||||
import net.sf.jsqlparser.JSQLParserException;
|
||||
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
||||
import net.sf.jsqlparser.statement.select.PlainSelect;
|
||||
import net.sf.jsqlparser.statement.select.Select;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
/**
|
||||
* SelectBody强转PlainSelect不支持sql里面最外层带union
|
||||
* 用SetOperationList处理sql带union的语句
|
||||
*/
|
||||
class SelectBodyToPlainSelectTest {
|
||||
|
||||
private static final List<OrderItem> ITEMS = new ArrayList<>();
|
||||
|
||||
static {
|
||||
ITEMS.add(OrderItem.asc("column"));
|
||||
}
|
||||
|
||||
/**
|
||||
* 报错的测试
|
||||
*/
|
||||
@Test
|
||||
void testSelectBodyToPlainSelectThrowException() {
|
||||
Select selectStatement = null;
|
||||
try {
|
||||
String originalUnionSql = "select * from test union select * from test";
|
||||
selectStatement = (Select) CCJSqlParserUtil.parse(originalUnionSql);
|
||||
} catch (JSQLParserException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
assert selectStatement != null;
|
||||
Select finalSelectStatement = selectStatement;
|
||||
Assertions.assertThrows(ClassCastException.class, () -> {
|
||||
PlainSelect plainSelect = (PlainSelect) finalSelectStatement.getSelectBody();
|
||||
});
|
||||
}
|
||||
|
||||
@BeforeEach
|
||||
void setup() {
|
||||
List<OrderItem> orderItems = new ArrayList<>();
|
||||
OrderItem order = new OrderItem();
|
||||
order.setAsc(true);
|
||||
order.setColumn("column");
|
||||
orderItems.add(order);
|
||||
OrderItem orderEmptyColumn = new OrderItem();
|
||||
orderEmptyColumn.setAsc(false);
|
||||
orderEmptyColumn.setColumn("");
|
||||
orderItems.add(orderEmptyColumn);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testPaginationInterceptorConcatOrderByBefore() {
|
||||
String actualSql = new PaginationInnerInterceptor()
|
||||
.concatOrderBy("select * from test", ITEMS);
|
||||
|
||||
assertThat(actualSql).isEqualTo("SELECT * FROM test ORDER BY column ASC");
|
||||
|
||||
String actualSqlWhere = new PaginationInnerInterceptor()
|
||||
.concatOrderBy("select * from test where 1 = 1", ITEMS);
|
||||
|
||||
assertThat(actualSqlWhere).isEqualTo("SELECT * FROM test WHERE 1 = 1 ORDER BY column ASC");
|
||||
}
|
||||
|
||||
@Test
|
||||
void testPaginationInterceptorConcatOrderByFix() {
|
||||
List<OrderItem> orderList = new ArrayList<>();
|
||||
// 测试可能的 sql 注入 https://github.com/baomidou/mybatis-plus/issues/5745
|
||||
orderList.add(OrderItem.asc("col umn"));
|
||||
String actualSql = new PaginationInnerInterceptor()
|
||||
.concatOrderBy("select * from test union select * from test2", orderList);
|
||||
assertThat(actualSql).isEqualTo("SELECT * FROM test UNION SELECT * FROM test2 ORDER BY column ASC");
|
||||
|
||||
String actualSqlUnionAll = new PaginationInnerInterceptor()
|
||||
.concatOrderBy("select * from test union all select * from test2", orderList);
|
||||
assertThat(actualSqlUnionAll).isEqualTo("SELECT * FROM test UNION ALL SELECT * FROM test2 ORDER BY column ASC");
|
||||
}
|
||||
|
||||
@Test
|
||||
void testPaginationInterceptorConcatOrderByFixWithWhere() {
|
||||
String actualSqlWhere = new PaginationInnerInterceptor()
|
||||
.concatOrderBy("select * from test where 1 = 1 union select * from test2 where 1 = 1", ITEMS);
|
||||
assertThat(actualSqlWhere).isEqualTo("SELECT * FROM test WHERE 1 = 1 UNION SELECT * FROM test2 WHERE 1 = 1 ORDER BY column ASC");
|
||||
|
||||
String actualSqlUnionAll = new PaginationInnerInterceptor()
|
||||
.concatOrderBy("select * from test where 1 = 1 union all select * from test2 where 1 = 1 ", ITEMS);
|
||||
assertThat(actualSqlUnionAll).isEqualTo("SELECT * FROM test WHERE 1 = 1 UNION ALL SELECT * FROM test2 WHERE 1 = 1 ORDER BY column ASC");
|
||||
}
|
||||
|
||||
@Test
|
||||
void testPaginationInterceptorOrderByEmptyColumnFix() {
|
||||
String actualSql = new PaginationInnerInterceptor()
|
||||
.concatOrderBy("select * from test", ITEMS);
|
||||
|
||||
assertThat(actualSql).isEqualTo("SELECT * FROM test ORDER BY column ASC");
|
||||
|
||||
String actualSqlWhere = new PaginationInnerInterceptor()
|
||||
.concatOrderBy("select * from test where 1 = 1", ITEMS);
|
||||
|
||||
assertThat(actualSqlWhere).isEqualTo("SELECT * FROM test WHERE 1 = 1 ORDER BY column ASC");
|
||||
}
|
||||
|
||||
}
|
@ -0,0 +1,6 @@
|
||||
dependencies {
|
||||
api project(":mybatis-plus-extension")
|
||||
implementation "${lib."slf4j-api"}"
|
||||
}
|
||||
|
||||
compileJava.dependsOn(processResources)
|
@ -28,5 +28,6 @@ dependencies {
|
||||
testImplementation "${lib.'logback-classic'}"
|
||||
testImplementation "${lib.cglib}"
|
||||
testImplementation "${lib.postgresql}"
|
||||
testImplementation project(":mybatis-plus-jsqlparser:mybatis-plus-jsqlparser-5.0")
|
||||
// testCompile ('org.apache.phoenix:phoenix-core:5.0.0-HBase-2.0')
|
||||
}
|
||||
|
@ -29,4 +29,8 @@ include ':spring-boot-starter:mybatis-plus-spring-boot-autoconfigure'
|
||||
include ':spring-boot-starter:mybatis-plus-spring-boot-test-autoconfigure'
|
||||
include ':spring-boot-starter:mybatis-plus-spring-boot3-starter'
|
||||
include ':spring-boot-starter:mybatis-plus-spring-boot3-starter-test'
|
||||
include 'mybatis-plus-jsqlparser'
|
||||
include ':mybatis-plus-jsqlparser:mybatis-plus-jsqlparser-common'
|
||||
include ':mybatis-plus-jsqlparser:mybatis-plus-jsqlparser-4.9'
|
||||
include ':mybatis-plus-jsqlparser:mybatis-plus-jsqlparser-5.0'
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user