跳转至

7.1 Go标准数据库接口深度解析

作为有三十年Go语言开发教学经验的老师,我很高兴带你深入理解Go的标准数据库接口。database/sql 包是Go语言操作数据库的核心标准库,它提供了一套优雅的接口抽象,让我们能够用统一的方式操作各种关系型数据库。

学习目标

  • 理解database/sql包的设计理念和架构
  • 掌握数据库驱动的注册与管理机制
  • 熟练配置各种数据库连接参数
  • 建立标准化的数据库操作基础

核心内容

1. database/sql包架构设计

1.1 接口抽象层设计

database/sql 包采用接口抽象的设计模式,将数据库操作分为两个层次: - 对应用层:提供统一的API接口 - 对驱动层:定义驱动需要实现的接口

这种设计使得应用程序可以用相同的方式操作不同的数据库,只需更换底层驱动即可。

package main

import (
    "database/sql"
    "fmt"
    "log"

    _ "github.com/go-sql-driver/mysql"
)

func main() {
    // 使用标准sql接口,无需关心底层是MySQL还是其他数据库
    db, err := sql.Open("mysql", "user:password@/dbname")
    if err != nil {
        log.Fatal(err)
    }
    defer db.Close()

    // 统一的查询接口
    rows, err := db.Query("SELECT id, name FROM users WHERE age > ?", 18)
    if err != nil {
        log.Fatal(err)
    }
    defer rows.Close()

    for rows.Next() {
        var id int
        var name string
        if err := rows.Scan(&id, &name); err != nil {
            log.Fatal(err)
        }
        fmt.Printf("id: %d, name: %s\n", id, name)
    }
}

1.2 驱动注册机制

Go使用隐式注册机制,驱动程序通过init函数自行注册:

package main

import (
    "database/sql"
    "fmt"

    // 导入MySQL驱动,执行其init函数进行注册
    _ "github.com/go-sql-driver/mysql"
    // 可以同时导入多个驱动
    _ "github.com/lib/pq"
)

func main() {
    // 获取所有已注册的驱动
    drivers := sql.Drivers()
    fmt.Println("Registered drivers:", drivers)

    // 使用MySQL驱动
    db, err := sql.Open("mysql", "user:password@/dbname")
    if err != nil {
        panic(err)
    }
    defer db.Close()

    // 验证连接
    err = db.Ping()
    if err != nil {
        panic(err)
    }
    fmt.Println("Successfully connected to MySQL database!")
}

1.3 连接管理原理

database/sql 包内置了连接池管理,自动处理连接的获取、释放和健康检查:

package main

import (
    "database/sql"
    "fmt"
    "log"
    "time"

    _ "github.com/go-sql-driver/mysql"
)

func main() {
    db, err := sql.Open("mysql", "user:password@/dbname")
    if err != nil {
        log.Fatal(err)
    }
    defer db.Close()

    // 配置连接池参数
    db.SetMaxOpenConns(25)      // 最大打开连接数
    db.SetMaxIdleConns(5)       // 最大空闲连接数
    db.SetConnMaxLifetime(5 * time.Minute)  // 连接最大存活时间
    db.SetConnMaxIdleTime(1 * time.Minute)  // 连接最大空闲时间

    // 获取连接池状态
    stats := db.Stats()
    fmt.Printf("Pool stats: OpenConnections=%d, InUse=%d, Idle=%d\n",
        stats.OpenConnections, stats.InUse, stats.Idle)

    // 执行查询会自动从连接池获取连接,用完后放回
    var version string
    err = db.QueryRow("SELECT VERSION()").Scan(&version)
    if err != nil {
        log.Fatal(err)
    }
    fmt.Println("MySQL version:", version)
}

2. 数据库驱动管理

2.1 驱动的注册与发现

package main

import (
    "database/sql"
    "fmt"
    "reflect"

    _ "github.com/go-sql-driver/mysql"
    _ "github.com/lib/pq"
)

func main() {
    // 查看已注册的驱动
    drivers := sql.Drivers()
    fmt.Println("Available drivers:", drivers)

    // 在实际项目中,可以根据配置动态选择驱动
    config := map[string]string{
        "db_type":   "mysql",
        "dsn":       "user:password@/dbname",
    }

    var db *sql.DB
    var err error

    switch config["db_type"] {
    case "mysql":
        db, err = sql.Open("mysql", config["dsn"])
    case "postgres":
        db, err = sql.Open("postgres", config["dsn"])
    default:
        panic("Unsupported database type")
    }

    if err != nil {
        panic(err)
    }
    defer db.Close()

    // 使用反射查看驱动的具体类型(高级用法)
    driver := db.Driver()
    fmt.Printf("Driver type: %v\n", reflect.TypeOf(driver))
}

2.2 多驱动环境下的管理

package main

