001/* 002 * Copyright (c) 2022-2025, Mybatis-Flex (fuhai999@gmail.com). 003 * <p> 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * <p> 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * <p> 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016package com.mybatisflex.core.datasource; 017 018import com.mybatisflex.core.FlexGlobalConfig; 019import com.mybatisflex.core.dialect.DbType; 020import com.mybatisflex.core.dialect.DbTypeUtil; 021import com.mybatisflex.core.transaction.TransactionContext; 022import com.mybatisflex.core.transaction.TransactionalManager; 023import com.mybatisflex.core.util.ArrayUtil; 024import com.mybatisflex.core.util.StringUtil; 025import org.apache.ibatis.logging.Log; 026import org.apache.ibatis.logging.LogFactory; 027 028import javax.sql.DataSource; 029import java.lang.reflect.InvocationHandler; 030import java.lang.reflect.Method; 031import java.lang.reflect.Proxy; 032import java.sql.Connection; 033import java.sql.SQLException; 034import java.util.ArrayList; 035import java.util.HashMap; 036import java.util.List; 037import java.util.Map; 038import java.util.Objects; 039import java.util.Optional; 040import java.util.concurrent.ThreadLocalRandom; 041 042/** 043 * @author michael 044 */ 045public class FlexDataSource extends AbstractDataSource { 046 047 private static final char LOAD_BALANCE_KEY_SUFFIX = '*'; 048 private static final Log log = LogFactory.getLog(FlexDataSource.class); 049 050 private final Map<String, DataSource> dataSourceMap = new HashMap<>(); 051 private final Map<String, DbType> dbTypeHashMap = new HashMap<>(); 052 053 private DbType defaultDbType; 054 private String defaultDataSourceKey; 055 private DataSource defaultDataSource; 056 057 public FlexDataSource(String dataSourceKey, DataSource dataSource) { 058 this(dataSourceKey, dataSource, true); 059 } 060 061 public FlexDataSource(String dataSourceKey, DataSource dataSource, boolean needDecryptDataSource) { 062 this(dataSourceKey, dataSource, DbTypeUtil.getDbType(dataSource), needDecryptDataSource); 063 } 064 065 public FlexDataSource(String dataSourceKey, DataSource dataSource, DbType dbType, boolean needDecryptDataSource) { 066 if (needDecryptDataSource) { 067 DataSourceManager.decryptDataSource(dataSource); 068 } 069 070 // 处理dbType 071 dbType = Optional.ofNullable(dbType).orElseGet(() -> DbTypeUtil.getDbType(dataSource)); 072 073 this.defaultDataSourceKey = dataSourceKey; 074 this.defaultDataSource = dataSource; 075 this.defaultDbType = dbType; 076 077 dataSourceMap.put(dataSourceKey, dataSource); 078 dbTypeHashMap.put(dataSourceKey, dbType); 079 } 080 081 /** 082 * 设置默认数据源(提供动态可控性) 083 */ 084 public void setDefaultDataSource(String dataSourceKey) { 085 DataSource ds = dataSourceMap.get(dataSourceKey); 086 087 if (Objects.isNull(ds)) { 088 throw new IllegalStateException("DataSource not found by key: \"" + dataSourceKey + "\""); 089 } 090 091 // 优先取缓存,否则根据数据源返回数据库类型 092 DbType dbType = Optional.ofNullable(dbTypeHashMap.get(dataSourceKey)).orElseGet(() -> DbTypeUtil.getDbType(ds)); 093 094 this.defaultDataSourceKey = dataSourceKey; 095 this.defaultDataSource = ds; 096 this.defaultDbType = dbType; 097 } 098 099 public void addDataSource(String dataSourceKey, DataSource dataSource) { 100 addDataSource(dataSourceKey, dataSource, true); 101 } 102 103 public void addDataSource(String dataSourceKey, DataSource dataSource, boolean needDecryptDataSource) { 104 addDataSource(dataSourceKey, dataSource, DbTypeUtil.getDbType(dataSource), needDecryptDataSource); 105 } 106 107 public void addDataSource(String dataSourceKey, DataSource dataSource, DbType dbType, boolean needDecryptDataSource) { 108 if (needDecryptDataSource) { 109 DataSourceManager.decryptDataSource(dataSource); 110 } 111 112 dbType = Optional.ofNullable(dbType).orElseGet(() -> DbTypeUtil.getDbType(dataSource)); 113 114 dataSourceMap.put(dataSourceKey, dataSource); 115 dbTypeHashMap.put(dataSourceKey, dbType); 116 } 117 118 119 public void removeDatasource(String dataSourceKey) { 120 dataSourceMap.remove(dataSourceKey); 121 dbTypeHashMap.remove(dataSourceKey); 122 } 123 124 public Map<String, DataSource> getDataSourceMap() { 125 return dataSourceMap; 126 } 127 128 public Map<String, DbType> getDbTypeHashMap() { 129 return dbTypeHashMap; 130 } 131 132 public String getDefaultDataSourceKey() { 133 return defaultDataSourceKey; 134 } 135 136 public DataSource getDefaultDataSource() { 137 return defaultDataSource; 138 } 139 140 public DbType getDefaultDbType() { 141 return defaultDbType; 142 } 143 144 public DbType getDbType(String dataSourceKey) { 145 return dbTypeHashMap.get(dataSourceKey); 146 } 147 148 149 @Override 150 public Connection getConnection() throws SQLException { 151 String xid = TransactionContext.getXID(); 152 if (StringUtil.hasText(xid)) { 153 String dataSourceKey = DataSourceKey.get(); 154 if (StringUtil.noText(dataSourceKey)) { 155 dataSourceKey = defaultDataSourceKey; 156 } 157 158 Connection connection = TransactionalManager.getConnection(xid, dataSourceKey); 159 if (connection == null) { 160 connection = proxy(getDataSource().getConnection(), xid); 161 TransactionalManager.hold(xid, dataSourceKey, connection); 162 } 163 return connection; 164 } else { 165 return getDataSource().getConnection(); 166 } 167 } 168 169 170 @Override 171 public Connection getConnection(String username, String password) throws SQLException { 172 String xid = TransactionContext.getXID(); 173 if (StringUtil.hasText(xid)) { 174 String dataSourceKey = DataSourceKey.get(); 175 if (StringUtil.noText(dataSourceKey)) { 176 dataSourceKey = defaultDataSourceKey; 177 } 178 Connection connection = TransactionalManager.getConnection(xid, dataSourceKey); 179 if (connection == null) { 180 connection = proxy(getDataSource().getConnection(username, password), xid); 181 TransactionalManager.hold(xid, dataSourceKey, connection); 182 } 183 return connection; 184 } else { 185 return getDataSource().getConnection(username, password); 186 } 187 } 188 189 static void closeAutoCommit(Connection connection) { 190 try { 191 connection.setAutoCommit(false); 192 } catch (SQLException e) { 193 if (log.isDebugEnabled()) { 194 log.debug("Error set autoCommit to false. Cause: " + e); 195 } 196 } 197 } 198 199 static void resetAutoCommit(Connection connection) { 200 try { 201 if (!connection.getAutoCommit()) { 202 connection.setAutoCommit(true); 203 } 204 } catch (SQLException e) { 205 if (log.isDebugEnabled()) { 206 log.debug("Error resetting autoCommit to true before closing the connection. " + 207 "Cause: " + e); 208 } 209 } 210 } 211 212 213 public Connection proxy(Connection connection, String xid) { 214 return (Connection) Proxy.newProxyInstance(FlexDataSource.class.getClassLoader() 215 , new Class[]{Connection.class} 216 , new ConnectionHandler(connection, xid) 217 ); 218 } 219 220 /** 221 * 方便用于 {@link DbTypeUtil#getDbType(DataSource)} 222 */ 223 public String getUrl() { 224 return DbTypeUtil.getJdbcUrl(defaultDataSource); 225 } 226 227 228 @Override 229 @SuppressWarnings("unchecked") 230 public <T> T unwrap(Class<T> iface) throws SQLException { 231 if (iface.isInstance(this)) { 232 return (T) this; 233 } 234 return getDataSource().unwrap(iface); 235 } 236 237 @Override 238 public boolean isWrapperFor(Class<?> iface) throws SQLException { 239 return (iface.isInstance(this) || getDataSource().isWrapperFor(iface)); 240 } 241 242 /** 243 * 获取数据源缺失处理器。 244 * 245 * @return DataSourceMissingHandler 数据源缺失处理器实例,用于自定义处理逻辑(如:记录日志、抛出异常或提供默认数据源)。 246 */ 247 public DataSourceMissingHandler getDataSourceMissingHandler() { 248 return FlexGlobalConfig.getDefaultConfig().getDataSourceMissingHandler(); 249 } 250 251 protected DataSource getDataSource() { 252 DataSource dataSource = defaultDataSource; 253 DataSourceMissingHandler dataSourceMissingHandler = getDataSourceMissingHandler(); 254 255 if (dataSourceMap.size() > 1) { 256 String dataSourceKey = DataSourceKey.get(); 257 258 if (StringUtil.hasText(dataSourceKey)) { 259 // 负载均衡 key 260 if (dataSourceKey.charAt(dataSourceKey.length() - 1) == LOAD_BALANCE_KEY_SUFFIX) { 261 String prefix = dataSourceKey.substring(0, dataSourceKey.length() - 1); 262 List<String> matchedKeys = new ArrayList<>(); 263 264 for (String key : dataSourceMap.keySet()) { 265 if (key.startsWith(prefix)) { 266 matchedKeys.add(key); 267 } 268 } 269 270 // 当找不到匹配的 key 时,尝试后备匹配 271 if (matchedKeys.isEmpty() && dataSourceMissingHandler != null) { 272 Map<String, DataSource> dsMap = dataSourceMissingHandler.handle(dataSourceKey, dataSourceMap); 273 274 if (dsMap != null && !dsMap.isEmpty()) { 275 for (String key : dsMap.keySet()) { 276 if (key.startsWith(prefix)) { 277 matchedKeys.add(key); 278 } 279 } 280 } 281 } 282 283 if (matchedKeys.isEmpty()) { 284 throw new IllegalStateException("Can not matched dataSource by key: \"" + dataSourceKey + "\""); 285 } 286 287 String randomKey = matchedKeys.get(ThreadLocalRandom.current().nextInt(matchedKeys.size())); 288 289 return dataSourceMap.get(randomKey); 290 } 291 // 非负载均衡 key 292 else { 293 dataSource = dataSourceMap.get(dataSourceKey); 294 295 // 当找不到匹配的 key 时,尝试后备匹配 296 if (dataSource == null && dataSourceMissingHandler != null) { 297 Map<String, DataSource> dsMap = dataSourceMissingHandler.handle(dataSourceKey, dataSourceMap); 298 299 if (dsMap != null && !dsMap.isEmpty()) { 300 dataSource = dsMap.get(dataSourceKey); 301 } 302 } 303 304 if (dataSource == null) { 305 throw new IllegalStateException("Cannot get target dataSource by key: \"" + dataSourceKey + "\""); 306 } 307 } 308 } 309 } 310 311 return dataSource; 312 } 313 314 private static class ConnectionHandler implements InvocationHandler { 315 private static final String[] proxyMethods = new String[]{"commit", "rollback", "close", "setAutoCommit"}; 316 private final Connection original; 317 private final String xid; 318 319 public ConnectionHandler(Connection original, String xid) { 320 closeAutoCommit(original); 321 this.original = original; 322 this.xid = xid; 323 } 324 325 @Override 326 public Object invoke(Object proxy, Method method, Object[] args) throws Throwable { 327 if (ArrayUtil.contains(proxyMethods, method.getName()) 328 && isTransactional()) { 329 // do nothing 330 return null; 331 } 332 333 // setAutoCommit: true 334 if ("close".equalsIgnoreCase(method.getName())) { 335 resetAutoCommit(original); 336 } 337 338 return method.invoke(original, args); 339 } 340 341 private boolean isTransactional() { 342 return Objects.equals(xid, TransactionContext.getXID()); 343 } 344 345 } 346}