/*
* Copyright (c) 2011-2022, baomidou ([email protected]).
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.baomidou.mybatisplus.advance.injector;
import com.baomidou.mybatisplus.core.conditions.AbstractJoinWrapper;
import com.baomidou.mybatisplus.core.conditions.AbstractWrapper;
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import com.baomidou.mybatisplus.core.enums.SqlMethod;
import com.baomidou.mybatisplus.core.toolkit.*;
import com.baomidou.mybatisplus.core.toolkit.support.ColumnCache;
import com.baomidou.mybatisplus.core.toolkit.support.SFunction;
import com.baomidou.mybatisplus.extension.conditions.query.BasicJoinQueryWrapper;
import com.baomidou.mybatisplus.extension.toolkit.SqlHelper;
import lombok.AllArgsConstructor;
import lombok.Data;
import org.apache.ibatis.session.SqlSession;
import org.mybatis.spring.SqlSessionUtils;
import java.lang.invoke.CallSite;
import java.lang.invoke.LambdaMetafactory;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.reflect.Field;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* 通过此类去给wrapper添加参数和执行sql返回结果
*
* @author wanglei
* @since 2022-03-14
*/
public class FuntionTools {
/**
* 可序列化
*/
private static final int FLAG_SERIALIZABLE = 1;
private static Map<String, SFunction> functionMap = new HashMap<>();
/**
* 把动作
*
* @param actions 动作
* @param operator 操作符 比如like
* @param po po对象
* @param property 属性名
* @param value 值
*/
public static void addAction(List<Action> actions, String operator, Object po, String property, Object value) {
actions.add(new Action(po.getClass(), property, operator, value, null, null, null));
}
/**
* 添加动作
*
* @param actions actions
* @param operator 操作符 比如like
* @param po po对象
* @param property 属性名
* @param minValue 最小值
* @param maxValue 最大值
*/
public static void addAction(List<Action> actions, String operator, Object po, String property, Object minValue, Object maxValue) {
actions.add(new Action(po.getClass(), property, operator, null, minValue, maxValue, null));
}
/**
* 构造一个查询条件 - 单表查询
*
* @param actions 动作集合
* @param po po
* @return QueryWrapper
*/
public static LambdaQueryWrapper buildQueryWrapper(List<Action> actions, Object po) {
LambdaQueryWrapper<?> queryWrapper = new LambdaQueryWrapper();
for (Action action : actions) {
// 普通的查询只保留他自己的where条件和select,就算他调用了join也忽略掉
if (action.getModelClass().equals(po.getClass()) && !"select".equals(action.getAction())) {
SFunction column = getSFunction(action.getModelClass(), action.getProperty());
buildWhere(queryWrapper, action, column);
} else if ("select".equals(action.getAction()) && action.getModelClass().equals(po.getClass())) {
SFunction[] columns = new SFunction[action.getProperties().length];
int i = 0;
for (String property : action.getProperties()) {
columns[i++] = getSFunction(action.getModelClass(), property);
}
queryWrapper.select(columns);
}
}
return queryWrapper;
}
/**
* 构造where条件
*
* @param queryWrapper wapper
* @param action 动作
* @param column 列名
*/
public static void buildWhere(AbstractWrapper queryWrapper, Action action, Object column) {
switch (action.getAction()) {
case OperatorConstant.EQ:
queryWrapper.eq(column, action.getValue());
break;
case OperatorConstant.LT:
queryWrapper.lt(column, action.getValue());
break;
case OperatorConstant.GT:
queryWrapper.gt(column, action.getValue());
break;
case OperatorConstant.LE:
queryWrapper.le(column, action.getValue());
break;
case OperatorConstant.GE:
queryWrapper.ge(column, action.getValue());
break;
case OperatorConstant.NE:
queryWrapper.ne(column, action.getValue());
break;
case OperatorConstant.LIKE:
queryWrapper.like(column, action.getValue());
break;
case OperatorConstant.LIKE_LEFT:
queryWrapper.likeLeft(column, action.getValue());
break;
case OperatorConstant.LIKE_RIGHT:
queryWrapper.likeRight(column, action.getValue());
break;
case OperatorConstant.NOT_LIKE:
queryWrapper.notLike(column, action.getValue());
break;
case OperatorConstant.IS_NULL:
queryWrapper.isNull(column);
break;
case OperatorConstant.NOT_NULL:
queryWrapper.isNotNull(column);
break;
case OperatorConstant.IN:
queryWrapper.in(column, (Collection) (action.getValue()));
break;
case OperatorConstant.NOT_IN:
queryWrapper.notIn(column, (Collection) (action.getValue()));
break;
case OperatorConstant.ORDER_BY_ASC:
queryWrapper.orderByAsc(column);
break;
case OperatorConstant.ORDER_BY_DESC:
queryWrapper.orderByDesc(column);
break;
case OperatorConstant.BETWEEN:
queryWrapper.between(column, action.getMin(), action.getMax());
break;
case OperatorConstant.NOT_BETWEEN:
queryWrapper.notBetween(column, action.getMin(), action.getMax());
break;
}
}
/**
* 构造一个查询条件 - 多表查询
*
* @param actions 动作集合
* @param po po
* @return
*/
public static BasicJoinQueryWrapper buildJoinWrapper(List<Action> actions, Object po) {
BasicJoinQueryWrapper queryWrapper = new BasicJoinQueryWrapper(po.getClass());
// 先搞定join,不然添加条件和select会报错
for (Action action : actions) {
if (action.getAction().contains(OperatorConstant.JOIN)) {
switch (action.getAction()) {
case OperatorConstant.JOIN:
queryWrapper.innerJoin(action.getModelClass());
break;
case OperatorConstant.LEFT_JOIN:
queryWrapper.leftJoin(action.getModelClass());
break;
}
}
}
for (Action action : actions) {
// 不是select和join动作的话,就是where条件
if (!OperatorConstant.SELECT.equals(action.getAction()) && !action.getAction().contains(OperatorConstant.JOIN)) {
BasicJoinQueryWrapper.ModelProperty column = new BasicJoinQueryWrapper.ModelProperty(action.getModelClass(), action.getProperty());
buildWhere(queryWrapper, action, column);
// select处理
} else if (OperatorConstant.SELECT.equals(action.getAction())) {
BasicJoinQueryWrapper.ModelProperty[] columns = new BasicJoinQueryWrapper.ModelProperty[action.getProperties().length];
int i = 0;
for (String property : action.getProperties()) {
columns[i++] = new BasicJoinQueryWrapper.ModelProperty(action.getModelClass(), property);
}
queryWrapper.select(columns);
}
}
return queryWrapper;
}
/**
* 判断是否是join
*
* @param actions 动作集合
* @return true 包含join动作,false 不包含join动作
*/
public static boolean isJoin(List<Action> actions) {
return actions.stream().filter(action -> action.getAction().contains(OperatorConstant.JOIN)).count() > 0;
}
/**
* join其他的表
*
* @param target 目标标的实体类
* @param actions 动作集合
* @param joinType join的类型
* @return 目标类的对象
*/
public static <T> T join(Class<T> target, List<Action> actions, String joinType) {
try {
T result = target.newInstance();
// 拿到关联对象的actions,并且重新赋值
Field field = ReflectionKit.getFieldMap(target).get("actions");
field.setAccessible(true);
field.set(result, actions);
actions.add(new Action(target, null, joinType, null, null, null, null));
return result;
} catch (InstantiationException e) {
e.printStackTrace();
} catch (IllegalAccessException e) {
e.printStackTrace();
}
return null;
}
/**
* 回主表
* @param target 目标标的实体类
* @param actions 动作集合
* @return 目标类的对象
*/
public static <T> T end(Class<T> target, List<Action> actions) {
try {
T result = target.newInstance();
// 拿到关联对象的actions,并且重新赋值
Field field = ReflectionKit.getFieldMap(target).get("actions");
field.setAccessible(true);
field.set(result, actions);
return result;
} catch (InstantiationException e) {
e.printStackTrace();
} catch (IllegalAccessException e) {
e.printStackTrace();
}
return null;
}
/**
* 查询列表
*
* @param actions 动作集合
* @param po po
* @return 列表
*/
public static List list(List<Action> actions, Object po) {
boolean isJoin = isJoin(actions);
AbstractWrapper wrapper = isJoin ? buildJoinWrapper(actions, po) : buildQueryWrapper(actions, po);
SqlSession sqlSession = sqlSession(po.getClass());
Map<String, Object> map = CollectionUtils.newHashMapWithExpectedSize(1);
map.put(Constants.WRAPPER, wrapper);
try {
return sqlSession.selectList(sqlStatement(SqlMethod.SELECT_LIST, po.getClass()), map);
} finally {
closeSqlSession(sqlSession, po.getClass());
}
}
/**
* 手动指定查询字段
*
* @param actions 动作集合
* @param po po
* @param fields 字段
*/
public static void addSelect(List<Action> actions, Object po, String... fields) {
actions.add(new Action(po.getClass(), null, OperatorConstant.SELECT, null, null, null, fields));
}
/**
* 获取数据库字段名
*
* @param entityClass 实体类
* @param fieldName 属性
* @return 字段名
*/
public static String getDBField(Class entityClass, String fieldName) {
ColumnCache columnCache = AbstractJoinWrapper.getCache(entityClass, fieldName);
if (columnCache != null) {
return columnCache.getColumn();
}
throw ExceptionUtils.mpe("This class %s is not have field %s ", entityClass.getName(), fieldName);
}
/**
* 使用wrapper进行查询
*
* @param actions 动作集合
* @param po po
* @return 单个对象
*/
public static Object one(List<Action> actions, Object po) {
List list = list(actions, po);
if (list.size() > 0) {
return list.get(0);
}
return null;
}
/**
* 使用wrapper进行查询
*
* @param actions 动作集合
* @param po po
* @return 总数
*/
public static Long count(List<Action> actions, Object po) {
boolean isJoin = isJoin(actions);
AbstractWrapper wrapper = isJoin ? buildJoinWrapper(actions, po) : buildQueryWrapper(actions, po);
SqlSession sqlSession = sqlSession(po.getClass());
Map<String, Object> map = CollectionUtils.newHashMapWithExpectedSize(1);
map.put(Constants.WRAPPER, wrapper);
try {
return sqlSession.selectOne(sqlStatement(SqlMethod.SELECT_COUNT, po.getClass()), map);
} finally {
closeSqlSession(sqlSession, po.getClass());
}
}
/**
* actions进行修改返回受影响行数
*
* @param actions 动作集合
* @param po po
* @return 受影响行数
*/
public static Integer update(List<Action> actions, Object po) {
AbstractWrapper wrapper = buildQueryWrapper(actions, po);
SqlSession sqlSession = sqlSession(po.getClass());
Map<String, Object> map = CollectionUtils.newHashMapWithExpectedSize(1);
map.put(Constants.WRAPPER, wrapper);
map.put(Constants.ENTITY, po);
try {
return sqlSession.update(sqlStatement(SqlMethod.UPDATE, po.getClass()), map);
} finally {
closeSqlSession(sqlSession, po.getClass());
}
}
/**
* 使用actions进行删除返回受影响行数
*
* @param actions 动作集合
* @param po po
* @return 受影响行数
*/
public static Integer delete(List<Action> actions, Object po) {
AbstractWrapper wrapper = buildQueryWrapper(actions, po);
SqlSession sqlSession = sqlSession(po.getClass());
Map<String, Object> map = CollectionUtils.newHashMapWithExpectedSize(1);
map.put(Constants.WRAPPER, wrapper);
try {
return sqlSession.delete(sqlStatement(SqlMethod.DELETE, po.getClass()), map);
} finally {
closeSqlSession(sqlSession, po.getClass());
}
}
protected static SqlSession sqlSession(Class poClass) {
return SqlHelper.sqlSession(poClass);
}
protected static String sqlStatement(SqlMethod sqlMethod, Class poClass) {
return sqlStatement(sqlMethod.getMethod(), poClass);
}
protected static String sqlStatement(String sqlMethod, Class poClass) {
return SqlHelper.table(poClass).getSqlStatement(sqlMethod);
}
protected static void closeSqlSession(SqlSession sqlSession, Class poClass) {
SqlSessionUtils.closeSqlSession(sqlSession, GlobalConfigUtils.currentSessionFactory(poClass));
}
/**
* 获取方法的sfunction
* @param entityClass 实体类
* @param fieldName 字段名
* @return sfunction
*/
public static SFunction getSFunction(Class<?> entityClass, String fieldName) {
if (functionMap.containsKey(entityClass.getName() + fieldName)) {
return functionMap.get(entityClass.getName() + fieldName);
}
Field field = getDeclaredField(entityClass, fieldName);
if(field == null){
throw ExceptionUtils.mpe("This class %s is not have field %s ", entityClass.getName(), fieldName);
}
SFunction func = null;
final MethodHandles.Lookup lookup = MethodHandles.lookup();
MethodType methodType = MethodType.methodType(field.getType(), entityClass);
final CallSite site;
String getFunName = "get" + fieldName.substring(0, 1).toUpperCase() + fieldName.substring(1);
try {
site = LambdaMetafactory.altMetafactory(lookup,
"invoke",
MethodType.methodType(SFunction.class),
methodType,
lookup.findVirtual(entityClass, getFunName, MethodType.methodType(field.getType())),
methodType, FLAG_SERIALIZABLE);
func = (SFunction) site.getTarget().invokeExact();
functionMap.put(entityClass.getName() + field, func);
return func;
} catch (Throwable e) {
throw ExceptionUtils.mpe("This class %s is not have method %s ", entityClass.getName(), getFunName);
}
}
/**
* 获取字段
* @param clazz 类
* @param fieldName 字段名
* @return 字段
*/
public static Field getDeclaredField(Class<?> clazz, String fieldName) {
Field field = null;
for (; clazz != Object.class; clazz = clazz.getSuperclass()) {
try {
field = clazz.getDeclaredField(fieldName);
return field;
} catch (Exception e) {
// 这里甚么都不要做!并且这里的异常必须这样写,不能抛出去。
// 如果这里的异常打印或者往外抛,则就不会执行clazz = clazz.getSuperclass(),最后就不会进入到父类中了
}
}
return null;
}
/**
* 动作
*
* @author wanglei
* @since 2022-03-18
*/
@Data
@AllArgsConstructor
public static class Action {
/**
* po的class
*/
private Class<?> modelClass;
/**
* 属性
*/
private String property;
/**
* 动作
*/
private String action;
/**
* 值
*/
private Object value;
/**
* 最小值
*/
private Object min;
/**
* 最大值
*/
private Object max;
/**
* 属性 -用于select
*/
private String[] properties;
}
}