diff --git a/controllers/hosts.go b/controllers/hosts.go index 2422a996..154faf9a 100644 --- a/controllers/hosts.go +++ b/controllers/hosts.go @@ -698,25 +698,25 @@ func bulkDeleteHosts(w http.ResponseWriter, r *http.Request) { for _, idStr := range req.IDs { hostID, err := uuid.Parse(idStr) if err != nil { - slog.Error("bulk host delete: invalid host id", "id", idStr) + slog.Debug("bulk host delete: invalid host id", "id", idStr) continue } currHost := &schema.Host{ID: hostID} if err = currHost.Get(db.WithContext(context.Background())); err != nil { - slog.Error("bulk host delete: host not found", "id", idStr, "error", err) + slog.Debug("bulk host delete: host not found", "id", idStr, "error", err) continue } var hostNodes []models.Node for _, nodeID := range currHost.Nodes { node, err := logic.GetNodeByID(nodeID) if err != nil { - slog.Error("bulk host delete: failed to get node", "nodeid", nodeID, "error", err) + slog.Debug("bulk host delete: failed to get node", "nodeid", nodeID, "error", err) continue } hostNodes = append(hostNodes, node) } if err = logic.RemoveHost(currHost, true); err != nil { - slog.Error("bulk host delete: failed to remove host", "id", idStr, "error", err) + slog.Debug("bulk host delete: failed to remove host", "id", idStr, "error", err) continue } for _, node := range hostNodes { @@ -724,14 +724,14 @@ func bulkDeleteHosts(w http.ResponseWriter, r *http.Request) { } if servercfg.GetBrokerType() == servercfg.EmqxBrokerType { if err := mq.GetEmqxHandler().DeleteEmqxUser(currHost.ID.String()); err != nil { - slog.Error("bulk host delete: failed to remove EMQX credentials", "id", currHost.ID, "error", err) + slog.Debug("bulk host delete: failed to remove EMQX credentials", "id", currHost.ID, "error", err) } } if err = mq.HostUpdate(&models.HostUpdate{ Action: models.DeleteHost, Host: *currHost, }); err != nil { - slog.Error("bulk host delete: failed to send host update", "id", currHost.ID, "error", err) + slog.Debug("bulk host delete: failed to send host update", "id", currHost.ID, "error", err) } (&schema.PendingHost{HostID: currHost.ID.String()}).DeleteAllPendingHosts(db.WithContext(context.TODO())) logic.LogEvent(&models.Event{ diff --git a/controllers/server.go b/controllers/server.go index 649f370b..f6dccfdd 100644 --- a/controllers/server.go +++ b/controllers/server.go @@ -336,7 +336,8 @@ func reInit(curr, new models.ServerSettings, force bool) { // On force AutoUpdate change, change AutoUpdate for all hosts. // On force FlowLogs enable, enable FlowLogs for all hosts. // On FlowLogs disable, forced or not, disable FlowLogs for all hosts. - if force || !new.EnableFlowLogs { + // On NetclientAutoUpdate disable, forced or not, disable AutoUpdate for all hosts. + if force || !new.EnableFlowLogs || !new.NetclientAutoUpdate { if curr.NetclientAutoUpdate != new.NetclientAutoUpdate || curr.EnableFlowLogs != new.EnableFlowLogs { hosts, _ := (&schema.Host{}).ListAll(db.WithContext(context.TODO())) diff --git a/controllers/user.go b/controllers/user.go index 381c760b..31ed7637 100644 --- a/controllers/user.go +++ b/controllers/user.go @@ -56,7 +56,7 @@ func userHandlers(r *mux.Router) { r.HandleFunc("/api/users/{username}/disable", logic.SecurityCheck(true, http.HandlerFunc(disableUserAccount))).Methods(http.MethodPost) r.HandleFunc("/api/users/{username}/settings", logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(getUserSettings)))).Methods(http.MethodGet) r.HandleFunc("/api/users/{username}/settings", logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(updateUserSettings)))).Methods(http.MethodPut) - r.HandleFunc("/api/v1/users", logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(getUserV1)))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/users", logic.SecurityCheck(false, logic.ContinueIfUserMatchOrAdmin(http.HandlerFunc(getUserV1)))).Methods(http.MethodGet) r.HandleFunc("/api/users", logic.SecurityCheck(true, http.HandlerFunc(getUsers))).Methods(http.MethodGet) r.HandleFunc("/api/v2/users", logic.SecurityCheck(true, http.HandlerFunc(listUsers))).Methods(http.MethodGet) r.HandleFunc("/api/v1/users/bulk", logic.SecurityCheck(true, http.HandlerFunc(bulkDeleteUsers))).Methods(http.MethodDelete) diff --git a/database/sqlite.go b/database/sqlite.go index 2b87c6d8..513c9414 100644 --- a/database/sqlite.go +++ b/database/sqlite.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "errors" + "sync" "time" "github.com/gravitl/netmaker/db" @@ -13,6 +14,9 @@ import ( // SqliteDB is the db object for sqlite database connections var SqliteDB *sql.DB +// sqliteWriteMu serializes SQLite write operations to reduce lock contention. +var sqliteWriteMu sync.Mutex + // SQLITE_FUNCTIONS - contains a map of the functions for sqlite var SQLITE_FUNCTIONS = map[string]interface{}{ INIT_DB: initSqliteDB, @@ -40,6 +44,9 @@ func initSqliteDB() error { } func sqliteCreateTable(tableName string) error { + sqliteWriteMu.Lock() + defer sqliteWriteMu.Unlock() + statement, err := SqliteDB.Prepare("CREATE TABLE IF NOT EXISTS " + tableName + " (key TEXT NOT NULL UNIQUE PRIMARY KEY, value TEXT)") if err != nil { return err @@ -54,6 +61,9 @@ func sqliteCreateTable(tableName string) error { func sqliteInsert(key string, value string, tableName string) error { if key != "" && value != "" { + sqliteWriteMu.Lock() + defer sqliteWriteMu.Unlock() + insertSQL := "INSERT OR REPLACE INTO " + tableName + " (key, value) VALUES (?, ?)" statement, err := SqliteDB.Prepare(insertSQL) if err != nil { @@ -81,6 +91,9 @@ func sqliteInsertPeer(key string, value string) error { } func sqliteDeleteRecord(tableName string, key string) error { + sqliteWriteMu.Lock() + defer sqliteWriteMu.Unlock() + deleteSQL := "DELETE FROM " + tableName + " WHERE key = ?" statement, err := SqliteDB.Prepare(deleteSQL) if err != nil { @@ -94,6 +107,9 @@ func sqliteDeleteRecord(tableName string, key string) error { } func sqliteDeleteAllRecords(tableName string) error { + sqliteWriteMu.Lock() + defer sqliteWriteMu.Unlock() + deleteSQL := "DELETE FROM " + tableName statement, err := SqliteDB.Prepare(deleteSQL) if err != nil { diff --git a/db/sqlite.go b/db/sqlite.go index 22a6f96a..6e3df11f 100644 --- a/db/sqlite.go +++ b/db/sqlite.go @@ -62,6 +62,7 @@ func (s *sqliteConnector) connect() (*gorm.DB, error) { return nil, err } + sqlDB.SetMaxOpenConns(1) sqlDB.SetMaxIdleConns(1) return db, nil diff --git a/logic/auth.go b/logic/auth.go index 3e08a3fb..3a7af446 100644 --- a/logic/auth.go +++ b/logic/auth.go @@ -451,7 +451,7 @@ func DeleteUser(user string) error { return err } - go RemoveUserFromAclPolicy(user) + RemoveUserFromAclPolicy(user) return (&schema.UserAccessToken{UserName: user}).DeleteAllUserTokens(db.WithContext(context.TODO())) } diff --git a/logic/extpeers.go b/logic/extpeers.go index 77ccdc38..54649a05 100644 --- a/logic/extpeers.go +++ b/logic/extpeers.go @@ -684,6 +684,7 @@ func getExtPeerEgressRoute(node models.Node, extPeer models.ExtClient) (egressRo NodeAddr: node.Address, NodeAddr6: node.Address6, EgressRanges: extPeer.ExtraAllowedIPs, + Network: node.Network, } for _, extraAllowedIP := range extPeer.ExtraAllowedIPs { r.EgressRangesWithMetric = append(r.EgressRangesWithMetric, models.EgressRangeMetric{ diff --git a/logic/peers.go b/logic/peers.go index bbfb5d3d..fbf889cb 100644 --- a/logic/peers.go +++ b/logic/peers.go @@ -238,14 +238,14 @@ func computeHostPeerInfo(host *schema.Host, allNodes []models.Node, serverInfo m // GetPeerUpdateForHost - gets the consolidated peer update for the host from all networks func GetPeerUpdateForHost(network string, host *schema.Host, allNodes []models.Node, - deletedNode *models.Node, deletedClients []models.ExtClient) (models.HostPeerUpdate, error) { + deletedNode *models.Node, deletedClients []models.ExtClient) (hostPeerUpdate models.HostPeerUpdate, err error) { if host == nil { return models.HostPeerUpdate{}, errors.New("host is nil") } // track which nodes are deleted // after peer calculation, if peer not in list, add delete config of peer - hostPeerUpdate := models.HostPeerUpdate{ + hostPeerUpdate = models.HostPeerUpdate{ Host: *host, Server: servercfg.GetServer(), ServerVersion: servercfg.GetVersion(), @@ -266,6 +266,9 @@ func GetPeerUpdateForHost(network string, host *schema.Host, allNodes []models.N GwNodes: make(map[schema.NetworkID][]models.Node), AddressIdentityMap: make(map[string]models.PeerIdentity), } + defer func() { + hostPeerUpdate.EgressRoutes = deduplicateEgressRoutes(hostPeerUpdate.EgressRoutes) + }() if host.DNS == "no" { hostPeerUpdate.ManageDNS = false } @@ -931,6 +934,18 @@ func getNodeAllowedIPs(peer, node *models.Node) []net.IPNet { } return allowedips } +func deduplicateEgressRoutes(routes []models.EgressNetworkRoutes) []models.EgressNetworkRoutes { + seen := make(map[string]struct{}, len(routes)) + result := make([]models.EgressNetworkRoutes, 0, len(routes)) + for _, r := range routes { + key := r.PeerKey + "|" + r.Network + if _, exists := seen[key]; !exists { + seen[key] = struct{}{} + result = append(result, r) + } + } + return result +} func getCIDRMaskFromAddr(addr string) net.IPMask { cidr := net.CIDRMask(32, 32) diff --git a/logic/security.go b/logic/security.go index e9ab9717..da67deb0 100644 --- a/logic/security.go +++ b/logic/security.go @@ -1,11 +1,13 @@ package logic import ( + "context" "errors" "net/http" "strings" "github.com/golang-jwt/jwt/v4" + "github.com/gravitl/netmaker/db" "github.com/gorilla/mux" "github.com/gravitl/netmaker/models" @@ -181,3 +183,31 @@ func ContinueIfUserMatch(next http.Handler) http.HandlerFunc { next.ServeHTTP(w, r) } } + +func ContinueIfUserMatchOrAdmin(next http.Handler) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + var errorResponse = models.ErrorResponse{ + Code: http.StatusForbidden, Message: Forbidden_Msg, + } + + user := &schema.User{ + Username: r.Header.Get("user"), + } + err := user.Get(db.WithContext(context.TODO())) + if err == nil && (user.PlatformRoleID == schema.SuperAdminRole || user.PlatformRoleID == schema.AdminRole) { + next.ServeHTTP(w, r) + return + } + + var params = mux.Vars(r) + var requestedUser = params["username"] + if requestedUser == "" { + requestedUser = r.URL.Query().Get("username") + } + if requestedUser != r.Header.Get("user") { + ReturnErrorResponse(w, r, errorResponse) + return + } + next.ServeHTTP(w, r) + } +} diff --git a/migrate/migrate.go b/migrate/migrate.go index c8b284be..9dcd562b 100644 --- a/migrate/migrate.go +++ b/migrate/migrate.go @@ -7,6 +7,7 @@ import ( "log" "net" "slices" + "strings" "time" "golang.org/x/exp/slog" @@ -38,6 +39,7 @@ func Run() { resync() deleteOldExtclients() cleanupDeletedUserGroupRefs() + migrateNameservers() } func updateNetworks() { @@ -588,6 +590,9 @@ func migrateToEgressV1() { CreatedBy: user.UserName, CreatedAt: time.Now().UTC(), } + if !e.Nat { + e.Mode = schema.DisabledNAT + } err = e.Create(db.WithContext(context.TODO())) if err == nil { acl := models.Acl{ @@ -737,8 +742,10 @@ func cleanupDeletedUserGroupRefs() { existingGroups[group.ID] = group } + existingUsers := make(map[string]schema.User) users, _ := (&schema.User{}).ListAll(db.WithContext(context.TODO())) for _, user := range users { + existingUsers[user.Username] = user var update bool for groupID := range user.UserGroups.Data() { if _, ok := existingGroups[groupID]; !ok { @@ -770,6 +777,10 @@ func cleanupDeletedUserGroupRefs() { newSrc = append(newSrc, src) } } + } else if src.ID == models.UserAclID && src.Value != "*" { + if _, ok := existingUsers[src.Value]; ok { + newSrc = append(newSrc, src) + } } else { newSrc = append(newSrc, src) } @@ -798,3 +809,86 @@ func cleanupDeletedUserGroupRefs() { } } } + +func migrateNameservers() { + networks, _ := (&schema.Network{}).ListAll(db.WithContext(context.TODO())) + for _, network := range networks { + _ = logic.CreateFallbackNameserver(network.Name) + } + + nameservers, _ := (&schema.Nameserver{}).ListAll(db.WithContext(context.TODO())) + for _, nameserver := range nameservers { + if len(nameserver.Domains) != 0 { + for _, matchDomain := range nameserver.MatchDomains { + nameserver.Domains = append(nameserver.Domains, schema.NameserverDomain{ + Domain: matchDomain, + }) + } + + nameserver.MatchDomains = []string{} + + _ = nameserver.Update(db.WithContext(context.TODO())) + } + } + + superAdmin := &schema.User{} + err := superAdmin.GetSuperAdmin(db.WithContext(context.TODO())) + if err != nil { + return + } + + nodes, _ := logic.GetAllNodes() + for _, node := range nodes { + if !node.IsGw { + continue + } + + if node.IngressDNS != "" { + var nsIPs []string + for _, nsIP := range strings.Split(node.IngressDNS, ",") { + nsIP = strings.TrimSpace(nsIP) + + if (node.Address.IP != nil && node.Address.IP.String() == nsIP) || + (node.Address6.IP != nil && node.Address6.IP.String() == nsIP) { + continue + } + if nsIP == "8.8.8.8" || nsIP == "1.1.1.1" || nsIP == "9.9.9.9" { + continue + } + + nsIPs = append(nsIPs, nsIP) + } + + if len(nsIPs) > 0 { + host := &schema.Host{ + ID: node.HostID, + } + err := host.Get(db.WithContext(context.TODO())) + if err != nil { + continue + } + ns := schema.Nameserver{ + ID: uuid.NewString(), + Name: fmt.Sprintf("%s gw nameservers", host.Name), + NetworkID: node.Network, + Servers: nsIPs, + MatchAll: true, + Domains: []schema.NameserverDomain{ + { + Domain: ".", + }, + }, + Nodes: datatypes.JSONMap{ + node.ID.String(): struct{}{}, + }, + Tags: make(datatypes.JSONMap), + Status: true, + CreatedBy: superAdmin.Username, + } + _ = ns.Create(db.WithContext(context.TODO())) + node.IngressDNS = "" + _ = logic.UpsertNode(&node) + } + } + } +} diff --git a/migrate/migrate_v1_5_1.go b/migrate/migrate_v1_5_1.go index 0cc42046..297f08f6 100644 --- a/migrate/migrate_v1_5_1.go +++ b/migrate/migrate_v1_5_1.go @@ -5,8 +5,10 @@ import ( "encoding/json" "errors" "fmt" + "net" "time" + "github.com/google/uuid" "github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/db" "github.com/gravitl/netmaker/logger" @@ -91,7 +93,7 @@ func migrateV1_5_1(ctx context.Context) error { } func migrateUsers(ctx context.Context) error { - records, err := database.FetchRecords(database.USERS_TABLE_NAME) + records, err := FetchAll(ctx, database.USERS_TABLE_NAME) if err != nil && !database.IsEmptyRecord(err) { return err } @@ -145,7 +147,7 @@ func migrateUsers(ctx context.Context) error { } func migrateNetworks(ctx context.Context) error { - records, err := database.FetchRecords(database.NETWORKS_TABLE_NAME) + records, err := FetchAll(ctx, database.NETWORKS_TABLE_NAME) if err != nil && !database.IsEmptyRecord(err) { return err } @@ -204,13 +206,87 @@ func migrateNetworks(ctx context.Context) error { logger.Log(4, fmt.Sprintf("migrating network %s failed: %v", _network.Name, err)) return err } + + var cidr, cidrv6 *net.IPNet + if len(network.AddressRange) != 0 { + _, cidr, err = net.ParseCIDR(network.AddressRange) + if err != nil { + err = fmt.Errorf("error parsing network (%s) cidr (%s): %v", _network.Name, network.AddressRange, err) + logger.Log(4, fmt.Sprintf("migrating network %s failed: %v", _network.Name, err)) + return err + } + } + + if len(network.AddressRange6) != 0 { + _, cidrv6, err = net.ParseCIDR(network.AddressRange6) + if err != nil { + err = fmt.Errorf("error parsing network (%s) cidr (%s): %v", _network.Name, network.AddressRange6, err) + logger.Log(4, fmt.Sprintf("migrating network %s failed: %v", _network.Name, err)) + return err + } + } + + superAdmin := &schema.User{} + err = superAdmin.GetSuperAdmin(ctx) + if err != nil { + err = fmt.Errorf("error getting superadmin: %v", err) + logger.Log(4, fmt.Sprintf("migrating network %s failed: %v", _network.Name, err)) + return err + } + + if len(network.NameServers) > 0 { + ns := schema.Nameserver{ + ID: uuid.NewString(), + Name: "upstream nameservers", + NetworkID: _network.Name, + Servers: []string{}, + MatchAll: true, + Domains: []schema.NameserverDomain{ + { + Domain: ".", + }, + }, + Tags: datatypes.JSONMap{ + "*": struct{}{}, + }, + Nodes: make(datatypes.JSONMap), + Status: true, + CreatedBy: superAdmin.Username, + } + + for _, nsIP := range network.NameServers { + ip := net.ParseIP(nsIP) + if ip == nil { + continue + } + + if ip.To4() != nil { + if cidr != nil && !cidr.Contains(ip) { + ns.Servers = append(ns.Servers, nsIP) + } + } else { + if cidrv6 != nil && !cidrv6.Contains(ip) { + ns.Servers = append(ns.Servers, nsIP) + } + } + } + + if len(ns.Servers) > 0 { + err = ns.Create(ctx) + if err != nil { + err = fmt.Errorf("error creating upstream nameserver for network (%s): %v", _network.Name, err) + logger.Log(4, fmt.Sprintf("migrating network %s failed: %v", _network.Name, err)) + return err + } + } + } } return nil } func migrateUserRoles(ctx context.Context) error { - records, err := database.FetchRecords(database.USER_PERMISSIONS_TABLE_NAME) + records, err := FetchAll(ctx, database.USER_PERMISSIONS_TABLE_NAME) if err != nil && !database.IsEmptyRecord(err) { return err } @@ -235,7 +311,7 @@ func migrateUserRoles(ctx context.Context) error { } func migrateUserGroups(ctx context.Context) error { - records, err := database.FetchRecords(database.USER_GROUPS_TABLE_NAME) + records, err := FetchAll(ctx, database.USER_GROUPS_TABLE_NAME) if err != nil && !database.IsEmptyRecord(err) { return err } @@ -260,7 +336,7 @@ func migrateUserGroups(ctx context.Context) error { } func migrateHosts(ctx context.Context) error { - records, err := database.FetchRecords(database.HOSTS_TABLE_NAME) + records, err := FetchAll(ctx, database.HOSTS_TABLE_NAME) if err != nil && !database.IsEmptyRecord(err) { return err } @@ -336,10 +412,6 @@ func migrateHosts(ctx context.Context) error { } } - if _host.IsDefault && !_host.AutoUpdate { - _host.AutoUpdate = true - } - logger.Log(4, fmt.Sprintf("migrating host %s", _host.ID)) err = _host.Create(ctx) @@ -351,3 +423,22 @@ func migrateHosts(ctx context.Context) error { return nil } + +func FetchAll(ctx context.Context, tableName string) (map[string]string, error) { + row, err := db.FromContext(ctx).Raw("SELECT * FROM " + tableName + " ORDER BY key").Rows() + if err != nil { + return nil, err + } + records := make(map[string]string) + defer row.Close() + for row.Next() { // Iterate and fetch the records from result cursor + var key string + var value string + row.Scan(&key, &value) + records[key] = value + } + if len(records) == 0 { + return nil, gorm.ErrRecordNotFound + } + return records, nil +} diff --git a/pro/controllers/events.go b/pro/controllers/events.go index 319b3d9b..954ab69b 100644 --- a/pro/controllers/events.go +++ b/pro/controllers/events.go @@ -106,7 +106,8 @@ func listUserActivity(w http.ResponseWriter, r *http.Request) { logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } - if caller.Username != username && caller.PlatformRoleID != schema.SuperAdminRole && caller.PlatformRoleID != schema.AdminRole { + if caller.Username != username && caller.PlatformRoleID != schema.SuperAdminRole && + caller.PlatformRoleID != schema.AdminRole && caller.PlatformRoleID != schema.Auditor { logic.ReturnErrorResponse(w, r, models.ErrorResponse{ Code: http.StatusForbidden, Message: "you are not authorized to view this user's activity", diff --git a/pro/controllers/flows.go b/pro/controllers/flows.go index 7d8f0637..da75f993 100644 --- a/pro/controllers/flows.go +++ b/pro/controllers/flows.go @@ -75,11 +75,11 @@ type FlowRow struct { // @Param network_id query string false "Filter by network ID" // @Param from query string false "Start time in RFC3339 format" // @Param to query string false "End time in RFC3339 format" -// @Param src_type query string false "Source type filter" +// @Param src_type []query string false "Source type filter" // @Param src_entity_id query string false "Source entity ID filter" -// @Param dst_type query string false "Destination type filter" +// @Param dst_type []query string false "Destination type filter" // @Param dst_entity_id query string false "Destination entity ID filter" -// @Param protocol query string false "Protocol filter" +// @Param protocol []query string false "Protocol filter" // @Param node_id query string false "Node ID filter" // @Param username query string false "Username filter" // @Param page query int false "Page number" @@ -115,7 +115,7 @@ func handleListFlows(w http.ResponseWriter, r *http.Request) { args = append(args, networkID) } - // 1. Time filtering (version: UInt64 timestamp in ms) + // 1. Time filtering (start_ts: UInt64 timestamp in ms) fromStr := q.Get("from") toStr := q.Get("to") @@ -125,7 +125,7 @@ func handleListFlows(w http.ResponseWriter, r *http.Request) { logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("invalid 'from' timestamp: %v", err), logic.BadReq)) return } - whereParts = append(whereParts, "version >= ?") + whereParts = append(whereParts, "start_ts >= ?") args = append(args, fromVal) } @@ -135,15 +135,14 @@ func handleListFlows(w http.ResponseWriter, r *http.Request) { logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("invalid 'to' timestamp: %v", err), logic.BadReq)) return } - whereParts = append(whereParts, "version <= ?") + whereParts = append(whereParts, "start_ts <= ?") args = append(args, toVal) } // 2. Source filters - srcTypeStr := q.Get("src_type") - if srcTypeStr != "" { - whereParts = append(whereParts, "src_type = ?") - args = append(args, srcTypeStr) + if q.Get("src_type") != "" { + whereParts = append(whereParts, "src_type IN ?") + args = append(args, q["src_type"]) } srcEntity := q.Get("src_entity_id") @@ -153,10 +152,9 @@ func handleListFlows(w http.ResponseWriter, r *http.Request) { } // 3. Destination filters - dstTypeStr := q.Get("dst_type") - if dstTypeStr != "" { - whereParts = append(whereParts, "dst_type = ?") - args = append(args, dstTypeStr) + if q.Get("dst_type") != "" { + whereParts = append(whereParts, "dst_type IN ?") + args = append(args, q["dst_type"]) } dstEntity := q.Get("dst_entity_id") @@ -166,10 +164,9 @@ func handleListFlows(w http.ResponseWriter, r *http.Request) { } // 4. Protocol filter - protoStr := q.Get("protocol") - if protoStr != "" { - whereParts = append(whereParts, "protocol = ?") - args = append(args, protoStr) + if q.Get("protocol") != "" { + whereParts = append(whereParts, "protocol IN ?") + args = append(args, q["protocol"]) } // 5. Node filter @@ -202,21 +199,25 @@ func handleListFlows(w http.ResponseWriter, r *http.Request) { // 6. User filter username := q.Get("username") if username != "" { - if srcTypeStr != "" || dstTypeStr != "" || - srcEntity != "" || dstEntity != "" { + if q.Has("src_type") || q.Has("dst_type") || + q.Has("src_entity_id") || q.Has("dst_entity_id") { logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("cannot provide username filter along with src/dst type and id filters"), logic.BadReq)) return } - srcTypeStr = "user" - srcEntity = username - dstTypeStr = "user" - dstEntity = username + srcTypeStr := "user" + srcEntity := username + dstTypeStr := "user" + dstEntity := username whereParts = append(whereParts, "((src_type = ? AND src_entity_id = ?) OR (dst_type = ? AND dst_entity_id = ?))") args = append(args, srcTypeStr, srcEntity, dstTypeStr, dstEntity) } + // 7. Ignore flow logs with zero end_ts. + whereParts = append(whereParts, "end_ts <> ?") + args = append(args, time.Unix(0, 0)) + // Pagination page := parseIntOrDefault(q.Get("page"), 1) perPage := parseIntOrDefault(q.Get("per_page"), 100) diff --git a/pro/controllers/tags.go b/pro/controllers/tags.go index bd65a46b..13cf76ee 100644 --- a/pro/controllers/tags.go +++ b/pro/controllers/tags.go @@ -288,6 +288,7 @@ func deleteTag(w http.ResponseWriter, r *http.Request) { go func() { proLogic.RemoveDeviceTagFromAclPolicies(tag.ID, tag.Network) proLogic.RemoveTagFromPostureChecks(tag.ID, tag.Network) + proLogic.RemoveTagFromNameservers(tag.ID, tag.Network) logic.RemoveTagFromEnrollmentKeys(tag.ID) mq.PublishPeerUpdate(false) }() diff --git a/pro/controllers/users.go b/pro/controllers/users.go index 74649b9c..6d80c484 100644 --- a/pro/controllers/users.go +++ b/pro/controllers/users.go @@ -2135,7 +2135,11 @@ func testIDPSync(w http.ResponseWriter, r *http.Request) { return } case "azure-ad": - idpClient = azure.NewAzureEntraIDClient(req.ClientID, req.ClientSecret, req.AzureTenantID) + secret := req.ClientSecret + if secret == logic.Mask() { + secret = logic.GetServerSettings().ClientSecret + } + idpClient = azure.NewAzureEntraIDClient(req.ClientID, secret, req.AzureTenantID) case "okta": idpClient, err = okta.NewOktaClient(req.OktaOrgURL, req.OktaAPIToken) if err != nil { diff --git a/pro/license.go b/pro/license.go index ea49a927..b3d87de3 100644 --- a/pro/license.go +++ b/pro/license.go @@ -140,7 +140,8 @@ func ValidateLicense() (err error) { proLogic.SetFeatureFlags(licenseResponse.FeatureFlags) proLogic.SetDeploymentMode(licenseResponse.DeploymentMode) - _ = mq.PublishExporterFeatureFlags() + go mq.PublishExporterFeatureFlags() + go mq.PublishPeerUpdate(false) slog.Info("License validation succeeded!") return nil diff --git a/pro/logic/dns.go b/pro/logic/dns.go index 31bd76da..16b4e69e 100644 --- a/pro/logic/dns.go +++ b/pro/logic/dns.go @@ -242,3 +242,23 @@ func GetNameserversForHost(h *schema.Host) (returnNsLi []models.Nameserver) { } return } + +func RemoveTagFromNameservers(tagID models.TagID, netID schema.NetworkID) error { + nameservers, err := (&schema.Nameserver{ + NetworkID: netID.String(), + }).ListByNetwork(db.WithContext(context.TODO())) + if err != nil { + return err + } + + var multiErr error + for _, nameserver := range nameservers { + delete(nameserver.Tags, tagID.String()) + err := nameserver.Update(db.WithContext(context.TODO())) + if err != nil { + multiErr = errors.Join(multiErr, err) + } + } + + return multiErr +} diff --git a/schema/dns.go b/schema/dns.go index 9cc1cddc..53a9aea8 100644 --- a/schema/dns.go +++ b/schema/dns.go @@ -46,6 +46,11 @@ func (ns *Nameserver) Create(ctx context.Context) error { return db.FromContext(ctx).Model(&Nameserver{}).Create(&ns).Error } +func (ns *Nameserver) ListAll(ctx context.Context) (dnsli []Nameserver, err error) { + err = db.FromContext(ctx).Model(&Nameserver{}).Find(&dnsli).Error + return +} + func (ns *Nameserver) ListByNetwork(ctx context.Context) (dnsli []Nameserver, err error) { err = db.FromContext(ctx).Model(&Nameserver{}).Where("network_id = ?", ns.NetworkID).Find(&dnsli).Error return diff --git a/schema/egress.go b/schema/egress.go index c9063d3f..628fe6cd 100644 --- a/schema/egress.go +++ b/schema/egress.go @@ -13,8 +13,9 @@ const egressTable = "egresses" type EgressNATMode string const ( - VirtualNAT EgressNATMode = "virtual_nat" - DirectNAT EgressNATMode = "direct_nat" + DisabledNAT EgressNATMode = "disabled" + VirtualNAT EgressNATMode = "virtual_nat" + DirectNAT EgressNATMode = "direct_nat" ) type Egress struct {