通过Filter和Redis实现分布式应用接口限流


通过Filter和Redis实现分布式应用接口限流

背景需求分析

  • 公司的restful的公共接口供第三方调用,在没有限流的情况下会导致服务应用程序奔溃
  • 需要限制第三方地区请求次数,比如一天只能发起1万次请求

思路分析

  • 第三方发起的是POST请求,请求体Body中是json数据,json数据中有userId这个数据可作为唯一标识的值

  • 请求体Body

{"encryptParam":"NrHhFoPcpvRS8OlUMAvDhrof9k7KTvKQkXSxejazUC6BNia9YXb2EKYkQakpraYD48mmFZeQ1UTXm2Av1sg+orlJ8wsclsWuSEjjtl5/nDlEb2N5DMEqd","secretKey":"043B6BB2AE96D62A9BEE7AB4B177167FE0D0811EAB250DADC2A8","userId":"1_330104_1"}
  • 我们可以通过Filter+redis实现接口限流,也可以通过aop切面自定义注解实现接口限流,本教程实现思路是通过Filter+redis实现接口限流

实现过程

工具类

IpUtil

package com.gisquest.realestate.supervise.core.utils;

import org.springframework.stereotype.Component;

import javax.servlet.http.HttpServletRequest;
import java.math.BigInteger;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;

/**
 * @Author:suny
 * @Date: 2021/11/26 9:02
 * @Description:
 */
@Component
public class IpUtil {

    public static String getIp(HttpServletRequest request) {
        String ip = request.getHeader("x-forwarded-for");
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("Proxy-Client-IP");
        }
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("WL-Proxy-Client-IP");
        }
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("HTTP_CLIENT_IP");
        }
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("HTTP_X_FORWARDED_FOR");
        }
        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getRemoteAddr();
        }
        return ip;
    }

    public static String getMod5Url(String url) {
        String ret = null;
        try {
            MessageDigest digest = MessageDigest.getInstance("MD5");
            ret = new BigInteger(1, digest.digest(url.getBytes())).toString(16);
        } catch (NoSuchAlgorithmException e) {
            e.printStackTrace();
        }
        return ret;
    }
}

RedisService

package com.gisquest.realestate.supervise.core.service;

import java.util.concurrent.TimeUnit;

/**
 * @Author:suny
 * @Date: 2021/11/26 9:05
 * @Description:
 */
public interface RedisService {

    Boolean delete(String key);

    String getValue(String key);

    void increaseOrExpire(String key, long expire, TimeUnit timeUnit);

}

RedisServiceImpl

package com.gisquest.realestate.supervise.core.service.impl;

import com.gisquest.realestate.supervise.core.service.RedisService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service;
import java.util.concurrent.TimeUnit;

/**
 * @Author:suny
 * @Date: 2021/11/26 9:07
 * @Description:
 */
@Service
@Slf4j
public class RedisServiceImpl implements RedisService {
    //private final static String PREFIX = "rework-stats-";
    private final StringRedisTemplate redisTemplate;

    public RedisServiceImpl(StringRedisTemplate redisTemplate) {
        this.redisTemplate = redisTemplate;
    }

    @Override
    @Async
    public Boolean delete(String key) {
        log.debug("RedisUtils delete key:{}", genKey(key));
        return redisTemplate.opsForSet().getOperations().delete(genKey(key));
    }

    @Override
    public String getValue(String key) {
        return redisTemplate.opsForValue().get(genKey(key));
    }

    @Override
    public void increaseOrExpire(String key, long expire, TimeUnit timeUnit) {

        String cnt = getValue(key);

        key = genKey(key);

        if (cnt != null) {
            log.debug("key:{} cnt:{} increment ", key, cnt);
            redisTemplate.opsForValue().increment(key);
        } else {
            redisTemplate.opsForValue().set(key, "1");
            redisTemplate.expire(key, expire, timeUnit);
            log.debug("create rateLimiter expire:{} - {} - {} ", key, expire, timeUnit);
        }
    }

