跳转至

6.5 中间件开发与应用

核心目标

理解中间件工作原理,能开发自定义中间件并合理应用


1. 中间件基础

什么是中间件?

在Gin框架中,中间件是一个特殊的函数,它符合func(*gin.Context)签名,能够拦截请求和响应的处理过程。中间件可以在请求到达处理器之前执行某些操作,也可以在处理器处理完成之后执行某些操作。

中间件的作用

中间件主要用于处理那些需要在多个请求处理器之间共享的功能,常见用途包括: - 日志记录:记录请求详情、处理时间等 - 认证授权:验证用户身份和权限 - 异常处理:捕获和处理程序运行时错误 - 性能监控:统计请求处理时间 - 数据转换:处理请求和响应数据格式 - 缓存控制:管理HTTP缓存策略

执行流程:洋葱模型

Gin中间件的执行流程遵循"洋葱模型",可以形象地理解为请求从外层逐层进入核心处理器,处理完成后再从核心逐层返回外层。

请求 → 中间件1 → 中间件2 → ... → 处理器 → ... → 中间件2 → 中间件1 → 响应

在代码中,通过c.Next()函数控制流程进入下一层,而c.Next()之后的代码会在后续中间件或处理器执行完成后才执行。


2. 内置中间件使用

Gin框架提供了多个实用的内置中间件,我们可以直接使用它们来快速实现常见功能。

gin.Logger():请求日志中间件

gin.Logger()用于记录请求日志,包括请求方法、路径、状态码、处理时间等信息。

package main

import (
    "net/http"
    "time"

    "github.com/gin-gonic/gin"
)

func main() {
    // 创建Gin引擎
    r := gin.Default() // 默认已包含Logger和Recovery中间件

    // 如果使用gin.New(),需要手动添加中间件
    // r := gin.New()
    // r.Use(gin.Logger())

    // 定义路由
    r.GET("/hello", func(c *gin.Context) {
        c.String(http.StatusOK, "Hello, World!")
    })

    // 启动服务器
    r.Run(":8080")
}

你也可以自定义日志格式和输出位置:

package main

import (
    "io"
    "net/http"
    "os"
    "time"

    "github.com/gin-gonic/gin"
)

func main() {
    r := gin.New()

    // 自定义日志格式
    logFormat := "%s - [%s] \"%s %s %s %d %s \"%s\" %s\"\n"
    r.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
        return fmt.Sprintf(
            logFormat,
            param.ClientIP,
            param.TimeStamp.Format(time.RFC1123),
            param.Method,
            param.Path,
            param.Request.Proto,
            param.StatusCode,
            param.Latency,
            param.Request.UserAgent(),
            param.ErrorMessage,
        )
    }))

    // 将日志输出到文件
    f, _ := os.Create("gin.log")
    r.Use(gin.LoggerWithWriter(io.MultiWriter(f, os.Stdout))) // 同时输出到文件和控制台

    r.GET("/hello", func(c *gin.Context) {
        c.String(http.StatusOK, "Hello, World!")
    })

    r.Run(":8080")
}

gin.Recovery():异常恢复中间件

gin.Recovery()用于捕获程序中的panic异常,防止服务器崩溃,并返回500状态码。

package main

import (
    "net/http"

    "github.com/gin-gonic/gin"
)

func main() {
    r := gin.New()
    // 添加Recovery中间件,捕获panic
    r.Use(gin.Recovery())

    // 这个路由会触发panic
    r.GET("/panic", func(c *gin.Context) {
        panic("故意触发的错误")
    })

    // 正常路由
    r.GET("/hello", func(c *gin.Context) {
        c.String(http.StatusOK, "Hello, World!")
    })

    r.Run(":8080")
}

你也可以自定义Recovery中间件的行为:

package main

import (
    "net/http"

    "github.com/gin-gonic/gin"
)

