新增重载方法支持自定义条件.

This commit is contained in:
nieqiurong 2024-04-10 13:21:12 +08:00
parent bb05cdbdd8
commit 52c7497c5b
2 changed files with 63 additions and 6 deletions

View File

@ -15,6 +15,7 @@
*/
package com.baomidou.mybatisplus.core.mapper;
import com.baomidou.mybatisplus.core.batch.BatchSqlSession;
import com.baomidou.mybatisplus.core.batch.MybatisBatch;
import com.baomidou.mybatisplus.core.conditions.Wrapper;
import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper;
@ -28,6 +29,7 @@ import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
import com.baomidou.mybatisplus.core.toolkit.Constants;
import com.baomidou.mybatisplus.core.toolkit.MybatisBatchUtils;
import com.baomidou.mybatisplus.core.toolkit.MybatisUtils;
import com.baomidou.mybatisplus.core.toolkit.StringPool;
import com.baomidou.mybatisplus.core.toolkit.StringUtils;
import com.baomidou.mybatisplus.core.toolkit.Wrappers;
import com.baomidou.mybatisplus.core.toolkit.reflect.GenericTypeUtils;
@ -44,6 +46,7 @@ import java.lang.reflect.Proxy;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.function.BiPredicate;
/*
@ -432,17 +435,28 @@ public interface BaseMapper<T> extends Mapper<T> {
* @since 3.5.7
*/
default List<BatchResult> saveOrUpdateBatch(Collection<T> entityList) {
MybatisMapperProxy<?> mybatisMapperProxy = (MybatisMapperProxy<?>) Proxy.getInvocationHandler(this);
Class<?> entityClass = GenericTypeUtils.resolveTypeArguments(getClass(), BaseMapper.class)[0];
TableInfo tableInfo = TableInfoHelper.getTableInfo(entityClass);
String keyProperty = tableInfo.getKeyProperty();
String statement = mybatisMapperProxy.getMapperInterface().getName() + StringPool.DOT + SqlMethod.SELECT_BY_ID.getMethod();
return saveOrUpdateBatch(entityList, (sqlSession, entity) -> {
Object idVal = tableInfo.getPropertyValue(entity, keyProperty);
return StringUtils.checkValNull(idVal) || CollectionUtils.isEmpty(sqlSession.selectList(statement, entity));
});
}
/**
* 批量修改或插入
*
* @param entityList 实体对象集合
* @since 3.5.7
*/
default List<BatchResult> saveOrUpdateBatch(Collection<T> entityList, BiPredicate<BatchSqlSession, T> insertPredicate) {
MybatisMapperProxy<?> mybatisMapperProxy = (MybatisMapperProxy<?>) Proxy.getInvocationHandler(this);
MybatisBatch.Method<T> method = new MybatisBatch.Method<>(mybatisMapperProxy.getMapperInterface());
SqlSessionFactory sqlSessionFactory = MybatisUtils.getSqlSessionFactory(mybatisMapperProxy);
String keyProperty = tableInfo.getKeyProperty();
String statementId = method.get(SqlMethod.SELECT_BY_ID.getMethod()).getStatementId();
return MybatisBatchUtils.saveOrUpdate(sqlSessionFactory, entityList, method.insert(), (sqlSession, entity) -> {
Object idVal = tableInfo.getPropertyValue(entity, keyProperty);
return StringUtils.checkValNull(idVal) || CollectionUtils.isEmpty(sqlSession.selectList(statementId, entity));
}, method.updateById());
return MybatisBatchUtils.saveOrUpdate(sqlSessionFactory, entityList, method.insert(), insertPredicate, method.updateById());
}
}

View File

@ -272,6 +272,49 @@ class H2UserMapperTest extends BaseTest {
}
}
@Test
void testSaveOrUpdateBatchMapper2() {
int batchSize = 10;
List<H2User> h2UserList = new ArrayList<>();
for (int i = 0; i < batchSize; i++) {
h2UserList.add(new H2User(Long.valueOf(40000 + i), "test" + i));
}
List<BatchResult> batchResults = userMapper.saveOrUpdateBatch(h2UserList,((sqlSession, h2User) -> userMapper.selectById(h2User.getTestId()) == null));
// 没有使用共享的sqlSession,由于都是新增返回还是一个批次
int[] updateCounts = batchResults.get(0).getUpdateCounts();
Assertions.assertEquals(batchSize, updateCounts.length);
for (int updateCount : updateCounts) {
Assertions.assertEquals(1, updateCount);
}
}
@Test
void testSaveOrUpdateBatchMapper3() {
var id = IdWorker.getId();
var h2UserList = List.of(new H2User(id, "testSaveOrUpdateBatchMapper3"), new H2User(id, "testSaveOrUpdateBatchMapper3-1"));
// 由于没有共享一个sqlSession,第二条记录selectById的时候第一个sqlSession的数据还没提交,会执行插入导致主键冲突.
Assertions.assertThrowsExactly(PersistenceException.class, () -> {
userMapper.saveOrUpdateBatch(h2UserList, ((sqlSession, h2User) -> userMapper.selectById(h2User.getTestId()) == null));
});
}
@Test
void testSaveOrUpdateBatchMapper4() {
var id = IdWorker.getId();
var h2UserList = List.of(new H2User(id, "testSaveOrUpdateBatchMapper4"), new H2User(id, "testSaveOrUpdateBatchMapper4-1"));
var mapperMethod = new MybatisBatch.Method<H2User>(H2UserMapper.class);
// 共享一个sqlSession,每次selectById都会刷新一下,第二条记录为update.
var batchResults = userMapper.saveOrUpdateBatch(h2UserList,
((sqlSession, h2User) -> sqlSession.selectList(mapperMethod.get("selectById").getStatementId(), h2User.getTestId()).isEmpty()));
var updateCounts = batchResults.get(0).getUpdateCounts();
for (int updateCount : updateCounts) {
Assertions.assertEquals(1, updateCount);
}
Assertions.assertEquals(userMapper.selectById(id).getName(), "testSaveOrUpdateBatchMapper4-1");
}
@Test
void testSaveOrUpdateBatch2() {