001/*
002 *  Copyright (c) 2022-2024, Mybatis-Flex (fuhai999@gmail.com).
003 *  <p>
004 *  Licensed under the Apache License, Version 2.0 (the "License");
005 *  you may not use this file except in compliance with the License.
006 *  You may obtain a copy of the License at
007 *  <p>
008 *  http://www.apache.org/licenses/LICENSE-2.0
009 *  <p>
010 *  Unless required by applicable law or agreed to in writing, software
011 *  distributed under the License is distributed on an "AS IS" BASIS,
012 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 *  See the License for the specific language governing permissions and
014 *  limitations under the License.
015 */
016
017package com.mybatisflex.spring.datasource;
018
019import com.mybatisflex.annotation.UseDataSource;
020import com.mybatisflex.core.datasource.DataSourceKey;
021import com.mybatisflex.core.util.StringUtil;
022import org.aopalliance.intercept.MethodInterceptor;
023import org.aopalliance.intercept.MethodInvocation;
024import org.springframework.core.MethodClassKey;
025
026import java.lang.reflect.Method;
027import java.util.Map;
028import java.util.concurrent.ConcurrentHashMap;
029
030/**
031 * 多数据源切换拦截器。
032 *
033 * @author 王帅
034 * @author barql
035 * @author michael
036 * @since 2023-06-25
037 */
038public class DataSourceInterceptor implements MethodInterceptor {
039
040    /**
041     * 缓存方法对应的数据源。
042     */
043    private final Map<Object, String> dsCache = new ConcurrentHashMap<>();
044
045    @Override
046    public Object invoke(MethodInvocation invocation) throws Throwable {
047        String dsKey = getDataSourceKey(invocation.getThis(), invocation.getMethod(), invocation.getArguments());
048        if (StringUtil.noText(dsKey)) {
049            return invocation.proceed();
050        }
051        try {
052            DataSourceKey.use(dsKey);
053            return invocation.proceed();
054        } finally {
055            DataSourceKey.clear();
056        }
057    }
058
059    private String getDataSourceKey(Object target, Method method, Object[] arguments) {
060        Object cacheKey = new MethodClassKey(method, target.getClass());
061        String dsKey = this.dsCache.get(cacheKey);
062        if (dsKey == null) {
063            dsKey = determineDataSourceKey(method, target.getClass());
064            // 对数据源取值进行动态取值处理
065            if (StringUtil.hasText(dsKey)) {
066                dsKey = DataSourceKey.processDataSourceKey(dsKey, target, method, arguments);
067            }
068            this.dsCache.put(cacheKey, dsKey);
069        }
070        return dsKey;
071    }
072
073    private String determineDataSourceKey(Method method, Class<?> targetClass) {
074        // 方法上定义有 UseDataSource 注解
075        UseDataSource annotation = method.getAnnotation(UseDataSource.class);
076        if (annotation != null) {
077            return annotation.value();
078        }
079        // 类上定义有 UseDataSource 注解
080        annotation = targetClass.getAnnotation(UseDataSource.class);
081        if (annotation != null) {
082            return annotation.value();
083        }
084        // 接口上定义有 UseDataSource 注解
085        Class<?>[] interfaces = targetClass.getInterfaces();
086        for (Class<?> anInterface : interfaces) {
087            annotation = anInterface.getAnnotation(UseDataSource.class);
088            if (annotation != null) {
089                return annotation.value();
090            }
091        }
092        // 哪里都没有 UseDataSource 注解
093        return "";
094    }
095
096}