Go语言快速入门Go语言快速入门
首页
基础篇
进阶篇
高阶篇
实战篇
Go官方网站
编程指南
首页
基础篇
进阶篇
高阶篇
实战篇
Go官方网站
编程指南
  • 进阶篇

    • 🚀 进阶篇
    • 方法
    • 接口
    • 错误处理
    • Goroutine
    • Channel
    • 包管理
    • 单元测试

单元测试

Go内置了强大的测试框架,写测试从未如此简单!我以前在其他语言里写测试,都需要安装各种第三方库。Go 把测试直接内置到标准库,还提供了基准测试、覆盖率等功能,真的太贴心了!

测试基础

测试文件命名

测试文件以 _test.go 结尾,和被测代码放在同一目录:

math/
├── math.go       // 源代码
└── math_test.go  // 测试代码

第一个测试

// math.go
package math

func Add(a, b int) int {
    return a + b
}
// math_test.go
package math

import "testing"

func TestAdd(t *testing.T) {
    result := Add(1, 2)
    expected := 3
    
    if result != expected {
        t.Errorf("Add(1, 2) = %d; want %d", result, expected)
    }
}

运行测试

# 运行当前目录测试
go test

# 显示详细输出
go test -v

# 运行所有包的测试
go test ./...

# 运行特定测试
go test -run TestAdd

# 运行匹配的测试
go test -run "TestAdd|TestSub"

测试函数类型

函数签名类型用途
TestXxx(t *testing.T)单元测试测试功能正确性
BenchmarkXxx(b *testing.B)基准测试测试性能
ExampleXxx()示例函数文档和验证

testing.T 方法

func TestSomething(t *testing.T) {
    // 记录日志(仅-v时显示)
    t.Log("开始测试")
    t.Logf("测试 %s", "某功能")
    
    // 标记失败,继续执行
    t.Error("出错了")
    t.Errorf("期望 %d,得到 %d", 1, 2)
    
    // 标记失败,立即停止
    t.Fatal("严重错误")
    t.Fatalf("严重错误: %v", err)
    
    // 跳过测试
    if testing.Short() {
        t.Skip("跳过长时间测试")
    }
    
    // 并行执行
    t.Parallel()
}

表格驱动测试

Go推荐的测试风格,清晰且易于扩展:

func TestAdd(t *testing.T) {
    tests := []struct {
        name     string
        a, b     int
        expected int
    }{
        {"positive numbers", 1, 2, 3},
        {"negative numbers", -1, -2, -3},
        {"mixed numbers", -1, 2, 1},
        {"zeros", 0, 0, 0},
    }
    
    for _, tt := range tests {
        t.Run(tt.name, func(t *testing.T) {
            result := Add(tt.a, tt.b)
            if result != tt.expected {
                t.Errorf("Add(%d, %d) = %d; want %d",
                    tt.a, tt.b, result, tt.expected)
            }
        })
    }
}

输出:

=== RUN   TestAdd
=== RUN   TestAdd/positive_numbers
=== RUN   TestAdd/negative_numbers
=== RUN   TestAdd/mixed_numbers
=== RUN   TestAdd/zeros
--- PASS: TestAdd (0.00s)
    --- PASS: TestAdd/positive_numbers (0.00s)
    --- PASS: TestAdd/negative_numbers (0.00s)
    --- PASS: TestAdd/mixed_numbers (0.00s)
    --- PASS: TestAdd/zeros (0.00s)

测试覆盖率

# 显示覆盖率
go test -cover

# 生成覆盖率文件
go test -coverprofile=coverage.out

# 查看详细覆盖
go tool cover -func=coverage.out

# HTML报告(推荐!)
go tool cover -html=coverage.out -o coverage.html

基准测试

测试代码性能:

func BenchmarkAdd(b *testing.B) {
    for i := 0; i < b.N; i++ {
        Add(1, 2)
    }
}

// 带输入的基准测试
func BenchmarkFibonacci(b *testing.B) {
    for i := 0; i < b.N; i++ {
        Fibonacci(20)
    }
}

// 多组输入
func BenchmarkFibonacci10(b *testing.B) {
    for i := 0; i < b.N; i++ {
        Fibonacci(10)
    }
}