import (
    "database/sql"
    "fmt"
    "log"
    "sync"

    _ "github.com/go-sql-driver/mysql"
    _ "github.com/lib/pq"
)

// DatabaseManager 管理多个数据库连接
type DatabaseManager struct {
    connections map[string]*sql.DB
    mutex       sync.RWMutex
}

func NewDatabaseManager() *DatabaseManager {
    return &DatabaseManager{
        connections: make(map[string]*sql.DB),
    }
}

func (dm *DatabaseManager) AddConnection(name, driver, dsn string) error {
    dm.mutex.Lock()
    defer dm.mutex.Unlock()

    db, err := sql.Open(driver, dsn)
    if err != nil {
        return err
    }

    // 测试连接
    if err := db.Ping(); err != nil {
        db.Close()
        return err
    }

    dm.connections[name] = db
    return nil
}

func (dm *DatabaseManager) GetConnection(name string) (*sql.DB, error) {
    dm.mutex.RLock()
    defer dm.mutex.RUnlock()

    db, exists := dm.connections[name]
    if !exists {
        return nil, fmt.Errorf("connection %s not found", name)
    }
    return db, nil
}

func (dm *DatabaseManager) CloseAll() {
    dm.mutex.Lock()
    defer dm.mutex.Unlock()

    for name, db := range dm.connections {
        if err := db.Close(); err != nil {
            log.Printf("Error closing connection %s: %v", name, err)
        }
        delete(dm.connections, name)
    }
}

func main() {
    manager := NewDatabaseManager()
    defer manager.CloseAll()

    // 添加多个数据库连接
    err := manager.AddConnection("mysql_db", "mysql", "user:password@/testdb")
    if err != nil {
        log.Fatal(err)
    }

    err = manager.AddConnection("postgres_db", "postgres", "user=username password=password dbname=testdb sslmode=disable")
    if err != nil {
        log.Fatal(err)
    }

    // 使用特定的数据库连接
    mysqlDB, err := manager.GetConnection("mysql_db")
    if err != nil {
        log.Fatal(err)
    }

    var result string
    err = mysqlDB.QueryRow("SELECT 'Hello from MySQL'").Scan(&result)
    if err != nil {
        log.Fatal(err)
    }
    fmt.Println(result)
}

2.3 驱动兼容性考虑

package main

import (
    "database/sql"
    "fmt"
    "log"

    _ "github.com/go-sql-driver/mysql"
)

// DatabaseCompatibility 处理不同数据库的兼容性问题
type DatabaseCompatibility struct {
    db *sql.DB
}

func (dc *DatabaseCompatibility) GetNow() (string, error) {
    var now string

    // 尝试不同的时间函数,处理数据库兼容性
    queries := []string{
        "SELECT NOW()",      // MySQL
        "SELECT datetime('now')", // SQLite
        "SELECT CURRENT_TIMESTAMP", // PostgreSQL, SQL Server
    }

    for _, query := range queries {
        err := dc.db.QueryRow(query).Scan(&now)
        if err == nil {
            return now, nil
        }
        log.Printf("Query failed: %s, error: %v", query, err)
    }

    return "", fmt.Errorf("all datetime queries failed")
}

func main() {
    db, err := sql.Open("mysql", "user:password@/testdb")
    if err != nil {
        log.Fatal(err)
    }
    defer db.Close()

    compat := &DatabaseCompatibility{db: db}
    currentTime, err := compat.GetNow()
    if err != nil {
        log.Fatal(err)
    }

    fmt.Println("Current database time:", currentTime)
}

3. 连接字符串与配置

3.1 标准连接字符串格式

package main

import (
    "database/sql"
    "fmt"
    "log"

    _ "github.com/go-sql-driver/mysql"
)

// 不同数据库的连接字符串示例
func main() {
    // MySQL连接字符串格式
    mysqlDSN := "username:password@tcp(127.0.0.1:3306)/dbname?charset=utf8mb4&parseTime=True&loc=Local"

    // PostgreSQL连接字符串格式
    // postgresDSN := "host=localhost port=5432 user=username password=password dbname=dbname sslmode=disable"

    // SQLite连接字符串格式
    // sqliteDSN := "file:test.db?cache=shared&mode=memory"

    db, err := sql.Open("mysql", mysqlDSN)
    if err != nil {
        log.Fatal(err)
    }
    defer db.Close()

    // 解析连接参数并打印
    fmt.Println("成功连接到数据库")
    fmt.Println("连接字符串:", mysqlDSN)

    // 验证连接
    err = db.Ping()
    if err != nil {
        log.Fatal("连接测试失败:", err)
    }
    fmt.Println("连接测试成功!")
}

3.2 连接参数详解

package main

import (
    "database/sql"
    "fmt"
    "log"
    "time"

    _ "github.com/go-sql-driver/mysql"
)

