package gorm import ( "crypto/md5" "encoding/hex" "encoding/json" "fmt" "strconv" "strings" "sync" "time" "gorm.io/driver/mysql" "gorm.io/driver/postgres" "gorm.io/gorm" gormlogger "gorm.io/gorm/logger" ) var _dbCache = sync.Map{} type ( DB = gorm.DB Session = gorm.Session ) type Config struct { id string // inner use Type string `toml:"type"` //数据库类型 mysql or pgsql Host string `toml:"host"` //数据库名称 Port int `toml:"port"` //数据库名称 Debug bool `toml:"debug"` //调试开关(会在日志打印SQL) Dbname string `toml:"dbname"` //数据库名称 Username string `toml:"username"` //数据库用户名 Password string `toml:"password"` //数据库连接密码 MaxIdleConns int `toml:"max_idle_conns"` //最大空闲连接数 MaxOpenConns int `toml:"max_open_conns"` //最大打开连接数 MaxConnLifetime time.Duration `toml:"max_conn_lifetime"` //连接最长生命周期 ExtraParameters string `toml:"extra_parameters"` //数据库连接扩展参数 } func NewDB(dbcfg *Config) (*gorm.DB, error) { return open(dbcfg) } func NewSession(db *gorm.DB) *gorm.DB { return db.Session(&gorm.Session{}) } func DefaultLogger() gormlogger.Interface { return defaultLogger } func DefaultRecorder() traceRecorder { return defautRecorder } func (dbcfg *Config) Id() string { if dbcfg.id != "" { return dbcfg.id } jsonDbCfg, err := json.Marshal(dbcfg) if err != nil { panic("unreachable code: Marshal ocour error:" + err.Error()) } md5sum := md5.Sum(jsonDbCfg) dbcfg.id = hex.EncodeToString(md5sum[:]) return dbcfg.id } func GetDsn(dbcfg *Config) (dsn string) { switch dbcfg.Type { case "mysql": //dsn := "user:pass@tcp(127.0.0.1:3306)/dbname?charset=utf8mb4&parseTime=True&loc=Local" dsn = fmt.Sprintf( "%s:%s@tcp(%s:%d)/%s?%s", dbcfg.Username, dbcfg.Password, dbcfg.Host, dbcfg.Port, dbcfg.Dbname, dbcfg.ExtraParameters, ) case "pgsql": //dsn := "host=localhost user=gorm password=gorm dbname=gorm port=9920 sslmode=disable TimeZone=Asia/Shanghai" arrConfStr := []string{ "host=" + dbcfg.Host, "port=" + strconv.Itoa(dbcfg.Port), "user=" + dbcfg.Username, "password=" + dbcfg.Password, "dbname=" + dbcfg.Dbname, dbcfg.ExtraParameters, } dsn = strings.Join(arrConfStr, " ") default: panic("DATABASE TYPE '" + dbcfg.Type + "' NOT SUPPORT (only mysql or pgsql)") } return dsn } func open(dbcfg *Config) (*gorm.DB, error) { var ( db *gorm.DB err error dsn = GetDsn(dbcfg) cacheValue, inCacheYes = _dbCache.Load(dbcfg.Id()) cachedDb, inCacheYes2 = cacheValue.(*gorm.DB) ) // return from cache if have been cached before if inCacheYes && inCacheYes2 { if cachedDb == nil { panic("unreachable code: cachedDb MUST not be nil if it cached") } return cachedDb.Session(&gorm.Session{}), nil } // dialector support 'mysql' and 'pgsql' var dialector gorm.Dialector switch dbcfg.Type { case "mysql": dialector = mysql.New( mysql.Config{ DSN: dsn, DefaultStringSize: 512, }, ) case "pgsql": dialector = postgres.New( postgres.Config{ DSN: dsn, PreferSimpleProtocol: true, }, ) default: panic("UNKNOWN DATABASE TYPE:" + dbcfg.Type) } // gorm open var gormconfig = &gorm.Config{ NowFunc: func() time.Time { return time.Now().UTC() }, QueryFields: true, } db, err = gorm.Open(dialector, gormconfig) if err != nil { return nil, fmt.Errorf("gorm.Open error:%s", err) } if dbcfg.Debug { db = db.Session(&gorm.Session{Logger: DefaultLogger()}) } // set connection pool parameters sqlDB, err := db.DB() if err != nil { return nil, fmt.Errorf("gormDB get sqlDB error:%s", err) } if dbcfg.MaxIdleConns > 0 { sqlDB.SetMaxIdleConns(dbcfg.MaxIdleConns) } else { sqlDB.SetMaxIdleConns(100) } if dbcfg.MaxOpenConns > 0 { sqlDB.SetMaxOpenConns(dbcfg.MaxOpenConns) } else { sqlDB.SetMaxOpenConns(500) } if dbcfg.MaxConnLifetime > 0 { sqlDB.SetConnMaxLifetime(dbcfg.MaxConnLifetime) } else { sqlDB.SetConnMaxLifetime(3 * time.Minute) } // store to cache and return _dbCache.Store(dbcfg.Id(), db) return db, nil }