func main() {
    r := gin.New()

    // 自定义Recovery中间件
    r.Use(gin.CustomRecovery(func(c *gin.Context, err interface{}) {
        c.JSON(http.StatusInternalServerError, gin.H{
            "code":    500,
            "message": "服务器内部错误",
            "error":   err,
        })
    }))

    r.GET("/panic", func(c *gin.Context) {
        panic("故意触发的错误")
    })

    r.Run(":8080")
}

gin.BasicAuth():HTTP基础认证中间件

gin.BasicAuth()实现了HTTP基础认证功能,可以简单地保护路由访问。

package main

import (
    "net/http"

    "github.com/gin-gonic/gin"
)

func main() {
    r := gin.Default()

    // 定义认证用户
    authorized := r.Group("/admin", gin.BasicAuth(gin.Accounts{
        "admin": "password", // 用户名:密码
        "user":  "123456",
    }))

    // 受保护的路由
    authorized.GET("/dashboard", func(c *gin.Context) {
        // 获取当前认证的用户名
        user := c.MustGet(gin.AuthUserKey).(string)
        c.JSON(http.StatusOK, gin.H{"message": "欢迎访问管理员面板", "user": user})
    })

    // 公开路由
    r.GET("/", func(c *gin.Context) {
        c.String(http.StatusOK, "这是公开页面")
    })

    r.Run(":8080")
}

3. 自定义中间件开发

除了使用内置中间件,我们还可以根据实际需求开发自定义中间件。

日志中间件:记录详细请求信息

下面实现一个更详细的日志中间件,记录请求IP、用户代理、处理时间等信息:

package main

import (
    "fmt"
    "net/http"
    "time"

    "github.com/gin-gonic/gin"
)

// 自定义日志中间件
func RequestLogger() gin.HandlerFunc {
    return func(c *gin.Context) {
        // 请求前的操作:记录开始时间
        startTime := time.Now()

        // 进入下一个中间件或处理器
        c.Next()

        // 请求后的操作:计算耗时并记录日志
        endTime := time.Now()
        latency := endTime.Sub(startTime)

        // 获取请求信息
        method := c.Request.Method
        path := c.Request.URL.Path
        clientIP := c.ClientIP()
        userAgent := c.Request.UserAgent()
        statusCode := c.Writer.Status()

        // 打印日志
        fmt.Printf("[%s] %s | %d | %s | %s | %s | %s\n",
            endTime.Format("2006-01-02 15:04:05"),
            method,
            statusCode,
            clientIP,
            userAgent,
            path,
            latency,
        )
    }
}

func main() {
    r := gin.New()
    // 使用自定义日志中间件
    r.Use(RequestLogger())
    r.Use(gin.Recovery())

    r.GET("/hello", func(c *gin.Context) {
        time.Sleep(100 * time.Millisecond) // 模拟处理耗时
        c.String(http.StatusOK, "Hello, World!")
    })

    r.Run(":8080")
}

限流中间件:基于令牌桶算法限制请求频率

实现一个简单的限流中间件,控制单位时间内的请求数量:

package main

import (
    "net/http"
    "sync"
    "time"

    "github.com/gin-gonic/gin"
    "golang.org/x/time/rate"
)

// 限流中间件
func RateLimiter(limit rate.Limit, burst int) gin.HandlerFunc {
    // 创建一个令牌桶
    limiter := rate.NewLimiter(limit, burst)

    // 使用sync.Mutex确保并发安全
    var mu sync.Mutex

    return func(c *gin.Context) {
        mu.Lock()
        // 尝试获取令牌
        allowed := limiter.Allow()
        mu.Unlock()

        if !allowed {
            // 没有获取到令牌,返回429 Too Many Requests
            c.JSON(http.StatusTooManyRequests, gin.H{
                "code":    429,
                "message": "请求过于频繁,请稍后再试",
            })
            c.Abort() // 阻止继续处理
            return
        }

        // 允许请求继续处理
        c.Next()
    }
}