// ConnectionConfig 封装连接配置
type ConnectionConfig struct {
    Username     string
    Password     string
    Host         string
    Port         string
    Database     string
    Charset      string
    ParseTime    bool
    Loc          string
    Timeout      time.Duration
    ReadTimeout  time.Duration
    WriteTimeout time.Duration
}

func (c *ConnectionConfig) BuildDSN() string {
    dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s", c.Username, c.Password, c.Host, c.Port, c.Database)

    params := fmt.Sprintf("charset=%s&parseTime=%t&loc=%s&timeout=%s&readTimeout=%s&writeTimeout=%s",
        c.Charset, c.ParseTime, c.Loc, c.Timeout, c.ReadTimeout, c.WriteTimeout)

    if params != "" {
        dsn += "?" + params
    }

    return dsn
}

func main() {
    config := &ConnectionConfig{
        Username:     "root",
        Password:     "password",
        Host:         "127.0.0.1",
        Port:         "3306",
        Database:     "testdb",
        Charset:      "utf8mb4",
        ParseTime:    true,
        Loc:          "Local",
        Timeout:      30 * time.Second,
        ReadTimeout:  30 * time.Second,
        WriteTimeout: 30 * time.Second,
    }

    dsn := config.BuildDSN()
    fmt.Println("生成的DSN:", dsn)

    db, err := sql.Open("mysql", dsn)
    if err != nil {
        log.Fatal(err)
    }
    defer db.Close()

    // 测试连接
    if err := db.Ping(); err != nil {
        log.Fatal("连接失败:", err)
    }
    fmt.Println("成功连接到数据库!")
}

3.3 安全配置最佳实践

package main

import (
    "crypto/tls"
    "crypto/x509"
    "database/sql"
    "fmt"
    "io/ioutil"
    "log"

    "github.com/go-sql-driver/mysql"
)

func main() {
    // 1. 使用TLS加密连接
    rootCertPool := x509.NewCertPool()
    pem, err := ioutil.ReadFile("/path/to/ca-cert.pem")
    if err != nil {
        log.Fatal(err)
    }
    if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
        log.Fatal("Failed to append PEM.")
    }

    mysql.RegisterTLSConfig("custom", &tls.Config{
        RootCAs: rootCertPool,
    })

    // 2. 安全的连接字符串(避免密码硬编码)
    // 实际项目中应从环境变量或配置文件中获取
    dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?tls=custom",
        getEnv("DB_USER", "root"),
        getEnv("DB_PASSWORD", ""),
        getEnv("DB_HOST", "localhost"),
        3306,
        getEnv("DB_NAME", "testdb"),
    )

    db, err := sql.Open("mysql", dsn)
    if err != nil {
        log.Fatal(err)
    }
    defer db.Close()

    // 3. 设置合理的连接池参数
    db.SetMaxOpenConns(25)
    db.SetMaxIdleConns(5)
    db.SetConnMaxLifetime(5 * time.Minute)

    // 4. 验证连接
    if err := db.Ping(); err != nil {
        log.Fatal("安全连接测试失败:", err)
    }

    fmt.Println("安全连接建立成功!")
}

// 模拟从环境变量获取配置
func getEnv(key, defaultValue string) string {
    // 实际项目中应该使用 os.Getenv()
    if key == "DB_PASSWORD" {
        return "secure_password" // 应从安全的地方获取
    }
    return defaultValue
}

4. 基础操作封装

4.1 连接的建立与关闭

package main

import (
    "database/sql"
    "fmt"
    "log"
    "time"

    _ "github.com/go-sql-driver/mysql"
)

// DatabaseClient 封装数据库连接管理
type DatabaseClient struct {
    db        *sql.DB
    config    *DBConfig
    isConnected bool
}

type DBConfig struct {
    Driver   string
    DSN      string
    MaxConns int
    IdleConns int
}

func NewDatabaseClient(config *DBConfig) *DatabaseClient {
    return &DatabaseClient{
        config: config,
    }
}

func (dc *DatabaseClient) Connect() error {
    db, err := sql.Open(dc.config.Driver, dc.config.DSN)
    if err != nil {
        return fmt.Errorf("failed to open database: %w", err)
    }

    // 配置连接池
    db.SetMaxOpenConns(dc.config.MaxConns)
    db.SetMaxIdleConns(dc.config.IdleConns)
    db.SetConnMaxLifetime(1 * time.Hour)

    // 测试连接
    if err := db.Ping(); err != nil {
        db.Close()
        return fmt.Errorf("failed to ping database: %w", err)
    }

    dc.db = db
    dc.isConnected = true
    return nil
}

func (dc *DatabaseClient) Close() error {
    if dc.db != nil {
        err := dc.db.Close()
        dc.isConnected = false
        return err
    }
    return nil
}

func (dc *DatabaseClient) GetDB() *sql.DB {
    return dc.db
}

func (dc *DatabaseClient) IsConnected() bool {
    return dc.isConnected
}