func BenchmarkFibonacci20(b *testing.B) {
    for i := 0; i < b.N; i++ {
        Fibonacci(20)
    }
}

运行基准测试:

# 运行基准测试
go test -bench=.

# 显示内存分配
go test -bench=. -benchmem

# 运行特定基准测试
go test -bench=BenchmarkAdd

输出:

BenchmarkAdd-8          1000000000           0.25 ns/op        0 B/op       0 allocs/op
BenchmarkFibonacci-8        39428         29892 ns/op        0 B/op       0 allocs/op

示例函数

示例函数既是文档,也是测试:

func ExampleAdd() {
    result := Add(1, 2)
    fmt.Println(result)
    // Output: 3
}

func ExampleCalculator_Add() {
    calc := NewCalculator()
    result := calc.Add(1, 2)
    fmt.Println(result)
    // Output: 3
}

// 无序输出
func ExampleShuffle() {
    nums := []int{1, 2, 3}
    Shuffle(nums)
    fmt.Println(nums)
    // Unordered output:
    // [1 2 3]
}

测试辅助函数

测试Main

func TestMain(m *testing.M) {
    // 测试前的设置
    setup()
    
    // 运行所有测试
    code := m.Run()
    
    // 测试后的清理
    teardown()
    
    os.Exit(code)
}

func setup() {
    fmt.Println("测试开始前的设置")
}

func teardown() {
    fmt.Println("测试结束后的清理")
}

t.Helper()

func assertEqual(t *testing.T, got, want int) {
    t.Helper()  // 标记为辅助函数,错误显示调用者位置
    if got != want {
        t.Errorf("got %d, want %d", got, want)
    }
}

func TestAdd(t *testing.T) {
    assertEqual(t, Add(1, 2), 3)  // 失败时显示这一行
}

t.Cleanup()

func TestWithCleanup(t *testing.T) {
    // 创建临时文件
    f, err := os.CreateTemp("", "test")
    if err != nil {
        t.Fatal(err)
    }
    
    // 注册清理函数
    t.Cleanup(func() {
        f.Close()
        os.Remove(f.Name())
    })
    
    // 使用文件进行测试
    // ...
}

模拟和依赖注入

接口模拟

// 定义接口
type UserRepository interface {
    GetByID(id int) (*User, error)
    Save(user *User) error
}

// 实际实现
type MySQLUserRepository struct {
    db *sql.DB
}

// 模拟实现
type MockUserRepository struct {
    users map[int]*User
}

func (m *MockUserRepository) GetByID(id int) (*User, error) {
    if user, ok := m.users[id]; ok {
        return user, nil
    }
    return nil, errors.New("user not found")
}

func (m *MockUserRepository) Save(user *User) error {
    m.users[user.ID] = user
    return nil
}

// 服务使用接口
type UserService struct {
    repo UserRepository
}

func (s *UserService) GetUser(id int) (*User, error) {
    return s.repo.GetByID(id)
}

// 测试
func TestGetUser(t *testing.T) {
    mock := &MockUserRepository{
        users: map[int]*User{
            1: {ID: 1, Name: "张三"},
        },
    }
    
    service := &UserService{repo: mock}
    
    user, err := service.GetUser(1)
    if err != nil {
        t.Fatal(err)
    }
    
    if user.Name != "张三" {
        t.Errorf("got %s, want 张三", user.Name)
    }
}

实战案例:完整测试示例

被测代码

// calculator.go
package calculator

import "errors"

var (
    ErrDivideByZero = errors.New("division by zero")
)

type Calculator struct {
    memory float64
}

func New() *Calculator {
    return &Calculator{}
}

func (c *Calculator) Add(a, b float64) float64 {
    return a + b
}

func (c *Calculator) Subtract(a, b float64) float64 {
    return a - b
}

func (c *Calculator) Multiply(a, b float64) float64 {
    return a * b
}

func (c *Calculator) Divide(a, b float64) (float64, error) {
    if b == 0 {
        return 0, ErrDivideByZero
    }
    return a / b, nil
}

func (c *Calculator) MemoryStore(value float64) {
    c.memory = value
}

func (c *Calculator) MemoryRecall() float64 {
    return c.memory
}

func (c *Calculator) MemoryClear() {
    c.memory = 0
}

