1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294
| import me.mingshan.util.StringUtil; import me.mingshan.util.orm.entity.VersionPEntity; import org.apache.ibatis.executor.Executor; import org.apache.ibatis.mapping.*; import org.apache.ibatis.plugin.*; import org.slf4j.Logger; import org.slf4j.LoggerFactory;
import java.lang.reflect.Field; import java.util.HashMap; import java.util.Map; import java.util.Objects; import java.util.Properties; import java.util.regex.Pattern;
@Intercepts({@Signature( type = Executor.class, method = "update", // update 包括了最常用的 insert/update/delete 三种操作 args = {MappedStatement.class, Object.class})}) public class OptimisticLockerInterceptor implements Interceptor { private static final Logger LOGGER = LoggerFactory.getLogger(OptimisticLockerInterceptor.class);
private static final Map<String, String> SUPPORT_METHODS_MAP = new HashMap<>();
static { SUPPORT_METHODS_MAP.put("updateByPrimaryKey", null); SUPPORT_METHODS_MAP.put("delete", null); SUPPORT_METHODS_MAP.put("updateByPrimaryKeySelective", null); SUPPORT_METHODS_MAP.put("updateByExample", null); SUPPORT_METHODS_MAP.put("updateByExampleSelective", null); }
@Override public Object intercept(Invocation invocation) throws Throwable { Object[] args = invocation.getArgs(); MappedStatement ms = (MappedStatement) args[0]; Object parameterObject = args[1];
Object versionEntity = fetchVersionEntity(parameterObject); if (versionEntity == null) { return invocation.proceed(); }
String id = ms.getId(); LOGGER.info("乐观锁,ID: {}", id);
String[] targetMethodArr = id.split("\\."); String method = targetMethodArr[targetMethodArr.length - 1]; boolean methodSupportIntercept = SUPPORT_METHODS_MAP.containsKey(method); if (!methodSupportIntercept) { return invocation.proceed(); }
String sqlCommandType = ms.getSqlCommandType().toString(); LOGGER.info("乐观锁,sqlCommandType: {}", sqlCommandType);
BoundSql boundSql = ms.getBoundSql(parameterObject); String origSql = boundSql.getSql(); LOGGER.info("乐观锁,原始SQL: {}", origSql);
Long version = ((VersionPEntity) versionEntity).getVersion();
String newSql = rewriteSql(origSql, version, ms.getSqlCommandType(), versionEntity); if (origSql.equals(newSql)) { return invocation.proceed(); }
BoundSql newBoundSql = new BoundSql(ms.getConfiguration(), newSql, boundSql.getParameterMappings(), boundSql.getParameterObject());
MappedStatement newMs = newMappedStatement(ms, new BoundSqlSqlSource(newBoundSql)); for (ParameterMapping mapping : boundSql.getParameterMappings()) { String prop = mapping.getProperty(); if (boundSql.hasAdditionalParameter(prop)) { newBoundSql.setAdditionalParameter(prop, boundSql.getAdditionalParameter(prop)); } }
Object[] queryArgs = invocation.getArgs(); queryArgs[0] = newMs;
LOGGER.info("乐观锁,改写的SQL: {}", newSql);
return invocation.proceed(); }
private static Object fetchVersionEntity(Object parameterObject) { Object versionEntity = null;
if (!(parameterObject instanceof VersionPEntity)) { if (Objects.nonNull(parameterObject) && parameterObject instanceof Map) { Object firstValue = ((Map<?, ?>) parameterObject).values().stream() .filter(Objects::nonNull).findFirst().orElse(null); if (firstValue instanceof VersionPEntity) { versionEntity = firstValue; } else { return null; } } } else { versionEntity = parameterObject; }
return versionEntity; }
private static String rewriteSql(String origSql, Long version, SqlCommandType sqlCommandType, Object versionEntity) { if (StringUtil.isEmpty(origSql) || version == null || sqlCommandType == null || versionEntity == null) { return origSql; }
String lowCaseOrigSql = origSql.toLowerCase();
boolean existWhere = lowCaseOrigSql.contains("where"); if (!existWhere) { return origSql; }
String inReg = "^.*(\\s)+in(\\s|\\(){1}.*$"; boolean existIn = Pattern.matches(inReg, lowCaseOrigSql); if (existIn) { return origSql; }
if (SqlCommandType.UPDATE.equals(sqlCommandType)) { String[] sqlArr = origSql.split("(?i)where");
String s1 = sqlArr[0]; String s2 = sqlArr[1];
String versionReg = "^.*(\\s)+version(\\s|=){1}.*$"; boolean existVersion = Pattern.matches(versionReg, lowCaseOrigSql); if (existVersion) { try { Field versionFiled = versionEntity.getClass().getSuperclass().getDeclaredField("version"); versionFiled.setAccessible(true);
Long value = (Long) versionFiled.get(versionEntity); value = value + 1; versionFiled.set(versionEntity, value); } catch (Exception e) { return origSql; }
s2 += " and version = " + version; } else { s1 += " , version = version + 1 "; s2 += " and version = " + version; }
return s1 + " where " + s2; } if (SqlCommandType.DELETE.equals(sqlCommandType)) { String versionReg = "^.*(\\s)+version(\\s|=){1}.*$"; boolean existVersion = Pattern.matches(versionReg, lowCaseOrigSql);
if (!existVersion) { origSql += " and version = " + version; }
return origSql; }
return origSql; }
static class BoundSqlSqlSource implements SqlSource { private final BoundSql boundSql;
public BoundSqlSqlSource(BoundSql boundSql) { this.boundSql = boundSql; }
public BoundSql getBoundSql(Object parameterObject) { return boundSql; }
}
private MappedStatement newMappedStatement(MappedStatement ms, SqlSource newSqlSource) { MappedStatement.Builder builder = new MappedStatement.Builder(ms.getConfiguration(), ms.getId(), newSqlSource, ms.getSqlCommandType()); builder.resource(ms.getResource()); builder.fetchSize(ms.getFetchSize()); builder.statementType(ms.getStatementType()); builder.keyGenerator(ms.getKeyGenerator()); if (ms.getKeyProperties() != null && ms.getKeyProperties().length > 0) { builder.keyProperty(ms.getKeyProperties()[0]); } builder.timeout(ms.getTimeout()); builder.parameterMap(ms.getParameterMap()); builder.resultMaps(ms.getResultMaps()); builder.resultSetType(ms.getResultSetType()); builder.cache(ms.getCache()); builder.flushCacheRequired(ms.isFlushCacheRequired()); builder.useCache(ms.isUseCache()); return builder.build(); }
@Override public Object plugin(Object target) { if (target instanceof Executor) { return Plugin.wrap(target, this); } return target; }
@Override public void setProperties(Properties properties) { } }
|