func main() {
    config := &DBConfig{
        Driver:   "mysql",
        DSN:      "root:password@tcp(127.0.0.1:3306)/testdb",
        MaxConns: 25,
        IdleConns: 5,
    }

    client := NewDatabaseClient(config)

    // 建立连接
    if err := client.Connect(); err != nil {
        log.Fatal("连接失败:", err)
    }
    defer client.Close()

    fmt.Println("数据库连接状态:", client.IsConnected())

    // 使用连接执行查询
    var version string
    err := client.GetDB().QueryRow("SELECT VERSION()").Scan(&version)
    if err != nil {
        log.Fatal(err)
    }

    fmt.Println("数据库版本:", version)
}

4.2 查询操作的标准化

package main

import (
    "database/sql"
    "fmt"
    "log"

    _ "github.com/go-sql-driver/mysql"
)

// QueryExecutor 标准化查询执行器
type QueryExecutor struct {
    db *sql.DB
}

func NewQueryExecutor(db *sql.DB) *QueryExecutor {
    return &QueryExecutor{db: db}
}

// QueryRow 查询单行数据
func (qe *QueryExecutor) QueryRow(query string, args ...interface{}) *sql.Row {
    return qe.db.QueryRow(query, args...)
}

// Query 查询多行数据
func (qe *QueryExecutor) Query(query string, args ...interface{}) (*sql.Rows, error) {
    return qe.db.Query(query, args...)
}

// Exec 执行写操作
func (qe *QueryExecutor) Exec(query string, args ...interface{}) (sql.Result, error) {
    return qe.db.Exec(query, args...)
}

// QueryMap 将查询结果映射到map
func (qe *QueryExecutor) QueryMap(query string, args ...interface{}) ([]map[string]interface{}, error) {
    rows, err := qe.Query(query, args...)
    if err != nil {
        return nil, err
    }
    defer rows.Close()

    columns, err := rows.Columns()
    if err != nil {
        return nil, err
    }

    var results []map[string]interface{}
    for rows.Next() {
        values := make([]interface{}, len(columns))
        valuePtrs := make([]interface{}, len(columns))
        for i := range values {
            valuePtrs[i] = &values[i]
        }

        if err := rows.Scan(valuePtrs...); err != nil {
            return nil, err
        }

        rowMap := make(map[string]interface{})
        for i, col := range columns {
            val := values[i]
            b, ok := val.([]byte)
            if ok {
                rowMap[col] = string(b)
            } else {
                rowMap[col] = val
            }
        }
        results = append(results, rowMap)
    }

    return results, nil
}

func main() {
    db, err := sql.Open("mysql", "root:password@tcp(127.0.0.1:3306)/testdb")
    if err != nil {
        log.Fatal(err)
    }
    defer db.Close()

    executor := NewQueryExecutor(db)

    // 示例1: 查询单行
    var userName string
    err = executor.QueryRow("SELECT name FROM users WHERE id = ?", 1).Scan(&userName)
    if err != nil {
        log.Fatal(err)
    }
    fmt.Println("用户名:", userName)

    // 示例2: 查询多行并映射到map
    results, err := executor.QueryMap("SELECT id, name, email FROM users LIMIT 5")
    if err != nil {
        log.Fatal(err)
    }

    fmt.Println("查询结果:")
    for _, row := range results {
        fmt.Printf("ID: %v, Name: %v, Email: %v\n", row["id"], row["name"], row["email"])
    }

    // 示例3: 执行写操作
    result, err := executor.Exec("INSERT INTO users (name, email) VALUES (?, ?)", "张三", "zhangsan@example.com")
    if err != nil {
        log.Fatal(err)
    }

    lastID, err := result.LastInsertId()
    if err != nil {
        log.Fatal(err)
    }
    fmt.Printf("插入成功,ID: %d\n", lastID)
}

4.3 错误处理机制

package main

import (
    "database/sql"
    "errors"
    "fmt"
    "log"

    _ "github.com/go-sql-driver/mysql"
)

// DBError 自定义数据库错误类型
type DBError struct {
    Op      string
    Query   string
    Err     error
    IsEmpty bool
}

func (e *DBError) Error() string {
    return fmt.Sprintf("数据库操作错误 [%s]: %s, 查询: %s", e.Op, e.Err, e.Query)
}

func (e *DBError) Unwrap() error {
    return e.Err
}

// ErrorHandler 数据库错误处理器
type ErrorHandler struct{}

func (eh *ErrorHandler) HandleError(op, query string, err error) error {
    if err == nil {
        return nil
    }

    dbErr := &DBError{
        Op:    op,
        Query: query,
        Err:   err,
    }

    // 处理特定的数据库错误
    if errors.Is(err, sql.ErrNoRows) {
        dbErr.IsEmpty = true
        return dbErr
    }

    // 可以根据需要添加更多的错误类型判断
    return dbErr
}

