泛型
Go 1.18引入泛型,让你写出类型安全的通用代码!我等这个特性等了好几年。以前写通用代码要么用 interface{} 丢失类型安全,要么就得写一堆重复代码。泛型出来之后,终于可以像其他语言一样优雅地写通用代码了!
为什么需要泛型?
在泛型之前,写通用代码很痛苦:
// ❌ 方案1:为每种类型写一遍
func MaxInt(a, b int) int {
if a > b { return a }
return b
}
func MaxFloat64(a, b float64) float64 {
if a > b { return a }
return b
}
// ❌ 方案2:使用interface{},失去类型安全
func Max(a, b interface{}) interface{} {
// 需要类型断言,容易出错
}
有了泛型:
// ✅ 一次编写,处理所有可比较的类型
func Max[T constraints.Ordered](a, b T) T {
if a > b {
return a
}
return b
}
// 使用
fmt.Println(Max(1, 2)) // 2
fmt.Println(Max(1.5, 2.5)) // 2.5
fmt.Println(Max("a", "b")) // b
基础语法
泛型函数
// [T any] 是类型参数
func Print[T any](value T) {
fmt.Println(value)
}
// 多个类型参数
func Pair[K, V any](key K, value V) map[K]V {
return map[K]V{key: value}
}
// 调用时可以省略类型(编译器推断)
Print(42) // Print[int](42)
Print("hello") // Print[string]("hello")
Pair("name", 123) // Pair[string, int]("name", 123)
泛型类型
// 泛型切片
type Slice[T any] []T
// 泛型Map
type Map[K comparable, V any] map[K]V
// 泛型结构体
type Pair[T, U any] struct {
First T
Second U
}
// 使用
var nums Slice[int] = []int{1, 2, 3}
var m Map[string, int] = map[string]int{"a": 1}
p := Pair[string, int]{"name", 42}
泛型方法
type Stack[T any] struct {
items []T
}
func (s *Stack[T]) Push(item T) {
s.items = append(s.items, item)
}
func (s *Stack[T]) Pop() (T, bool) {
if len(s.items) == 0 {
var zero T
return zero, false
}
item := s.items[len(s.items)-1]
s.items = s.items[:len(s.items)-1]
return item, true
}
func (s *Stack[T]) Len() int {
return len(s.items)
}
// 使用
stack := &Stack[int]{}
stack.Push(1)
stack.Push(2)
v, _ := stack.Pop() // 2
类型约束
any 约束
any 是 interface{} 的别名,接受任何类型:
func Print[T any](v T) {
fmt.Println(v)
}
comparable 约束
可比较的类型(可用 == 和 !=):
func Contains[T comparable](slice []T, target T) bool {
for _, v := range slice {
if v == target {
return true
}
}
return false
}
Contains([]int{1, 2, 3}, 2) // true
Contains([]string{"a", "b"}, "c") // false
constraints 包
import "golang.org/x/exp/constraints"
// Ordered:可排序的类型(支持 < > <= >=)
func Min[T constraints.Ordered](a, b T) T {
if a < b {
return a
}
return b
}
// Integer:所有整数类型
// Float:所有浮点类型
// Signed:有符号整数
// Unsigned:无符号整数
自定义约束
// 使用接口定义约束
type Stringer interface {
String() string
}
func PrintAll[T Stringer](items []T) {
for _, item := range items {
fmt.Println(item.String())
}
}
// 使用类型集合
type Number interface {
int | int32 | int64 | float32 | float64
}
func Sum[T Number](nums []T) T {
var sum T
for _, n := range nums {
sum += n
}
return sum
}
// 使用 ~ 表示底层类型
type MyInt int
type Integer interface {
~int | ~int32 | ~int64
}
func Double[T Integer](n T) T {
return n * 2
}
var x MyInt = 5
Double(x) // 可以,因为MyInt的底层类型是int
常用泛型函数
切片操作
// Map - 映射
func Map[T, U any](slice []T, fn func(T) U) []U {
result := make([]U, len(slice))
for i, v := range slice {
result[i] = fn(v)
}
return result
}
// Filter - 过滤
func Filter[T any](slice []T, fn func(T) bool) []T {
result := make([]T, 0)
for _, v := range slice {
if fn(v) {
result = append(result, v)
}
}
return result
}
// Reduce - 归约
func Reduce[T, U any](slice []T, initial U, fn func(U, T) U) U {
result := initial
for _, v := range slice {
result = fn(result, v)
}
return result
}
// 使用
nums := []int{1, 2, 3, 4, 5}
// 每个数翻倍
doubled := Map(nums, func(n int) int { return n * 2 })
// [2, 4, 6, 8, 10]
// 过滤偶数
evens := Filter(nums, func(n int) bool { return n%2 == 0 })
// [2, 4]
// 求和
sum := Reduce(nums, 0, func(acc, n int) int { return acc + n })
// 15
查找函数
// Find - 查找第一个匹配的元素
func Find[T any](slice []T, fn func(T) bool) (T, bool) {
for _, v := range slice {
if fn(v) {
return v, true
}
}
var zero T
return zero, false
}
// FindIndex - 查找索引
func FindIndex[T any](slice []T, fn func(T) bool) int {
for i, v := range slice {
if fn(v) {
return i
}
}
return -1
}
// All - 所有元素都满足条件
func All[T any](slice []T, fn func(T) bool) bool {
for _, v := range slice {
if !fn(v) {
return false
}
}
return true
}
// Any - 任一元素满足条件
func Any[T any](slice []T, fn func(T) bool) bool {
for _, v := range slice {
if fn(v) {
return true
}
}
return false
}
泛型数据结构
泛型栈
type Stack[T any] struct {
items []T
}
func NewStack[T any]() *Stack[T] {
return &Stack[T]{items: make([]T, 0)}
}
func (s *Stack[T]) Push(item T) {
s.items = append(s.items, item)
}
func (s *Stack[T]) Pop() (T, bool) {
if len(s.items) == 0 {
var zero T
return zero, false
}
n := len(s.items) - 1
item := s.items[n]
s.items = s.items[:n]
return item, true
}
func (s *Stack[T]) Peek() (T, bool) {
if len(s.items) == 0 {
var zero T
return zero, false
}
return s.items[len(s.items)-1], true
}
func (s *Stack[T]) IsEmpty() bool {
return len(s.items) == 0
}
func (s *Stack[T]) Size() int {
return len(s.items)
}
泛型队列
type Queue[T any] struct {
items []T
}
func NewQueue[T any]() *Queue[T] {
return &Queue[T]{items: make([]T, 0)}
}
func (q *Queue[T]) Enqueue(item T) {
q.items = append(q.items, item)
}
func (q *Queue[T]) Dequeue() (T, bool) {
if len(q.items) == 0 {
var zero T
return zero, false
}
item := q.items[0]
q.items = q.items[1:]
return item, true
}
func (q *Queue[T]) Peek() (T, bool) {
if len(q.items) == 0 {
var zero T
return zero, false
}
return q.items[0], true
}
泛型链表
type Node[T any] struct {
Value T
Next *Node[T]
}
type LinkedList[T any] struct {
Head *Node[T]
size int
}
func NewLinkedList[T any]() *LinkedList[T] {
return &LinkedList[T]{}
}
func (l *LinkedList[T]) Append(value T) {
node := &Node[T]{Value: value}
if l.Head == nil {
l.Head = node
} else {
current := l.Head
for current.Next != nil {
current = current.Next
}
current.Next = node
}
l.size++
}
func (l *LinkedList[T]) ToSlice() []T {
result := make([]T, 0, l.size)
current := l.Head
for current != nil {
result = append(result, current.Value)
current = current.Next
}
return result
}
泛型集合
type Set[T comparable] struct {
items map[T]struct{}
}
func NewSet[T comparable]() *Set[T] {
return &Set[T]{items: make(map[T]struct{})}
}
func (s *Set[T]) Add(item T) {
s.items[item] = struct{}{}
}
func (s *Set[T]) Remove(item T) {
delete(s.items, item)
}
func (s *Set[T]) Contains(item T) bool {
_, exists := s.items[item]
return exists
}
func (s *Set[T]) Size() int {
return len(s.items)
}
func (s *Set[T]) ToSlice() []T {
result := make([]T, 0, len(s.items))
for item := range s.items {
result = append(result, item)
}
return result
}
func (s *Set[T]) Union(other *Set[T]) *Set[T] {
result := NewSet[T]()
for item := range s.items {
result.Add(item)
}
for item := range other.items {
result.Add(item)
}
return result
}
func (s *Set[T]) Intersection(other *Set[T]) *Set[T] {
result := NewSet[T]()
for item := range s.items {
if other.Contains(item) {
result.Add(item)
}
}
return result
}
实战案例:泛型缓存
package main
import (
"sync"
"time"
)
type CacheItem[V any] struct {
Value V
Expiration time.Time
}
type Cache[K comparable, V any] struct {
items map[K]CacheItem[V]
mu sync.RWMutex
}
func NewCache[K comparable, V any]() *Cache[K, V] {
c := &Cache[K, V]{
items: make(map[K]CacheItem[V]),
}
go c.cleanup()
return c
}
func (c *Cache[K, V]) Set(key K, value V, ttl time.Duration) {
c.mu.Lock()
defer c.mu.Unlock()
c.items[key] = CacheItem[V]{
Value: value,
Expiration: time.Now().Add(ttl),
}
}
func (c *Cache[K, V]) Get(key K) (V, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
item, exists := c.items[key]
if !exists {
var zero V
return zero, false
}
if time.Now().After(item.Expiration) {
var zero V
return zero, false
}
return item.Value, true
}
func (c *Cache[K, V]) Delete(key K) {
c.mu.Lock()
defer c.mu.Unlock()
delete(c.items, key)
}
func (c *Cache[K, V]) cleanup() {
ticker := time.NewTicker(time.Minute)
for range ticker.C {
c.mu.Lock()
for key, item := range c.items {
if time.Now().After(item.Expiration) {
delete(c.items, key)
}
}
c.mu.Unlock()
}
}
func main() {
// 字符串->int缓存
intCache := NewCache[string, int]()
intCache.Set("count", 100, time.Hour)
if val, ok := intCache.Get("count"); ok {
fmt.Println("count:", val)
}
// 用户缓存
type User struct {
ID int
Name string
}
userCache := NewCache[int, User]()
userCache.Set(1, User{ID: 1, Name: "张三"}, time.Hour)
if user, ok := userCache.Get(1); ok {
fmt.Printf("User: %+v\n", user)
}
}
泛型最佳实践
1. 不要过度使用
// ❌ 过度使用:没必要泛型
func PrintInt[T ~int](n T) {
fmt.Println(n)
}
// ✅ 直接用int
func PrintInt(n int) {
fmt.Println(n)
}
2. 优先使用接口
// 如果接口够用,就用接口
type Writer interface {
Write([]byte) (int, error)
}
// 只有需要具体类型时才用泛型
func Max[T constraints.Ordered](a, b T) T
3. 命名约定
// T, U, V - 任意类型
// K - Key类型
// V - Value类型
// E - Element类型
type Map[K comparable, V any] map[K]V
type Slice[E any] []E
练习
- 实现一个泛型的
Reverse函数反转切片 - 实现一个泛型的
Unique函数去除切片重复元素 - 实现一个泛型的优先队列
参考答案
// 1. Reverse
func Reverse[T any](slice []T) []T {
result := make([]T, len(slice))
for i, v := range slice {
result[len(slice)-1-i] = v
}
return result
}
// 2. Unique
func Unique[T comparable](slice []T) []T {
seen := make(map[T]bool)
result := make([]T, 0)
for _, v := range slice {
if !seen[v] {
seen[v] = true
result = append(result, v)
}
}
return result
}
// 3. 优先队列
type PriorityQueue[T any] struct {
items []T
less func(a, b T) bool
}
func NewPriorityQueue[T any](less func(a, b T) bool) *PriorityQueue[T] {
return &PriorityQueue[T]{
items: make([]T, 0),
less: less,
}
}
func (pq *PriorityQueue[T]) Push(item T) {
pq.items = append(pq.items, item)
pq.up(len(pq.items) - 1)
}
func (pq *PriorityQueue[T]) Pop() (T, bool) {
if len(pq.items) == 0 {
var zero T
return zero, false
}
item := pq.items[0]
n := len(pq.items) - 1
pq.items[0] = pq.items[n]
pq.items = pq.items[:n]
if n > 0 {
pq.down(0)
}
return item, true
}
func (pq *PriorityQueue[T]) up(i int) {
for i > 0 {
parent := (i - 1) / 2
if !pq.less(pq.items[i], pq.items[parent]) {
break
}
pq.items[i], pq.items[parent] = pq.items[parent], pq.items[i]
i = parent
}
}
func (pq *PriorityQueue[T]) down(i int) {
for {
left := 2*i + 1
if left >= len(pq.items) {
break
}
smallest := left
if right := left + 1; right < len(pq.items) && pq.less(pq.items[right], pq.items[left]) {
smallest = right
}
if !pq.less(pq.items[smallest], pq.items[i]) {
break
}
pq.items[i], pq.items[smallest] = pq.items[smallest], pq.items[i]
i = smallest
}
}
泛型让Go更强大了,下一节学习性能优化!
