mirror of
https://github.com/gravitl/netmaker.git
synced 2026-04-22 16:07:11 +08:00
Merge branch 'release-v1.5.1' into NM-311
This commit is contained in:
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gravitl/netmaker/db"
|
||||
@@ -13,6 +14,9 @@ import (
|
||||
// SqliteDB is the db object for sqlite database connections
|
||||
var SqliteDB *sql.DB
|
||||
|
||||
// sqliteWriteMu serializes SQLite write operations to reduce lock contention.
|
||||
var sqliteWriteMu sync.Mutex
|
||||
|
||||
// SQLITE_FUNCTIONS - contains a map of the functions for sqlite
|
||||
var SQLITE_FUNCTIONS = map[string]interface{}{
|
||||
INIT_DB: initSqliteDB,
|
||||
@@ -40,6 +44,9 @@ func initSqliteDB() error {
|
||||
}
|
||||
|
||||
func sqliteCreateTable(tableName string) error {
|
||||
sqliteWriteMu.Lock()
|
||||
defer sqliteWriteMu.Unlock()
|
||||
|
||||
statement, err := SqliteDB.Prepare("CREATE TABLE IF NOT EXISTS " + tableName + " (key TEXT NOT NULL UNIQUE PRIMARY KEY, value TEXT)")
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -54,6 +61,9 @@ func sqliteCreateTable(tableName string) error {
|
||||
|
||||
func sqliteInsert(key string, value string, tableName string) error {
|
||||
if key != "" && value != "" {
|
||||
sqliteWriteMu.Lock()
|
||||
defer sqliteWriteMu.Unlock()
|
||||
|
||||
insertSQL := "INSERT OR REPLACE INTO " + tableName + " (key, value) VALUES (?, ?)"
|
||||
statement, err := SqliteDB.Prepare(insertSQL)
|
||||
if err != nil {
|
||||
@@ -81,6 +91,9 @@ func sqliteInsertPeer(key string, value string) error {
|
||||
}
|
||||
|
||||
func sqliteDeleteRecord(tableName string, key string) error {
|
||||
sqliteWriteMu.Lock()
|
||||
defer sqliteWriteMu.Unlock()
|
||||
|
||||
deleteSQL := "DELETE FROM " + tableName + " WHERE key = ?"
|
||||
statement, err := SqliteDB.Prepare(deleteSQL)
|
||||
if err != nil {
|
||||
@@ -94,6 +107,9 @@ func sqliteDeleteRecord(tableName string, key string) error {
|
||||
}
|
||||
|
||||
func sqliteDeleteAllRecords(tableName string) error {
|
||||
sqliteWriteMu.Lock()
|
||||
defer sqliteWriteMu.Unlock()
|
||||
|
||||
deleteSQL := "DELETE FROM " + tableName
|
||||
statement, err := SqliteDB.Prepare(deleteSQL)
|
||||
if err != nil {
|
||||
|
||||
@@ -93,7 +93,7 @@ func migrateV1_5_1(ctx context.Context) error {
|
||||
}
|
||||
|
||||
func migrateUsers(ctx context.Context) error {
|
||||
records, err := database.FetchRecords(database.USERS_TABLE_NAME)
|
||||
records, err := FetchAll(ctx, database.USERS_TABLE_NAME)
|
||||
if err != nil && !database.IsEmptyRecord(err) {
|
||||
return err
|
||||
}
|
||||
@@ -147,7 +147,7 @@ func migrateUsers(ctx context.Context) error {
|
||||
}
|
||||
|
||||
func migrateNetworks(ctx context.Context) error {
|
||||
records, err := database.FetchRecords(database.NETWORKS_TABLE_NAME)
|
||||
records, err := FetchAll(ctx, database.NETWORKS_TABLE_NAME)
|
||||
if err != nil && !database.IsEmptyRecord(err) {
|
||||
return err
|
||||
}
|
||||
@@ -286,7 +286,7 @@ func migrateNetworks(ctx context.Context) error {
|
||||
}
|
||||
|
||||
func migrateUserRoles(ctx context.Context) error {
|
||||
records, err := database.FetchRecords(database.USER_PERMISSIONS_TABLE_NAME)
|
||||
records, err := FetchAll(ctx, database.USER_PERMISSIONS_TABLE_NAME)
|
||||
if err != nil && !database.IsEmptyRecord(err) {
|
||||
return err
|
||||
}
|
||||
@@ -311,7 +311,7 @@ func migrateUserRoles(ctx context.Context) error {
|
||||
}
|
||||
|
||||
func migrateUserGroups(ctx context.Context) error {
|
||||
records, err := database.FetchRecords(database.USER_GROUPS_TABLE_NAME)
|
||||
records, err := FetchAll(ctx, database.USER_GROUPS_TABLE_NAME)
|
||||
if err != nil && !database.IsEmptyRecord(err) {
|
||||
return err
|
||||
}
|
||||
@@ -336,7 +336,7 @@ func migrateUserGroups(ctx context.Context) error {
|
||||
}
|
||||
|
||||
func migrateHosts(ctx context.Context) error {
|
||||
records, err := database.FetchRecords(database.HOSTS_TABLE_NAME)
|
||||
records, err := FetchAll(ctx, database.HOSTS_TABLE_NAME)
|
||||
if err != nil && !database.IsEmptyRecord(err) {
|
||||
return err
|
||||
}
|
||||
@@ -423,3 +423,22 @@ func migrateHosts(ctx context.Context) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func FetchAll(ctx context.Context, tableName string) (map[string]string, error) {
|
||||
row, err := db.FromContext(ctx).Raw("SELECT * FROM " + tableName + " ORDER BY key").Rows()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
records := make(map[string]string)
|
||||
defer row.Close()
|
||||
for row.Next() { // Iterate and fetch the records from result cursor
|
||||
var key string
|
||||
var value string
|
||||
row.Scan(&key, &value)
|
||||
records[key] = value
|
||||
}
|
||||
if len(records) == 0 {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
return records, nil
|
||||
}
|
||||
|
||||
+27
-4
@@ -9,6 +9,7 @@ import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gravitl/netmaker/db"
|
||||
"github.com/gravitl/netmaker/logger"
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
@@ -91,8 +92,7 @@ func handleAzureCallback(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
user := &schema.User{Username: content.UserPrincipalName}
|
||||
err = user.Get(r.Context())
|
||||
user, err := GetMatchingUser(content)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) { // user must not exist, so try to make one
|
||||
if inviteExists {
|
||||
@@ -142,8 +142,7 @@ func handleAzureCallback(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
user = &schema.User{Username: content.UserPrincipalName}
|
||||
err = user.Get(r.Context())
|
||||
user, err = GetMatchingUser(content)
|
||||
if err != nil {
|
||||
handleOauthUserNotFound(w)
|
||||
return
|
||||
@@ -199,6 +198,30 @@ func handleAzureCallback(w http.ResponseWriter, r *http.Request) {
|
||||
http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?login="+jwt+"&user="+user.Username, http.StatusPermanentRedirect)
|
||||
}
|
||||
|
||||
func GetMatchingUser(oauthUser *OAuthUser) (*schema.User, error) {
|
||||
user := &schema.User{
|
||||
Username: oauthUser.UserPrincipalName,
|
||||
}
|
||||
err := user.Get(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
return user, nil
|
||||
}
|
||||
|
||||
user = &schema.User{
|
||||
ExternalIdentityProviderID: string(oauthUser.ID),
|
||||
}
|
||||
err = user.GetByExternalID(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func getAzureUserInfo(state string, code string) (*OAuthUser, error) {
|
||||
oauth_state_string, isValid := logic.IsStateValid(state)
|
||||
if (!isValid || state != oauth_state_string) && !isStateCached(state) {
|
||||
|
||||
@@ -62,8 +62,14 @@ func HandleHeadlessSSOCallback(w http.ResponseWriter, r *http.Request) {
|
||||
handleOauthUserSignUpApprovalPending(w)
|
||||
return
|
||||
}
|
||||
user := &schema.User{Username: userClaims.getUserName()}
|
||||
err = user.Get(r.Context())
|
||||
|
||||
var user *schema.User
|
||||
if logic.GetServerSettings().AuthProvider == azure_ad_provider_name {
|
||||
user, err = GetMatchingUser(userClaims)
|
||||
} else {
|
||||
user = &schema.User{Username: userClaims.getUserName()}
|
||||
err = user.Get(r.Context())
|
||||
}
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) { // user must not exist, so try to make one
|
||||
err = logic.InsertPendingUser(&models.User{
|
||||
|
||||
@@ -74,6 +74,17 @@ func (u *User) Get(ctx context.Context) error {
|
||||
Error
|
||||
}
|
||||
|
||||
func (u *User) GetByExternalID(ctx context.Context) error {
|
||||
if u.ExternalIdentityProviderID == "" {
|
||||
return ErrUserIdentifiersNotProvided
|
||||
}
|
||||
|
||||
return db.FromContext(ctx).Model(&User{}).
|
||||
Where("external_identity_provider_id = ?", u.ExternalIdentityProviderID).
|
||||
First(u).
|
||||
Error
|
||||
}
|
||||
|
||||
func (u *User) GetSuperAdmin(ctx context.Context) error {
|
||||
return db.FromContext(ctx).Model(u).
|
||||
Where("platform_role_id = ?", SuperAdminRole).
|
||||
|
||||
Reference in New Issue
Block a user