通过Filter和Redis实现分布式应用接口限流
背景需求分析
- 公司的restful的公共接口供第三方调用,在没有限流的情况下会导致服务应用程序奔溃
- 需要限制第三方地区请求次数,比如一天只能发起1万次请求
思路分析
第三方发起的是POST请求,请求体Body中是json数据,json数据中有userId这个数据可作为唯一标识的值
请求体Body
{"encryptParam":"NrHhFoPcpvRS8OlUMAvDhrof9k7KTvKQkXSxejazUC6BNia9YXb2EKYkQakpraYD48mmFZeQ1UTXm2Av1sg+orlJ8wsclsWuSEjjtl5/nDlEb2N5DMEqd","secretKey":"043B6BB2AE96D62A9BEE7AB4B177167FE0D0811EAB250DADC2A8","userId":"1_330104_1"}
- 我们可以通过Filter+redis实现接口限流,也可以通过aop切面自定义注解实现接口限流,本教程实现思路是通过Filter+redis实现接口限流
实现过程
工具类
IpUtil
package com.gisquest.realestate.supervise.core.utils;
import org.springframework.stereotype.Component;
import javax.servlet.http.HttpServletRequest;
import java.math.BigInteger;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
/**
* @Author:suny
* @Date: 2021/11/26 9:02
* @Description:
*/
@Component
public class IpUtil {
public static String getIp(HttpServletRequest request) {
String ip = request.getHeader("x-forwarded-for");
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("Proxy-Client-IP");
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("WL-Proxy-Client-IP");
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("HTTP_CLIENT_IP");
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("HTTP_X_FORWARDED_FOR");
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getRemoteAddr();
}
return ip;
}
public static String getMod5Url(String url) {
String ret = null;
try {
MessageDigest digest = MessageDigest.getInstance("MD5");
ret = new BigInteger(1, digest.digest(url.getBytes())).toString(16);
} catch (NoSuchAlgorithmException e) {
e.printStackTrace();
}
return ret;
}
}
RedisService
package com.gisquest.realestate.supervise.core.service;
import java.util.concurrent.TimeUnit;
/**
* @Author:suny
* @Date: 2021/11/26 9:05
* @Description:
*/
public interface RedisService {
Boolean delete(String key);
String getValue(String key);
void increaseOrExpire(String key, long expire, TimeUnit timeUnit);
}
RedisServiceImpl
package com.gisquest.realestate.supervise.core.service.impl;
import com.gisquest.realestate.supervise.core.service.RedisService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service;
import java.util.concurrent.TimeUnit;
/**
* @Author:suny
* @Date: 2021/11/26 9:07
* @Description:
*/
@Service
@Slf4j
public class RedisServiceImpl implements RedisService {
//private final static String PREFIX = "rework-stats-";
private final StringRedisTemplate redisTemplate;
public RedisServiceImpl(StringRedisTemplate redisTemplate) {
this.redisTemplate = redisTemplate;
}
@Override
@Async
public Boolean delete(String key) {
log.debug("RedisUtils delete key:{}", genKey(key));
return redisTemplate.opsForSet().getOperations().delete(genKey(key));
}
@Override
public String getValue(String key) {
return redisTemplate.opsForValue().get(genKey(key));
}
@Override
public void increaseOrExpire(String key, long expire, TimeUnit timeUnit) {
String cnt = getValue(key);
key = genKey(key);
if (cnt != null) {
log.debug("key:{} cnt:{} increment ", key, cnt);
redisTemplate.opsForValue().increment(key);
} else {
redisTemplate.opsForValue().set(key, "1");
redisTemplate.expire(key, expire, timeUnit);
log.debug("create rateLimiter expire:{} - {} - {} ", key, expire, timeUnit);
}
}
/**
* 生成统一规则的前缀key
*
* @param key key值
* @return
*/
private static String genKey(String key) {
StringBuilder ret = new StringBuilder();
ret.append(key);
return ret.toString();
}
}
CustomHttpServletRequestWrapper
package com.gisquest.realestate.restful.supervise.collection.filter;
import javax.servlet.ReadListener;
import javax.servlet.ServletException;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.*;
/**
* @Author:suny
* @Date: 2021/11/26 16:09
* @Description:
*/
public class CustomHttpServletRequestWrapper extends HttpServletRequestWrapper {
private final String body;
public CustomHttpServletRequestWrapper(HttpServletRequest request) throws IOException, ServletException {
super(request);
body = read(request);
}
@Override
public BufferedReader getReader() {
return new BufferedReader(new InputStreamReader(this.getInputStream()));
}
@Override
public ServletInputStream getInputStream() {
final ByteArrayInputStream bais = new ByteArrayInputStream(body.getBytes());
return new ServletInputStream() {
@Override
public boolean isFinished() {
return false;
}
@Override
public boolean isReady() {
return false;
}
@Override
public void setReadListener(ReadListener readListener) {
}
@Override
public int read() {
return bais.read();
}
};
}
private String read(HttpServletRequest request) throws IOException {
BufferedReader bufferedReader = request.getReader();
StringWriter writer = new StringWriter();
write(bufferedReader, writer);
return writer.getBuffer().toString();
}
private void write(Reader reader, Writer writer) throws IOException {
int read;
char[] buf = new char[1024 * 8];
while ((read = reader.read(buf)) != -1) {
writer.write(buf, 0, read);
}
}
}
RateLimiterFilter
package com.gisquest.realestate.restful.supervise.collection.filter;
import com.alibaba.fastjson.JSONObject;
import com.alibaba.fastjson.serializer.SerializerFeature;
import com.gisquest.realestate.data.supervise.collection.conf.GisqplatformSuperviseProperties;
import com.gisquest.realestate.data.supervise.collection.conf.help.LogbackHelper;
import com.gisquest.realestate.supervise.core.service.RedisService;
import com.gisquest.realestate.supervise.core.utils.IpUtil;
import lombok.extern.slf4j.Slf4j;
import org.apache.logging.log4j.util.Strings;
import org.springframework.stereotype.Component;
import org.springframework.web.filter.OncePerRequestFilter;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.*;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.TimeUnit;
/**
* @Author:suny
* @Date: 2021/11/26 9:09
* @Description:
*/
@Slf4j
@Component
public class RateLimiterFilter extends OncePerRequestFilter {
//存放需要限流地区的url和对应的userId
private List<String> rateLimiterList = new ArrayList<>();
private final RedisService redisService;
private final GisqplatformSuperviseProperties properties;
public RateLimiterFilter(RedisService redisService,GisqplatformSuperviseProperties properties) {
this.redisService = redisService;
this.properties = properties;
}
@Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {
String url = request.getRequestURI();
System.out.println(url);
String userId = "";
CustomHttpServletRequestWrapper requestWrapper = new CustomHttpServletRequestWrapper(request);
String body = ReadAsChars(requestWrapper);
//根据yml配置和request对象获得的url动态匹配配置限流url和请求的userId
if(properties.getRateLimiterMap().get(url) != null){
JSONObject jsonObject = JSONObject.parseObject(body);
System.out.println(body);
String userIdSrc = jsonObject.getString("userId");
//截取userId取中间区县代码的前4位
userId = userIdSrc.substring(0,6)+userIdSrc.substring(8,10);
System.out.println(userId);
rateLimiterList.add(url);
//添加url对应的userId
for(int i = 0; i < properties.getRateLimiterMap().get(url).size(); i++){
rateLimiterList.add(properties.getRateLimiterMap().get(url).get(i));
}
}
if (rateLimiterList.contains(url) && rateLimiterList.contains(userId)) {
String ip = IpUtil.getIp(request);
//String key = IpUtil.getMod5Url(url).substring(0, 8) + "-" + ip + userId;
//将userId用为redis的key
String key = userId;
String cnt = redisService.getValue(key);
System.out.println(cnt);
if(properties.getRate().equals(cnt)){
LogbackHelper.logLimit.info("【IP】="+ip+"【URL】="+url+"【KEY】="+key+"【CNT】="+cnt);
}
if (!Strings.isBlank(cnt) && Integer.parseInt(cnt) > Integer.parseInt(properties.getRate())) {
log.warn("[{}] - [{}] 访问频率上限[{}]次/分钟. key:[{}]", url,ip, cnt, key);
PrintWriter out = response.getWriter();
response.setCharacterEncoding("utf-8");
response.setContentType("application/json; charset=utf-8");
//这里返回固定的BaseResponse对象给前端,直接抛异常时发现,在filter中的异常,GlobalExceptionHandler全局异常处理类无法处理
out.print(JSONObject.toJSONString( "please wait some times!", SerializerFeature.WriteNullStringAsEmpty));
out.flush();
out.close();
return;
}
redisService.increaseOrExpire(key, Integer.parseInt(properties.getExpireTime()), TimeUnit.MINUTES);
}
//责任链模式
filterChain.doFilter(requestWrapper,response);
}
public static String ReadAsChars(HttpServletRequest request) {
BufferedReader br = null;
StringBuilder sb = new StringBuilder("");
try {
br = request.getReader();
String str;
while ((str = br.readLine()) != null) {
sb.append(str);
}
br.close();
} catch (IOException e) {
e.printStackTrace();
} finally {
if (null != br) {
try {
br.close();
} catch (IOException e) {
e.printStackTrace();
}
}
}
return sb.toString();
}
}
配置类
GisqplatformSuperviseConfiguration
package com.gisquest.realestate.data.supervise.collection.conf;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Configuration;
@Configuration
@EnableConfigurationProperties({GisqplatformSuperviseProperties.class})
public class GisqplatformSuperviseConfiguration {
}
GisqplatformSuperviseProperties
package com.gisquest.realestate.data.supervise.collection.conf;
import com.fasterxml.jackson.annotation.JsonFormat;
import com.gisquest.realestate.data.supervise.collection.conf.check.CheckProperties;
import com.gisquest.realestate.data.supervise.collection.conf.dom.CollectionProperties;
import org.springframework.boot.context.properties.ConfigurationProperties;
import lombok.Data;
import org.springframework.format.annotation.DateTimeFormat;
import java.util.Date;
import java.util.List;
import java.util.Map;
@Data
@ConfigurationProperties(prefix = "gisq.platform.supervise.collection")
public class GisqplatformSuperviseProperties {
/**
* 单位时间内expireTime 访问接口的频率 单位:次
*/
private String rate;
/**
* key的过期时间,单位:分钟
*/
private String expireTime;
/**
* 限流的url和userId的map关系
*/
private Map<String,List<String>> rateLimiterMap;
}
application.yml
gisq:
platform:
supervise:
collection:
#单位时间内expireTime 访问接口的频率 单位:次 10000
rate: 100
#key的过期时间,单位:分钟 1440
expireTime: 1
rateLimiterMap:
'[/rec/decrypt/data]':
- '1_3301_1'
- '1_3302_1'
'[/rg/rec/ywxx/data]':
- '1_3304_1'
- '1_3310_1'
- '1_3301_1'
小总结
- 问题1
getReader()/getInputStream() has already been called for this request
参考https://blog.csdn.net/qq_40161158/article/details/106060820 解决
本案例中实现了CustomHttpServletRequestWrapper类
- 问题2
map中的key是链接格式的话,在yml配置中需要加[]