From b6fbfe815b864d5df75ff498658574ceefa73809 Mon Sep 17 00:00:00 2001 From: VishalDalwadi Date: Tue, 7 Apr 2026 12:12:51 +0530 Subject: [PATCH] feat(go): add schema for pending users table; 1. Schema Definition for Pending Users table. 2. Use the newer table everywhere. 3. Migration stubs for v1.5.2. 4. Migration code for Pending Users table; --- database/database.go | 4 +- logic/users.go | 65 +++------------------- migrate/migrate_schema.go | 78 ++++++++++++++++++++++++++ migrate/migrate_v1_5_1.go | 54 +----------------- migrate/migrate_v1_5_2.go | 55 +++++++++++++++++++ pro/auth/azure-ad.go | 8 +-- pro/auth/github.go | 7 +-- pro/auth/google.go | 8 +-- pro/auth/headless_callback.go | 7 +-- pro/auth/oidc.go | 8 +-- pro/controllers/users.go | 100 +++++++++++++++++++--------------- schema/models.go | 1 + schema/pending_users.go | 86 +++++++++++++++++++++++++++++ 13 files changed, 307 insertions(+), 174 deletions(-) create mode 100644 migrate/migrate_schema.go create mode 100644 migrate/migrate_v1_5_2.go create mode 100644 schema/pending_users.go diff --git a/database/database.go b/database/database.go index 874b5c48..adf73105 100644 --- a/database/database.go +++ b/database/database.go @@ -114,8 +114,6 @@ var Tables = []string{ CACHE_TABLE_NAME, ENROLLMENT_KEYS_TABLE_NAME, HOST_ACTIONS_TABLE_NAME, - PENDING_USERS_TABLE_NAME, - USER_INVITES_TABLE_NAME, TAG_TABLE_NAME, ACLS_TABLE_NAME, PEER_ACK_TABLE, @@ -127,6 +125,8 @@ var Tables = []string{ USER_PERMISSIONS_TABLE_NAME, NETWORKS_TABLE_NAME, HOSTS_TABLE_NAME, + PENDING_USERS_TABLE_NAME, + USER_INVITES_TABLE_NAME, } func getCurrentDB() map[string]interface{} { diff --git a/logic/users.go b/logic/users.go index 84579cc4..e7efcd83 100644 --- a/logic/users.go +++ b/logic/users.go @@ -94,70 +94,23 @@ func GetSuperAdmin() (models.ReturnUser, error) { return ToReturnUser(_user), nil } -func InsertPendingUser(u *models.User) error { - data, err := json.Marshal(u) - if err != nil { - return err - } - return database.Insert(u.UserName, string(data), database.PENDING_USERS_TABLE_NAME) -} - func DeletePendingUser(username string) error { - return database.DeleteRecord(database.PENDING_USERS_TABLE_NAME, username) + return (&schema.PendingUser{ + Username: username, + }).Delete(db.WithContext(context.TODO())) } func IsPendingUser(username string) bool { - records, err := database.FetchRecords(database.PENDING_USERS_TABLE_NAME) - if err != nil { - return false + exists, err := (&schema.PendingUser{ + Username: username, + }).Exists(db.WithContext(context.TODO())) + if err == nil { + return exists + } - } - for _, record := range records { - u := models.ReturnUser{} - err := json.Unmarshal([]byte(record), &u) - if err == nil && u.UserName == username { - return true - } - } return false } -func ListPendingReturnUsers() ([]models.ReturnUser, error) { - pendingUsers := []models.ReturnUser{} - records, err := database.FetchRecords(database.PENDING_USERS_TABLE_NAME) - if err != nil && !database.IsEmptyRecord(err) { - return pendingUsers, err - } - for _, record := range records { - user := models.ReturnUser{} - err = json.Unmarshal([]byte(record), &user) - if err == nil { - user.IsSuperAdmin = user.PlatformRoleID == schema.SuperAdminRole - user.IsAdmin = user.PlatformRoleID == schema.SuperAdminRole || user.PlatformRoleID == schema.AdminRole - pendingUsers = append(pendingUsers, user) - } - } - return pendingUsers, nil -} - -func ListPendingUsers() ([]models.User, error) { - var pendingUsers []models.User - records, err := database.FetchRecords(database.PENDING_USERS_TABLE_NAME) - if err != nil && !database.IsEmptyRecord(err) { - return pendingUsers, err - } - for _, record := range records { - var user models.User - err = json.Unmarshal([]byte(record), &user) - if err == nil { - user.IsSuperAdmin = user.PlatformRoleID == schema.SuperAdminRole - user.IsAdmin = user.PlatformRoleID == schema.SuperAdminRole || user.PlatformRoleID == schema.AdminRole - pendingUsers = append(pendingUsers, user) - } - } - return pendingUsers, nil -} - func GetUserMap() (map[string]schema.User, error) { users, err := (&schema.User{}).ListAll(db.WithContext(context.TODO())) if err != nil { diff --git a/migrate/migrate_schema.go b/migrate/migrate_schema.go new file mode 100644 index 00000000..9a0342b0 --- /dev/null +++ b/migrate/migrate_schema.go @@ -0,0 +1,78 @@ +package migrate + +import ( + "context" + "errors" + "fmt" + + "github.com/gravitl/netmaker/db" + "github.com/gravitl/netmaker/logger" + "github.com/gravitl/netmaker/schema" + "gorm.io/gorm" +) + +type migrationFunc func(ctx context.Context) error + +// ToSQLSchema migrates the data from key-value +// db to sql db. +func ToSQLSchema() error { + // begin a new transaction. + dbctx := db.BeginTx(context.TODO()) + commit := false + defer func() { + if commit { + db.FromContext(dbctx).Commit() + } else { + db.FromContext(dbctx).Rollback() + } + }() + + // v1.5.1 migration includes migrating the users, groups, roles, networks and hosts tables. + // future table migrations should be made below this block, + // with a different version number and a similar check for whether the + // migration was already done. + err := ensureMigrationCompleted(dbctx, "migration-v1.5.1", migrateV1_5_1) + if err != nil { + return err + } + + // v1.5.2 migration includes migrating the pending users and user invites tables. + err = ensureMigrationCompleted(dbctx, "migration-v1.5.2", migrateV1_5_2) + if err != nil { + return err + } + + commit = true + return nil +} + +func ensureMigrationCompleted(ctx context.Context, version string, migrate migrationFunc) error { + migrationJob := &schema.Job{ + ID: version, + } + err := migrationJob.Get(ctx) + if err != nil { + if !errors.Is(err, gorm.ErrRecordNotFound) { + return err + } + + logger.Log(1, fmt.Sprintf("running migration job %s", migrationJob.ID)) + // migrate. + err = migrate(ctx) + if err != nil { + return err + } + + // mark migration job completed. + err = migrationJob.Create(ctx) + if err != nil { + return err + } + + logger.Log(1, fmt.Sprintf("migration job %s completed", migrationJob.ID)) + } else { + logger.Log(1, fmt.Sprintf("migration job %s already completed, skipping", migrationJob.ID)) + } + + return nil +} diff --git a/migrate/migrate_v1_5_1.go b/migrate/migrate_v1_5_1.go index 93f49d85..0ddb79f3 100644 --- a/migrate/migrate_v1_5_1.go +++ b/migrate/migrate_v1_5_1.go @@ -3,71 +3,19 @@ package migrate import ( "context" "encoding/json" - "errors" "fmt" "net" "time" "github.com/google/uuid" "github.com/gravitl/netmaker/database" - "github.com/gravitl/netmaker/db" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/schema" "gorm.io/datatypes" - "gorm.io/gorm" ) -// ToSQLSchema migrates the data from key-value -// db to sql db. -func ToSQLSchema() error { - // begin a new transaction. - dbctx := db.BeginTx(context.TODO()) - commit := false - defer func() { - if commit { - db.FromContext(dbctx).Commit() - } else { - db.FromContext(dbctx).Rollback() - } - }() - - // v1.5.1 migration includes migrating the users, groups, roles, networks and hosts tables. - // future table migrations should be made below this block, - // with a different version number and a similar check for whether the - // migration was already done. - migrationJob := &schema.Job{ - ID: "migration-v1.5.1", - } - err := migrationJob.Get(dbctx) - if err != nil { - if !errors.Is(err, gorm.ErrRecordNotFound) { - return err - } - - logger.Log(1, fmt.Sprintf("running migration job %s", migrationJob.ID)) - // migrate. - err = migrateV1_5_1(dbctx) - if err != nil { - return err - } - - // mark migration job completed. - err = migrationJob.Create(dbctx) - if err != nil { - return err - } - - logger.Log(1, fmt.Sprintf("migration job %s completed", migrationJob.ID)) - commit = true - } else { - logger.Log(1, fmt.Sprintf("migration job %s already completed, skipping", migrationJob.ID)) - } - - return nil -} - func migrateV1_5_1(ctx context.Context) error { err := migrateUsers(ctx) if err != nil { @@ -116,7 +64,7 @@ func migrateUsers(ctx context.Context) error { } } - _user := schema.User{ + _user := &schema.User{ ID: "", Username: user.UserName, DisplayName: user.DisplayName, diff --git a/migrate/migrate_v1_5_2.go b/migrate/migrate_v1_5_2.go new file mode 100644 index 00000000..34be5a04 --- /dev/null +++ b/migrate/migrate_v1_5_2.go @@ -0,0 +1,55 @@ +package migrate + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/gravitl/netmaker/database" + "github.com/gravitl/netmaker/logger" + "github.com/gravitl/netmaker/models" + "github.com/gravitl/netmaker/schema" +) + +func migrateV1_5_2(ctx context.Context) error { + err := migratePendingUsers(ctx) + if err != nil { + return err + } + + return migrateUserInvites(ctx) +} + +func migratePendingUsers(ctx context.Context) error { + records, err := database.FetchRecords(database.PENDING_USERS_TABLE_NAME) + if err != nil && !database.IsEmptyRecord(err) { + return err + } + + for _, record := range records { + var pendingUser models.User + err = json.Unmarshal([]byte(record), &pendingUser) + if err != nil { + return err + } + + _pendingUser := &schema.PendingUser{ + Username: pendingUser.UserName, + ExternalIdentityProviderID: pendingUser.ExternalIdentityProviderID, + } + + logger.Log(4, fmt.Sprintf("migrating pending user %s", _pendingUser.Username)) + + err = _pendingUser.Create(ctx) + if err != nil { + logger.Log(4, fmt.Sprintf("migrating pending user %s failed: %v", _pendingUser.Username, err)) + return err + } + } + + return nil +} + +func migrateUserInvites(ctx context.Context) error { + return nil +} diff --git a/pro/auth/azure-ad.go b/pro/auth/azure-ad.go index b6f1e9ff..1db44c4f 100644 --- a/pro/auth/azure-ad.go +++ b/pro/auth/azure-ad.go @@ -116,11 +116,11 @@ func handleAzureCallback(w http.ResponseWriter, r *http.Request) { handleOauthUserNotAllowedToSignUp(w) return } - err = logic.InsertPendingUser(&models.User{ - UserName: content.UserPrincipalName, + + err := (&schema.PendingUser{ + Username: content.UserPrincipalName, ExternalIdentityProviderID: string(content.ID), - AuthType: schema.OAuth, - }) + }).Create(r.Context()) if err != nil { handleSomethingWentWrong(w) return diff --git a/pro/auth/github.go b/pro/auth/github.go index 9a8d6eb6..12f0aae9 100644 --- a/pro/auth/github.go +++ b/pro/auth/github.go @@ -136,11 +136,10 @@ func handleGithubCallback(w http.ResponseWriter, r *http.Request) { handleOauthUserNotAllowedToSignUp(w) return } - err = logic.InsertPendingUser(&models.User{ - UserName: content.Email, + err = (&schema.PendingUser{ + Username: content.Email, ExternalIdentityProviderID: string(content.ID), - AuthType: schema.OAuth, - }) + }).Create(r.Context()) if err != nil { handleSomethingWentWrong(w) return diff --git a/pro/auth/google.go b/pro/auth/google.go index fdbd12b9..329ad644 100644 --- a/pro/auth/google.go +++ b/pro/auth/google.go @@ -118,11 +118,11 @@ func handleGoogleCallback(w http.ResponseWriter, r *http.Request) { handleOauthUserNotAllowedToSignUp(w) return } - err = logic.InsertPendingUser(&models.User{ - UserName: content.Email, + + err = (&schema.PendingUser{ + Username: content.Email, ExternalIdentityProviderID: string(content.ID), - AuthType: schema.OAuth, - }) + }).Create(r.Context()) if err != nil { handleSomethingWentWrong(w) return diff --git a/pro/auth/headless_callback.go b/pro/auth/headless_callback.go index 7cd13423..61165307 100644 --- a/pro/auth/headless_callback.go +++ b/pro/auth/headless_callback.go @@ -66,11 +66,10 @@ func HandleHeadlessSSOCallback(w http.ResponseWriter, r *http.Request) { err = user.Get(r.Context()) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { // user must not exist, so try to make one - err = logic.InsertPendingUser(&models.User{ - UserName: userClaims.getUserName(), + err = (&schema.PendingUser{ + Username: userClaims.getUserName(), ExternalIdentityProviderID: string(userClaims.ID), - AuthType: schema.OAuth, - }) + }).Create(r.Context()) if err != nil { handleSomethingWentWrong(w) return diff --git a/pro/auth/oidc.go b/pro/auth/oidc.go index 0dab8190..1462d0ef 100644 --- a/pro/auth/oidc.go +++ b/pro/auth/oidc.go @@ -127,11 +127,11 @@ func handleOIDCCallback(w http.ResponseWriter, r *http.Request) { handleOauthUserNotAllowedToSignUp(w) return } - err = logic.InsertPendingUser(&models.User{ - UserName: content.Email, + + err = (&schema.PendingUser{ + Username: content.Email, ExternalIdentityProviderID: string(content.ID), - AuthType: schema.OAuth, - }) + }).Create(r.Context()) if err != nil { handleSomethingWentWrong(w) return diff --git a/pro/controllers/users.go b/pro/controllers/users.go index 6d80c484..034ee242 100644 --- a/pro/controllers/users.go +++ b/pro/controllers/users.go @@ -29,6 +29,7 @@ import ( "github.com/gravitl/netmaker/utils" "golang.org/x/exp/slog" "gorm.io/datatypes" + "gorm.io/gorm" ) func UserHandlers(r *mux.Router) { @@ -1923,16 +1924,18 @@ func getPendingUsers(w http.ResponseWriter, r *http.Request) { // set header. w.Header().Set("Content-Type", "application/json") - users, err := logic.ListPendingReturnUsers() + pendingUsers, err := (&schema.PendingUser{}).ListAll( + r.Context(), + dbtypes.InAscOrder("username"), + ) if err != nil { logger.Log(0, "failed to fetch users: ", err.Error()) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } - logic.SortUsers(users[:]) logger.Log(2, r.Header.Get("user"), "fetched pending users") - json.NewEncoder(w).Encode(users) + json.NewEncoder(w).Encode(pendingUsers) } // @Summary Approve a pending user @@ -1948,37 +1951,41 @@ func approvePendingUser(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") var params = mux.Vars(r) username := params["username"] - users, err := logic.ListPendingUsers() + pendingUser := &schema.PendingUser{ + Username: username, + } + err := pendingUser.Get(r.Context()) if err != nil { - logger.Log(0, "failed to fetch users: ", err.Error()) - logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) + errType := logic.Internal + if errors.Is(err, gorm.ErrRecordNotFound) { + errType = logic.NotFound + } + err = fmt.Errorf("failed to approve pending user (%s): error fetching pending user: %w", username, err) + logger.Log(0, err.Error()) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, errType)) return } - for _, user := range users { - if user.UserName == username { - var newPass, fetchErr = logic.FetchPassValue("") - if fetchErr != nil { - logic.ReturnErrorResponse(w, r, logic.FormatError(fetchErr, "internal")) - return - } - if err = logic.CreateUser(&schema.User{ - Username: user.UserName, - ExternalIdentityProviderID: user.ExternalIdentityProviderID, - Password: newPass, - AuthType: user.AuthType, - PlatformRoleID: schema.ServiceUser, - }); err != nil { - logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("failed to create user: %s", err), "internal")) - return - } - err = logic.DeletePendingUser(username) - if err != nil { - logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("failed to delete pending user: %s", err), "internal")) - return - } - break - } + + var newPass, fetchErr = logic.FetchPassValue("") + if fetchErr != nil { + logic.ReturnErrorResponse(w, r, logic.FormatError(fetchErr, "internal")) + return + } + if err = logic.CreateUser(&schema.User{ + Username: pendingUser.Username, + ExternalIdentityProviderID: pendingUser.ExternalIdentityProviderID, + Password: newPass, + AuthType: schema.OAuth, + PlatformRoleID: schema.ServiceUser, + }); err != nil { + logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("failed to create user: %s", err), "internal")) + return + } + err = logic.DeletePendingUser(username) + if err != nil { + logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("failed to delete pending user: %s", err), "internal")) + return } logic.LogEvent(&models.Event{ Action: schema.Create, @@ -2011,23 +2018,30 @@ func deletePendingUser(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") var params = mux.Vars(r) username := params["username"] - users, err := logic.ListPendingReturnUsers() + pendingUser := &schema.PendingUser{ + Username: username, + } + err := pendingUser.Get(r.Context()) if err != nil { - logger.Log(0, "failed to fetch users: ", err.Error()) - logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) + errType := logic.Internal + if errors.Is(err, gorm.ErrRecordNotFound) { + errType = logic.NotFound + } + err = fmt.Errorf("failed to delete pending user (%s): error fetching pending user: %w", username, err) + logger.Log(0, err.Error()) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, errType)) return } - for _, user := range users { - if user.UserName == username { - err = logic.DeletePendingUser(username) - if err != nil { - logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("failed to delete pending user: %s", err), "internal")) - return - } - break - } + + err = pendingUser.Delete(r.Context()) + if err != nil { + err = fmt.Errorf("failed to delete pending user (%s): %w", username, err) + logger.Log(0, err.Error()) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.Internal)) + return } + logic.LogEvent(&models.Event{ Action: schema.Delete, Source: models.Subject{ @@ -2061,7 +2075,7 @@ func deletePendingUser(w http.ResponseWriter, r *http.Request) { // @Failure 500 {object} models.ErrorResponse func deleteAllPendingUsers(w http.ResponseWriter, r *http.Request) { // set header. - err := database.DeleteAllRecords(database.PENDING_USERS_TABLE_NAME) + err := (&schema.PendingUser{}).DeleteAll(r.Context()) if err != nil { logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("failed to delete all pending users "+err.Error()), "internal")) return diff --git a/schema/models.go b/schema/models.go index 5f0c29d2..1887f7c2 100644 --- a/schema/models.go +++ b/schema/models.go @@ -17,5 +17,6 @@ func ListModels() []interface{} { &JITRequest{}, &JITGrant{}, &Host{}, + &PendingUser{}, } } diff --git a/schema/pending_users.go b/schema/pending_users.go new file mode 100644 index 00000000..6cf0c843 --- /dev/null +++ b/schema/pending_users.go @@ -0,0 +1,86 @@ +package schema + +import ( + "context" + "errors" + "time" + + "github.com/google/uuid" + "github.com/gravitl/netmaker/db" + dbtypes "github.com/gravitl/netmaker/db/types" +) + +var ( + ErrPendingUserIdentifiersNotProvided = errors.New("pending user identifiers not provided") +) + +type PendingUser struct { + ID string `gorm:"primaryKey" json:"id"` + Username string `gorm:"unique" json:"username"` + ExternalIdentityProviderID string `json:"external_identity_provider_id"` + CreatedAt time.Time `json:"created_at"` +} + +func (p *PendingUser) TableName() string { + return "pending_users_v1" +} + +func (p *PendingUser) Create(ctx context.Context) error { + if p.ID == "" { + p.ID = uuid.NewString() + } + + return db.FromContext(ctx).Model(&PendingUser{}).Create(p).Error +} + +func (p *PendingUser) Exists(ctx context.Context) (bool, error) { + if p.ID == "" && p.Username == "" { + return false, ErrPendingUserIdentifiersNotProvided + } + + var exists bool + err := db.FromContext(ctx).Raw( + "SELECT EXISTS (SELECT 1 FROM pending_users_v1 WHERE id = ? OR username = ?)", + p.ID, + p.Username, + ).Scan(&exists).Error + return exists, err +} + +func (p *PendingUser) Get(ctx context.Context) error { + if p.ID == "" && p.Username == "" { + return ErrPendingUserIdentifiersNotProvided + } + + return db.FromContext(ctx).Model(&PendingUser{}). + Where("id = ? OR username = ?", p.ID, p.Username). + First(p). + Error +} + +func (p *PendingUser) ListAll(ctx context.Context, options ...dbtypes.Option) ([]PendingUser, error) { + var pendingUsers []PendingUser + query := db.FromContext(ctx).Model(&PendingUser{}) + + for _, option := range options { + query = option(query) + } + + err := query.Find(&pendingUsers).Error + return pendingUsers, err +} + +func (p *PendingUser) Delete(ctx context.Context) error { + if p.ID == "" && p.Username == "" { + return ErrPendingUserIdentifiersNotProvided + } + + return db.FromContext(ctx).Model(&PendingUser{}). + Where("id = ? OR username = ?", p.ID, p.Username). + Delete(p). + Error +} + +func (p *PendingUser) DeleteAll(ctx context.Context) error { + return db.FromContext(ctx).Exec("DELETE FROM pending_users_v1").Error +}