diff --git a/controllers/enrollmentkeys.go b/controllers/enrollmentkeys.go index fa839880..ccdf0ade 100644 --- a/controllers/enrollmentkeys.go +++ b/controllers/enrollmentkeys.go @@ -2,8 +2,11 @@ package controller import ( "encoding/json" + "errors" "fmt" "net/http" + "os" + "slices" "time" "github.com/go-playground/validator/v10" @@ -358,6 +361,37 @@ func handleHostRegister(w http.ResponseWriter, r *http.Request) { ) return } + if newHost.EndpointIP != nil { + newHost.Location, newHost.CountryCode = logic.GetHostLocInfo(newHost.EndpointIP.String(), os.Getenv("IP_INFO_TOKEN")) + } else if newHost.EndpointIPv6 != nil { + newHost.Location, newHost.CountryCode = logic.GetHostLocInfo(newHost.EndpointIPv6.String(), os.Getenv("IP_INFO_TOKEN")) + } + pcviolations := []models.Violation{} + skipViolatedNetworks := []string{} + for _, netI := range enrollmentKey.Networks { + violations, _ := logic.CheckPostureViolations(models.PostureCheckDeviceInfo{ + ClientLocation: newHost.CountryCode, + ClientVersion: newHost.Version, + OS: newHost.OS, + OSFamily: newHost.OSFamily, + OSVersion: newHost.OSVersion, + KernelVersion: newHost.KernelVersion, + AutoUpdate: newHost.AutoUpdate, + }, models.NetworkID(netI)) + pcviolations = append(pcviolations, violations...) + if len(violations) > 0 { + skipViolatedNetworks = append(skipViolatedNetworks, netI) + } + } + if len(skipViolatedNetworks) == len(enrollmentKey.Networks) && len(pcviolations) > 0 { + logic.ReturnErrorResponse(w, r, + logic.FormatError(errors.New("access blocked: this device doesn’t meet security requirements"), logic.Forbidden)) + return + } + // need to remove the networks that were skipped from the enrollment key + enrollmentKey.Networks = slices.DeleteFunc(enrollmentKey.Networks, func(netI string) bool { + return slices.Contains(skipViolatedNetworks, netI) + }) if !hostExists { newHost.PersistentKeepalive = models.DefaultPersistentKeepAlive // register host diff --git a/controllers/ext_client.go b/controllers/ext_client.go index a958dd6f..c293bb40 100644 --- a/controllers/ext_client.go +++ b/controllers/ext_client.go @@ -715,13 +715,29 @@ func createExtClient(w http.ResponseWriter, r *http.Request) { } err = errors.New("remote client config already exists on the gateway") slog.Error("failed to create extclient", "user", userName, "error", err) - logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.BadReq)) return } } } extclient := logic.UpdateExtClient(&models.ExtClient{}, &customExtClient) + if extclient.DeviceID != "" { + // check for violations connecting from desktop app + violations, _ := logic.CheckPostureViolations(models.PostureCheckDeviceInfo{ + ClientLocation: extclient.Country, + ClientVersion: extclient.ClientVersion, + OS: extclient.OS, + OSFamily: extclient.OSFamily, + OSVersion: extclient.OSVersion, + KernelVersion: extclient.KernelVersion, + //AutoUpdate: extclient.AutoUpdate, + }, models.NetworkID(extclient.Network)) + if len(violations) > 0 { + logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("posture check violations"), logic.Forbidden)) + return + } + } extclient.OwnerID = userName extclient.RemoteAccessClientID = customExtClient.RemoteAccessClientID extclient.IngressGatewayID = nodeid @@ -749,7 +765,6 @@ func createExtClient(w http.ResponseWriter, r *http.Request) { if err == nil { // check if parent network default ACL is enabled (yes) or not (no) extclient.Enabled = parentNetwork.DefaultACL == "yes" } - extclient.Os = customExtClient.Os extclient.DeviceID = customExtClient.DeviceID extclient.DeviceName = customExtClient.DeviceName if customExtClient.IsAlreadyConnectedToInetGw { @@ -758,7 +773,7 @@ func createExtClient(w http.ResponseWriter, r *http.Request) { extclient.PublicEndpoint = customExtClient.PublicEndpoint extclient.Country = customExtClient.Country if customExtClient.RemoteAccessClientID != "" && customExtClient.Location == "" { - extclient.Location = logic.GetHostLocInfo(logic.GetClientIP(r), os.Getenv("IP_INFO_TOKEN")) + extclient.Location, extclient.Country = logic.GetHostLocInfo(logic.GetClientIP(r), os.Getenv("IP_INFO_TOKEN")) } extclient.Location = customExtClient.Location @@ -772,7 +787,7 @@ func createExtClient(w http.ResponseWriter, r *http.Request) { "error", err, ) - logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.Internal)) return } @@ -820,7 +835,7 @@ func createExtClient(w http.ResponseWriter, r *http.Request) { "error", err, ) - logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.Internal)) return } if err := mq.PublishPeerUpdate(false); err != nil { @@ -903,9 +918,25 @@ func updateExtClient(w http.ResponseWriter, r *http.Request) { replacePeers = true } if update.RemoteAccessClientID != "" && update.Location == "" { - update.Location = logic.GetHostLocInfo(logic.GetClientIP(r), os.Getenv("IP_INFO_TOKEN")) + update.Location, update.Country = logic.GetHostLocInfo(logic.GetClientIP(r), os.Getenv("IP_INFO_TOKEN")) } newclient := logic.UpdateExtClient(&oldExtClient, &update) + if newclient.DeviceID != "" && newclient.Enabled { + // check for violations connecting from desktop app + violations, _ := logic.CheckPostureViolations(models.PostureCheckDeviceInfo{ + ClientLocation: newclient.Country, + ClientVersion: newclient.ClientVersion, + OS: newclient.OS, + OSFamily: newclient.OSFamily, + OSVersion: newclient.OSVersion, + KernelVersion: newclient.KernelVersion, + //AutoUpdate: extclient.AutoUpdate, + }, models.NetworkID(newclient.Network)) + if len(violations) > 0 { + logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("posture check violations"), logic.Forbidden)) + return + } + } if err := logic.DeleteExtClient(oldExtClient.Network, oldExtClient.ClientID, true); err != nil { slog.Error( "failed to delete ext client", diff --git a/controllers/hosts.go b/controllers/hosts.go index 38e454a6..29ebfb2a 100644 --- a/controllers/hosts.go +++ b/controllers/hosts.go @@ -547,7 +547,7 @@ func addHostToNetwork(w http.ResponseWriter, r *http.Request) { logic.ReturnErrorResponse( w, r, - logic.FormatError(errors.New("hostid or network cannot be empty"), "badrequest"), + logic.FormatError(errors.New("hostid or network cannot be empty"), logic.BadReq), ) return } @@ -555,10 +555,23 @@ func addHostToNetwork(w http.ResponseWriter, r *http.Request) { currHost, err := logic.GetHost(hostid) if err != nil { logger.Log(0, r.Header.Get("user"), "failed to find host:", hostid, err.Error()) - logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.Internal)) return } + violations, _ := logic.CheckPostureViolations(models.PostureCheckDeviceInfo{ + ClientLocation: currHost.CountryCode, + ClientVersion: currHost.Version, + OS: currHost.OS, + OSFamily: currHost.OSFamily, + OSVersion: currHost.OSVersion, + KernelVersion: currHost.KernelVersion, + AutoUpdate: currHost.AutoUpdate, + }, models.NetworkID(network)) + if len(violations) > 0 { + logic.ReturnErrorResponseWithJson(w, r, violations, logic.FormatError(errors.New("posture check violations"), logic.BadReq)) + return + } newNode, err := logic.UpdateHostNetwork(currHost, network, true) if err != nil { logger.Log( diff --git a/controllers/node.go b/controllers/node.go index 918d9258..5c28254a 100644 --- a/controllers/node.go +++ b/controllers/node.go @@ -5,6 +5,7 @@ import ( "fmt" "net/http" "strings" + "time" "github.com/gorilla/mux" "github.com/gravitl/netmaker/database" @@ -661,6 +662,10 @@ func updateNode(w http.ResponseWriter, r *http.Request) { logic.ResetAutoRelayedPeer(¤tNode) } } + newNode.PostureChecksViolations, + newNode.PostureCheckVolationSeverityLevel = logic.CheckPostureViolations(logic.GetPostureCheckDeviceInfoByNode(newNode), + models.NetworkID(newNode.Network)) + newNode.LastEvaluatedAt = time.Now().UTC() logic.UpsertNode(newNode) logic.GetNodeStatus(newNode, false) diff --git a/go.mod b/go.mod index dfb196fd..f0bc052b 100644 --- a/go.mod +++ b/go.mod @@ -40,6 +40,8 @@ require ( ) require ( + github.com/Masterminds/semver/v3 v3.4.0 + github.com/biter777/countries v1.7.5 github.com/google/go-cmp v0.7.0 github.com/goombaio/namegenerator v0.0.0-20181006234301-989e774b106e github.com/guumaster/tablewriter v0.0.10 diff --git a/go.sum b/go.sum index 63cfe2c2..4cd75e77 100644 --- a/go.sum +++ b/go.sum @@ -6,6 +6,10 @@ cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdB cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10= filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= +github.com/Masterminds/semver/v3 v3.4.0 h1:Zog+i5UMtVoCU8oKka5P7i9q9HgrJeGzI9SA1Xbatp0= +github.com/Masterminds/semver/v3 v3.4.0/go.mod h1:4V+yj/TJE1HU9XfppCwVMZq3I84lprf4nC11bSS5beM= +github.com/biter777/countries v1.7.5 h1:MJ+n3+rSxWQdqVJU8eBy9RqcdH6ePPn4PJHocVWUa+Q= +github.com/biter777/countries v1.7.5/go.mod h1:1HSpZ526mYqKJcpT5Ti1kcGQ0L0SrXWIaptUWjFfv2E= github.com/blang/semver v3.5.1+incompatible h1:cQNTCjp13qL8KC3Nbxr/y2Bqb63oX6wdnnjpJbkM4JQ= github.com/blang/semver v3.5.1+incompatible/go.mod h1:kRBLl5iJ+tD4TcOOxsy/0fnwebNt5EWlYSAyrTnjyyk= github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc h1:biVzkmvwrH8WK8raXaxBx6fRVTlJILwEwQGL1I/ByEI= diff --git a/logic/errors.go b/logic/errors.go index 2931ed29..40db4bab 100644 --- a/logic/errors.go +++ b/logic/errors.go @@ -77,3 +77,16 @@ func ReturnErrorResponse(response http.ResponseWriter, request *http.Request, er response.WriteHeader(errorMessage.Code) response.Write(jsonResponse) } + +// ReturnErrorResponseWithJson - processes error with body and adds header +func ReturnErrorResponseWithJson(response http.ResponseWriter, request *http.Request, msg interface{}, errorMessage models.ErrorResponse) { + httpResponse := &models.ErrorResponse{Code: errorMessage.Code, Message: errorMessage.Message, Response: msg} + jsonResponse, err := json.Marshal(httpResponse) + if err != nil { + panic(err) + } + slog.Debug("processed request error", "err", errorMessage.Message) + response.Header().Set("Content-Type", "application/json") + response.WriteHeader(errorMessage.Code) + response.Write(jsonResponse) +} diff --git a/logic/extpeers.go b/logic/extpeers.go index 3171dd75..039741b3 100644 --- a/logic/extpeers.go +++ b/logic/extpeers.go @@ -451,11 +451,26 @@ func UpdateExtClient(old *models.ExtClient, update *models.CustomExtClient) mode new.Location = update.Location } if update.Country != "" && update.Country != old.Country { - new.Country = update.Country + new.Country = strings.ToUpper(update.Country) } if update.DeviceID != "" && old.DeviceID == "" { new.DeviceID = update.DeviceID } + if update.OS != "" { + new.OS = update.OS + } + if update.OSFamily != "" { + new.OSFamily = update.OSFamily + } + if update.OSVersion != "" { + new.OSVersion = update.OSVersion + } + if update.KernelVersion != "" { + new.KernelVersion = update.KernelVersion + } + if update.ClientVersion != "" { + new.ClientVersion = update.ClientVersion + } return new } diff --git a/logic/gateway.go b/logic/gateway.go index 1c5c8791..89c4f449 100644 --- a/logic/gateway.go +++ b/logic/gateway.go @@ -231,6 +231,8 @@ func CreateIngressGateway(netid string, nodeid string, ingress models.IngressReq node.Tags = make(map[models.TagID]struct{}) } node.Tags[models.TagID(fmt.Sprintf("%s.%s", netid, models.GwTagName))] = struct{}{} + node.PostureChecksViolations, node.PostureCheckVolationSeverityLevel = CheckPostureViolations(GetPostureCheckDeviceInfoByNode(&node), models.NetworkID(node.Network)) + node.LastEvaluatedAt = time.Now().UTC() err = UpsertNode(&node) if err != nil { return models.Node{}, err @@ -282,11 +284,12 @@ func DeleteIngressGateway(nodeid string) (models.Node, []models.ExtClient, error delete(node.Tags, models.TagID(fmt.Sprintf("%s.%s", node.Network, models.GwTagName))) node.IngressGatewayRange = "" node.Metadata = "" + node.PostureChecksViolations, node.PostureCheckVolationSeverityLevel = CheckPostureViolations(GetPostureCheckDeviceInfoByNode(&node), models.NetworkID(node.Network)) + node.LastEvaluatedAt = time.Now().UTC() err = UpsertNode(&node) if err != nil { return models.Node{}, removedClients, err } - err = SetNetworkNodesLastModified(node.Network) return node, removedClients, err } diff --git a/logic/hosts.go b/logic/hosts.go index 1037643b..c86cf32c 100644 --- a/logic/hosts.go +++ b/logic/hosts.go @@ -34,7 +34,11 @@ var ( ErrInvalidHostID error = errors.New("invalid host id") ) -var GetHostLocInfo = func(ip, token string) string { return "" } +var GetHostLocInfo = func(ip, token string) (string, string) { return "", "" } + +var CheckPostureViolations = func(d models.PostureCheckDeviceInfo, network models.NetworkID) (v []models.Violation, level models.Severity) { + return []models.Violation{}, models.SeverityUnknown +} func getHostsFromCache() (hosts []models.Host) { hostCacheMutex.RLock() @@ -254,9 +258,9 @@ func CreateHost(h *models.Host) error { h.DNS = "no" } if h.EndpointIP != nil { - h.Location = GetHostLocInfo(h.EndpointIP.String(), os.Getenv("IP_INFO_TOKEN")) + h.Location, h.CountryCode = GetHostLocInfo(h.EndpointIP.String(), os.Getenv("IP_INFO_TOKEN")) } else if h.EndpointIPv6 != nil { - h.Location = GetHostLocInfo(h.EndpointIPv6.String(), os.Getenv("IP_INFO_TOKEN")) + h.Location, h.CountryCode = GetHostLocInfo(h.EndpointIPv6.String(), os.Getenv("IP_INFO_TOKEN")) } checkForZombieHosts(h) return UpsertHost(h) diff --git a/logic/nodes.go b/logic/nodes.go index b7af9c75..0891eb90 100644 --- a/logic/nodes.go +++ b/logic/nodes.go @@ -620,7 +620,15 @@ func FindRelay(node *models.Node) *models.Node { func GetAllNodesAPI(nodes []models.Node) []models.ApiNode { apiNodes := []models.ApiNode{} for i := range nodes { - newApiNode := nodes[i].ConvertToAPINode() + node := nodes[i] + if !node.IsStatic { + h, err := GetHost(node.HostID.String()) + if err == nil { + node.Location = h.Location + node.CountryCode = h.CountryCode + } + } + newApiNode := node.ConvertToAPINode() apiNodes = append(apiNodes, *newApiNode) } return apiNodes[:] @@ -874,3 +882,40 @@ func GetAllFailOvers() ([]models.Node, error) { } return igs, nil } + +// GetPostureCheckDeviceInfoByNode retrieves PostureCheckDeviceInfo for a given node +func GetPostureCheckDeviceInfoByNode(node *models.Node) models.PostureCheckDeviceInfo { + var deviceInfo models.PostureCheckDeviceInfo + + if !node.IsStatic { + h, err := GetHost(node.HostID.String()) + if err != nil { + return deviceInfo + } + deviceInfo = models.PostureCheckDeviceInfo{ + ClientLocation: h.CountryCode, + ClientVersion: h.Version, + OS: h.OS, + OSVersion: h.OSVersion, + OSFamily: h.OSFamily, + KernelVersion: h.KernelVersion, + AutoUpdate: h.AutoUpdate, + Tags: node.Tags, + } + } else { + if node.StaticNode.DeviceID == "" && node.StaticNode.RemoteAccessClientID == "" { + return deviceInfo + } + deviceInfo = models.PostureCheckDeviceInfo{ + ClientLocation: node.StaticNode.Country, + ClientVersion: node.StaticNode.ClientVersion, + OS: node.StaticNode.OS, + OSVersion: node.StaticNode.OSVersion, + OSFamily: node.StaticNode.OSFamily, + KernelVersion: node.StaticNode.KernelVersion, + Tags: node.StaticNode.Tags, + } + } + + return deviceInfo +} diff --git a/logic/sysinfo.go b/logic/sysinfo.go new file mode 100644 index 00000000..4b79ab48 --- /dev/null +++ b/logic/sysinfo.go @@ -0,0 +1,227 @@ +package logic + +import ( + "bytes" + "os" + "os/exec" + "runtime" + "strings" +) + +type OSInfo struct { + OS string `json:"os"` // e.g. "ubuntu", "windows", "macos" + OSFamily string `json:"os_family"` // e.g. "linux-debian", "windows" + OSVersion string `json:"os_version"` // e.g. "22.04", "10.0.22631" + KernelVersion string `json:"kernel_version"` // e.g. "6.8.0" +} + +/// --- classification helpers you already had --- + +func NormalizeOSName(raw string) string { + return strings.ToLower(strings.TrimSpace(raw)) +} + +// OSFamily returns a normalized OS family string. +// Examples: "linux-debian", "linux-redhat", "linux-arch", "linux-other", "windows", "darwin" +func OSFamily(osName string) string { + osName = NormalizeOSName(osName) + + // Non-Linux first + if strings.Contains(osName, "windows") { + return "windows" + } + if strings.Contains(osName, "darwin") || strings.Contains(osName, "mac") || strings.Contains(osName, "os x") { + return "darwin" + } + + // Linux families + switch { + // Debian family + case containsAny(osName, + "debian", "ubuntu", "pop", "linuxmint", "kali", "raspbian", "elementary"): + return "linux-debian" + + // Red Hat family + case containsAny(osName, + "rhel", "red hat", "centos", "rocky", "alma", "fedora", "oracle linux", "ol"): + return "linux-redhat" + + // SUSE family + case containsAny(osName, + "suse", "opensuse", "sles"): + return "linux-suse" + + // Arch family + case containsAny(osName, + "arch", "manjaro", "endeavouros", "garuda"): + return "linux-arch" + + // Gentoo + case strings.Contains(osName, "gentoo"): + return "linux-gentoo" + + // Alpine, Amazon, BusyBox, etc. + case containsAny(osName, + "alpine", "amazon", "busybox"): + return "linux-other" + } + + // Fallbacks + if strings.Contains(osName, "linux") { + return "linux-other" + } + + return "unknown" +} + +func containsAny(s string, subs ...string) bool { + for _, sub := range subs { + if strings.Contains(s, sub) { + return true + } + } + return false +} + +/// --- public entrypoint --- + +// GetOSInfo returns OS, OSFamily, OSVersion and KernelVersion for the current platform. +func GetOSInfo() OSInfo { + switch runtime.GOOS { + case "linux": + return getLinuxOSInfo() + case "darwin": + return getDarwinOSInfo() + case "windows": + return getWindowsOSInfo() + default: + // Fallback for other UNIX-likes; best-effort + kernel := strings.TrimSpace(runCmd("uname", "-r")) + name := runtime.GOOS + return OSInfo{ + OS: NormalizeOSName(name), + OSFamily: OSFamily(name), + OSVersion: "", + KernelVersion: CleanVersion(kernel), + } + } +} + +/// --- Linux --- + +func getLinuxOSInfo() OSInfo { + var osName, osVersion string + + data, err := os.ReadFile("/etc/os-release") + if err == nil { + lines := strings.Split(string(data), "\n") + for _, line := range lines { + line = strings.TrimSpace(line) + if line == "" || strings.HasPrefix(line, "#") { + continue + } + parts := strings.SplitN(line, "=", 2) + if len(parts) != 2 { + continue + } + key := parts[0] + value := strings.Trim(parts[1], `"'`) + + switch key { + case "ID": + osName = value + case "VERSION_ID": + osVersion = value + } + } + } + + if osName == "" { + // Fallback + osName = "linux" + } + kernel := strings.TrimSpace(runCmd("uname", "-r")) + // trim extras like -generic + if idx := strings.Index(kernel, "-"); idx > 0 { + kernel = kernel[:idx] + } + + normName := NormalizeOSName(osName) + return OSInfo{ + OS: "linux", + OSFamily: OSFamily(normName), + OSVersion: CleanVersion(osVersion), + KernelVersion: CleanVersion(kernel), + } +} + +/// --- macOS (darwin) --- + +func getDarwinOSInfo() OSInfo { + productName := strings.TrimSpace(runCmd("sw_vers", "-productName")) + productVer := strings.TrimSpace(runCmd("sw_vers", "-productVersion")) + + if productName == "" { + productName = "macos" + } + kernel := strings.TrimSpace(runCmd("uname", "-r")) + if idx := strings.Index(kernel, "-"); idx > 0 { + kernel = kernel[:idx] + } + + normName := NormalizeOSName(productName) + return OSInfo{ + OS: "darwin", + OSFamily: OSFamily(normName), // "darwin" + OSVersion: CleanVersion(productVer), // e.g. "15.0" + KernelVersion: CleanVersion(kernel), + } +} + +/// --- Windows --- + +func getWindowsOSInfo() OSInfo { + // OS name: we just say "windows" + osName := "windows" + + // OS version via "wmic" or "ver" as fallback + var version string + + // Try wmic first (may be missing on newer builds but often still present) + out := runCmd("wmic", "os", "get", "Version", "/value") + for _, line := range strings.Split(out, "\n") { + line = strings.TrimSpace(line) + if strings.HasPrefix(line, "Version=") { + version = strings.TrimPrefix(line, "Version=") + version = strings.TrimSpace(version) + break + } + } + + if version == "" { + // Fallback to "ver" + raw := strings.TrimSpace(runCmd("cmd", "/C", "ver")) + version = raw // you can add better parsing if you need + } + + // On Windows, kernel and OS version are effectively tied; reuse + kernel := version + + normName := NormalizeOSName(osName) + return OSInfo{ + OS: "windows", // "windows" + OSFamily: OSFamily(normName), + OSVersion: CleanVersion(version), // e.g. "10.0.22631" + KernelVersion: CleanVersion(kernel), + } +} + +/// --- small helper to run commands safely --- + +func runCmd(name string, args ...string) string { + cmd := exec.Command(name, args...) + var buf bytes.Buffer + cmd.Stdout = &buf + _ = cmd.Run() // ignore error; best-effort + return buf.String() +} diff --git a/logic/version.go b/logic/version.go index c9cfd331..f80a7a4f 100644 --- a/logic/version.go +++ b/logic/version.go @@ -1,9 +1,11 @@ package logic import ( + "regexp" "strings" "unicode" + "github.com/Masterminds/semver/v3" "github.com/hashicorp/go-version" ) @@ -29,3 +31,48 @@ func IsVersionCompatible(ver string) bool { return constraint.Check(v) } + +// CleanVersion normalizes a version string safely for storage. +// - removes "v" or "V" prefix +// - trims whitespace +// - strips invalid trailing characters +// - preserves semver, prerelease, and build metadata +func CleanVersion(raw string) string { + if raw == "" { + return "" + } + + v := strings.TrimSpace(raw) + + // Remove leading v/V (common in semver) + v = strings.TrimPrefix(v, "v") + v = strings.TrimPrefix(v, "V") + + // Remove trailing commas, quotes, spaces + v = strings.Trim(v, " ,\"'") + + // Remove any characters not allowed in semantic versioning: + // Allowed: 0-9 a-z A-Z . - + + re := regexp.MustCompile(`[^0-9A-Za-z\.\-\+]+`) + v = re.ReplaceAllString(v, "") + + // Collapse multiple dots (e.g., "1..2" → "1.2") + v = strings.ReplaceAll(v, "..", ".") + for strings.Contains(v, "..") { + v = strings.ReplaceAll(v, "..", ".") + } + + return v +} + +// IsValidVersion returns true if the version string can be parsed as semantic version. +func IsValidVersion(raw string) bool { + cleaned := CleanVersion(raw) + + if cleaned == "" { + return false + } + + _, err := semver.NewVersion(cleaned) + return err == nil +} diff --git a/migrate/migrate.go b/migrate/migrate.go index f206fc8e..425529ca 100644 --- a/migrate/migrate.go +++ b/migrate/migrate.go @@ -425,13 +425,13 @@ func updateHosts() { host.AutoUpdate = true logic.UpsertHost(&host) } - if servercfg.IsPro && host.Location == "" { + if servercfg.IsPro && (host.Location == "" || host.CountryCode == "") { if host.EndpointIP != nil { - host.Location = logic.GetHostLocInfo(host.EndpointIP.String(), os.Getenv("IP_INFO_TOKEN")) + host.Location, host.CountryCode = logic.GetHostLocInfo(host.EndpointIP.String(), os.Getenv("IP_INFO_TOKEN")) } else if host.EndpointIPv6 != nil { - host.Location = logic.GetHostLocInfo(host.EndpointIPv6.String(), os.Getenv("IP_INFO_TOKEN")) + host.Location, host.CountryCode = logic.GetHostLocInfo(host.EndpointIPv6.String(), os.Getenv("IP_INFO_TOKEN")) } - if host.Location != "" { + if host.Location != "" && host.CountryCode != "" { logic.UpsertHost(&host) } } @@ -891,6 +891,9 @@ func migrateSettings() { if settings.PeerConnectionCheckInterval == "" { settings.PeerConnectionCheckInterval = "15" } + if settings.PostureCheckInterval == "" { + settings.PostureCheckInterval = "30" + } if settings.CleanUpInterval == 0 { settings.CleanUpInterval = 60 } diff --git a/models/api_host.go b/models/api_host.go index b3849305..0a869d03 100644 --- a/models/api_host.go +++ b/models/api_host.go @@ -14,6 +14,9 @@ type ApiHost struct { Version string `json:"version"` Name string `json:"name"` OS string `json:"os"` + OSFamily string `json:"os_family" yaml:"os_family"` + OSVersion string `json:"os_version" yaml:"os_version"` + KernelVersion string `json:"kernel_version" yaml:"kernel_version"` Debug bool `json:"debug"` IsStaticPort bool `json:"isstaticport"` IsStatic bool `json:"isstatic"` @@ -33,6 +36,7 @@ type ApiHost struct { AutoUpdate bool `json:"autoupdate" yaml:"autoupdate"` DNS string `json:"dns" yaml:"dns"` Location string `json:"location"` + CountryCode string `json:"country_code"` } // ApiIface - the interface struct for API usage @@ -71,6 +75,8 @@ func (h *Host) ConvertNMHostToAPI() *ApiHost { a.MacAddress = h.MacAddress.String() a.Name = h.Name a.OS = h.OS + a.OSFamily = h.OSFamily + a.KernelVersion = h.KernelVersion a.Nodes = h.Nodes a.WgPublicListenPort = h.WgPublicListenPort a.PublicKey = h.PublicKey.String() @@ -82,6 +88,7 @@ func (h *Host) ConvertNMHostToAPI() *ApiHost { a.AutoUpdate = h.AutoUpdate a.DNS = h.DNS a.Location = h.Location + a.CountryCode = h.CountryCode return &a } @@ -122,6 +129,8 @@ func (a *ApiHost) ConvertAPIHostToNMHost(currentHost *Host) *Host { h.Nodes = currentHost.Nodes h.TrafficKeyPublic = currentHost.TrafficKeyPublic h.OS = currentHost.OS + h.OSFamily = currentHost.OSFamily + h.KernelVersion = currentHost.KernelVersion h.IsDefault = a.IsDefault h.NatType = currentHost.NatType h.TurnEndpoint = currentHost.TurnEndpoint @@ -129,5 +138,6 @@ func (a *ApiHost) ConvertAPIHostToNMHost(currentHost *Host) *Host { h.AutoUpdate = a.AutoUpdate h.DNS = strings.ToLower(a.DNS) h.Location = currentHost.Location + h.CountryCode = currentHost.CountryCode return &h } diff --git a/models/api_node.go b/models/api_node.go index 1d9f0504..b40996c3 100644 --- a/models/api_node.go +++ b/models/api_node.go @@ -53,20 +53,24 @@ type ApiNode struct { PendingDelete bool `json:"pendingdelete"` Metadata string `json:"metadata"` // == PRO == - DefaultACL string `json:"defaultacl,omitempty" validate:"checkyesornoorunset"` - IsFailOver bool `json:"is_fail_over"` - FailOverPeers map[string]struct{} `json:"fail_over_peers" yaml:"fail_over_peers"` - FailedOverBy uuid.UUID `json:"failed_over_by" yaml:"failed_over_by"` - IsInternetGateway bool `json:"isinternetgateway" yaml:"isinternetgateway"` - InetNodeReq InetNodeReq `json:"inet_node_req" yaml:"inet_node_req"` - InternetGwID string `json:"internetgw_node_id" yaml:"internetgw_node_id"` - AdditionalRagIps []string `json:"additional_rag_ips" yaml:"additional_rag_ips"` - Tags map[TagID]struct{} `json:"tags" yaml:"tags"` - IsStatic bool `json:"is_static"` - IsUserNode bool `json:"is_user_node"` - StaticNode ExtClient `json:"static_node"` - Status NodeStatus `json:"status"` - Location string `json:"location"` + DefaultACL string `json:"defaultacl,omitempty" validate:"checkyesornoorunset"` + IsFailOver bool `json:"is_fail_over"` + FailOverPeers map[string]struct{} `json:"fail_over_peers" yaml:"fail_over_peers"` + FailedOverBy uuid.UUID `json:"failed_over_by" yaml:"failed_over_by"` + IsInternetGateway bool `json:"isinternetgateway" yaml:"isinternetgateway"` + InetNodeReq InetNodeReq `json:"inet_node_req" yaml:"inet_node_req"` + InternetGwID string `json:"internetgw_node_id" yaml:"internetgw_node_id"` + AdditionalRagIps []string `json:"additional_rag_ips" yaml:"additional_rag_ips"` + Tags map[TagID]struct{} `json:"tags" yaml:"tags"` + IsStatic bool `json:"is_static"` + IsUserNode bool `json:"is_user_node"` + StaticNode ExtClient `json:"static_node"` + Status NodeStatus `json:"status"` + Location string `json:"location"` + Country string `json:"country"` + PostureChecksViolations []Violation `json:"posture_check_violations"` + PostureCheckVolationSeverityLevel Severity `json:"posture_check_violation_severity_level"` + LastEvaluatedAt time.Time `json:"last_evaluated_at"` } // ApiNode.ConvertToServerNode - converts an api node to a server node @@ -227,6 +231,11 @@ func (nm *Node) ConvertToAPINode() *ApiNode { apiNode.IsUserNode = nm.IsUserNode apiNode.StaticNode = nm.StaticNode apiNode.Status = nm.Status + apiNode.PostureChecksViolations = nm.PostureChecksViolations + apiNode.PostureCheckVolationSeverityLevel = nm.PostureCheckVolationSeverityLevel + apiNode.LastEvaluatedAt = nm.LastEvaluatedAt + apiNode.Location = nm.Location + apiNode.Country = nm.CountryCode return &apiNode } diff --git a/models/events.go b/models/events.go index 45b6a37b..017e9845 100644 --- a/models/events.go +++ b/models/events.go @@ -54,6 +54,7 @@ const ( EnrollmentKeySub SubjectType = "ENROLLMENT_KEY" ClientAppSub SubjectType = "CLIENT-APP" NameserverSub SubjectType = "NAMESERVER" + PostureCheckSub SubjectType = "POSTURE_CHECK" ) func (sub SubjectType) String() string { diff --git a/models/extclient.go b/models/extclient.go index 1e24f923..7ede66d3 100644 --- a/models/extclient.go +++ b/models/extclient.go @@ -1,35 +1,45 @@ package models -import "sync" +import ( + "sync" + "time" +) // ExtClient - struct for external clients type ExtClient struct { - ClientID string `json:"clientid" bson:"clientid"` - PrivateKey string `json:"privatekey" bson:"privatekey"` - PublicKey string `json:"publickey" bson:"publickey"` - Network string `json:"network" bson:"network"` - DNS string `json:"dns" bson:"dns"` - Address string `json:"address" bson:"address"` - Address6 string `json:"address6" bson:"address6"` - ExtraAllowedIPs []string `json:"extraallowedips" bson:"extraallowedips"` - AllowedIPs []string `json:"allowed_ips"` - IngressGatewayID string `json:"ingressgatewayid" bson:"ingressgatewayid"` - IngressGatewayEndpoint string `json:"ingressgatewayendpoint" bson:"ingressgatewayendpoint"` - LastModified int64 `json:"lastmodified" bson:"lastmodified" swaggertype:"primitive,integer" format:"int64"` - Enabled bool `json:"enabled" bson:"enabled"` - OwnerID string `json:"ownerid" bson:"ownerid"` - DeniedACLs map[string]struct{} `json:"deniednodeacls" bson:"acls,omitempty"` - RemoteAccessClientID string `json:"remote_access_client_id"` // unique ID (MAC address) of RAC machine - PostUp string `json:"postup" bson:"postup"` - PostDown string `json:"postdown" bson:"postdown"` - Tags map[TagID]struct{} `json:"tags"` - Os string `json:"os"` - DeviceID string `json:"device_id"` - DeviceName string `json:"device_name"` - PublicEndpoint string `json:"public_endpoint"` - Country string `json:"country"` - Location string `json:"location"` //format: lat,long - Mutex *sync.Mutex `json:"-"` + ClientID string `json:"clientid" bson:"clientid"` + PrivateKey string `json:"privatekey" bson:"privatekey"` + PublicKey string `json:"publickey" bson:"publickey"` + Network string `json:"network" bson:"network"` + DNS string `json:"dns" bson:"dns"` + Address string `json:"address" bson:"address"` + Address6 string `json:"address6" bson:"address6"` + ExtraAllowedIPs []string `json:"extraallowedips" bson:"extraallowedips"` + AllowedIPs []string `json:"allowed_ips"` + IngressGatewayID string `json:"ingressgatewayid" bson:"ingressgatewayid"` + IngressGatewayEndpoint string `json:"ingressgatewayendpoint" bson:"ingressgatewayendpoint"` + LastModified int64 `json:"lastmodified" bson:"lastmodified" swaggertype:"primitive,integer" format:"int64"` + Enabled bool `json:"enabled" bson:"enabled"` + OwnerID string `json:"ownerid" bson:"ownerid"` + DeniedACLs map[string]struct{} `json:"deniednodeacls" bson:"acls,omitempty"` + RemoteAccessClientID string `json:"remote_access_client_id"` // unique ID (MAC address) of RAC machine + PostUp string `json:"postup" bson:"postup"` + PostDown string `json:"postdown" bson:"postdown"` + Tags map[TagID]struct{} `json:"tags"` + OS string `json:"os"` + OSFamily string `json:"os_family" yaml:"os_family"` + OSVersion string `json:"os_version" yaml:"os_version"` + KernelVersion string `json:"kernel_version" yaml:"kernel_version"` + ClientVersion string `json:"client_version"` + DeviceID string `json:"device_id"` + DeviceName string `json:"device_name"` + PublicEndpoint string `json:"public_endpoint"` + Country string `json:"country"` + Location string `json:"location"` //format: lat,long + PostureChecksViolations []Violation `json:"posture_check_violations"` + PostureCheckVolationSeverityLevel Severity `json:"posture_check_violation_severity_level"` + LastEvaluatedAt time.Time `json:"last_evaluated_at"` + Mutex *sync.Mutex `json:"-"` } // CustomExtClient - struct for CustomExtClient params @@ -44,11 +54,15 @@ type CustomExtClient struct { PostUp string `json:"postup" bson:"postup" validate:"max=1024"` PostDown string `json:"postdown" bson:"postdown" validate:"max=1024"` Tags map[TagID]struct{} `json:"tags"` - Os string `json:"os"` DeviceID string `json:"device_id"` DeviceName string `json:"device_name"` IsAlreadyConnectedToInetGw bool `json:"is_already_connected_to_inet_gw"` PublicEndpoint string `json:"public_endpoint"` + OS string `json:"os"` + OSFamily string `json:"os_family" yaml:"os_family"` + OSVersion string `json:"os_version" yaml:"os_version"` + KernelVersion string `json:"kernel_version" yaml:"kernel_version"` + ClientVersion string `json:"client_version"` Country string `json:"country"` Location string `json:"location"` //format: lat,long } @@ -63,10 +77,15 @@ func (ext *ExtClient) ConvertToStaticNode() Node { Address: ext.AddressIPNet4(), Address6: ext.AddressIPNet6(), }, - Tags: ext.Tags, - IsStatic: true, - StaticNode: *ext, - IsUserNode: ext.RemoteAccessClientID != "" || ext.DeviceID != "", - Mutex: ext.Mutex, + Tags: ext.Tags, + IsStatic: true, + StaticNode: *ext, + IsUserNode: ext.RemoteAccessClientID != "" || ext.DeviceID != "", + Mutex: ext.Mutex, + CountryCode: ext.Country, + Location: ext.Location, + PostureChecksViolations: ext.PostureChecksViolations, + PostureCheckVolationSeverityLevel: ext.PostureCheckVolationSeverityLevel, + LastEvaluatedAt: ext.LastEvaluatedAt, } } diff --git a/models/host.go b/models/host.go index 636b3996..0b3d579e 100644 --- a/models/host.go +++ b/models/host.go @@ -51,6 +51,9 @@ type Host struct { HostPass string `json:"hostpass" yaml:"hostpass"` Name string `json:"name" yaml:"name"` OS string `json:"os" yaml:"os"` + OSFamily string `json:"os_family" yaml:"os_family"` + OSVersion string `json:"os_version" yaml:"os_version"` + KernelVersion string `json:"kernel_version" yaml:"kernel_version"` Interface string `json:"interface" yaml:"interface"` Debug bool `json:"debug" yaml:"debug"` ListenPort int `json:"listenport" yaml:"listenport"` @@ -74,6 +77,7 @@ type Host struct { TurnEndpoint *netip.AddrPort `json:"turn_endpoint,omitempty" yaml:"turn_endpoint,omitempty"` PersistentKeepalive time.Duration `json:"persistentkeepalive" swaggertype:"primitive,integer" format:"int64" yaml:"persistentkeepalive"` Location string `json:"location"` // Format: "lat,lon" + CountryCode string `json:"country_code"` } // FormatBool converts a boolean to a [yes|no] string diff --git a/models/node.go b/models/node.go index 21044655..429112a9 100644 --- a/models/node.go +++ b/models/node.go @@ -113,19 +113,24 @@ type Node struct { //AutoRelayedPeers map[string]struct{} `json:"auto_relayed_peers"` AutoRelayedPeers map[string]string `json:"auto_relayed_peers_v1"` //AutoRelayedBy uuid.UUID `json:"auto_relayed_by"` - FailOverPeers map[string]struct{} `json:"fail_over_peers"` - FailedOverBy uuid.UUID `json:"failed_over_by"` - IsInternetGateway bool `json:"isinternetgateway"` - InetNodeReq InetNodeReq `json:"inet_node_req"` - InternetGwID string `json:"internetgw_node_id"` - AdditionalRagIps []net.IP `json:"additional_rag_ips" swaggertype:"array,number"` - Tags map[TagID]struct{} `json:"tags"` - IsStatic bool `json:"is_static"` - IsUserNode bool `json:"is_user_node"` - StaticNode ExtClient `json:"static_node"` - Status NodeStatus `json:"node_status"` - Mutex *sync.Mutex `json:"-"` - EgressDetails EgressDetails `json:"-"` + FailOverPeers map[string]struct{} `json:"fail_over_peers"` + FailedOverBy uuid.UUID `json:"failed_over_by"` + IsInternetGateway bool `json:"isinternetgateway"` + InetNodeReq InetNodeReq `json:"inet_node_req"` + InternetGwID string `json:"internetgw_node_id"` + AdditionalRagIps []net.IP `json:"additional_rag_ips" swaggertype:"array,number"` + Tags map[TagID]struct{} `json:"tags"` + IsStatic bool `json:"is_static"` + IsUserNode bool `json:"is_user_node"` + StaticNode ExtClient `json:"static_node"` + Status NodeStatus `json:"node_status"` + Mutex *sync.Mutex `json:"-"` + EgressDetails EgressDetails `json:"-"` + PostureChecksViolations []Violation `json:"posture_check_violations"` + PostureCheckVolationSeverityLevel Severity `json:"posture_check_violation_severity_level"` + LastEvaluatedAt time.Time `json:"last_evaluated_at"` + Location string `json:"location"` // Format: "lat,lon" + CountryCode string `json:"country_code"` } type EgressDetails struct { EgressGatewayNatEnabled bool diff --git a/models/settings.go b/models/settings.go index d2440495..9ba94ad7 100644 --- a/models/settings.go +++ b/models/settings.go @@ -50,6 +50,7 @@ type ServerSettings struct { AuditLogsRetentionPeriodInDays int `json:"audit_logs_retention_period"` OldAClsSupport bool `json:"old_acl_support"` PeerConnectionCheckInterval string `json:"peer_connection_check_interval"` + PostureCheckInterval string `json:"posture_check_interval"` // in minutes CleanUpInterval int `json:"clean_up_interval_in_mins"` } diff --git a/models/structs.go b/models/structs.go index 2f761a98..30c0bd8c 100644 --- a/models/structs.go +++ b/models/structs.go @@ -116,8 +116,9 @@ type SuccessfulLoginResponse struct { // ErrorResponse is struct for error type ErrorResponse struct { - Code int - Message string + Code int + Message string + Response interface{} } // NodeAuth - struct for node auth @@ -457,3 +458,32 @@ type IDPSyncTestRequest struct { OktaOrgURL string `json:"okta_org_url"` OktaAPIToken string `json:"okta_api_token"` } + +type PostureCheckDeviceInfo struct { + ClientLocation string + ClientVersion string + OS string + OSFamily string + OSVersion string + KernelVersion string + AutoUpdate bool + Tags map[TagID]struct{} +} + +type Violation struct { + CheckID string `json:"check_id"` + Name string `json:"name"` + Attribute string `json:"attribute"` + Message string `json:"message"` + Severity Severity `json:"severity"` +} + +type Severity int + +const ( + SeverityUnknown Severity = iota + SeverityLow + SeverityMedium + SeverityHigh + SeverityCritical +) diff --git a/mq/handlers.go b/mq/handlers.go index 7808fb61..199b84de 100644 --- a/mq/handlers.go +++ b/mq/handlers.go @@ -296,13 +296,15 @@ func HandleHostCheckin(h, currentHost *models.Host) bool { (len(h.NatType) > 0 && h.NatType != currentHost.NatType) || h.DefaultInterface != currentHost.DefaultInterface || (h.ListenPort != 0 && h.ListenPort != currentHost.ListenPort) || - (h.WgPublicListenPort != 0 && h.WgPublicListenPort != currentHost.WgPublicListenPort) || (!h.EndpointIPv6.Equal(currentHost.EndpointIPv6)) + (h.WgPublicListenPort != 0 && h.WgPublicListenPort != currentHost.WgPublicListenPort) || + (!h.EndpointIPv6.Equal(currentHost.EndpointIPv6)) || (h.OSFamily != currentHost.OSFamily) || + (h.OSVersion != currentHost.OSVersion) || (h.KernelVersion != currentHost.KernelVersion) if ifaceDelta { // only save if something changes if !h.EndpointIP.Equal(currentHost.EndpointIP) || !h.EndpointIPv6.Equal(currentHost.EndpointIPv6) || currentHost.Location == "" { if h.EndpointIP != nil { - h.Location = logic.GetHostLocInfo(h.EndpointIP.String(), os.Getenv("IP_INFO_TOKEN")) + h.Location, h.CountryCode = logic.GetHostLocInfo(h.EndpointIP.String(), os.Getenv("IP_INFO_TOKEN")) } else if h.EndpointIPv6 != nil { - h.Location = logic.GetHostLocInfo(h.EndpointIPv6.String(), os.Getenv("IP_INFO_TOKEN")) + h.Location, h.CountryCode = logic.GetHostLocInfo(h.EndpointIPv6.String(), os.Getenv("IP_INFO_TOKEN")) } } currentHost.EndpointIP = h.EndpointIP @@ -310,9 +312,15 @@ func HandleHostCheckin(h, currentHost *models.Host) bool { currentHost.Interfaces = h.Interfaces currentHost.DefaultInterface = h.DefaultInterface currentHost.NatType = h.NatType + currentHost.OSFamily = h.OSFamily + currentHost.OSVersion = h.OSVersion + currentHost.KernelVersion = h.KernelVersion if h.Location != "" { currentHost.Location = h.Location } + if h.CountryCode != "" { + currentHost.CountryCode = h.CountryCode + } if h.ListenPort != 0 { currentHost.ListenPort = h.ListenPort } diff --git a/pro/controllers/posture_check.go b/pro/controllers/posture_check.go new file mode 100644 index 00000000..edfe47fc --- /dev/null +++ b/pro/controllers/posture_check.go @@ -0,0 +1,343 @@ +package controllers + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "time" + + "github.com/google/uuid" + "github.com/gorilla/mux" + "github.com/gravitl/netmaker/db" + "github.com/gravitl/netmaker/logger" + "github.com/gravitl/netmaker/logic" + "github.com/gravitl/netmaker/models" + "github.com/gravitl/netmaker/mq" + proLogic "github.com/gravitl/netmaker/pro/logic" + "github.com/gravitl/netmaker/schema" +) + +func PostureCheckHandlers(r *mux.Router) { + r.HandleFunc("/api/v1/posture_check", logic.SecurityCheck(true, http.HandlerFunc(createPostureCheck))).Methods(http.MethodPost) + r.HandleFunc("/api/v1/posture_check", logic.SecurityCheck(true, http.HandlerFunc(listPostureChecks))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/posture_check", logic.SecurityCheck(true, http.HandlerFunc(updatePostureCheck))).Methods(http.MethodPut) + r.HandleFunc("/api/v1/posture_check", logic.SecurityCheck(true, http.HandlerFunc(deletePostureCheck))).Methods(http.MethodDelete) + r.HandleFunc("/api/v1/posture_check/attrs", logic.SecurityCheck(true, http.HandlerFunc(listPostureChecksAttrs))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/posture_check/violations", logic.SecurityCheck(true, http.HandlerFunc(listPostureCheckViolatedNodes))).Methods(http.MethodGet) +} + +// @Summary List Posture Checks Available Attributes +// @Router /api/v1/posture_check/attrs [get] +// @Tags Auth +// @Accept json +// @Param query network string +// @Success 200 {object} models.SuccessResponse +// @Failure 400 {object} models.ErrorResponse +// @Failure 401 {object} models.ErrorResponse +// @Failure 500 {object} models.ErrorResponse +func listPostureChecksAttrs(w http.ResponseWriter, r *http.Request) { + + logic.ReturnSuccessResponseWithJson(w, r, schema.PostureCheckAttrValues, "fetched posture checks") +} + +// @Summary Create Posture Check +// @Router /api/v1/posture_check [post] +// @Tags DNS +// @Accept json +// @Param body body schema.PostureCheck +// @Success 200 {object} models.SuccessResponse +// @Failure 400 {object} models.ErrorResponse +// @Failure 401 {object} models.ErrorResponse +// @Failure 500 {object} models.ErrorResponse +func createPostureCheck(w http.ResponseWriter, r *http.Request) { + + var req schema.PostureCheck + err := json.NewDecoder(r.Body).Decode(&req) + if err != nil { + logger.Log(0, "error decoding request body: ", + err.Error()) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) + return + } + if err := proLogic.ValidatePostureCheck(&req); err != nil { + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) + return + } + + pc := schema.PostureCheck{ + ID: uuid.New().String(), + Name: req.Name, + NetworkID: req.NetworkID, + Description: req.Description, + Tags: req.Tags, + Attribute: req.Attribute, + Values: req.Values, + Severity: req.Severity, + Status: true, + CreatedBy: r.Header.Get("user"), + CreatedAt: time.Now().UTC(), + } + + err = pc.Create(db.WithContext(r.Context())) + if err != nil { + logic.ReturnErrorResponse( + w, + r, + logic.FormatError(errors.New("error creating posture check "+err.Error()), logic.Internal), + ) + return + } + logic.LogEvent(&models.Event{ + Action: models.Create, + Source: models.Subject{ + ID: r.Header.Get("user"), + Name: r.Header.Get("user"), + Type: models.UserSub, + }, + TriggeredBy: r.Header.Get("user"), + Target: models.Subject{ + ID: pc.ID, + Name: pc.Name, + Type: models.PostureCheckSub, + }, + NetworkID: models.NetworkID(pc.NetworkID), + Origin: models.Dashboard, + }) + + go mq.PublishPeerUpdate(false) + go proLogic.RunPostureChecks() + logic.ReturnSuccessResponseWithJson(w, r, pc, "created posture check") +} + +// @Summary List Posture Checks +// @Router /api/v1/posture_check [get] +// @Tags Auth +// @Accept json +// @Param query network string +// @Success 200 {object} models.SuccessResponse +// @Failure 400 {object} models.ErrorResponse +// @Failure 401 {object} models.ErrorResponse +// @Failure 500 {object} models.ErrorResponse +func listPostureChecks(w http.ResponseWriter, r *http.Request) { + + network := r.URL.Query().Get("network") + if network == "" { + logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("network is required"), logic.BadReq)) + return + } + _, err := logic.GetNetwork(network) + if err != nil { + logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("network not found"), logic.BadReq)) + return + } + id := r.URL.Query().Get("id") + if id != "" { + pc := schema.PostureCheck{ID: id} + err := pc.Get(db.WithContext(r.Context())) + if err != nil { + logic.ReturnErrorResponse( + w, + r, + logic.FormatError(errors.New("error listing posture checks "+err.Error()), "internal"), + ) + return + } + logic.ReturnSuccessResponseWithJson(w, r, pc, "fetched posture check") + return + } + pc := schema.PostureCheck{NetworkID: network} + list, err := pc.ListByNetwork(db.WithContext(r.Context())) + if err != nil { + logic.ReturnErrorResponse( + w, + r, + logic.FormatError(errors.New("error listing posture checks "+err.Error()), "internal"), + ) + return + } + logic.ReturnSuccessResponseWithJson(w, r, list, "fetched posture checks") +} + +// @Summary Update Posture Check +// @Router /api/v1/posture_check [put] +// @Tags Auth +// @Accept json +// @Param body body schema.PostureCheck +// @Success 200 {object} models.SuccessResponse +// @Failure 400 {object} models.ErrorResponse +// @Failure 401 {object} models.ErrorResponse +// @Failure 500 {object} models.ErrorResponse +func updatePostureCheck(w http.ResponseWriter, r *http.Request) { + + var updatePc schema.PostureCheck + err := json.NewDecoder(r.Body).Decode(&updatePc) + if err != nil { + logger.Log(0, "error decoding request body: ", + err.Error()) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) + return + } + + if err := proLogic.ValidatePostureCheck(&updatePc); err != nil { + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) + return + } + + pc := schema.PostureCheck{ID: updatePc.ID} + err = pc.Get(db.WithContext(r.Context())) + if err != nil { + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) + return + } + var updateStatus bool + if updatePc.Status != pc.Status { + updateStatus = true + } + event := &models.Event{ + Action: models.Update, + Source: models.Subject{ + ID: r.Header.Get("user"), + Name: r.Header.Get("user"), + Type: models.UserSub, + }, + TriggeredBy: r.Header.Get("user"), + Target: models.Subject{ + ID: pc.ID, + Name: updatePc.Name, + Type: models.PostureCheckSub, + }, + Diff: models.Diff{ + Old: pc, + New: updatePc, + }, + NetworkID: models.NetworkID(pc.NetworkID), + Origin: models.Dashboard, + } + pc.Tags = updatePc.Tags + pc.Attribute = updatePc.Attribute + pc.Values = updatePc.Values + pc.Description = updatePc.Description + pc.Name = updatePc.Name + pc.Severity = updatePc.Severity + pc.Status = updatePc.Status + pc.UpdatedAt = time.Now().UTC() + + err = pc.Update(db.WithContext(context.TODO())) + if err != nil { + logic.ReturnErrorResponse( + w, + r, + logic.FormatError(errors.New("error updating posture check "+err.Error()), "internal"), + ) + return + } + if updateStatus { + pc.UpdateStatus(db.WithContext(context.TODO())) + } + logic.LogEvent(event) + go mq.PublishPeerUpdate(false) + go proLogic.RunPostureChecks() + logic.ReturnSuccessResponseWithJson(w, r, pc, "updated posture check") +} + +// @Summary Delete Posture Check +// @Router /api/v1/posture_check [delete] +// @Tags Auth +// @Accept json +// @Param query id string +// @Success 200 {object} models.SuccessResponse +// @Failure 400 {object} models.ErrorResponse +// @Failure 401 {object} models.ErrorResponse +// @Failure 500 {object} models.ErrorResponse +func deletePostureCheck(w http.ResponseWriter, r *http.Request) { + + id := r.URL.Query().Get("id") + if id == "" { + logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("id is required"), "badrequest")) + return + } + pc := schema.PostureCheck{ID: id} + err := pc.Get(db.WithContext(r.Context())) + if err != nil { + logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.BadReq)) + return + } + err = pc.Delete(db.WithContext(r.Context())) + if err != nil { + logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.Internal)) + return + } + logic.LogEvent(&models.Event{ + Action: models.Delete, + Source: models.Subject{ + ID: r.Header.Get("user"), + Name: r.Header.Get("user"), + Type: models.UserSub, + }, + TriggeredBy: r.Header.Get("user"), + Target: models.Subject{ + ID: pc.ID, + Name: pc.Name, + Type: models.PostureCheckSub, + }, + NetworkID: models.NetworkID(pc.NetworkID), + Origin: models.Dashboard, + Diff: models.Diff{ + Old: pc, + New: nil, + }, + }) + + go mq.PublishPeerUpdate(false) + logic.ReturnSuccessResponseWithJson(w, r, pc, "deleted posture check") +} + +// @Summary List Posture Check violated Nodes +// @Router /api/v1/posture_check/violations [get] +// @Tags Auth +// @Accept json +// @Param query network string +// @Success 200 {object} models.SuccessResponse +// @Failure 400 {object} models.ErrorResponse +// @Failure 401 {object} models.ErrorResponse +// @Failure 500 {object} models.ErrorResponse +func listPostureCheckViolatedNodes(w http.ResponseWriter, r *http.Request) { + + network := r.URL.Query().Get("network") + if network == "" { + logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("network is required"), logic.BadReq)) + return + } + listViolatedusers := r.URL.Query().Get("users") == "true" + violatedNodes := []models.Node{} + if listViolatedusers { + extclients, err := logic.GetNetworkExtClients(network) + if err != nil { + logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.BadReq)) + return + } + for _, extclient := range extclients { + if extclient.DeviceID != "" && extclient.Enabled { + if len(extclient.PostureChecksViolations) > 0 { + violatedNodes = append(violatedNodes, extclient.ConvertToStaticNode()) + } + } + } + } else { + nodes, err := logic.GetNetworkNodes(network) + if err != nil { + logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.BadReq)) + return + } + + for _, node := range nodes { + if len(node.PostureChecksViolations) > 0 { + violatedNodes = append(violatedNodes, node) + } + } + } + apiNodes := logic.GetAllNodesAPI(violatedNodes) + logic.SortApiNodes(apiNodes[:]) + logic.ReturnSuccessResponseWithJson(w, r, apiNodes, "fetched posture checks violated nodes") +} diff --git a/pro/controllers/tags.go b/pro/controllers/tags.go index 704fd8ac..693ef5df 100644 --- a/pro/controllers/tags.go +++ b/pro/controllers/tags.go @@ -241,7 +241,7 @@ func updateTag(w http.ResponseWriter, r *http.Request) { UsedByCnt: len(updateTag.TaggedNodes), TaggedNodes: updateTag.TaggedNodes, } - + go proLogic.RunPostureChecks() logic.ReturnSuccessResponseWithJson(w, r, res, "updated tags") } diff --git a/pro/initialize.go b/pro/initialize.go index 6137b157..86eef8c1 100644 --- a/pro/initialize.go +++ b/pro/initialize.go @@ -37,6 +37,7 @@ func InitPro() { proControllers.TagHandlers, proControllers.NetworkHandlers, proControllers.AutoRelayHandlers, + proControllers.PostureCheckHandlers, ) controller.ListRoles = proControllers.ListRoles logic.EnterpriseCheckFuncs = append(logic.EnterpriseCheckFuncs, func() { @@ -99,8 +100,8 @@ func InitPro() { auth.ResetIDPSyncHook() email.Init() go proLogic.EventWatcher() - logic.GetMetricsMonitor().Start() + proLogic.AddPostureCheckHook() }) logic.ResetFailOver = proLogic.ResetFailOver logic.ResetFailedOverPeer = proLogic.ResetFailedOverPeer @@ -172,6 +173,7 @@ func InitPro() { logic.GetNameserversForNode = proLogic.GetNameserversForNode logic.ValidateNameserverReq = proLogic.ValidateNameserverReq logic.ValidateEgressReq = proLogic.ValidateEgressReq + logic.CheckPostureViolations = proLogic.CheckPostureViolations } diff --git a/pro/logic/metrics.go b/pro/logic/metrics.go index 8cfcd716..a850e80a 100644 --- a/pro/logic/metrics.go +++ b/pro/logic/metrics.go @@ -240,7 +240,7 @@ func updateNodeMetrics(currentNode *models.Node, newMetrics *models.Metrics) { slog.Debug("[metrics] node metrics data", "node ID", currentNode.ID, "metrics", newMetrics) } -func GetHostLocInfo(ip, token string) string { +func GetHostLocInfo(ip, token string) (loc, country string) { url := "https://ipinfo.io/" if ip != "" { url += ip @@ -253,15 +253,18 @@ func GetHostLocInfo(ip, token string) string { client := http.Client{Timeout: 3 * time.Second} resp, err := client.Get(url) if err != nil { - return "" + return "", "" } defer resp.Body.Close() var data struct { - Loc string `json:"loc"` // Format: "lat,lon" + Loc string `json:"loc"` // Format: "lat,lon" + Country string `json:"country"` } if err := json.NewDecoder(resp.Body).Decode(&data); err != nil { - return "" + return "", "" } - return data.Loc + loc = data.Loc + country = data.Country + return } diff --git a/pro/logic/posture_check.go b/pro/logic/posture_check.go new file mode 100644 index 00000000..232e94a1 --- /dev/null +++ b/pro/logic/posture_check.go @@ -0,0 +1,464 @@ +package logic + +import ( + "context" + "errors" + "fmt" + "slices" + "strconv" + "strings" + "sync" + "time" + + "github.com/biter777/countries" + "github.com/gravitl/netmaker/db" + "github.com/gravitl/netmaker/logic" + "github.com/gravitl/netmaker/models" + "github.com/gravitl/netmaker/schema" + "gorm.io/datatypes" +) + +var postureCheckMutex = &sync.Mutex{} + +func AddPostureCheckHook() { + settings := logic.GetServerSettings() + interval := time.Hour + i, err := strconv.Atoi(settings.PostureCheckInterval) + if err == nil { + interval = time.Minute * time.Duration(i) + } + logic.HookManagerCh <- models.HookDetails{ + Hook: logic.WrapHook(RunPostureChecks), + Interval: interval, + } +} +func RunPostureChecks() error { + postureCheckMutex.Lock() + defer postureCheckMutex.Unlock() + nets, err := logic.GetNetworks() + if err != nil { + return err + } + nodes, err := logic.GetAllNodes() + if err != nil { + return err + } + for _, netI := range nets { + networkNodes := logic.GetNetworkNodesMemory(nodes, netI.NetID) + if len(networkNodes) == 0 { + continue + } + networkNodes = logic.AddStaticNodestoList(networkNodes) + pcLi, err := (&schema.PostureCheck{NetworkID: netI.NetID}).ListByNetwork(db.WithContext(context.TODO())) + if err != nil || len(pcLi) == 0 { + continue + } + + for _, nodeI := range networkNodes { + postureChecksViolations, postureCheckVolationSeverityLevel := GetPostureCheckViolations(pcLi, logic.GetPostureCheckDeviceInfoByNode(&nodeI)) + if nodeI.IsStatic { + extclient, err := logic.GetExtClient(nodeI.StaticNode.ClientID, nodeI.StaticNode.Network) + if err == nil { + extclient.PostureChecksViolations = postureChecksViolations + extclient.PostureCheckVolationSeverityLevel = postureCheckVolationSeverityLevel + extclient.LastEvaluatedAt = time.Now().UTC() + logic.SaveExtClient(&extclient) + } + } else { + nodeI.PostureChecksViolations, nodeI.PostureCheckVolationSeverityLevel = postureChecksViolations, + postureCheckVolationSeverityLevel + nodeI.LastEvaluatedAt = time.Now().UTC() + logic.UpsertNode(&nodeI) + } + + } + + } + + return nil +} + +func CheckPostureViolations(d models.PostureCheckDeviceInfo, network models.NetworkID) ([]models.Violation, models.Severity) { + pcLi, err := (&schema.PostureCheck{NetworkID: network.String()}).ListByNetwork(db.WithContext(context.TODO())) + if err != nil || len(pcLi) == 0 { + return []models.Violation{}, models.SeverityUnknown + } + violations, level := GetPostureCheckViolations(pcLi, d) + return violations, level +} +func GetPostureCheckViolations(checks []schema.PostureCheck, d models.PostureCheckDeviceInfo) ([]models.Violation, models.Severity) { + var violations []models.Violation + highest := models.SeverityUnknown + + // Group checks by attribute + checksByAttribute := make(map[schema.Attribute][]schema.PostureCheck) + for _, c := range checks { + // skip disabled checks + if !c.Status { + continue + } + // Check if tags match + if _, ok := c.Tags["*"]; !ok { + exists := false + for tagID := range c.Tags { + if _, ok := d.Tags[models.TagID(tagID)]; ok { + exists = true + break + } + } + if !exists { + continue + } + } + checksByAttribute[c.Attribute] = append(checksByAttribute[c.Attribute], c) + } + + // Handle OS and OSFamily together with OR logic since they are related + osChecks := checksByAttribute[schema.OS] + osFamilyChecks := checksByAttribute[schema.OSFamily] + if len(osChecks) > 0 || len(osFamilyChecks) > 0 { + osAllowed := evaluateAttributeChecks(osChecks, schema.OS, d) + osFamilyAllowed := evaluateAttributeChecks(osFamilyChecks, schema.OSFamily, d) + + // OR condition: if either OS or OSFamily passes, both are considered passed + if !osAllowed && !osFamilyAllowed { + + // Both failed, add violations for both + osDenied := getDeniedChecks(osChecks, schema.OS, d) + osFamilyDenied := getDeniedChecks(osFamilyChecks, schema.OSFamily, d) + + for _, denied := range osDenied { + sev := denied.check.Severity + if sev > highest { + highest = sev + } + v := models.Violation{ + CheckID: denied.check.ID, + Name: denied.check.Name, + Attribute: string(denied.check.Attribute), + Message: denied.reason, + Severity: sev, + } + violations = append(violations, v) + } + for _, denied := range osFamilyDenied { + sev := denied.check.Severity + if sev > highest { + highest = sev + } + v := models.Violation{ + CheckID: denied.check.ID, + Name: denied.check.Name, + Attribute: string(denied.check.Attribute), + Message: denied.reason, + Severity: sev, + } + violations = append(violations, v) + } + } + } + + // For all other attributes, check if ANY check allows it + for attr, attrChecks := range checksByAttribute { + // Skip OS and OSFamily as they are handled above + if attr == schema.OS || attr == schema.OSFamily { + continue + } + + // Check if any check for this attribute allows the device + allowed := false + var deniedChecks []struct { + check schema.PostureCheck + reason string + } + + for _, c := range attrChecks { + violated, reason := evaluatePostureCheck(&c, d) + if !violated { + // At least one check allows it + allowed = true + break + } + // Track denied checks with their reasons for violation reporting + deniedChecks = append(deniedChecks, struct { + check schema.PostureCheck + reason string + }{check: c, reason: reason}) + } + + // If no check allows it, add violations for all denied checks + if !allowed { + for _, denied := range deniedChecks { + sev := denied.check.Severity + if sev > highest { + highest = sev + } + + v := models.Violation{ + CheckID: denied.check.ID, + Name: denied.check.Name, + Attribute: string(denied.check.Attribute), + Message: denied.reason, + Severity: sev, + } + violations = append(violations, v) + } + } + } + + return violations, highest +} + +// evaluateAttributeChecks evaluates checks for a specific attribute and returns true if any check allows the device +func evaluateAttributeChecks(attrChecks []schema.PostureCheck, attr schema.Attribute, d models.PostureCheckDeviceInfo) bool { + for _, c := range attrChecks { + violated, _ := evaluatePostureCheck(&c, d) + if !violated { + // At least one check allows it + return true + } + } + return false +} + +// getDeniedChecks returns all checks that denied the device for a specific attribute +func getDeniedChecks(attrChecks []schema.PostureCheck, attr schema.Attribute, d models.PostureCheckDeviceInfo) []struct { + check schema.PostureCheck + reason string +} { + var deniedChecks []struct { + check schema.PostureCheck + reason string + } + + for _, c := range attrChecks { + violated, reason := evaluatePostureCheck(&c, d) + if violated { + deniedChecks = append(deniedChecks, struct { + check schema.PostureCheck + reason string + }{check: c, reason: reason}) + } + } + return deniedChecks +} + +func evaluatePostureCheck(check *schema.PostureCheck, d models.PostureCheckDeviceInfo) (violated bool, reason string) { + switch check.Attribute { + + // ------------------------ + // 1. Geographic check + // ------------------------ + case schema.ClientLocation: + if !slices.Contains(check.Values, strings.ToUpper(d.ClientLocation)) { + return true, fmt.Sprintf("client location '%s' not allowed", CountryNameFromISO(d.ClientLocation)) + } + + // ------------------------ + // 2. Client version check + // Supports: exact match OR allowed list OR semver rules + // ------------------------ + case schema.ClientVersion: + for _, rule := range check.Values { + ok, err := matchVersionRule(d.ClientVersion, rule) + if err != nil || !ok { + return true, fmt.Sprintf("client version '%s' violation", d.ClientVersion) + } + } + + // ------------------------ + // 3. OS check + // ("windows", "mac", "linux", etc.) + // ------------------------ + case schema.OS: + if !slices.Contains(check.Values, d.OS) { + return true, fmt.Sprintf("client os '%s' not allowed", d.OS) + } + case schema.OSFamily: + if !slices.Contains(check.Values, d.OSFamily) { + return true, fmt.Sprintf("os family '%s' not allowed", d.OSFamily) + } + // ------------------------ + // 4. OS version check + // Supports operators: > >= < <= = + // ------------------------ + case schema.OSVersion: + for _, rule := range check.Values { + ok, err := matchVersionRule(d.OSVersion, rule) + if err != nil || !ok { + return true, fmt.Sprintf("os version '%s' violation", d.OSVersion) + } + } + case schema.KernelVersion: + for _, rule := range check.Values { + ok, err := matchVersionRule(d.KernelVersion, rule) + if err != nil || !ok { + return true, fmt.Sprintf("kernel version '%s' violation", d.KernelVersion) + } + } + // ------------------------ + // 5. Auto-update check + // Values: ["true"] or ["false"] + // ------------------------ + case schema.AutoUpdate: + required := len(check.Values) > 0 && strings.ToLower(check.Values[0]) == "true" + if required && !d.AutoUpdate { + return true, "auto update must be enabled" + } + if !required && d.AutoUpdate { + return true, "auto update must be disabled" + } + } + + return false, "" +} +func cleanVersion(v string) string { + v = strings.TrimSpace(v) + v = strings.TrimPrefix(v, "v") + v = strings.TrimPrefix(v, "V") + v = strings.TrimSuffix(v, ",") + v = strings.TrimSpace(v) + return v +} + +func matchVersionRule(actual, rule string) (bool, error) { + actual = cleanVersion(actual) + rule = strings.TrimSpace(rule) + + op := "=" + + switch { + case strings.HasPrefix(rule, ">="): + op = ">=" + rule = strings.TrimPrefix(rule, ">=") + case strings.HasPrefix(rule, "<="): + op = "<=" + rule = strings.TrimPrefix(rule, "<=") + case strings.HasPrefix(rule, ">"): + op = ">" + rule = strings.TrimPrefix(rule, ">") + case strings.HasPrefix(rule, "<"): + op = "<" + rule = strings.TrimPrefix(rule, "<") + case strings.HasPrefix(rule, "="): + op = "=" + rule = strings.TrimPrefix(rule, "=") + } + + rule = cleanVersion(rule) + + cmp := compareVersions(actual, rule) + + switch op { + case "=": + return cmp == 0, nil + case ">": + return cmp == 1, nil + case "<": + return cmp == -1, nil + case ">=": + return cmp == 1 || cmp == 0, nil + case "<=": + return cmp == -1 || cmp == 0, nil + } + + return false, fmt.Errorf("invalid rule: %s", rule) +} + +func compareVersions(a, b string) int { + pa := strings.Split(a, ".") + pb := strings.Split(b, ".") + + n := len(pa) + if len(pb) > n { + n = len(pb) + } + + for i := 0; i < n; i++ { + ai, bi := 0, 0 + + if i < len(pa) { + ai, _ = strconv.Atoi(pa[i]) + } + if i < len(pb) { + bi, _ = strconv.Atoi(pb[i]) + } + + if ai > bi { + return 1 + } + if ai < bi { + return -1 + } + } + return 0 +} + +func ValidatePostureCheck(pc *schema.PostureCheck) error { + if pc.Name == "" { + return errors.New("name cannot be empty") + } + _, err := logic.GetNetwork(pc.NetworkID) + if err != nil { + return errors.New("invalid network") + } + allowedAttrvaluesMap, ok := schema.PostureCheckAttrValuesMap[pc.Attribute] + if !ok { + return errors.New("unkown attribute") + } + if len(pc.Values) == 0 { + return errors.New("attribute value cannot be empty") + } + for i, valueI := range pc.Values { + pc.Values[i] = strings.ToLower(valueI) + } + if pc.Attribute == schema.ClientLocation { + for i, loc := range pc.Values { + if countries.ByName(loc) == countries.Unknown { + return errors.New("invalid country code") + } + pc.Values[i] = strings.ToUpper(loc) + } + } + if pc.Attribute == schema.AutoUpdate || pc.Attribute == schema.OS || + pc.Attribute == schema.OSFamily { + for _, valueI := range pc.Values { + if _, ok := allowedAttrvaluesMap[valueI]; !ok { + return errors.New("invalid attribute value") + } + } + } + if pc.Attribute == schema.ClientVersion || pc.Attribute == schema.OSVersion || + pc.Attribute == schema.KernelVersion { + for i, valueI := range pc.Values { + if !logic.IsValidVersion(valueI) { + return errors.New("invalid attribute version value") + } + pc.Values[i] = logic.CleanVersion(valueI) + } + } + if len(pc.Tags) > 0 { + for tagID := range pc.Tags { + if tagID == "*" { + continue + } + _, err := GetTag(models.TagID(tagID)) + if err != nil { + return errors.New("unknown tag") + } + } + } else { + pc.Tags = make(datatypes.JSONMap) + pc.Tags["*"] = struct{}{} + } + + return nil +} + +func CountryNameFromISO(code string) string { + c := countries.ByName(code) // works with ISO2, ISO3, full name + if c == countries.Unknown { + return "" + } + return c.Info().Name +} diff --git a/schema/models.go b/schema/models.go index 07047eee..43c31238 100644 --- a/schema/models.go +++ b/schema/models.go @@ -9,5 +9,6 @@ func ListModels() []interface{} { &Event{}, &PendingHost{}, &Nameserver{}, + &PostureCheck{}, } } diff --git a/schema/posture_check.go b/schema/posture_check.go new file mode 100644 index 00000000..f2ac5652 --- /dev/null +++ b/schema/posture_check.go @@ -0,0 +1,123 @@ +package schema + +import ( + "context" + "time" + + "github.com/gravitl/netmaker/db" + "github.com/gravitl/netmaker/models" + "gorm.io/datatypes" +) + +type Attribute string +type Values string + +const ( + OS Attribute = "os" + OSVersion Attribute = "os_version" + OSFamily Attribute = "os_family" + KernelVersion Attribute = "kernel_version" + AutoUpdate Attribute = "auto_update" + ClientVersion Attribute = "client_version" + ClientLocation Attribute = "client_location" +) + +var PostureCheckAttrs = []Attribute{ + ClientLocation, + ClientVersion, + OS, + OSVersion, + OSFamily, + KernelVersion, + AutoUpdate, +} + +var PostureCheckAttrValuesMap = map[Attribute]map[string]struct{}{ + ClientLocation: { + "any_valid_iso_country_codes": {}, + }, + ClientVersion: { + "any_valid_semantic_version": {}, + }, + OS: { + "linux": {}, + "darwin": {}, + "windows": {}, + "ios": {}, + "android": {}, + }, + OSVersion: { + "any_valid_semantic_version": {}, + }, + OSFamily: { + "linux-debian": {}, + "linux-redhat": {}, + "linux-suse": {}, + "linux-arch": {}, + "linux-gentoo": {}, + "linux-other": {}, + "darwin": {}, + "windows": {}, + "ios": {}, + "android": {}, + }, + KernelVersion: { + "any_valid_semantic_version": {}, + }, + AutoUpdate: { + "true": {}, + "false": {}, + }, +} + +var PostureCheckAttrValues = map[Attribute][]string{ + ClientLocation: {"any_valid_iso_country_codes"}, + ClientVersion: {"any_valid_semantic_version"}, + OS: {"linux", "darwin", "windows", "ios", "android"}, + OSVersion: {"any_valid_semantic_version"}, + OSFamily: {"linux-debian", "linux-redhat", "linux-suse", "linux-arch", "linux-gentoo", "linux-other", "darwin", "windows", "ios", "android"}, + KernelVersion: {"any_valid_semantic_version"}, + AutoUpdate: {"true", "false"}, +} + +type PostureCheck struct { + ID string `gorm:"primaryKey" json:"id"` + Name string `gorm:"name" json:"name"` + NetworkID string `gorm:"network_id" json:"network_id"` + Description string `gorm:"description" json:"description"` + Attribute Attribute `gorm:"attribute" json:"attribute"` + Values datatypes.JSONSlice[string] `gorm:"values" json:"values"` + Severity models.Severity `gorm:"severity" json:"severity"` + Tags datatypes.JSONMap `gorm:"tags" json:"tags"` + Status bool `gorm:"status" json:"status"` + CreatedBy string `gorm:"created_by" json:"created_by"` + CreatedAt time.Time `gorm:"created_at" json:"created_at"` + UpdatedAt time.Time `gorm:"updated_at" json:"updated_at"` +} + +func (p *PostureCheck) Get(ctx context.Context) error { + return db.FromContext(ctx).Model(&PostureCheck{}).First(&p).Where("id = ?", p.ID).Error +} + +func (p *PostureCheck) Update(ctx context.Context) error { + return db.FromContext(ctx).Model(&PostureCheck{}).Where("id = ?", p.ID).Updates(&p).Error +} + +func (p *PostureCheck) Create(ctx context.Context) error { + return db.FromContext(ctx).Model(&PostureCheck{}).Create(&p).Error +} + +func (p *PostureCheck) ListByNetwork(ctx context.Context) (pcli []PostureCheck, err error) { + err = db.FromContext(ctx).Model(&PostureCheck{}).Where("network_id = ?", p.NetworkID).Find(&pcli).Error + return +} + +func (p *PostureCheck) Delete(ctx context.Context) error { + return db.FromContext(ctx).Model(&PostureCheck{}).Where("id = ?", p.ID).Delete(&p).Error +} + +func (p *PostureCheck) UpdateStatus(ctx context.Context) error { + return db.FromContext(ctx).Model(&PostureCheck{}).Where("id = ?", p.ID).Updates(map[string]any{ + "status": p.Status, + }).Error +}