Go语言快速入门Go语言快速入门
首页
基础篇
进阶篇
高阶篇
实战篇
Go官方网站
编程指南
首页
基础篇
进阶篇
高阶篇
实战篇
Go官方网站
编程指南
  • 高阶篇

    • 🎯 高阶篇
    • Context
    • 反射
    • 泛型
    • 性能优化
    • 最佳实践

Context

Context是Go并发编程的关键,用于传递取消信号、超时、截止时间和请求范围的值。我刚开始接触 Context 的时候,觉得它的作用不太明显。但在实际项目中处理HTTP请求、数据库操作、微服务调用时,才真正理解了它的重要性。没有 Context,你很难优雅地取消正在进行的操作!

为什么需要Context?

想象一个场景:用户发起HTTP请求,服务器需要查询数据库、调用外部API。如果用户取消了请求,我们希望所有操作都能停止,不浪费资源。

// ❌ 没有Context:无法优雅取消
func handleRequest() {
    go queryDatabase()   // 用户取消后还在执行
    go callExternalAPI() // 用户取消后还在执行
}

// ✅ 有Context:优雅取消
func handleRequest(ctx context.Context) {
    go queryDatabase(ctx)   // 会响应取消
    go callExternalAPI(ctx) // 会响应取消
}

Context接口

type Context interface {
    // 返回截止时间
    Deadline() (deadline time.Time, ok bool)
    
    // 返回一个Channel,Context被取消时关闭
    Done() <-chan struct{}
    
    // 返回取消原因
    Err() error
    
    // 获取键值对
    Value(key interface{}) interface{}
}

创建Context

context.Background()

根Context,永不取消,用于main函数、初始化、测试:

ctx := context.Background()

context.TODO()

不确定用什么时的占位符:

ctx := context.TODO()

context.WithCancel()

可取消的Context:

ctx, cancel := context.WithCancel(context.Background())

go func() {
    // 工作...
    select {
    case <-ctx.Done():
        fmt.Println("收到取消信号:", ctx.Err())
        return
    default:
        // 继续工作
    }
}()

// 取消
cancel()

context.WithTimeout()

带超时的Context:

// 3秒后自动取消
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()  // 好习惯:即使超时也调用cancel

select {
case <-ctx.Done():
    if ctx.Err() == context.DeadlineExceeded {
        fmt.Println("超时了!")
    }
}

context.WithDeadline()

带截止时间的Context:

deadline := time.Now().Add(5 * time.Second)
ctx, cancel := context.WithDeadline(context.Background(), deadline)
defer cancel()

context.WithValue()

携带值的Context:

type key string

ctx := context.WithValue(context.Background(), key("userID"), 123)

// 获取值
if userID, ok := ctx.Value(key("userID")).(int); ok {
    fmt.Println("User ID:", userID)
}

使用模式

模式1:超时控制

func fetchWithTimeout(url string, timeout time.Duration) ([]byte, error) {
    ctx, cancel := context.WithTimeout(context.Background(), timeout)
    defer cancel()
    
    req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
    if err != nil {
        return nil, err
    }
    
    resp, err := http.DefaultClient.Do(req)
    if err != nil {
        return nil, err
    }
    defer resp.Body.Close()
    
    return io.ReadAll(resp.Body)
}

// 使用
data, err := fetchWithTimeout("https://api.example.com", 5*time.Second)
if err != nil {
    if errors.Is(err, context.DeadlineExceeded) {
        fmt.Println("请求超时")
    }
}

模式2:取消长时间操作

func longOperation(ctx context.Context) error {
    for i := 0; i < 100; i++ {
        select {
        case <-ctx.Done():
            return ctx.Err()
        default:
            // 每次迭代检查是否取消
            time.Sleep(100 * time.Millisecond)
            fmt.Printf("步骤 %d\n", i)
        }
    }
    return nil
}

func main() {
    ctx, cancel := context.WithCancel(context.Background())
    
    go func() {
        time.Sleep(500 * time.Millisecond)
        fmt.Println("取消操作")
        cancel()
    }()
    
    if err := longOperation(ctx); err != nil {
        fmt.Println("操作被取消:", err)
    }
}

模式3:传递请求信息

type contextKey string

const (
    requestIDKey contextKey = "requestID"
    userIDKey    contextKey = "userID"
)