// SafeQueryExecutor 安全的查询执行器
type SafeQueryExecutor struct {
    db      *sql.DB
    handler *ErrorHandler
}

func NewSafeQueryExecutor(db *sql.DB) *SafeQueryExecutor {
    return &SafeQueryExecutor{
        db:      db,
        handler: &ErrorHandler{},
    }
}

func (sqe *SafeQueryExecutor) SafeQueryRow(query string, args ...interface{}) (map[string]interface{}, error) {
    rows, err := sqe.db.Query(query, args...)
    if err != nil {
        return nil, sqe.handler.HandleError("query", query, err)
    }
    defer rows.Close()

    columns, err := rows.Columns()
    if err != nil {
        return nil, sqe.handler.HandleError("get columns", query, err)
    }

    if !rows.Next() {
        return nil, sqe.handler.HandleError("query", query, sql.ErrNoRows)
    }

    values := make([]interface{}, len(columns))
    valuePtrs := make([]interface{}, len(columns))
    for i := range values {
        valuePtrs[i] = &values[i]
    }

    if err := rows.Scan(valuePtrs...); err != nil {
        return nil, sqe.handler.HandleError("scan", query, err)
    }

    result := make(map[string]interface{})
    for i, col := range columns {
        val := values[i]
        b, ok := val.([]byte)
        if ok {
            result[col] = string(b)
        } else {
            result[col] = val
        }
    }

    return result, nil
}

func main() {
    db, err := sql.Open("mysql", "root:password@tcp(127.0.0.1:3306)/testdb")
    if err != nil {
        log.Fatal(err)
    }
    defer db.Close()

    executor := NewSafeQueryExecutor(db)

    // 示例1: 正常查询
    result, err := executor.SafeQueryRow("SELECT name, email FROM users WHERE id = ?", 1)
    if err != nil {
        var dbErr *DBError
        if errors.As(err, &dbErr) && dbErr.IsEmpty {
            fmt.Println("查询结果为空")
        } else {
            log.Fatal("查询错误:", err)
        }
    } else {
        fmt.Println("查询结果:", result)
    }

    // 示例2: 查询不存在的记录
    result, err = executor.SafeQueryRow("SELECT name, email FROM users WHERE id = ?", 999)
    if err != nil {
        var dbErr *DBError
        if errors.As(err, &dbErr) && dbErr.IsEmpty {
            fmt.Println("提示: 查询的记录不存在")
        } else {
            log.Fatal("查询错误:", err)
        }
    } else {
        fmt.Println("查询结果:", result)
    }

    // 示例3: 语法错误的查询
    result, err = executor.SafeQueryRow("SELECT invalid_column FROM users WHERE id = ?", 1)
    if err != nil {
        fmt.Println("捕获到预期错误:", err)
    } else {
        fmt.Println("查询结果:", result)
    }
}

实战练习

练习1:多数据库驱动管理

package main

import (
    "database/sql"
    "fmt"
    "log"
    "sync"

    _ "github.com/go-sql-driver/mysql"
    _ "github.com/lib/pq"
)

// MultiDBManager 多数据库管理器
type MultiDBManager struct {
    databases map[string]*sql.DB
    mutex     sync.RWMutex
}

func NewMultiDBManager() *MultiDBManager {
    return &MultiDBManager{
        databases: make(map[string]*sql.DB),
    }
}

// AddDatabase 添加数据库连接
func (m *MultiDBManager) AddDatabase(name, driver, dsn string) error {
    m.mutex.Lock()
    defer m.mutex.Unlock()

    db, err := sql.Open(driver, dsn)
    if err != nil {
        return fmt.Errorf("failed to open database %s: %w", name, err)
    }

    if err := db.Ping(); err != nil {
        db.Close()
        return fmt.Errorf("failed to ping database %s: %w", name, err)
    }

    // 配置连接池
    db.SetMaxOpenConns(10)
    db.SetMaxIdleConns(2)

    m.databases[name] = db
    return nil
}

// GetDatabase 获取数据库连接
func (m *MultiDBManager) GetDatabase(name string) (*sql.DB, error) {
    m.mutex.RLock()
    defer m.mutex.RUnlock()

    db, exists := m.databases[name]
    if !exists {
        return nil, fmt.Errorf("database %s not found", name)
    }
    return db, nil
}

// ExecuteOnAll 在所有数据库上执行操作
func (m *MultiDBManager) ExecuteOnAll(operation func(*sql.DB) error) map[string]error {
    m.mutex.RLock()
    databases := make(map[string]*sql.DB)
    for name, db := range m.databases {
        databases[name] = db
    }
    m.mutex.RUnlock()

    results := make(map[string]error)
    for name, db := range databases {
        results[name] = operation(db)
    }
    return results
}

