Spring Cloud Gateway 读取请求传参

2年前 (2022) 程序员胖胖胖虎阿
257 0 0

背景介绍

有个业务需求,要提供一套API接口给第三方调用。

在处理具体业务接口之前,设计上要先做个简单的鉴权,协商拟定了身份传参后,考虑到项目上已经用到了Spring Cloud Gateway ,就统一在网关模块做身份校验。

所以在服务端获取到请求的时候,要先拦截获取到请求传参,才能做后续的鉴权逻辑。

这里就需要解决一个问题:Spring Cloud Gateway 怎么读取请求传参?

搜索关键词:spring cloud gateway get request body

问题描述

问题:Spring Cloud Gateway 读取请求传参

这里只简单处理两种情况,get请求和post请求。

如果发现是get请求,就取url上的参数;
如果发现是post请求,就读取body的内容。

解决方案

参考 https://github.com/spring-cloud/spring-cloud-gateway/issues/747

定义了两个过滤器 filter,第一个过滤器ApiRequestFilter获取参数,放到上下文 GatewayContext

注意如果是POST请求,请求体读取完后,要重新构造,填回请求体中。

第二个过滤器ApiVerifyFilter, 从上下文可以直接获取到参数。

后面如果其他业务也有读取参数的需求,就直接从上下文获取,不用再重复写获取参数的逻辑。

实现代码

GatewayContext

@Data
public class GatewayContext {
    public static final String CACHE_GATEWAY_CONTEXT = "cacheGatewayContext";

    /**
     * cache json body
     */
    private String cacheBody;
    /**
     * cache form data
     */
    private MultiValueMap<String, Part> formData;
    /**
     * cache request path
     */
    private String path;
}

ApiRequestFilter


@Component
@Slf4j
public class ApiRequestFilter implements GlobalFilter, Ordered {

    private static AntPathMatcher antPathMatcher;

    static {
        antPathMatcher = new AntPathMatcher();
    }

    /**
     * default HttpMessageReader
     */
    private static final List<HttpMessageReader<?>> messageReaders = HandlerStrategies.withDefaults().messageReaders();

    private static final ResolvableType MULTIPART_DATA_TYPE = ResolvableType.forClassWithGenerics(MultiValueMap.class, String.class, Part.class);

    private static final Mono<MultiValueMap<String, Part>> EMPTY_MULTIPART_DATA = Mono.just(CollectionUtils.unmodifiableMultiValueMap(new LinkedMultiValueMap<String, Part>(0))).cache();

    @Override
    public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
        ServerHttpRequest request = exchange.getRequest();
        String url = request.getURI().getPath();

        if(request.getMethod() == HttpMethod.GET){
            // get请求 处理参数
            return handleGetMethod(exchange, chain, request);
        }

        if(request.getMethod() == HttpMethod.POST){
            // post请求 处理参数
            return handlePostMethod(exchange, chain, request);
        }