    /**
     * 生成统一规则的前缀key
     *
     * @param key key值
     * @return
     */
    private static String genKey(String key) {
        StringBuilder ret = new StringBuilder();
        ret.append(key);
        return ret.toString();
    }

}

CustomHttpServletRequestWrapper

package com.gisquest.realestate.restful.supervise.collection.filter;

import javax.servlet.ReadListener;
import javax.servlet.ServletException;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.*;

/**
 * @Author:suny
 * @Date: 2021/11/26 16:09
 * @Description:
 */
public class CustomHttpServletRequestWrapper extends HttpServletRequestWrapper {

    private final String body;

    public CustomHttpServletRequestWrapper(HttpServletRequest request) throws IOException, ServletException {
        super(request);
        body = read(request);
    }

    @Override
    public BufferedReader getReader() {
        return new BufferedReader(new InputStreamReader(this.getInputStream()));
    }

    @Override
    public ServletInputStream getInputStream() {
        final ByteArrayInputStream bais = new ByteArrayInputStream(body.getBytes());
        return new ServletInputStream() {

            @Override
            public boolean isFinished() {
                return false;
            }

            @Override
            public boolean isReady() {
                return false;
            }

            @Override
            public void setReadListener(ReadListener readListener) {

            }

            @Override
            public int read() {
                return bais.read();
            }
        };
    }

    private String read(HttpServletRequest request) throws IOException {
        BufferedReader bufferedReader = request.getReader();
        StringWriter writer = new StringWriter();
        write(bufferedReader, writer);
        return writer.getBuffer().toString();
    }

    private void write(Reader reader, Writer writer) throws IOException {
        int read;
        char[] buf = new char[1024 * 8];
        while ((read = reader.read(buf)) != -1) {
            writer.write(buf, 0, read);
        }
    }
}

RateLimiterFilter

package com.gisquest.realestate.restful.supervise.collection.filter;

import com.alibaba.fastjson.JSONObject;
import com.alibaba.fastjson.serializer.SerializerFeature;
import com.gisquest.realestate.data.supervise.collection.conf.GisqplatformSuperviseProperties;
import com.gisquest.realestate.data.supervise.collection.conf.help.LogbackHelper;
import com.gisquest.realestate.supervise.core.service.RedisService;
import com.gisquest.realestate.supervise.core.utils.IpUtil;
import lombok.extern.slf4j.Slf4j;
import org.apache.logging.log4j.util.Strings;
import org.springframework.stereotype.Component;
import org.springframework.web.filter.OncePerRequestFilter;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.*;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.TimeUnit;

/**
 * @Author:suny
 * @Date: 2021/11/26 9:09
 * @Description:
 */
@Slf4j
@Component
public class RateLimiterFilter extends OncePerRequestFilter {

    //存放需要限流地区的url和对应的userId
    private List<String> rateLimiterList = new ArrayList<>();

    private final RedisService redisService;

    private final GisqplatformSuperviseProperties properties;

    public RateLimiterFilter(RedisService redisService,GisqplatformSuperviseProperties properties) {
        this.redisService = redisService;
        this.properties = properties;
    }

