mirror of
https://github.com/gravitl/netmaker.git
synced 2026-04-22 16:07:11 +08:00
Merge branch 'release-v1.5.1' into feat/match-azure-user-by-id
This commit is contained in:
@@ -698,25 +698,25 @@ func bulkDeleteHosts(w http.ResponseWriter, r *http.Request) {
|
||||
for _, idStr := range req.IDs {
|
||||
hostID, err := uuid.Parse(idStr)
|
||||
if err != nil {
|
||||
slog.Error("bulk host delete: invalid host id", "id", idStr)
|
||||
slog.Debug("bulk host delete: invalid host id", "id", idStr)
|
||||
continue
|
||||
}
|
||||
currHost := &schema.Host{ID: hostID}
|
||||
if err = currHost.Get(db.WithContext(context.Background())); err != nil {
|
||||
slog.Error("bulk host delete: host not found", "id", idStr, "error", err)
|
||||
slog.Debug("bulk host delete: host not found", "id", idStr, "error", err)
|
||||
continue
|
||||
}
|
||||
var hostNodes []models.Node
|
||||
for _, nodeID := range currHost.Nodes {
|
||||
node, err := logic.GetNodeByID(nodeID)
|
||||
if err != nil {
|
||||
slog.Error("bulk host delete: failed to get node", "nodeid", nodeID, "error", err)
|
||||
slog.Debug("bulk host delete: failed to get node", "nodeid", nodeID, "error", err)
|
||||
continue
|
||||
}
|
||||
hostNodes = append(hostNodes, node)
|
||||
}
|
||||
if err = logic.RemoveHost(currHost, true); err != nil {
|
||||
slog.Error("bulk host delete: failed to remove host", "id", idStr, "error", err)
|
||||
slog.Debug("bulk host delete: failed to remove host", "id", idStr, "error", err)
|
||||
continue
|
||||
}
|
||||
for _, node := range hostNodes {
|
||||
@@ -724,14 +724,14 @@ func bulkDeleteHosts(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
if servercfg.GetBrokerType() == servercfg.EmqxBrokerType {
|
||||
if err := mq.GetEmqxHandler().DeleteEmqxUser(currHost.ID.String()); err != nil {
|
||||
slog.Error("bulk host delete: failed to remove EMQX credentials", "id", currHost.ID, "error", err)
|
||||
slog.Debug("bulk host delete: failed to remove EMQX credentials", "id", currHost.ID, "error", err)
|
||||
}
|
||||
}
|
||||
if err = mq.HostUpdate(&models.HostUpdate{
|
||||
Action: models.DeleteHost,
|
||||
Host: *currHost,
|
||||
}); err != nil {
|
||||
slog.Error("bulk host delete: failed to send host update", "id", currHost.ID, "error", err)
|
||||
slog.Debug("bulk host delete: failed to send host update", "id", currHost.ID, "error", err)
|
||||
}
|
||||
(&schema.PendingHost{HostID: currHost.ID.String()}).DeleteAllPendingHosts(db.WithContext(context.TODO()))
|
||||
logic.LogEvent(&models.Event{
|
||||
|
||||
@@ -336,7 +336,8 @@ func reInit(curr, new models.ServerSettings, force bool) {
|
||||
// On force AutoUpdate change, change AutoUpdate for all hosts.
|
||||
// On force FlowLogs enable, enable FlowLogs for all hosts.
|
||||
// On FlowLogs disable, forced or not, disable FlowLogs for all hosts.
|
||||
if force || !new.EnableFlowLogs {
|
||||
// On NetclientAutoUpdate disable, forced or not, disable AutoUpdate for all hosts.
|
||||
if force || !new.EnableFlowLogs || !new.NetclientAutoUpdate {
|
||||
if curr.NetclientAutoUpdate != new.NetclientAutoUpdate ||
|
||||
curr.EnableFlowLogs != new.EnableFlowLogs {
|
||||
hosts, _ := (&schema.Host{}).ListAll(db.WithContext(context.TODO()))
|
||||
|
||||
+1
-1
@@ -56,7 +56,7 @@ func userHandlers(r *mux.Router) {
|
||||
r.HandleFunc("/api/users/{username}/disable", logic.SecurityCheck(true, http.HandlerFunc(disableUserAccount))).Methods(http.MethodPost)
|
||||
r.HandleFunc("/api/users/{username}/settings", logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(getUserSettings)))).Methods(http.MethodGet)
|
||||
r.HandleFunc("/api/users/{username}/settings", logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(updateUserSettings)))).Methods(http.MethodPut)
|
||||
r.HandleFunc("/api/v1/users", logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(getUserV1)))).Methods(http.MethodGet)
|
||||
r.HandleFunc("/api/v1/users", logic.SecurityCheck(false, logic.ContinueIfUserMatchOrAdmin(http.HandlerFunc(getUserV1)))).Methods(http.MethodGet)
|
||||
r.HandleFunc("/api/users", logic.SecurityCheck(true, http.HandlerFunc(getUsers))).Methods(http.MethodGet)
|
||||
r.HandleFunc("/api/v2/users", logic.SecurityCheck(true, http.HandlerFunc(listUsers))).Methods(http.MethodGet)
|
||||
r.HandleFunc("/api/v1/users/bulk", logic.SecurityCheck(true, http.HandlerFunc(bulkDeleteUsers))).Methods(http.MethodDelete)
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gravitl/netmaker/db"
|
||||
@@ -13,6 +14,9 @@ import (
|
||||
// SqliteDB is the db object for sqlite database connections
|
||||
var SqliteDB *sql.DB
|
||||
|
||||
// sqliteWriteMu serializes SQLite write operations to reduce lock contention.
|
||||
var sqliteWriteMu sync.Mutex
|
||||
|
||||
// SQLITE_FUNCTIONS - contains a map of the functions for sqlite
|
||||
var SQLITE_FUNCTIONS = map[string]interface{}{
|
||||
INIT_DB: initSqliteDB,
|
||||
@@ -40,6 +44,9 @@ func initSqliteDB() error {
|
||||
}
|
||||
|
||||
func sqliteCreateTable(tableName string) error {
|
||||
sqliteWriteMu.Lock()
|
||||
defer sqliteWriteMu.Unlock()
|
||||
|
||||
statement, err := SqliteDB.Prepare("CREATE TABLE IF NOT EXISTS " + tableName + " (key TEXT NOT NULL UNIQUE PRIMARY KEY, value TEXT)")
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -54,6 +61,9 @@ func sqliteCreateTable(tableName string) error {
|
||||
|
||||
func sqliteInsert(key string, value string, tableName string) error {
|
||||
if key != "" && value != "" {
|
||||
sqliteWriteMu.Lock()
|
||||
defer sqliteWriteMu.Unlock()
|
||||
|
||||
insertSQL := "INSERT OR REPLACE INTO " + tableName + " (key, value) VALUES (?, ?)"
|
||||
statement, err := SqliteDB.Prepare(insertSQL)
|
||||
if err != nil {
|
||||
@@ -81,6 +91,9 @@ func sqliteInsertPeer(key string, value string) error {
|
||||
}
|
||||
|
||||
func sqliteDeleteRecord(tableName string, key string) error {
|
||||
sqliteWriteMu.Lock()
|
||||
defer sqliteWriteMu.Unlock()
|
||||
|
||||
deleteSQL := "DELETE FROM " + tableName + " WHERE key = ?"
|
||||
statement, err := SqliteDB.Prepare(deleteSQL)
|
||||
if err != nil {
|
||||
@@ -94,6 +107,9 @@ func sqliteDeleteRecord(tableName string, key string) error {
|
||||
}
|
||||
|
||||
func sqliteDeleteAllRecords(tableName string) error {
|
||||
sqliteWriteMu.Lock()
|
||||
defer sqliteWriteMu.Unlock()
|
||||
|
||||
deleteSQL := "DELETE FROM " + tableName
|
||||
statement, err := SqliteDB.Prepare(deleteSQL)
|
||||
if err != nil {
|
||||
|
||||
@@ -62,6 +62,7 @@ func (s *sqliteConnector) connect() (*gorm.DB, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sqlDB.SetMaxOpenConns(1)
|
||||
sqlDB.SetMaxIdleConns(1)
|
||||
|
||||
return db, nil
|
||||
|
||||
+1
-1
@@ -451,7 +451,7 @@ func DeleteUser(user string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
go RemoveUserFromAclPolicy(user)
|
||||
RemoveUserFromAclPolicy(user)
|
||||
return (&schema.UserAccessToken{UserName: user}).DeleteAllUserTokens(db.WithContext(context.TODO()))
|
||||
}
|
||||
|
||||
|
||||
@@ -684,6 +684,7 @@ func getExtPeerEgressRoute(node models.Node, extPeer models.ExtClient) (egressRo
|
||||
NodeAddr: node.Address,
|
||||
NodeAddr6: node.Address6,
|
||||
EgressRanges: extPeer.ExtraAllowedIPs,
|
||||
Network: node.Network,
|
||||
}
|
||||
for _, extraAllowedIP := range extPeer.ExtraAllowedIPs {
|
||||
r.EgressRangesWithMetric = append(r.EgressRangesWithMetric, models.EgressRangeMetric{
|
||||
|
||||
+17
-2
@@ -238,14 +238,14 @@ func computeHostPeerInfo(host *schema.Host, allNodes []models.Node, serverInfo m
|
||||
|
||||
// GetPeerUpdateForHost - gets the consolidated peer update for the host from all networks
|
||||
func GetPeerUpdateForHost(network string, host *schema.Host, allNodes []models.Node,
|
||||
deletedNode *models.Node, deletedClients []models.ExtClient) (models.HostPeerUpdate, error) {
|
||||
deletedNode *models.Node, deletedClients []models.ExtClient) (hostPeerUpdate models.HostPeerUpdate, err error) {
|
||||
if host == nil {
|
||||
return models.HostPeerUpdate{}, errors.New("host is nil")
|
||||
}
|
||||
|
||||
// track which nodes are deleted
|
||||
// after peer calculation, if peer not in list, add delete config of peer
|
||||
hostPeerUpdate := models.HostPeerUpdate{
|
||||
hostPeerUpdate = models.HostPeerUpdate{
|
||||
Host: *host,
|
||||
Server: servercfg.GetServer(),
|
||||
ServerVersion: servercfg.GetVersion(),
|
||||
@@ -266,6 +266,9 @@ func GetPeerUpdateForHost(network string, host *schema.Host, allNodes []models.N
|
||||
GwNodes: make(map[schema.NetworkID][]models.Node),
|
||||
AddressIdentityMap: make(map[string]models.PeerIdentity),
|
||||
}
|
||||
defer func() {
|
||||
hostPeerUpdate.EgressRoutes = deduplicateEgressRoutes(hostPeerUpdate.EgressRoutes)
|
||||
}()
|
||||
if host.DNS == "no" {
|
||||
hostPeerUpdate.ManageDNS = false
|
||||
}
|
||||
@@ -931,6 +934,18 @@ func getNodeAllowedIPs(peer, node *models.Node) []net.IPNet {
|
||||
}
|
||||
return allowedips
|
||||
}
|
||||
func deduplicateEgressRoutes(routes []models.EgressNetworkRoutes) []models.EgressNetworkRoutes {
|
||||
seen := make(map[string]struct{}, len(routes))
|
||||
result := make([]models.EgressNetworkRoutes, 0, len(routes))
|
||||
for _, r := range routes {
|
||||
key := r.PeerKey + "|" + r.Network
|
||||
if _, exists := seen[key]; !exists {
|
||||
seen[key] = struct{}{}
|
||||
result = append(result, r)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func getCIDRMaskFromAddr(addr string) net.IPMask {
|
||||
cidr := net.CIDRMask(32, 32)
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
package logic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/gravitl/netmaker/db"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
@@ -181,3 +183,31 @@ func ContinueIfUserMatch(next http.Handler) http.HandlerFunc {
|
||||
next.ServeHTTP(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
func ContinueIfUserMatchOrAdmin(next http.Handler) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
var errorResponse = models.ErrorResponse{
|
||||
Code: http.StatusForbidden, Message: Forbidden_Msg,
|
||||
}
|
||||
|
||||
user := &schema.User{
|
||||
Username: r.Header.Get("user"),
|
||||
}
|
||||
err := user.Get(db.WithContext(context.TODO()))
|
||||
if err == nil && (user.PlatformRoleID == schema.SuperAdminRole || user.PlatformRoleID == schema.AdminRole) {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
var params = mux.Vars(r)
|
||||
var requestedUser = params["username"]
|
||||
if requestedUser == "" {
|
||||
requestedUser = r.URL.Query().Get("username")
|
||||
}
|
||||
if requestedUser != r.Header.Get("user") {
|
||||
ReturnErrorResponse(w, r, errorResponse)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"log"
|
||||
"net"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/exp/slog"
|
||||
@@ -38,6 +39,7 @@ func Run() {
|
||||
resync()
|
||||
deleteOldExtclients()
|
||||
cleanupDeletedUserGroupRefs()
|
||||
migrateNameservers()
|
||||
}
|
||||
|
||||
func updateNetworks() {
|
||||
@@ -588,6 +590,9 @@ func migrateToEgressV1() {
|
||||
CreatedBy: user.UserName,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
if !e.Nat {
|
||||
e.Mode = schema.DisabledNAT
|
||||
}
|
||||
err = e.Create(db.WithContext(context.TODO()))
|
||||
if err == nil {
|
||||
acl := models.Acl{
|
||||
@@ -737,8 +742,10 @@ func cleanupDeletedUserGroupRefs() {
|
||||
existingGroups[group.ID] = group
|
||||
}
|
||||
|
||||
existingUsers := make(map[string]schema.User)
|
||||
users, _ := (&schema.User{}).ListAll(db.WithContext(context.TODO()))
|
||||
for _, user := range users {
|
||||
existingUsers[user.Username] = user
|
||||
var update bool
|
||||
for groupID := range user.UserGroups.Data() {
|
||||
if _, ok := existingGroups[groupID]; !ok {
|
||||
@@ -770,6 +777,10 @@ func cleanupDeletedUserGroupRefs() {
|
||||
newSrc = append(newSrc, src)
|
||||
}
|
||||
}
|
||||
} else if src.ID == models.UserAclID && src.Value != "*" {
|
||||
if _, ok := existingUsers[src.Value]; ok {
|
||||
newSrc = append(newSrc, src)
|
||||
}
|
||||
} else {
|
||||
newSrc = append(newSrc, src)
|
||||
}
|
||||
@@ -798,3 +809,86 @@ func cleanupDeletedUserGroupRefs() {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func migrateNameservers() {
|
||||
networks, _ := (&schema.Network{}).ListAll(db.WithContext(context.TODO()))
|
||||
for _, network := range networks {
|
||||
_ = logic.CreateFallbackNameserver(network.Name)
|
||||
}
|
||||
|
||||
nameservers, _ := (&schema.Nameserver{}).ListAll(db.WithContext(context.TODO()))
|
||||
for _, nameserver := range nameservers {
|
||||
if len(nameserver.Domains) != 0 {
|
||||
for _, matchDomain := range nameserver.MatchDomains {
|
||||
nameserver.Domains = append(nameserver.Domains, schema.NameserverDomain{
|
||||
Domain: matchDomain,
|
||||
})
|
||||
}
|
||||
|
||||
nameserver.MatchDomains = []string{}
|
||||
|
||||
_ = nameserver.Update(db.WithContext(context.TODO()))
|
||||
}
|
||||
}
|
||||
|
||||
superAdmin := &schema.User{}
|
||||
err := superAdmin.GetSuperAdmin(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
nodes, _ := logic.GetAllNodes()
|
||||
for _, node := range nodes {
|
||||
if !node.IsGw {
|
||||
continue
|
||||
}
|
||||
|
||||
if node.IngressDNS != "" {
|
||||
var nsIPs []string
|
||||
for _, nsIP := range strings.Split(node.IngressDNS, ",") {
|
||||
nsIP = strings.TrimSpace(nsIP)
|
||||
|
||||
if (node.Address.IP != nil && node.Address.IP.String() == nsIP) ||
|
||||
(node.Address6.IP != nil && node.Address6.IP.String() == nsIP) {
|
||||
continue
|
||||
}
|
||||
if nsIP == "8.8.8.8" || nsIP == "1.1.1.1" || nsIP == "9.9.9.9" {
|
||||
continue
|
||||
}
|
||||
|
||||
nsIPs = append(nsIPs, nsIP)
|
||||
}
|
||||
|
||||
if len(nsIPs) > 0 {
|
||||
host := &schema.Host{
|
||||
ID: node.HostID,
|
||||
}
|
||||
err := host.Get(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
ns := schema.Nameserver{
|
||||
ID: uuid.NewString(),
|
||||
Name: fmt.Sprintf("%s gw nameservers", host.Name),
|
||||
NetworkID: node.Network,
|
||||
Servers: nsIPs,
|
||||
MatchAll: true,
|
||||
Domains: []schema.NameserverDomain{
|
||||
{
|
||||
Domain: ".",
|
||||
},
|
||||
},
|
||||
Nodes: datatypes.JSONMap{
|
||||
node.ID.String(): struct{}{},
|
||||
},
|
||||
Tags: make(datatypes.JSONMap),
|
||||
Status: true,
|
||||
CreatedBy: superAdmin.Username,
|
||||
}
|
||||
_ = ns.Create(db.WithContext(context.TODO()))
|
||||
node.IngressDNS = ""
|
||||
_ = logic.UpsertNode(&node)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+100
-9
@@ -5,8 +5,10 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gravitl/netmaker/database"
|
||||
"github.com/gravitl/netmaker/db"
|
||||
"github.com/gravitl/netmaker/logger"
|
||||
@@ -91,7 +93,7 @@ func migrateV1_5_1(ctx context.Context) error {
|
||||
}
|
||||
|
||||
func migrateUsers(ctx context.Context) error {
|
||||
records, err := database.FetchRecords(database.USERS_TABLE_NAME)
|
||||
records, err := FetchAll(ctx, database.USERS_TABLE_NAME)
|
||||
if err != nil && !database.IsEmptyRecord(err) {
|
||||
return err
|
||||
}
|
||||
@@ -145,7 +147,7 @@ func migrateUsers(ctx context.Context) error {
|
||||
}
|
||||
|
||||
func migrateNetworks(ctx context.Context) error {
|
||||
records, err := database.FetchRecords(database.NETWORKS_TABLE_NAME)
|
||||
records, err := FetchAll(ctx, database.NETWORKS_TABLE_NAME)
|
||||
if err != nil && !database.IsEmptyRecord(err) {
|
||||
return err
|
||||
}
|
||||
@@ -204,13 +206,87 @@ func migrateNetworks(ctx context.Context) error {
|
||||
logger.Log(4, fmt.Sprintf("migrating network %s failed: %v", _network.Name, err))
|
||||
return err
|
||||
}
|
||||
|
||||
var cidr, cidrv6 *net.IPNet
|
||||
if len(network.AddressRange) != 0 {
|
||||
_, cidr, err = net.ParseCIDR(network.AddressRange)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("error parsing network (%s) cidr (%s): %v", _network.Name, network.AddressRange, err)
|
||||
logger.Log(4, fmt.Sprintf("migrating network %s failed: %v", _network.Name, err))
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if len(network.AddressRange6) != 0 {
|
||||
_, cidrv6, err = net.ParseCIDR(network.AddressRange6)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("error parsing network (%s) cidr (%s): %v", _network.Name, network.AddressRange6, err)
|
||||
logger.Log(4, fmt.Sprintf("migrating network %s failed: %v", _network.Name, err))
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
superAdmin := &schema.User{}
|
||||
err = superAdmin.GetSuperAdmin(ctx)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("error getting superadmin: %v", err)
|
||||
logger.Log(4, fmt.Sprintf("migrating network %s failed: %v", _network.Name, err))
|
||||
return err
|
||||
}
|
||||
|
||||
if len(network.NameServers) > 0 {
|
||||
ns := schema.Nameserver{
|
||||
ID: uuid.NewString(),
|
||||
Name: "upstream nameservers",
|
||||
NetworkID: _network.Name,
|
||||
Servers: []string{},
|
||||
MatchAll: true,
|
||||
Domains: []schema.NameserverDomain{
|
||||
{
|
||||
Domain: ".",
|
||||
},
|
||||
},
|
||||
Tags: datatypes.JSONMap{
|
||||
"*": struct{}{},
|
||||
},
|
||||
Nodes: make(datatypes.JSONMap),
|
||||
Status: true,
|
||||
CreatedBy: superAdmin.Username,
|
||||
}
|
||||
|
||||
for _, nsIP := range network.NameServers {
|
||||
ip := net.ParseIP(nsIP)
|
||||
if ip == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if ip.To4() != nil {
|
||||
if cidr != nil && !cidr.Contains(ip) {
|
||||
ns.Servers = append(ns.Servers, nsIP)
|
||||
}
|
||||
} else {
|
||||
if cidrv6 != nil && !cidrv6.Contains(ip) {
|
||||
ns.Servers = append(ns.Servers, nsIP)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(ns.Servers) > 0 {
|
||||
err = ns.Create(ctx)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("error creating upstream nameserver for network (%s): %v", _network.Name, err)
|
||||
logger.Log(4, fmt.Sprintf("migrating network %s failed: %v", _network.Name, err))
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func migrateUserRoles(ctx context.Context) error {
|
||||
records, err := database.FetchRecords(database.USER_PERMISSIONS_TABLE_NAME)
|
||||
records, err := FetchAll(ctx, database.USER_PERMISSIONS_TABLE_NAME)
|
||||
if err != nil && !database.IsEmptyRecord(err) {
|
||||
return err
|
||||
}
|
||||
@@ -235,7 +311,7 @@ func migrateUserRoles(ctx context.Context) error {
|
||||
}
|
||||
|
||||
func migrateUserGroups(ctx context.Context) error {
|
||||
records, err := database.FetchRecords(database.USER_GROUPS_TABLE_NAME)
|
||||
records, err := FetchAll(ctx, database.USER_GROUPS_TABLE_NAME)
|
||||
if err != nil && !database.IsEmptyRecord(err) {
|
||||
return err
|
||||
}
|
||||
@@ -260,7 +336,7 @@ func migrateUserGroups(ctx context.Context) error {
|
||||
}
|
||||
|
||||
func migrateHosts(ctx context.Context) error {
|
||||
records, err := database.FetchRecords(database.HOSTS_TABLE_NAME)
|
||||
records, err := FetchAll(ctx, database.HOSTS_TABLE_NAME)
|
||||
if err != nil && !database.IsEmptyRecord(err) {
|
||||
return err
|
||||
}
|
||||
@@ -336,10 +412,6 @@ func migrateHosts(ctx context.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
if _host.IsDefault && !_host.AutoUpdate {
|
||||
_host.AutoUpdate = true
|
||||
}
|
||||
|
||||
logger.Log(4, fmt.Sprintf("migrating host %s", _host.ID))
|
||||
|
||||
err = _host.Create(ctx)
|
||||
@@ -351,3 +423,22 @@ func migrateHosts(ctx context.Context) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func FetchAll(ctx context.Context, tableName string) (map[string]string, error) {
|
||||
row, err := db.FromContext(ctx).Raw("SELECT * FROM " + tableName + " ORDER BY key").Rows()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
records := make(map[string]string)
|
||||
defer row.Close()
|
||||
for row.Next() { // Iterate and fetch the records from result cursor
|
||||
var key string
|
||||
var value string
|
||||
row.Scan(&key, &value)
|
||||
records[key] = value
|
||||
}
|
||||
if len(records) == 0 {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
return records, nil
|
||||
}
|
||||
|
||||
@@ -106,7 +106,8 @@ func listUserActivity(w http.ResponseWriter, r *http.Request) {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
|
||||
return
|
||||
}
|
||||
if caller.Username != username && caller.PlatformRoleID != schema.SuperAdminRole && caller.PlatformRoleID != schema.AdminRole {
|
||||
if caller.Username != username && caller.PlatformRoleID != schema.SuperAdminRole &&
|
||||
caller.PlatformRoleID != schema.AdminRole && caller.PlatformRoleID != schema.Auditor {
|
||||
logic.ReturnErrorResponse(w, r, models.ErrorResponse{
|
||||
Code: http.StatusForbidden,
|
||||
Message: "you are not authorized to view this user's activity",
|
||||
|
||||
+25
-24
@@ -75,11 +75,11 @@ type FlowRow struct {
|
||||
// @Param network_id query string false "Filter by network ID"
|
||||
// @Param from query string false "Start time in RFC3339 format"
|
||||
// @Param to query string false "End time in RFC3339 format"
|
||||
// @Param src_type query string false "Source type filter"
|
||||
// @Param src_type []query string false "Source type filter"
|
||||
// @Param src_entity_id query string false "Source entity ID filter"
|
||||
// @Param dst_type query string false "Destination type filter"
|
||||
// @Param dst_type []query string false "Destination type filter"
|
||||
// @Param dst_entity_id query string false "Destination entity ID filter"
|
||||
// @Param protocol query string false "Protocol filter"
|
||||
// @Param protocol []query string false "Protocol filter"
|
||||
// @Param node_id query string false "Node ID filter"
|
||||
// @Param username query string false "Username filter"
|
||||
// @Param page query int false "Page number"
|
||||
@@ -115,7 +115,7 @@ func handleListFlows(w http.ResponseWriter, r *http.Request) {
|
||||
args = append(args, networkID)
|
||||
}
|
||||
|
||||
// 1. Time filtering (version: UInt64 timestamp in ms)
|
||||
// 1. Time filtering (start_ts: UInt64 timestamp in ms)
|
||||
fromStr := q.Get("from")
|
||||
toStr := q.Get("to")
|
||||
|
||||
@@ -125,7 +125,7 @@ func handleListFlows(w http.ResponseWriter, r *http.Request) {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("invalid 'from' timestamp: %v", err), logic.BadReq))
|
||||
return
|
||||
}
|
||||
whereParts = append(whereParts, "version >= ?")
|
||||
whereParts = append(whereParts, "start_ts >= ?")
|
||||
args = append(args, fromVal)
|
||||
}
|
||||
|
||||
@@ -135,15 +135,14 @@ func handleListFlows(w http.ResponseWriter, r *http.Request) {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("invalid 'to' timestamp: %v", err), logic.BadReq))
|
||||
return
|
||||
}
|
||||
whereParts = append(whereParts, "version <= ?")
|
||||
whereParts = append(whereParts, "start_ts <= ?")
|
||||
args = append(args, toVal)
|
||||
}
|
||||
|
||||
// 2. Source filters
|
||||
srcTypeStr := q.Get("src_type")
|
||||
if srcTypeStr != "" {
|
||||
whereParts = append(whereParts, "src_type = ?")
|
||||
args = append(args, srcTypeStr)
|
||||
if q.Get("src_type") != "" {
|
||||
whereParts = append(whereParts, "src_type IN ?")
|
||||
args = append(args, q["src_type"])
|
||||
}
|
||||
|
||||
srcEntity := q.Get("src_entity_id")
|
||||
@@ -153,10 +152,9 @@ func handleListFlows(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// 3. Destination filters
|
||||
dstTypeStr := q.Get("dst_type")
|
||||
if dstTypeStr != "" {
|
||||
whereParts = append(whereParts, "dst_type = ?")
|
||||
args = append(args, dstTypeStr)
|
||||
if q.Get("dst_type") != "" {
|
||||
whereParts = append(whereParts, "dst_type IN ?")
|
||||
args = append(args, q["dst_type"])
|
||||
}
|
||||
|
||||
dstEntity := q.Get("dst_entity_id")
|
||||
@@ -166,10 +164,9 @@ func handleListFlows(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// 4. Protocol filter
|
||||
protoStr := q.Get("protocol")
|
||||
if protoStr != "" {
|
||||
whereParts = append(whereParts, "protocol = ?")
|
||||
args = append(args, protoStr)
|
||||
if q.Get("protocol") != "" {
|
||||
whereParts = append(whereParts, "protocol IN ?")
|
||||
args = append(args, q["protocol"])
|
||||
}
|
||||
|
||||
// 5. Node filter
|
||||
@@ -202,21 +199,25 @@ func handleListFlows(w http.ResponseWriter, r *http.Request) {
|
||||
// 6. User filter
|
||||
username := q.Get("username")
|
||||
if username != "" {
|
||||
if srcTypeStr != "" || dstTypeStr != "" ||
|
||||
srcEntity != "" || dstEntity != "" {
|
||||
if q.Has("src_type") || q.Has("dst_type") ||
|
||||
q.Has("src_entity_id") || q.Has("dst_entity_id") {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("cannot provide username filter along with src/dst type and id filters"), logic.BadReq))
|
||||
return
|
||||
}
|
||||
|
||||
srcTypeStr = "user"
|
||||
srcEntity = username
|
||||
dstTypeStr = "user"
|
||||
dstEntity = username
|
||||
srcTypeStr := "user"
|
||||
srcEntity := username
|
||||
dstTypeStr := "user"
|
||||
dstEntity := username
|
||||
|
||||
whereParts = append(whereParts, "((src_type = ? AND src_entity_id = ?) OR (dst_type = ? AND dst_entity_id = ?))")
|
||||
args = append(args, srcTypeStr, srcEntity, dstTypeStr, dstEntity)
|
||||
}
|
||||
|
||||
// 7. Ignore flow logs with zero end_ts.
|
||||
whereParts = append(whereParts, "end_ts <> ?")
|
||||
args = append(args, time.Unix(0, 0))
|
||||
|
||||
// Pagination
|
||||
page := parseIntOrDefault(q.Get("page"), 1)
|
||||
perPage := parseIntOrDefault(q.Get("per_page"), 100)
|
||||
|
||||
@@ -288,6 +288,7 @@ func deleteTag(w http.ResponseWriter, r *http.Request) {
|
||||
go func() {
|
||||
proLogic.RemoveDeviceTagFromAclPolicies(tag.ID, tag.Network)
|
||||
proLogic.RemoveTagFromPostureChecks(tag.ID, tag.Network)
|
||||
proLogic.RemoveTagFromNameservers(tag.ID, tag.Network)
|
||||
logic.RemoveTagFromEnrollmentKeys(tag.ID)
|
||||
mq.PublishPeerUpdate(false)
|
||||
}()
|
||||
|
||||
@@ -2135,7 +2135,11 @@ func testIDPSync(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
case "azure-ad":
|
||||
idpClient = azure.NewAzureEntraIDClient(req.ClientID, req.ClientSecret, req.AzureTenantID)
|
||||
secret := req.ClientSecret
|
||||
if secret == logic.Mask() {
|
||||
secret = logic.GetServerSettings().ClientSecret
|
||||
}
|
||||
idpClient = azure.NewAzureEntraIDClient(req.ClientID, secret, req.AzureTenantID)
|
||||
case "okta":
|
||||
idpClient, err = okta.NewOktaClient(req.OktaOrgURL, req.OktaAPIToken)
|
||||
if err != nil {
|
||||
|
||||
+2
-1
@@ -140,7 +140,8 @@ func ValidateLicense() (err error) {
|
||||
proLogic.SetFeatureFlags(licenseResponse.FeatureFlags)
|
||||
proLogic.SetDeploymentMode(licenseResponse.DeploymentMode)
|
||||
|
||||
_ = mq.PublishExporterFeatureFlags()
|
||||
go mq.PublishExporterFeatureFlags()
|
||||
go mq.PublishPeerUpdate(false)
|
||||
|
||||
slog.Info("License validation succeeded!")
|
||||
return nil
|
||||
|
||||
@@ -242,3 +242,23 @@ func GetNameserversForHost(h *schema.Host) (returnNsLi []models.Nameserver) {
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func RemoveTagFromNameservers(tagID models.TagID, netID schema.NetworkID) error {
|
||||
nameservers, err := (&schema.Nameserver{
|
||||
NetworkID: netID.String(),
|
||||
}).ListByNetwork(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var multiErr error
|
||||
for _, nameserver := range nameservers {
|
||||
delete(nameserver.Tags, tagID.String())
|
||||
err := nameserver.Update(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
multiErr = errors.Join(multiErr, err)
|
||||
}
|
||||
}
|
||||
|
||||
return multiErr
|
||||
}
|
||||
|
||||
@@ -46,6 +46,11 @@ func (ns *Nameserver) Create(ctx context.Context) error {
|
||||
return db.FromContext(ctx).Model(&Nameserver{}).Create(&ns).Error
|
||||
}
|
||||
|
||||
func (ns *Nameserver) ListAll(ctx context.Context) (dnsli []Nameserver, err error) {
|
||||
err = db.FromContext(ctx).Model(&Nameserver{}).Find(&dnsli).Error
|
||||
return
|
||||
}
|
||||
|
||||
func (ns *Nameserver) ListByNetwork(ctx context.Context) (dnsli []Nameserver, err error) {
|
||||
err = db.FromContext(ctx).Model(&Nameserver{}).Where("network_id = ?", ns.NetworkID).Find(&dnsli).Error
|
||||
return
|
||||
|
||||
@@ -13,6 +13,7 @@ const egressTable = "egresses"
|
||||
type EgressNATMode string
|
||||
|
||||
const (
|
||||
DisabledNAT EgressNATMode = "disabled"
|
||||
VirtualNAT EgressNATMode = "virtual_nat"
|
||||
DirectNAT EgressNATMode = "direct_nat"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user