func main() {
    r := gin.Default()

    // 限制每秒最多2个请求,最多允许3个并发请求
    r.Use(RateLimiter(rate.Limit(2), 3))

    r.GET("/api/data", func(c *gin.Context) {
        c.JSON(http.StatusOK, gin.H{
            "code":    200,
            "message": "success",
            "data":    "这是受限访问的数据",
        })
    })

    // 不受限的路由
    freeGroup := r.Group("/free")
    freeGroup.Use() // 不使用限流中间件
    freeGroup.GET("/info", func(c *gin.Context) {
        c.JSON(http.StatusOK, gin.H{
            "code":    200,
            "message": "这是自由访问的信息",
        })
    })

    r.Run(":8080")
}

跨域中间件:处理CORS问题

实现一个处理跨域请求的中间件,设置必要的CORS头部:

package main

import (
    "net/http"
    "time"

    "github.com/gin-gonic/gin"
)

// 跨域中间件
func Cors() gin.HandlerFunc {
    return func(c *gin.Context) {
        // 允许的源,*表示允许所有源
        origin := c.Request.Header.Get("Origin")
        if origin != "" {
            c.Header("Access-Control-Allow-Origin", origin)
        }

        // 允许的请求方法
        c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")

        // 允许的请求头
        c.Header("Access-Control-Allow-Headers", "Origin, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization")

        // 允许客户端获取的响应头
        c.Header("Access-Control-Expose-Headers", "Content-Length")

        // 是否允许发送Cookie
        c.Header("Access-Control-Allow-Credentials", "true")

        // 预检请求的缓存时间
        c.Header("Access-Control-Max-Age", "86400") // 24小时

        // 处理OPTIONS请求
        if c.Request.Method == "OPTIONS" {
            c.AbortWithStatus(http.StatusNoContent)
            return
        }

        c.Next()
    }
}

func main() {
    r := gin.Default()

    // 使用跨域中间件
    r.Use(Cors())

    r.GET("/api/data", func(c *gin.Context) {
        c.JSON(http.StatusOK, gin.H{
            "code":    200,
            "message": "success",
            "data":    "这是可以跨域访问的数据",
        })
    })

    r.POST("/api/submit", func(c *gin.Context) {
        var data struct {
            Name  string `json:"name"`
            Email string `json:"email"`
        }

        if err := c.ShouldBindJSON(&data); err != nil {
            c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
            return
        }

        c.JSON(http.StatusOK, gin.H{
            "code":    200,
            "message": "提交成功",
            "data":    data,
        })
    })

    r.Run(":8080")
}

4. 中间件作用域

在Gin中,中间件可以应用在不同的作用域,满足不同的需求。

全局中间件

全局中间件作用于所有路由,通过engine.Use(middleware)方法注册:

package main

import (
    "net/http"

    "github.com/gin-gonic/gin"
)

// 全局日志中间件
func GlobalLogger() gin.HandlerFunc {
    return func(c *gin.Context) {
        // 请求前的操作
        println("全局中间件:请求开始 -", c.Request.Method, c.Request.URL.Path)

        // 继续处理
        c.Next()

        // 请求后的操作
        println("全局中间件:请求结束 -", c.Request.Method, c.Request.URL.Path)
    }
}

func main() {
    r := gin.New()

    // 注册全局中间件,作用于所有路由
    r.Use(GlobalLogger())
    r.Use(gin.Recovery())

    // 定义路由
    r.GET("/", func(c *gin.Context) {
        c.String(http.StatusOK, "首页")
    })

    r.GET("/about", func(c *gin.Context) {
        c.String(http.StatusOK, "关于我们")
    })

    r.Run(":8080")
}

分组中间件

分组中间件只作用于特定的路由分组,通过group.Use(middleware)方法注册:

package main

import (
    "net/http"

    "github.com/gin-gonic/gin"
)

