拦截器版
实现思路
1.通过拦截器,读取方法上的注解
2.累计请求数量,进行限流
1.1 定义注解RequestLimit
复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28package com.zhf.model.annotation; import java.lang.annotation.*; import java.util.concurrent.TimeUnit; @Documented @Inherited @Target({ElementType.METHOD,ElementType.TYPE}) @Retention(RetentionPolicy.RUNTIME) public @interface RequestLimit { /** * 限流的时间单位 */ TimeUnit timeUnit() default TimeUnit.SECONDS; /** * 限流的时长 */ int limit() default 1; /** * 最大限流量 * @return */ int maxCount() default 1; }
1.2 定义拦截器RequestLimitInterceptor
复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145package com.zhf.model.interceptor; import com.alibaba.fastjson.JSONObject; import com.zhf.model.annotation.RequestLimit; import org.springframework.stereotype.Component; import org.springframework.web.method.HandlerMethod; import org.springframework.web.servlet.handler.HandlerInterceptorAdapter; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import java.io.IOException; import java.io.PrintWriter; import java.lang.annotation.Annotation; import java.lang.reflect.Method; import java.util.HashMap; import java.util.Map; import java.util.Timer; import java.util.TimerTask; import java.util.concurrent.TimeUnit; @Component public class RequestLimitInterceptor extends HandlerInterceptorAdapter { /** * 限流Map,懒得搭建Redis,暂时放在Map里面,然后通过定时任务实现限流 * 实际业务中可以放在Redis,利用过期时间限流 */ private final Map<String,Integer> map = new HashMap<>(); /** * 拦截请求执行的方法 * @param request * @param response * @param handler * @return * @throws Exception */ @Override public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception { System.out.println("限流的拦截器!"); //判断处理类是否为HandlerMethod if(handler.getClass().isAssignableFrom(HandlerMethod.class)){ //进行强制转换 HandlerMethod handlerMethod = (HandlerMethod)handler; //获取拦截的方法 Method method = handlerMethod.getMethod(); //获取方法上的注解对象,看是否被RequestLimit修饰 RequestLimit limiter = getTagAnnotation(method, RequestLimit.class); //判断是否限流 if(null != limiter){ if(isLimit(request,limiter)){ responseOut(response,limiter.maxCount()); return false; } } } return super.preHandle(request, response, handler); } /** * 封装返回结果 */ private void responseOut(HttpServletResponse response,Integer limit) throws IOException { response.setCharacterEncoding("UTF-8"); response.setContentType("application/json; charset=utf-8"); PrintWriter writer = response.getWriter(); Map<String,String> resultMap = new HashMap<>(); resultMap.put("status","502"); resultMap.put("msg","接口超出最大请求数:" + limit); String s = JSONObject.toJSON(resultMap).toString(); writer.append(s); } //获取处理类上的注解 public <T extends Annotation> T getTagAnnotation(Method method, Class<T> annotationClass){ //获取方法中是否有相关注解 T methodAnnotation = method.getAnnotation(annotationClass); //获取类上是否有相关注解 T classAnnotation = method.getDeclaringClass().getAnnotation(annotationClass); //判断是否存在相关注解 if(null != methodAnnotation){ return methodAnnotation; }else return classAnnotation; } /** * 判断接口是否限流,通过请求的SessionId进行限流 */ public boolean isLimit(HttpServletRequest request,RequestLimit limiter){ //获取请求的SessionID String id = request.getSession().getId(); //查看是否在限流map里面 Integer num = map.get(id); System.out.println("SessionId:" + id + "n" + "num:" + num + "n" + "limiterCount:" + limiter.maxCount() + "n" + "limit:" + limiter.limit()); //没有则初始化限流map,并创建定时任务(解除限流) if(null == num){ //初始化计数器 map.put(id,1); //创建定时器任务,删除限流器 Timer timer = new Timer(); //获取限流的时间毫秒数 long delay = getDelay(limiter.timeUnit(), limiter.limit()); timer.schedule(new TimerTask() { @Override public void run() { System.out.println("删除任务执行"); map.remove(id); } },delay); }else{ //累加请求 ++num; //判断是否超出最大限流次数 if(num > limiter.maxCount()){ return true; } //更新计数器 map.put(id,num); } return false; } /** * 获取限流时间,总共限流的毫秒数 * @param timeUnit * @param limit * @return */ public long getDelay(TimeUnit timeUnit,Integer limit){ if(null == timeUnit || limit == 0){ return 0; } switch (timeUnit){ case MILLISECONDS: return limit; case MINUTES: return limit*60*1000; case HOURS: return limit*60*60*1000; default: return limit*1000; } } }
1.3 将自定义拦截器加入到Spring的拦截器
复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21package com.zhf.model.config; import com.zhf.model.interceptor.RequestLimitInterceptor; import org.springframework.stereotype.Component; import org.springframework.web.servlet.config.annotation.InterceptorRegistry; import org.springframework.web.servlet.config.annotation.WebMvcConfigurer; import javax.annotation.Resource; @Component public class WebMVCConfig implements WebMvcConfigurer { @Resource RequestLimitInterceptor limiter; @Override public void addInterceptors(InterceptorRegistry registry) { registry.addInterceptor(limiter); WebMvcConfigurer.super.addInterceptors(registry); } }
1.4 创建MainController进行测试
为了方便测试,这里设置为1分钟5次
复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19package com.zhf.model.controller; import com.zhf.model.annotation.RequestLimit; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController; import java.util.concurrent.TimeUnit; @RestController @RequestMapping("/main") public class MainController { @RequestMapping("/lock") @RequestLimit(maxCount = 5 , limit = 1 ,timeUnit = TimeUnit.MINUTES) public String testLock(){ return "ok"; } }
16:19点击5次,接口限流

16:20再次点击,限流解除

AOP实现
没依赖先添加依赖
复制代码
1
2
3
4
5
6
7
8
9
10
11<dependency> <groupId>org.springframework</groupId> <artifactId>spring-aspects</artifactId> <version>4.3.7.RELEASE</version> </dependency> <dependency> <groupId>org.springframework</groupId> <artifactId>spring-aop</artifactId> <version>4.3.7.RELEASE</version> </dependency>
2.1 创建限流的注解类RateLimiter.java
复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27package com.zhf.model.annotation; import java.lang.annotation.*; import java.util.concurrent.TimeUnit; @Target(ElementType.METHOD) @Retention(RetentionPolicy.RUNTIME) @Documented public @interface RateLimiter { /** * 限流的时间单位 */ TimeUnit timeUnit() default TimeUnit.SECONDS; /** * 限流的时长 */ int limit() default 1; /** * 最大限流量 * @return */ int maxCount() default 1; }
2.2 创建切面RateLimiterAspect.java
这里直接使用Redis(String)的过期时间作为限流的计数器
直接使用String不方便管理,可以使用RedisScript进行管理
复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64package com.zhf.model.aop; import com.alibaba.fastjson.JSONObject; import com.zhf.model.annotation.RateLimiter; import com.zhf.model.exception.CommonException; import org.aspectj.lang.JoinPoint; import org.aspectj.lang.annotation.Aspect; import org.aspectj.lang.annotation.Before; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.data.redis.core.RedisTemplate; import org.springframework.stereotype.Component; import org.springframework.web.context.request.RequestContextHolder; import org.springframework.web.context.request.ServletRequestAttributes; import javax.servlet.http.HttpServletResponse; import java.io.IOException; import java.io.PrintWriter; import java.util.HashMap; import java.util.Map; /** * 限流处理类 * @Aspect 声明该类为切面 */ @Aspect @Component public class RateLimiterAspect { @Autowired private RedisTemplate redisTemplate; /** * 执行前置方法,"@annotation(rateLimiter)"在注解rateLimiter之前执行 * @param point * @param rateLimiter * @throws Throwable */ @Before("@annotation(rateLimiter)") public void doBefore(JoinPoint point, RateLimiter rateLimiter) throws Throwable { //获取请求 ServletRequestAttributes attributes = (ServletRequestAttributes)RequestContextHolder.getRequestAttributes(); //获取SessionID String id = attributes.getRequest().getSession().getId(); //这里使用redis过期时间作限流 Integer count = (Integer)redisTemplate.opsForValue().get(id); //如果第一次请求或上次限流已经解除 System.out.println("Count:" + count + ";TimeUnit:" + rateLimiter.timeUnit()); if(null == count){ //初始化限流器 redisTemplate.opsForValue().set(id,1,rateLimiter.limit(),rateLimiter.timeUnit()); }else{ //累加 ++count; redisTemplate.opsForValue().set(id,count,0); if(count > rateLimiter.maxCount()){ //抛出自定义异常码,然后统一返回 throw new CommonException(480); } } } }
2.3 自定义异常类CommonException,同一异常返回
自定义异常类
复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24package com.zhf.model.exception; import lombok.Data; @Data public class CommonException extends RuntimeException{ private int code; private String msg; public CommonException() { } public CommonException(int code) { this.code = code; } public CommonException(int code, String msg) { this.code = code; this.msg = msg; } }
统一异常处理,将限流的异常码catch

复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48package com.zhf.model.handler; import com.zhf.model.exception.CommonException; import lombok.extern.slf4j.Slf4j; import org.springframework.http.HttpStatus; import org.springframework.web.bind.annotation.*; import java.time.format.DateTimeParseException; @RestControllerAdvice @Slf4j public class CommonExceptionHandler { @ExceptionHandler(NullPointerException.class) @ResponseStatus(value= HttpStatus.INTERNAL_SERVER_ERROR) public ReturnResult handleTypeMismatchException(NullPointerException ex){ log.debug(ex.getMessage()); return new ReturnResult(500,"空指针异常"); } @ExceptionHandler(ArithmeticException.class) @ResponseStatus(value= HttpStatus.INTERNAL_SERVER_ERROR) public ReturnResult handleArithmeticException(ArithmeticException ex){ ex.printStackTrace(); return new ReturnResult(500,"被除数不能为零"); } @ExceptionHandler(DateTimeParseException.class) @ResponseStatus(value= HttpStatus.INTERNAL_SERVER_ERROR) public ReturnResult handleDateTimeParseException(DateTimeParseException ex){ ex.printStackTrace(); return new ReturnResult(500,"时间转换格式错误"); } @ExceptionHandler(CommonException.class) @ResponseStatus(value= HttpStatus.INTERNAL_SERVER_ERROR) public ReturnResult handleCommonException(CommonException ex){ ex.printStackTrace(); if(ex.getCode() == 480){ return new ReturnResult(480,"接口限流"); } if(ex.getCode() == 502){ return new ReturnResult(502,"自定义异常处理502"); } return new ReturnResult(500,"时间转换格式错误"); } }
2.4创建测试类
方便测试,一分钟五次
复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29package com.zhf.model.controller; import com.zhf.model.annotation.RateLimiter; import com.zhf.model.annotation.RequestLimit; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController; import java.util.concurrent.TimeUnit; @RestController @RequestMapping("/main") public class MainController { /** * 测试限流 * @return */ @RequestMapping("/lock") @RequestLimit(maxCount = 5 , limit = 1 ,timeUnit = TimeUnit.MINUTES) public String testLock(){ return "ok"; } @RequestMapping("/testAopLimit") @RateLimiter(maxCount = 5 , limit = 1 ,timeUnit = TimeUnit.MINUTES) public String testAopLimit(){ return "ok"; } }
18:00点击5次,接口限流

18:01再次请求,解除限流

最后
以上就是阳光棒棒糖最近收集整理的关于SpringBoot自定义接口限流注解(拦截器实现,AOP实现)拦截器版AOP实现的全部内容,更多相关SpringBoot自定义接口限流注解(拦截器实现内容请搜索靠谱客的其他文章。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复