        return chain.filter(exchange);
    }

    /**
     * get请求 处理参数
     * @param exchange
     * @param chain
     * @param request
     * @return
     */
    private Mono<Void> handleGetMethod(ServerWebExchange exchange, GatewayFilterChain chain, ServerHttpRequest request) {
        // TODO 暂时不做处理

        return chain.filter(exchange);
    }

    /**
     * post请求 校验参数
     * @param exchange
     * @param chain
     * @param request
     * @return
     */
    private Mono<Void> handlePostMethod(ServerWebExchange exchange, GatewayFilterChain chain, ServerHttpRequest request){
        GatewayContext gatewayContext = new GatewayContext();
        gatewayContext.setPath(request.getPath().pathWithinApplication().value());
        /**
         * save gateway context into exchange
         */
        exchange.getAttributes().put(GatewayContext.CACHE_GATEWAY_CONTEXT, gatewayContext);

        MediaType contentType = request.getHeaders().getContentType();
        if(MediaType.APPLICATION_JSON.equals(contentType)
                || MediaType.APPLICATION_JSON_UTF8.equals(contentType)){
            // 请求内容为 application json

            // 重新构造 请求体
            return readJsonBody(exchange, chain, gatewayContext);
        }

        if (MediaType.MULTIPART_FORM_DATA.isCompatibleWith(contentType)) {
            // 请求内容为 form data
            return readFormData(exchange, chain, gatewayContext);
        }
        return chain.filter(exchange);
    }

    /**
     * post 请求
     * 重新构造 请求体
     * @param exchange
     * @param chain
     * @param gatewayContext
     * @return
     */
    private Mono<Void> readJsonBody(ServerWebExchange exchange, GatewayFilterChain chain, GatewayContext gatewayContext) {
        return DataBufferUtils.join(exchange.getRequest().getBody())
                .flatMap(dataBuffer -> {
                    /*
                     * read the body Flux<DataBuffer>, and release the buffer
                     * //TODO when SpringCloudGateway Version Release To G.SR2,this can be update with the new version's feature
                     * see PR https://github.com/spring-cloud/spring-cloud-gateway/pull/1095
                     */
                    byte[] bytes = new byte[dataBuffer.readableByteCount()];
                    dataBuffer.read(bytes);
                    DataBufferUtils.release(dataBuffer);
                    Flux<DataBuffer> cachedFlux = Flux.defer(() -> {
                        DataBuffer buffer = exchange.getResponse().bufferFactory().wrap(bytes);
                        DataBufferUtils.retain(buffer);
                        return Mono.just(buffer);
                    });
                    /**
                     * repackage ServerHttpRequest
                     */
                    ServerHttpRequest mutatedRequest =
                            new ServerHttpRequestDecorator(exchange.getRequest()) {
                                @Override
                                public Flux<DataBuffer> getBody() {
                                    return cachedFlux;
                                }
                            };
                    /**
                     * mutate exchage with new ServerHttpRequest
                     */
                    ServerWebExchange mutatedExchange = exchange.mutate().request(mutatedRequest).build();
                    /**
                     * read body string with default messageReaders
                     */
                    return ServerRequest.create(mutatedExchange, messageReaders)
                            .bodyToMono(String.class)
                            .doOnNext(objectValue -> {
                                // save body into gatewayContext
                                gatewayContext.setCacheBody(objectValue);
                            })
                            .then(chain.filter(mutatedExchange));
                });
    }

    private Mono<Void> readFormData(ServerWebExchange exchange, GatewayFilterChain chain, GatewayContext gatewayContext) {
        return exchange.getRequest().getBody().collectList().flatMap(dataBuffers -> {
            final byte[] totalBytes = dataBuffers.stream().map(dataBuffer -> {
                try {
                    final byte[] bytes = IOUtils.toByteArray(dataBuffer.asInputStream());
//                    System.out.println(new String(bytes));
                    return bytes;
                } catch (IOException e) {
                    throw new RuntimeException(e);
                }
            }).reduce(this::addBytes).get();
            final ServerHttpRequestDecorator decorator = new ServerHttpRequestDecorator(exchange.getRequest()) {
                @Override
                public Flux<DataBuffer> getBody() {
                    return Flux.just(buffer(totalBytes));
                }
            };
            final ServerCodecConfigurer configurer = ServerCodecConfigurer.create();
            final Mono<MultiValueMap<String, Part>> multiValueMapMono = repackageMultipartData(decorator, configurer);
            return multiValueMapMono.flatMap(part -> {
                for (String key : part.keySet()) {
                    // 如果为文件时 则进入下一次循环
                    if (key.equals("file")) {
                        continue;
                    }
                    part.getFirst(key).content().subscribe(buffer -> {
                        final byte[] bytes = new byte[buffer.readableByteCount()];
                        buffer.read(bytes);
                        DataBufferUtils.release(buffer);
                        try {
                            final String bodyString = new String(bytes, "utf-8");
                            gatewayContext.setCacheBody(bodyString);
                        } catch (UnsupportedEncodingException e) {
                            e.printStackTrace();
                        }
                    });
                }
                return chain.filter(exchange.mutate().request(decorator).build());
            });
        });
    }

    @SuppressWarnings("unchecked")
    private static Mono<MultiValueMap<String, Part>> repackageMultipartData(ServerHttpRequest request, ServerCodecConfigurer configurer) {
        try {
            final MediaType contentType = request.getHeaders().getContentType();
            if (MediaType.MULTIPART_FORM_DATA.isCompatibleWith(contentType)) {
                return ((HttpMessageReader<MultiValueMap<String, Part>>) configurer.getReaders().stream().filter(reader -> reader.canRead(MULTIPART_DATA_TYPE, MediaType.MULTIPART_FORM_DATA))
                        .findFirst().orElseThrow(() -> new IllegalStateException("No multipart HttpMessageReader."))).readMono(MULTIPART_DATA_TYPE, request, Collections.emptyMap())
                        .switchIfEmpty(EMPTY_MULTIPART_DATA).cache();
            }
        } catch (InvalidMediaTypeException ex) {
            // Ignore
        }
        return EMPTY_MULTIPART_DATA;
    }

    /**
     * addBytes.
     * @param first first
     * @param second second
     * @return byte
     */
    public byte[] addBytes(byte[] first, byte[] second) {
        final byte[] result = Arrays.copyOf(first, first.length + second.length);
        System.arraycopy(second, 0, result, first.length, second.length);
        return result;
    }

    private DataBuffer buffer(byte[] bytes) {
        final NettyDataBufferFactory nettyDataBufferFactory = new NettyDataBufferFactory(ByteBufAllocator.DEFAULT);
        final DataBuffer buffer = nettyDataBufferFactory.allocateBuffer(bytes.length);
        buffer.write(bytes);
        return buffer;
    }


    @Override
    public int getOrder() {
        return FilterOrderConstant.getOrder(this.getClass().getName());
    }
}