// CloseAll 关闭所有数据库连接
func (m *MultiDBManager) CloseAll() {
    m.mutex.Lock()
    defer m.mutex.Unlock()

    for name, db := range m.databases {
        if err := db.Close(); err != nil {
            log.Printf("Error closing database %s: %v", name, err)
        }
        delete(m.databases, name)
    }
}

func main() {
    manager := NewMultiDBManager()
    defer manager.CloseAll()

    // 添加多个数据库
    databases := map[string]struct {
        driver string
        dsn    string
    }{
        "mysql_primary": {
            driver: "mysql",
            dsn:    "root:password@tcp(127.0.0.1:3306)/db1",
        },
        "mysql_backup": {
            driver: "mysql",
            dsn:    "root:password@tcp(127.0.0.1:3306)/db2",
        },
        "postgres_analytics": {
            driver: "postgres",
            dsn:    "postgres://user:password@localhost/analytics?sslmode=disable",
        },
    }

    for name, config := range databases {
        if err := manager.AddDatabase(name, config.driver, config.dsn); err != nil {
            log.Printf("Warning: Failed to add database %s: %v", name, err)
        }
    }

    // 在所有数据库上执行操作
    results := manager.ExecuteOnAll(func(db *sql.DB) error {
        var result string
        // 尝试不同的查询语句
        queries := []string{
            "SELECT VERSION()",
            "SELECT version()",
            "SELECT 'Hello' as greeting",
        }

        for _, query := range queries {
            err := db.QueryRow(query).Scan(&result)
            if err == nil {
                fmt.Printf("Query successful: %s\n", result)
                return nil
            }
        }
        return fmt.Errorf("all queries failed")
    })

    // 输出执行结果
    for dbName, err := range results {
        if err != nil {
            fmt.Printf("Database %s: ERROR - %v\n", dbName, err)
        } else {
            fmt.Printf("Database %s: SUCCESS\n", dbName)
        }
    }
}

练习2:连接配置抽象层设计

package main

import (
    "database/sql"
    "encoding/json"
    "fmt"
    "log"
    "os"
    "time"

    _ "github.com/go-sql-driver/mysql"
)

// DBConfig 数据库配置
type DBConfig struct {
    Driver          string        `json:"driver"`
    DSN             string        `json:"dsn"`
    MaxOpenConns    int           `json:"max_open_conns"`
    MaxIdleConns    int           `json:"max_idle_conns"`
    ConnMaxLifetime time.Duration `json:"conn_max_lifetime"`
    ConnMaxIdleTime time.Duration `json:"conn_max_idle_time"`
}

// ConfigManager 配置管理器
type ConfigManager struct {
    configs map[string]DBConfig
}

func NewConfigManager(configFile string) (*ConfigManager, error) {
    data, err := os.ReadFile(configFile)
    if err != nil {
        return nil, fmt.Errorf("failed to read config file: %w", err)
    }

    var configs map[string]DBConfig
    if err := json.Unmarshal(data, &configs); err != nil {
        return nil, fmt.Errorf("failed to parse config: %w", err)
    }

    return &ConfigManager{configs: configs}, nil
}

func (cm *ConfigManager) GetConfig(name string) (DBConfig, error) {
    config, exists := cm.configs[name]
    if !exists {
        return DBConfig{}, fmt.Errorf("config %s not found", name)
    }
    return config, nil
}

// ConnectionFactory 连接工厂
type ConnectionFactory struct {
    configManager *ConfigManager
    connections   map[string]*sql.DB
}

func NewConnectionFactory(configManager *ConfigManager) *ConnectionFactory {
    return &ConnectionFactory{
        configManager: configManager,
        connections:   make(map[string]*sql.DB),
    }
}

func (cf *ConnectionFactory) GetConnection(name string) (*sql.DB, error) {
    if db, exists := cf.connections[name]; exists {
        // 检查连接是否仍然有效
        if err := db.Ping(); err == nil {
            return db, nil
        }
        // 连接已失效,关闭并重新创建
        db.Close()
        delete(cf.connections, name)
    }

    config, err := cf.configManager.GetConfig(name)
    if err != nil {
        return nil, err
    }

    db, err := sql.Open(config.Driver, config.DSN)
    if err != nil {
        return nil, fmt.Errorf("failed to open connection: %w", err)
    }

    // 配置连接池
    db.SetMaxOpenConns(config.MaxOpenConns)
    db.SetMaxIdleConns(config.MaxIdleConns)
    db.SetConnMaxLifetime(config.ConnMaxLifetime)
    db.SetConnMaxIdleTime(config.ConnMaxIdleTime)

    // 测试连接
    if err := db.Ping(); err != nil {
        db.Close()
        return nil, fmt.Errorf("failed to ping database: %w", err)
    }

    cf.connections[name] = db
    return db, nil
}

func (cf *ConnectionFactory) CloseAll() {
    for name, db := range cf.connections {
        if err := db.Close(); err != nil {
            log.Printf("Error closing connection %s: %v", name, err)
        }
        delete(cf.connections, name)
    }
}