// 全局日志中间件
func GlobalLogger() gin.HandlerFunc {
    return func(c *gin.Context) {
        println("全局中间件:请求开始")
        c.Next()
        println("全局中间件:请求结束")
    }
}

// API分组中间件
func APILogger() gin.HandlerFunc {
    return func(c *gin.Context) {
        println("API中间件:请求开始")
        c.Next()
        println("API中间件:请求结束")
    }
}

// 管理员认证中间件
func AdminAuth() gin.HandlerFunc {
    return func(c *gin.Context) {
        println("验证管理员身份...")
        // 这里可以添加实际的认证逻辑
        c.Next()
    }
}

func main() {
    r := gin.New()
    r.Use(GlobalLogger())
    r.Use(gin.Recovery())

    // 公开路由
    r.GET("/", func(c *gin.Context) {
        c.String(http.StatusOK, "首页")
    })

    // API路由分组,应用APILogger中间件
    api := r.Group("/api")
    api.Use(APILogger())
    {
        api.GET("/data", func(c *gin.Context) {
            c.JSON(http.StatusOK, gin.H{"data": "API数据"})
        })

        // 管理员API子分组,额外应用AdminAuth中间件
        admin := api.Group("/admin")
        admin.Use(AdminAuth())
        {
            admin.GET("/dashboard", func(c *gin.Context) {
                c.JSON(http.StatusOK, gin.H{"message": "管理员面板"})
            })
        }
    }

    r.Run(":8080")
}

路由中间件

路由中间件只作用于特定的路由,在定义路由时直接指定:

package main

import (
    "net/http"

    "github.com/gin-gonic/gin"
)

// 全局日志中间件
func GlobalLogger() gin.HandlerFunc {
    return func(c *gin.Context) {
        println("全局中间件:请求开始")
        c.Next()
        println("全局中间件:请求结束")
    }
}

// 特定路由的中间件
func PremiumAccess() gin.HandlerFunc {
    return func(c *gin.Context) {
        println("验证高级用户权限...")
        // 这里可以添加实际的权限验证逻辑
        c.Next()
    }
}

func main() {
    r := gin.New()
    r.Use(GlobalLogger())
    r.Use(gin.Recovery())

    // 普通路由,只使用全局中间件
    r.GET("/free-content", func(c *gin.Context) {
        c.String(http.StatusOK, "免费内容")
    })

    // 高级路由,除了全局中间件外,还使用PremiumAccess中间件
    r.GET("/premium-content", PremiumAccess(), func(c *gin.Context) {
        c.String(http.StatusOK, "高级付费内容")
    })

    // 可以为一个路由指定多个中间件
    r.GET("/super-premium", PremiumAccess(), func(c *gin.Context) {
        println("超级高级内容验证...")
        c.Next()
    }, func(c *gin.Context) {
        c.String(http.StatusOK, "超级高级付费内容")
    })

    r.Run(":8080")
}

实战:设计多层中间件

下面是一个综合示例,展示如何设计和使用多层中间件:

package main

import (
    "fmt"
    "net/http"
    "time"

    "github.com/gin-gonic/gin"
)

// 1. 日志中间件 - 记录请求基本信息和耗时
func LoggingMiddleware() gin.HandlerFunc {
    return func(c *gin.Context) {
        start := time.Now()
        fmt.Printf("【日志】开始处理 %s %s\n", c.Request.Method, c.Request.URL.Path)

        // 继续处理
        c.Next()

        // 请求处理完成后
        duration := time.Since(start)
        fmt.Printf("【日志】完成处理 %s %s,状态码: %d,耗时: %v\n",
            c.Request.Method, c.Request.URL.Path,
            c.Writer.Status(), duration)
    }
}

