001/*
002 *  Copyright (c) 2022-2025, Mybatis-Flex (fuhai999@gmail.com).
003 *  <p>
004 *  Licensed under the Apache License, Version 2.0 (the "License");
005 *  you may not use this file except in compliance with the License.
006 *  You may obtain a copy of the License at
007 *  <p>
008 *  http://www.apache.org/licenses/LICENSE-2.0
009 *  <p>
010 *  Unless required by applicable law or agreed to in writing, software
011 *  distributed under the License is distributed on an "AS IS" BASIS,
012 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 *  See the License for the specific language governing permissions and
014 *  limitations under the License.
015 */
016package com.mybatisflex.core.datasource;
017
018import com.mybatisflex.core.FlexGlobalConfig;
019import com.mybatisflex.core.dialect.DbType;
020import com.mybatisflex.core.dialect.DbTypeUtil;
021import com.mybatisflex.core.transaction.TransactionContext;
022import com.mybatisflex.core.transaction.TransactionalManager;
023import com.mybatisflex.core.util.ArrayUtil;
024import com.mybatisflex.core.util.StringUtil;
025import org.apache.ibatis.logging.Log;
026import org.apache.ibatis.logging.LogFactory;
027
028import javax.sql.DataSource;
029import java.lang.reflect.InvocationHandler;
030import java.lang.reflect.Method;
031import java.lang.reflect.Proxy;
032import java.sql.Connection;
033import java.sql.SQLException;
034import java.util.ArrayList;
035import java.util.HashMap;
036import java.util.List;
037import java.util.Map;
038import java.util.Objects;
039import java.util.Optional;
040import java.util.concurrent.ThreadLocalRandom;
041
042/**
043 * @author michael
044 */
045public class FlexDataSource extends AbstractDataSource {
046
047    private static final char LOAD_BALANCE_KEY_SUFFIX = '*';
048    private static final Log log = LogFactory.getLog(FlexDataSource.class);
049
050    private final Map<String, DataSource> dataSourceMap = new HashMap<>();
051    private final Map<String, DbType> dbTypeHashMap = new HashMap<>();
052
053    private DbType defaultDbType;
054    private String defaultDataSourceKey;
055    private DataSource defaultDataSource;
056
057    public FlexDataSource(String dataSourceKey, DataSource dataSource) {
058        this(dataSourceKey, dataSource, true);
059    }
060
061    public FlexDataSource(String dataSourceKey, DataSource dataSource, boolean needDecryptDataSource) {
062        this(dataSourceKey, dataSource, DbTypeUtil.getDbType(dataSource), needDecryptDataSource);
063    }
064
065    public FlexDataSource(String dataSourceKey, DataSource dataSource, DbType dbType, boolean needDecryptDataSource) {
066        if (needDecryptDataSource) {
067            DataSourceManager.decryptDataSource(dataSource);
068        }
069
070        // 处理dbType
071        dbType = Optional.ofNullable(dbType).orElseGet(() -> DbTypeUtil.getDbType(dataSource));
072
073        this.defaultDataSourceKey = dataSourceKey;
074        this.defaultDataSource = dataSource;
075        this.defaultDbType = dbType;
076
077        dataSourceMap.put(dataSourceKey, dataSource);
078        dbTypeHashMap.put(dataSourceKey, dbType);
079    }
080
081    /**
082     * 设置默认数据源(提供动态可控性)
083     */
084    public void setDefaultDataSource(String dataSourceKey) {
085        DataSource ds = dataSourceMap.get(dataSourceKey);
086
087        if (Objects.isNull(ds)) {
088            throw new IllegalStateException("DataSource not found by key: \"" + dataSourceKey + "\"");
089        }
090
091        // 优先取缓存,否则根据数据源返回数据库类型
092        DbType dbType = Optional.ofNullable(dbTypeHashMap.get(dataSourceKey)).orElseGet(() -> DbTypeUtil.getDbType(ds));
093
094        this.defaultDataSourceKey = dataSourceKey;
095        this.defaultDataSource = ds;
096        this.defaultDbType = dbType;
097    }
098
099    public void addDataSource(String dataSourceKey, DataSource dataSource) {
100        addDataSource(dataSourceKey, dataSource, true);
101    }
102
103    public void addDataSource(String dataSourceKey, DataSource dataSource, boolean needDecryptDataSource) {
104        addDataSource(dataSourceKey, dataSource, DbTypeUtil.getDbType(dataSource), needDecryptDataSource);
105    }
106
107    public void addDataSource(String dataSourceKey, DataSource dataSource, DbType dbType, boolean needDecryptDataSource) {
108        if (needDecryptDataSource) {
109            DataSourceManager.decryptDataSource(dataSource);
110        }
111
112        dbType = Optional.ofNullable(dbType).orElseGet(() -> DbTypeUtil.getDbType(dataSource));
113
114        dataSourceMap.put(dataSourceKey, dataSource);
115        dbTypeHashMap.put(dataSourceKey, dbType);
116    }
117
118
119    public void removeDatasource(String dataSourceKey) {
120        dataSourceMap.remove(dataSourceKey);
121        dbTypeHashMap.remove(dataSourceKey);
122    }
123
124    public Map<String, DataSource> getDataSourceMap() {
125        return dataSourceMap;
126    }
127
128    public Map<String, DbType> getDbTypeHashMap() {
129        return dbTypeHashMap;
130    }
131
132    public String getDefaultDataSourceKey() {
133        return defaultDataSourceKey;
134    }
135
136    public DataSource getDefaultDataSource() {
137        return defaultDataSource;
138    }
139
140    public DbType getDefaultDbType() {
141        return defaultDbType;
142    }
143
144    public DbType getDbType(String dataSourceKey) {
145        return dbTypeHashMap.get(dataSourceKey);
146    }
147
148
149    @Override
150    public Connection getConnection() throws SQLException {
151        String xid = TransactionContext.getXID();
152        if (StringUtil.hasText(xid)) {
153            String dataSourceKey = DataSourceKey.get();
154            if (StringUtil.noText(dataSourceKey)) {
155                dataSourceKey = defaultDataSourceKey;
156            }
157
158            Connection connection = TransactionalManager.getConnection(xid, dataSourceKey);
159            if (connection == null) {
160                connection = proxy(getDataSource().getConnection(), xid);
161                TransactionalManager.hold(xid, dataSourceKey, connection);
162            }
163            return connection;
164        } else {
165            return getDataSource().getConnection();
166        }
167    }
168
169
170    @Override
171    public Connection getConnection(String username, String password) throws SQLException {
172        String xid = TransactionContext.getXID();
173        if (StringUtil.hasText(xid)) {
174            String dataSourceKey = DataSourceKey.get();
175            if (StringUtil.noText(dataSourceKey)) {
176                dataSourceKey = defaultDataSourceKey;
177            }
178            Connection connection = TransactionalManager.getConnection(xid, dataSourceKey);
179            if (connection == null) {
180                connection = proxy(getDataSource().getConnection(username, password), xid);
181                TransactionalManager.hold(xid, dataSourceKey, connection);
182            }
183            return connection;
184        } else {
185            return getDataSource().getConnection(username, password);
186        }
187    }
188
189    static void closeAutoCommit(Connection connection) {
190        try {
191            connection.setAutoCommit(false);
192        } catch (SQLException e) {
193            if (log.isDebugEnabled()) {
194                log.debug("Error set autoCommit to false. Cause: " + e);
195            }
196        }
197    }
198
199    static void resetAutoCommit(Connection connection) {
200        try {
201            if (!connection.getAutoCommit()) {
202                connection.setAutoCommit(true);
203            }
204        } catch (SQLException e) {
205            if (log.isDebugEnabled()) {
206                log.debug("Error resetting autoCommit to true before closing the connection. " +
207                    "Cause: " + e);
208            }
209        }
210    }
211
212
213    public Connection proxy(Connection connection, String xid) {
214        return (Connection) Proxy.newProxyInstance(FlexDataSource.class.getClassLoader()
215            , new Class[]{Connection.class}
216            , new ConnectionHandler(connection, xid)
217        );
218    }
219
220    /**
221     * 方便用于 {@link DbTypeUtil#getDbType(DataSource)}
222     */
223    public String getUrl() {
224        return DbTypeUtil.getJdbcUrl(defaultDataSource);
225    }
226
227
228    @Override
229    @SuppressWarnings("unchecked")
230    public <T> T unwrap(Class<T> iface) throws SQLException {
231        if (iface.isInstance(this)) {
232            return (T) this;
233        }
234        return getDataSource().unwrap(iface);
235    }
236
237    @Override
238    public boolean isWrapperFor(Class<?> iface) throws SQLException {
239        return (iface.isInstance(this) || getDataSource().isWrapperFor(iface));
240    }
241
242    /**
243     * 获取数据源缺失处理器。
244     *
245     * @return DataSourceMissingHandler 数据源缺失处理器实例,用于自定义处理逻辑(如:记录日志、抛出异常或提供默认数据源)。
246     */
247    public DataSourceMissingHandler getDataSourceMissingHandler() {
248        return FlexGlobalConfig.getDefaultConfig().getDataSourceMissingHandler();
249    }
250
251    protected DataSource getDataSource() {
252        DataSource dataSource = defaultDataSource;
253        DataSourceMissingHandler dataSourceMissingHandler = getDataSourceMissingHandler();
254
255        if (dataSourceMap.size() > 1) {
256            String dataSourceKey = DataSourceKey.get();
257
258            if (StringUtil.hasText(dataSourceKey)) {
259                // 负载均衡 key
260                if (dataSourceKey.charAt(dataSourceKey.length() - 1) == LOAD_BALANCE_KEY_SUFFIX) {
261                    String prefix = dataSourceKey.substring(0, dataSourceKey.length() - 1);
262                    List<String> matchedKeys = new ArrayList<>();
263
264                    for (String key : dataSourceMap.keySet()) {
265                        if (key.startsWith(prefix)) {
266                            matchedKeys.add(key);
267                        }
268                    }
269
270                    // 当找不到匹配的 key 时,尝试后备匹配
271                    if (matchedKeys.isEmpty() && dataSourceMissingHandler != null) {
272                        Map<String, DataSource> dsMap = dataSourceMissingHandler.handle(dataSourceKey, dataSourceMap);
273
274                        if (dsMap != null && !dsMap.isEmpty()) {
275                            for (String key : dsMap.keySet()) {
276                                if (key.startsWith(prefix)) {
277                                    matchedKeys.add(key);
278                                }
279                            }
280                        }
281                    }
282
283                    if (matchedKeys.isEmpty()) {
284                        throw new IllegalStateException("Can not matched dataSource by key: \"" + dataSourceKey + "\"");
285                    }
286
287                    String randomKey = matchedKeys.get(ThreadLocalRandom.current().nextInt(matchedKeys.size()));
288
289                    return dataSourceMap.get(randomKey);
290                }
291                // 非负载均衡 key
292                else {
293                    dataSource = dataSourceMap.get(dataSourceKey);
294
295                    // 当找不到匹配的 key 时,尝试后备匹配
296                    if (dataSource == null && dataSourceMissingHandler != null) {
297                        Map<String, DataSource> dsMap = dataSourceMissingHandler.handle(dataSourceKey, dataSourceMap);
298
299                        if (dsMap != null && !dsMap.isEmpty()) {
300                            dataSource = dsMap.get(dataSourceKey);
301                        }
302                    }
303
304                    if (dataSource == null) {
305                        throw new IllegalStateException("Cannot get target dataSource by key: \"" + dataSourceKey + "\"");
306                    }
307                }
308            }
309        }
310
311        return dataSource;
312    }
313
314    private static class ConnectionHandler implements InvocationHandler {
315        private static final String[] proxyMethods = new String[]{"commit", "rollback", "close", "setAutoCommit"};
316        private final Connection original;
317        private final String xid;
318
319        public ConnectionHandler(Connection original, String xid) {
320            closeAutoCommit(original);
321            this.original = original;
322            this.xid = xid;
323        }
324
325        @Override
326        public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
327            if (ArrayUtil.contains(proxyMethods, method.getName())
328                && isTransactional()) {
329                // do nothing
330                return null;
331            }
332
333            // setAutoCommit: true
334            if ("close".equalsIgnoreCase(method.getName())) {
335                resetAutoCommit(original);
336            }
337
338            return method.invoke(original, args);
339        }
340
341        private boolean isTransactional() {
342            return Objects.equals(xid, TransactionContext.getXID());
343        }
344
345    }
346}