gorm/gorm.go
2025-04-08 08:49:31 +08:00

185 lines
4.6 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package gorm
import (
"crypto/md5"
"encoding/hex"
"encoding/json"
"fmt"
"strconv"
"strings"
"sync"
"time"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/gorm"
gormlogger "gorm.io/gorm/logger"
)
var _dbCache = sync.Map{}
type (
DB = gorm.DB
Session = gorm.Session
)
type Config struct {
id string // inner use
Type string `toml:"type"` //数据库类型 mysql or pgsql
Host string `toml:"host"` //数据库名称
Port int `toml:"port"` //数据库名称
Debug bool `toml:"debug"` //调试开关会在日志打印SQL)
Dbname string `toml:"dbname"` //数据库名称
Username string `toml:"username"` //数据库用户名
Password string `toml:"password"` //数据库连接密码
MaxIdleConns int `toml:"max_idle_conns"` //最大空闲连接数
MaxOpenConns int `toml:"max_open_conns"` //最大打开连接数
MaxConnLifetime time.Duration `toml:"max_conn_lifetime"` //连接最长生命周期
ExtraParameters string `toml:"extra_parameters"` //数据库连接扩展参数
}
func NewDB(dbcfg Config) (*gorm.DB, error) {
return open(&dbcfg)
}
func NewSession(db *gorm.DB) *gorm.DB {
return db.Session(&gorm.Session{})
}
func DefaultLogger() gormlogger.Interface {
return defaultLogger
}
func DefaultRecorder() traceRecorder {
return defautRecorder
}
func (dbcfg *Config) Id() string {
if dbcfg.id != "" {
return dbcfg.id
}
jsonDbCfg, err := json.Marshal(dbcfg)
if err != nil {
panic("unreachable code: Marshal ocour error:" + err.Error())
}
md5sum := md5.Sum(jsonDbCfg)
dbcfg.id = hex.EncodeToString(md5sum[:])
return dbcfg.id
}
func (dbcfg *Config) GetDsn() (dsn string) {
switch dbcfg.Type {
case "mysql":
//dsn := "user:pass@tcp(127.0.0.1:3306)/dbname?charset=utf8mb4&parseTime=True&loc=Local"
dsn = fmt.Sprintf(
"%s:%s@tcp(%s:%d)/%s?%s",
dbcfg.Username,
dbcfg.Password,
dbcfg.Host,
dbcfg.Port,
dbcfg.Dbname,
dbcfg.ExtraParameters,
)
case "pgsql":
//dsn := "host=localhost user=gorm password=gorm dbname=gorm port=9920 sslmode=disable TimeZone=Asia/Shanghai"
arrConfStr := []string{
"host=" + dbcfg.Host,
"port=" + strconv.Itoa(dbcfg.Port),
"user=" + dbcfg.Username,
"password=" + dbcfg.Password,
"dbname=" + dbcfg.Dbname,
dbcfg.ExtraParameters,
}
dsn = strings.Join(arrConfStr, " ")
default:
panic("DATABASE TYPE '" + dbcfg.Type + "' NOT SUPPORT (only mysql or pgsql)")
}
return dsn
}
func (dbcfg Config) GetSecDsn() string {
dsn := dbcfg.GetDsn()
if len(dbcfg.Password) > 5 {
p, l := dbcfg.Password, len(dbcfg.Password)
return strings.Replace(dsn, p[2:l-3], "*****", 1)
} else {
return dsn
}
}
func open(dbcfg *Config) (*gorm.DB, error) {
var (
db *gorm.DB
err error
dsn = dbcfg.GetDsn()
cacheValue, inCacheYes = _dbCache.Load(dbcfg.Id())
cachedDb, inCacheYes2 = cacheValue.(*gorm.DB)
)
// return from cache if have been cached before
if inCacheYes && inCacheYes2 {
if cachedDb == nil {
panic("unreachable code: cachedDb MUST not be nil if it cached")
}
return cachedDb.Session(&gorm.Session{}), nil
}
// dialector support 'mysql' and 'pgsql'
var dialector gorm.Dialector
switch dbcfg.Type {
case "mysql":
dialector = mysql.New(
mysql.Config{
DSN: dsn,
DefaultStringSize: 512,
},
)
case "pgsql":
dialector = postgres.New(
postgres.Config{
DSN: dsn,
PreferSimpleProtocol: true,
},
)
default:
panic("UNKNOWN DATABASE TYPE:" + dbcfg.Type)
}
// gorm open
var gormconfig = &gorm.Config{
NowFunc: func() time.Time { return time.Now().UTC() },
QueryFields: true,
}
db, err = gorm.Open(dialector, gormconfig)
if err != nil {
return nil, fmt.Errorf("gorm.Open error:%s", err)
}
if dbcfg.Debug {
db = db.Session(&gorm.Session{Logger: DefaultLogger()})
}
// set connection pool parameters
sqlDB, err := db.DB()
if err != nil {
return nil, fmt.Errorf("gormDB get sqlDB error:%s", err)
}
if dbcfg.MaxIdleConns > 0 {
sqlDB.SetMaxIdleConns(dbcfg.MaxIdleConns)
} else {
sqlDB.SetMaxIdleConns(100)
}
if dbcfg.MaxOpenConns > 0 {
sqlDB.SetMaxOpenConns(dbcfg.MaxOpenConns)
} else {
sqlDB.SetMaxOpenConns(500)
}
if dbcfg.MaxConnLifetime > 0 {
sqlDB.SetConnMaxLifetime(dbcfg.MaxConnLifetime)
} else {
sqlDB.SetConnMaxLifetime(3 * time.Minute)
}
// store to cache and return
_dbCache.Store(dbcfg.Id(), db)
return db, nil
}