Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 16 additions & 60 deletions models/db/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,28 +21,8 @@ type engineContextKeyType struct{}

var engineContextKey = engineContextKeyType{}

type xormContextType struct {
context.Context
engine Engine
}

var xormContext *xormContextType

func newContext(ctx context.Context, e Engine) *xormContextType {
return &xormContextType{Context: ctx, engine: e}
}

// Value shadows Value for context.Context but allows us to get ourselves and an Engined object
func (ctx *xormContextType) Value(key any) any {
if key == engineContextKey {
return ctx
}
return ctx.Context.Value(key)
}

// WithContext returns this engine tied to this context
func (ctx *xormContextType) WithContext(other context.Context) *xormContextType {
return newContext(ctx, ctx.engine.Context(other))
func withContextEngine(ctx context.Context, e Engine) context.Context {
return context.WithValue(ctx, engineContextKey, e)
}

var (
Expand Down Expand Up @@ -89,8 +69,8 @@ func contextSafetyCheck(e Engine) {
// GetEngine gets an existing db Engine/Statement or creates a new Session
func GetEngine(ctx context.Context) (e Engine) {
defer func() { contextSafetyCheck(e) }()
if e := getExistingEngine(ctx); e != nil {
return e
if engine, ok := ctx.Value(engineContextKey).(Engine); ok {
return engine
}
return xormEngine.Context(ctx)
}
Expand All @@ -99,17 +79,6 @@ func GetXORMEngineForTesting() *xorm.Engine {
return xormEngine
}

// getExistingEngine gets an existing db Engine/Statement from this context or returns nil
func getExistingEngine(ctx context.Context) (e Engine) {
if engined, ok := ctx.(*xormContextType); ok {
return engined.engine
}
if engined, ok := ctx.Value(engineContextKey).(*xormContextType); ok {
return engined.engine
}
return nil
}

// Committer represents an interface to Commit or Close the Context
type Committer interface {
Commit() error
Expand Down Expand Up @@ -152,24 +121,23 @@ func (c *halfCommitter) Close() error {
// And all operations submitted by the caller stack will be rollbacked as well, not only the operations in the current function.
// d. It doesn't mean rollback is forbidden, but always do it only when there is an error, and you do want to rollback.
func TxContext(parentCtx context.Context) (context.Context, Committer, error) {
if sess, ok := inTransaction(parentCtx); ok {
return newContext(parentCtx, sess), &halfCommitter{committer: sess}, nil
if sess := getTransactionSession(parentCtx); sess != nil {
return withContextEngine(parentCtx, sess), &halfCommitter{committer: sess}, nil
}

sess := xormEngine.NewSession()
if err := sess.Begin(); err != nil {
_ = sess.Close()
return nil, nil, err
}

return newContext(xormContext, sess), sess, nil
return withContextEngine(parentCtx, sess), sess, nil
}

// WithTx represents executing database operations on a transaction, if the transaction exist,
// this function will reuse it otherwise will create a new one and close it when finished.
func WithTx(parentCtx context.Context, f func(ctx context.Context) error) error {
if sess, ok := inTransaction(parentCtx); ok {
err := f(newContext(parentCtx, sess))
if sess := getTransactionSession(parentCtx); sess != nil {
err := f(withContextEngine(parentCtx, sess))
if err != nil {
// rollback immediately, in case the caller ignores returned error and tries to commit the transaction.
_ = sess.Close()
Expand All @@ -195,7 +163,7 @@ func txWithNoCheck(parentCtx context.Context, f func(ctx context.Context) error)
return err
}

if err := f(newContext(parentCtx, sess)); err != nil {
if err := f(withContextEngine(parentCtx, sess)); err != nil {
return err
}

Expand Down Expand Up @@ -340,25 +308,13 @@ func TableName(bean any) string {

// InTransaction returns true if the engine is in a transaction otherwise return false
func InTransaction(ctx context.Context) bool {
_, ok := inTransaction(ctx)
return ok
return getTransactionSession(ctx) != nil
}

func inTransaction(ctx context.Context) (*xorm.Session, bool) {
e := getExistingEngine(ctx)
if e == nil {
return nil, false
}

switch t := e.(type) {
case *xorm.Engine:
return nil, false
case *xorm.Session:
if t.IsInTx() {
return t, true
}
return nil, false
default:
return nil, false
func getTransactionSession(ctx context.Context) *xorm.Session {
e, _ := ctx.Value(engineContextKey).(Engine)
if sess, ok := e.(*xorm.Session); ok && sess.IsInTx() {
return sess
}
return nil
}
2 changes: 0 additions & 2 deletions models/db/engine_init.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ func InitEngine(ctx context.Context) error {
func SetDefaultEngine(ctx context.Context, eng *xorm.Engine) {
xormEngine = eng
xormEngine.SetDefaultContext(ctx)
xormContext = &xormContextType{Context: ctx, engine: xormEngine}
}

// UnsetDefaultEngine closes and unsets the default engine
Expand All @@ -98,7 +97,6 @@ func UnsetDefaultEngine() {
_ = xormEngine.Close()
xormEngine = nil
}
xormContext = nil
}

// InitEngineWithMigration initializes a new xorm.Engine and sets it as the XORM's default context
Expand Down