// 中间件:设置请求ID
func RequestIDMiddleware(next http.Handler) http.Handler {
    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        requestID := uuid.New().String()
        ctx := context.WithValue(r.Context(), requestIDKey, requestID)
        next.ServeHTTP(w, r.WithContext(ctx))
    })
}

// 获取请求ID
func GetRequestID(ctx context.Context) string {
    if id, ok := ctx.Value(requestIDKey).(string); ok {
        return id
    }
    return ""
}

// 在处理器中使用
func handler(w http.ResponseWriter, r *http.Request) {
    requestID := GetRequestID(r.Context())
    log.Printf("[%s] 处理请求", requestID)
}

实战案例:并发请求

package main

import (
    "context"
    "fmt"
    "sync"
    "time"
)

// 模拟API调用
func callAPI(ctx context.Context, name string, duration time.Duration) (string, error) {
    select {
    case <-time.After(duration):
        return fmt.Sprintf("%s响应成功", name), nil
    case <-ctx.Done():
        return "", fmt.Errorf("%s被取消: %w", name, ctx.Err())
    }
}

func fetchAllAPIs(ctx context.Context) ([]string, error) {
    apis := []struct {
        name     string
        duration time.Duration
    }{
        {"API-A", 100 * time.Millisecond},
        {"API-B", 200 * time.Millisecond},
        {"API-C", 150 * time.Millisecond},
    }
    
    results := make([]string, len(apis))
    errs := make([]error, len(apis))
    
    var wg sync.WaitGroup
    
    for i, api := range apis {
        wg.Add(1)
        go func(idx int, name string, dur time.Duration) {
            defer wg.Done()
            result, err := callAPI(ctx, name, dur)
            results[idx] = result
            errs[idx] = err
        }(i, api.name, api.duration)
    }
    
    wg.Wait()
    
    // 检查错误
    for _, err := range errs {
        if err != nil {
            return nil, err
        }
    }
    
    return results, nil
}

func main() {
    // 250ms超时,API-B会超时
    ctx, cancel := context.WithTimeout(context.Background(), 250*time.Millisecond)
    defer cancel()
    
    results, err := fetchAllAPIs(ctx)
    if err != nil {
        fmt.Println("错误:", err)
        return
    }
    
    for _, r := range results {
        fmt.Println(r)
    }
}

实战案例:数据库查询

package main

import (
    "context"
    "database/sql"
    "fmt"
    "time"
)

func queryUser(ctx context.Context, db *sql.DB, userID int) (*User, error) {
    // 查询时使用Context
    row := db.QueryRowContext(ctx, 
        "SELECT id, name, email FROM users WHERE id = ?", userID)
    
    var user User
    if err := row.Scan(&user.ID, &user.Name, &user.Email); err != nil {
        if err == sql.ErrNoRows {
            return nil, fmt.Errorf("用户不存在")
        }
        return nil, err
    }
    
    return &user, nil
}

func getUserWithTimeout(db *sql.DB, userID int) (*User, error) {
    // 3秒超时
    ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
    defer cancel()
    
    return queryUser(ctx, db, userID)
}

实战案例:HTTP服务器

package main

import (
    "context"
    "encoding/json"
    "log"
    "net/http"
    "time"
)

type Response struct {
    Message string `json:"message"`
    Time    string `json:"time"`
}

func slowHandler(w http.ResponseWriter, r *http.Request) {
    ctx := r.Context()
    
    // 模拟慢操作
    select {
    case <-time.After(5 * time.Second):
        resp := Response{
            Message: "操作完成",
            Time:    time.Now().Format(time.RFC3339),
        }
        json.NewEncoder(w).Encode(resp)
        
    case <-ctx.Done():
        log.Println("请求被取消:", ctx.Err())
        http.Error(w, "请求被取消", http.StatusRequestTimeout)
    }
}

func timeoutMiddleware(timeout time.Duration) func(http.Handler) http.Handler {
    return func(next http.Handler) http.Handler {
        return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
            ctx, cancel := context.WithTimeout(r.Context(), timeout)
            defer cancel()
            
            r = r.WithContext(ctx)
            next.ServeHTTP(w, r)
        })
    }
}

func main() {
    mux := http.NewServeMux()
    mux.HandleFunc("/slow", slowHandler)
    
    // 3秒超时
    handler := timeoutMiddleware(3 * time.Second)(mux)
    
    log.Println("服务器启动在 :8080")
    http.ListenAndServe(":8080", handler)
}

实战案例:优雅关闭

package main