func main() {
    // 创建示例配置文件
    configData := map[string]DBConfig{
        "mysql_primary": {
            Driver:          "mysql",
            DSN:             "root:password@tcp(127.0.0.1:3306)/primary_db",
            MaxOpenConns:    20,
            MaxIdleConns:    5,
            ConnMaxLifetime: time.Hour,
            ConnMaxIdleTime: 30 * time.Minute,
        },
        "mysql_replica": {
            Driver:          "mysql",
            DSN:             "root:password@tcp(127.0.0.1:3306)/replica_db",
            MaxOpenConns:    15,
            MaxIdleConns:    3,
            ConnMaxLifetime: time.Hour,
            ConnMaxIdleTime: 30 * time.Minute,
        },
    }

    // 写入临时配置文件
    tmpFile := "/tmp/db_config.json"
    data, err := json.MarshalIndent(configData, "", "  ")
    if err != nil {
        log.Fatal(err)
    }

    if err := os.WriteFile(tmpFile, data, 0644); err != nil {
        log.Fatal(err)
    }
    defer os.Remove(tmpFile)

    // 使用配置管理器
    configManager, err := NewConfigManager(tmpFile)
    if err != nil {
        log.Fatal(err)
    }

    factory := NewConnectionFactory(configManager)
    defer factory.CloseAll()

    // 获取连接并使用
    db, err := factory.GetConnection("mysql_primary")
    if err != nil {
        log.Fatal(err)
    }

    var version string
    if err := db.QueryRow("SELECT VERSION()").Scan(&version); err != nil {
        log.Fatal(err)
    }

    fmt.Println("Database version:", version)
    fmt.Println("Connection pool stats:")
    fmt.Printf("  OpenConnections: %d\n", db.Stats().OpenConnections)
    fmt.Printf("  InUse: %d\n", db.Stats().InUse)
    fmt.Printf("  Idle: %d\n", db.Stats().Idle)
}

练习3:标准化操作接口封装

package main

import (
    "context"
    "database/sql"
    "fmt"
    "log"
    "time"

    _ "github.com/go-sql-driver/mysql"
)

// StandardizedDB 标准化数据库操作接口
type StandardizedDB interface {
    // 查询操作
    Query(ctx context.Context, query string, args ...interface{}) ([]map[string]interface{}, error)
    QueryRow(ctx context.Context, query string, args ...interface{}) (map[string]interface{}, error)

    // 执行操作
    Exec(ctx context.Context, query string, args ...interface{}) (sql.Result, error)

    // 事务操作
    BeginTx(ctx context.Context) (StandardizedTx, error)

    // 连接管理
    Ping(ctx context.Context) error
    Close() error

    // 统计信息
    Stats() sql.DBStats
}

