Merge pull request #3955 from gravitl/fixes/release-v1.5.1

Fixes: release-v1.5.1
This commit is contained in:
Abhishek Kondur
2026-04-03 16:51:03 +05:30
committed by GitHub
11 changed files with 232 additions and 28 deletions
+1 -1
View File
@@ -56,7 +56,7 @@ func userHandlers(r *mux.Router) {
r.HandleFunc("/api/users/{username}/disable", logic.SecurityCheck(true, http.HandlerFunc(disableUserAccount))).Methods(http.MethodPost)
r.HandleFunc("/api/users/{username}/settings", logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(getUserSettings)))).Methods(http.MethodGet)
r.HandleFunc("/api/users/{username}/settings", logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(updateUserSettings)))).Methods(http.MethodPut)
r.HandleFunc("/api/v1/users", logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(getUserV1)))).Methods(http.MethodGet)
r.HandleFunc("/api/v1/users", logic.SecurityCheck(false, logic.ContinueIfUserMatchOrAdmin(http.HandlerFunc(getUserV1)))).Methods(http.MethodGet)
r.HandleFunc("/api/users", logic.SecurityCheck(true, http.HandlerFunc(getUsers))).Methods(http.MethodGet)
r.HandleFunc("/api/v2/users", logic.SecurityCheck(true, http.HandlerFunc(listUsers))).Methods(http.MethodGet)
r.HandleFunc("/api/v1/users/bulk", logic.SecurityCheck(true, http.HandlerFunc(bulkDeleteUsers))).Methods(http.MethodDelete)
+1 -1
View File
@@ -451,7 +451,7 @@ func DeleteUser(user string) error {
return err
}
go RemoveUserFromAclPolicy(user)
RemoveUserFromAclPolicy(user)
return (&schema.UserAccessToken{UserName: user}).DeleteAllUserTokens(db.WithContext(context.TODO()))
}
+30
View File
@@ -1,11 +1,13 @@
package logic
import (
"context"
"errors"
"net/http"
"strings"
"github.com/golang-jwt/jwt/v4"
"github.com/gravitl/netmaker/db"
"github.com/gorilla/mux"
"github.com/gravitl/netmaker/models"
@@ -181,3 +183,31 @@ func ContinueIfUserMatch(next http.Handler) http.HandlerFunc {
next.ServeHTTP(w, r)
}
}
func ContinueIfUserMatchOrAdmin(next http.Handler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var errorResponse = models.ErrorResponse{
Code: http.StatusForbidden, Message: Forbidden_Msg,
}
user := &schema.User{
Username: r.Header.Get("user"),
}
err := user.Get(db.WithContext(context.TODO()))
if err == nil && (user.PlatformRoleID == schema.SuperAdminRole || user.PlatformRoleID == schema.AdminRole) {
next.ServeHTTP(w, r)
return
}
var params = mux.Vars(r)
var requestedUser = params["username"]
if requestedUser == "" {
requestedUser = r.URL.Query().Get("username")
}
if requestedUser != r.Header.Get("user") {
ReturnErrorResponse(w, r, errorResponse)
return
}
next.ServeHTTP(w, r)
}
}
+79
View File
@@ -38,6 +38,7 @@ func Run() {
resync()
deleteOldExtclients()
cleanupDeletedUserGroupRefs()
migrateNameservers()
}
func updateNetworks() {
@@ -737,8 +738,10 @@ func cleanupDeletedUserGroupRefs() {
existingGroups[group.ID] = group
}
existingUsers := make(map[string]schema.User)
users, _ := (&schema.User{}).ListAll(db.WithContext(context.TODO()))
for _, user := range users {
existingUsers[user.Username] = user
var update bool
for groupID := range user.UserGroups.Data() {
if _, ok := existingGroups[groupID]; !ok {
@@ -770,6 +773,10 @@ func cleanupDeletedUserGroupRefs() {
newSrc = append(newSrc, src)
}
}
} else if src.ID == models.UserAclID {
if _, ok := existingUsers[src.Value]; ok {
newSrc = append(newSrc, src)
}
} else {
newSrc = append(newSrc, src)
}
@@ -798,3 +805,75 @@ func cleanupDeletedUserGroupRefs() {
}
}
}
func migrateNameservers() {
networks, _ := (&schema.Network{}).ListAll(db.WithContext(context.TODO()))
for _, network := range networks {
_ = logic.CreateFallbackNameserver(network.Name)
}
nameservers, _ := (&schema.Nameserver{}).ListAll(db.WithContext(context.TODO()))
for _, nameserver := range nameservers {
if len(nameserver.Domains) != 0 {
for _, matchDomain := range nameserver.MatchDomains {
nameserver.Domains = append(nameserver.Domains, schema.NameserverDomain{
Domain: matchDomain,
})
}
nameserver.MatchDomains = []string{}
_ = nameserver.Update(db.WithContext(context.TODO()))
}
}
superAdmin := &schema.User{}
err := superAdmin.GetSuperAdmin(db.WithContext(context.TODO()))
if err != nil {
return
}
nodes, _ := logic.GetAllNodes()
for _, node := range nodes {
if !node.IsGw {
continue
}
if node.IngressDNS != "" {
if (node.Address.IP != nil && node.Address.IP.String() == node.IngressDNS) ||
(node.Address6.IP != nil && node.Address6.IP.String() == node.IngressDNS) {
continue
}
if node.IngressDNS == "8.8.8.8" || node.IngressDNS == "1.1.1.1" || node.IngressDNS == "9.9.9.9" {
continue
}
host := &schema.Host{
ID: node.HostID,
}
err := host.Get(db.WithContext(context.TODO()))
if err != nil {
continue
}
ns := schema.Nameserver{
ID: uuid.NewString(),
Name: fmt.Sprintf("%s gw nameservers", host.Name),
NetworkID: node.Network,
Servers: []string{node.IngressDNS},
MatchAll: true,
Domains: []schema.NameserverDomain{
{
Domain: ".",
},
},
Nodes: datatypes.JSONMap{
node.ID.String(): struct{}{},
},
Tags: make(datatypes.JSONMap),
Status: true,
CreatedBy: superAdmin.Username,
}
_ = ns.Create(db.WithContext(context.TODO()))
node.IngressDNS = ""
_ = logic.UpsertNode(&node)
}
}
}
+63
View File
@@ -5,8 +5,10 @@ import (
"encoding/json"
"errors"
"fmt"
"net"
"time"
"github.com/google/uuid"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/db"
"github.com/gravitl/netmaker/logger"
@@ -204,6 +206,67 @@ func migrateNetworks(ctx context.Context) error {
logger.Log(4, fmt.Sprintf("migrating network %s failed: %v", _network.Name, err))
return err
}
_, cidr, err := net.ParseCIDR(network.AddressRange)
if err != nil {
err = fmt.Errorf("error parsing network (%s) cidr (%s): %v", _network.Name, network.AddressRange, err)
logger.Log(4, fmt.Sprintf("migrating network %s failed: %v", _network.Name, err))
return err
}
_, cidrv6, err := net.ParseCIDR(network.AddressRange6)
if err != nil {
err = fmt.Errorf("error parsing network (%s) cidr (%s): %v", _network.Name, network.AddressRange6, err)
logger.Log(4, fmt.Sprintf("migrating network %s failed: %v", _network.Name, err))
return err
}
superAdmin := &schema.User{}
err = superAdmin.GetSuperAdmin(ctx)
if err != nil {
err = fmt.Errorf("error getting superadmin: %v", err)
logger.Log(4, fmt.Sprintf("migrating network %s failed: %v", _network.Name, err))
return err
}
if len(network.NameServers) > 0 {
ns := schema.Nameserver{
ID: uuid.NewString(),
Name: "upstream nameservers",
NetworkID: _network.Name,
Servers: []string{},
MatchAll: true,
Domains: []schema.NameserverDomain{
{
Domain: ".",
},
},
Tags: datatypes.JSONMap{
"*": struct{}{},
},
Nodes: make(datatypes.JSONMap),
Status: true,
CreatedBy: superAdmin.Username,
}
for _, nsIP := range network.NameServers {
if net.ParseIP(nsIP) == nil {
continue
}
if !cidr.Contains(net.ParseIP(nsIP)) && !cidrv6.Contains(net.ParseIP(nsIP)) {
ns.Servers = append(ns.Servers, nsIP)
}
}
if len(ns.Servers) > 0 {
err = ns.Create(ctx)
if err != nil {
err = fmt.Errorf("error creating upstream nameserver for network (%s): %v", _network.Name, err)
logger.Log(4, fmt.Sprintf("migrating network %s failed: %v", _network.Name, err))
return err
}
}
}
}
return nil
+2 -1
View File
@@ -106,7 +106,8 @@ func listUserActivity(w http.ResponseWriter, r *http.Request) {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
}
if caller.Username != username && caller.PlatformRoleID != schema.SuperAdminRole && caller.PlatformRoleID != schema.AdminRole {
if caller.Username != username && caller.PlatformRoleID != schema.SuperAdminRole &&
caller.PlatformRoleID != schema.AdminRole && caller.PlatformRoleID != schema.Auditor {
logic.ReturnErrorResponse(w, r, models.ErrorResponse{
Code: http.StatusForbidden,
Message: "you are not authorized to view this user's activity",
+25 -24
View File
@@ -75,11 +75,11 @@ type FlowRow struct {
// @Param network_id query string false "Filter by network ID"
// @Param from query string false "Start time in RFC3339 format"
// @Param to query string false "End time in RFC3339 format"
// @Param src_type query string false "Source type filter"
// @Param src_type []query string false "Source type filter"
// @Param src_entity_id query string false "Source entity ID filter"
// @Param dst_type query string false "Destination type filter"
// @Param dst_type []query string false "Destination type filter"
// @Param dst_entity_id query string false "Destination entity ID filter"
// @Param protocol query string false "Protocol filter"
// @Param protocol []query string false "Protocol filter"
// @Param node_id query string false "Node ID filter"
// @Param username query string false "Username filter"
// @Param page query int false "Page number"
@@ -115,7 +115,7 @@ func handleListFlows(w http.ResponseWriter, r *http.Request) {
args = append(args, networkID)
}
// 1. Time filtering (version: UInt64 timestamp in ms)
// 1. Time filtering (start_ts: UInt64 timestamp in ms)
fromStr := q.Get("from")
toStr := q.Get("to")
@@ -125,7 +125,7 @@ func handleListFlows(w http.ResponseWriter, r *http.Request) {
logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("invalid 'from' timestamp: %v", err), logic.BadReq))
return
}
whereParts = append(whereParts, "version >= ?")
whereParts = append(whereParts, "start_ts >= ?")
args = append(args, fromVal)
}
@@ -135,15 +135,14 @@ func handleListFlows(w http.ResponseWriter, r *http.Request) {
logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("invalid 'to' timestamp: %v", err), logic.BadReq))
return
}
whereParts = append(whereParts, "version <= ?")
whereParts = append(whereParts, "start_ts <= ?")
args = append(args, toVal)
}
// 2. Source filters
srcTypeStr := q.Get("src_type")
if srcTypeStr != "" {
whereParts = append(whereParts, "src_type = ?")
args = append(args, srcTypeStr)
if q.Get("src_type") != "" {
whereParts = append(whereParts, "src_type IN ?")
args = append(args, q["src_type"])
}
srcEntity := q.Get("src_entity_id")
@@ -153,10 +152,9 @@ func handleListFlows(w http.ResponseWriter, r *http.Request) {
}
// 3. Destination filters
dstTypeStr := q.Get("dst_type")
if dstTypeStr != "" {
whereParts = append(whereParts, "dst_type = ?")
args = append(args, dstTypeStr)
if q.Get("dst_type") != "" {
whereParts = append(whereParts, "dst_type IN ?")
args = append(args, q["dst_type"])
}
dstEntity := q.Get("dst_entity_id")
@@ -166,10 +164,9 @@ func handleListFlows(w http.ResponseWriter, r *http.Request) {
}
// 4. Protocol filter
protoStr := q.Get("protocol")
if protoStr != "" {
whereParts = append(whereParts, "protocol = ?")
args = append(args, protoStr)
if q.Get("protocol") != "" {
whereParts = append(whereParts, "protocol IN ?")
args = append(args, q["protocol"])
}
// 5. Node filter
@@ -202,21 +199,25 @@ func handleListFlows(w http.ResponseWriter, r *http.Request) {
// 6. User filter
username := q.Get("username")
if username != "" {
if srcTypeStr != "" || dstTypeStr != "" ||
srcEntity != "" || dstEntity != "" {
if q.Has("src_type") || q.Has("dst_type") ||
q.Has("src_entity_id") || q.Has("dst_entity_id") {
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("cannot provide username filter along with src/dst type and id filters"), logic.BadReq))
return
}
srcTypeStr = "user"
srcEntity = username
dstTypeStr = "user"
dstEntity = username
srcTypeStr := "user"
srcEntity := username
dstTypeStr := "user"
dstEntity := username
whereParts = append(whereParts, "((src_type = ? AND src_entity_id = ?) OR (dst_type = ? AND dst_entity_id = ?))")
args = append(args, srcTypeStr, srcEntity, dstTypeStr, dstEntity)
}
// 7. Ignore flow logs with zero end_ts.
whereParts = append(whereParts, "end_ts <> ?")
args = append(args, time.Unix(0, 0))
// Pagination
page := parseIntOrDefault(q.Get("page"), 1)
perPage := parseIntOrDefault(q.Get("per_page"), 100)
+1
View File
@@ -288,6 +288,7 @@ func deleteTag(w http.ResponseWriter, r *http.Request) {
go func() {
proLogic.RemoveDeviceTagFromAclPolicies(tag.ID, tag.Network)
proLogic.RemoveTagFromPostureChecks(tag.ID, tag.Network)
proLogic.RemoveTagFromNameservers(tag.ID, tag.Network)
logic.RemoveTagFromEnrollmentKeys(tag.ID)
mq.PublishPeerUpdate(false)
}()
+5 -1
View File
@@ -2135,7 +2135,11 @@ func testIDPSync(w http.ResponseWriter, r *http.Request) {
return
}
case "azure-ad":
idpClient = azure.NewAzureEntraIDClient(req.ClientID, req.ClientSecret, req.AzureTenantID)
secret := req.ClientSecret
if secret == logic.Mask() {
secret = logic.GetServerSettings().ClientSecret
}
idpClient = azure.NewAzureEntraIDClient(req.ClientID, secret, req.AzureTenantID)
case "okta":
idpClient, err = okta.NewOktaClient(req.OktaOrgURL, req.OktaAPIToken)
if err != nil {
+20
View File
@@ -242,3 +242,23 @@ func GetNameserversForHost(h *schema.Host) (returnNsLi []models.Nameserver) {
}
return
}
func RemoveTagFromNameservers(tagID models.TagID, netID schema.NetworkID) error {
nameservers, err := (&schema.Nameserver{
NetworkID: netID.String(),
}).ListByNetwork(db.WithContext(context.TODO()))
if err != nil {
return err
}
var multiErr error
for _, nameserver := range nameservers {
delete(nameserver.Tags, tagID.String())
err := nameserver.Update(db.WithContext(context.TODO()))
if err != nil {
multiErr = errors.Join(multiErr, err)
}
}
return multiErr
}
+5
View File
@@ -46,6 +46,11 @@ func (ns *Nameserver) Create(ctx context.Context) error {
return db.FromContext(ctx).Model(&Nameserver{}).Create(&ns).Error
}
func (ns *Nameserver) ListAll(ctx context.Context) (dnsli []Nameserver, err error) {
err = db.FromContext(ctx).Model(&Nameserver{}).Find(&dnsli).Error
return
}
func (ns *Nameserver) ListByNetwork(ctx context.Context) (dnsli []Nameserver, err error) {
err = db.FromContext(ctx).Model(&Nameserver{}).Where("network_id = ?", ns.NetworkID).Find(&dnsli).Error
return