首页 > 其他分享 >超简单!手把手实现axum简易中间件

超简单!手把手实现axum简易中间件

时间:2024-02-06 20:44:25浏览次数:42  
标签:use axum 手把手 中间件 let new fn

axum是Rust语言tokio生态中的重要一环,以轻量、模块化、易用而闻名于世。它的中间件系统集成自另一个叫tower的框架,这就意味着如果我们要写axum的中间件的话,就得了解一下这个tower的各个核心概念,并学习它的用法。但是,很多时候我们可能只是想写一点简单的小工具,为了小需求去学习一个复杂的大库也许并不是很值得。

下图是tower文档里的中间件范例,可以在axum中使用……但是还是算了吧,让我们看看别的解决方法?

还好,axum贴心地为我们提供了一个小工具:middleware::from_fn 。它可以让我们像写Handler一样地写中间件,就像这样:

没有乱七八糟的trait、生命周期和Into<Box<dyn>>,一切都是如此的自然和简单!

axum的文档这样描述这种构建中间件的方法:

当你满足这些条件时,使用middleware::from_fn

  • 你不打算实现自己的Future,而更愿意用熟悉的async/await语法。
  • 你不打算将自己的中间件发布为一个crate供别人使用,因为像这样编写的中间件只能与axum兼容。

显然,我们没有上面的两种顾虑(自己的Future实现?这不是月薪三千还随时可能被裁的可怜rust程序员该考虑的)。我们的目标就是十分钟内搞定需求,所以我们就快乐地选择middleware::from_fn了。

如果读者想更加详细地了解axum的中间件生态,可以去看看官方文档:axum::middleware

准备

接下来,我将会手把手带领你在10分钟内(根据熟悉axum的程度因人而异)写完一个简单的axum中间件。中间件的功能很简单:假设我们系统的登录接口没有设置验证码,为了防止系统内的用户密码被爆破,我们需要对登录接口加一层限流的中间件,限制每个用户每分钟能调用登录接口10次左右。

为了简化教程的复杂度,我们用timedmap库来代替Redis;在实际生产中,建议用Redis来保证服务的可扩展性。


首先,我们创建一个项目,并导入这样的依赖:

[dependencies]
anyhow = "1.0.79"
axum = "0.7.4"
axum-client-ip = "0.5.0"
lazy_static = "1.4.0"
timedmap = { version = "1.0.1", features = ["tokio"] }
tokio = { version = "1.36.0", features = ["full"] }

接下来,我们快速地构建一个简单的axum应用:

use anyhow::Result;
use axum::{routing::get, Router};
use tokio::net::TcpListener;

#[tokio::main]
async fn main() -> Result<()> {
    let app = Router::new().route("/login", get(|| async {"Username or password invalid."}));
    let listener = TcpListener::bind("0.0.0.0:8080").await?;
    axum::serve(listener, app).await?;

    Ok(())
}

这个12行的程序可以在访问/login时返回一段文本,模拟一个成熟系统的登录功能。

运行效果

接下来,我们在它的基础上进行改造,为它编写一个中间件。

中间件

axum的from_fn中间件有这样的模板:

use axum::{
    response::Response,
    middleware::Next,
    extract::Request,
};

async fn my_middleware(
    request: Request,
    next: Next,
) -> Response {
    // do something with `request`...

    let response = next.run(request).await;

    // do something with `response`...

    response
}