ApiVerifyFilter

@Component
@Slf4j
public class ApiVerifyFilter implements GlobalFilter, Ordered {

    @Override
    public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
        ServerHttpRequest request = exchange.getRequest();
        String url = request.getURI().getPath();
		
        if(request.getMethod() == HttpMethod.GET){
            // get请求 校验参数
            return verifyGetMethod(exchange, chain, request);
        }

        if(request.getMethod() == HttpMethod.POST){
            // post请求 校验参数
            return verifyPostMethod(exchange, chain, request);
        }
			
        return chain.filter(exchange);
    }

    /**
     * get请求 校验参数
     * @param exchange
     * @param chain
     * @param request
     * @return
     */
    private Mono<Void> verifyGetMethod(ServerWebExchange exchange, GatewayFilterChain chain, ServerHttpRequest request) {
	// get请求获取参数
        Map<String, String> queryParamMap = request.getQueryParams().toSingleValueMap();
		
	// 具体业务参数
        String secretId = queryParamMap.get("secretId");
        String secretKey = queryParamMap.get("secretKey");

	// 校验参数逻辑
        return verifyParams(exchange, chain, secretId, secretKey);
    }

    /**
     * post请求 校验参数
     * @param exchange
     * @param chain
     * @param request
     * @return
     */
    private Mono<Void> verifyPostMethod(ServerWebExchange exchange, GatewayFilterChain chain, ServerHttpRequest request) {
        try {
            GatewayContext gatewayContext = (GatewayContext)exchange.getAttributes().get(GatewayContext.CACHE_GATEWAY_CONTEXT);
            // get body from gatewayContext
            String cacheBody = gatewayContext.getCacheBody();

            Map map = new ObjectMapper().readValue(cacheBody, Map.class);

	    // 具体业务参数
            String secretId = String.valueOf(map.get("secretId"));
            String secretKey = String.valueOf(map.get("secretKey"));
           
	    // 校验参数逻辑
            return verifyParams(exchange, chain, secretId, secretKey);

        } catch (Exception e){
            log.error("解析body内容失败:{}", e);
            // 403
            return response(exchange, R.fail().enumCode(HttpCode.FORBIDDEN));
        }
    }

    /**
     * 校验参数
     * @param exchange
     * @param chain
     * @param secretId
     * @param secretKey
     * @return
     */
    private Mono<Void> verifyParams(ServerWebExchange exchange, GatewayFilterChain chain, String secretId, String secretKey) {
        // 校验失败,则返回相应提示
	// return response(exchange, R.fail().enumCode(HttpCode.UNAUTHORIZED));
        // todo

	// 校验成功,则当前过滤器执行完毕
        return chain.filter(exchange);
    }

    /**
     * response 返回code
     * @param exchange
     * @param r
     * @return
     */
    private Mono<Void> response(ServerWebExchange exchange, R r) {
        ServerHttpResponse originalResponse = exchange.getResponse();
        originalResponse.setStatusCode(HttpStatus.OK);
        originalResponse.getHeaders().add(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_UTF8_VALUE);

        try {
            byte[] bytes = new ObjectMapper().writeValueAsBytes(r);
            DataBuffer buffer = originalResponse.bufferFactory().wrap(bytes);
            return originalResponse.writeWith(Flux.just(buffer));
        } catch (JsonProcessingException e) {
            e.printStackTrace();
            return null;
        }

    }

    @Override
    public int getOrder() {
        return FilterOrderConstant.getOrder(this.getClass().getName());
    }
}

版权声明:程序员胖胖胖虎阿 发表于 2022年11月1日 下午1:56。
转载请注明:Spring Cloud Gateway 读取请求传参 | 胖虎的工具箱-编程导航

相关文章

暂无评论

暂无评论...