// StandardizedTx 标准化事务接口
type StandardizedTx interface {
    Commit() error
    Rollback() error
    Exec(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
    Query(ctx context.Context, query string, args ...interface{}) ([]map[string]interface{}, error)
}

// MySQLDB MySQL数据库实现
type MySQLDB struct {
    db *sql.DB
}

func NewMySQLDB(dsn string) (*MySQLDB, error) {
    db, err := sql.Open("mysql", dsn)
    if err != nil {
        return nil, err
    }

    // 配置连接池
    db.SetMaxOpenConns(25)
    db.SetMaxIdleConns(5)
    db.SetConnMaxLifetime(5 * time.Minute)

    if err := db.Ping(); err != nil {
        db.Close()
        return nil, err
    }

    return &MySQLDB{db: db}, nil
}

func (m *MySQLDB) Query(ctx context.Context, query string, args ...interface{}) ([]map[string]interface{}, error) {
    rows, err := m.db.QueryContext(ctx, query, args...)
    if err != nil {
        return nil, err
    }
    defer rows.Close()

    columns, err := rows.Columns()
    if err != nil {
        return nil, err
    }

    var results []map[string]interface{}
    for rows.Next() {
        values := make([]interface{}, len(columns))
        valuePtrs := make([]interface{}, len(columns))
        for i := range values {
            valuePtrs[i] = &values[i]
        }

        if err := rows.Scan(valuePtrs...); err != nil {
            return nil, err
        }

        rowMap := make(map[string]interface{})
        for i, col := range columns {
            val := values[i]
            b, ok := val.([]byte)
            if ok {
                rowMap[col] = string(b)
            } else {
                rowMap[col] = val
            }
        }
        results = append(results, rowMap)
    }

    if err := rows.Err(); err != nil {
        return nil, err
    }

    return results, nil
}

func (m *MySQLDB) QueryRow(ctx context.Context, query string, args ...interface{}) (map[string]interface{}, error) {
    rows, err := m.Query(ctx, query, args...)
    if err != nil {
        return nil, err
    }

    if len(rows) == 0 {
        return nil, sql.ErrNoRows
    }

    return rows[0], nil
}

func (m *MySQLDB) Exec(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
    return m.db.ExecContext(ctx, query, args...)
}

func (m *MySQLDB) BeginTx(ctx context.Context) (StandardizedTx, error) {
    tx, err := m.db.BeginTx(ctx, nil)
    if err != nil {
        return nil, err
    }
    return &MySQLTx{tx: tx}, nil
}

func (m *MySQLDB) Ping(ctx context.Context) error {
    return m.db.PingContext(ctx)
}

func (m *MySQLDB) Close() error {
    return m.db.Close()
}

func (m *MySQLDB) Stats() sql.DBStats {
    return m.db.Stats()
}

// MySQLTx MySQL事务实现
type MySQLTx struct {
    tx *sql.Tx
}

func (mt *MySQLTx) Commit() error {
    return mt.tx.Commit()
}

func (mt *MySQLTx) Rollback() error {
    return mt.tx.Rollback()
}

func (mt *MySQLTx) Exec(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
    return mt.tx.ExecContext(ctx, query, args...)
}

func (mt *MySQLTx) Query(ctx context.Context, query string, args ...interface{}) ([]map[string]interface{}, error) {
    rows, err := mt.tx.QueryContext(ctx, query, args...)
    if err != nil {
        return nil, err
    }
    defer rows.Close()

    columns, err := rows.Columns()
    if err != nil {
        return nil, err
    }

    var results []map[string]interface{}
    for rows.Next() {
        values := make([]interface{}, len(columns))
        valuePtrs := make([]interface{}, len(columns))
        for i := range values {
            valuePtrs[i] = &values[i]
        }

        if err := rows.Scan(valuePtrs...); err != nil {
            return nil, err
        }

        rowMap := make(map[string]interface{})
        for i, col := range columns {
            val := values[i]
            b, ok := val.([]byte)
            if ok {
                rowMap[col] = string(b)
            } else {
                rowMap[col] = val
            }
        }
        results = append(results, rowMap)
    }

    if err := rows.Err(); err != nil {
        return nil, err
    }

    return results, nil
}

func main() {
    // 创建标准化数据库实例
    db, err := NewMySQLDB("root:password@tcp(127.0.0.1:3306)/testdb")
    if err != nil {
        log.Fatal(err)
    }
    defer db.Close()

    ctx := context.Background()

    // 测试连接
    if err := db.Ping(ctx); err != nil {
        log.Fatal("连接测试失败:", err)
    }
    fmt.Println("数据库连接成功!")

    // 执行查询
    results, err := db.Query(ctx, "SELECT id, name, email FROM users LIMIT 5")
    if err != nil {
        log.Fatal("查询失败:", err)
    }

    fmt.Println("用户列表:")
    for _, row := range results {
        fmt.Printf("ID: %v, Name: %v, Email: %v\n", row["id"], row["name"], row["email"])
    }

    // 执行事务
    tx, err := db.BeginTx(ctx)
    if err != nil {
        log.Fatal("开始事务失败:", err)
    }

    defer func() {
        if err != nil {
            tx.Rollback()
            log.Println("事务回滚")
        }
    }()

    // 在事务中执行操作
    _, err = tx.Exec(ctx, "UPDATE users SET email = ? WHERE id = ?", "updated@example.com", 1)
    if err != nil {
        log.Fatal("更新失败:", err)
    }

    // 查询事务中的结果
    updatedRow, err := tx.QueryRow(ctx, "SELECT name, email FROM users WHERE id = ?", 1)
    if err != nil {
        log.Fatal("查询失败:", err)
    }

    fmt.Printf("更新后的数据: Name=%v, Email=%v\n", updatedRow["name"], updatedRow["email"])

    // 提交事务
    if err := tx.Commit(); err != nil {
        log.Fatal("提交事务失败:", err)
    }
    fmt.Println("事务提交成功")

    // 显示连接池统计信息
    stats := db.Stats()
    fmt.Printf("\n连接池统计:\n")
    fmt.Printf("  最大打开连接数: %d\n", stats.MaxOpenConnections)
    fmt.Printf("  打开连接数: %d\n", stats.OpenConnections)
    fmt.Printf("  使用中连接数: %d\n", stats.InUse)
    fmt.Printf("  空闲连接数: %d\n", stats.Idle)
}

这份教程详细介绍了Go标准数据库接口的各个方面,从架构设计到实战应用。每个代码示例都是完整的、可运行的,并且遵循了最佳实践。通过学习和实践这些内容,你将能够建立扎实的Go数据库编程基础,为后续更高级的主题打下坚实的基础。

记住,良好的数据库操作习惯和错误处理机制是构建稳定应用程序的关键。在实际项目中,你应该根据具体需求适当调整连接池参数和错误处理策略。


详细内容待补充...