3.4 泛型编程深度解析¶
概述¶
Go语言在1.18版本中正式引入了泛型(Generics)特性,这是自Go语言诞生以来最重要的语言特性之一。泛型允许我们编写不依赖于具体类型的代码,从而在保证类型安全的同时极大地提高代码复用性。本章节将深入探讨Go泛型的设计理念、语法规则、最佳实践以及实际应用场景。
学习目标¶
- 深入理解Go语言泛型的设计理念
- 掌握类型约束的定义与使用
- 熟练编写泛型函数与泛型类型
- 了解泛型的优势、限制与最佳实践
学习内容¶
泛型的基本语法(Go 1.18+)¶
泛型的核心思想是将类型参数化,允许函数或类型在定义时不指定具体类型,而是在使用时再确定。
类型参数的定义¶
在Go中,类型参数通过方括号[]来定义,放在函数名或类型名之后:
泛型函数的声明¶
泛型函数在普通函数的基础上增加了类型参数列表:
下面是一个简单的泛型函数示例,实现两个值的交换:
package main
import "fmt"
// 泛型交换函数,可以交换任意类型的两个值
func Swap[T any](a, b *T) {
*a, *b = *b, *a
}
func main() {
// 交换整数
x, y := 10, 20
fmt.Printf("交换前: x=%d, y=%d\n", x, y)
Swap(&x, &y)
fmt.Printf("交换后: x=%d, y=%d\n", x, y)
// 交换字符串
s1, s2 := "hello", "world"
fmt.Printf("交换前: s1=%s, s2=%s\n", s1, s2)
Swap(&s1, &s2)
fmt.Printf("交换后: s1=%s, s2=%s\n", s1, s2)
}
泛型类型的定义¶
除了函数,结构体、接口、映射等也可以是泛型的:
package main
import "fmt"
// 泛型容器类型
type Container[T any] struct {
items []T
}
// 泛型容器的方法
func (c *Container[T]) Add(item T) {
c.items = append(c.items, item)
}
func (c *Container[T]) Get(index int) (T, bool) {
if index >= 0 && index < len(c.items) {
return c.items[index], true
}
var zero T
return zero, false
}
func main() {
// 整数容器
intContainer := &Container[int]{}
intContainer.Add(10)
intContainer.Add(20)
if val, ok := intContainer.Get(0); ok {
fmt.Println("整数容器第一个元素:", val)
}
// 字符串容器
strContainer := &Container[string]{}
strContainer.Add("hello")
strContainer.Add("generics")
if val, ok := strContainer.Get(1); ok {
fmt.Println("字符串容器第二个元素:", val)
}
}
类型推断机制¶
Go的泛型实现了强大的类型推断功能,在很多情况下不需要显式指定类型参数:
package main
import "fmt"
// 泛型函数:返回两个值中较大的一个
func Max[T int | float64 | string](a, b T) T {
if a > b {
return a
}
return b
}
func main() {
// 不需要显式指定类型参数,编译器会自动推断
fmt.Println(Max(10, 20)) // int类型
fmt.Println(Max(3.14, 2.71)) // float64类型
fmt.Println(Max("apple", "banana")) // string类型
// 也可以显式指定类型参数
fmt.Println(Max[int](100, 200))
}
类型约束的定义与使用¶
类型约束限定了类型参数可以接受的类型范围,确保泛型代码能够安全地使用类型的特定操作。
接口作为类型约束¶
在Go中,接口可以作为类型约束,指定类型参数必须实现的方法:
package main
import "fmt"
// 定义一个接口作为类型约束
type MathOperations interface {
int | float64
}
// 泛型函数:计算两个数的和
func Add[T MathOperations](a, b T) T {
return a + b
}
// 另一个接口约束示例
type Stringer interface {
String() string
}
// 泛型函数:打印任何实现了String()方法的类型
func PrintAsString[T Stringer](t T) {
fmt.Println(t.String())
}
// 自定义类型实现Stringer接口
type MyInt int
func (m MyInt) String() string {
return fmt.Sprintf("MyInt: %d", m)
}
func main() {
fmt.Println(Add(10, 20)) // 30
fmt.Println(Add(3.14, 2.71)) // 5.85
var num MyInt = 42
PrintAsString(num) // 输出: MyInt: 42
}
内置约束类型¶
Go标准库在constraints包中提供了一些常用的约束类型(需要导入golang.org/x/exp/constraints):
package main
import (
"fmt"
"golang.org/x/exp/constraints"
)
// 使用内置的Ordered约束,支持比较操作
func FindMax[T constraints.Ordered](slice []T) T {
if len(slice) == 0 {
panic("slice is empty")
}
max := slice[0]
for _, v := range slice[1:] {
if v > max {
max = v
}
}
return max
}
func main() {
ints := []int{3, 1, 4, 1, 5, 9}
fmt.Println("最大整数:", FindMax(ints))
floats := []float64{3.14, 1.59, 2.65}
fmt.Println("最大浮点数:", FindMax(floats))
strings := []string{"apple", "banana", "cherry"}
fmt.Println("最大字符串:", FindMax(strings))
}
自定义约束接口¶
我们可以定义更复杂的自定义约束,组合多个类型和接口:
package main
import "fmt"
// 定义一个自定义约束
type Number interface {
int | int8 | int16 | int32 | int64 |
uint | uint8 | uint16 | uint32 | uint64 | uintptr |
float32 | float64 | complex64 | complex128
}
// 另一个自定义约束,组合了Number和一个方法
type NumericWithAbs[T Number] interface {
Number
Abs() T
}
// 为float64实现Abs方法
type MyFloat float64
func (f MyFloat) Abs() MyFloat {
if f < 0 {
return -f
}
return f
}
// 泛型函数,使用自定义约束
func GetAbsoluteValue[T NumericWithAbs[T]](x T) T {
return x.Abs()
}
func main() {
var f MyFloat = -3.14
fmt.Println("绝对值:", GetAbsoluteValue(f)) // 输出: 3.14
}
约束的组合与嵌套¶
约束可以通过嵌套和组合创建更复杂的约束条件:
package main
import "fmt"
// 基础约束:数值类型
type Number interface {
int | float64
}
// 基础约束:可打印类型
type Printable interface {
fmt.Stringer
}
// 组合约束:既是数值又是可打印的类型
type PrintableNumber interface {
Number
Printable
}
// 实现约束的类型
type Measurable int
func (m Measurable) String() string {
return fmt.Sprintf("Measurable(%d)", m)
}
// 使用组合约束的泛型函数
func Process[T PrintableNumber](value T) {
fmt.Printf("Processing %s, squared value: %v\n", value, value*value)
}
func main() {
var m Measurable = 5
Process(m) // 输出: Processing Measurable(5), squared value: 25
}
泛型函数与泛型类型¶
泛型函数的实现¶
泛型函数可以处理多种类型的输入,同时保持类型安全:
package main
import "fmt"
// 泛型函数:计算切片的总和
func Sum[T int | float64](slice []T) T {
var total T
for _, v := range slice {
total += v
}
return total
}
// 泛型函数:过滤切片元素
func Filter[T any](slice []T, predicate func(T) bool) []T {
var result []T
for _, v := range slice {
if predicate(v) {
result = append(result, v)
}
}
return result
}
func main() {
// 测试Sum函数
ints := []int{1, 2, 3, 4, 5}
fmt.Println("整数总和:", Sum(ints)) // 15
floats := []float64{1.1, 2.2, 3.3}
fmt.Println("浮点数总和:", Sum(floats)) // 6.6
// 测试Filter函数
evenNumbers := Filter(ints, func(n int) bool {
return n%2 == 0
})
fmt.Println("偶数:", evenNumbers) // [2 4]
largeNumbers := Filter(floats, func(f float64) bool {
return f > 2.0
})
fmt.Println("大于2.0的数:", largeNumbers) // [2.2 3.3]
}
泛型结构体设计¶
泛型结构体可以包含任意类型的字段,提供通用的数据结构:
package main
import "fmt"
// 泛型键值对结构体
type Pair[K comparable, V any] struct {
Key K
Value V
}
// 泛型链表节点
type Node[T any] struct {
Data T
Next *Node[T]
}
// 泛型链表
type LinkedList[T any] struct {
Head *Node[T]
Size int
}
// 向链表添加元素
func (l *LinkedList[T]) Add(data T) {
newNode := &Node[T]{Data: data}
if l.Head == nil {
l.Head = newNode
} else {
current := l.Head
for current.Next != nil {
current = current.Next
}
current.Next = newNode
}
l.Size++
}
// 打印链表元素
func (l *LinkedList[T]) Print() {
current := l.Head
for current != nil {
fmt.Printf("%v ", current.Data)
current = current.Next
}
fmt.Println()
}
func main() {
// 使用Pair结构体
p1 := Pair[string, int]{Key: "age", Value: 30}
p2 := Pair[int, string]{Key: 100, Value: "score"}
fmt.Println(p1, p2)
// 使用链表
intList := &LinkedList[int]{}
intList.Add(10)
intList.Add(20)
intList.Add(30)
fmt.Print("整数链表: ")
intList.Print()
strList := &LinkedList[string]{}
strList.Add("hello")
strList.Add("generics")
fmt.Print("字符串链表: ")
strList.Print()
}
泛型方法定义¶
泛型类型可以有方法,这些方法可以使用类型参数或引入新的类型参数:
package main
import "fmt"
// 泛型集合类型
type Set[T comparable] struct {
elements map[T]struct{}
}
// 创建新集合
func NewSet[T comparable]() *Set[T] {
return &Set[T]{
elements: make(map[T]struct{}),
}
}
// 向集合添加元素
func (s *Set[T]) Add(element T) {
s.elements[element] = struct{}{}
}
// 检查元素是否在集合中
func (s *Set[T]) Contains(element T) bool {
_, exists := s.elements[element]
return exists
}
// 泛型方法:与另一个集合的交集
func (s *Set[T]) Intersection(other *Set[T]) *Set[T] {
result := NewSet[T]()
for elem := range s.elements {
if other.Contains(elem) {
result.Add(elem)
}
}
return result
}
// 泛型方法:转换为切片
func (s *Set[T]) ToSlice() []T {
slice := make([]T, 0, len(s.elements))
for elem := range s.elements {
slice = append(slice, elem)
}
return slice
}
// 带有额外类型参数的泛型方法
func (s *Set[T]) Map[U any](f func(T) U) []U {
result := make([]U, 0, len(s.elements))
for elem := range s.elements {
result = append(result, f(elem))
}
return result
}
func main() {
set1 := NewSet[int]()
set1.Add(1)
set1.Add(2)
set1.Add(3)
set2 := NewSet[int]()
set2.Add(2)
set2.Add(3)
set2.Add(4)
intersection := set1.Intersection(set2)
fmt.Println("交集:", intersection.ToSlice()) // [2 3]
// 使用Map方法转换类型
strSet := NewSet[string]()
strSet.Add("1")
strSet.Add("2")
strSet.Add("3")
intSlice := strSet.Map(func(s string) int {
// 简化示例,实际使用中应处理错误
var result int
fmt.Sscanf(s, "%d", &result)
return result
})
fmt.Println("转换后的整数切片:", intSlice) // [1 2 3]
}
类型实例化过程¶
类型实例化是将泛型类型或函数与具体类型结合的过程:
package main
import "fmt"
// 泛型函数
func Reverse[T any](slice []T) []T {
length := len(slice)
result := make([]T, length)
for i := 0; i < length; i++ {
result[i] = slice[length-1-i]
}
return result
}
// 泛型结构体
type Box[T any] struct {
Content T
}
func main() {
// 函数实例化 - 显式
reversedInts := Reverse[int]([]int{1, 2, 3, 4})
fmt.Println("反转整数切片:", reversedInts)
// 函数实例化 - 隐式(类型推断)
reversedStrings := Reverse([]string{"a", "b", "c"})
fmt.Println("反转字符串切片:", reversedStrings)
// 结构体实例化
intBox := Box[int]{Content: 42}
strBox := Box[string]{Content: "hello"}
fmt.Println("整数盒子内容:", intBox.Content)
fmt.Println("字符串盒子内容:", strBox.Content)
// 可以创建实例化类型的变量
type IntBox Box[int]
var myIntBox IntBox
myIntBox.Content = 100
fmt.Println("自定义IntBox内容:", myIntBox.Content)
}
泛型的优势与限制¶
类型安全的提升¶
泛型相比interface{}提供了更强的类型安全,在编译时就能发现类型错误:
package main
import "fmt"
// 使用泛型的安全实现
func GenericAdd[T int | float64](a, b T) T {
return a + b
}
// 使用interface{}的不安全实现
func InterfaceAdd(a, b interface{}) interface{} {
switch a.(type) {
case int:
if bVal, ok := b.(int); ok {
return a.(int) + bVal
}
case float64:
if bVal, ok := b.(float64); ok {
return a.(float64) + bVal
}
}
return nil // 类型不匹配时返回nil,运行时错误
}
func main() {
// 泛型版本在编译时就能检查类型错误
fmt.Println(GenericAdd(10, 20)) // 正确
fmt.Println(GenericAdd(3.14, 2.71)) // 正确
// fmt.Println(GenericAdd(10, 3.14)) // 编译错误,类型不匹配
// interface{}版本在运行时才能发现错误
fmt.Println(InterfaceAdd(10, 20)) // 正确
fmt.Println(InterfaceAdd(3.14, 2.71)) // 正确
fmt.Println(InterfaceAdd(10, 3.14)) // 运行时返回nil,可能导致后续错误
}
代码复用性增强¶
泛型允许我们编写一次算法,适用于多种类型:
package main
import "fmt"
// 泛型冒泡排序,适用于任何可比较的类型
func BubbleSort[T constraints.Ordered](slice []T) {
n := len(slice)
for i := 0; i < n-1; i++ {
for j := 0; j < n-i-1; j++ {
if slice[j] > slice[j+1] {
slice[j], slice[j+1] = slice[j+1], slice[j]
}
}
}
}
func main() {
// 排序整数
ints := []int{5, 2, 9, 1, 5, 6}
BubbleSort(ints)
fmt.Println("排序后的整数:", ints)
// 排序浮点数
floats := []float64{3.14, 1.59, 2.65, 0.58}
BubbleSort(floats)
fmt.Println("排序后的浮点数:", floats)
// 排序字符串
strings := []string{"banana", "apple", "cherry", "date"}
BubbleSort(strings)
fmt.Println("排序后的字符串:", strings)
}
性能优化潜力¶
泛型在某些情况下可以提供比interface{}更好的性能,因为避免了类型断言的开销:
package main
import (
"fmt"
"time"
)
// 泛型版本的求和函数
func GenericSum[T int | float64](numbers []T) T {
var sum T
for _, n := range numbers {
sum += n
}
return sum
}
// interface{}版本的求和函数
func InterfaceSum(numbers []interface{}) interface{} {
var sumInt int
var sumFloat float64
isInt := true
for i, n := range numbers {
if i == 0 {
// 确定类型
if _, ok := n.(int); ok {
isInt = true
} else if _, ok := n.(float64); ok {
isInt = false
} else {
return nil // 不支持的类型
}
}
if isInt {
if val, ok := n.(int); ok {
sumInt += val
} else {
return nil // 类型不一致
}
} else {
if val, ok := n.(float64); ok {
sumFloat += val
} else {
return nil // 类型不一致
}
}
}
if isInt {
return sumInt
}
return sumFloat
}
func main() {
// 性能测试
ints := make([]int, 1_000_000)
for i := 0; i < 1_000_000; i++ {
ints[i] = i
}
// 测试泛型版本
start := time.Now()
genericResult := GenericSum(ints)
genericDuration := time.Since(start)
fmt.Printf("泛型版本: 结果=%d, 耗时=%v\n", genericResult, genericDuration)
// 准备interface{}版本的输入
interfaceInts := make([]interface{}, 1_000_000)
for i := 0; i < 1_000_000; i++ {
interfaceInts[i] = i
}
// 测试interface{}版本
start = time.Now()
interfaceResult := InterfaceSum(interfaceInts)
interfaceDuration := time.Since(start)
fmt.Printf("Interface版本: 结果=%d, 耗时=%v\n", interfaceResult, interfaceDuration)
}
当前版本的限制¶
Go 1.18+的泛型实现有一些限制:
package main
import "fmt"
// 限制1: 不能直接使用类型参数实例化新类型
type MyType[T any] struct {
Value T
}
// 这会导致编译错误
// func NewMyType[T any]() MyType[T] {
// return MyType[T]{} // 不能直接使用类型参数实例化
// }
// 正确的做法是使用具体类型或让编译器推断
func NewMyType[T any](value T) MyType[T] {
return MyType[T]{Value: value}
}
// 限制2: 不能将类型参数用作接收器
type AnotherType struct{}
// 这会导致编译错误
// func (t T) SomeMethod() {}
// 限制3: 类型参数不能用于非类型参数的比较
func Compare[T any](a, b T) bool {
// 这会导致编译错误,除非T有comparable约束
// return a == b
return false
}
// 正确的做法是添加comparable约束
func CompareComparable[T comparable](a, b T) bool {
return a == b
}
// 限制4: 不能在包级别变量中使用类型参数
// var globalVar T // 编译错误
func main() {
numType := NewMyType(42)
fmt.Println(numType.Value)
strType := NewMyType("hello")
fmt.Println(strType.Value)
fmt.Println(CompareComparable(10, 10)) // true
fmt.Println(CompareComparable("a", "b")) // false
}
迁移策略与最佳实践¶
从interface{}迁移到泛型¶
将使用interface{}的代码迁移到泛型可以提高类型安全性:
package main
import "fmt"
// 旧版本: 使用interface{}
type OldStack struct {
elements []interface{}
}
func (s *OldStack) Push(element interface{}) {
s.elements = append(s.elements, element)
}
func (s *OldStack) Pop() interface{} {
if len(s.elements) == 0 {
return nil
}
element := s.elements[len(s.elements)-1]
s.elements = s.elements[:len(s.elements)-1]
return element
}
// 新版本: 使用泛型
type NewStack[T any] struct {
elements []T
}
func (s *NewStack[T]) Push(element T) {
s.elements = append(s.elements, element)
}
func (s *NewStack[T]) Pop() (T, bool) {
if len(s.elements) == 0 {
var zero T
return zero, false
}
element := s.elements[len(s.elements)-1]
s.elements = s.elements[:len(s.elements)-1]
return element, true
}
func main() {
// 旧版本使用方式 - 缺乏类型安全
oldStack := &OldStack{}
oldStack.Push(10)
oldStack.Push("hello") // 可以推入不同类型,导致潜在错误
// 弹出时需要类型断言
if val, ok := oldStack.Pop().(string); ok {
fmt.Println("旧栈弹出:", val)
}
// 新版本使用方式 - 类型安全
newStack := &NewStack[int]{}
newStack.Push(10)
newStack.Push(20)
// newStack.Push("hello") // 编译错误,类型不匹配
// 弹出时不需要类型断言
if val, ok := newStack.Pop(); ok {
fmt.Println("新栈弹出:", val)
}
}
泛型库的设计原则¶
设计泛型库时应遵循一些关键原则:
package main
import (
"fmt"
"golang.org/x/exp/constraints"
)
// 原则1: 最小权限原则 - 使用最严格的必要约束
// 好的设计:只要求必要的操作
func Sum[T constraints.Integer | constraints.Float](numbers []T) T {
var sum T
for _, n := range numbers {
sum += n
}
return sum
}
// 原则2: 提供具体类型的别名,提高可用性
type IntList = []int
// 为常用类型提供专用函数
func SumInts(numbers IntList) int {
return Sum(numbers)
}
// 原则3: 考虑类型推断,简化使用
func Map[T, U any](input []T, f func(T) U) []U {
result := make([]U, len(input))
for i, v := range input {
result[i] = f(v)
}
return result
}
// 原则4: 提供清晰的错误处理
func SafeDivide[T constraints.Float](a, b T) (T, error) {
if b == 0 {
return 0, fmt.Errorf("除数不能为零")
}
return a / b, nil
}
func main() {
// 测试Sum函数
ints := []int{1, 2, 3, 4}
fmt.Println("整数和:", Sum(ints))
floats := []float64{1.5, 2.5, 3.5}
fmt.Println("浮点数和:", Sum(floats))
// 测试Map函数(类型推断)
strs := []string{"1", "2", "3"}
nums := Map(strs, func(s string) int {
var n int
fmt.Sscanf(s, "%d", &n)
return n
})
fmt.Println("转换后的数字:", nums)
// 测试安全除法
if result, err := SafeDivide(10.0, 2.0); err == nil {
fmt.Println("除法结果:", result)
}
if _, err := SafeDivide(10.0, 0.0); err != nil {
fmt.Println("除法错误:", err)
}
}
向后兼容性考虑¶
引入泛型时需要考虑与旧版本代码的兼容性:
package main
import "fmt"
// 旧版本:非泛型函数
func IntMax(a, b int) int {
if a > b {
return a
}
return b
}
// 新版本:泛型函数,同时保持旧函数的兼容性
func Max[T constraints.Ordered](a, b T) T {
if a > b {
return a
}
return b
}
// 为了向后兼容,将旧函数用新的泛型函数实现
func IntMaxV2(a, b int) int {
return Max(a, b)
}
// 泛型结构体与旧类型的兼容
type OldIntQueue struct {
elements []int
}
func (q *OldIntQueue) Enqueue(v int) {
q.elements = append(q.elements, v)
}
func (q *OldIntQueue) Dequeue() int {
if len(q.elements) == 0 {
return 0
}
v := q.elements[0]
q.elements = q.elements[1:]
return v
}
// 新的泛型队列
type NewQueue[T any] struct {
elements []T
}
func (q *NewQueue[T]) Enqueue(v T) {
q.elements = append(q.elements, v)
}
func (q *NewQueue[T]) Dequeue() (T, bool) {
if len(q.elements) == 0 {
var zero T
return zero, false
}
v := q.elements[0]
q.elements = q.elements[1:]
return v, true
}
// 提供转换函数,便于从旧类型迁移到新类型
func ConvertOldQueueToNew(old *OldIntQueue) *NewQueue[int] {
newQueue := &NewQueue[int]{}
newQueue.elements = append(newQueue.elements, old.elements...)
return newQueue
}
func main() {
// 旧函数仍可使用
fmt.Println("旧版本IntMax:", IntMax(10, 20))
// 新版本泛型函数
fmt.Println("新版本Max(int):", Max(10, 20))
fmt.Println("新版本Max(float):", Max(3.14, 2.71))
// 队列兼容性示例
oldQueue := &OldIntQueue{}
oldQueue.Enqueue(1)
oldQueue.Enqueue(2)
newQueue := ConvertOldQueueToNew(oldQueue)
if v, ok := newQueue.Dequeue(); ok {
fmt.Println("从转换后的队列中取出:", v)
}
}
渐进式迁移策略¶
大型项目可以采用渐进式策略迁移到泛型:
package main
import "fmt"
// 阶段1: 现有代码使用具体类型
type IntList struct {
items []int
}
func (l *IntList) Add(item int) {
l.items = append(l.items, item)
}
func (l *IntList) Get(index int) int {
return l.items[index]
}
// 阶段2: 创建泛型版本,但保持旧版本
type GenericList[T any] struct {
items []T
}
func (l *GenericList[T]) Add(item T) {
l.items = append(l.items, item)
}
func (l *GenericList[T]) Get(index int) T {
return l.items[index]
}
// 阶段3: 为旧类型创建适配器,使用泛型实现
type IntListV2 struct {
genericList GenericList[int]
}
func (l *IntListV2) Add(item int) {
l.genericList.Add(item)
}
func (l *IntListV2) Get(index int) int {
return l.genericList.Get(index)
}
// 阶段4: 新功能只使用泛型版本
func ProcessList[T any](list *GenericList[T], processor func(T)) {
for i := 0; i < len(list.items); i++ {
processor(list.Get(i))
}
}
func main() {
// 旧代码仍可工作
oldList := &IntList{}
oldList.Add(10)
oldList.Add(20)
fmt.Println("旧列表第一个元素:", oldList.Get(0))
// 新代码使用泛型版本
newIntList := &GenericList[int]{}
newIntList.Add(30)
newIntList.Add(40)
fmt.Println("新整数列表第一个元素:", newIntList.Get(0))
newStrList := &GenericList[string]{}
newStrList.Add("hello")
newStrList.Add("world")
fmt.Println("新字符串列表第一个元素:", newStrList.Get(0))
// 处理列表的新功能
fmt.Print("处理整数列表: ")
ProcessList(newIntList, func(item int) {
fmt.Printf("%d ", item)
})
fmt.Println()
fmt.Print("处理字符串列表: ")
ProcessList(newStrList, func(item string) {
fmt.Printf("%s ", item)
})
fmt.Println()
}
实战练习¶
练习1:实现泛型数据结构¶
实现要求: - 泛型栈(Stack)实现 - 泛型队列(Queue)实现 - 泛型链表(LinkedList)实现 - 类型安全的操作接口
package main
import "fmt"
// 1. 泛型栈实现
type Stack[T any] struct {
elements []T
}
// 压栈
func (s *Stack[T]) Push(element T) {
s.elements = append(s.elements, element)
}
// 出栈
func (s *Stack[T]) Pop() (T, bool) {
if len(s.elements) == 0 {
var zero T
return zero, false
}
// 获取最后一个元素
top := s.elements[len(s.elements)-1]
// 移除最后一个元素
s.elements = s.elements[:len(s.elements)-1]
return top, true
}
// 获取栈顶元素
func (s *Stack[T]) Peek() (T, bool) {
if len(s.elements) == 0 {
var zero T
return zero, false
}
return s.elements[len(s.elements)-1], true
}
// 获取栈的大小
func (s *Stack[T]) Size() int {
return len(s.elements)
}
// 检查栈是否为空
func (s *Stack[T]) IsEmpty() bool {
return len(s.elements) == 0
}
// 2. 泛型队列实现
type Queue[T any] struct {
elements []T
}
// 入队
func (q *Queue[T]) Enqueue(element T) {
q.elements = append(q.elements, element)
}
// 出队
func (q *Queue[T]) Dequeue() (T, bool) {
if len(q.elements) == 0 {
var zero T
return zero, false
}
// 获取第一个元素
front := q.elements[0]
// 移除第一个元素
q.elements = q.elements[1:]
return front, true
}
// 获取队首元素
func (q *Queue[T]) Front() (T, bool) {
if len(q.elements) == 0 {
var zero T
return zero, false
}
return q.elements[0], true
}
// 获取队列大小
func (q *Queue[T]) Size() int {
return len(q.elements)
}
// 检查队列是否为空
func (q *Queue[T]) IsEmpty() bool {
return len(q.elements) == 0
}
// 3. 泛型链表实现
type LinkedListNode[T any] struct {
Data T
Next *LinkedListNode[T]
}
type LinkedList[T any] struct {
head *LinkedListNode[T]
tail *LinkedListNode[T]
size int
}
// 在链表末尾添加元素
func (l *LinkedList[T]) Append(data T) {
newNode := &LinkedListNode[T]{Data: data}
if l.size == 0 {
l.head = newNode
l.tail = newNode
} else {
l.tail.Next = newNode
l.tail = newNode
}
l.size++
}
// 在链表头部添加元素
func (l *LinkedList[T]) Prepend(data T) {
newNode := &LinkedListNode[T]{Data: data}
if l.size == 0 {
l.head = newNode
l.tail = newNode
} else {
newNode.Next = l.head
l.head = newNode
}
l.size++
}
// 根据索引获取元素
func (l *LinkedList[T]) Get(index int) (T, bool) {
if index < 0 || index >= l.size {
var zero T
return zero, false
}
current := l.head
for i := 0; i < index; i++ {
current = current.Next
}
return current.Data, true
}
// 删除指定索引的元素
func (l *LinkedList[T]) Remove(index int) bool {
if index < 0 || index >= l.size {
return false
}
if index == 0 {
l.head = l.head.Next
if l.size == 1 {
l.tail = nil
}
} else {
prev := l.head
for i := 0; i < index-1; i++ {
prev = prev.Next
}
prev.Next = prev.Next.Next
if index == l.size-1 {
l.tail = prev
}
}
l.size--
return true
}
// 获取链表大小
func (l *LinkedList[T]) Size() int {
return l.size
}
// 打印链表元素
func (l *LinkedList[T]) Print() {
current := l.head
for current != nil {
fmt.Printf("%v ", current.Data)
current = current.Next
}
fmt.Println()
}
func main() {
// 测试栈
fmt.Println("=== 测试栈 ===")
stack := &Stack[int]{}
stack.Push(1)
stack.Push(2)
stack.Push(3)
fmt.Println("栈大小:", stack.Size())
if val, ok := stack.Pop(); ok {
fmt.Println("出栈元素:", val)
}
if val, ok := stack.Peek(); ok {
fmt.Println("栈顶元素:", val)
}
// 测试队列
fmt.Println("\n=== 测试队列 ===")
queue := &Queue[string]{}
queue.Enqueue("a")
queue.Enqueue("b")
queue.Enqueue("c")
fmt.Println("队列大小:", queue.Size())
if val, ok := queue.Dequeue(); ok {
fmt.Println("出队元素:", val)
}
if val, ok := queue.Front(); ok {
fmt.Println("队首元素:", val)
}
// 测试链表
fmt.Println("\n=== 测试链表 ===")
list := &LinkedList[float64]{}
list.Append(1.1)
list.Append(2.2)
list.Prepend(0.0)
fmt.Print("链表元素: ")
list.Print()
fmt.Println("链表大小:", list.Size())
if val, ok := list.Get(1); ok {
fmt.Println("索引1处的元素:", val)
}
list.Remove(1)
fmt.Print("删除索引1后的元素: ")
list.Print()
}
练习2:设计泛型工具函数库¶
核心功能: - Map、Filter、Reduce函数 - 排序与搜索算法 - 数据转换工具 - 集合操作函数
package main
import (
"fmt"
"golang.org/x/exp/constraints"
)
// 1. 基础函数式工具
// Map 对切片中的每个元素应用函数f,并返回新的切片
func Map[T, U any](slice []T, f func(T) U) []U {
result := make([]U, len(slice))
for i, v := range slice {
result[i] = f(v)
}
return result
}
// Filter 过滤出满足条件的元素
func Filter[T any](slice []T, f func(T) bool) []T {
result := []T{}
for _, v := range slice {
if f(v) {
result = append(result, v)
}
}
return result
}
// Reduce 对切片元素进行归约操作
func Reduce[T, U any](slice []T, initial U, f func(U, T) U) U {
result := initial
for _, v := range slice {
result = f(result, v)
}
return result
}
// 2. 排序与搜索算法
// 冒泡排序
func BubbleSort[T constraints.Ordered](slice []T) {
n := len(slice)
for i := 0; i < n-1; i++ {
for j := 0; j < n-i-1; j++ {
if slice[j] > slice[j+1] {
slice[j], slice[j+1] = slice[j+1], slice[j]
}
}
}
}
// 二分查找
func BinarySearch[T constraints.Ordered](slice []T, target T) int {
low, high := 0, len(slice)-1
for low <= high {
mid := (low + high) / 2
if slice[mid] == target {
return mid
} else if slice[mid] < target {
low = mid + 1
} else {
high = mid - 1
}
}
return -1 // 未找到
}
// 3. 数据转换工具
// 转换为集合(去重)
func ToSet[T comparable](slice []T) map[T]struct{} {
set := make(map[T]struct{})
for _, v := range slice {
set[v] = struct{}{}
}
return set
}
// 集合转换为切片
func SetToSlice[T comparable](set map[T]struct{}) []T {
slice := make([]T, 0, len(set))
for v := range set {
slice = append(slice, v)
}
return slice
}
// 4. 集合操作函数
// 并集
func Union[T comparable](a, b []T) []T {
setA := ToSet(a)
for _, v := range b {
setA[v] = struct{}{}
}
return SetToSlice(setA)
}
// 交集
func Intersection[T comparable](a, b []T) []T {
setA := ToSet(a)
result := []T{}
for _, v := range b {
if _, exists := setA[v]; exists {
result = append(result, v)
// 避免重复添加
delete(setA, v)
}
}
return result
}
// 差集 (a - b)
func Difference[T comparable](a, b []T) []T {
setB := ToSet(b)
result := []T{}
for _, v := range a {
if _, exists := setB[v]; !exists {
result = append(result, v)
}
}
return result
}
func main() {
// 测试Map、Filter、Reduce
numbers := []int{1, 2, 3, 4, 5}
// Map: 整数转字符串
strNumbers := Map(numbers, func(n int) string {
return fmt.Sprintf("num:%d", n)
})
fmt.Println("Map结果:", strNumbers)
// Filter: 筛选偶数
evenNumbers := Filter(numbers, func(n int) bool {
return n%2 == 0
})
fmt.Println("Filter结果:", evenNumbers)
// Reduce: 计算总和
sum := Reduce(numbers, 0, func(acc, n int) int {
return acc + n
})
fmt.Println("Reduce求和结果:", sum)
// 测试排序和搜索
unsorted := []int{5, 2, 9, 1, 5, 6}
BubbleSort(unsorted)
fmt.Println("排序结果:", unsorted)
target := 5
index := BinarySearch(unsorted, target)
fmt.Printf("元素%d的索引: %d\n", target, index)
// 测试集合操作
a := []string{"a", "b", "c", "d"}
b := []string{"c", "d", "e", "f"}
fmt.Println("并集:", Union(a, b))
fmt.Println("交集:", Intersection(a, b))
fmt.Println("差集(a-b):", Difference(a, b))
}
练习3:构建泛型缓存系统¶
实现要求: - 支持任意类型的键值对 - LRU淘汰策略 - 线程安全设计 - 过期时间管理
package main
import (
"container/list"
"fmt"
"sync"
"time"
)
// 缓存项
type cacheItem[K comparable, V any] struct {
key K
value V
expiry time.Time // 过期时间,零值表示永不过期
}
// 泛型缓存结构
type Cache[K comparable, V any] struct {
mu sync.RWMutex
items map[K]*list.Element // 键到链表元素的映射
lruList *list.List // LRU淘汰策略的双向链表
capacity int // 缓存最大容量
defaultTTL time.Duration // 默认过期时间
cleanupTicker *time.Ticker // 定期清理过期项的定时器
}
// 创建新缓存
func NewCache[K comparable, V any](capacity int, defaultTTL time.Duration) *Cache[K, V] {
c := &Cache[K, V]{
items: make(map[K]*list.Element),
lruList: list.New(),
capacity: capacity,
defaultTTL: defaultTTL,
cleanupTicker: time.NewTicker(5 * time.Minute), // 每5分钟清理一次过期项
}
// 启动清理过期项的 goroutine
go c.cleanupExpired()
return c
}
// 关闭缓存,停止清理定时器
func (c *Cache[K, V]) Close() {
c.cleanupTicker.Stop()
}
// 定期清理过期项
func (c *Cache[K, V]) cleanupExpired() {
for range c.cleanupTicker.C {
c.mu.Lock()
// 遍历链表,删除过期项
for e := c.lruList.Front(); e != nil; {
next := e.Next
item := e.Value.(*cacheItem[K, V])
if !item.expiry.IsZero() && time.Now().After(item.expiry) {
delete(c.items, item.key)
c.lruList.Remove(e)
}
e = next
}
c.mu.Unlock()
}
}
// 添加或更新缓存项
func (c *Cache[K, V]) Set(key K, value V, ttl ...time.Duration) {
c.mu.Lock()
defer c.mu.Unlock()
// 确定过期时间
var expiry time.Time
if len(ttl) > 0 && ttl[0] > 0 {
expiry = time.Now().Add(ttl[0])
} else if c.defaultTTL > 0 {
expiry = time.Now().Add(c.defaultTTL)
}
// 如果键已存在,更新值并移到链表头部
if elem, exists := c.items[key]; exists {
c.lruList.MoveToFront(elem)
elem.Value.(*cacheItem[K, V]).value = value
elem.Value.(*cacheItem[K, V]).expiry = expiry
return
}
// 如果缓存已满,删除最近最少使用的项
if c.lruList.Len() >= c.capacity {
oldest := c.lruList.Back()
if oldest != nil {
oldestItem := oldest.Value.(*cacheItem[K, V])
delete(c.items, oldestItem.key)
c.lruList.Remove(oldest)
}
}
// 添加新项到链表头部
elem := c.lruList.PushFront(&cacheItem[K, V]{
key: key,
value: value,
expiry: expiry,
})
c.items[key] = elem
}
// 获取缓存项
func (c *Cache[K, V]) Get(key K) (V, bool) {
c.mu.RLock()
elem, exists := c.items[key]
if !exists {
var zero V
c.mu.RUnlock()
return zero, false
}
item := elem.Value.(*cacheItem[K, V])
// 检查是否过期
if !item.expiry.IsZero() && time.Now().After(item.expiry) {
c.mu.RUnlock()
// 这里需要获取写锁来删除过期项
c.mu.Lock()
defer c.mu.Unlock()
// 再次检查,防止并发问题
if elem, exists := c.items[key]; exists {
delete(c.items, key)
c.lruList.Remove(elem)
}
var zero V
return zero, false
}
c.mu.RUnlock()
// 将访问的项移到链表头部,表示最近使用
c.mu.Lock()
c.lruList.MoveToFront(elem)
c.mu.Unlock()
return item.value, true
}
// 删除缓存项
func (c *Cache[K, V]) Delete(key K) {
c.mu.Lock()
defer c.mu.Unlock()
if elem, exists := c.items[key]; exists {
delete(c.items, key)
c.lruList.Remove(elem)
}
}
// 清空缓存
func (c *Cache[K, V]) Clear() {
c.mu.Lock()
defer c.mu.Unlock()
c.items = make(map[K]*list.Element)
c.lruList.Init()
}
// 获取缓存当前大小
func (c *Cache[K, V]) Size() int {
c.mu.RLock()
defer c.mu.RUnlock()
return c.lruList.Len()
}
func main() {
// 创建一个容量为100,默认过期时间为1小时的缓存
userCache := NewCache[string, map[string]interface{}](100, time.Hour)
defer userCache.Close()
// 添加缓存项
userCache.Set("user1", map[string]interface{}{
"id": "1",
"name": "Alice",
"age": 30,
})
// 添加一个5秒后过期的缓存项
userCache.Set("user2", map[string]interface{}{
"id": "2",
"name": "Bob",
"age": 25,
}, 5*time.Second)
// 获取缓存项
if user, exists := userCache.Get("user1"); exists {
fmt.Println("获取到user1:", user)
} else {
fmt.Println("未找到user1")
}
fmt.Println("当前缓存大小:", userCache.Size())
// 测试过期功能
fmt.Println("等待6秒,让user2过期...")
time.Sleep(6 * time.Second)
if _, exists := userCache.Get("user2"); exists {
fmt.Println("获取到user2")
} else {
fmt.Println("user2已过期")
}
// 测试LRU淘汰策略
productCache := NewCache[int, string](3, 0) // 容量3,永不过期
defer productCache.Close()
productCache.Set(1, "产品1")
productCache.Set(2, "产品2")
productCache.Set(3, "产品3")
fmt.Println("添加3个产品后,缓存大小:", productCache.Size())
// 访问产品1,使其成为最近使用
productCache.Get(1)
// 添加第4个产品,会淘汰最近最少使用的产品2
productCache.Set(4, "产品4")
if _, exists := productCache.Get(2); exists {
fmt.Println("产品2仍然存在(不符合预期)")
} else {
fmt.Println("产品2已被LRU淘汰(符合预期)")
}
if val, exists := productCache.Get(1); exists {
fmt.Println("产品1仍然存在:", val)
}
}
总结¶
泛型是Go语言中一项强大的特性,它允许我们编写更加通用且类型安全的代码。通过本章的学习,你应该已经掌握了泛型的基本语法、类型约束的使用、泛型函数与类型的实现,以及泛型在实际项目中的应用策略。
泛型的引入并不意味着我们应该在所有地方都使用它。在实际开发中,应根据具体场景权衡使用泛型的利弊,遵循"简单性优先"的Go语言哲学。当需要处理多种类型且逻辑相同时,泛型是一个优秀的选择;而对于简单的场景,使用具体类型可能更加清晰直接。
随着Go语言的不断发展,泛型特性也将不断完善,为我们提供更强大的编程能力。掌握泛型编程,将使你能够编写更加高效、可维护的Go代码。
学习检查点¶
- 理解泛型的基本语法与概念
- 掌握类型约束的定义与使用
- 能够设计泛型函数与类型
- 了解泛型的优势与限制
- 完成所有实战练习
下节预告:实战练习与面试重点 - 综合应用与面试准备