Context
Context是Go并发编程的关键,用于传递取消信号、超时、截止时间和请求范围的值。我刚开始接触 Context 的时候,觉得它的作用不太明显。但在实际项目中处理HTTP请求、数据库操作、微服务调用时,才真正理解了它的重要性。没有 Context,你很难优雅地取消正在进行的操作!
为什么需要Context?
想象一个场景:用户发起HTTP请求,服务器需要查询数据库、调用外部API。如果用户取消了请求,我们希望所有操作都能停止,不浪费资源。
// ❌ 没有Context:无法优雅取消
func handleRequest() {
go queryDatabase() // 用户取消后还在执行
go callExternalAPI() // 用户取消后还在执行
}
// ✅ 有Context:优雅取消
func handleRequest(ctx context.Context) {
go queryDatabase(ctx) // 会响应取消
go callExternalAPI(ctx) // 会响应取消
}
Context接口
type Context interface {
// 返回截止时间
Deadline() (deadline time.Time, ok bool)
// 返回一个Channel,Context被取消时关闭
Done() <-chan struct{}
// 返回取消原因
Err() error
// 获取键值对
Value(key interface{}) interface{}
}
创建Context
context.Background()
根Context,永不取消,用于main函数、初始化、测试:
ctx := context.Background()
context.TODO()
不确定用什么时的占位符:
ctx := context.TODO()
context.WithCancel()
可取消的Context:
ctx, cancel := context.WithCancel(context.Background())
go func() {
// 工作...
select {
case <-ctx.Done():
fmt.Println("收到取消信号:", ctx.Err())
return
default:
// 继续工作
}
}()
// 取消
cancel()
context.WithTimeout()
带超时的Context:
// 3秒后自动取消
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel() // 好习惯:即使超时也调用cancel
select {
case <-ctx.Done():
if ctx.Err() == context.DeadlineExceeded {
fmt.Println("超时了!")
}
}
context.WithDeadline()
带截止时间的Context:
deadline := time.Now().Add(5 * time.Second)
ctx, cancel := context.WithDeadline(context.Background(), deadline)
defer cancel()
context.WithValue()
携带值的Context:
type key string
ctx := context.WithValue(context.Background(), key("userID"), 123)
// 获取值
if userID, ok := ctx.Value(key("userID")).(int); ok {
fmt.Println("User ID:", userID)
}
使用模式
模式1:超时控制
func fetchWithTimeout(url string, timeout time.Duration) ([]byte, error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return nil, err
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
return io.ReadAll(resp.Body)
}
// 使用
data, err := fetchWithTimeout("https://api.example.com", 5*time.Second)
if err != nil {
if errors.Is(err, context.DeadlineExceeded) {
fmt.Println("请求超时")
}
}
模式2:取消长时间操作
func longOperation(ctx context.Context) error {
for i := 0; i < 100; i++ {
select {
case <-ctx.Done():
return ctx.Err()
default:
// 每次迭代检查是否取消
time.Sleep(100 * time.Millisecond)
fmt.Printf("步骤 %d\n", i)
}
}
return nil
}
func main() {
ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(500 * time.Millisecond)
fmt.Println("取消操作")
cancel()
}()
if err := longOperation(ctx); err != nil {
fmt.Println("操作被取消:", err)
}
}
模式3:传递请求信息
type contextKey string
const (
requestIDKey contextKey = "requestID"
userIDKey contextKey = "userID"
)
// 中间件:设置请求ID
func RequestIDMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestID := uuid.New().String()
ctx := context.WithValue(r.Context(), requestIDKey, requestID)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// 获取请求ID
func GetRequestID(ctx context.Context) string {
if id, ok := ctx.Value(requestIDKey).(string); ok {
return id
}
return ""
}
// 在处理器中使用
func handler(w http.ResponseWriter, r *http.Request) {
requestID := GetRequestID(r.Context())
log.Printf("[%s] 处理请求", requestID)
}
实战案例:并发请求
package main
import (
"context"
"fmt"
"sync"
"time"
)
// 模拟API调用
func callAPI(ctx context.Context, name string, duration time.Duration) (string, error) {
select {
case <-time.After(duration):
return fmt.Sprintf("%s响应成功", name), nil
case <-ctx.Done():
return "", fmt.Errorf("%s被取消: %w", name, ctx.Err())
}
}
func fetchAllAPIs(ctx context.Context) ([]string, error) {
apis := []struct {
name string
duration time.Duration
}{
{"API-A", 100 * time.Millisecond},
{"API-B", 200 * time.Millisecond},
{"API-C", 150 * time.Millisecond},
}
results := make([]string, len(apis))
errs := make([]error, len(apis))
var wg sync.WaitGroup
for i, api := range apis {
wg.Add(1)
go func(idx int, name string, dur time.Duration) {
defer wg.Done()
result, err := callAPI(ctx, name, dur)
results[idx] = result
errs[idx] = err
}(i, api.name, api.duration)
}
wg.Wait()
// 检查错误
for _, err := range errs {
if err != nil {
return nil, err
}
}
return results, nil
}
func main() {
// 250ms超时,API-B会超时
ctx, cancel := context.WithTimeout(context.Background(), 250*time.Millisecond)
defer cancel()
results, err := fetchAllAPIs(ctx)
if err != nil {
fmt.Println("错误:", err)
return
}
for _, r := range results {
fmt.Println(r)
}
}
实战案例:数据库查询
package main
import (
"context"
"database/sql"
"fmt"
"time"
)
func queryUser(ctx context.Context, db *sql.DB, userID int) (*User, error) {
// 查询时使用Context
row := db.QueryRowContext(ctx,
"SELECT id, name, email FROM users WHERE id = ?", userID)
var user User
if err := row.Scan(&user.ID, &user.Name, &user.Email); err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("用户不存在")
}
return nil, err
}
return &user, nil
}
func getUserWithTimeout(db *sql.DB, userID int) (*User, error) {
// 3秒超时
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
return queryUser(ctx, db, userID)
}
实战案例:HTTP服务器
package main
import (
"context"
"encoding/json"
"log"
"net/http"
"time"
)
type Response struct {
Message string `json:"message"`
Time string `json:"time"`
}
func slowHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
// 模拟慢操作
select {
case <-time.After(5 * time.Second):
resp := Response{
Message: "操作完成",
Time: time.Now().Format(time.RFC3339),
}
json.NewEncoder(w).Encode(resp)
case <-ctx.Done():
log.Println("请求被取消:", ctx.Err())
http.Error(w, "请求被取消", http.StatusRequestTimeout)
}
}
func timeoutMiddleware(timeout time.Duration) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), timeout)
defer cancel()
r = r.WithContext(ctx)
next.ServeHTTP(w, r)
})
}
}
func main() {
mux := http.NewServeMux()
mux.HandleFunc("/slow", slowHandler)
// 3秒超时
handler := timeoutMiddleware(3 * time.Second)(mux)
log.Println("服务器启动在 :8080")
http.ListenAndServe(":8080", handler)
}
实战案例:优雅关闭
package main
import (
"context"
"fmt"
"log"
"net/http"
"os"
"os/signal"
"syscall"
"time"
)
func main() {
server := &http.Server{
Addr: ":8080",
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(2 * time.Second) // 模拟处理
fmt.Fprintln(w, "Hello!")
}),
}
// 启动服务器
go func() {
log.Println("服务器启动在 :8080")
if err := server.ListenAndServe(); err != http.ErrServerClosed {
log.Fatalf("服务器错误: %v", err)
}
}()
// 等待中断信号
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
<-quit
log.Println("收到关闭信号,开始优雅关闭...")
// 给正在处理的请求10秒时间完成
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := server.Shutdown(ctx); err != nil {
log.Fatalf("强制关闭: %v", err)
}
log.Println("服务器已关闭")
}
Context 最佳实践
1. Context应该是第一个参数
// ✅ 好
func DoSomething(ctx context.Context, arg Arg) error
// ❌ 不好
func DoSomething(arg Arg, ctx context.Context) error
2. 不要存储Context
// ❌ 不要这样做
type MyService struct {
ctx context.Context // 不要存储
}
// ✅ 作为参数传递
func (s *MyService) DoWork(ctx context.Context) error
3. 即使函数允许nil,也要传递context.TODO()
// ✅ 好
DoSomething(context.TODO(), arg)
// ❌ 不要传nil
DoSomething(nil, arg)
4. 使用自定义类型作为key
// ✅ 好:使用自定义类型避免冲突
type contextKey string
const userKey contextKey = "user"
// ❌ 不好:可能和其他包冲突
ctx = context.WithValue(ctx, "user", user)
5. 只传递请求范围的数据
// ✅ 适合用Value传递
// - 请求ID
// - 用户认证信息
// - 链路追踪信息
// ❌ 不适合
// - 可选参数
// - 函数依赖
// - 大型对象
练习
- 实现一个带超时的HTTP客户端
- 使用Context取消多个并发goroutine
- 实现一个带请求追踪的日志中间件
参考答案
// 1. 带超时的HTTP客户端
type HTTPClient struct {
client *http.Client
timeout time.Duration
}
func NewHTTPClient(timeout time.Duration) *HTTPClient {
return &HTTPClient{
client: &http.Client{},
timeout: timeout,
}
}
func (c *HTTPClient) Get(url string) ([]byte, error) {
ctx, cancel := context.WithTimeout(context.Background(), c.timeout)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return nil, err
}
resp, err := c.client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
return io.ReadAll(resp.Body)
}
// 2. 取消多个goroutine
func cancelMultiple() {
ctx, cancel := context.WithCancel(context.Background())
var wg sync.WaitGroup
for i := 0; i < 5; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
for {
select {
case <-ctx.Done():
fmt.Printf("Worker %d 停止\n", id)
return
default:
fmt.Printf("Worker %d 工作中\n", id)
time.Sleep(100 * time.Millisecond)
}
}
}(i)
}
time.Sleep(500 * time.Millisecond)
cancel()
wg.Wait()
}
// 3. 请求追踪中间件
type requestIDKey struct{}
func TracingMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestID := r.Header.Get("X-Request-ID")
if requestID == "" {
requestID = fmt.Sprintf("%d", time.Now().UnixNano())
}
ctx := context.WithValue(r.Context(), requestIDKey{}, requestID)
log.Printf("[%s] 开始处理 %s %s", requestID, r.Method, r.URL.Path)
start := time.Now()
next.ServeHTTP(w, r.WithContext(ctx))
log.Printf("[%s] 完成处理,耗时 %v", requestID, time.Since(start))
})
}
Context是Go并发的灵魂,下一节学习反射!
