commit a6a25a8b7387787c3ad7db8f954c9b32cb205a85 Author: bryan Date: Mon Apr 7 12:17:00 2025 +0800 first commit diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..0028f17 --- /dev/null +++ b/go.mod @@ -0,0 +1,11 @@ +module qoobing.com/gomod/database + +go 1.19.2 + +require qoobing.com/gomod/log v1.4.0 + +require ( + github.com/tylerb/gls v0.0.0-20150407001822-e606233f194d // indirect + gorm.io/gorm v1.25.12 + qoobing.com/gomod/str v1.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..65aff42 --- /dev/null +++ b/go.sum @@ -0,0 +1,9 @@ +github.com/tylerb/gls v0.0.0-20150407001822-e606233f194d h1:yYYPFFlbqxF5mrj5sEfETtM/Ssz2LTy0/VKlDdXYctc= +github.com/tylerb/gls v0.0.0-20150407001822-e606233f194d/go.mod h1:0MwyId/pXK5wkYYEXe7NnVknX+aNBuF73fLV3U0reU8= +github.com/tylerb/is v2.1.4+incompatible/go.mod h1:3Bw2NWEEe8Kx7/etYqgm9ug53iNDgabnloch75jjOSc= +gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8= +gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ= +qoobing.com/gomod/log v1.4.0 h1:VSdV8Fm2roCkSwOTunmNrH2kl14KOCYavsh4tJ/SGj0= +qoobing.com/gomod/log v1.4.0/go.mod h1:rNXuq0d/EWog4+8hIEVGvkusLD/pzafYBQo6w+Evv6A= +qoobing.com/gomod/str v1.0.1 h1:X+JOigE9xA6cTNph7/s1KeD4zLYM9XTLPPHQcpHFoog= +qoobing.com/gomod/str v1.0.1/go.mod h1:gbhN2dba/P5gFRGVJvEI57KEJLlMHHAd6Kuuxn4GlMY= diff --git a/gorm.go b/gorm.go new file mode 100644 index 0000000..4dd05be --- /dev/null +++ b/gorm.go @@ -0,0 +1,174 @@ +package gorm + +import ( + "crypto/md5" + "encoding/hex" + "encoding/json" + "fmt" + "strconv" + "strings" + "sync" + "time" + + "gorm.io/driver/mysql" + "gorm.io/driver/postgres" + "gorm.io/gorm" + gormlogger "gorm.io/gorm/logger" +) + +var _dbCache = sync.Map{} + +type ( + DB = gorm.DB + Session = gorm.Session +) + +type Config struct { + id string // inner use + Type string `toml:"type"` //数据库类型 mysql or pgsql + Host string `toml:"host"` //数据库名称 + Port int `toml:"port"` //数据库名称 + Debug bool `toml:"debug"` //调试开关(会在日志打印SQL) + Dbname string `toml:"dbname"` //数据库名称 + Username string `toml:"username"` //数据库用户名 + Password string `toml:"password"` //数据库连接密码 + MaxIdleConns int `toml:"max_idle_conns"` //最大空闲连接数 + MaxOpenConns int `toml:"max_open_conns"` //最大打开连接数 + MaxConnLifetime time.Duration `toml:"max_conn_lifetime"` //连接最长生命周期 + ExtraParameters string `toml:"extra_parameters"` //数据库连接扩展参数 +} + +func NewDB(dbcfg *Config) (*gorm.DB, error) { + return open(dbcfg) +} + +func NewSession(db *gorm.DB) *gorm.DB { + return db.Session(&gorm.Session{}) +} + +func DefaultLogger() gormlogger.Interface { + return defaultLogger +} + +func DefaultRecorder() traceRecorder { + return defautRecorder +} + +func (dbcfg *Config) Id() string { + if dbcfg.id != "" { + return dbcfg.id + } + jsonDbCfg, err := json.Marshal(dbcfg) + if err != nil { + panic("unreachable code: Marshal ocour error:" + err.Error()) + } + md5sum := md5.Sum(jsonDbCfg) + dbcfg.id = hex.EncodeToString(md5sum[:]) + return dbcfg.id +} + +func GetDsn(dbcfg *Config) (dsn string) { + switch dbcfg.Type { + case "mysql": + //dsn := "user:pass@tcp(127.0.0.1:3306)/dbname?charset=utf8mb4&parseTime=True&loc=Local" + dsn = fmt.Sprintf( + "%s:%s@tcp(%s:%d)/%s?%s", + dbcfg.Username, + dbcfg.Password, + dbcfg.Host, + dbcfg.Port, + dbcfg.Dbname, + dbcfg.ExtraParameters, + ) + case "pgsql": + //dsn := "host=localhost user=gorm password=gorm dbname=gorm port=9920 sslmode=disable TimeZone=Asia/Shanghai" + arrConfStr := []string{ + "host=" + dbcfg.Host, + "port=" + strconv.Itoa(dbcfg.Port), + "user=" + dbcfg.Username, + "password=" + dbcfg.Password, + "dbname=" + dbcfg.Dbname, + dbcfg.ExtraParameters, + } + dsn = strings.Join(arrConfStr, " ") + default: + panic("DATABASE TYPE '" + dbcfg.Type + "' NOT SUPPORT (only mysql or pgsql)") + } + return dsn +} + +func open(dbcfg *Config) (*gorm.DB, error) { + var ( + db *gorm.DB + err error + dsn = GetDsn(dbcfg) + cacheValue, inCacheYes = _dbCache.Load(dbcfg.Id()) + cachedDb, inCacheYes2 = cacheValue.(*gorm.DB) + ) + // return from cache if have been cached before + if inCacheYes && inCacheYes2 { + if cachedDb == nil { + panic("unreachable code: cachedDb MUST not be nil if it cached") + } + return cachedDb.Session(&gorm.Session{}), nil + } + + // dialector support 'mysql' and 'pgsql' + var dialector gorm.Dialector + switch dbcfg.Type { + case "mysql": + dialector = mysql.New( + mysql.Config{ + DSN: dsn, + DefaultStringSize: 512, + }, + ) + case "pgsql": + dialector = postgres.New( + postgres.Config{ + DSN: dsn, + PreferSimpleProtocol: true, + }, + ) + default: + panic("UNKNOWN DATABASE TYPE:" + dbcfg.Type) + } + + // gorm open + var gormconfig = &gorm.Config{ + NowFunc: func() time.Time { return time.Now().UTC() }, + QueryFields: true, + } + db, err = gorm.Open(dialector, gormconfig) + if err != nil { + return nil, fmt.Errorf("gorm.Open error:%s", err) + } + if dbcfg.Debug { + db = db.Session(&gorm.Session{Logger: DefaultLogger()}) + } + + // set connection pool parameters + sqlDB, err := db.DB() + if err != nil { + return nil, fmt.Errorf("gormDB get sqlDB error:%s", err) + } + if dbcfg.MaxIdleConns > 0 { + sqlDB.SetMaxIdleConns(dbcfg.MaxIdleConns) + } else { + sqlDB.SetMaxIdleConns(100) + } + if dbcfg.MaxOpenConns > 0 { + sqlDB.SetMaxOpenConns(dbcfg.MaxOpenConns) + } else { + sqlDB.SetMaxOpenConns(500) + } + if dbcfg.MaxConnLifetime > 0 { + sqlDB.SetConnMaxLifetime(dbcfg.MaxConnLifetime) + } else { + sqlDB.SetConnMaxLifetime(3 * time.Minute) + } + + // store to cache and return + _dbCache.Store(dbcfg.Id(), db) + return db, nil +} diff --git a/logger.go b/logger.go new file mode 100644 index 0000000..659a55f --- /dev/null +++ b/logger.go @@ -0,0 +1,200 @@ +package gorm + +import ( + "context" + "errors" + "fmt" + "time" + + gormlogger "gorm.io/gorm/logger" + "qoobing.com/gomod/log" +) + +// ErrRecordNotFound record not found error +var ErrRecordNotFound = errors.New("record not found") + +// Colors +const ( + Reset = "\033[0m" + Red = "\033[31m" + Green = "\033[32m" + Yellow = "\033[33m" + Blue = "\033[34m" + Magenta = "\033[35m" + Cyan = "\033[36m" + White = "\033[37m" + BlueBold = "\033[34;1m" + MagentaBold = "\033[35;1m" + RedBold = "\033[31;1m" + YellowBold = "\033[33;1m" +) + +// LogLevel log level +type LogLevel = gormlogger.LogLevel + +const ( + // Silent silent log level + Silent LogLevel = iota + 1 + // Error error log level + Error + // Warn warn log level + Warn + // Info info log level + Info +) + +// LogConfig logger config +type LogConfig struct { + SlowThreshold time.Duration + Colorful bool + IgnoreRecordNotFoundError bool + ParameterizedQueries bool + LogLevel LogLevel +} + +var ( + // // Discard logger will print any log to io.Discard + // Discard = New(log.New(io.Discard, "", log.LstdFlags), Config{}) + // // defaultLogger defaultLogger logger + // + defaultLogger = New(LogConfig{ + SlowThreshold: 200 * time.Millisecond, + LogLevel: Info, + IgnoreRecordNotFoundError: false, + Colorful: true, + }) + + // defautRecorder logger records running SQL into a recorder instance + defautRecorder = traceRecorder{Interface: defaultLogger, BeginAt: time.Now()} +) + +// New initialize logger +func New(config LogConfig) gormlogger.Interface { + var ( + infoStr = "%s\n[info] " + warnStr = "%s\n[warn] " + errStr = "%s\n[error] " + traceStr = "%s\n[%.3fms] [rows:%v] %s" + traceWarnStr = "%s\n[%.3fms] [rows:%v] %s" + traceErrStr = "%s\n[%.3fms] [rows:%v] %s" + ) + + if config.Colorful { + infoStr = Green + "%s\n" + Reset + Green + "[info] " + Reset + warnStr = BlueBold + "%s\n" + Reset + Magenta + "[warn] " + Reset + errStr = Magenta + "%s\n" + Reset + Red + "[error] " + Reset + traceStr = Yellow + "[%.3fms] " + BlueBold + "[rows:%v]" + Reset + " %s" + traceWarnStr = Yellow + "%s\n" + Reset + RedBold + "[%.3fms] " + Yellow + "[rows:%v]" + Magenta + " %s" + Reset + traceErrStr = MagentaBold + "%s\n" + Reset + Yellow + "[%.3fms] " + BlueBold + "[rows:%v]" + Reset + " %s" + } + baseLogger := *log.New("gorm") + baseLogger.SetCalldepth(3) + return &logger{ + baseLogger: baseLogger, + LogConfig: config, + infoStr: infoStr, + warnStr: warnStr, + errStr: errStr, + traceStr: traceStr, + traceWarnStr: traceWarnStr, + traceErrStr: traceErrStr, + } +} + +type logger struct { + LogConfig + baseLogger log.Logger + infoStr, warnStr, errStr string + traceStr, traceErrStr, traceWarnStr string +} + +// LogMode log mode +func (l *logger) LogMode(level LogLevel) gormlogger.Interface { + newlogger := *l + newlogger.LogLevel = level + return &newlogger +} + +// Info print info +func (l *logger) Info(ctx context.Context, msg string, data ...interface{}) { + if l.LogLevel >= Info { + l.baseLogger.Infof(l.infoStr+msg, data) + } +} + +// Warn print warn messages +func (l *logger) Warn(ctx context.Context, msg string, data ...interface{}) { + if l.LogLevel >= Warn { + l.baseLogger.Warningf(l.warnStr+msg, data) + } +} + +// Error print error messages +func (l *logger) Error(ctx context.Context, msg string, data ...interface{}) { + if l.LogLevel >= Error { + l.baseLogger.Errorf(l.errStr+msg, data) + } +} + +// Trace print sql message +// +//nolint:cyclop +func (l *logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { + if l.LogLevel <= Silent { + return + } + + elapsed := time.Since(begin) + switch { + case err != nil && l.LogLevel >= Error && (!errors.Is(err, ErrRecordNotFound) || !l.IgnoreRecordNotFoundError): + sql, rows := fc() + if rows == -1 { + l.baseLogger.Debugf(l.traceErrStr, err, float64(elapsed.Nanoseconds())/1e6, "-", sql) + } else { + l.baseLogger.Debugf(l.traceErrStr, err, float64(elapsed.Nanoseconds())/1e6, rows, sql) + } + case elapsed > l.SlowThreshold && l.SlowThreshold != 0 && l.LogLevel >= Warn: + sql, rows := fc() + slowLog := fmt.Sprintf("SLOW SQL >= %v", l.SlowThreshold) + if rows == -1 { + l.baseLogger.Debugf(l.traceWarnStr, slowLog, float64(elapsed.Nanoseconds())/1e6, "-", sql) + } else { + l.baseLogger.Debugf(l.traceWarnStr, slowLog, float64(elapsed.Nanoseconds())/1e6, rows, sql) + } + case l.LogLevel == Info: + sql, rows := fc() + if rows == -1 { + l.baseLogger.Debugf(l.traceStr, float64(elapsed.Nanoseconds())/1e6, "-", sql) + } else { + l.baseLogger.Debugf(l.traceStr, float64(elapsed.Nanoseconds())/1e6, rows, sql) + } + } +} + +// ParamsFilter filter params +func (l *logger) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) { + if l.LogConfig.ParameterizedQueries { + return sql, nil + } + return sql, params +} + +type traceRecorder struct { + gormlogger.Interface + BeginAt time.Time + SQL string + RowsAffected int64 + Err error +} + +// New trace recorder +func (l *traceRecorder) New() *traceRecorder { + return &traceRecorder{Interface: l.Interface, BeginAt: time.Now()} +} + +// Trace implement logger interface +func (l *traceRecorder) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { + l.BeginAt = begin + l.SQL, l.RowsAffected = fc() + l.Err = err +}