feat(go): add schema for pending users table;

1. Schema Definition for Pending Users table.
2. Use the newer table everywhere.
3. Migration stubs for v1.5.2.
4. Migration code for Pending Users table;
This commit is contained in:
VishalDalwadi
2026-04-07 12:12:51 +05:30
parent 96e1d92e48
commit b6fbfe815b
13 changed files with 307 additions and 174 deletions
+2 -2
View File
@@ -114,8 +114,6 @@ var Tables = []string{
CACHE_TABLE_NAME,
ENROLLMENT_KEYS_TABLE_NAME,
HOST_ACTIONS_TABLE_NAME,
PENDING_USERS_TABLE_NAME,
USER_INVITES_TABLE_NAME,
TAG_TABLE_NAME,
ACLS_TABLE_NAME,
PEER_ACK_TABLE,
@@ -127,6 +125,8 @@ var Tables = []string{
USER_PERMISSIONS_TABLE_NAME,
NETWORKS_TABLE_NAME,
HOSTS_TABLE_NAME,
PENDING_USERS_TABLE_NAME,
USER_INVITES_TABLE_NAME,
}
func getCurrentDB() map[string]interface{} {
+9 -56
View File
@@ -94,70 +94,23 @@ func GetSuperAdmin() (models.ReturnUser, error) {
return ToReturnUser(_user), nil
}
func InsertPendingUser(u *models.User) error {
data, err := json.Marshal(u)
if err != nil {
return err
}
return database.Insert(u.UserName, string(data), database.PENDING_USERS_TABLE_NAME)
}
func DeletePendingUser(username string) error {
return database.DeleteRecord(database.PENDING_USERS_TABLE_NAME, username)
return (&schema.PendingUser{
Username: username,
}).Delete(db.WithContext(context.TODO()))
}
func IsPendingUser(username string) bool {
records, err := database.FetchRecords(database.PENDING_USERS_TABLE_NAME)
if err != nil {
return false
exists, err := (&schema.PendingUser{
Username: username,
}).Exists(db.WithContext(context.TODO()))
if err == nil {
return exists
}
}
for _, record := range records {
u := models.ReturnUser{}
err := json.Unmarshal([]byte(record), &u)
if err == nil && u.UserName == username {
return true
}
}
return false
}
func ListPendingReturnUsers() ([]models.ReturnUser, error) {
pendingUsers := []models.ReturnUser{}
records, err := database.FetchRecords(database.PENDING_USERS_TABLE_NAME)
if err != nil && !database.IsEmptyRecord(err) {
return pendingUsers, err
}
for _, record := range records {
user := models.ReturnUser{}
err = json.Unmarshal([]byte(record), &user)
if err == nil {
user.IsSuperAdmin = user.PlatformRoleID == schema.SuperAdminRole
user.IsAdmin = user.PlatformRoleID == schema.SuperAdminRole || user.PlatformRoleID == schema.AdminRole
pendingUsers = append(pendingUsers, user)
}
}
return pendingUsers, nil
}
func ListPendingUsers() ([]models.User, error) {
var pendingUsers []models.User
records, err := database.FetchRecords(database.PENDING_USERS_TABLE_NAME)
if err != nil && !database.IsEmptyRecord(err) {
return pendingUsers, err
}
for _, record := range records {
var user models.User
err = json.Unmarshal([]byte(record), &user)
if err == nil {
user.IsSuperAdmin = user.PlatformRoleID == schema.SuperAdminRole
user.IsAdmin = user.PlatformRoleID == schema.SuperAdminRole || user.PlatformRoleID == schema.AdminRole
pendingUsers = append(pendingUsers, user)
}
}
return pendingUsers, nil
}
func GetUserMap() (map[string]schema.User, error) {
users, err := (&schema.User{}).ListAll(db.WithContext(context.TODO()))
if err != nil {
+78
View File
@@ -0,0 +1,78 @@
package migrate
import (
"context"
"errors"
"fmt"
"github.com/gravitl/netmaker/db"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/schema"
"gorm.io/gorm"
)
type migrationFunc func(ctx context.Context) error
// ToSQLSchema migrates the data from key-value
// db to sql db.
func ToSQLSchema() error {
// begin a new transaction.
dbctx := db.BeginTx(context.TODO())
commit := false
defer func() {
if commit {
db.FromContext(dbctx).Commit()
} else {
db.FromContext(dbctx).Rollback()
}
}()
// v1.5.1 migration includes migrating the users, groups, roles, networks and hosts tables.
// future table migrations should be made below this block,
// with a different version number and a similar check for whether the
// migration was already done.
err := ensureMigrationCompleted(dbctx, "migration-v1.5.1", migrateV1_5_1)
if err != nil {
return err
}
// v1.5.2 migration includes migrating the pending users and user invites tables.
err = ensureMigrationCompleted(dbctx, "migration-v1.5.2", migrateV1_5_2)
if err != nil {
return err
}
commit = true
return nil
}
func ensureMigrationCompleted(ctx context.Context, version string, migrate migrationFunc) error {
migrationJob := &schema.Job{
ID: version,
}
err := migrationJob.Get(ctx)
if err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
logger.Log(1, fmt.Sprintf("running migration job %s", migrationJob.ID))
// migrate.
err = migrate(ctx)
if err != nil {
return err
}
// mark migration job completed.
err = migrationJob.Create(ctx)
if err != nil {
return err
}
logger.Log(1, fmt.Sprintf("migration job %s completed", migrationJob.ID))
} else {
logger.Log(1, fmt.Sprintf("migration job %s already completed, skipping", migrationJob.ID))
}
return nil
}
+1 -53
View File
@@ -3,71 +3,19 @@ package migrate
import (
"context"
"encoding/json"
"errors"
"fmt"
"net"
"time"
"github.com/google/uuid"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/db"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/schema"
"gorm.io/datatypes"
"gorm.io/gorm"
)
// ToSQLSchema migrates the data from key-value
// db to sql db.
func ToSQLSchema() error {
// begin a new transaction.
dbctx := db.BeginTx(context.TODO())
commit := false
defer func() {
if commit {
db.FromContext(dbctx).Commit()
} else {
db.FromContext(dbctx).Rollback()
}
}()
// v1.5.1 migration includes migrating the users, groups, roles, networks and hosts tables.
// future table migrations should be made below this block,
// with a different version number and a similar check for whether the
// migration was already done.
migrationJob := &schema.Job{
ID: "migration-v1.5.1",
}
err := migrationJob.Get(dbctx)
if err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
logger.Log(1, fmt.Sprintf("running migration job %s", migrationJob.ID))
// migrate.
err = migrateV1_5_1(dbctx)
if err != nil {
return err
}
// mark migration job completed.
err = migrationJob.Create(dbctx)
if err != nil {
return err
}
logger.Log(1, fmt.Sprintf("migration job %s completed", migrationJob.ID))
commit = true
} else {
logger.Log(1, fmt.Sprintf("migration job %s already completed, skipping", migrationJob.ID))
}
return nil
}
func migrateV1_5_1(ctx context.Context) error {
err := migrateUsers(ctx)
if err != nil {
@@ -116,7 +64,7 @@ func migrateUsers(ctx context.Context) error {
}
}
_user := schema.User{
_user := &schema.User{
ID: "",
Username: user.UserName,
DisplayName: user.DisplayName,
+55
View File
@@ -0,0 +1,55 @@
package migrate
import (
"context"
"encoding/json"
"fmt"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/schema"
)
func migrateV1_5_2(ctx context.Context) error {
err := migratePendingUsers(ctx)
if err != nil {
return err
}
return migrateUserInvites(ctx)
}
func migratePendingUsers(ctx context.Context) error {
records, err := database.FetchRecords(database.PENDING_USERS_TABLE_NAME)
if err != nil && !database.IsEmptyRecord(err) {
return err
}
for _, record := range records {
var pendingUser models.User
err = json.Unmarshal([]byte(record), &pendingUser)
if err != nil {
return err
}
_pendingUser := &schema.PendingUser{
Username: pendingUser.UserName,
ExternalIdentityProviderID: pendingUser.ExternalIdentityProviderID,
}
logger.Log(4, fmt.Sprintf("migrating pending user %s", _pendingUser.Username))
err = _pendingUser.Create(ctx)
if err != nil {
logger.Log(4, fmt.Sprintf("migrating pending user %s failed: %v", _pendingUser.Username, err))
return err
}
}
return nil
}
func migrateUserInvites(ctx context.Context) error {
return nil
}
+4 -4
View File
@@ -116,11 +116,11 @@ func handleAzureCallback(w http.ResponseWriter, r *http.Request) {
handleOauthUserNotAllowedToSignUp(w)
return
}
err = logic.InsertPendingUser(&models.User{
UserName: content.UserPrincipalName,
err := (&schema.PendingUser{
Username: content.UserPrincipalName,
ExternalIdentityProviderID: string(content.ID),
AuthType: schema.OAuth,
})
}).Create(r.Context())
if err != nil {
handleSomethingWentWrong(w)
return
+3 -4
View File
@@ -136,11 +136,10 @@ func handleGithubCallback(w http.ResponseWriter, r *http.Request) {
handleOauthUserNotAllowedToSignUp(w)
return
}
err = logic.InsertPendingUser(&models.User{
UserName: content.Email,
err = (&schema.PendingUser{
Username: content.Email,
ExternalIdentityProviderID: string(content.ID),
AuthType: schema.OAuth,
})
}).Create(r.Context())
if err != nil {
handleSomethingWentWrong(w)
return
+4 -4
View File
@@ -118,11 +118,11 @@ func handleGoogleCallback(w http.ResponseWriter, r *http.Request) {
handleOauthUserNotAllowedToSignUp(w)
return
}
err = logic.InsertPendingUser(&models.User{
UserName: content.Email,
err = (&schema.PendingUser{
Username: content.Email,
ExternalIdentityProviderID: string(content.ID),
AuthType: schema.OAuth,
})
}).Create(r.Context())
if err != nil {
handleSomethingWentWrong(w)
return
+3 -4
View File
@@ -66,11 +66,10 @@ func HandleHeadlessSSOCallback(w http.ResponseWriter, r *http.Request) {
err = user.Get(r.Context())
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { // user must not exist, so try to make one
err = logic.InsertPendingUser(&models.User{
UserName: userClaims.getUserName(),
err = (&schema.PendingUser{
Username: userClaims.getUserName(),
ExternalIdentityProviderID: string(userClaims.ID),
AuthType: schema.OAuth,
})
}).Create(r.Context())
if err != nil {
handleSomethingWentWrong(w)
return
+4 -4
View File
@@ -127,11 +127,11 @@ func handleOIDCCallback(w http.ResponseWriter, r *http.Request) {
handleOauthUserNotAllowedToSignUp(w)
return
}
err = logic.InsertPendingUser(&models.User{
UserName: content.Email,
err = (&schema.PendingUser{
Username: content.Email,
ExternalIdentityProviderID: string(content.ID),
AuthType: schema.OAuth,
})
}).Create(r.Context())
if err != nil {
handleSomethingWentWrong(w)
return
+57 -43
View File
@@ -29,6 +29,7 @@ import (
"github.com/gravitl/netmaker/utils"
"golang.org/x/exp/slog"
"gorm.io/datatypes"
"gorm.io/gorm"
)
func UserHandlers(r *mux.Router) {
@@ -1923,16 +1924,18 @@ func getPendingUsers(w http.ResponseWriter, r *http.Request) {
// set header.
w.Header().Set("Content-Type", "application/json")
users, err := logic.ListPendingReturnUsers()
pendingUsers, err := (&schema.PendingUser{}).ListAll(
r.Context(),
dbtypes.InAscOrder("username"),
)
if err != nil {
logger.Log(0, "failed to fetch users: ", err.Error())
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
}
logic.SortUsers(users[:])
logger.Log(2, r.Header.Get("user"), "fetched pending users")
json.NewEncoder(w).Encode(users)
json.NewEncoder(w).Encode(pendingUsers)
}
// @Summary Approve a pending user
@@ -1948,37 +1951,41 @@ func approvePendingUser(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
var params = mux.Vars(r)
username := params["username"]
users, err := logic.ListPendingUsers()
pendingUser := &schema.PendingUser{
Username: username,
}
err := pendingUser.Get(r.Context())
if err != nil {
logger.Log(0, "failed to fetch users: ", err.Error())
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
errType := logic.Internal
if errors.Is(err, gorm.ErrRecordNotFound) {
errType = logic.NotFound
}
err = fmt.Errorf("failed to approve pending user (%s): error fetching pending user: %w", username, err)
logger.Log(0, err.Error())
logic.ReturnErrorResponse(w, r, logic.FormatError(err, errType))
return
}
for _, user := range users {
if user.UserName == username {
var newPass, fetchErr = logic.FetchPassValue("")
if fetchErr != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(fetchErr, "internal"))
return
}
if err = logic.CreateUser(&schema.User{
Username: user.UserName,
ExternalIdentityProviderID: user.ExternalIdentityProviderID,
Password: newPass,
AuthType: user.AuthType,
PlatformRoleID: schema.ServiceUser,
}); err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("failed to create user: %s", err), "internal"))
return
}
err = logic.DeletePendingUser(username)
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("failed to delete pending user: %s", err), "internal"))
return
}
break
}
var newPass, fetchErr = logic.FetchPassValue("")
if fetchErr != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(fetchErr, "internal"))
return
}
if err = logic.CreateUser(&schema.User{
Username: pendingUser.Username,
ExternalIdentityProviderID: pendingUser.ExternalIdentityProviderID,
Password: newPass,
AuthType: schema.OAuth,
PlatformRoleID: schema.ServiceUser,
}); err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("failed to create user: %s", err), "internal"))
return
}
err = logic.DeletePendingUser(username)
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("failed to delete pending user: %s", err), "internal"))
return
}
logic.LogEvent(&models.Event{
Action: schema.Create,
@@ -2011,23 +2018,30 @@ func deletePendingUser(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
var params = mux.Vars(r)
username := params["username"]
users, err := logic.ListPendingReturnUsers()
pendingUser := &schema.PendingUser{
Username: username,
}
err := pendingUser.Get(r.Context())
if err != nil {
logger.Log(0, "failed to fetch users: ", err.Error())
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
errType := logic.Internal
if errors.Is(err, gorm.ErrRecordNotFound) {
errType = logic.NotFound
}
err = fmt.Errorf("failed to delete pending user (%s): error fetching pending user: %w", username, err)
logger.Log(0, err.Error())
logic.ReturnErrorResponse(w, r, logic.FormatError(err, errType))
return
}
for _, user := range users {
if user.UserName == username {
err = logic.DeletePendingUser(username)
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("failed to delete pending user: %s", err), "internal"))
return
}
break
}
err = pendingUser.Delete(r.Context())
if err != nil {
err = fmt.Errorf("failed to delete pending user (%s): %w", username, err)
logger.Log(0, err.Error())
logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.Internal))
return
}
logic.LogEvent(&models.Event{
Action: schema.Delete,
Source: models.Subject{
@@ -2061,7 +2075,7 @@ func deletePendingUser(w http.ResponseWriter, r *http.Request) {
// @Failure 500 {object} models.ErrorResponse
func deleteAllPendingUsers(w http.ResponseWriter, r *http.Request) {
// set header.
err := database.DeleteAllRecords(database.PENDING_USERS_TABLE_NAME)
err := (&schema.PendingUser{}).DeleteAll(r.Context())
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("failed to delete all pending users "+err.Error()), "internal"))
return
+1
View File
@@ -17,5 +17,6 @@ func ListModels() []interface{} {
&JITRequest{},
&JITGrant{},
&Host{},
&PendingUser{},
}
}
+86
View File
@@ -0,0 +1,86 @@
package schema
import (
"context"
"errors"
"time"
"github.com/google/uuid"
"github.com/gravitl/netmaker/db"
dbtypes "github.com/gravitl/netmaker/db/types"
)
var (
ErrPendingUserIdentifiersNotProvided = errors.New("pending user identifiers not provided")
)
type PendingUser struct {
ID string `gorm:"primaryKey" json:"id"`
Username string `gorm:"unique" json:"username"`
ExternalIdentityProviderID string `json:"external_identity_provider_id"`
CreatedAt time.Time `json:"created_at"`
}
func (p *PendingUser) TableName() string {
return "pending_users_v1"
}
func (p *PendingUser) Create(ctx context.Context) error {
if p.ID == "" {
p.ID = uuid.NewString()
}
return db.FromContext(ctx).Model(&PendingUser{}).Create(p).Error
}
func (p *PendingUser) Exists(ctx context.Context) (bool, error) {
if p.ID == "" && p.Username == "" {
return false, ErrPendingUserIdentifiersNotProvided
}
var exists bool
err := db.FromContext(ctx).Raw(
"SELECT EXISTS (SELECT 1 FROM pending_users_v1 WHERE id = ? OR username = ?)",
p.ID,
p.Username,
).Scan(&exists).Error
return exists, err
}
func (p *PendingUser) Get(ctx context.Context) error {
if p.ID == "" && p.Username == "" {
return ErrPendingUserIdentifiersNotProvided
}
return db.FromContext(ctx).Model(&PendingUser{}).
Where("id = ? OR username = ?", p.ID, p.Username).
First(p).
Error
}
func (p *PendingUser) ListAll(ctx context.Context, options ...dbtypes.Option) ([]PendingUser, error) {
var pendingUsers []PendingUser
query := db.FromContext(ctx).Model(&PendingUser{})
for _, option := range options {
query = option(query)
}
err := query.Find(&pendingUsers).Error
return pendingUsers, err
}
func (p *PendingUser) Delete(ctx context.Context) error {
if p.ID == "" && p.Username == "" {
return ErrPendingUserIdentifiersNotProvided
}
return db.FromContext(ctx).Model(&PendingUser{}).
Where("id = ? OR username = ?", p.ID, p.Username).
Delete(p).
Error
}
func (p *PendingUser) DeleteAll(ctx context.Context) error {
return db.FromContext(ctx).Exec("DELETE FROM pending_users_v1").Error
}