测试代码

// calculator_test.go
package calculator

import (
    "errors"
    "testing"
)

func TestNew(t *testing.T) {
    c := New()
    if c == nil {
        t.Error("New() returned nil")
    }
}

func TestAdd(t *testing.T) {
    tests := []struct {
        name     string
        a, b     float64
        expected float64
    }{
        {"positive", 1, 2, 3},
        {"negative", -1, -2, -3},
        {"mixed", -1, 2, 1},
        {"zero", 0, 0, 0},
        {"float", 1.5, 2.5, 4},
    }
    
    c := New()
    
    for _, tt := range tests {
        t.Run(tt.name, func(t *testing.T) {
            result := c.Add(tt.a, tt.b)
            if result != tt.expected {
                t.Errorf("Add(%v, %v) = %v; want %v",
                    tt.a, tt.b, result, tt.expected)
            }
        })
    }
}

func TestSubtract(t *testing.T) {
    c := New()
    
    result := c.Subtract(5, 3)
    if result != 2 {
        t.Errorf("Subtract(5, 3) = %v; want 2", result)
    }
}

func TestMultiply(t *testing.T) {
    c := New()
    
    tests := []struct {
        a, b     float64
        expected float64
    }{
        {2, 3, 6},
        {-2, 3, -6},
        {0, 5, 0},
    }
    
    for _, tt := range tests {
        result := c.Multiply(tt.a, tt.b)
        if result != tt.expected {
            t.Errorf("Multiply(%v, %v) = %v; want %v",
                tt.a, tt.b, result, tt.expected)
        }
    }
}

func TestDivide(t *testing.T) {
    c := New()
    
    t.Run("normal division", func(t *testing.T) {
        result, err := c.Divide(6, 2)
        if err != nil {
            t.Fatalf("unexpected error: %v", err)
        }
        if result != 3 {
            t.Errorf("Divide(6, 2) = %v; want 3", result)
        }
    })
    
    t.Run("division by zero", func(t *testing.T) {
        _, err := c.Divide(6, 0)
        if err == nil {
            t.Error("expected error, got nil")
        }
        if !errors.Is(err, ErrDivideByZero) {
            t.Errorf("got error %v; want ErrDivideByZero", err)
        }
    })
}

func TestMemory(t *testing.T) {
    c := New()
    
    // 测试存储和召回
    c.MemoryStore(42)
    if c.MemoryRecall() != 42 {
        t.Errorf("MemoryRecall() = %v; want 42", c.MemoryRecall())
    }
    
    // 测试清除
    c.MemoryClear()
    if c.MemoryRecall() != 0 {
        t.Errorf("after MemoryClear(), MemoryRecall() = %v; want 0",
            c.MemoryRecall())
    }
}

// 基准测试
func BenchmarkAdd(b *testing.B) {
    c := New()
    for i := 0; i < b.N; i++ {
        c.Add(1.5, 2.5)
    }
}

func BenchmarkDivide(b *testing.B) {
    c := New()
    for i := 0; i < b.N; i++ {
        c.Divide(10, 3)
    }
}

// 示例函数
func ExampleCalculator_Add() {
    c := New()
    result := c.Add(1, 2)
    fmt.Println(result)
    // Output: 3
}

func ExampleCalculator_Divide() {
    c := New()
    result, _ := c.Divide(10, 2)
    fmt.Println(result)
    // Output: 5
}

HTTP处理器测试

// handler.go
package api

import (
    "encoding/json"
    "net/http"
)

type User struct {
    ID   int    `json:"id"`
    Name string `json:"name"`
}

func GetUserHandler(w http.ResponseWriter, r *http.Request) {
    user := User{ID: 1, Name: "张三"}
    
    w.Header().Set("Content-Type", "application/json")
    json.NewEncoder(w).Encode(user)
}
// handler_test.go
package api

import (
    "encoding/json"
    "net/http"
    "net/http/httptest"
    "testing"
)