// 2. 认证中间件 - 验证用户是否已登录
func AuthMiddleware() gin.HandlerFunc {
    return func(c *gin.Context) {
        fmt.Println("【认证】开始验证用户身份")

        // 简单的认证逻辑:检查请求头中的Token
        token := c.GetHeader("Authorization")
        if token == "" || token != "valid-token" {
            c.JSON(http.StatusUnauthorized, gin.H{
                "code":    401,
                "message": "未授权访问,请提供有效的令牌",
            })
            c.Abort() // 阻止继续处理
            return
        }

        fmt.Println("【认证】用户身份验证通过")
        // 将用户信息存入上下文,供后续处理使用
        c.Set("user", "current_user")

        // 继续处理
        c.Next()

        fmt.Println("【认证】请求处理完成,清理认证信息")
    }
}

// 3. 权限检查中间件 - 验证用户是否有权限访问特定资源
func PermissionMiddleware(requiredRole string) gin.HandlerFunc {
    return func(c *gin.Context) {
        fmt.Printf("【权限】检查是否有 %s 权限\n", requiredRole)

        // 从上下文中获取用户信息
        user, exists := c.Get("user")
        if !exists {
            c.JSON(http.StatusUnauthorized, gin.H{
                "code":    401,
                "message": "用户信息不存在",
            })
            c.Abort()
            return
        }

        // 简单的权限检查逻辑
        // 实际应用中可能需要从数据库或缓存中查询用户权限
        userRoles := map[string][]string{
            "current_user": {"user", "admin"},
        }

        hasPermission := false
        for _, role := range userRoles[user.(string)] {
            if role == requiredRole {
                hasPermission = true
                break
            }
        }

        if !hasPermission {
            c.JSON(http.StatusForbidden, gin.H{
                "code":    403,
                "message": "没有访问权限",
            })
            c.Abort()
            return
        }

        fmt.Printf("【权限】已确认 %s 权限\n", requiredRole)
        c.Next()
    }
}

// 业务处理器
func getUserProfile(c *gin.Context) {
    user := c.MustGet("user").(string)
    c.JSON(http.StatusOK, gin.H{
        "code":    200,
        "message": "success",
        "data": gin.H{
            "username": user,
            "email":    "user@example.com",
            "role":     "admin",
        },
    })
}

func deleteUser(c *gin.Context) {
    userId := c.Param("id")
    c.JSON(http.StatusOK, gin.H{
        "code":    200,
        "message": fmt.Sprintf("用户 %s 已删除", userId),
    })
}

func main() {
    r := gin.New()

    // 全局中间件:日志和异常恢复
    r.Use(LoggingMiddleware())
    r.Use(gin.Recovery())

    // 公开路由
    r.GET("/", func(c *gin.Context) {
        c.String(http.StatusOK, "欢迎访问API服务器")
    })

    // 需要认证的路由分组
    authGroup := r.Group("/api")
    authGroup.Use(AuthMiddleware()) // 应用认证中间件
    {
        // 需要普通用户权限的路由
        authGroup.GET("/profile", getUserProfile)

        // 需要管理员权限的路由子分组
        adminGroup := authGroup.Group("/admin")
        adminGroup.Use(PermissionMiddleware("admin")) // 应用权限检查中间件
        {
            adminGroup.DELETE("/users/:id", deleteUser)
        }
    }

    r.Run(":8080")
}

这个示例展示了一个完整的多层中间件架构: 1. 全局日志中间件记录所有请求 2. 认证中间件验证用户身份 3. 权限检查中间件验证用户是否有权限访问特定资源 4. 最后才是具体的业务处理器

当访问/api/admin/users/123时,中间件的执行顺序是: 1. LoggingMiddleware(请求开始) 2. AuthMiddleware(验证身份) 3. PermissionMiddleware(检查管理员权限) 4. deleteUser(业务处理) 5. PermissionMiddleware(权限检查后处理) 6. AuthMiddleware(认证后处理) 7. LoggingMiddleware(请求结束,记录耗时)

这种多层中间件架构可以使代码职责清晰,便于维护和扩展。