package orm import ( "context" times "git.icechen.cn/monorepo/backend/pkg/time" "github.com/gofiber/fiber/v2" ctxLogger "github.com/luizsuper/ctxLoggers" "github.com/spf13/cast" "go.uber.org/zap" "gorm.io/gorm" "gorm.io/gorm/utils" "time" ) func GetContextDB(ctx context.Context, db *gorm.DB) (*gorm.DB, error) { if err := db.Use(&TracePlugin{}); err != nil { return nil, err } return db.WithContext(ctx), nil } type SQL struct { Timestamp string `json:"timestamp"` // 时间,格式:2006-01-02 15:04:05 Stack string `json:"stack"` // 文件地址和行号 SQL string `json:"sql"` // SQL 语句 Rows int64 `json:"rows_affected"` // 影响行数 CostSeconds float64 `json:"cost_seconds"` // 执行时长(单位秒) } const ( callBackBeforeName = "core:before" callBackAfterName = "core:after" startTime = "_start_time" ) type TracePlugin struct{} func (op *TracePlugin) Name() string { return "tracePlugin" } func (op *TracePlugin) Initialize(db *gorm.DB) (err error) { // 开始前 _ = db.Callback().Create().Before("gorm:before_create").Register(callBackBeforeName, before) _ = db.Callback().Query().Before("gorm:query").Register(callBackBeforeName, before) _ = db.Callback().Delete().Before("gorm:before_delete").Register(callBackBeforeName, before) _ = db.Callback().Update().Before("gorm:setup_reflect_value").Register(callBackBeforeName, before) _ = db.Callback().Row().Before("gorm:row").Register(callBackBeforeName, before) _ = db.Callback().Raw().Before("gorm:raw").Register(callBackBeforeName, before) // 结束后 _ = db.Callback().Create().After("gorm:after_create").Register(callBackAfterName, after) _ = db.Callback().Query().After("gorm:after_query").Register(callBackAfterName, after) _ = db.Callback().Delete().After("gorm:after_delete").Register(callBackAfterName, after) _ = db.Callback().Update().After("gorm:after_update").Register(callBackAfterName, after) _ = db.Callback().Row().After("gorm:row").Register(callBackAfterName, after) _ = db.Callback().Raw().After("gorm:raw").Register(callBackAfterName, after) return } //记录db开始时间 func before(db *gorm.DB) { db.InstanceSet(startTime, time.Now()) return } func after(db *gorm.DB) { _ctx := db.Statement.Context ctx, ok := _ctx.(context.Context) if !ok { return } defer ctx.Done() _ts, isExist := db.InstanceGet(startTime) if !isExist { return } ts, ok := _ts.(time.Time) if !ok { return } sql := db.Dialector.Explain(db.Statement.SQL.String(), db.Statement.Vars...) sqlInfo := new(SQL) sqlInfo.Timestamp = times.CSTLayoutString() sqlInfo.SQL = sql sqlInfo.Stack = utils.FileWithLineNum() sqlInfo.Rows = db.Statement.RowsAffected sqlInfo.CostSeconds = time.Since(ts).Seconds() ctxLogger.FInfo(nil, "sql", zap.Any("info", sqlInfo), zap.String("trace_id", cast.ToString(ctx.Value(fiber.HeaderXRequestID)))) return }