简单介绍一下这个模板中的关键事项:

  • 必须是async fn形式的异步函数
  • 必须有一个或多个提取器作为函数参数(在上面的模板中,是Request
  • 最后一个参数必须是Next
  • 返回值必须实现了IntoResponse

在了解了这些事项之后,我们就可以着手写中间件了。

最简单的中间件

话不多说,直接上代码:

use std::time::{self, SystemTime};

async fn hello_world(request: Request, next: Next) -> Response {
    let now = SystemTime::now()
        .duration_since(time::UNIX_EPOCH)
        .unwrap()
        .as_millis();
    if now % 2 == 0 {
        "Surprise!".into_response()
    } else {
        next.run(request).await
    }
}

这个中间件以二分之一的概率拦截请求并返回一个“Surprise”,虽然它并没有什么用 ,但是它很好地为我们的后续开发开了一个头。它就像一个真正的handler一样,接受各种各样的参数,并返回一个Response;它也像真正的handler一样,可以直接返回内容,而不管后续的中间件和handler。

搞定了中间件,接下来就是把它插入到路由里了:

let app = Router::new()
    .route("/login", get(|| async { "Username or password invalid." }))
    .route_layer(middleware::from_fn(hello_world)); // 看这一行

我们调用了axum::middleware::from_fn方法,这是我们之前所提到过的;它将我们的函数转换为一个真正的中间件,如下图所示,它用一个宏把我们的短短几行代码展开成了一个towerService

现在我们再启动程序,刷新刚刚的页面,可以看到新的结果:

程序当前的代码如下:

use std::time::{self, SystemTime};

use anyhow::Result;
use axum::{
    extract::Request,
    middleware::{self, Next},
    response::{IntoResponse, Response},
    routing::get,
    Router,
};
use tokio::net::TcpListener;

async fn hello_world(request: Request, next: Next) -> Response {
    let now = SystemTime::now()
        .duration_since(time::UNIX_EPOCH)
        .unwrap()
        .as_millis();
    if now % 2 == 0 {
        "Surprise!".into_response()
    } else {
        next.run(request).await
    }
}

#[tokio::main]
async fn main() -> Result<()> {
    let app = Router::new()
        .route("/login", get(|| async { "Username or password invalid." }))
        .route_layer(middleware::from_fn(hello_world));
    let listener = TcpListener::bind("0.0.0.0:8080").await?;
    axum::serve(listener, app).await?;

    Ok(())
}

实战:登录限流

现在,我们可以继续我们的实战演练了。首先,我们准备一个数据结构来检查用户是否发出了过多请求:

struct RateLimiter {
    pub map: TimedMap<IpAddr, i32>,
}

impl RateLimiter {
    pub fn new() -> Self {
        Self {
            map: Default::default(),
        }
    }

    pub fn check(&self, ip_addr: &IpAddr) -> bool {
        let count = self.map.get(ip_addr).unwrap_or(0);
        if count > 10 {
            false
        }else{
            self.map.insert(ip_addr.clone(), count+1, Duration::from_secs(6));
            true
        }
    }
}

这里我们用了一个比较简单的算法来判断用户是否发送了过多请求:每次用户的请求到来时,我们从含有效期的哈希表中获取用户最近的连续请求次数;如果次数大于10次,我们判定用户请求量过大,返回false;否则,我们更新哈希表并返回true。这里可以注意到我们用的超时时长是6秒,这是因为每次更新数据时,有效期都会刷新,因此我们需要设置有效期为60秒/10次=6次/秒。


现在我们已经准备好了限流工具,那么我们要怎样将它放到中间件里呢?

也许有的读者刚接触axum,会说:

lazy_static!{
    pub static ref RATE_LIMITER: Mutex<RateLimiter> = Mutex::new(RateLimiter::new());
}

这种写法其实是不适合axum和异步编程的,具体的原因可以询问GPT;正确的做法是通过axum提供的State机制来将限流工具注入到中间件内:

#[derive(Clone)]
struct LimitState {
    pub limiter: Arc<RateLimiter>,
}

impl LimitState {
    pub fn new() -> Self {
        Self {
            limiter: Arc::new(RateLimiter::new()),
        }
    }
}

在这里,我们定义了一个用于中间件的State。应当注意,它内部的限流工具使用了Arc来包裹,并且State实现了Clone特质。这种写法可以保证limiter可以在多个线程/协程之间使用同一个RateLimiter实例。接下来,我们稍微修改中间件部分的from_fn,更换为from_fn_with_state

let state = LimitState::new();
let app = Router::new()
    .route("/login", get(|| async { "Username or password invalid." }))
    .route_layer(middleware::from_fn_with_state(state, rate_limit));  // 变化在这里

拦在我们面前的是最后一道难关:怎么获取用户的IP地址?这里我们直接使用axum-client-ip库,它提供了提取用户真实IP的工具,详情和配置方法可以参阅它的官方文档。在这里,我们只需要在主函数里额外加一行配置就好了:

use axum_client_ip::SecureClientIpSource;

let app = Router::new()
    .route("/login", get(|| async { "Username or password invalid." }))
    .route_layer(middleware::from_fn_with_state(state, rate_limit))
    .layer(SecureClientIpSource::ConnectInfo.into_extension());  // 在这里

在生产环境中,不要忘了把SecureClientIpSource更换成你应用的部署平台/CDN/反向代理使用的HTTP头;如果没有用CDN或者反代,则保留这样不变即可

此外,serve一行也要做一些修改:

axum::serve(listener, app.into_make_service_with_connect_info::<SocketAddr>()).await?;

一切障碍都已经清除,我们现在直接手起刀落,写完我们的中间件:

async fn rate_limit(
    State(state): State<LimitState>,
    SecureClientIp(ip_addr): SecureClientIp,
    request: Request,
    next: Next,
) -> Response {
    if state.limiter.check(&ip_addr) {
        next.run(request).await
    } else {
        (
            StatusCode::TOO_MANY_REQUESTS,
            "Too many requests! Please wait for one minute.",
        )
            .into_response()
    }
}

接下来就是测试环节!

可以看到,我们的中间件已经开始工作啦~

最后附上程序的完整代码:

use std::{
    net::{IpAddr, SocketAddr},
    sync::Arc,
    time::Duration,
};

use anyhow::Result;
use axum::{
    extract::{Request, State},
    http::StatusCode,
    middleware::{self, Next},
    response::{IntoResponse, Response},
    routing::get,
    Router,
};
use axum_client_ip::{SecureClientIp, SecureClientIpSource};
use timedmap::TimedMap;
use tokio::net::TcpListener;

struct RateLimiter {
    pub map: TimedMap<IpAddr, i32>,
}

impl RateLimiter {
    pub fn new() -> Self {
        Self {
            map: Default::default(),
        }
    }

    pub fn check(&self, ip_addr: &IpAddr) -> bool {
        let count = self.map.get(ip_addr).unwrap_or(0);
        if count > 10 {
            false
        } else {
            self.map
                .insert(ip_addr.clone(), count + 1, Duration::from_secs(6));
            true
        }
    }
}

#[derive(Clone)]
struct LimitState {
    pub limiter: Arc<RateLimiter>,
}

impl LimitState {
    pub fn new() -> Self {
        Self {
            limiter: Arc::new(RateLimiter::new()),
        }
    }
}

async fn rate_limit(
    State(state): State<LimitState>,
    SecureClientIp(ip_addr): SecureClientIp,
    request: Request,
    next: Next,
) -> Response {
    if state.limiter.check(&ip_addr) {
        next.run(request).await
    } else {
        (
            StatusCode::TOO_MANY_REQUESTS,
            "Too many requests! Please wait for one minute.",
        )
            .into_response()
    }
}

#[tokio::main]
async fn main() -> Result<()> {
    let state = LimitState::new();
    let app = Router::new()
        .route("/login", get(|| async { "Username or password invalid." }))
        .route_layer(middleware::from_fn_with_state(state, rate_limit))
        .layer(SecureClientIpSource::ConnectInfo.into_extension());
    let listener = TcpListener::bind("0.0.0.0:8080").await?;
    axum::serve(
        listener,
        app.into_make_service_with_connect_info::<SocketAddr>(),
    )
    .await?;

    Ok(())
}

标签:use,axum,手把手,中间件,let,new,fn
From: https://www.cnblogs.com/cinea/p/18010282

相关文章

  • 手把手教你如何下载有道领世上面已购买的视频课程
    前言:很多小伙伴都想知道有道领世的视频课程怎么下载,但是有道领世上面已购买的视频课程是不提供直接下载方式的,所以下面就教大家如何用学无止下载器下载高途上面已购买的视频课程。一、下载器首页输入Y回车,登录有道账号,选择课程序号,即可下载二、此时会有弹窗让你登录,选择对应的......
  • 手把手搭建QEMU ARM64开发环境
    根据上篇我们讲了搭建ARM32QEMU环境没看到的小伙伴可以看下https://mp.weixin.qq.com/s?__biz=MzUyNDUyMDQyNQ==&mid=2247483838&idx=1&sn=87a65f10e558bdfc35277153d4b42f6a&chksm=fa2d5f38cd5ad62ead217bd0efe857b2ac06e1a14042cacb488f926e8791b75f28c6ec930c4f&token=420704......
  • 第六十四天 csrf, auth,中间件插拔解释
    一、csrf跨站请求伪造1.简介 钓鱼网站:假设是一个跟银行一模一样的网址页面用户在该页面上转账 账户的钱会减少但是受益人却不是自己想要转账的那个人2.模拟一台计算机上两个服务端不同端口启动钓鱼网站提交地址改为正规网站的地址deftransfer(request):ifreques......
  • 手把手教你搭建属于自己的网站(获取被动收入),无需服务器,使用github托管
    大家好,我是亚洲著名程序员青松,本次教大家如何搭建一个属于自己的网站。下面是我自己搭建的一个网站,是一个网址导航网站。托管在了github上面,目前已经运营了三个月,每天的访问量大约有100ip左右。下图是在51.la上面的统计,这个网站是我在2023年11月份发布的,刚发布的时候流量比较高......
  • 中间件漏洞
    中间件漏洞IIS服务器漏洞IIS文件上传漏洞IIS6.0PUT上传漏洞是比较经典的ISS漏洞,如果IIS开启了PUT上传方法,就可以利用此方法上传任意文件,因此,该漏洞危害极大。漏洞产生原因IIS6.0PUT上传漏洞产生的原因IISServer在WEB服务扩展中开启了WebDAVIIS配置了可以写入的权限漏......
  • 第六十三天 cookie, session与Django中间件
    一、cookie与session简介"""HTTP协议四大特性1.基于请求响应2.基于TCP、IP作用于应用层之上协议 3.无状态服务端无法识别客户端的状态 1.互联网刚开始兴起的的时候所有人访问网址都是一样的数据 服务端无法识别客户端问题不大 2.互联网发展淘宝、京东、阿里 服务端......
  • 幻兽帕鲁专用服务器搭建教程分享(手把手教学)
    想要快速搭建幻兽帕鲁服务器,我们只需要参考以下教程即可轻易完成幻兽帕鲁服务器的搭建部署,与其他专用服务器游戏一样,可以让您和朋友在一个相对独立、稳定且私密的云端跨境中进行游戏,以获得更好、更流畅的游戏体验。幻兽帕鲁游戏和steam平台作为国外服务,使用大陆服务器会有......
  • django 项目中,用户登录功能中间件的应用
    不是完整的Demo,简单记录下。在Django项目中,中间件(Middleware)是一个轻量级、底层的插件系统,用于全局修改Django的输入或输出。每个中间件是一个处理请求或响应的钩子,可以在视图执行之前或之后运行代码。对于用户登录功能,中间件可以用来处理多种任务,比如:验证用户的登录状态:在每......
  • 手把手教你如何创建并上传modelscope模型
    参考来源:https://modelscope.cn/docs/模型的创建与文件上传1.注册modelscope相关账号(略)2.创建对应的模型3.填写模型的相关资料4.创建审核通过了之后,下载对应的模型文件夹5.拷贝对应的上传脚本,可以根据上面的页面复制使用modelscope的SDK脚本6.需要获取用户特......
  • [Express]中间件监听不同事件
    监听req的data事件在中间件中,需要监听req对象的data事件,来获取客户端发送到服务器的数据。如果数据量比较大,无法一次性发送完毕,则客户端会把数据切割后,分批发送到服务器。所以data事件可能会触发多次,每一次触发data事件时,获取到数据只是完整数据的一部分,需要手动对接收到的......