Merge pull request #3878 from gravitl/NM-256-v1.5.0-patch

v1.5.0: DB optimisations, Add Postgresql connection pool limits, add SSO cache cleanup hook
This commit is contained in:
Abhishek Kondur
2026-02-27 13:51:16 +04:00
committed by GitHub
16 changed files with 265 additions and 105 deletions
+18
View File
@@ -106,6 +106,14 @@ func createNs(w http.ResponseWriter, r *http.Request) {
},
}
}
if req.Fallback {
for _, domain := range req.Domains {
if domain.IsADDomain {
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("cannot configure ad domain for fallback nameservers"), "badrequest"))
return
}
}
}
ns := schema.Nameserver{
ID: uuid.New().String(),
Name: req.Name,
@@ -208,6 +216,16 @@ func updateNs(w http.ResponseWriter, r *http.Request) {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
}
if updateNs.Fallback {
for _, domain := range updateNs.Domains {
if domain.IsADDomain {
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("cannot configure ad domain for fallback nameservers"), "badrequest"))
return
}
}
}
if updateNs.Tags == nil {
updateNs.Tags = make(datatypes.JSONMap)
}
+6 -9
View File
@@ -90,6 +90,8 @@ const (
DELETE_ALL = "deleteall"
// FETCH_ALL - fetch table contents const
FETCH_ALL = "fetchall"
// FETCH_ONE - fetch a single record const
FETCH_ONE = "fetchone"
// CLOSE_DB - graceful close of db const
CLOSE_DB = "closedb"
// isconnected
@@ -203,16 +205,11 @@ func DeleteAllRecords(tableName string) error {
return nil
}
// FetchRecord - fetches a record
// FetchRecord - fetches a single record by key
func FetchRecord(tableName string, key string) (string, error) {
results, err := FetchRecords(tableName)
if err != nil {
return "", err
}
if results[key] == "" {
return "", errors.New(NO_RECORD)
}
return results[key], nil
dbMutex.RLock()
defer dbMutex.RUnlock()
return getCurrentDB()[FETCH_ONE].(func(string, string) (string, error))(tableName, key)
}
// FetchRecords - fetches all records in given table
+13
View File
@@ -20,6 +20,7 @@ var PG_FUNCTIONS = map[string]interface{}{
DELETE: pgDeleteRecord,
DELETE_ALL: pgDeleteAllRecords,
FETCH_ALL: pgFetchRecords,
FETCH_ONE: pgFetchRecord,
CLOSE_DB: pgCloseDB,
isConnected: pgIsConnected,
}
@@ -105,6 +106,18 @@ func pgDeleteAllRecords(tableName string) error {
return nil
}
func pgFetchRecord(tableName string, key string) (string, error) {
var value string
err := PGDB.QueryRow("SELECT value FROM "+tableName+" WHERE key = $1", key).Scan(&value)
if err != nil {
if err == sql.ErrNoRows {
return "", errors.New(NO_RECORD)
}
return "", err
}
return value, nil
}
func pgFetchRecords(tableName string) (map[string]string, error) {
row, err := PGDB.Query("SELECT * FROM " + tableName + " ORDER BY key")
if err != nil {
+14
View File
@@ -19,6 +19,7 @@ var RQLITE_FUNCTIONS = map[string]interface{}{
DELETE: rqliteDeleteRecord,
DELETE_ALL: rqliteDeleteAllRecords,
FETCH_ALL: rqliteFetchRecords,
FETCH_ONE: rqliteFetchRecord,
CLOSE_DB: rqliteCloseDB,
isConnected: rqliteConnected,
}
@@ -84,6 +85,19 @@ func rqliteDeleteAllRecords(tableName string) error {
return nil
}
func rqliteFetchRecord(tableName string, key string) (string, error) {
row, err := RQliteDatabase.QueryOne("SELECT value FROM " + tableName + " WHERE key = '" + key + "'")
if err != nil {
return "", err
}
if row.Next() {
var value string
row.Scan(&value)
return value, nil
}
return "", errors.New(NO_RECORD)
}
func rqliteFetchRecords(tableName string) (map[string]string, error) {
row, err := RQliteDatabase.QueryOne("SELECT * FROM " + tableName + " ORDER BY key")
if err != nil {
+13
View File
@@ -20,6 +20,7 @@ var SQLITE_FUNCTIONS = map[string]interface{}{
DELETE: sqliteDeleteRecord,
DELETE_ALL: sqliteDeleteAllRecords,
FETCH_ALL: sqliteFetchRecords,
FETCH_ONE: sqliteFetchRecord,
CLOSE_DB: sqliteCloseDB,
isConnected: sqliteConnected,
}
@@ -103,6 +104,18 @@ func sqliteDeleteAllRecords(tableName string) error {
return nil
}
func sqliteFetchRecord(tableName string, key string) (string, error) {
var value string
err := SqliteDB.QueryRow("SELECT value FROM "+tableName+" WHERE key = ?", key).Scan(&value)
if err != nil {
if err == sql.ErrNoRows {
return "", errors.New(NO_RECORD)
}
return "", err
}
return value, nil
}
func sqliteFetchRecords(tableName string) (map[string]string, error) {
row, err := SqliteDB.Query("SELECT * FROM " + tableName + " ORDER BY key")
if err != nil {
+17 -2
View File
@@ -2,11 +2,12 @@ package db
import (
"fmt"
"github.com/gravitl/netmaker/servercfg"
"os"
"strconv"
"time"
"github.com/gravitl/netmaker/config"
"github.com/gravitl/netmaker/servercfg"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
@@ -30,9 +31,23 @@ func (pg *postgresConnector) connect() (*gorm.DB, error) {
pgConf.SSLMode,
)
return gorm.Open(postgres.Open(dsn), &gorm.Config{
gormDB, err := gorm.Open(postgres.Open(dsn), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
return nil, err
}
sqlDB, err := gormDB.DB()
if err != nil {
return gormDB, err
}
sqlDB.SetMaxOpenConns(25)
sqlDB.SetMaxIdleConns(10)
sqlDB.SetConnMaxLifetime(5 * time.Minute)
sqlDB.SetConnMaxIdleTime(2 * time.Minute)
return gormDB, nil
}
func GetSQLConf() config.SQLConfig {
+86 -53
View File
@@ -18,6 +18,7 @@ import (
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/servercfg"
)
const (
@@ -36,77 +37,63 @@ var ResetIDPSyncHook = func() {}
// HasSuperAdmin - checks if server has an superadmin/owner
func HasSuperAdmin() (bool, error) {
collection, err := database.FetchRecords(database.USERS_TABLE_NAME)
users, err := GetUsersDB()
if err != nil {
if database.IsEmptyRecord(err) {
return false, nil
} else {
return true, err
}
return true, err
}
for _, user := range users {
if user.PlatformRoleID == models.SuperAdminRole {
return true, nil
}
}
for _, value := range collection { // filter for isadmin true
return false, nil
}
// GetUsersDB - gets users
func GetUsersDB() ([]models.User, error) {
if servercfg.CacheEnabled() {
users := getUsersFromCache()
if len(users) != 0 {
return users, nil
}
}
var users []models.User
collection, err := database.FetchRecords(database.USERS_TABLE_NAME)
if err != nil {
return users, err
}
cacheMap := make(map[string]models.User, len(collection))
for _, value := range collection {
var user models.User
err = json.Unmarshal([]byte(value), &user)
if err != nil {
continue
}
if user.PlatformRoleID == models.SuperAdminRole {
return true, nil
}
}
return false, err
}
// GetUsersDB - gets users
func GetUsersDB() ([]models.User, error) {
var users []models.User
collection, err := database.FetchRecords(database.USERS_TABLE_NAME)
if err != nil {
return users, err
}
for _, value := range collection {
var user models.User
err = json.Unmarshal([]byte(value), &user)
if err != nil {
continue // get users
}
users = append(users, user)
cacheMap[user.UserName] = user
}
return users, err
if servercfg.CacheEnabled() {
loadUsersIntoCache(cacheMap)
}
return users, nil
}
// GetUsers - gets users
func GetUsers() ([]models.ReturnUser, error) {
var users []models.ReturnUser
collection, err := database.FetchRecords(database.USERS_TABLE_NAME)
dbUsers, err := GetUsersDB()
if err != nil {
return users, err
return nil, err
}
for _, value := range collection {
var user models.ReturnUser
err = json.Unmarshal([]byte(value), &user)
if err != nil {
continue // get users
}
user.IsSuperAdmin = user.PlatformRoleID == models.SuperAdminRole
user.IsAdmin = user.PlatformRoleID == models.SuperAdminRole || user.PlatformRoleID == models.AdminRole
users = append(users, user)
users := make([]models.ReturnUser, 0, len(dbUsers))
for _, u := range dbUsers {
users = append(users, ToReturnUser(u))
}
return users, err
return users, nil
}
// IsOauthUser - returns
@@ -192,6 +179,9 @@ func CreateUser(user *models.User) error {
logger.Log(0, "failed to insert user", err.Error())
return err
}
if servercfg.CacheEnabled() {
storeUserInCache(*user)
}
return nil
}
@@ -273,7 +263,9 @@ func UpsertUser(user models.User) error {
slog.Error("error inserting user", "user", user.UserName, "error", err.Error())
return err
}
if servercfg.CacheEnabled() {
storeUserInCache(user)
}
return nil
}
@@ -407,6 +399,12 @@ func UpdateUser(userchange, user *models.User) (*models.User, error) {
if err = database.Insert(user.UserName, string(data), database.USERS_TABLE_NAME); err != nil {
return &models.User{}, err
}
if servercfg.CacheEnabled() {
if queryUser != user.UserName {
deleteUserFromCache(queryUser)
}
storeUserInCache(*user)
}
logger.Log(1, "updated user", queryUser)
return user, nil
}
@@ -446,6 +444,9 @@ func DeleteUser(user string) error {
if err != nil {
return err
}
if servercfg.CacheEnabled() {
deleteUserFromCache(user)
}
go RemoveUserFromAclPolicy(user)
return (&schema.UserAccessToken{UserName: user}).DeleteAllUserTokens(db.WithContext(context.TODO()))
}
@@ -535,3 +536,35 @@ func IsStateValid(state string) (string, bool) {
func delState(state string) error {
return database.DeleteRecord(database.SSO_STATE_CACHE, state)
}
// CleanExpiredSSOStates removes expired SSO state entries from the database
// to prevent unbounded table growth that degrades FetchRecord performance.
func CleanExpiredSSOStates() error {
records, err := database.FetchRecords(database.SSO_STATE_CACHE)
if err != nil {
if database.IsEmptyRecord(err) {
return nil
}
return err
}
for key, value := range records {
var s models.SsoState
if err := json.Unmarshal([]byte(value), &s); err != nil {
_ = database.DeleteRecord(database.SSO_STATE_CACHE, key)
continue
}
if s.IsExpired() {
_ = database.DeleteRecord(database.SSO_STATE_CACHE, key)
}
}
return nil
}
// AddSSOStateCleanupHook registers a periodic cleanup of expired SSO states
func AddSSOStateCleanupHook() {
HookManagerCh <- models.HookDetails{
ID: "sso-state-cleanup",
Hook: WrapHook(CleanExpiredSSOStates),
Interval: 15 * time.Minute,
}
}
+4
View File
@@ -566,6 +566,7 @@ func getNameserversForNode(node *models.Node) (returnNsLi []models.Nameserver) {
IPs: filteredIps,
MatchDomain: domain.Domain,
IsSearchDomain: domain.IsSearchDomain,
IsADDomain: domain.IsADDomain,
})
}
}
@@ -584,6 +585,7 @@ func getNameserversForNode(node *models.Node) (returnNsLi []models.Nameserver) {
IPs: filteredIps,
MatchDomain: domain.Domain,
IsSearchDomain: domain.IsSearchDomain,
IsADDomain: domain.IsADDomain,
})
}
}
@@ -639,6 +641,7 @@ func getNameserversForHost(h *models.Host) (returnNsLi []models.Nameserver) {
IPs: filteredIps,
MatchDomain: domain.Domain,
IsSearchDomain: domain.IsSearchDomain,
IsADDomain: domain.IsADDomain,
})
}
}
@@ -657,6 +660,7 @@ func getNameserversForHost(h *models.Host) (returnNsLi []models.Nameserver) {
IPs: filteredIps,
MatchDomain: domain.Domain,
IsSearchDomain: domain.IsSearchDomain,
IsADDomain: domain.IsADDomain,
})
}
}
+61 -21
View File
@@ -4,14 +4,62 @@ import (
"encoding/json"
"errors"
"sort"
"sync"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/servercfg"
)
var (
userCacheMutex = &sync.RWMutex{}
usersCacheMap = make(map[string]models.User)
)
func getUserFromCache(username string) (models.User, bool) {
userCacheMutex.RLock()
user, ok := usersCacheMap[username]
userCacheMutex.RUnlock()
return user, ok
}
func getUsersFromCache() []models.User {
userCacheMutex.RLock()
users := make([]models.User, 0, len(usersCacheMap))
for _, user := range usersCacheMap {
users = append(users, user)
}
userCacheMutex.RUnlock()
return users
}
func storeUserInCache(user models.User) {
userCacheMutex.Lock()
usersCacheMap[user.UserName] = user
userCacheMutex.Unlock()
}
func deleteUserFromCache(username string) {
userCacheMutex.Lock()
delete(usersCacheMap, username)
userCacheMutex.Unlock()
}
func loadUsersIntoCache(users map[string]models.User) {
userCacheMutex.Lock()
usersCacheMap = users
userCacheMutex.Unlock()
}
// GetUser - gets a user
// TODO support "masteradmin"
func GetUser(username string) (*models.User, error) {
if servercfg.CacheEnabled() {
if user, ok := getUserFromCache(username); ok {
user.IsSuperAdmin = user.PlatformRoleID == models.SuperAdminRole
user.IsAdmin = user.PlatformRoleID == models.SuperAdminRole || user.PlatformRoleID == models.AdminRole
return &user, nil
}
}
var user models.User
record, err := database.FetchRecord(database.USERS_TABLE_NAME, username)
if err != nil {
@@ -23,23 +71,19 @@ func GetUser(username string) (*models.User, error) {
user.IsSuperAdmin = user.PlatformRoleID == models.SuperAdminRole
user.IsAdmin = user.PlatformRoleID == models.SuperAdminRole || user.PlatformRoleID == models.AdminRole
if servercfg.CacheEnabled() {
storeUserInCache(user)
}
return &user, err
}
// GetReturnUser - gets a user
func GetReturnUser(username string) (models.ReturnUser, error) {
var user models.ReturnUser
record, err := database.FetchRecord(database.USERS_TABLE_NAME, username)
u, err := GetUser(username)
if err != nil {
return user, err
}
if err = json.Unmarshal([]byte(record), &user); err != nil {
return models.ReturnUser{}, err
}
user.IsSuperAdmin = user.PlatformRoleID == models.SuperAdminRole
user.IsAdmin = user.PlatformRoleID == models.SuperAdminRole || user.PlatformRoleID == models.AdminRole
return user, err
return ToReturnUser(*u), nil
}
// ToReturnUser - gets a user as a return user
@@ -160,19 +204,15 @@ func ListPendingUsers() ([]models.User, error) {
}
func GetUserMap() (map[string]models.User, error) {
userMap := make(map[string]models.User)
records, err := database.FetchRecords(database.USERS_TABLE_NAME)
users, err := GetUsersDB()
if err != nil && !database.IsEmptyRecord(err) {
return userMap, err
return nil, err
}
for _, record := range records {
user := models.User{}
err = json.Unmarshal([]byte(record), &user)
if err == nil {
user.IsSuperAdmin = user.PlatformRoleID == models.SuperAdminRole
user.IsAdmin = user.PlatformRoleID == models.SuperAdminRole || user.PlatformRoleID == models.AdminRole
userMap[user.UserName] = user
}
userMap := make(map[string]models.User, len(users))
for _, user := range users {
user.IsSuperAdmin = user.PlatformRoleID == models.SuperAdminRole
user.IsAdmin = user.PlatformRoleID == models.SuperAdminRole || user.PlatformRoleID == models.AdminRole
userMap[user.UserName] = user
}
return userMap, nil
}
+3
View File
@@ -130,6 +130,8 @@ func initialize() { // Client Mode Prereq Check
_, _ = logic.GetAllExtClients()
_ = logic.ListAcls()
_, _ = logic.GetAllEnrollmentKeys()
_, _ = logic.GetUsersDB()
_ = logic.CleanExpiredSSOStates()
migrate.Run()
@@ -193,6 +195,7 @@ func startControllers(wg *sync.WaitGroup, ctx context.Context) {
wg.Add(1)
go logic.StartHookManager(ctx, wg)
logic.InitNetworkHooks()
logic.AddSSOStateCleanupHook()
}
// Should we be using a context vice a waitgroup????????????
+1
View File
@@ -64,6 +64,7 @@ type Nameserver struct {
MatchDomain string `json:"match_domain"`
IsSearchDomain bool `json:"is_search_domain"`
IsFallback bool `json:"is_fallback"`
IsADDomain bool `json:"is_ad_domain"`
}
type OldPeerUpdateFields struct {
+18 -17
View File
@@ -46,23 +46,24 @@ type IngressGwUsers struct {
// UserRemoteGws - struct to hold user's remote gws
type UserRemoteGws struct {
GwID string `json:"remote_access_gw_id"`
GWName string `json:"gw_name"`
Network string `json:"network"`
Connected bool `json:"connected"`
IsInternetGateway bool `json:"is_internet_gateway"`
GwClient ExtClient `json:"gw_client"`
GwPeerPublicKey string `json:"gw_peer_public_key"`
GwListenPort int `json:"gw_listen_port"`
Metadata string `json:"metadata"`
AllowedEndpoints []string `json:"allowed_endpoints"`
NetworkAddresses []string `json:"network_addresses"`
Status NodeStatus `json:"status"`
ManageDNS bool `json:"manage_dns"`
DnsAddress string `json:"dns_address"`
Addresses string `json:"addresses"`
MatchDomains []string `json:"match_domains"`
SearchDomains []string `json:"search_domains"`
GwID string `json:"remote_access_gw_id"`
GWName string `json:"gw_name"`
Network string `json:"network"`
Connected bool `json:"connected"`
IsInternetGateway bool `json:"is_internet_gateway"`
GwClient ExtClient `json:"gw_client"`
GwPeerPublicKey string `json:"gw_peer_public_key"`
GwListenPort int `json:"gw_listen_port"`
Metadata string `json:"metadata"`
AllowedEndpoints []string `json:"allowed_endpoints"`
NetworkAddresses []string `json:"network_addresses"`
Status NodeStatus `json:"status"`
ManageDNS bool `json:"manage_dns"`
DnsAddress string `json:"dns_address"`
Addresses string `json:"addresses"`
MatchDomains []string `json:"match_domains"`
SearchDomains []string `json:"search_domains"`
Nameservers []Nameserver `json:"nameservers"`
}
// UserRAGs - struct for user access gws
+2 -3
View File
@@ -106,9 +106,8 @@ func handleGithubCallback(w http.ResponseWriter, r *http.Request) {
if err != nil {
user.UserName = content.Email
user.ExternalIdentityProviderID = content.Login
database.DeleteRecord(database.USERS_TABLE_NAME, content.Login)
d, _ := json.Marshal(user)
database.Insert(user.UserName, string(d), database.USERS_TABLE_NAME)
_ = logic.DeleteUser(content.Login)
_ = logic.UpsertUser(*user)
}
}
+2
View File
@@ -1681,6 +1681,7 @@ func getUserRemoteAccessGwsV1(w http.ResponseWriter, r *http.Request) {
// skip fallback nameservers for user remote access gws.
continue
}
gw.Nameservers = append(gw.Nameservers, nsI)
gw.MatchDomains = append(gw.MatchDomains, nsI.MatchDomain)
if nsI.IsSearchDomain {
gw.SearchDomains = append(gw.SearchDomains, nsI.MatchDomain)
@@ -1738,6 +1739,7 @@ func getUserRemoteAccessGwsV1(w http.ResponseWriter, r *http.Request) {
// skip fallback nameservers for user remote access gws.
continue
}
gw.Nameservers = append(gw.Nameservers, nsI)
gw.MatchDomains = append(gw.MatchDomains, nsI.MatchDomain)
if nsI.IsSearchDomain {
gw.SearchDomains = append(gw.SearchDomains, nsI.MatchDomain)
+6
View File
@@ -86,6 +86,7 @@ func GetNameserversForNode(node *models.Node) (returnNsLi []models.Nameserver) {
IPs: filteredIps,
MatchDomain: domain.Domain,
IsSearchDomain: domain.IsSearchDomain,
IsADDomain: domain.IsADDomain,
})
}
}
@@ -105,6 +106,7 @@ func GetNameserversForNode(node *models.Node) (returnNsLi []models.Nameserver) {
IPs: filteredIps,
MatchDomain: domain.Domain,
IsSearchDomain: domain.IsSearchDomain,
IsADDomain: domain.IsADDomain,
})
}
}
@@ -129,6 +131,7 @@ func GetNameserversForNode(node *models.Node) (returnNsLi []models.Nameserver) {
IPs: nsI.Servers,
MatchDomain: domain.Domain,
IsSearchDomain: domain.IsSearchDomain,
IsADDomain: domain.IsADDomain,
})
}
}
@@ -184,6 +187,7 @@ func GetNameserversForHost(h *models.Host) (returnNsLi []models.Nameserver) {
IPs: filteredIps,
MatchDomain: domain.Domain,
IsSearchDomain: domain.IsSearchDomain,
IsADDomain: domain.IsADDomain,
})
}
}
@@ -203,6 +207,7 @@ func GetNameserversForHost(h *models.Host) (returnNsLi []models.Nameserver) {
IPs: filteredIps,
MatchDomain: domain.Domain,
IsSearchDomain: domain.IsSearchDomain,
IsADDomain: domain.IsADDomain,
})
}
}
@@ -227,6 +232,7 @@ func GetNameserversForHost(h *models.Host) (returnNsLi []models.Nameserver) {
IPs: nsI.Servers,
MatchDomain: domain.Domain,
IsSearchDomain: domain.IsSearchDomain,
IsADDomain: domain.IsADDomain,
})
}
}
+1
View File
@@ -31,6 +31,7 @@ type Nameserver struct {
type NameserverDomain struct {
Domain string `json:"domain"`
IsSearchDomain bool `json:"is_search_domain"`
IsADDomain bool `json:"is_ad_domain"`
}
func (ns *Nameserver) Get(ctx context.Context) error {