import (
    "context"
    "fmt"
    "log"
    "net/http"
    "os"
    "os/signal"
    "syscall"
    "time"
)

func main() {
    server := &http.Server{
        Addr: ":8080",
        Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
            time.Sleep(2 * time.Second)  // 模拟处理
            fmt.Fprintln(w, "Hello!")
        }),
    }
    
    // 启动服务器
    go func() {
        log.Println("服务器启动在 :8080")
        if err := server.ListenAndServe(); err != http.ErrServerClosed {
            log.Fatalf("服务器错误: %v", err)
        }
    }()
    
    // 等待中断信号
    quit := make(chan os.Signal, 1)
    signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
    <-quit
    
    log.Println("收到关闭信号,开始优雅关闭...")
    
    // 给正在处理的请求10秒时间完成
    ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
    defer cancel()
    
    if err := server.Shutdown(ctx); err != nil {
        log.Fatalf("强制关闭: %v", err)
    }
    
    log.Println("服务器已关闭")
}

Context 最佳实践

1. Context应该是第一个参数

// ✅ 好
func DoSomething(ctx context.Context, arg Arg) error

// ❌ 不好
func DoSomething(arg Arg, ctx context.Context) error

2. 不要存储Context

// ❌ 不要这样做
type MyService struct {
    ctx context.Context  // 不要存储
}

// ✅ 作为参数传递
func (s *MyService) DoWork(ctx context.Context) error

3. 即使函数允许nil,也要传递context.TODO()

// ✅ 好
DoSomething(context.TODO(), arg)

// ❌ 不要传nil
DoSomething(nil, arg)

4. 使用自定义类型作为key

// ✅ 好:使用自定义类型避免冲突
type contextKey string
const userKey contextKey = "user"

// ❌ 不好:可能和其他包冲突
ctx = context.WithValue(ctx, "user", user)

5. 只传递请求范围的数据

// ✅ 适合用Value传递
// - 请求ID
// - 用户认证信息
// - 链路追踪信息

// ❌ 不适合
// - 可选参数
// - 函数依赖
// - 大型对象

练习

  1. 实现一个带超时的HTTP客户端
  2. 使用Context取消多个并发goroutine
  3. 实现一个带请求追踪的日志中间件
参考答案
// 1. 带超时的HTTP客户端
type HTTPClient struct {
    client  *http.Client
    timeout time.Duration
}

func NewHTTPClient(timeout time.Duration) *HTTPClient {
    return &HTTPClient{
        client:  &http.Client{},
        timeout: timeout,
    }
}

func (c *HTTPClient) Get(url string) ([]byte, error) {
    ctx, cancel := context.WithTimeout(context.Background(), c.timeout)
    defer cancel()
    
    req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
    if err != nil {
        return nil, err
    }
    
    resp, err := c.client.Do(req)
    if err != nil {
        return nil, err
    }
    defer resp.Body.Close()
    
    return io.ReadAll(resp.Body)
}

// 2. 取消多个goroutine
func cancelMultiple() {
    ctx, cancel := context.WithCancel(context.Background())
    
    var wg sync.WaitGroup
    
    for i := 0; i < 5; i++ {
        wg.Add(1)
        go func(id int) {
            defer wg.Done()
            for {
                select {
                case <-ctx.Done():
                    fmt.Printf("Worker %d 停止\n", id)
                    return
                default:
                    fmt.Printf("Worker %d 工作中\n", id)
                    time.Sleep(100 * time.Millisecond)
                }
            }
        }(i)
    }
    
    time.Sleep(500 * time.Millisecond)
    cancel()
    wg.Wait()
}

// 3. 请求追踪中间件
type requestIDKey struct{}

func TracingMiddleware(next http.Handler) http.Handler {
    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        requestID := r.Header.Get("X-Request-ID")
        if requestID == "" {
            requestID = fmt.Sprintf("%d", time.Now().UnixNano())
        }
        
        ctx := context.WithValue(r.Context(), requestIDKey{}, requestID)
        
        log.Printf("[%s] 开始处理 %s %s", requestID, r.Method, r.URL.Path)
        start := time.Now()
        
        next.ServeHTTP(w, r.WithContext(ctx))
        
        log.Printf("[%s] 完成处理,耗时 %v", requestID, time.Since(start))
    })
}

Context是Go并发的灵魂,下一节学习反射!

最近更新: 2025/12/27 13:26
Contributors: 王长安
Prev
🎯 高阶篇
Next
反射