Merge pull request #3735 from gravitl/NM-166

NM-166: Device Posture Checks
This commit is contained in:
Abhishek Kondur
2025-12-05 10:33:11 +04:00
committed by GitHub
parent 6533b827cf
commit eed32cd2d6
30 changed files with 1559 additions and 90 deletions
+34
View File
@@ -2,8 +2,11 @@ package controller
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"os"
"slices"
"time"
"github.com/go-playground/validator/v10"
@@ -358,6 +361,37 @@ func handleHostRegister(w http.ResponseWriter, r *http.Request) {
)
return
}
if newHost.EndpointIP != nil {
newHost.Location, newHost.CountryCode = logic.GetHostLocInfo(newHost.EndpointIP.String(), os.Getenv("IP_INFO_TOKEN"))
} else if newHost.EndpointIPv6 != nil {
newHost.Location, newHost.CountryCode = logic.GetHostLocInfo(newHost.EndpointIPv6.String(), os.Getenv("IP_INFO_TOKEN"))
}
pcviolations := []models.Violation{}
skipViolatedNetworks := []string{}
for _, netI := range enrollmentKey.Networks {
violations, _ := logic.CheckPostureViolations(models.PostureCheckDeviceInfo{
ClientLocation: newHost.CountryCode,
ClientVersion: newHost.Version,
OS: newHost.OS,
OSFamily: newHost.OSFamily,
OSVersion: newHost.OSVersion,
KernelVersion: newHost.KernelVersion,
AutoUpdate: newHost.AutoUpdate,
}, models.NetworkID(netI))
pcviolations = append(pcviolations, violations...)
if len(violations) > 0 {
skipViolatedNetworks = append(skipViolatedNetworks, netI)
}
}
if len(skipViolatedNetworks) == len(enrollmentKey.Networks) && len(pcviolations) > 0 {
logic.ReturnErrorResponse(w, r,
logic.FormatError(errors.New("access blocked: this device doesnt meet security requirements"), logic.Forbidden))
return
}
// need to remove the networks that were skipped from the enrollment key
enrollmentKey.Networks = slices.DeleteFunc(enrollmentKey.Networks, func(netI string) bool {
return slices.Contains(skipViolatedNetworks, netI)
})
if !hostExists {
newHost.PersistentKeepalive = models.DefaultPersistentKeepAlive
// register host
+37 -6
View File
@@ -715,13 +715,29 @@ func createExtClient(w http.ResponseWriter, r *http.Request) {
}
err = errors.New("remote client config already exists on the gateway")
slog.Error("failed to create extclient", "user", userName, "error", err)
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.BadReq))
return
}
}
}
extclient := logic.UpdateExtClient(&models.ExtClient{}, &customExtClient)
if extclient.DeviceID != "" {
// check for violations connecting from desktop app
violations, _ := logic.CheckPostureViolations(models.PostureCheckDeviceInfo{
ClientLocation: extclient.Country,
ClientVersion: extclient.ClientVersion,
OS: extclient.OS,
OSFamily: extclient.OSFamily,
OSVersion: extclient.OSVersion,
KernelVersion: extclient.KernelVersion,
//AutoUpdate: extclient.AutoUpdate,
}, models.NetworkID(extclient.Network))
if len(violations) > 0 {
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("posture check violations"), logic.Forbidden))
return
}
}
extclient.OwnerID = userName
extclient.RemoteAccessClientID = customExtClient.RemoteAccessClientID
extclient.IngressGatewayID = nodeid
@@ -749,7 +765,6 @@ func createExtClient(w http.ResponseWriter, r *http.Request) {
if err == nil { // check if parent network default ACL is enabled (yes) or not (no)
extclient.Enabled = parentNetwork.DefaultACL == "yes"
}
extclient.Os = customExtClient.Os
extclient.DeviceID = customExtClient.DeviceID
extclient.DeviceName = customExtClient.DeviceName
if customExtClient.IsAlreadyConnectedToInetGw {
@@ -758,7 +773,7 @@ func createExtClient(w http.ResponseWriter, r *http.Request) {
extclient.PublicEndpoint = customExtClient.PublicEndpoint
extclient.Country = customExtClient.Country
if customExtClient.RemoteAccessClientID != "" && customExtClient.Location == "" {
extclient.Location = logic.GetHostLocInfo(logic.GetClientIP(r), os.Getenv("IP_INFO_TOKEN"))
extclient.Location, extclient.Country = logic.GetHostLocInfo(logic.GetClientIP(r), os.Getenv("IP_INFO_TOKEN"))
}
extclient.Location = customExtClient.Location
@@ -772,7 +787,7 @@ func createExtClient(w http.ResponseWriter, r *http.Request) {
"error",
err,
)
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.Internal))
return
}
@@ -820,7 +835,7 @@ func createExtClient(w http.ResponseWriter, r *http.Request) {
"error",
err,
)
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.Internal))
return
}
if err := mq.PublishPeerUpdate(false); err != nil {
@@ -903,9 +918,25 @@ func updateExtClient(w http.ResponseWriter, r *http.Request) {
replacePeers = true
}
if update.RemoteAccessClientID != "" && update.Location == "" {
update.Location = logic.GetHostLocInfo(logic.GetClientIP(r), os.Getenv("IP_INFO_TOKEN"))
update.Location, update.Country = logic.GetHostLocInfo(logic.GetClientIP(r), os.Getenv("IP_INFO_TOKEN"))
}
newclient := logic.UpdateExtClient(&oldExtClient, &update)
if newclient.DeviceID != "" && newclient.Enabled {
// check for violations connecting from desktop app
violations, _ := logic.CheckPostureViolations(models.PostureCheckDeviceInfo{
ClientLocation: newclient.Country,
ClientVersion: newclient.ClientVersion,
OS: newclient.OS,
OSFamily: newclient.OSFamily,
OSVersion: newclient.OSVersion,
KernelVersion: newclient.KernelVersion,
//AutoUpdate: extclient.AutoUpdate,
}, models.NetworkID(newclient.Network))
if len(violations) > 0 {
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("posture check violations"), logic.Forbidden))
return
}
}
if err := logic.DeleteExtClient(oldExtClient.Network, oldExtClient.ClientID, true); err != nil {
slog.Error(
"failed to delete ext client",
+15 -2
View File
@@ -547,7 +547,7 @@ func addHostToNetwork(w http.ResponseWriter, r *http.Request) {
logic.ReturnErrorResponse(
w,
r,
logic.FormatError(errors.New("hostid or network cannot be empty"), "badrequest"),
logic.FormatError(errors.New("hostid or network cannot be empty"), logic.BadReq),
)
return
}
@@ -555,10 +555,23 @@ func addHostToNetwork(w http.ResponseWriter, r *http.Request) {
currHost, err := logic.GetHost(hostid)
if err != nil {
logger.Log(0, r.Header.Get("user"), "failed to find host:", hostid, err.Error())
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.Internal))
return
}
violations, _ := logic.CheckPostureViolations(models.PostureCheckDeviceInfo{
ClientLocation: currHost.CountryCode,
ClientVersion: currHost.Version,
OS: currHost.OS,
OSFamily: currHost.OSFamily,
OSVersion: currHost.OSVersion,
KernelVersion: currHost.KernelVersion,
AutoUpdate: currHost.AutoUpdate,
}, models.NetworkID(network))
if len(violations) > 0 {
logic.ReturnErrorResponseWithJson(w, r, violations, logic.FormatError(errors.New("posture check violations"), logic.BadReq))
return
}
newNode, err := logic.UpdateHostNetwork(currHost, network, true)
if err != nil {
logger.Log(
+5
View File
@@ -5,6 +5,7 @@ import (
"fmt"
"net/http"
"strings"
"time"
"github.com/gorilla/mux"
"github.com/gravitl/netmaker/database"
@@ -661,6 +662,10 @@ func updateNode(w http.ResponseWriter, r *http.Request) {
logic.ResetAutoRelayedPeer(&currentNode)
}
}
newNode.PostureChecksViolations,
newNode.PostureCheckVolationSeverityLevel = logic.CheckPostureViolations(logic.GetPostureCheckDeviceInfoByNode(newNode),
models.NetworkID(newNode.Network))
newNode.LastEvaluatedAt = time.Now().UTC()
logic.UpsertNode(newNode)
logic.GetNodeStatus(newNode, false)
+2
View File
@@ -40,6 +40,8 @@ require (
)
require (
github.com/Masterminds/semver/v3 v3.4.0
github.com/biter777/countries v1.7.5
github.com/google/go-cmp v0.7.0
github.com/goombaio/namegenerator v0.0.0-20181006234301-989e774b106e
github.com/guumaster/tablewriter v0.0.10
+4
View File
@@ -6,6 +6,10 @@ cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdB
cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10=
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
github.com/Masterminds/semver/v3 v3.4.0 h1:Zog+i5UMtVoCU8oKka5P7i9q9HgrJeGzI9SA1Xbatp0=
github.com/Masterminds/semver/v3 v3.4.0/go.mod h1:4V+yj/TJE1HU9XfppCwVMZq3I84lprf4nC11bSS5beM=
github.com/biter777/countries v1.7.5 h1:MJ+n3+rSxWQdqVJU8eBy9RqcdH6ePPn4PJHocVWUa+Q=
github.com/biter777/countries v1.7.5/go.mod h1:1HSpZ526mYqKJcpT5Ti1kcGQ0L0SrXWIaptUWjFfv2E=
github.com/blang/semver v3.5.1+incompatible h1:cQNTCjp13qL8KC3Nbxr/y2Bqb63oX6wdnnjpJbkM4JQ=
github.com/blang/semver v3.5.1+incompatible/go.mod h1:kRBLl5iJ+tD4TcOOxsy/0fnwebNt5EWlYSAyrTnjyyk=
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc h1:biVzkmvwrH8WK8raXaxBx6fRVTlJILwEwQGL1I/ByEI=
+13
View File
@@ -77,3 +77,16 @@ func ReturnErrorResponse(response http.ResponseWriter, request *http.Request, er
response.WriteHeader(errorMessage.Code)
response.Write(jsonResponse)
}
// ReturnErrorResponseWithJson - processes error with body and adds header
func ReturnErrorResponseWithJson(response http.ResponseWriter, request *http.Request, msg interface{}, errorMessage models.ErrorResponse) {
httpResponse := &models.ErrorResponse{Code: errorMessage.Code, Message: errorMessage.Message, Response: msg}
jsonResponse, err := json.Marshal(httpResponse)
if err != nil {
panic(err)
}
slog.Debug("processed request error", "err", errorMessage.Message)
response.Header().Set("Content-Type", "application/json")
response.WriteHeader(errorMessage.Code)
response.Write(jsonResponse)
}
+16 -1
View File
@@ -451,11 +451,26 @@ func UpdateExtClient(old *models.ExtClient, update *models.CustomExtClient) mode
new.Location = update.Location
}
if update.Country != "" && update.Country != old.Country {
new.Country = update.Country
new.Country = strings.ToUpper(update.Country)
}
if update.DeviceID != "" && old.DeviceID == "" {
new.DeviceID = update.DeviceID
}
if update.OS != "" {
new.OS = update.OS
}
if update.OSFamily != "" {
new.OSFamily = update.OSFamily
}
if update.OSVersion != "" {
new.OSVersion = update.OSVersion
}
if update.KernelVersion != "" {
new.KernelVersion = update.KernelVersion
}
if update.ClientVersion != "" {
new.ClientVersion = update.ClientVersion
}
return new
}
+4 -1
View File
@@ -231,6 +231,8 @@ func CreateIngressGateway(netid string, nodeid string, ingress models.IngressReq
node.Tags = make(map[models.TagID]struct{})
}
node.Tags[models.TagID(fmt.Sprintf("%s.%s", netid, models.GwTagName))] = struct{}{}
node.PostureChecksViolations, node.PostureCheckVolationSeverityLevel = CheckPostureViolations(GetPostureCheckDeviceInfoByNode(&node), models.NetworkID(node.Network))
node.LastEvaluatedAt = time.Now().UTC()
err = UpsertNode(&node)
if err != nil {
return models.Node{}, err
@@ -282,11 +284,12 @@ func DeleteIngressGateway(nodeid string) (models.Node, []models.ExtClient, error
delete(node.Tags, models.TagID(fmt.Sprintf("%s.%s", node.Network, models.GwTagName)))
node.IngressGatewayRange = ""
node.Metadata = ""
node.PostureChecksViolations, node.PostureCheckVolationSeverityLevel = CheckPostureViolations(GetPostureCheckDeviceInfoByNode(&node), models.NetworkID(node.Network))
node.LastEvaluatedAt = time.Now().UTC()
err = UpsertNode(&node)
if err != nil {
return models.Node{}, removedClients, err
}
err = SetNetworkNodesLastModified(node.Network)
return node, removedClients, err
}
+7 -3
View File
@@ -34,7 +34,11 @@ var (
ErrInvalidHostID error = errors.New("invalid host id")
)
var GetHostLocInfo = func(ip, token string) string { return "" }
var GetHostLocInfo = func(ip, token string) (string, string) { return "", "" }
var CheckPostureViolations = func(d models.PostureCheckDeviceInfo, network models.NetworkID) (v []models.Violation, level models.Severity) {
return []models.Violation{}, models.SeverityUnknown
}
func getHostsFromCache() (hosts []models.Host) {
hostCacheMutex.RLock()
@@ -254,9 +258,9 @@ func CreateHost(h *models.Host) error {
h.DNS = "no"
}
if h.EndpointIP != nil {
h.Location = GetHostLocInfo(h.EndpointIP.String(), os.Getenv("IP_INFO_TOKEN"))
h.Location, h.CountryCode = GetHostLocInfo(h.EndpointIP.String(), os.Getenv("IP_INFO_TOKEN"))
} else if h.EndpointIPv6 != nil {
h.Location = GetHostLocInfo(h.EndpointIPv6.String(), os.Getenv("IP_INFO_TOKEN"))
h.Location, h.CountryCode = GetHostLocInfo(h.EndpointIPv6.String(), os.Getenv("IP_INFO_TOKEN"))
}
checkForZombieHosts(h)
return UpsertHost(h)
+46 -1
View File
@@ -620,7 +620,15 @@ func FindRelay(node *models.Node) *models.Node {
func GetAllNodesAPI(nodes []models.Node) []models.ApiNode {
apiNodes := []models.ApiNode{}
for i := range nodes {
newApiNode := nodes[i].ConvertToAPINode()
node := nodes[i]
if !node.IsStatic {
h, err := GetHost(node.HostID.String())
if err == nil {
node.Location = h.Location
node.CountryCode = h.CountryCode
}
}
newApiNode := node.ConvertToAPINode()
apiNodes = append(apiNodes, *newApiNode)
}
return apiNodes[:]
@@ -874,3 +882,40 @@ func GetAllFailOvers() ([]models.Node, error) {
}
return igs, nil
}
// GetPostureCheckDeviceInfoByNode retrieves PostureCheckDeviceInfo for a given node
func GetPostureCheckDeviceInfoByNode(node *models.Node) models.PostureCheckDeviceInfo {
var deviceInfo models.PostureCheckDeviceInfo
if !node.IsStatic {
h, err := GetHost(node.HostID.String())
if err != nil {
return deviceInfo
}
deviceInfo = models.PostureCheckDeviceInfo{
ClientLocation: h.CountryCode,
ClientVersion: h.Version,
OS: h.OS,
OSVersion: h.OSVersion,
OSFamily: h.OSFamily,
KernelVersion: h.KernelVersion,
AutoUpdate: h.AutoUpdate,
Tags: node.Tags,
}
} else {
if node.StaticNode.DeviceID == "" && node.StaticNode.RemoteAccessClientID == "" {
return deviceInfo
}
deviceInfo = models.PostureCheckDeviceInfo{
ClientLocation: node.StaticNode.Country,
ClientVersion: node.StaticNode.ClientVersion,
OS: node.StaticNode.OS,
OSVersion: node.StaticNode.OSVersion,
OSFamily: node.StaticNode.OSFamily,
KernelVersion: node.StaticNode.KernelVersion,
Tags: node.StaticNode.Tags,
}
}
return deviceInfo
}
+227
View File
@@ -0,0 +1,227 @@
package logic
import (
"bytes"
"os"
"os/exec"
"runtime"
"strings"
)
type OSInfo struct {
OS string `json:"os"` // e.g. "ubuntu", "windows", "macos"
OSFamily string `json:"os_family"` // e.g. "linux-debian", "windows"
OSVersion string `json:"os_version"` // e.g. "22.04", "10.0.22631"
KernelVersion string `json:"kernel_version"` // e.g. "6.8.0"
}
/// --- classification helpers you already had ---
func NormalizeOSName(raw string) string {
return strings.ToLower(strings.TrimSpace(raw))
}
// OSFamily returns a normalized OS family string.
// Examples: "linux-debian", "linux-redhat", "linux-arch", "linux-other", "windows", "darwin"
func OSFamily(osName string) string {
osName = NormalizeOSName(osName)
// Non-Linux first
if strings.Contains(osName, "windows") {
return "windows"
}
if strings.Contains(osName, "darwin") || strings.Contains(osName, "mac") || strings.Contains(osName, "os x") {
return "darwin"
}
// Linux families
switch {
// Debian family
case containsAny(osName,
"debian", "ubuntu", "pop", "linuxmint", "kali", "raspbian", "elementary"):
return "linux-debian"
// Red Hat family
case containsAny(osName,
"rhel", "red hat", "centos", "rocky", "alma", "fedora", "oracle linux", "ol"):
return "linux-redhat"
// SUSE family
case containsAny(osName,
"suse", "opensuse", "sles"):
return "linux-suse"
// Arch family
case containsAny(osName,
"arch", "manjaro", "endeavouros", "garuda"):
return "linux-arch"
// Gentoo
case strings.Contains(osName, "gentoo"):
return "linux-gentoo"
// Alpine, Amazon, BusyBox, etc.
case containsAny(osName,
"alpine", "amazon", "busybox"):
return "linux-other"
}
// Fallbacks
if strings.Contains(osName, "linux") {
return "linux-other"
}
return "unknown"
}
func containsAny(s string, subs ...string) bool {
for _, sub := range subs {
if strings.Contains(s, sub) {
return true
}
}
return false
}
/// --- public entrypoint ---
// GetOSInfo returns OS, OSFamily, OSVersion and KernelVersion for the current platform.
func GetOSInfo() OSInfo {
switch runtime.GOOS {
case "linux":
return getLinuxOSInfo()
case "darwin":
return getDarwinOSInfo()
case "windows":
return getWindowsOSInfo()
default:
// Fallback for other UNIX-likes; best-effort
kernel := strings.TrimSpace(runCmd("uname", "-r"))
name := runtime.GOOS
return OSInfo{
OS: NormalizeOSName(name),
OSFamily: OSFamily(name),
OSVersion: "",
KernelVersion: CleanVersion(kernel),
}
}
}
/// --- Linux ---
func getLinuxOSInfo() OSInfo {
var osName, osVersion string
data, err := os.ReadFile("/etc/os-release")
if err == nil {
lines := strings.Split(string(data), "\n")
for _, line := range lines {
line = strings.TrimSpace(line)
if line == "" || strings.HasPrefix(line, "#") {
continue
}
parts := strings.SplitN(line, "=", 2)
if len(parts) != 2 {
continue
}
key := parts[0]
value := strings.Trim(parts[1], `"'`)
switch key {
case "ID":
osName = value
case "VERSION_ID":
osVersion = value
}
}
}
if osName == "" {
// Fallback
osName = "linux"
}
kernel := strings.TrimSpace(runCmd("uname", "-r"))
// trim extras like -generic
if idx := strings.Index(kernel, "-"); idx > 0 {
kernel = kernel[:idx]
}
normName := NormalizeOSName(osName)
return OSInfo{
OS: "linux",
OSFamily: OSFamily(normName),
OSVersion: CleanVersion(osVersion),
KernelVersion: CleanVersion(kernel),
}
}
/// --- macOS (darwin) ---
func getDarwinOSInfo() OSInfo {
productName := strings.TrimSpace(runCmd("sw_vers", "-productName"))
productVer := strings.TrimSpace(runCmd("sw_vers", "-productVersion"))
if productName == "" {
productName = "macos"
}
kernel := strings.TrimSpace(runCmd("uname", "-r"))
if idx := strings.Index(kernel, "-"); idx > 0 {
kernel = kernel[:idx]
}
normName := NormalizeOSName(productName)
return OSInfo{
OS: "darwin",
OSFamily: OSFamily(normName), // "darwin"
OSVersion: CleanVersion(productVer), // e.g. "15.0"
KernelVersion: CleanVersion(kernel),
}
}
/// --- Windows ---
func getWindowsOSInfo() OSInfo {
// OS name: we just say "windows"
osName := "windows"
// OS version via "wmic" or "ver" as fallback
var version string
// Try wmic first (may be missing on newer builds but often still present)
out := runCmd("wmic", "os", "get", "Version", "/value")
for _, line := range strings.Split(out, "\n") {
line = strings.TrimSpace(line)
if strings.HasPrefix(line, "Version=") {
version = strings.TrimPrefix(line, "Version=")
version = strings.TrimSpace(version)
break
}
}
if version == "" {
// Fallback to "ver"
raw := strings.TrimSpace(runCmd("cmd", "/C", "ver"))
version = raw // you can add better parsing if you need
}
// On Windows, kernel and OS version are effectively tied; reuse
kernel := version
normName := NormalizeOSName(osName)
return OSInfo{
OS: "windows", // "windows"
OSFamily: OSFamily(normName),
OSVersion: CleanVersion(version), // e.g. "10.0.22631"
KernelVersion: CleanVersion(kernel),
}
}
/// --- small helper to run commands safely ---
func runCmd(name string, args ...string) string {
cmd := exec.Command(name, args...)
var buf bytes.Buffer
cmd.Stdout = &buf
_ = cmd.Run() // ignore error; best-effort
return buf.String()
}
+47
View File
@@ -1,9 +1,11 @@
package logic
import (
"regexp"
"strings"
"unicode"
"github.com/Masterminds/semver/v3"
"github.com/hashicorp/go-version"
)
@@ -29,3 +31,48 @@ func IsVersionCompatible(ver string) bool {
return constraint.Check(v)
}
// CleanVersion normalizes a version string safely for storage.
// - removes "v" or "V" prefix
// - trims whitespace
// - strips invalid trailing characters
// - preserves semver, prerelease, and build metadata
func CleanVersion(raw string) string {
if raw == "" {
return ""
}
v := strings.TrimSpace(raw)
// Remove leading v/V (common in semver)
v = strings.TrimPrefix(v, "v")
v = strings.TrimPrefix(v, "V")
// Remove trailing commas, quotes, spaces
v = strings.Trim(v, " ,\"'")
// Remove any characters not allowed in semantic versioning:
// Allowed: 0-9 a-z A-Z . - +
re := regexp.MustCompile(`[^0-9A-Za-z\.\-\+]+`)
v = re.ReplaceAllString(v, "")
// Collapse multiple dots (e.g., "1..2" → "1.2")
v = strings.ReplaceAll(v, "..", ".")
for strings.Contains(v, "..") {
v = strings.ReplaceAll(v, "..", ".")
}
return v
}
// IsValidVersion returns true if the version string can be parsed as semantic version.
func IsValidVersion(raw string) bool {
cleaned := CleanVersion(raw)
if cleaned == "" {
return false
}
_, err := semver.NewVersion(cleaned)
return err == nil
}
+7 -4
View File
@@ -425,13 +425,13 @@ func updateHosts() {
host.AutoUpdate = true
logic.UpsertHost(&host)
}
if servercfg.IsPro && host.Location == "" {
if servercfg.IsPro && (host.Location == "" || host.CountryCode == "") {
if host.EndpointIP != nil {
host.Location = logic.GetHostLocInfo(host.EndpointIP.String(), os.Getenv("IP_INFO_TOKEN"))
host.Location, host.CountryCode = logic.GetHostLocInfo(host.EndpointIP.String(), os.Getenv("IP_INFO_TOKEN"))
} else if host.EndpointIPv6 != nil {
host.Location = logic.GetHostLocInfo(host.EndpointIPv6.String(), os.Getenv("IP_INFO_TOKEN"))
host.Location, host.CountryCode = logic.GetHostLocInfo(host.EndpointIPv6.String(), os.Getenv("IP_INFO_TOKEN"))
}
if host.Location != "" {
if host.Location != "" && host.CountryCode != "" {
logic.UpsertHost(&host)
}
}
@@ -891,6 +891,9 @@ func migrateSettings() {
if settings.PeerConnectionCheckInterval == "" {
settings.PeerConnectionCheckInterval = "15"
}
if settings.PostureCheckInterval == "" {
settings.PostureCheckInterval = "30"
}
if settings.CleanUpInterval == 0 {
settings.CleanUpInterval = 60
}
+10
View File
@@ -14,6 +14,9 @@ type ApiHost struct {
Version string `json:"version"`
Name string `json:"name"`
OS string `json:"os"`
OSFamily string `json:"os_family" yaml:"os_family"`
OSVersion string `json:"os_version" yaml:"os_version"`
KernelVersion string `json:"kernel_version" yaml:"kernel_version"`
Debug bool `json:"debug"`
IsStaticPort bool `json:"isstaticport"`
IsStatic bool `json:"isstatic"`
@@ -33,6 +36,7 @@ type ApiHost struct {
AutoUpdate bool `json:"autoupdate" yaml:"autoupdate"`
DNS string `json:"dns" yaml:"dns"`
Location string `json:"location"`
CountryCode string `json:"country_code"`
}
// ApiIface - the interface struct for API usage
@@ -71,6 +75,8 @@ func (h *Host) ConvertNMHostToAPI() *ApiHost {
a.MacAddress = h.MacAddress.String()
a.Name = h.Name
a.OS = h.OS
a.OSFamily = h.OSFamily
a.KernelVersion = h.KernelVersion
a.Nodes = h.Nodes
a.WgPublicListenPort = h.WgPublicListenPort
a.PublicKey = h.PublicKey.String()
@@ -82,6 +88,7 @@ func (h *Host) ConvertNMHostToAPI() *ApiHost {
a.AutoUpdate = h.AutoUpdate
a.DNS = h.DNS
a.Location = h.Location
a.CountryCode = h.CountryCode
return &a
}
@@ -122,6 +129,8 @@ func (a *ApiHost) ConvertAPIHostToNMHost(currentHost *Host) *Host {
h.Nodes = currentHost.Nodes
h.TrafficKeyPublic = currentHost.TrafficKeyPublic
h.OS = currentHost.OS
h.OSFamily = currentHost.OSFamily
h.KernelVersion = currentHost.KernelVersion
h.IsDefault = a.IsDefault
h.NatType = currentHost.NatType
h.TurnEndpoint = currentHost.TurnEndpoint
@@ -129,5 +138,6 @@ func (a *ApiHost) ConvertAPIHostToNMHost(currentHost *Host) *Host {
h.AutoUpdate = a.AutoUpdate
h.DNS = strings.ToLower(a.DNS)
h.Location = currentHost.Location
h.CountryCode = currentHost.CountryCode
return &h
}
+23 -14
View File
@@ -53,20 +53,24 @@ type ApiNode struct {
PendingDelete bool `json:"pendingdelete"`
Metadata string `json:"metadata"`
// == PRO ==
DefaultACL string `json:"defaultacl,omitempty" validate:"checkyesornoorunset"`
IsFailOver bool `json:"is_fail_over"`
FailOverPeers map[string]struct{} `json:"fail_over_peers" yaml:"fail_over_peers"`
FailedOverBy uuid.UUID `json:"failed_over_by" yaml:"failed_over_by"`
IsInternetGateway bool `json:"isinternetgateway" yaml:"isinternetgateway"`
InetNodeReq InetNodeReq `json:"inet_node_req" yaml:"inet_node_req"`
InternetGwID string `json:"internetgw_node_id" yaml:"internetgw_node_id"`
AdditionalRagIps []string `json:"additional_rag_ips" yaml:"additional_rag_ips"`
Tags map[TagID]struct{} `json:"tags" yaml:"tags"`
IsStatic bool `json:"is_static"`
IsUserNode bool `json:"is_user_node"`
StaticNode ExtClient `json:"static_node"`
Status NodeStatus `json:"status"`
Location string `json:"location"`
DefaultACL string `json:"defaultacl,omitempty" validate:"checkyesornoorunset"`
IsFailOver bool `json:"is_fail_over"`
FailOverPeers map[string]struct{} `json:"fail_over_peers" yaml:"fail_over_peers"`
FailedOverBy uuid.UUID `json:"failed_over_by" yaml:"failed_over_by"`
IsInternetGateway bool `json:"isinternetgateway" yaml:"isinternetgateway"`
InetNodeReq InetNodeReq `json:"inet_node_req" yaml:"inet_node_req"`
InternetGwID string `json:"internetgw_node_id" yaml:"internetgw_node_id"`
AdditionalRagIps []string `json:"additional_rag_ips" yaml:"additional_rag_ips"`
Tags map[TagID]struct{} `json:"tags" yaml:"tags"`
IsStatic bool `json:"is_static"`
IsUserNode bool `json:"is_user_node"`
StaticNode ExtClient `json:"static_node"`
Status NodeStatus `json:"status"`
Location string `json:"location"`
Country string `json:"country"`
PostureChecksViolations []Violation `json:"posture_check_violations"`
PostureCheckVolationSeverityLevel Severity `json:"posture_check_violation_severity_level"`
LastEvaluatedAt time.Time `json:"last_evaluated_at"`
}
// ApiNode.ConvertToServerNode - converts an api node to a server node
@@ -227,6 +231,11 @@ func (nm *Node) ConvertToAPINode() *ApiNode {
apiNode.IsUserNode = nm.IsUserNode
apiNode.StaticNode = nm.StaticNode
apiNode.Status = nm.Status
apiNode.PostureChecksViolations = nm.PostureChecksViolations
apiNode.PostureCheckVolationSeverityLevel = nm.PostureCheckVolationSeverityLevel
apiNode.LastEvaluatedAt = nm.LastEvaluatedAt
apiNode.Location = nm.Location
apiNode.Country = nm.CountryCode
return &apiNode
}
+1
View File
@@ -54,6 +54,7 @@ const (
EnrollmentKeySub SubjectType = "ENROLLMENT_KEY"
ClientAppSub SubjectType = "CLIENT-APP"
NameserverSub SubjectType = "NAMESERVER"
PostureCheckSub SubjectType = "POSTURE_CHECK"
)
func (sub SubjectType) String() string {
+52 -33
View File
@@ -1,35 +1,45 @@
package models
import "sync"
import (
"sync"
"time"
)
// ExtClient - struct for external clients
type ExtClient struct {
ClientID string `json:"clientid" bson:"clientid"`
PrivateKey string `json:"privatekey" bson:"privatekey"`
PublicKey string `json:"publickey" bson:"publickey"`
Network string `json:"network" bson:"network"`
DNS string `json:"dns" bson:"dns"`
Address string `json:"address" bson:"address"`
Address6 string `json:"address6" bson:"address6"`
ExtraAllowedIPs []string `json:"extraallowedips" bson:"extraallowedips"`
AllowedIPs []string `json:"allowed_ips"`
IngressGatewayID string `json:"ingressgatewayid" bson:"ingressgatewayid"`
IngressGatewayEndpoint string `json:"ingressgatewayendpoint" bson:"ingressgatewayendpoint"`
LastModified int64 `json:"lastmodified" bson:"lastmodified" swaggertype:"primitive,integer" format:"int64"`
Enabled bool `json:"enabled" bson:"enabled"`
OwnerID string `json:"ownerid" bson:"ownerid"`
DeniedACLs map[string]struct{} `json:"deniednodeacls" bson:"acls,omitempty"`
RemoteAccessClientID string `json:"remote_access_client_id"` // unique ID (MAC address) of RAC machine
PostUp string `json:"postup" bson:"postup"`
PostDown string `json:"postdown" bson:"postdown"`
Tags map[TagID]struct{} `json:"tags"`
Os string `json:"os"`
DeviceID string `json:"device_id"`
DeviceName string `json:"device_name"`
PublicEndpoint string `json:"public_endpoint"`
Country string `json:"country"`
Location string `json:"location"` //format: lat,long
Mutex *sync.Mutex `json:"-"`
ClientID string `json:"clientid" bson:"clientid"`
PrivateKey string `json:"privatekey" bson:"privatekey"`
PublicKey string `json:"publickey" bson:"publickey"`
Network string `json:"network" bson:"network"`
DNS string `json:"dns" bson:"dns"`
Address string `json:"address" bson:"address"`
Address6 string `json:"address6" bson:"address6"`
ExtraAllowedIPs []string `json:"extraallowedips" bson:"extraallowedips"`
AllowedIPs []string `json:"allowed_ips"`
IngressGatewayID string `json:"ingressgatewayid" bson:"ingressgatewayid"`
IngressGatewayEndpoint string `json:"ingressgatewayendpoint" bson:"ingressgatewayendpoint"`
LastModified int64 `json:"lastmodified" bson:"lastmodified" swaggertype:"primitive,integer" format:"int64"`
Enabled bool `json:"enabled" bson:"enabled"`
OwnerID string `json:"ownerid" bson:"ownerid"`
DeniedACLs map[string]struct{} `json:"deniednodeacls" bson:"acls,omitempty"`
RemoteAccessClientID string `json:"remote_access_client_id"` // unique ID (MAC address) of RAC machine
PostUp string `json:"postup" bson:"postup"`
PostDown string `json:"postdown" bson:"postdown"`
Tags map[TagID]struct{} `json:"tags"`
OS string `json:"os"`
OSFamily string `json:"os_family" yaml:"os_family"`
OSVersion string `json:"os_version" yaml:"os_version"`
KernelVersion string `json:"kernel_version" yaml:"kernel_version"`
ClientVersion string `json:"client_version"`
DeviceID string `json:"device_id"`
DeviceName string `json:"device_name"`
PublicEndpoint string `json:"public_endpoint"`
Country string `json:"country"`
Location string `json:"location"` //format: lat,long
PostureChecksViolations []Violation `json:"posture_check_violations"`
PostureCheckVolationSeverityLevel Severity `json:"posture_check_violation_severity_level"`
LastEvaluatedAt time.Time `json:"last_evaluated_at"`
Mutex *sync.Mutex `json:"-"`
}
// CustomExtClient - struct for CustomExtClient params
@@ -44,11 +54,15 @@ type CustomExtClient struct {
PostUp string `json:"postup" bson:"postup" validate:"max=1024"`
PostDown string `json:"postdown" bson:"postdown" validate:"max=1024"`
Tags map[TagID]struct{} `json:"tags"`
Os string `json:"os"`
DeviceID string `json:"device_id"`
DeviceName string `json:"device_name"`
IsAlreadyConnectedToInetGw bool `json:"is_already_connected_to_inet_gw"`
PublicEndpoint string `json:"public_endpoint"`
OS string `json:"os"`
OSFamily string `json:"os_family" yaml:"os_family"`
OSVersion string `json:"os_version" yaml:"os_version"`
KernelVersion string `json:"kernel_version" yaml:"kernel_version"`
ClientVersion string `json:"client_version"`
Country string `json:"country"`
Location string `json:"location"` //format: lat,long
}
@@ -63,10 +77,15 @@ func (ext *ExtClient) ConvertToStaticNode() Node {
Address: ext.AddressIPNet4(),
Address6: ext.AddressIPNet6(),
},
Tags: ext.Tags,
IsStatic: true,
StaticNode: *ext,
IsUserNode: ext.RemoteAccessClientID != "" || ext.DeviceID != "",
Mutex: ext.Mutex,
Tags: ext.Tags,
IsStatic: true,
StaticNode: *ext,
IsUserNode: ext.RemoteAccessClientID != "" || ext.DeviceID != "",
Mutex: ext.Mutex,
CountryCode: ext.Country,
Location: ext.Location,
PostureChecksViolations: ext.PostureChecksViolations,
PostureCheckVolationSeverityLevel: ext.PostureCheckVolationSeverityLevel,
LastEvaluatedAt: ext.LastEvaluatedAt,
}
}
+4
View File
@@ -51,6 +51,9 @@ type Host struct {
HostPass string `json:"hostpass" yaml:"hostpass"`
Name string `json:"name" yaml:"name"`
OS string `json:"os" yaml:"os"`
OSFamily string `json:"os_family" yaml:"os_family"`
OSVersion string `json:"os_version" yaml:"os_version"`
KernelVersion string `json:"kernel_version" yaml:"kernel_version"`
Interface string `json:"interface" yaml:"interface"`
Debug bool `json:"debug" yaml:"debug"`
ListenPort int `json:"listenport" yaml:"listenport"`
@@ -74,6 +77,7 @@ type Host struct {
TurnEndpoint *netip.AddrPort `json:"turn_endpoint,omitempty" yaml:"turn_endpoint,omitempty"`
PersistentKeepalive time.Duration `json:"persistentkeepalive" swaggertype:"primitive,integer" format:"int64" yaml:"persistentkeepalive"`
Location string `json:"location"` // Format: "lat,lon"
CountryCode string `json:"country_code"`
}
// FormatBool converts a boolean to a [yes|no] string
+18 -13
View File
@@ -113,19 +113,24 @@ type Node struct {
//AutoRelayedPeers map[string]struct{} `json:"auto_relayed_peers"`
AutoRelayedPeers map[string]string `json:"auto_relayed_peers_v1"`
//AutoRelayedBy uuid.UUID `json:"auto_relayed_by"`
FailOverPeers map[string]struct{} `json:"fail_over_peers"`
FailedOverBy uuid.UUID `json:"failed_over_by"`
IsInternetGateway bool `json:"isinternetgateway"`
InetNodeReq InetNodeReq `json:"inet_node_req"`
InternetGwID string `json:"internetgw_node_id"`
AdditionalRagIps []net.IP `json:"additional_rag_ips" swaggertype:"array,number"`
Tags map[TagID]struct{} `json:"tags"`
IsStatic bool `json:"is_static"`
IsUserNode bool `json:"is_user_node"`
StaticNode ExtClient `json:"static_node"`
Status NodeStatus `json:"node_status"`
Mutex *sync.Mutex `json:"-"`
EgressDetails EgressDetails `json:"-"`
FailOverPeers map[string]struct{} `json:"fail_over_peers"`
FailedOverBy uuid.UUID `json:"failed_over_by"`
IsInternetGateway bool `json:"isinternetgateway"`
InetNodeReq InetNodeReq `json:"inet_node_req"`
InternetGwID string `json:"internetgw_node_id"`
AdditionalRagIps []net.IP `json:"additional_rag_ips" swaggertype:"array,number"`
Tags map[TagID]struct{} `json:"tags"`
IsStatic bool `json:"is_static"`
IsUserNode bool `json:"is_user_node"`
StaticNode ExtClient `json:"static_node"`
Status NodeStatus `json:"node_status"`
Mutex *sync.Mutex `json:"-"`
EgressDetails EgressDetails `json:"-"`
PostureChecksViolations []Violation `json:"posture_check_violations"`
PostureCheckVolationSeverityLevel Severity `json:"posture_check_violation_severity_level"`
LastEvaluatedAt time.Time `json:"last_evaluated_at"`
Location string `json:"location"` // Format: "lat,lon"
CountryCode string `json:"country_code"`
}
type EgressDetails struct {
EgressGatewayNatEnabled bool
+1
View File
@@ -50,6 +50,7 @@ type ServerSettings struct {
AuditLogsRetentionPeriodInDays int `json:"audit_logs_retention_period"`
OldAClsSupport bool `json:"old_acl_support"`
PeerConnectionCheckInterval string `json:"peer_connection_check_interval"`
PostureCheckInterval string `json:"posture_check_interval"` // in minutes
CleanUpInterval int `json:"clean_up_interval_in_mins"`
}
+32 -2
View File
@@ -116,8 +116,9 @@ type SuccessfulLoginResponse struct {
// ErrorResponse is struct for error
type ErrorResponse struct {
Code int
Message string
Code int
Message string
Response interface{}
}
// NodeAuth - struct for node auth
@@ -457,3 +458,32 @@ type IDPSyncTestRequest struct {
OktaOrgURL string `json:"okta_org_url"`
OktaAPIToken string `json:"okta_api_token"`
}
type PostureCheckDeviceInfo struct {
ClientLocation string
ClientVersion string
OS string
OSFamily string
OSVersion string
KernelVersion string
AutoUpdate bool
Tags map[TagID]struct{}
}
type Violation struct {
CheckID string `json:"check_id"`
Name string `json:"name"`
Attribute string `json:"attribute"`
Message string `json:"message"`
Severity Severity `json:"severity"`
}
type Severity int
const (
SeverityUnknown Severity = iota
SeverityLow
SeverityMedium
SeverityHigh
SeverityCritical
)
+11 -3
View File
@@ -296,13 +296,15 @@ func HandleHostCheckin(h, currentHost *models.Host) bool {
(len(h.NatType) > 0 && h.NatType != currentHost.NatType) ||
h.DefaultInterface != currentHost.DefaultInterface ||
(h.ListenPort != 0 && h.ListenPort != currentHost.ListenPort) ||
(h.WgPublicListenPort != 0 && h.WgPublicListenPort != currentHost.WgPublicListenPort) || (!h.EndpointIPv6.Equal(currentHost.EndpointIPv6))
(h.WgPublicListenPort != 0 && h.WgPublicListenPort != currentHost.WgPublicListenPort) ||
(!h.EndpointIPv6.Equal(currentHost.EndpointIPv6)) || (h.OSFamily != currentHost.OSFamily) ||
(h.OSVersion != currentHost.OSVersion) || (h.KernelVersion != currentHost.KernelVersion)
if ifaceDelta { // only save if something changes
if !h.EndpointIP.Equal(currentHost.EndpointIP) || !h.EndpointIPv6.Equal(currentHost.EndpointIPv6) || currentHost.Location == "" {
if h.EndpointIP != nil {
h.Location = logic.GetHostLocInfo(h.EndpointIP.String(), os.Getenv("IP_INFO_TOKEN"))
h.Location, h.CountryCode = logic.GetHostLocInfo(h.EndpointIP.String(), os.Getenv("IP_INFO_TOKEN"))
} else if h.EndpointIPv6 != nil {
h.Location = logic.GetHostLocInfo(h.EndpointIPv6.String(), os.Getenv("IP_INFO_TOKEN"))
h.Location, h.CountryCode = logic.GetHostLocInfo(h.EndpointIPv6.String(), os.Getenv("IP_INFO_TOKEN"))
}
}
currentHost.EndpointIP = h.EndpointIP
@@ -310,9 +312,15 @@ func HandleHostCheckin(h, currentHost *models.Host) bool {
currentHost.Interfaces = h.Interfaces
currentHost.DefaultInterface = h.DefaultInterface
currentHost.NatType = h.NatType
currentHost.OSFamily = h.OSFamily
currentHost.OSVersion = h.OSVersion
currentHost.KernelVersion = h.KernelVersion
if h.Location != "" {
currentHost.Location = h.Location
}
if h.CountryCode != "" {
currentHost.CountryCode = h.CountryCode
}
if h.ListenPort != 0 {
currentHost.ListenPort = h.ListenPort
}
+343
View File
@@ -0,0 +1,343 @@
package controllers
import (
"context"
"encoding/json"
"errors"
"net/http"
"time"
"github.com/google/uuid"
"github.com/gorilla/mux"
"github.com/gravitl/netmaker/db"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/mq"
proLogic "github.com/gravitl/netmaker/pro/logic"
"github.com/gravitl/netmaker/schema"
)
func PostureCheckHandlers(r *mux.Router) {
r.HandleFunc("/api/v1/posture_check", logic.SecurityCheck(true, http.HandlerFunc(createPostureCheck))).Methods(http.MethodPost)
r.HandleFunc("/api/v1/posture_check", logic.SecurityCheck(true, http.HandlerFunc(listPostureChecks))).Methods(http.MethodGet)
r.HandleFunc("/api/v1/posture_check", logic.SecurityCheck(true, http.HandlerFunc(updatePostureCheck))).Methods(http.MethodPut)
r.HandleFunc("/api/v1/posture_check", logic.SecurityCheck(true, http.HandlerFunc(deletePostureCheck))).Methods(http.MethodDelete)
r.HandleFunc("/api/v1/posture_check/attrs", logic.SecurityCheck(true, http.HandlerFunc(listPostureChecksAttrs))).Methods(http.MethodGet)
r.HandleFunc("/api/v1/posture_check/violations", logic.SecurityCheck(true, http.HandlerFunc(listPostureCheckViolatedNodes))).Methods(http.MethodGet)
}
// @Summary List Posture Checks Available Attributes
// @Router /api/v1/posture_check/attrs [get]
// @Tags Auth
// @Accept json
// @Param query network string
// @Success 200 {object} models.SuccessResponse
// @Failure 400 {object} models.ErrorResponse
// @Failure 401 {object} models.ErrorResponse
// @Failure 500 {object} models.ErrorResponse
func listPostureChecksAttrs(w http.ResponseWriter, r *http.Request) {
logic.ReturnSuccessResponseWithJson(w, r, schema.PostureCheckAttrValues, "fetched posture checks")
}
// @Summary Create Posture Check
// @Router /api/v1/posture_check [post]
// @Tags DNS
// @Accept json
// @Param body body schema.PostureCheck
// @Success 200 {object} models.SuccessResponse
// @Failure 400 {object} models.ErrorResponse
// @Failure 401 {object} models.ErrorResponse
// @Failure 500 {object} models.ErrorResponse
func createPostureCheck(w http.ResponseWriter, r *http.Request) {
var req schema.PostureCheck
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
logger.Log(0, "error decoding request body: ",
err.Error())
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
}
if err := proLogic.ValidatePostureCheck(&req); err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
}
pc := schema.PostureCheck{
ID: uuid.New().String(),
Name: req.Name,
NetworkID: req.NetworkID,
Description: req.Description,
Tags: req.Tags,
Attribute: req.Attribute,
Values: req.Values,
Severity: req.Severity,
Status: true,
CreatedBy: r.Header.Get("user"),
CreatedAt: time.Now().UTC(),
}
err = pc.Create(db.WithContext(r.Context()))
if err != nil {
logic.ReturnErrorResponse(
w,
r,
logic.FormatError(errors.New("error creating posture check "+err.Error()), logic.Internal),
)
return
}
logic.LogEvent(&models.Event{
Action: models.Create,
Source: models.Subject{
ID: r.Header.Get("user"),
Name: r.Header.Get("user"),
Type: models.UserSub,
},
TriggeredBy: r.Header.Get("user"),
Target: models.Subject{
ID: pc.ID,
Name: pc.Name,
Type: models.PostureCheckSub,
},
NetworkID: models.NetworkID(pc.NetworkID),
Origin: models.Dashboard,
})
go mq.PublishPeerUpdate(false)
go proLogic.RunPostureChecks()
logic.ReturnSuccessResponseWithJson(w, r, pc, "created posture check")
}
// @Summary List Posture Checks
// @Router /api/v1/posture_check [get]
// @Tags Auth
// @Accept json
// @Param query network string
// @Success 200 {object} models.SuccessResponse
// @Failure 400 {object} models.ErrorResponse
// @Failure 401 {object} models.ErrorResponse
// @Failure 500 {object} models.ErrorResponse
func listPostureChecks(w http.ResponseWriter, r *http.Request) {
network := r.URL.Query().Get("network")
if network == "" {
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("network is required"), logic.BadReq))
return
}
_, err := logic.GetNetwork(network)
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("network not found"), logic.BadReq))
return
}
id := r.URL.Query().Get("id")
if id != "" {
pc := schema.PostureCheck{ID: id}
err := pc.Get(db.WithContext(r.Context()))
if err != nil {
logic.ReturnErrorResponse(
w,
r,
logic.FormatError(errors.New("error listing posture checks "+err.Error()), "internal"),
)
return
}
logic.ReturnSuccessResponseWithJson(w, r, pc, "fetched posture check")
return
}
pc := schema.PostureCheck{NetworkID: network}
list, err := pc.ListByNetwork(db.WithContext(r.Context()))
if err != nil {
logic.ReturnErrorResponse(
w,
r,
logic.FormatError(errors.New("error listing posture checks "+err.Error()), "internal"),
)
return
}
logic.ReturnSuccessResponseWithJson(w, r, list, "fetched posture checks")
}
// @Summary Update Posture Check
// @Router /api/v1/posture_check [put]
// @Tags Auth
// @Accept json
// @Param body body schema.PostureCheck
// @Success 200 {object} models.SuccessResponse
// @Failure 400 {object} models.ErrorResponse
// @Failure 401 {object} models.ErrorResponse
// @Failure 500 {object} models.ErrorResponse
func updatePostureCheck(w http.ResponseWriter, r *http.Request) {
var updatePc schema.PostureCheck
err := json.NewDecoder(r.Body).Decode(&updatePc)
if err != nil {
logger.Log(0, "error decoding request body: ",
err.Error())
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
}
if err := proLogic.ValidatePostureCheck(&updatePc); err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
}
pc := schema.PostureCheck{ID: updatePc.ID}
err = pc.Get(db.WithContext(r.Context()))
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
}
var updateStatus bool
if updatePc.Status != pc.Status {
updateStatus = true
}
event := &models.Event{
Action: models.Update,
Source: models.Subject{
ID: r.Header.Get("user"),
Name: r.Header.Get("user"),
Type: models.UserSub,
},
TriggeredBy: r.Header.Get("user"),
Target: models.Subject{
ID: pc.ID,
Name: updatePc.Name,
Type: models.PostureCheckSub,
},
Diff: models.Diff{
Old: pc,
New: updatePc,
},
NetworkID: models.NetworkID(pc.NetworkID),
Origin: models.Dashboard,
}
pc.Tags = updatePc.Tags
pc.Attribute = updatePc.Attribute
pc.Values = updatePc.Values
pc.Description = updatePc.Description
pc.Name = updatePc.Name
pc.Severity = updatePc.Severity
pc.Status = updatePc.Status
pc.UpdatedAt = time.Now().UTC()
err = pc.Update(db.WithContext(context.TODO()))
if err != nil {
logic.ReturnErrorResponse(
w,
r,
logic.FormatError(errors.New("error updating posture check "+err.Error()), "internal"),
)
return
}
if updateStatus {
pc.UpdateStatus(db.WithContext(context.TODO()))
}
logic.LogEvent(event)
go mq.PublishPeerUpdate(false)
go proLogic.RunPostureChecks()
logic.ReturnSuccessResponseWithJson(w, r, pc, "updated posture check")
}
// @Summary Delete Posture Check
// @Router /api/v1/posture_check [delete]
// @Tags Auth
// @Accept json
// @Param query id string
// @Success 200 {object} models.SuccessResponse
// @Failure 400 {object} models.ErrorResponse
// @Failure 401 {object} models.ErrorResponse
// @Failure 500 {object} models.ErrorResponse
func deletePostureCheck(w http.ResponseWriter, r *http.Request) {
id := r.URL.Query().Get("id")
if id == "" {
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("id is required"), "badrequest"))
return
}
pc := schema.PostureCheck{ID: id}
err := pc.Get(db.WithContext(r.Context()))
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.BadReq))
return
}
err = pc.Delete(db.WithContext(r.Context()))
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.Internal))
return
}
logic.LogEvent(&models.Event{
Action: models.Delete,
Source: models.Subject{
ID: r.Header.Get("user"),
Name: r.Header.Get("user"),
Type: models.UserSub,
},
TriggeredBy: r.Header.Get("user"),
Target: models.Subject{
ID: pc.ID,
Name: pc.Name,
Type: models.PostureCheckSub,
},
NetworkID: models.NetworkID(pc.NetworkID),
Origin: models.Dashboard,
Diff: models.Diff{
Old: pc,
New: nil,
},
})
go mq.PublishPeerUpdate(false)
logic.ReturnSuccessResponseWithJson(w, r, pc, "deleted posture check")
}
// @Summary List Posture Check violated Nodes
// @Router /api/v1/posture_check/violations [get]
// @Tags Auth
// @Accept json
// @Param query network string
// @Success 200 {object} models.SuccessResponse
// @Failure 400 {object} models.ErrorResponse
// @Failure 401 {object} models.ErrorResponse
// @Failure 500 {object} models.ErrorResponse
func listPostureCheckViolatedNodes(w http.ResponseWriter, r *http.Request) {
network := r.URL.Query().Get("network")
if network == "" {
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("network is required"), logic.BadReq))
return
}
listViolatedusers := r.URL.Query().Get("users") == "true"
violatedNodes := []models.Node{}
if listViolatedusers {
extclients, err := logic.GetNetworkExtClients(network)
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.BadReq))
return
}
for _, extclient := range extclients {
if extclient.DeviceID != "" && extclient.Enabled {
if len(extclient.PostureChecksViolations) > 0 {
violatedNodes = append(violatedNodes, extclient.ConvertToStaticNode())
}
}
}
} else {
nodes, err := logic.GetNetworkNodes(network)
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.BadReq))
return
}
for _, node := range nodes {
if len(node.PostureChecksViolations) > 0 {
violatedNodes = append(violatedNodes, node)
}
}
}
apiNodes := logic.GetAllNodesAPI(violatedNodes)
logic.SortApiNodes(apiNodes[:])
logic.ReturnSuccessResponseWithJson(w, r, apiNodes, "fetched posture checks violated nodes")
}
+1 -1
View File
@@ -241,7 +241,7 @@ func updateTag(w http.ResponseWriter, r *http.Request) {
UsedByCnt: len(updateTag.TaggedNodes),
TaggedNodes: updateTag.TaggedNodes,
}
go proLogic.RunPostureChecks()
logic.ReturnSuccessResponseWithJson(w, r, res, "updated tags")
}
+3 -1
View File
@@ -37,6 +37,7 @@ func InitPro() {
proControllers.TagHandlers,
proControllers.NetworkHandlers,
proControllers.AutoRelayHandlers,
proControllers.PostureCheckHandlers,
)
controller.ListRoles = proControllers.ListRoles
logic.EnterpriseCheckFuncs = append(logic.EnterpriseCheckFuncs, func() {
@@ -99,8 +100,8 @@ func InitPro() {
auth.ResetIDPSyncHook()
email.Init()
go proLogic.EventWatcher()
logic.GetMetricsMonitor().Start()
proLogic.AddPostureCheckHook()
})
logic.ResetFailOver = proLogic.ResetFailOver
logic.ResetFailedOverPeer = proLogic.ResetFailedOverPeer
@@ -172,6 +173,7 @@ func InitPro() {
logic.GetNameserversForNode = proLogic.GetNameserversForNode
logic.ValidateNameserverReq = proLogic.ValidateNameserverReq
logic.ValidateEgressReq = proLogic.ValidateEgressReq
logic.CheckPostureViolations = proLogic.CheckPostureViolations
}
+8 -5
View File
@@ -240,7 +240,7 @@ func updateNodeMetrics(currentNode *models.Node, newMetrics *models.Metrics) {
slog.Debug("[metrics] node metrics data", "node ID", currentNode.ID, "metrics", newMetrics)
}
func GetHostLocInfo(ip, token string) string {
func GetHostLocInfo(ip, token string) (loc, country string) {
url := "https://ipinfo.io/"
if ip != "" {
url += ip
@@ -253,15 +253,18 @@ func GetHostLocInfo(ip, token string) string {
client := http.Client{Timeout: 3 * time.Second}
resp, err := client.Get(url)
if err != nil {
return ""
return "", ""
}
defer resp.Body.Close()
var data struct {
Loc string `json:"loc"` // Format: "lat,lon"
Loc string `json:"loc"` // Format: "lat,lon"
Country string `json:"country"`
}
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
return ""
return "", ""
}
return data.Loc
loc = data.Loc
country = data.Country
return
}
+464
View File
@@ -0,0 +1,464 @@
package logic
import (
"context"
"errors"
"fmt"
"slices"
"strconv"
"strings"
"sync"
"time"
"github.com/biter777/countries"
"github.com/gravitl/netmaker/db"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/schema"
"gorm.io/datatypes"
)
var postureCheckMutex = &sync.Mutex{}
func AddPostureCheckHook() {
settings := logic.GetServerSettings()
interval := time.Hour
i, err := strconv.Atoi(settings.PostureCheckInterval)
if err == nil {
interval = time.Minute * time.Duration(i)
}
logic.HookManagerCh <- models.HookDetails{
Hook: logic.WrapHook(RunPostureChecks),
Interval: interval,
}
}
func RunPostureChecks() error {
postureCheckMutex.Lock()
defer postureCheckMutex.Unlock()
nets, err := logic.GetNetworks()
if err != nil {
return err
}
nodes, err := logic.GetAllNodes()
if err != nil {
return err
}
for _, netI := range nets {
networkNodes := logic.GetNetworkNodesMemory(nodes, netI.NetID)
if len(networkNodes) == 0 {
continue
}
networkNodes = logic.AddStaticNodestoList(networkNodes)
pcLi, err := (&schema.PostureCheck{NetworkID: netI.NetID}).ListByNetwork(db.WithContext(context.TODO()))
if err != nil || len(pcLi) == 0 {
continue
}
for _, nodeI := range networkNodes {
postureChecksViolations, postureCheckVolationSeverityLevel := GetPostureCheckViolations(pcLi, logic.GetPostureCheckDeviceInfoByNode(&nodeI))
if nodeI.IsStatic {
extclient, err := logic.GetExtClient(nodeI.StaticNode.ClientID, nodeI.StaticNode.Network)
if err == nil {
extclient.PostureChecksViolations = postureChecksViolations
extclient.PostureCheckVolationSeverityLevel = postureCheckVolationSeverityLevel
extclient.LastEvaluatedAt = time.Now().UTC()
logic.SaveExtClient(&extclient)
}
} else {
nodeI.PostureChecksViolations, nodeI.PostureCheckVolationSeverityLevel = postureChecksViolations,
postureCheckVolationSeverityLevel
nodeI.LastEvaluatedAt = time.Now().UTC()
logic.UpsertNode(&nodeI)
}
}
}
return nil
}
func CheckPostureViolations(d models.PostureCheckDeviceInfo, network models.NetworkID) ([]models.Violation, models.Severity) {
pcLi, err := (&schema.PostureCheck{NetworkID: network.String()}).ListByNetwork(db.WithContext(context.TODO()))
if err != nil || len(pcLi) == 0 {
return []models.Violation{}, models.SeverityUnknown
}
violations, level := GetPostureCheckViolations(pcLi, d)
return violations, level
}
func GetPostureCheckViolations(checks []schema.PostureCheck, d models.PostureCheckDeviceInfo) ([]models.Violation, models.Severity) {
var violations []models.Violation
highest := models.SeverityUnknown
// Group checks by attribute
checksByAttribute := make(map[schema.Attribute][]schema.PostureCheck)
for _, c := range checks {
// skip disabled checks
if !c.Status {
continue
}
// Check if tags match
if _, ok := c.Tags["*"]; !ok {
exists := false
for tagID := range c.Tags {
if _, ok := d.Tags[models.TagID(tagID)]; ok {
exists = true
break
}
}
if !exists {
continue
}
}
checksByAttribute[c.Attribute] = append(checksByAttribute[c.Attribute], c)
}
// Handle OS and OSFamily together with OR logic since they are related
osChecks := checksByAttribute[schema.OS]
osFamilyChecks := checksByAttribute[schema.OSFamily]
if len(osChecks) > 0 || len(osFamilyChecks) > 0 {
osAllowed := evaluateAttributeChecks(osChecks, schema.OS, d)
osFamilyAllowed := evaluateAttributeChecks(osFamilyChecks, schema.OSFamily, d)
// OR condition: if either OS or OSFamily passes, both are considered passed
if !osAllowed && !osFamilyAllowed {
// Both failed, add violations for both
osDenied := getDeniedChecks(osChecks, schema.OS, d)
osFamilyDenied := getDeniedChecks(osFamilyChecks, schema.OSFamily, d)
for _, denied := range osDenied {
sev := denied.check.Severity
if sev > highest {
highest = sev
}
v := models.Violation{
CheckID: denied.check.ID,
Name: denied.check.Name,
Attribute: string(denied.check.Attribute),
Message: denied.reason,
Severity: sev,
}
violations = append(violations, v)
}
for _, denied := range osFamilyDenied {
sev := denied.check.Severity
if sev > highest {
highest = sev
}
v := models.Violation{
CheckID: denied.check.ID,
Name: denied.check.Name,
Attribute: string(denied.check.Attribute),
Message: denied.reason,
Severity: sev,
}
violations = append(violations, v)
}
}
}
// For all other attributes, check if ANY check allows it
for attr, attrChecks := range checksByAttribute {
// Skip OS and OSFamily as they are handled above
if attr == schema.OS || attr == schema.OSFamily {
continue
}
// Check if any check for this attribute allows the device
allowed := false
var deniedChecks []struct {
check schema.PostureCheck
reason string
}
for _, c := range attrChecks {
violated, reason := evaluatePostureCheck(&c, d)
if !violated {
// At least one check allows it
allowed = true
break
}
// Track denied checks with their reasons for violation reporting
deniedChecks = append(deniedChecks, struct {
check schema.PostureCheck
reason string
}{check: c, reason: reason})
}
// If no check allows it, add violations for all denied checks
if !allowed {
for _, denied := range deniedChecks {
sev := denied.check.Severity
if sev > highest {
highest = sev
}
v := models.Violation{
CheckID: denied.check.ID,
Name: denied.check.Name,
Attribute: string(denied.check.Attribute),
Message: denied.reason,
Severity: sev,
}
violations = append(violations, v)
}
}
}
return violations, highest
}
// evaluateAttributeChecks evaluates checks for a specific attribute and returns true if any check allows the device
func evaluateAttributeChecks(attrChecks []schema.PostureCheck, attr schema.Attribute, d models.PostureCheckDeviceInfo) bool {
for _, c := range attrChecks {
violated, _ := evaluatePostureCheck(&c, d)
if !violated {
// At least one check allows it
return true
}
}
return false
}
// getDeniedChecks returns all checks that denied the device for a specific attribute
func getDeniedChecks(attrChecks []schema.PostureCheck, attr schema.Attribute, d models.PostureCheckDeviceInfo) []struct {
check schema.PostureCheck
reason string
} {
var deniedChecks []struct {
check schema.PostureCheck
reason string
}
for _, c := range attrChecks {
violated, reason := evaluatePostureCheck(&c, d)
if violated {
deniedChecks = append(deniedChecks, struct {
check schema.PostureCheck
reason string
}{check: c, reason: reason})
}
}
return deniedChecks
}
func evaluatePostureCheck(check *schema.PostureCheck, d models.PostureCheckDeviceInfo) (violated bool, reason string) {
switch check.Attribute {
// ------------------------
// 1. Geographic check
// ------------------------
case schema.ClientLocation:
if !slices.Contains(check.Values, strings.ToUpper(d.ClientLocation)) {
return true, fmt.Sprintf("client location '%s' not allowed", CountryNameFromISO(d.ClientLocation))
}
// ------------------------
// 2. Client version check
// Supports: exact match OR allowed list OR semver rules
// ------------------------
case schema.ClientVersion:
for _, rule := range check.Values {
ok, err := matchVersionRule(d.ClientVersion, rule)
if err != nil || !ok {
return true, fmt.Sprintf("client version '%s' violation", d.ClientVersion)
}
}
// ------------------------
// 3. OS check
// ("windows", "mac", "linux", etc.)
// ------------------------
case schema.OS:
if !slices.Contains(check.Values, d.OS) {
return true, fmt.Sprintf("client os '%s' not allowed", d.OS)
}
case schema.OSFamily:
if !slices.Contains(check.Values, d.OSFamily) {
return true, fmt.Sprintf("os family '%s' not allowed", d.OSFamily)
}
// ------------------------
// 4. OS version check
// Supports operators: > >= < <= =
// ------------------------
case schema.OSVersion:
for _, rule := range check.Values {
ok, err := matchVersionRule(d.OSVersion, rule)
if err != nil || !ok {
return true, fmt.Sprintf("os version '%s' violation", d.OSVersion)
}
}
case schema.KernelVersion:
for _, rule := range check.Values {
ok, err := matchVersionRule(d.KernelVersion, rule)
if err != nil || !ok {
return true, fmt.Sprintf("kernel version '%s' violation", d.KernelVersion)
}
}
// ------------------------
// 5. Auto-update check
// Values: ["true"] or ["false"]
// ------------------------
case schema.AutoUpdate:
required := len(check.Values) > 0 && strings.ToLower(check.Values[0]) == "true"
if required && !d.AutoUpdate {
return true, "auto update must be enabled"
}
if !required && d.AutoUpdate {
return true, "auto update must be disabled"
}
}
return false, ""
}
func cleanVersion(v string) string {
v = strings.TrimSpace(v)
v = strings.TrimPrefix(v, "v")
v = strings.TrimPrefix(v, "V")
v = strings.TrimSuffix(v, ",")
v = strings.TrimSpace(v)
return v
}
func matchVersionRule(actual, rule string) (bool, error) {
actual = cleanVersion(actual)
rule = strings.TrimSpace(rule)
op := "="
switch {
case strings.HasPrefix(rule, ">="):
op = ">="
rule = strings.TrimPrefix(rule, ">=")
case strings.HasPrefix(rule, "<="):
op = "<="
rule = strings.TrimPrefix(rule, "<=")
case strings.HasPrefix(rule, ">"):
op = ">"
rule = strings.TrimPrefix(rule, ">")
case strings.HasPrefix(rule, "<"):
op = "<"
rule = strings.TrimPrefix(rule, "<")
case strings.HasPrefix(rule, "="):
op = "="
rule = strings.TrimPrefix(rule, "=")
}
rule = cleanVersion(rule)
cmp := compareVersions(actual, rule)
switch op {
case "=":
return cmp == 0, nil
case ">":
return cmp == 1, nil
case "<":
return cmp == -1, nil
case ">=":
return cmp == 1 || cmp == 0, nil
case "<=":
return cmp == -1 || cmp == 0, nil
}
return false, fmt.Errorf("invalid rule: %s", rule)
}
func compareVersions(a, b string) int {
pa := strings.Split(a, ".")
pb := strings.Split(b, ".")
n := len(pa)
if len(pb) > n {
n = len(pb)
}
for i := 0; i < n; i++ {
ai, bi := 0, 0
if i < len(pa) {
ai, _ = strconv.Atoi(pa[i])
}
if i < len(pb) {
bi, _ = strconv.Atoi(pb[i])
}
if ai > bi {
return 1
}
if ai < bi {
return -1
}
}
return 0
}
func ValidatePostureCheck(pc *schema.PostureCheck) error {
if pc.Name == "" {
return errors.New("name cannot be empty")
}
_, err := logic.GetNetwork(pc.NetworkID)
if err != nil {
return errors.New("invalid network")
}
allowedAttrvaluesMap, ok := schema.PostureCheckAttrValuesMap[pc.Attribute]
if !ok {
return errors.New("unkown attribute")
}
if len(pc.Values) == 0 {
return errors.New("attribute value cannot be empty")
}
for i, valueI := range pc.Values {
pc.Values[i] = strings.ToLower(valueI)
}
if pc.Attribute == schema.ClientLocation {
for i, loc := range pc.Values {
if countries.ByName(loc) == countries.Unknown {
return errors.New("invalid country code")
}
pc.Values[i] = strings.ToUpper(loc)
}
}
if pc.Attribute == schema.AutoUpdate || pc.Attribute == schema.OS ||
pc.Attribute == schema.OSFamily {
for _, valueI := range pc.Values {
if _, ok := allowedAttrvaluesMap[valueI]; !ok {
return errors.New("invalid attribute value")
}
}
}
if pc.Attribute == schema.ClientVersion || pc.Attribute == schema.OSVersion ||
pc.Attribute == schema.KernelVersion {
for i, valueI := range pc.Values {
if !logic.IsValidVersion(valueI) {
return errors.New("invalid attribute version value")
}
pc.Values[i] = logic.CleanVersion(valueI)
}
}
if len(pc.Tags) > 0 {
for tagID := range pc.Tags {
if tagID == "*" {
continue
}
_, err := GetTag(models.TagID(tagID))
if err != nil {
return errors.New("unknown tag")
}
}
} else {
pc.Tags = make(datatypes.JSONMap)
pc.Tags["*"] = struct{}{}
}
return nil
}
func CountryNameFromISO(code string) string {
c := countries.ByName(code) // works with ISO2, ISO3, full name
if c == countries.Unknown {
return ""
}
return c.Info().Name
}
+1
View File
@@ -9,5 +9,6 @@ func ListModels() []interface{} {
&Event{},
&PendingHost{},
&Nameserver{},
&PostureCheck{},
}
}
+123
View File
@@ -0,0 +1,123 @@
package schema
import (
"context"
"time"
"github.com/gravitl/netmaker/db"
"github.com/gravitl/netmaker/models"
"gorm.io/datatypes"
)
type Attribute string
type Values string
const (
OS Attribute = "os"
OSVersion Attribute = "os_version"
OSFamily Attribute = "os_family"
KernelVersion Attribute = "kernel_version"
AutoUpdate Attribute = "auto_update"
ClientVersion Attribute = "client_version"
ClientLocation Attribute = "client_location"
)
var PostureCheckAttrs = []Attribute{
ClientLocation,
ClientVersion,
OS,
OSVersion,
OSFamily,
KernelVersion,
AutoUpdate,
}
var PostureCheckAttrValuesMap = map[Attribute]map[string]struct{}{
ClientLocation: {
"any_valid_iso_country_codes": {},
},
ClientVersion: {
"any_valid_semantic_version": {},
},
OS: {
"linux": {},
"darwin": {},
"windows": {},
"ios": {},
"android": {},
},
OSVersion: {
"any_valid_semantic_version": {},
},
OSFamily: {
"linux-debian": {},
"linux-redhat": {},
"linux-suse": {},
"linux-arch": {},
"linux-gentoo": {},
"linux-other": {},
"darwin": {},
"windows": {},
"ios": {},
"android": {},
},
KernelVersion: {
"any_valid_semantic_version": {},
},
AutoUpdate: {
"true": {},
"false": {},
},
}
var PostureCheckAttrValues = map[Attribute][]string{
ClientLocation: {"any_valid_iso_country_codes"},
ClientVersion: {"any_valid_semantic_version"},
OS: {"linux", "darwin", "windows", "ios", "android"},
OSVersion: {"any_valid_semantic_version"},
OSFamily: {"linux-debian", "linux-redhat", "linux-suse", "linux-arch", "linux-gentoo", "linux-other", "darwin", "windows", "ios", "android"},
KernelVersion: {"any_valid_semantic_version"},
AutoUpdate: {"true", "false"},
}
type PostureCheck struct {
ID string `gorm:"primaryKey" json:"id"`
Name string `gorm:"name" json:"name"`
NetworkID string `gorm:"network_id" json:"network_id"`
Description string `gorm:"description" json:"description"`
Attribute Attribute `gorm:"attribute" json:"attribute"`
Values datatypes.JSONSlice[string] `gorm:"values" json:"values"`
Severity models.Severity `gorm:"severity" json:"severity"`
Tags datatypes.JSONMap `gorm:"tags" json:"tags"`
Status bool `gorm:"status" json:"status"`
CreatedBy string `gorm:"created_by" json:"created_by"`
CreatedAt time.Time `gorm:"created_at" json:"created_at"`
UpdatedAt time.Time `gorm:"updated_at" json:"updated_at"`
}
func (p *PostureCheck) Get(ctx context.Context) error {
return db.FromContext(ctx).Model(&PostureCheck{}).First(&p).Where("id = ?", p.ID).Error
}
func (p *PostureCheck) Update(ctx context.Context) error {
return db.FromContext(ctx).Model(&PostureCheck{}).Where("id = ?", p.ID).Updates(&p).Error
}
func (p *PostureCheck) Create(ctx context.Context) error {
return db.FromContext(ctx).Model(&PostureCheck{}).Create(&p).Error
}
func (p *PostureCheck) ListByNetwork(ctx context.Context) (pcli []PostureCheck, err error) {
err = db.FromContext(ctx).Model(&PostureCheck{}).Where("network_id = ?", p.NetworkID).Find(&pcli).Error
return
}
func (p *PostureCheck) Delete(ctx context.Context) error {
return db.FromContext(ctx).Model(&PostureCheck{}).Where("id = ?", p.ID).Delete(&p).Error
}
func (p *PostureCheck) UpdateStatus(ctx context.Context) error {
return db.FromContext(ctx).Model(&PostureCheck{}).Where("id = ?", p.ID).Updates(map[string]any{
"status": p.Status,
}).Error
}