175 lines
4.4 KiB
Go
175 lines
4.4 KiB
Go
package gorm
|
||
|
||
import (
|
||
"crypto/md5"
|
||
"encoding/hex"
|
||
"encoding/json"
|
||
"fmt"
|
||
"strconv"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
|
||
"gorm.io/driver/mysql"
|
||
"gorm.io/driver/postgres"
|
||
"gorm.io/gorm"
|
||
gormlogger "gorm.io/gorm/logger"
|
||
)
|
||
|
||
var _dbCache = sync.Map{}
|
||
|
||
type (
|
||
DB = gorm.DB
|
||
Session = gorm.Session
|
||
)
|
||
|
||
type Config struct {
|
||
id string // inner use
|
||
Type string `toml:"type"` //数据库类型 mysql or pgsql
|
||
Host string `toml:"host"` //数据库名称
|
||
Port int `toml:"port"` //数据库名称
|
||
Debug bool `toml:"debug"` //调试开关(会在日志打印SQL)
|
||
Dbname string `toml:"dbname"` //数据库名称
|
||
Username string `toml:"username"` //数据库用户名
|
||
Password string `toml:"password"` //数据库连接密码
|
||
MaxIdleConns int `toml:"max_idle_conns"` //最大空闲连接数
|
||
MaxOpenConns int `toml:"max_open_conns"` //最大打开连接数
|
||
MaxConnLifetime time.Duration `toml:"max_conn_lifetime"` //连接最长生命周期
|
||
ExtraParameters string `toml:"extra_parameters"` //数据库连接扩展参数
|
||
}
|
||
|
||
func NewDB(dbcfg *Config) (*gorm.DB, error) {
|
||
return open(dbcfg)
|
||
}
|
||
|
||
func NewSession(db *gorm.DB) *gorm.DB {
|
||
return db.Session(&gorm.Session{})
|
||
}
|
||
|
||
func DefaultLogger() gormlogger.Interface {
|
||
return defaultLogger
|
||
}
|
||
|
||
func DefaultRecorder() traceRecorder {
|
||
return defautRecorder
|
||
}
|
||
|
||
func (dbcfg *Config) Id() string {
|
||
if dbcfg.id != "" {
|
||
return dbcfg.id
|
||
}
|
||
jsonDbCfg, err := json.Marshal(dbcfg)
|
||
if err != nil {
|
||
panic("unreachable code: Marshal ocour error:" + err.Error())
|
||
}
|
||
md5sum := md5.Sum(jsonDbCfg)
|
||
dbcfg.id = hex.EncodeToString(md5sum[:])
|
||
return dbcfg.id
|
||
}
|
||
|
||
func GetDsn(dbcfg *Config) (dsn string) {
|
||
switch dbcfg.Type {
|
||
case "mysql":
|
||
//dsn := "user:pass@tcp(127.0.0.1:3306)/dbname?charset=utf8mb4&parseTime=True&loc=Local"
|
||
dsn = fmt.Sprintf(
|
||
"%s:%s@tcp(%s:%d)/%s?%s",
|
||
dbcfg.Username,
|
||
dbcfg.Password,
|
||
dbcfg.Host,
|
||
dbcfg.Port,
|
||
dbcfg.Dbname,
|
||
dbcfg.ExtraParameters,
|
||
)
|
||
case "pgsql":
|
||
//dsn := "host=localhost user=gorm password=gorm dbname=gorm port=9920 sslmode=disable TimeZone=Asia/Shanghai"
|
||
arrConfStr := []string{
|
||
"host=" + dbcfg.Host,
|
||
"port=" + strconv.Itoa(dbcfg.Port),
|
||
"user=" + dbcfg.Username,
|
||
"password=" + dbcfg.Password,
|
||
"dbname=" + dbcfg.Dbname,
|
||
dbcfg.ExtraParameters,
|
||
}
|
||
dsn = strings.Join(arrConfStr, " ")
|
||
default:
|
||
panic("DATABASE TYPE '" + dbcfg.Type + "' NOT SUPPORT (only mysql or pgsql)")
|
||
}
|
||
return dsn
|
||
}
|
||
|
||
func open(dbcfg *Config) (*gorm.DB, error) {
|
||
var (
|
||
db *gorm.DB
|
||
err error
|
||
dsn = GetDsn(dbcfg)
|
||
cacheValue, inCacheYes = _dbCache.Load(dbcfg.Id())
|
||
cachedDb, inCacheYes2 = cacheValue.(*gorm.DB)
|
||
)
|
||
// return from cache if have been cached before
|
||
if inCacheYes && inCacheYes2 {
|
||
if cachedDb == nil {
|
||
panic("unreachable code: cachedDb MUST not be nil if it cached")
|
||
}
|
||
return cachedDb.Session(&gorm.Session{}), nil
|
||
}
|
||
|
||
// dialector support 'mysql' and 'pgsql'
|
||
var dialector gorm.Dialector
|
||
switch dbcfg.Type {
|
||
case "mysql":
|
||
dialector = mysql.New(
|
||
mysql.Config{
|
||
DSN: dsn,
|
||
DefaultStringSize: 512,
|
||
},
|
||
)
|
||
case "pgsql":
|
||
dialector = postgres.New(
|
||
postgres.Config{
|
||
DSN: dsn,
|
||
PreferSimpleProtocol: true,
|
||
},
|
||
)
|
||
default:
|
||
panic("UNKNOWN DATABASE TYPE:" + dbcfg.Type)
|
||
}
|
||
|
||
// gorm open
|
||
var gormconfig = &gorm.Config{
|
||
NowFunc: func() time.Time { return time.Now().UTC() },
|
||
QueryFields: true,
|
||
}
|
||
db, err = gorm.Open(dialector, gormconfig)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("gorm.Open error:%s", err)
|
||
}
|
||
if dbcfg.Debug {
|
||
db = db.Session(&gorm.Session{Logger: DefaultLogger()})
|
||
}
|
||
|
||
// set connection pool parameters
|
||
sqlDB, err := db.DB()
|
||
if err != nil {
|
||
return nil, fmt.Errorf("gormDB get sqlDB error:%s", err)
|
||
}
|
||
if dbcfg.MaxIdleConns > 0 {
|
||
sqlDB.SetMaxIdleConns(dbcfg.MaxIdleConns)
|
||
} else {
|
||
sqlDB.SetMaxIdleConns(100)
|
||
}
|
||
if dbcfg.MaxOpenConns > 0 {
|
||
sqlDB.SetMaxOpenConns(dbcfg.MaxOpenConns)
|
||
} else {
|
||
sqlDB.SetMaxOpenConns(500)
|
||
}
|
||
if dbcfg.MaxConnLifetime > 0 {
|
||
sqlDB.SetConnMaxLifetime(dbcfg.MaxConnLifetime)
|
||
} else {
|
||
sqlDB.SetConnMaxLifetime(3 * time.Minute)
|
||
}
|
||
|
||
// store to cache and return
|
||
_dbCache.Store(dbcfg.Id(), db)
|
||
return db, nil
|
||
}
|