在Go 1.18版本引入泛型特性后,结合接口可以编写出适配多种数据模型的通用数据库查询函数,避免为不同表结构重复编写查询逻辑。这种方式既保留了泛型的类型安全优势,又通过接口实现了不同数据模型的统一适配。

核心实现思路
要实现泛型式数据库查询函数,核心分为三步:首先定义数据模型需要实现的通用接口,明确查询后数据扫描的规范;然后定义泛型约束,限制泛型参数必须实现该通用接口;最后编写通用的查询函数,在函数内部通过泛型参数完成数据扫描和结果组装。
步骤1:定义通用数据接口
首先定义一个Scanner接口,所有需要被查询的数据模型都需要实现这个接口,接口中声明从数据库行扫描数据的方法:
// Scanner 定义数据从数据库行扫描的通用接口
type Scanner interface {
// Scan 从sql.Row中扫描数据到当前实例
Scan(row *sql.Row) error
// ScanRows 从sql.Rows中扫描多行数据到切片
ScanRows(rows *sql.Rows) error
}
步骤2:定义数据模型并实现接口
以用户表和订单表两个数据模型为例,分别实现Scanner接口:
import (
"database/sql"
"fmt"
)
// User 用户表数据模型
type User struct {
ID int64
Name string
Age int
Email string
}
// Scan 实现Scanner接口的Scan方法,扫描单条用户数据
func (u *User) Scan(row *sql.Row) error {
return row.Scan(&u.ID, &u.Name, &u.Age, &u.Email)
}
// ScanRows 实现Scanner接口的ScanRows方法,扫描多条用户数据
func (u *User) ScanRows(rows *sql.Rows) error {
return rows.Scan(&u.ID, &u.Name, &u.Age, &u.Email)
}
// Order 订单表数据模型
type Order struct {
ID int64
UserID int64
Product string
Amount float64
CreatedAt string
}
// Scan 实现Scanner接口的Scan方法,扫描单条订单数据
func (o *Order) Scan(row *sql.Row) error {
return row.Scan(&o.ID, &o.UserID, &o.Product, &o.Amount, &o.CreatedAt)
}
// ScanRows 实现Scanner接口的ScanRows方法,扫描多条订单数据
func (o *Order) ScanRows(rows *sql.Rows) error {
return rows.Scan(&o.ID, &o.UserID, &o.Product, &o.Amount, &o.CreatedAt)
}
步骤3:定义泛型查询函数
接下来定义泛型约束,限制泛型参数T必须实现Scanner接口,然后编写通用的单条查询和批量查询函数:
// QuerySingle 泛型单条数据查询函数
// T 泛型参数,约束为实现Scanner接口的类型
// db 数据库连接实例
// query SQL查询语句
// args 查询参数
func QuerySingle[T Scanner](db *sql.DB, query string, args ...any) (T, error) {
var t T
row := db.QueryRow(query, args...)
err := t.Scan(row)
if err != nil {
return t, fmt.Errorf("查询单条数据失败: %w", err)
}
return t, nil
}
// QueryMulti 泛型多条数据查询函数
// T 泛型参数,约束为实现Scanner接口的类型
// db 数据库连接实例
// query SQL查询语句
// args 查询参数
func QueryMulti[T Scanner](db *sql.DB, query string, args ...any) ([]T, error) {
rows, err := db.Query(query, args...)
if err != nil {
return nil, fmt.Errorf("执行查询失败: %w", err)
}
defer rows.Close()
var result []T
for rows.Next() {
var t T
err := t.ScanRows(rows)
if err != nil {
return nil, fmt.Errorf("扫描行数据失败: %w", err)
}
result = append(result, t)
}
if err = rows.Err(); err != nil {
return nil, fmt.Errorf("遍历行数据出错: %w", err)
}
return result, nil
}
使用示例
完成上述定义后,就可以直接调用通用查询函数查询不同表的数据,无需重复编写查询逻辑:
func main() {
// 假设已经初始化好db连接
// db, _ := sql.Open("mysql", "root:password@tcp(127.0.0.1:3306)/test")
// 查询单条用户数据
userQuery := "SELECT id, name, age, email FROM user WHERE id = ?"
user, err := QuerySingle[User](db, userQuery, 1)
if err != nil {
fmt.Printf("查询用户失败: %vn", err)
} else {
fmt.Printf("用户信息: %+vn", user)
}
// 查询多条订单数据
orderQuery := "SELECT id, user_id, product, amount, created_at FROM order WHERE user_id = ?"
orders, err := QueryMulti[Order](db, orderQuery, 1)
if err != nil {
fmt.Printf("查询订单失败: %vn", err)
} else {
fmt.Printf("订单列表: %+vn", orders)
}
}
注意事项
- 泛型约束必须明确,确保所有需要查询的数据模型都实现了
Scanner接口,否则编译时会报错。 - 数据库查询的字段顺序需要和
Scan方法中的参数顺序一致,避免数据赋值错误。 - 查询多行数据后要及时关闭
sql.Rows对象,避免资源泄漏,通用函数中已经通过defer处理了这个逻辑。 - 如果数据模型有嵌套结构,需要在
Scan方法中处理嵌套字段的扫描逻辑,保证数据正确赋值。