diff --git a/controllers/acls.go b/controllers/acls.go index 9599cfd8..fee1e8f3 100644 --- a/controllers/acls.go +++ b/controllers/acls.go @@ -223,6 +223,7 @@ func getAcls(w http.ResponseWriter, r *http.Request) { return } logic.SortAclEntrys(acls[:]) + logic.PopulateAclPolicyTagNames(acls) logic.ReturnSuccessResponseWithJson(w, r, acls, "fetched all acls in the network "+netID) } @@ -254,6 +255,7 @@ func getEgressAcls(w http.ResponseWriter, r *http.Request) { return } logic.SortAclEntrys(acls[:]) + logic.PopulateAclPolicyTagNames(acls) logic.ReturnSuccessResponseWithJson(w, r, acls, "fetched acls for egress"+e.Name) } @@ -329,7 +331,9 @@ func createAcl(w http.ResponseWriter, r *http.Request) { Origin: schema.Dashboard, }) go mq.PublishPeerUpdate(true) - logic.ReturnSuccessResponseWithJson(w, r, acl, "created acl successfully") + acls := []models.Acl{acl} + logic.PopulateAclPolicyTagNames(acls) + logic.ReturnSuccessResponseWithJson(w, r, acls[0], "created acl successfully") } // @Summary Update Acl @@ -395,7 +399,14 @@ func updateAcl(w http.ResponseWriter, r *http.Request) { Origin: schema.Dashboard, }) go mq.PublishPeerUpdate(true) - logic.ReturnSuccessResponse(w, r, "updated acl "+acl.Name) + updatedAcl, err := logic.GetAcl(acl.ID) + if err != nil { + logic.ReturnSuccessResponse(w, r, "updated acl "+acl.Name) + return + } + acls := []models.Acl{updatedAcl} + logic.PopulateAclPolicyTagNames(acls) + logic.ReturnSuccessResponseWithJson(w, r, acls[0], "updated acl "+acl.Name) } // @Summary Delete Acl diff --git a/controllers/gateway.go b/controllers/gateway.go index 8e3a28fc..9848953b 100644 --- a/controllers/gateway.go +++ b/controllers/gateway.go @@ -338,7 +338,12 @@ func assignGw(w http.ResponseWriter, r *http.Request) { autoAssignGw = false } if autoAssignGw { - + if node.InternetGwID != "" { + logic.ReturnErrorResponse(w, r, logic.FormatError( + errors.New("node is configured to route all traffic via an internet gateway; auto-assign gateway is not allowed"), + "badrequest")) + return + } if node.RelayedBy != "" { gatewayNode, err := logic.GetNodeByID(node.RelayedBy) if err == nil { diff --git a/controllers/hosts.go b/controllers/hosts.go index d9f14e74..2422a996 100644 --- a/controllers/hosts.go +++ b/controllers/hosts.go @@ -526,7 +526,20 @@ func hostUpdateFallback(w http.ResponseWriter, r *http.Request) { return } case models.UpdateNode: - sendDeletedNodeUpdate, sendPeerUpdate = logic.UpdateHostNode(&hostUpdate.Host, &hostUpdate.Node) + var displacedGwNodes []models.Node + sendDeletedNodeUpdate, sendPeerUpdate, displacedGwNodes = logic.UpdateHostNode(&hostUpdate.Host, &hostUpdate.Node) + if len(displacedGwNodes) > 0 { + go func() { + for _, dNode := range displacedGwNodes { + dHost := &schema.Host{ID: dNode.HostID} + if err := dHost.Get(db.WithContext(context.TODO())); err != nil { + slog.Error("fallback disconnect gw: failed to get host for displaced node", "node", dNode.ID, "error", err) + continue + } + mq.HostUpdate(&models.HostUpdate{Action: models.CheckAutoAssignGw, Host: *dHost, Node: dNode}) + } + }() + } case models.UpdateMetrics: mq.UpdateMetricsFallBack(hostUpdate.Node.ID.String(), hostUpdate.NewMetrics) case models.EgressUpdate: diff --git a/controllers/node.go b/controllers/node.go index 1e35a2e9..ed00b69b 100644 --- a/controllers/node.go +++ b/controllers/node.go @@ -703,34 +703,14 @@ func updateNode(w http.ResponseWriter, r *http.Request) { _ = logic.UpdateMetrics(newNode.ID.String(), metrics) } if servercfg.IsPro { - gwNode, err := logic.GetNodeByID(newNode.ID.String()) - if err != nil { - slog.Error("disconnect gw: failed to re-fetch node", "node", newNode.ID, "error", err) - } else if gwNode.IsGw && len(gwNode.RelayedNodes) > 0 { - newRelayedNodes := []string{} - var displacedNodes []models.Node - for _, relayedNodeID := range gwNode.RelayedNodes { - relayedNode, err := logic.GetNodeByID(relayedNodeID) - if err != nil { - continue - } - if relayedNode.AutoAssignGateway && relayedNode.RelayedBy == gwNode.ID.String() { - displacedNodes = append(displacedNodes, relayedNode) - continue - } - newRelayedNodes = append(newRelayedNodes, relayedNodeID) - } - if len(displacedNodes) > 0 { - logic.UpdateRelayNodes(gwNode.ID.String(), gwNode.RelayedNodes, newRelayedNodes) - for _, dNode := range displacedNodes { - dHost := &schema.Host{ID: dNode.HostID} - if err := dHost.Get(db.WithContext(context.TODO())); err != nil { - slog.Error("disconnect gw: failed to get host for displaced node", "node", dNode.ID, "error", err) - continue - } - mq.HostUpdate(&models.HostUpdate{Action: models.CheckAutoAssignGw, Host: *dHost, Node: dNode}) - } + displacedNodes := logic.DisplaceAutoRelayedNodes(newNode.ID.String()) + for _, dNode := range displacedNodes { + dHost := &schema.Host{ID: dNode.HostID} + if err := dHost.Get(db.WithContext(context.TODO())); err != nil { + slog.Error("disconnect gw: failed to get host for displaced node", "node", dNode.ID, "error", err) + continue } + mq.HostUpdate(&models.HostUpdate{Action: models.CheckAutoAssignGw, Host: *dHost, Node: dNode}) } } } diff --git a/logic/acls.go b/logic/acls.go index 9191975e..43bf5bdf 100644 --- a/logic/acls.go +++ b/logic/acls.go @@ -1721,6 +1721,60 @@ func SortAclEntrys(acls []models.Acl) { }) } +// PopulateAclPolicyTagNames resolves human-readable names for ACL policy tags +func PopulateAclPolicyTagNames(acls []models.Acl) { + for i := range acls { + populateTagNames(acls[i].Src) + populateTagNames(acls[i].Dst) + } +} + +func populateTagNames(tags []models.AclPolicyTag) { + for i := range tags { + tag := &tags[i] + if tag.Value == "" || tag.Value == "*" { + tag.Name = tag.Value + continue + } + switch tag.ID { + case models.UserAclID: + tag.Name = tag.Value + case models.UserGroupAclID: + grp, err := GetUserGroup(schema.UserGroupID(tag.Value)) + if err == nil { + tag.Name = grp.Name + } else { + tag.Name = tag.Value + } + case models.NodeTagID: + tag.Name = tag.Value + case models.NodeID: + node, err := GetNodeByID(tag.Value) + if err == nil { + host := &schema.Host{ID: node.HostID} + if err := host.Get(db.WithContext(context.TODO())); err == nil { + tag.Name = host.Name + } else { + tag.Name = tag.Value + } + } else { + tag.Name = tag.Value + } + case models.EgressID: + egress := schema.Egress{ID: tag.Value} + if err := egress.Get(db.WithContext(context.TODO())); err == nil { + tag.Name = egress.Name + } else { + tag.Name = tag.Value + } + case models.EgressRange: + tag.Name = tag.Value + default: + tag.Name = tag.Value + } + } +} + // ValidateCreateAclReq - validates create req for acl func ValidateCreateAclReq(req models.Acl) error { // check if acl network exists diff --git a/logic/gateway.go b/logic/gateway.go index ef65b36b..ed5e3b97 100644 --- a/logic/gateway.go +++ b/logic/gateway.go @@ -423,10 +423,20 @@ func SetInternetGw(node *models.Node, req models.InetNodeReq) { if err != nil { continue } + if clientNode.AutoAssignGateway { + clientNode.AutoAssignGateway = false + if clientNode.RelayedBy != "" && clientNode.RelayedBy != node.ID.String() { + currRelay, err := GetNodeByID(clientNode.RelayedBy) + if err == nil { + newRelayed := RemoveAllFromSlice(currRelay.RelayedNodes, clientNode.ID.String()) + UpdateRelayNodes(currRelay.ID.String(), currRelay.RelayedNodes, newRelayed) + } + clientNode.RelayedBy = "" + } + } clientNode.InternetGwID = node.ID.String() UpsertNode(&clientNode) } - } func UnsetInternetGw(node *models.Node) { diff --git a/logic/hosts.go b/logic/hosts.go index b2daf282..838cad10 100644 --- a/logic/hosts.go +++ b/logic/hosts.go @@ -273,7 +273,7 @@ func UpsertHost(h *schema.Host) error { } // UpdateHostNode - handles updates from client nodes -func UpdateHostNode(h *schema.Host, newNode *models.Node) (publishDeletedNodeUpdate, publishPeerUpdate bool) { +func UpdateHostNode(h *schema.Host, newNode *models.Node) (publishDeletedNodeUpdate, publishPeerUpdate bool, displacedGwNodes []models.Node) { currentNode, err := GetNodeByID(newNode.ID.String()) if err != nil { return @@ -283,15 +283,43 @@ func UpdateHostNode(h *schema.Host, newNode *models.Node) (publishDeletedNodeUpd UpsertNode(¤tNode) if !newNode.Connected { publishDeletedNodeUpdate = true + if servercfg.IsPro { + displacedGwNodes = DisplaceAutoRelayedNodes(newNode.ID.String()) + } } publishPeerUpdate = true - // reset failover data for this node ResetFailedOverPeer(newNode) ResetAutoRelayedPeer(newNode) return } +// DisplaceAutoRelayedNodes removes auto-assigned nodes from a disconnected gateway +// and returns the displaced nodes that need re-assignment. +func DisplaceAutoRelayedNodes(nodeID string) []models.Node { + gwNode, err := GetNodeByID(nodeID) + if err != nil || !gwNode.IsGw || len(gwNode.RelayedNodes) == 0 { + return nil + } + var newRelayedNodes []string + var displacedNodes []models.Node + for _, relayedNodeID := range gwNode.RelayedNodes { + relayedNode, err := GetNodeByID(relayedNodeID) + if err != nil { + continue + } + if relayedNode.AutoAssignGateway && relayedNode.RelayedBy == gwNode.ID.String() { + displacedNodes = append(displacedNodes, relayedNode) + continue + } + newRelayedNodes = append(newRelayedNodes, relayedNodeID) + } + if len(displacedNodes) > 0 { + UpdateRelayNodes(gwNode.ID.String(), gwNode.RelayedNodes, newRelayedNodes) + } + return displacedNodes +} + // RemoveHost - removes a given host from server func RemoveHost(h *schema.Host, forceDelete bool) error { if !forceDelete && len(h.Nodes) > 0 { diff --git a/models/acl.go b/models/acl.go index 2cc21521..da950dcc 100644 --- a/models/acl.go +++ b/models/acl.go @@ -51,6 +51,7 @@ const ( type AclPolicyTag struct { ID AclGroupType `json:"id"` + Name string `json:"name"` Value string `json:"value"` } diff --git a/pro/controllers/posture_check.go b/pro/controllers/posture_check.go index 9ee306f0..dd9b2198 100644 --- a/pro/controllers/posture_check.go +++ b/pro/controllers/posture_check.go @@ -110,6 +110,7 @@ func createPostureCheck(w http.ResponseWriter, r *http.Request) { go mq.PublishPeerUpdate(false) go proLogic.RunPostureChecks() + proLogic.PopulatePostureCheckGroupNames([]schema.PostureCheck{pc}) logic.ReturnSuccessResponseWithJson(w, r, pc, "created posture check") } @@ -148,6 +149,7 @@ func listPostureChecks(w http.ResponseWriter, r *http.Request) { ) return } + proLogic.PopulatePostureCheckGroupNames([]schema.PostureCheck{pc}) logic.ReturnSuccessResponseWithJson(w, r, pc, "fetched posture check") return } @@ -161,6 +163,7 @@ func listPostureChecks(w http.ResponseWriter, r *http.Request) { ) return } + proLogic.PopulatePostureCheckGroupNames(list) logic.ReturnSuccessResponseWithJson(w, r, list, "fetched posture checks") } @@ -246,6 +249,7 @@ func updatePostureCheck(w http.ResponseWriter, r *http.Request) { logic.LogEvent(event) go mq.PublishPeerUpdate(false) go proLogic.RunPostureChecks() + proLogic.PopulatePostureCheckGroupNames([]schema.PostureCheck{pc}) logic.ReturnSuccessResponseWithJson(w, r, pc, "updated posture check") } diff --git a/pro/controllers/users.go b/pro/controllers/users.go index ef36c179..fca4a527 100644 --- a/pro/controllers/users.go +++ b/pro/controllers/users.go @@ -51,6 +51,8 @@ func UserHandlers(r *mux.Router) { r.HandleFunc("/api/v1/users/group", logic.SecurityCheck(true, http.HandlerFunc(createUserGroup))).Methods(http.MethodPost) r.HandleFunc("/api/v1/users/group", logic.SecurityCheck(true, http.HandlerFunc(updateUserGroup))).Methods(http.MethodPut) r.HandleFunc("/api/v1/users/group", logic.SecurityCheck(true, http.HandlerFunc(deleteUserGroup))).Methods(http.MethodDelete) + r.HandleFunc("/api/v1/users/groups/network", logic.SecurityCheck(true, http.HandlerFunc(listNetworkUserGroups))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/users/network", logic.SecurityCheck(true, http.HandlerFunc(listNetworkUsers))).Methods(http.MethodGet) r.HandleFunc("/api/v1/users/add_network_user", logic.SecurityCheck(true, http.HandlerFunc(addUsertoNetwork))).Methods(http.MethodPut) r.HandleFunc("/api/v1/users/remove_network_user", logic.SecurityCheck(true, http.HandlerFunc(removeUserfromNetwork))).Methods(http.MethodPut) r.HandleFunc("/api/v1/users/unassigned_network_users", logic.SecurityCheck(true, http.HandlerFunc(listUnAssignedNetUsers))).Methods(http.MethodGet) @@ -649,6 +651,115 @@ func updateUserGroup(w http.ResponseWriter, r *http.Request) { logic.ReturnSuccessResponseWithJson(w, r, userGroup, "updated user group") } +// @Summary List user groups with access to a network +// @Router /api/v1/users/groups/network [get] +// @Tags Users +// @Security oauth +// @Produce json +// @Param network query string true "Network ID" +// @Success 200 {array} schema.UserGroup +// @Failure 400 {object} models.ErrorResponse +// @Failure 500 {object} models.ErrorResponse +func listNetworkUserGroups(w http.ResponseWriter, r *http.Request) { + network := r.URL.Query().Get("network") + if network == "" { + logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("network is required"), logic.BadReq)) + return + } + if err := (&schema.Network{Name: network}).Get(r.Context()); err != nil { + logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("network %s not found", network), logic.BadReq)) + return + } + netID := schema.NetworkID(network) + allGroups, err := (&schema.UserGroup{}).ListAll(r.Context()) + if err != nil { + logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.Internal)) + return + } + var networkGroups []schema.UserGroup + for _, grp := range allGroups { + roles := grp.NetworkRoles.Data() + if _, ok := roles[netID]; ok { + networkGroups = append(networkGroups, grp) + continue + } + if _, ok := roles[schema.AllNetworks]; ok { + networkGroups = append(networkGroups, grp) + } + } + if networkGroups == nil { + networkGroups = []schema.UserGroup{} + } + logic.ReturnSuccessResponseWithJson(w, r, networkGroups, "fetched user groups for network "+network) +} + +// @Summary List users with access to a network +// @Router /api/v1/users/network [get] +// @Tags Users +// @Security oauth +// @Produce json +// @Param network query string true "Network ID" +// @Success 200 {array} models.ReturnUser +// @Failure 400 {object} models.ErrorResponse +// @Failure 500 {object} models.ErrorResponse +func listNetworkUsers(w http.ResponseWriter, r *http.Request) { + network := r.URL.Query().Get("network") + if network == "" { + logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("network is required"), logic.BadReq)) + return + } + if err := (&schema.Network{Name: network}).Get(r.Context()); err != nil { + logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("network %s not found", network), logic.BadReq)) + return + } + netID := schema.NetworkID(network) + + allUsers, err := logic.GetUsers() + if err != nil { + logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.Internal)) + return + } + allGroupsList, err := (&schema.UserGroup{}).ListAll(r.Context()) + if err != nil { + logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.Internal)) + return + } + allGroupsMap := make(map[schema.UserGroupID]schema.UserGroup, len(allGroupsList)) + for _, g := range allGroupsList { + allGroupsMap[g.ID] = g + } + var networkUsers []models.ReturnUser + for _, user := range allUsers { + if user.PlatformRoleID == schema.SuperAdminRole || user.PlatformRoleID == schema.AdminRole { + networkUsers = append(networkUsers, user) + continue + } + hasAccess := false + for groupID := range user.UserGroups { + grp, ok := allGroupsMap[groupID] + if !ok { + continue + } + roles := grp.NetworkRoles.Data() + if _, ok := roles[netID]; ok { + hasAccess = true + break + } + if _, ok := roles[schema.AllNetworks]; ok { + hasAccess = true + break + } + } + if hasAccess { + networkUsers = append(networkUsers, user) + } + } + if networkUsers == nil { + networkUsers = []models.ReturnUser{} + } + logic.ReturnSuccessResponseWithJson(w, r, networkUsers, "fetched users for network "+network) +} + // @Summary List unassigned network users // @Router /api/v1/users/unassigned_network_users [get] // @Tags Users diff --git a/pro/logic/posture_check.go b/pro/logic/posture_check.go index e1bdd8fb..6bf3cdf0 100644 --- a/pro/logic/posture_check.go +++ b/pro/logic/posture_check.go @@ -504,6 +504,24 @@ func compareVersions(a, b string) int { return 0 } +// PopulatePostureCheckGroupNames sets group name as the value for each user group key +func PopulatePostureCheckGroupNames(pcs []schema.PostureCheck) { + for i := range pcs { + for groupID := range pcs[i].UserGroups { + if groupID == "*" { + pcs[i].UserGroups[groupID] = "*" + continue + } + grp, err := logic.GetUserGroup(schema.UserGroupID(groupID)) + if err == nil { + pcs[i].UserGroups[groupID] = grp.Name + } else { + pcs[i].UserGroups[groupID] = groupID + } + } + } +} + func ValidatePostureCheck(pc *schema.PostureCheck) error { if pc.Name == "" { return errors.New("name cannot be empty")