func TestGetUserHandler(t *testing.T) {
    // 创建请求
    req, err := http.NewRequest("GET", "/user", nil)
    if err != nil {
        t.Fatal(err)
    }
    
    // 创建ResponseRecorder
    rr := httptest.NewRecorder()
    
    // 调用处理器
    handler := http.HandlerFunc(GetUserHandler)
    handler.ServeHTTP(rr, req)
    
    // 检查状态码
    if status := rr.Code; status != http.StatusOK {
        t.Errorf("handler returned wrong status code: got %v want %v",
            status, http.StatusOK)
    }
    
    // 检查Content-Type
    contentType := rr.Header().Get("Content-Type")
    if contentType != "application/json" {
        t.Errorf("wrong content type: got %v want application/json",
            contentType)
    }
    
    // 检查响应体
    var user User
    if err := json.Unmarshal(rr.Body.Bytes(), &user); err != nil {
        t.Fatalf("failed to parse response: %v", err)
    }
    
    if user.Name != "张三" {
        t.Errorf("wrong name: got %v want 张三", user.Name)
    }
}

测试最佳实践

1. 测试命名

// 格式: Test<Function>_<Scenario>
func TestAdd_PositiveNumbers(t *testing.T) { }
func TestAdd_NegativeNumbers(t *testing.T) { }
func TestDivide_ByZero(t *testing.T) { }

2. 使用表格驱动测试

3. 避免测试私有函数

4. 测试边界条件

func TestSlice(t *testing.T) {
    tests := []struct {
        name  string
        input []int
    }{
        {"nil slice", nil},
        {"empty slice", []int{}},
        {"one element", []int{1}},
        {"many elements", []int{1, 2, 3, 4, 5}},
    }
    // ...
}

5. 使用t.Parallel()加速

func TestSomething(t *testing.T) {
    t.Parallel()  // 允许并行运行
    // ...
}

练习

  1. 为一个字符串反转函数编写测试
  2. 使用表格驱动测试测试一个验证邮箱的函数
  3. 为一个HTTP处理器编写测试
参考答案
// 1. 字符串反转测试
func Reverse(s string) string {
    runes := []rune(s)
    for i, j := 0, len(runes)-1; i < j; i, j = i+1, j-1 {
        runes[i], runes[j] = runes[j], runes[i]
    }
    return string(runes)
}

func TestReverse(t *testing.T) {
    tests := []struct {
        input    string
        expected string
    }{
        {"hello", "olleh"},
        {"", ""},
        {"a", "a"},
        {"世界", "界世"},
        {"Hello世界", "界世olleH"},
    }
    
    for _, tt := range tests {
        t.Run(tt.input, func(t *testing.T) {
            result := Reverse(tt.input)
            if result != tt.expected {
                t.Errorf("Reverse(%q) = %q; want %q",
                    tt.input, result, tt.expected)
            }
        })
    }
}

// 2. 邮箱验证测试
func ValidateEmail(email string) bool {
    pattern := `^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`
    matched, _ := regexp.MatchString(pattern, email)
    return matched
}

func TestValidateEmail(t *testing.T) {
    tests := []struct {
        email    string
        expected bool
    }{
        {"test@example.com", true},
        {"user.name@domain.org", true},
        {"invalid", false},
        {"@domain.com", false},
        {"user@", false},
        {"", false},
    }
    
    for _, tt := range tests {
        t.Run(tt.email, func(t *testing.T) {
            result := ValidateEmail(tt.email)
            if result != tt.expected {
                t.Errorf("ValidateEmail(%q) = %v; want %v",
                    tt.email, result, tt.expected)
            }
        })
    }
}

// 3. HTTP处理器测试
func HealthHandler(w http.ResponseWriter, r *http.Request) {
    w.WriteHeader(http.StatusOK)
    w.Write([]byte(`{"status":"ok"}`))
}

func TestHealthHandler(t *testing.T) {
    req := httptest.NewRequest("GET", "/health", nil)
    rr := httptest.NewRecorder()
    
    HealthHandler(rr, req)
    
    if rr.Code != http.StatusOK {
        t.Errorf("status = %d; want %d", rr.Code, http.StatusOK)
    }
    
    expected := `{"status":"ok"}`
    if rr.Body.String() != expected {
        t.Errorf("body = %q; want %q", rr.Body.String(), expected)
    }
}

🎉 恭喜你完成了进阶篇!

接下来进入 高阶篇,学习Context、反射、泛型等高级特性!

最近更新: 2025/12/27 13:26
Contributors: 王长安
Prev
包管理