    @Override
    protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {
        String url = request.getRequestURI();
        System.out.println(url);
        String userId = "";
        CustomHttpServletRequestWrapper requestWrapper = new CustomHttpServletRequestWrapper(request);
        String body = ReadAsChars(requestWrapper);

        //根据yml配置和request对象获得的url动态匹配配置限流url和请求的userId
        if(properties.getRateLimiterMap().get(url) != null){
            JSONObject jsonObject = JSONObject.parseObject(body);
            System.out.println(body);
            String userIdSrc = jsonObject.getString("userId");
            //截取userId取中间区县代码的前4位
            userId = userIdSrc.substring(0,6)+userIdSrc.substring(8,10);
            System.out.println(userId);
            rateLimiterList.add(url);
            //添加url对应的userId
            for(int i = 0; i < properties.getRateLimiterMap().get(url).size(); i++){
                rateLimiterList.add(properties.getRateLimiterMap().get(url).get(i));
            }
        }
        if (rateLimiterList.contains(url) && rateLimiterList.contains(userId)) {
            String ip = IpUtil.getIp(request);
            //String key = IpUtil.getMod5Url(url).substring(0, 8) + "-" + ip + userId;
            //将userId用为redis的key
            String key = userId;
            String cnt = redisService.getValue(key);
            System.out.println(cnt);
                if(properties.getRate().equals(cnt)){
                    LogbackHelper.logLimit.info("【IP】="+ip+"【URL】="+url+"【KEY】="+key+"【CNT】="+cnt);
                }
                if (!Strings.isBlank(cnt) && Integer.parseInt(cnt) > Integer.parseInt(properties.getRate())) {
                    log.warn("[{}] - [{}] 访问频率上限[{}]次/分钟. key:[{}]", url,ip, cnt, key);
                    PrintWriter out = response.getWriter();
                    response.setCharacterEncoding("utf-8");
                    response.setContentType("application/json; charset=utf-8");
                    //这里返回固定的BaseResponse对象给前端,直接抛异常时发现,在filter中的异常,GlobalExceptionHandler全局异常处理类无法处理
                    out.print(JSONObject.toJSONString( "please wait some times!", SerializerFeature.WriteNullStringAsEmpty));
                    out.flush();
                    out.close();
                    return;
                }
                redisService.increaseOrExpire(key, Integer.parseInt(properties.getExpireTime()), TimeUnit.MINUTES);
            }
        //责任链模式
       filterChain.doFilter(requestWrapper,response);

    }

    public static String ReadAsChars(HttpServletRequest request) {

        BufferedReader br = null;
        StringBuilder sb = new StringBuilder("");
        try {
            br = request.getReader();
            String str;
            while ((str = br.readLine()) != null) {
                sb.append(str);
            }
            br.close();
        } catch (IOException e) {
            e.printStackTrace();
        } finally {
            if (null != br) {
                try {
                    br.close();
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
        }
        return sb.toString();
    }
}

配置类

GisqplatformSuperviseConfiguration

package com.gisquest.realestate.data.supervise.collection.conf;

import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Configuration;

@Configuration
@EnableConfigurationProperties({GisqplatformSuperviseProperties.class})
public class GisqplatformSuperviseConfiguration {
}

GisqplatformSuperviseProperties

package com.gisquest.realestate.data.supervise.collection.conf;

import com.fasterxml.jackson.annotation.JsonFormat;
import com.gisquest.realestate.data.supervise.collection.conf.check.CheckProperties;
import com.gisquest.realestate.data.supervise.collection.conf.dom.CollectionProperties;
import org.springframework.boot.context.properties.ConfigurationProperties;
import lombok.Data;
import org.springframework.format.annotation.DateTimeFormat;
import java.util.Date;
import java.util.List;
import java.util.Map;

@Data
@ConfigurationProperties(prefix = "gisq.platform.supervise.collection")
public class GisqplatformSuperviseProperties {
    /**
     * 单位时间内expireTime 访问接口的频率 单位:次
     */
    private String rate;

    /**
     * key的过期时间,单位:分钟
     */
    private String expireTime;

    /**
     * 限流的url和userId的map关系
     */
    private Map<String,List<String>> rateLimiterMap;

}

application.yml

gisq:
  platform:
    supervise:
      collection:
        #单位时间内expireTime 访问接口的频率 单位:次 10000
        rate: 100
        #key的过期时间,单位:分钟 1440
        expireTime: 1
        rateLimiterMap:
          '[/rec/decrypt/data]':
            - '1_3301_1'
            - '1_3302_1'
          '[/rg/rec/ywxx/data]':
            - '1_3304_1'
            - '1_3310_1'
            - '1_3301_1'                    

小总结

  • 问题1
getReader()/getInputStream() has already been called for this request

参考https://blog.csdn.net/qq_40161158/article/details/106060820 解决

本案例中实现了CustomHttpServletRequestWrapper类

  • 问题2

map中的key是链接格式的话,在yml配置中需要加[]


文章作者: fejxc
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 fejxc !
评论
  目录