mirror of
https://github.com/gravitl/netmaker.git
synced 2026-04-22 16:07:11 +08:00
Merge pull request #3735 from gravitl/NM-166
NM-166: Device Posture Checks
This commit is contained in:
@@ -2,8 +2,11 @@ package controller
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
@@ -358,6 +361,37 @@ func handleHostRegister(w http.ResponseWriter, r *http.Request) {
|
||||
)
|
||||
return
|
||||
}
|
||||
if newHost.EndpointIP != nil {
|
||||
newHost.Location, newHost.CountryCode = logic.GetHostLocInfo(newHost.EndpointIP.String(), os.Getenv("IP_INFO_TOKEN"))
|
||||
} else if newHost.EndpointIPv6 != nil {
|
||||
newHost.Location, newHost.CountryCode = logic.GetHostLocInfo(newHost.EndpointIPv6.String(), os.Getenv("IP_INFO_TOKEN"))
|
||||
}
|
||||
pcviolations := []models.Violation{}
|
||||
skipViolatedNetworks := []string{}
|
||||
for _, netI := range enrollmentKey.Networks {
|
||||
violations, _ := logic.CheckPostureViolations(models.PostureCheckDeviceInfo{
|
||||
ClientLocation: newHost.CountryCode,
|
||||
ClientVersion: newHost.Version,
|
||||
OS: newHost.OS,
|
||||
OSFamily: newHost.OSFamily,
|
||||
OSVersion: newHost.OSVersion,
|
||||
KernelVersion: newHost.KernelVersion,
|
||||
AutoUpdate: newHost.AutoUpdate,
|
||||
}, models.NetworkID(netI))
|
||||
pcviolations = append(pcviolations, violations...)
|
||||
if len(violations) > 0 {
|
||||
skipViolatedNetworks = append(skipViolatedNetworks, netI)
|
||||
}
|
||||
}
|
||||
if len(skipViolatedNetworks) == len(enrollmentKey.Networks) && len(pcviolations) > 0 {
|
||||
logic.ReturnErrorResponse(w, r,
|
||||
logic.FormatError(errors.New("access blocked: this device doesn’t meet security requirements"), logic.Forbidden))
|
||||
return
|
||||
}
|
||||
// need to remove the networks that were skipped from the enrollment key
|
||||
enrollmentKey.Networks = slices.DeleteFunc(enrollmentKey.Networks, func(netI string) bool {
|
||||
return slices.Contains(skipViolatedNetworks, netI)
|
||||
})
|
||||
if !hostExists {
|
||||
newHost.PersistentKeepalive = models.DefaultPersistentKeepAlive
|
||||
// register host
|
||||
|
||||
@@ -715,13 +715,29 @@ func createExtClient(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
err = errors.New("remote client config already exists on the gateway")
|
||||
slog.Error("failed to create extclient", "user", userName, "error", err)
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.BadReq))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
extclient := logic.UpdateExtClient(&models.ExtClient{}, &customExtClient)
|
||||
if extclient.DeviceID != "" {
|
||||
// check for violations connecting from desktop app
|
||||
violations, _ := logic.CheckPostureViolations(models.PostureCheckDeviceInfo{
|
||||
ClientLocation: extclient.Country,
|
||||
ClientVersion: extclient.ClientVersion,
|
||||
OS: extclient.OS,
|
||||
OSFamily: extclient.OSFamily,
|
||||
OSVersion: extclient.OSVersion,
|
||||
KernelVersion: extclient.KernelVersion,
|
||||
//AutoUpdate: extclient.AutoUpdate,
|
||||
}, models.NetworkID(extclient.Network))
|
||||
if len(violations) > 0 {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("posture check violations"), logic.Forbidden))
|
||||
return
|
||||
}
|
||||
}
|
||||
extclient.OwnerID = userName
|
||||
extclient.RemoteAccessClientID = customExtClient.RemoteAccessClientID
|
||||
extclient.IngressGatewayID = nodeid
|
||||
@@ -749,7 +765,6 @@ func createExtClient(w http.ResponseWriter, r *http.Request) {
|
||||
if err == nil { // check if parent network default ACL is enabled (yes) or not (no)
|
||||
extclient.Enabled = parentNetwork.DefaultACL == "yes"
|
||||
}
|
||||
extclient.Os = customExtClient.Os
|
||||
extclient.DeviceID = customExtClient.DeviceID
|
||||
extclient.DeviceName = customExtClient.DeviceName
|
||||
if customExtClient.IsAlreadyConnectedToInetGw {
|
||||
@@ -758,7 +773,7 @@ func createExtClient(w http.ResponseWriter, r *http.Request) {
|
||||
extclient.PublicEndpoint = customExtClient.PublicEndpoint
|
||||
extclient.Country = customExtClient.Country
|
||||
if customExtClient.RemoteAccessClientID != "" && customExtClient.Location == "" {
|
||||
extclient.Location = logic.GetHostLocInfo(logic.GetClientIP(r), os.Getenv("IP_INFO_TOKEN"))
|
||||
extclient.Location, extclient.Country = logic.GetHostLocInfo(logic.GetClientIP(r), os.Getenv("IP_INFO_TOKEN"))
|
||||
}
|
||||
extclient.Location = customExtClient.Location
|
||||
|
||||
@@ -772,7 +787,7 @@ func createExtClient(w http.ResponseWriter, r *http.Request) {
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.Internal))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -820,7 +835,7 @@ func createExtClient(w http.ResponseWriter, r *http.Request) {
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.Internal))
|
||||
return
|
||||
}
|
||||
if err := mq.PublishPeerUpdate(false); err != nil {
|
||||
@@ -903,9 +918,25 @@ func updateExtClient(w http.ResponseWriter, r *http.Request) {
|
||||
replacePeers = true
|
||||
}
|
||||
if update.RemoteAccessClientID != "" && update.Location == "" {
|
||||
update.Location = logic.GetHostLocInfo(logic.GetClientIP(r), os.Getenv("IP_INFO_TOKEN"))
|
||||
update.Location, update.Country = logic.GetHostLocInfo(logic.GetClientIP(r), os.Getenv("IP_INFO_TOKEN"))
|
||||
}
|
||||
newclient := logic.UpdateExtClient(&oldExtClient, &update)
|
||||
if newclient.DeviceID != "" && newclient.Enabled {
|
||||
// check for violations connecting from desktop app
|
||||
violations, _ := logic.CheckPostureViolations(models.PostureCheckDeviceInfo{
|
||||
ClientLocation: newclient.Country,
|
||||
ClientVersion: newclient.ClientVersion,
|
||||
OS: newclient.OS,
|
||||
OSFamily: newclient.OSFamily,
|
||||
OSVersion: newclient.OSVersion,
|
||||
KernelVersion: newclient.KernelVersion,
|
||||
//AutoUpdate: extclient.AutoUpdate,
|
||||
}, models.NetworkID(newclient.Network))
|
||||
if len(violations) > 0 {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("posture check violations"), logic.Forbidden))
|
||||
return
|
||||
}
|
||||
}
|
||||
if err := logic.DeleteExtClient(oldExtClient.Network, oldExtClient.ClientID, true); err != nil {
|
||||
slog.Error(
|
||||
"failed to delete ext client",
|
||||
|
||||
+15
-2
@@ -547,7 +547,7 @@ func addHostToNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
logic.ReturnErrorResponse(
|
||||
w,
|
||||
r,
|
||||
logic.FormatError(errors.New("hostid or network cannot be empty"), "badrequest"),
|
||||
logic.FormatError(errors.New("hostid or network cannot be empty"), logic.BadReq),
|
||||
)
|
||||
return
|
||||
}
|
||||
@@ -555,10 +555,23 @@ func addHostToNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
currHost, err := logic.GetHost(hostid)
|
||||
if err != nil {
|
||||
logger.Log(0, r.Header.Get("user"), "failed to find host:", hostid, err.Error())
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.Internal))
|
||||
return
|
||||
}
|
||||
|
||||
violations, _ := logic.CheckPostureViolations(models.PostureCheckDeviceInfo{
|
||||
ClientLocation: currHost.CountryCode,
|
||||
ClientVersion: currHost.Version,
|
||||
OS: currHost.OS,
|
||||
OSFamily: currHost.OSFamily,
|
||||
OSVersion: currHost.OSVersion,
|
||||
KernelVersion: currHost.KernelVersion,
|
||||
AutoUpdate: currHost.AutoUpdate,
|
||||
}, models.NetworkID(network))
|
||||
if len(violations) > 0 {
|
||||
logic.ReturnErrorResponseWithJson(w, r, violations, logic.FormatError(errors.New("posture check violations"), logic.BadReq))
|
||||
return
|
||||
}
|
||||
newNode, err := logic.UpdateHostNetwork(currHost, network, true)
|
||||
if err != nil {
|
||||
logger.Log(
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/gravitl/netmaker/database"
|
||||
@@ -661,6 +662,10 @@ func updateNode(w http.ResponseWriter, r *http.Request) {
|
||||
logic.ResetAutoRelayedPeer(¤tNode)
|
||||
}
|
||||
}
|
||||
newNode.PostureChecksViolations,
|
||||
newNode.PostureCheckVolationSeverityLevel = logic.CheckPostureViolations(logic.GetPostureCheckDeviceInfoByNode(newNode),
|
||||
models.NetworkID(newNode.Network))
|
||||
newNode.LastEvaluatedAt = time.Now().UTC()
|
||||
logic.UpsertNode(newNode)
|
||||
logic.GetNodeStatus(newNode, false)
|
||||
|
||||
|
||||
@@ -40,6 +40,8 @@ require (
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/Masterminds/semver/v3 v3.4.0
|
||||
github.com/biter777/countries v1.7.5
|
||||
github.com/google/go-cmp v0.7.0
|
||||
github.com/goombaio/namegenerator v0.0.0-20181006234301-989e774b106e
|
||||
github.com/guumaster/tablewriter v0.0.10
|
||||
|
||||
@@ -6,6 +6,10 @@ cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdB
|
||||
cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10=
|
||||
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
|
||||
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
|
||||
github.com/Masterminds/semver/v3 v3.4.0 h1:Zog+i5UMtVoCU8oKka5P7i9q9HgrJeGzI9SA1Xbatp0=
|
||||
github.com/Masterminds/semver/v3 v3.4.0/go.mod h1:4V+yj/TJE1HU9XfppCwVMZq3I84lprf4nC11bSS5beM=
|
||||
github.com/biter777/countries v1.7.5 h1:MJ+n3+rSxWQdqVJU8eBy9RqcdH6ePPn4PJHocVWUa+Q=
|
||||
github.com/biter777/countries v1.7.5/go.mod h1:1HSpZ526mYqKJcpT5Ti1kcGQ0L0SrXWIaptUWjFfv2E=
|
||||
github.com/blang/semver v3.5.1+incompatible h1:cQNTCjp13qL8KC3Nbxr/y2Bqb63oX6wdnnjpJbkM4JQ=
|
||||
github.com/blang/semver v3.5.1+incompatible/go.mod h1:kRBLl5iJ+tD4TcOOxsy/0fnwebNt5EWlYSAyrTnjyyk=
|
||||
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc h1:biVzkmvwrH8WK8raXaxBx6fRVTlJILwEwQGL1I/ByEI=
|
||||
|
||||
@@ -77,3 +77,16 @@ func ReturnErrorResponse(response http.ResponseWriter, request *http.Request, er
|
||||
response.WriteHeader(errorMessage.Code)
|
||||
response.Write(jsonResponse)
|
||||
}
|
||||
|
||||
// ReturnErrorResponseWithJson - processes error with body and adds header
|
||||
func ReturnErrorResponseWithJson(response http.ResponseWriter, request *http.Request, msg interface{}, errorMessage models.ErrorResponse) {
|
||||
httpResponse := &models.ErrorResponse{Code: errorMessage.Code, Message: errorMessage.Message, Response: msg}
|
||||
jsonResponse, err := json.Marshal(httpResponse)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
slog.Debug("processed request error", "err", errorMessage.Message)
|
||||
response.Header().Set("Content-Type", "application/json")
|
||||
response.WriteHeader(errorMessage.Code)
|
||||
response.Write(jsonResponse)
|
||||
}
|
||||
|
||||
+16
-1
@@ -451,11 +451,26 @@ func UpdateExtClient(old *models.ExtClient, update *models.CustomExtClient) mode
|
||||
new.Location = update.Location
|
||||
}
|
||||
if update.Country != "" && update.Country != old.Country {
|
||||
new.Country = update.Country
|
||||
new.Country = strings.ToUpper(update.Country)
|
||||
}
|
||||
if update.DeviceID != "" && old.DeviceID == "" {
|
||||
new.DeviceID = update.DeviceID
|
||||
}
|
||||
if update.OS != "" {
|
||||
new.OS = update.OS
|
||||
}
|
||||
if update.OSFamily != "" {
|
||||
new.OSFamily = update.OSFamily
|
||||
}
|
||||
if update.OSVersion != "" {
|
||||
new.OSVersion = update.OSVersion
|
||||
}
|
||||
if update.KernelVersion != "" {
|
||||
new.KernelVersion = update.KernelVersion
|
||||
}
|
||||
if update.ClientVersion != "" {
|
||||
new.ClientVersion = update.ClientVersion
|
||||
}
|
||||
return new
|
||||
}
|
||||
|
||||
|
||||
+4
-1
@@ -231,6 +231,8 @@ func CreateIngressGateway(netid string, nodeid string, ingress models.IngressReq
|
||||
node.Tags = make(map[models.TagID]struct{})
|
||||
}
|
||||
node.Tags[models.TagID(fmt.Sprintf("%s.%s", netid, models.GwTagName))] = struct{}{}
|
||||
node.PostureChecksViolations, node.PostureCheckVolationSeverityLevel = CheckPostureViolations(GetPostureCheckDeviceInfoByNode(&node), models.NetworkID(node.Network))
|
||||
node.LastEvaluatedAt = time.Now().UTC()
|
||||
err = UpsertNode(&node)
|
||||
if err != nil {
|
||||
return models.Node{}, err
|
||||
@@ -282,11 +284,12 @@ func DeleteIngressGateway(nodeid string) (models.Node, []models.ExtClient, error
|
||||
delete(node.Tags, models.TagID(fmt.Sprintf("%s.%s", node.Network, models.GwTagName)))
|
||||
node.IngressGatewayRange = ""
|
||||
node.Metadata = ""
|
||||
node.PostureChecksViolations, node.PostureCheckVolationSeverityLevel = CheckPostureViolations(GetPostureCheckDeviceInfoByNode(&node), models.NetworkID(node.Network))
|
||||
node.LastEvaluatedAt = time.Now().UTC()
|
||||
err = UpsertNode(&node)
|
||||
if err != nil {
|
||||
return models.Node{}, removedClients, err
|
||||
}
|
||||
|
||||
err = SetNetworkNodesLastModified(node.Network)
|
||||
return node, removedClients, err
|
||||
}
|
||||
|
||||
+7
-3
@@ -34,7 +34,11 @@ var (
|
||||
ErrInvalidHostID error = errors.New("invalid host id")
|
||||
)
|
||||
|
||||
var GetHostLocInfo = func(ip, token string) string { return "" }
|
||||
var GetHostLocInfo = func(ip, token string) (string, string) { return "", "" }
|
||||
|
||||
var CheckPostureViolations = func(d models.PostureCheckDeviceInfo, network models.NetworkID) (v []models.Violation, level models.Severity) {
|
||||
return []models.Violation{}, models.SeverityUnknown
|
||||
}
|
||||
|
||||
func getHostsFromCache() (hosts []models.Host) {
|
||||
hostCacheMutex.RLock()
|
||||
@@ -254,9 +258,9 @@ func CreateHost(h *models.Host) error {
|
||||
h.DNS = "no"
|
||||
}
|
||||
if h.EndpointIP != nil {
|
||||
h.Location = GetHostLocInfo(h.EndpointIP.String(), os.Getenv("IP_INFO_TOKEN"))
|
||||
h.Location, h.CountryCode = GetHostLocInfo(h.EndpointIP.String(), os.Getenv("IP_INFO_TOKEN"))
|
||||
} else if h.EndpointIPv6 != nil {
|
||||
h.Location = GetHostLocInfo(h.EndpointIPv6.String(), os.Getenv("IP_INFO_TOKEN"))
|
||||
h.Location, h.CountryCode = GetHostLocInfo(h.EndpointIPv6.String(), os.Getenv("IP_INFO_TOKEN"))
|
||||
}
|
||||
checkForZombieHosts(h)
|
||||
return UpsertHost(h)
|
||||
|
||||
+46
-1
@@ -620,7 +620,15 @@ func FindRelay(node *models.Node) *models.Node {
|
||||
func GetAllNodesAPI(nodes []models.Node) []models.ApiNode {
|
||||
apiNodes := []models.ApiNode{}
|
||||
for i := range nodes {
|
||||
newApiNode := nodes[i].ConvertToAPINode()
|
||||
node := nodes[i]
|
||||
if !node.IsStatic {
|
||||
h, err := GetHost(node.HostID.String())
|
||||
if err == nil {
|
||||
node.Location = h.Location
|
||||
node.CountryCode = h.CountryCode
|
||||
}
|
||||
}
|
||||
newApiNode := node.ConvertToAPINode()
|
||||
apiNodes = append(apiNodes, *newApiNode)
|
||||
}
|
||||
return apiNodes[:]
|
||||
@@ -874,3 +882,40 @@ func GetAllFailOvers() ([]models.Node, error) {
|
||||
}
|
||||
return igs, nil
|
||||
}
|
||||
|
||||
// GetPostureCheckDeviceInfoByNode retrieves PostureCheckDeviceInfo for a given node
|
||||
func GetPostureCheckDeviceInfoByNode(node *models.Node) models.PostureCheckDeviceInfo {
|
||||
var deviceInfo models.PostureCheckDeviceInfo
|
||||
|
||||
if !node.IsStatic {
|
||||
h, err := GetHost(node.HostID.String())
|
||||
if err != nil {
|
||||
return deviceInfo
|
||||
}
|
||||
deviceInfo = models.PostureCheckDeviceInfo{
|
||||
ClientLocation: h.CountryCode,
|
||||
ClientVersion: h.Version,
|
||||
OS: h.OS,
|
||||
OSVersion: h.OSVersion,
|
||||
OSFamily: h.OSFamily,
|
||||
KernelVersion: h.KernelVersion,
|
||||
AutoUpdate: h.AutoUpdate,
|
||||
Tags: node.Tags,
|
||||
}
|
||||
} else {
|
||||
if node.StaticNode.DeviceID == "" && node.StaticNode.RemoteAccessClientID == "" {
|
||||
return deviceInfo
|
||||
}
|
||||
deviceInfo = models.PostureCheckDeviceInfo{
|
||||
ClientLocation: node.StaticNode.Country,
|
||||
ClientVersion: node.StaticNode.ClientVersion,
|
||||
OS: node.StaticNode.OS,
|
||||
OSVersion: node.StaticNode.OSVersion,
|
||||
OSFamily: node.StaticNode.OSFamily,
|
||||
KernelVersion: node.StaticNode.KernelVersion,
|
||||
Tags: node.StaticNode.Tags,
|
||||
}
|
||||
}
|
||||
|
||||
return deviceInfo
|
||||
}
|
||||
|
||||
@@ -0,0 +1,227 @@
|
||||
package logic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type OSInfo struct {
|
||||
OS string `json:"os"` // e.g. "ubuntu", "windows", "macos"
|
||||
OSFamily string `json:"os_family"` // e.g. "linux-debian", "windows"
|
||||
OSVersion string `json:"os_version"` // e.g. "22.04", "10.0.22631"
|
||||
KernelVersion string `json:"kernel_version"` // e.g. "6.8.0"
|
||||
}
|
||||
|
||||
/// --- classification helpers you already had ---
|
||||
|
||||
func NormalizeOSName(raw string) string {
|
||||
return strings.ToLower(strings.TrimSpace(raw))
|
||||
}
|
||||
|
||||
// OSFamily returns a normalized OS family string.
|
||||
// Examples: "linux-debian", "linux-redhat", "linux-arch", "linux-other", "windows", "darwin"
|
||||
func OSFamily(osName string) string {
|
||||
osName = NormalizeOSName(osName)
|
||||
|
||||
// Non-Linux first
|
||||
if strings.Contains(osName, "windows") {
|
||||
return "windows"
|
||||
}
|
||||
if strings.Contains(osName, "darwin") || strings.Contains(osName, "mac") || strings.Contains(osName, "os x") {
|
||||
return "darwin"
|
||||
}
|
||||
|
||||
// Linux families
|
||||
switch {
|
||||
// Debian family
|
||||
case containsAny(osName,
|
||||
"debian", "ubuntu", "pop", "linuxmint", "kali", "raspbian", "elementary"):
|
||||
return "linux-debian"
|
||||
|
||||
// Red Hat family
|
||||
case containsAny(osName,
|
||||
"rhel", "red hat", "centos", "rocky", "alma", "fedora", "oracle linux", "ol"):
|
||||
return "linux-redhat"
|
||||
|
||||
// SUSE family
|
||||
case containsAny(osName,
|
||||
"suse", "opensuse", "sles"):
|
||||
return "linux-suse"
|
||||
|
||||
// Arch family
|
||||
case containsAny(osName,
|
||||
"arch", "manjaro", "endeavouros", "garuda"):
|
||||
return "linux-arch"
|
||||
|
||||
// Gentoo
|
||||
case strings.Contains(osName, "gentoo"):
|
||||
return "linux-gentoo"
|
||||
|
||||
// Alpine, Amazon, BusyBox, etc.
|
||||
case containsAny(osName,
|
||||
"alpine", "amazon", "busybox"):
|
||||
return "linux-other"
|
||||
}
|
||||
|
||||
// Fallbacks
|
||||
if strings.Contains(osName, "linux") {
|
||||
return "linux-other"
|
||||
}
|
||||
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
func containsAny(s string, subs ...string) bool {
|
||||
for _, sub := range subs {
|
||||
if strings.Contains(s, sub) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
/// --- public entrypoint ---
|
||||
|
||||
// GetOSInfo returns OS, OSFamily, OSVersion and KernelVersion for the current platform.
|
||||
func GetOSInfo() OSInfo {
|
||||
switch runtime.GOOS {
|
||||
case "linux":
|
||||
return getLinuxOSInfo()
|
||||
case "darwin":
|
||||
return getDarwinOSInfo()
|
||||
case "windows":
|
||||
return getWindowsOSInfo()
|
||||
default:
|
||||
// Fallback for other UNIX-likes; best-effort
|
||||
kernel := strings.TrimSpace(runCmd("uname", "-r"))
|
||||
name := runtime.GOOS
|
||||
return OSInfo{
|
||||
OS: NormalizeOSName(name),
|
||||
OSFamily: OSFamily(name),
|
||||
OSVersion: "",
|
||||
KernelVersion: CleanVersion(kernel),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// --- Linux ---
|
||||
|
||||
func getLinuxOSInfo() OSInfo {
|
||||
var osName, osVersion string
|
||||
|
||||
data, err := os.ReadFile("/etc/os-release")
|
||||
if err == nil {
|
||||
lines := strings.Split(string(data), "\n")
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
parts := strings.SplitN(line, "=", 2)
|
||||
if len(parts) != 2 {
|
||||
continue
|
||||
}
|
||||
key := parts[0]
|
||||
value := strings.Trim(parts[1], `"'`)
|
||||
|
||||
switch key {
|
||||
case "ID":
|
||||
osName = value
|
||||
case "VERSION_ID":
|
||||
osVersion = value
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if osName == "" {
|
||||
// Fallback
|
||||
osName = "linux"
|
||||
}
|
||||
kernel := strings.TrimSpace(runCmd("uname", "-r"))
|
||||
// trim extras like -generic
|
||||
if idx := strings.Index(kernel, "-"); idx > 0 {
|
||||
kernel = kernel[:idx]
|
||||
}
|
||||
|
||||
normName := NormalizeOSName(osName)
|
||||
return OSInfo{
|
||||
OS: "linux",
|
||||
OSFamily: OSFamily(normName),
|
||||
OSVersion: CleanVersion(osVersion),
|
||||
KernelVersion: CleanVersion(kernel),
|
||||
}
|
||||
}
|
||||
|
||||
/// --- macOS (darwin) ---
|
||||
|
||||
func getDarwinOSInfo() OSInfo {
|
||||
productName := strings.TrimSpace(runCmd("sw_vers", "-productName"))
|
||||
productVer := strings.TrimSpace(runCmd("sw_vers", "-productVersion"))
|
||||
|
||||
if productName == "" {
|
||||
productName = "macos"
|
||||
}
|
||||
kernel := strings.TrimSpace(runCmd("uname", "-r"))
|
||||
if idx := strings.Index(kernel, "-"); idx > 0 {
|
||||
kernel = kernel[:idx]
|
||||
}
|
||||
|
||||
normName := NormalizeOSName(productName)
|
||||
return OSInfo{
|
||||
OS: "darwin",
|
||||
OSFamily: OSFamily(normName), // "darwin"
|
||||
OSVersion: CleanVersion(productVer), // e.g. "15.0"
|
||||
KernelVersion: CleanVersion(kernel),
|
||||
}
|
||||
}
|
||||
|
||||
/// --- Windows ---
|
||||
|
||||
func getWindowsOSInfo() OSInfo {
|
||||
// OS name: we just say "windows"
|
||||
osName := "windows"
|
||||
|
||||
// OS version via "wmic" or "ver" as fallback
|
||||
var version string
|
||||
|
||||
// Try wmic first (may be missing on newer builds but often still present)
|
||||
out := runCmd("wmic", "os", "get", "Version", "/value")
|
||||
for _, line := range strings.Split(out, "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if strings.HasPrefix(line, "Version=") {
|
||||
version = strings.TrimPrefix(line, "Version=")
|
||||
version = strings.TrimSpace(version)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if version == "" {
|
||||
// Fallback to "ver"
|
||||
raw := strings.TrimSpace(runCmd("cmd", "/C", "ver"))
|
||||
version = raw // you can add better parsing if you need
|
||||
}
|
||||
|
||||
// On Windows, kernel and OS version are effectively tied; reuse
|
||||
kernel := version
|
||||
|
||||
normName := NormalizeOSName(osName)
|
||||
return OSInfo{
|
||||
OS: "windows", // "windows"
|
||||
OSFamily: OSFamily(normName),
|
||||
OSVersion: CleanVersion(version), // e.g. "10.0.22631"
|
||||
KernelVersion: CleanVersion(kernel),
|
||||
}
|
||||
}
|
||||
|
||||
/// --- small helper to run commands safely ---
|
||||
|
||||
func runCmd(name string, args ...string) string {
|
||||
cmd := exec.Command(name, args...)
|
||||
var buf bytes.Buffer
|
||||
cmd.Stdout = &buf
|
||||
_ = cmd.Run() // ignore error; best-effort
|
||||
return buf.String()
|
||||
}
|
||||
@@ -1,9 +1,11 @@
|
||||
package logic
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/Masterminds/semver/v3"
|
||||
"github.com/hashicorp/go-version"
|
||||
)
|
||||
|
||||
@@ -29,3 +31,48 @@ func IsVersionCompatible(ver string) bool {
|
||||
return constraint.Check(v)
|
||||
|
||||
}
|
||||
|
||||
// CleanVersion normalizes a version string safely for storage.
|
||||
// - removes "v" or "V" prefix
|
||||
// - trims whitespace
|
||||
// - strips invalid trailing characters
|
||||
// - preserves semver, prerelease, and build metadata
|
||||
func CleanVersion(raw string) string {
|
||||
if raw == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
v := strings.TrimSpace(raw)
|
||||
|
||||
// Remove leading v/V (common in semver)
|
||||
v = strings.TrimPrefix(v, "v")
|
||||
v = strings.TrimPrefix(v, "V")
|
||||
|
||||
// Remove trailing commas, quotes, spaces
|
||||
v = strings.Trim(v, " ,\"'")
|
||||
|
||||
// Remove any characters not allowed in semantic versioning:
|
||||
// Allowed: 0-9 a-z A-Z . - +
|
||||
re := regexp.MustCompile(`[^0-9A-Za-z\.\-\+]+`)
|
||||
v = re.ReplaceAllString(v, "")
|
||||
|
||||
// Collapse multiple dots (e.g., "1..2" → "1.2")
|
||||
v = strings.ReplaceAll(v, "..", ".")
|
||||
for strings.Contains(v, "..") {
|
||||
v = strings.ReplaceAll(v, "..", ".")
|
||||
}
|
||||
|
||||
return v
|
||||
}
|
||||
|
||||
// IsValidVersion returns true if the version string can be parsed as semantic version.
|
||||
func IsValidVersion(raw string) bool {
|
||||
cleaned := CleanVersion(raw)
|
||||
|
||||
if cleaned == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
_, err := semver.NewVersion(cleaned)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
+7
-4
@@ -425,13 +425,13 @@ func updateHosts() {
|
||||
host.AutoUpdate = true
|
||||
logic.UpsertHost(&host)
|
||||
}
|
||||
if servercfg.IsPro && host.Location == "" {
|
||||
if servercfg.IsPro && (host.Location == "" || host.CountryCode == "") {
|
||||
if host.EndpointIP != nil {
|
||||
host.Location = logic.GetHostLocInfo(host.EndpointIP.String(), os.Getenv("IP_INFO_TOKEN"))
|
||||
host.Location, host.CountryCode = logic.GetHostLocInfo(host.EndpointIP.String(), os.Getenv("IP_INFO_TOKEN"))
|
||||
} else if host.EndpointIPv6 != nil {
|
||||
host.Location = logic.GetHostLocInfo(host.EndpointIPv6.String(), os.Getenv("IP_INFO_TOKEN"))
|
||||
host.Location, host.CountryCode = logic.GetHostLocInfo(host.EndpointIPv6.String(), os.Getenv("IP_INFO_TOKEN"))
|
||||
}
|
||||
if host.Location != "" {
|
||||
if host.Location != "" && host.CountryCode != "" {
|
||||
logic.UpsertHost(&host)
|
||||
}
|
||||
}
|
||||
@@ -891,6 +891,9 @@ func migrateSettings() {
|
||||
if settings.PeerConnectionCheckInterval == "" {
|
||||
settings.PeerConnectionCheckInterval = "15"
|
||||
}
|
||||
if settings.PostureCheckInterval == "" {
|
||||
settings.PostureCheckInterval = "30"
|
||||
}
|
||||
if settings.CleanUpInterval == 0 {
|
||||
settings.CleanUpInterval = 60
|
||||
}
|
||||
|
||||
@@ -14,6 +14,9 @@ type ApiHost struct {
|
||||
Version string `json:"version"`
|
||||
Name string `json:"name"`
|
||||
OS string `json:"os"`
|
||||
OSFamily string `json:"os_family" yaml:"os_family"`
|
||||
OSVersion string `json:"os_version" yaml:"os_version"`
|
||||
KernelVersion string `json:"kernel_version" yaml:"kernel_version"`
|
||||
Debug bool `json:"debug"`
|
||||
IsStaticPort bool `json:"isstaticport"`
|
||||
IsStatic bool `json:"isstatic"`
|
||||
@@ -33,6 +36,7 @@ type ApiHost struct {
|
||||
AutoUpdate bool `json:"autoupdate" yaml:"autoupdate"`
|
||||
DNS string `json:"dns" yaml:"dns"`
|
||||
Location string `json:"location"`
|
||||
CountryCode string `json:"country_code"`
|
||||
}
|
||||
|
||||
// ApiIface - the interface struct for API usage
|
||||
@@ -71,6 +75,8 @@ func (h *Host) ConvertNMHostToAPI() *ApiHost {
|
||||
a.MacAddress = h.MacAddress.String()
|
||||
a.Name = h.Name
|
||||
a.OS = h.OS
|
||||
a.OSFamily = h.OSFamily
|
||||
a.KernelVersion = h.KernelVersion
|
||||
a.Nodes = h.Nodes
|
||||
a.WgPublicListenPort = h.WgPublicListenPort
|
||||
a.PublicKey = h.PublicKey.String()
|
||||
@@ -82,6 +88,7 @@ func (h *Host) ConvertNMHostToAPI() *ApiHost {
|
||||
a.AutoUpdate = h.AutoUpdate
|
||||
a.DNS = h.DNS
|
||||
a.Location = h.Location
|
||||
a.CountryCode = h.CountryCode
|
||||
return &a
|
||||
}
|
||||
|
||||
@@ -122,6 +129,8 @@ func (a *ApiHost) ConvertAPIHostToNMHost(currentHost *Host) *Host {
|
||||
h.Nodes = currentHost.Nodes
|
||||
h.TrafficKeyPublic = currentHost.TrafficKeyPublic
|
||||
h.OS = currentHost.OS
|
||||
h.OSFamily = currentHost.OSFamily
|
||||
h.KernelVersion = currentHost.KernelVersion
|
||||
h.IsDefault = a.IsDefault
|
||||
h.NatType = currentHost.NatType
|
||||
h.TurnEndpoint = currentHost.TurnEndpoint
|
||||
@@ -129,5 +138,6 @@ func (a *ApiHost) ConvertAPIHostToNMHost(currentHost *Host) *Host {
|
||||
h.AutoUpdate = a.AutoUpdate
|
||||
h.DNS = strings.ToLower(a.DNS)
|
||||
h.Location = currentHost.Location
|
||||
h.CountryCode = currentHost.CountryCode
|
||||
return &h
|
||||
}
|
||||
|
||||
+23
-14
@@ -53,20 +53,24 @@ type ApiNode struct {
|
||||
PendingDelete bool `json:"pendingdelete"`
|
||||
Metadata string `json:"metadata"`
|
||||
// == PRO ==
|
||||
DefaultACL string `json:"defaultacl,omitempty" validate:"checkyesornoorunset"`
|
||||
IsFailOver bool `json:"is_fail_over"`
|
||||
FailOverPeers map[string]struct{} `json:"fail_over_peers" yaml:"fail_over_peers"`
|
||||
FailedOverBy uuid.UUID `json:"failed_over_by" yaml:"failed_over_by"`
|
||||
IsInternetGateway bool `json:"isinternetgateway" yaml:"isinternetgateway"`
|
||||
InetNodeReq InetNodeReq `json:"inet_node_req" yaml:"inet_node_req"`
|
||||
InternetGwID string `json:"internetgw_node_id" yaml:"internetgw_node_id"`
|
||||
AdditionalRagIps []string `json:"additional_rag_ips" yaml:"additional_rag_ips"`
|
||||
Tags map[TagID]struct{} `json:"tags" yaml:"tags"`
|
||||
IsStatic bool `json:"is_static"`
|
||||
IsUserNode bool `json:"is_user_node"`
|
||||
StaticNode ExtClient `json:"static_node"`
|
||||
Status NodeStatus `json:"status"`
|
||||
Location string `json:"location"`
|
||||
DefaultACL string `json:"defaultacl,omitempty" validate:"checkyesornoorunset"`
|
||||
IsFailOver bool `json:"is_fail_over"`
|
||||
FailOverPeers map[string]struct{} `json:"fail_over_peers" yaml:"fail_over_peers"`
|
||||
FailedOverBy uuid.UUID `json:"failed_over_by" yaml:"failed_over_by"`
|
||||
IsInternetGateway bool `json:"isinternetgateway" yaml:"isinternetgateway"`
|
||||
InetNodeReq InetNodeReq `json:"inet_node_req" yaml:"inet_node_req"`
|
||||
InternetGwID string `json:"internetgw_node_id" yaml:"internetgw_node_id"`
|
||||
AdditionalRagIps []string `json:"additional_rag_ips" yaml:"additional_rag_ips"`
|
||||
Tags map[TagID]struct{} `json:"tags" yaml:"tags"`
|
||||
IsStatic bool `json:"is_static"`
|
||||
IsUserNode bool `json:"is_user_node"`
|
||||
StaticNode ExtClient `json:"static_node"`
|
||||
Status NodeStatus `json:"status"`
|
||||
Location string `json:"location"`
|
||||
Country string `json:"country"`
|
||||
PostureChecksViolations []Violation `json:"posture_check_violations"`
|
||||
PostureCheckVolationSeverityLevel Severity `json:"posture_check_violation_severity_level"`
|
||||
LastEvaluatedAt time.Time `json:"last_evaluated_at"`
|
||||
}
|
||||
|
||||
// ApiNode.ConvertToServerNode - converts an api node to a server node
|
||||
@@ -227,6 +231,11 @@ func (nm *Node) ConvertToAPINode() *ApiNode {
|
||||
apiNode.IsUserNode = nm.IsUserNode
|
||||
apiNode.StaticNode = nm.StaticNode
|
||||
apiNode.Status = nm.Status
|
||||
apiNode.PostureChecksViolations = nm.PostureChecksViolations
|
||||
apiNode.PostureCheckVolationSeverityLevel = nm.PostureCheckVolationSeverityLevel
|
||||
apiNode.LastEvaluatedAt = nm.LastEvaluatedAt
|
||||
apiNode.Location = nm.Location
|
||||
apiNode.Country = nm.CountryCode
|
||||
return &apiNode
|
||||
}
|
||||
|
||||
|
||||
@@ -54,6 +54,7 @@ const (
|
||||
EnrollmentKeySub SubjectType = "ENROLLMENT_KEY"
|
||||
ClientAppSub SubjectType = "CLIENT-APP"
|
||||
NameserverSub SubjectType = "NAMESERVER"
|
||||
PostureCheckSub SubjectType = "POSTURE_CHECK"
|
||||
)
|
||||
|
||||
func (sub SubjectType) String() string {
|
||||
|
||||
+52
-33
@@ -1,35 +1,45 @@
|
||||
package models
|
||||
|
||||
import "sync"
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ExtClient - struct for external clients
|
||||
type ExtClient struct {
|
||||
ClientID string `json:"clientid" bson:"clientid"`
|
||||
PrivateKey string `json:"privatekey" bson:"privatekey"`
|
||||
PublicKey string `json:"publickey" bson:"publickey"`
|
||||
Network string `json:"network" bson:"network"`
|
||||
DNS string `json:"dns" bson:"dns"`
|
||||
Address string `json:"address" bson:"address"`
|
||||
Address6 string `json:"address6" bson:"address6"`
|
||||
ExtraAllowedIPs []string `json:"extraallowedips" bson:"extraallowedips"`
|
||||
AllowedIPs []string `json:"allowed_ips"`
|
||||
IngressGatewayID string `json:"ingressgatewayid" bson:"ingressgatewayid"`
|
||||
IngressGatewayEndpoint string `json:"ingressgatewayendpoint" bson:"ingressgatewayendpoint"`
|
||||
LastModified int64 `json:"lastmodified" bson:"lastmodified" swaggertype:"primitive,integer" format:"int64"`
|
||||
Enabled bool `json:"enabled" bson:"enabled"`
|
||||
OwnerID string `json:"ownerid" bson:"ownerid"`
|
||||
DeniedACLs map[string]struct{} `json:"deniednodeacls" bson:"acls,omitempty"`
|
||||
RemoteAccessClientID string `json:"remote_access_client_id"` // unique ID (MAC address) of RAC machine
|
||||
PostUp string `json:"postup" bson:"postup"`
|
||||
PostDown string `json:"postdown" bson:"postdown"`
|
||||
Tags map[TagID]struct{} `json:"tags"`
|
||||
Os string `json:"os"`
|
||||
DeviceID string `json:"device_id"`
|
||||
DeviceName string `json:"device_name"`
|
||||
PublicEndpoint string `json:"public_endpoint"`
|
||||
Country string `json:"country"`
|
||||
Location string `json:"location"` //format: lat,long
|
||||
Mutex *sync.Mutex `json:"-"`
|
||||
ClientID string `json:"clientid" bson:"clientid"`
|
||||
PrivateKey string `json:"privatekey" bson:"privatekey"`
|
||||
PublicKey string `json:"publickey" bson:"publickey"`
|
||||
Network string `json:"network" bson:"network"`
|
||||
DNS string `json:"dns" bson:"dns"`
|
||||
Address string `json:"address" bson:"address"`
|
||||
Address6 string `json:"address6" bson:"address6"`
|
||||
ExtraAllowedIPs []string `json:"extraallowedips" bson:"extraallowedips"`
|
||||
AllowedIPs []string `json:"allowed_ips"`
|
||||
IngressGatewayID string `json:"ingressgatewayid" bson:"ingressgatewayid"`
|
||||
IngressGatewayEndpoint string `json:"ingressgatewayendpoint" bson:"ingressgatewayendpoint"`
|
||||
LastModified int64 `json:"lastmodified" bson:"lastmodified" swaggertype:"primitive,integer" format:"int64"`
|
||||
Enabled bool `json:"enabled" bson:"enabled"`
|
||||
OwnerID string `json:"ownerid" bson:"ownerid"`
|
||||
DeniedACLs map[string]struct{} `json:"deniednodeacls" bson:"acls,omitempty"`
|
||||
RemoteAccessClientID string `json:"remote_access_client_id"` // unique ID (MAC address) of RAC machine
|
||||
PostUp string `json:"postup" bson:"postup"`
|
||||
PostDown string `json:"postdown" bson:"postdown"`
|
||||
Tags map[TagID]struct{} `json:"tags"`
|
||||
OS string `json:"os"`
|
||||
OSFamily string `json:"os_family" yaml:"os_family"`
|
||||
OSVersion string `json:"os_version" yaml:"os_version"`
|
||||
KernelVersion string `json:"kernel_version" yaml:"kernel_version"`
|
||||
ClientVersion string `json:"client_version"`
|
||||
DeviceID string `json:"device_id"`
|
||||
DeviceName string `json:"device_name"`
|
||||
PublicEndpoint string `json:"public_endpoint"`
|
||||
Country string `json:"country"`
|
||||
Location string `json:"location"` //format: lat,long
|
||||
PostureChecksViolations []Violation `json:"posture_check_violations"`
|
||||
PostureCheckVolationSeverityLevel Severity `json:"posture_check_violation_severity_level"`
|
||||
LastEvaluatedAt time.Time `json:"last_evaluated_at"`
|
||||
Mutex *sync.Mutex `json:"-"`
|
||||
}
|
||||
|
||||
// CustomExtClient - struct for CustomExtClient params
|
||||
@@ -44,11 +54,15 @@ type CustomExtClient struct {
|
||||
PostUp string `json:"postup" bson:"postup" validate:"max=1024"`
|
||||
PostDown string `json:"postdown" bson:"postdown" validate:"max=1024"`
|
||||
Tags map[TagID]struct{} `json:"tags"`
|
||||
Os string `json:"os"`
|
||||
DeviceID string `json:"device_id"`
|
||||
DeviceName string `json:"device_name"`
|
||||
IsAlreadyConnectedToInetGw bool `json:"is_already_connected_to_inet_gw"`
|
||||
PublicEndpoint string `json:"public_endpoint"`
|
||||
OS string `json:"os"`
|
||||
OSFamily string `json:"os_family" yaml:"os_family"`
|
||||
OSVersion string `json:"os_version" yaml:"os_version"`
|
||||
KernelVersion string `json:"kernel_version" yaml:"kernel_version"`
|
||||
ClientVersion string `json:"client_version"`
|
||||
Country string `json:"country"`
|
||||
Location string `json:"location"` //format: lat,long
|
||||
}
|
||||
@@ -63,10 +77,15 @@ func (ext *ExtClient) ConvertToStaticNode() Node {
|
||||
Address: ext.AddressIPNet4(),
|
||||
Address6: ext.AddressIPNet6(),
|
||||
},
|
||||
Tags: ext.Tags,
|
||||
IsStatic: true,
|
||||
StaticNode: *ext,
|
||||
IsUserNode: ext.RemoteAccessClientID != "" || ext.DeviceID != "",
|
||||
Mutex: ext.Mutex,
|
||||
Tags: ext.Tags,
|
||||
IsStatic: true,
|
||||
StaticNode: *ext,
|
||||
IsUserNode: ext.RemoteAccessClientID != "" || ext.DeviceID != "",
|
||||
Mutex: ext.Mutex,
|
||||
CountryCode: ext.Country,
|
||||
Location: ext.Location,
|
||||
PostureChecksViolations: ext.PostureChecksViolations,
|
||||
PostureCheckVolationSeverityLevel: ext.PostureCheckVolationSeverityLevel,
|
||||
LastEvaluatedAt: ext.LastEvaluatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -51,6 +51,9 @@ type Host struct {
|
||||
HostPass string `json:"hostpass" yaml:"hostpass"`
|
||||
Name string `json:"name" yaml:"name"`
|
||||
OS string `json:"os" yaml:"os"`
|
||||
OSFamily string `json:"os_family" yaml:"os_family"`
|
||||
OSVersion string `json:"os_version" yaml:"os_version"`
|
||||
KernelVersion string `json:"kernel_version" yaml:"kernel_version"`
|
||||
Interface string `json:"interface" yaml:"interface"`
|
||||
Debug bool `json:"debug" yaml:"debug"`
|
||||
ListenPort int `json:"listenport" yaml:"listenport"`
|
||||
@@ -74,6 +77,7 @@ type Host struct {
|
||||
TurnEndpoint *netip.AddrPort `json:"turn_endpoint,omitempty" yaml:"turn_endpoint,omitempty"`
|
||||
PersistentKeepalive time.Duration `json:"persistentkeepalive" swaggertype:"primitive,integer" format:"int64" yaml:"persistentkeepalive"`
|
||||
Location string `json:"location"` // Format: "lat,lon"
|
||||
CountryCode string `json:"country_code"`
|
||||
}
|
||||
|
||||
// FormatBool converts a boolean to a [yes|no] string
|
||||
|
||||
+18
-13
@@ -113,19 +113,24 @@ type Node struct {
|
||||
//AutoRelayedPeers map[string]struct{} `json:"auto_relayed_peers"`
|
||||
AutoRelayedPeers map[string]string `json:"auto_relayed_peers_v1"`
|
||||
//AutoRelayedBy uuid.UUID `json:"auto_relayed_by"`
|
||||
FailOverPeers map[string]struct{} `json:"fail_over_peers"`
|
||||
FailedOverBy uuid.UUID `json:"failed_over_by"`
|
||||
IsInternetGateway bool `json:"isinternetgateway"`
|
||||
InetNodeReq InetNodeReq `json:"inet_node_req"`
|
||||
InternetGwID string `json:"internetgw_node_id"`
|
||||
AdditionalRagIps []net.IP `json:"additional_rag_ips" swaggertype:"array,number"`
|
||||
Tags map[TagID]struct{} `json:"tags"`
|
||||
IsStatic bool `json:"is_static"`
|
||||
IsUserNode bool `json:"is_user_node"`
|
||||
StaticNode ExtClient `json:"static_node"`
|
||||
Status NodeStatus `json:"node_status"`
|
||||
Mutex *sync.Mutex `json:"-"`
|
||||
EgressDetails EgressDetails `json:"-"`
|
||||
FailOverPeers map[string]struct{} `json:"fail_over_peers"`
|
||||
FailedOverBy uuid.UUID `json:"failed_over_by"`
|
||||
IsInternetGateway bool `json:"isinternetgateway"`
|
||||
InetNodeReq InetNodeReq `json:"inet_node_req"`
|
||||
InternetGwID string `json:"internetgw_node_id"`
|
||||
AdditionalRagIps []net.IP `json:"additional_rag_ips" swaggertype:"array,number"`
|
||||
Tags map[TagID]struct{} `json:"tags"`
|
||||
IsStatic bool `json:"is_static"`
|
||||
IsUserNode bool `json:"is_user_node"`
|
||||
StaticNode ExtClient `json:"static_node"`
|
||||
Status NodeStatus `json:"node_status"`
|
||||
Mutex *sync.Mutex `json:"-"`
|
||||
EgressDetails EgressDetails `json:"-"`
|
||||
PostureChecksViolations []Violation `json:"posture_check_violations"`
|
||||
PostureCheckVolationSeverityLevel Severity `json:"posture_check_violation_severity_level"`
|
||||
LastEvaluatedAt time.Time `json:"last_evaluated_at"`
|
||||
Location string `json:"location"` // Format: "lat,lon"
|
||||
CountryCode string `json:"country_code"`
|
||||
}
|
||||
type EgressDetails struct {
|
||||
EgressGatewayNatEnabled bool
|
||||
|
||||
@@ -50,6 +50,7 @@ type ServerSettings struct {
|
||||
AuditLogsRetentionPeriodInDays int `json:"audit_logs_retention_period"`
|
||||
OldAClsSupport bool `json:"old_acl_support"`
|
||||
PeerConnectionCheckInterval string `json:"peer_connection_check_interval"`
|
||||
PostureCheckInterval string `json:"posture_check_interval"` // in minutes
|
||||
CleanUpInterval int `json:"clean_up_interval_in_mins"`
|
||||
}
|
||||
|
||||
|
||||
+32
-2
@@ -116,8 +116,9 @@ type SuccessfulLoginResponse struct {
|
||||
|
||||
// ErrorResponse is struct for error
|
||||
type ErrorResponse struct {
|
||||
Code int
|
||||
Message string
|
||||
Code int
|
||||
Message string
|
||||
Response interface{}
|
||||
}
|
||||
|
||||
// NodeAuth - struct for node auth
|
||||
@@ -457,3 +458,32 @@ type IDPSyncTestRequest struct {
|
||||
OktaOrgURL string `json:"okta_org_url"`
|
||||
OktaAPIToken string `json:"okta_api_token"`
|
||||
}
|
||||
|
||||
type PostureCheckDeviceInfo struct {
|
||||
ClientLocation string
|
||||
ClientVersion string
|
||||
OS string
|
||||
OSFamily string
|
||||
OSVersion string
|
||||
KernelVersion string
|
||||
AutoUpdate bool
|
||||
Tags map[TagID]struct{}
|
||||
}
|
||||
|
||||
type Violation struct {
|
||||
CheckID string `json:"check_id"`
|
||||
Name string `json:"name"`
|
||||
Attribute string `json:"attribute"`
|
||||
Message string `json:"message"`
|
||||
Severity Severity `json:"severity"`
|
||||
}
|
||||
|
||||
type Severity int
|
||||
|
||||
const (
|
||||
SeverityUnknown Severity = iota
|
||||
SeverityLow
|
||||
SeverityMedium
|
||||
SeverityHigh
|
||||
SeverityCritical
|
||||
)
|
||||
|
||||
+11
-3
@@ -296,13 +296,15 @@ func HandleHostCheckin(h, currentHost *models.Host) bool {
|
||||
(len(h.NatType) > 0 && h.NatType != currentHost.NatType) ||
|
||||
h.DefaultInterface != currentHost.DefaultInterface ||
|
||||
(h.ListenPort != 0 && h.ListenPort != currentHost.ListenPort) ||
|
||||
(h.WgPublicListenPort != 0 && h.WgPublicListenPort != currentHost.WgPublicListenPort) || (!h.EndpointIPv6.Equal(currentHost.EndpointIPv6))
|
||||
(h.WgPublicListenPort != 0 && h.WgPublicListenPort != currentHost.WgPublicListenPort) ||
|
||||
(!h.EndpointIPv6.Equal(currentHost.EndpointIPv6)) || (h.OSFamily != currentHost.OSFamily) ||
|
||||
(h.OSVersion != currentHost.OSVersion) || (h.KernelVersion != currentHost.KernelVersion)
|
||||
if ifaceDelta { // only save if something changes
|
||||
if !h.EndpointIP.Equal(currentHost.EndpointIP) || !h.EndpointIPv6.Equal(currentHost.EndpointIPv6) || currentHost.Location == "" {
|
||||
if h.EndpointIP != nil {
|
||||
h.Location = logic.GetHostLocInfo(h.EndpointIP.String(), os.Getenv("IP_INFO_TOKEN"))
|
||||
h.Location, h.CountryCode = logic.GetHostLocInfo(h.EndpointIP.String(), os.Getenv("IP_INFO_TOKEN"))
|
||||
} else if h.EndpointIPv6 != nil {
|
||||
h.Location = logic.GetHostLocInfo(h.EndpointIPv6.String(), os.Getenv("IP_INFO_TOKEN"))
|
||||
h.Location, h.CountryCode = logic.GetHostLocInfo(h.EndpointIPv6.String(), os.Getenv("IP_INFO_TOKEN"))
|
||||
}
|
||||
}
|
||||
currentHost.EndpointIP = h.EndpointIP
|
||||
@@ -310,9 +312,15 @@ func HandleHostCheckin(h, currentHost *models.Host) bool {
|
||||
currentHost.Interfaces = h.Interfaces
|
||||
currentHost.DefaultInterface = h.DefaultInterface
|
||||
currentHost.NatType = h.NatType
|
||||
currentHost.OSFamily = h.OSFamily
|
||||
currentHost.OSVersion = h.OSVersion
|
||||
currentHost.KernelVersion = h.KernelVersion
|
||||
if h.Location != "" {
|
||||
currentHost.Location = h.Location
|
||||
}
|
||||
if h.CountryCode != "" {
|
||||
currentHost.CountryCode = h.CountryCode
|
||||
}
|
||||
if h.ListenPort != 0 {
|
||||
currentHost.ListenPort = h.ListenPort
|
||||
}
|
||||
|
||||
@@ -0,0 +1,343 @@
|
||||
package controllers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/gravitl/netmaker/db"
|
||||
"github.com/gravitl/netmaker/logger"
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/gravitl/netmaker/mq"
|
||||
proLogic "github.com/gravitl/netmaker/pro/logic"
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
)
|
||||
|
||||
func PostureCheckHandlers(r *mux.Router) {
|
||||
r.HandleFunc("/api/v1/posture_check", logic.SecurityCheck(true, http.HandlerFunc(createPostureCheck))).Methods(http.MethodPost)
|
||||
r.HandleFunc("/api/v1/posture_check", logic.SecurityCheck(true, http.HandlerFunc(listPostureChecks))).Methods(http.MethodGet)
|
||||
r.HandleFunc("/api/v1/posture_check", logic.SecurityCheck(true, http.HandlerFunc(updatePostureCheck))).Methods(http.MethodPut)
|
||||
r.HandleFunc("/api/v1/posture_check", logic.SecurityCheck(true, http.HandlerFunc(deletePostureCheck))).Methods(http.MethodDelete)
|
||||
r.HandleFunc("/api/v1/posture_check/attrs", logic.SecurityCheck(true, http.HandlerFunc(listPostureChecksAttrs))).Methods(http.MethodGet)
|
||||
r.HandleFunc("/api/v1/posture_check/violations", logic.SecurityCheck(true, http.HandlerFunc(listPostureCheckViolatedNodes))).Methods(http.MethodGet)
|
||||
}
|
||||
|
||||
// @Summary List Posture Checks Available Attributes
|
||||
// @Router /api/v1/posture_check/attrs [get]
|
||||
// @Tags Auth
|
||||
// @Accept json
|
||||
// @Param query network string
|
||||
// @Success 200 {object} models.SuccessResponse
|
||||
// @Failure 400 {object} models.ErrorResponse
|
||||
// @Failure 401 {object} models.ErrorResponse
|
||||
// @Failure 500 {object} models.ErrorResponse
|
||||
func listPostureChecksAttrs(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
logic.ReturnSuccessResponseWithJson(w, r, schema.PostureCheckAttrValues, "fetched posture checks")
|
||||
}
|
||||
|
||||
// @Summary Create Posture Check
|
||||
// @Router /api/v1/posture_check [post]
|
||||
// @Tags DNS
|
||||
// @Accept json
|
||||
// @Param body body schema.PostureCheck
|
||||
// @Success 200 {object} models.SuccessResponse
|
||||
// @Failure 400 {object} models.ErrorResponse
|
||||
// @Failure 401 {object} models.ErrorResponse
|
||||
// @Failure 500 {object} models.ErrorResponse
|
||||
func createPostureCheck(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
var req schema.PostureCheck
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
logger.Log(0, "error decoding request body: ",
|
||||
err.Error())
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
|
||||
return
|
||||
}
|
||||
if err := proLogic.ValidatePostureCheck(&req); err != nil {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
|
||||
return
|
||||
}
|
||||
|
||||
pc := schema.PostureCheck{
|
||||
ID: uuid.New().String(),
|
||||
Name: req.Name,
|
||||
NetworkID: req.NetworkID,
|
||||
Description: req.Description,
|
||||
Tags: req.Tags,
|
||||
Attribute: req.Attribute,
|
||||
Values: req.Values,
|
||||
Severity: req.Severity,
|
||||
Status: true,
|
||||
CreatedBy: r.Header.Get("user"),
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
err = pc.Create(db.WithContext(r.Context()))
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(
|
||||
w,
|
||||
r,
|
||||
logic.FormatError(errors.New("error creating posture check "+err.Error()), logic.Internal),
|
||||
)
|
||||
return
|
||||
}
|
||||
logic.LogEvent(&models.Event{
|
||||
Action: models.Create,
|
||||
Source: models.Subject{
|
||||
ID: r.Header.Get("user"),
|
||||
Name: r.Header.Get("user"),
|
||||
Type: models.UserSub,
|
||||
},
|
||||
TriggeredBy: r.Header.Get("user"),
|
||||
Target: models.Subject{
|
||||
ID: pc.ID,
|
||||
Name: pc.Name,
|
||||
Type: models.PostureCheckSub,
|
||||
},
|
||||
NetworkID: models.NetworkID(pc.NetworkID),
|
||||
Origin: models.Dashboard,
|
||||
})
|
||||
|
||||
go mq.PublishPeerUpdate(false)
|
||||
go proLogic.RunPostureChecks()
|
||||
logic.ReturnSuccessResponseWithJson(w, r, pc, "created posture check")
|
||||
}
|
||||
|
||||
// @Summary List Posture Checks
|
||||
// @Router /api/v1/posture_check [get]
|
||||
// @Tags Auth
|
||||
// @Accept json
|
||||
// @Param query network string
|
||||
// @Success 200 {object} models.SuccessResponse
|
||||
// @Failure 400 {object} models.ErrorResponse
|
||||
// @Failure 401 {object} models.ErrorResponse
|
||||
// @Failure 500 {object} models.ErrorResponse
|
||||
func listPostureChecks(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
network := r.URL.Query().Get("network")
|
||||
if network == "" {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("network is required"), logic.BadReq))
|
||||
return
|
||||
}
|
||||
_, err := logic.GetNetwork(network)
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("network not found"), logic.BadReq))
|
||||
return
|
||||
}
|
||||
id := r.URL.Query().Get("id")
|
||||
if id != "" {
|
||||
pc := schema.PostureCheck{ID: id}
|
||||
err := pc.Get(db.WithContext(r.Context()))
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(
|
||||
w,
|
||||
r,
|
||||
logic.FormatError(errors.New("error listing posture checks "+err.Error()), "internal"),
|
||||
)
|
||||
return
|
||||
}
|
||||
logic.ReturnSuccessResponseWithJson(w, r, pc, "fetched posture check")
|
||||
return
|
||||
}
|
||||
pc := schema.PostureCheck{NetworkID: network}
|
||||
list, err := pc.ListByNetwork(db.WithContext(r.Context()))
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(
|
||||
w,
|
||||
r,
|
||||
logic.FormatError(errors.New("error listing posture checks "+err.Error()), "internal"),
|
||||
)
|
||||
return
|
||||
}
|
||||
logic.ReturnSuccessResponseWithJson(w, r, list, "fetched posture checks")
|
||||
}
|
||||
|
||||
// @Summary Update Posture Check
|
||||
// @Router /api/v1/posture_check [put]
|
||||
// @Tags Auth
|
||||
// @Accept json
|
||||
// @Param body body schema.PostureCheck
|
||||
// @Success 200 {object} models.SuccessResponse
|
||||
// @Failure 400 {object} models.ErrorResponse
|
||||
// @Failure 401 {object} models.ErrorResponse
|
||||
// @Failure 500 {object} models.ErrorResponse
|
||||
func updatePostureCheck(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
var updatePc schema.PostureCheck
|
||||
err := json.NewDecoder(r.Body).Decode(&updatePc)
|
||||
if err != nil {
|
||||
logger.Log(0, "error decoding request body: ",
|
||||
err.Error())
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
|
||||
return
|
||||
}
|
||||
|
||||
if err := proLogic.ValidatePostureCheck(&updatePc); err != nil {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
|
||||
return
|
||||
}
|
||||
|
||||
pc := schema.PostureCheck{ID: updatePc.ID}
|
||||
err = pc.Get(db.WithContext(r.Context()))
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
|
||||
return
|
||||
}
|
||||
var updateStatus bool
|
||||
if updatePc.Status != pc.Status {
|
||||
updateStatus = true
|
||||
}
|
||||
event := &models.Event{
|
||||
Action: models.Update,
|
||||
Source: models.Subject{
|
||||
ID: r.Header.Get("user"),
|
||||
Name: r.Header.Get("user"),
|
||||
Type: models.UserSub,
|
||||
},
|
||||
TriggeredBy: r.Header.Get("user"),
|
||||
Target: models.Subject{
|
||||
ID: pc.ID,
|
||||
Name: updatePc.Name,
|
||||
Type: models.PostureCheckSub,
|
||||
},
|
||||
Diff: models.Diff{
|
||||
Old: pc,
|
||||
New: updatePc,
|
||||
},
|
||||
NetworkID: models.NetworkID(pc.NetworkID),
|
||||
Origin: models.Dashboard,
|
||||
}
|
||||
pc.Tags = updatePc.Tags
|
||||
pc.Attribute = updatePc.Attribute
|
||||
pc.Values = updatePc.Values
|
||||
pc.Description = updatePc.Description
|
||||
pc.Name = updatePc.Name
|
||||
pc.Severity = updatePc.Severity
|
||||
pc.Status = updatePc.Status
|
||||
pc.UpdatedAt = time.Now().UTC()
|
||||
|
||||
err = pc.Update(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(
|
||||
w,
|
||||
r,
|
||||
logic.FormatError(errors.New("error updating posture check "+err.Error()), "internal"),
|
||||
)
|
||||
return
|
||||
}
|
||||
if updateStatus {
|
||||
pc.UpdateStatus(db.WithContext(context.TODO()))
|
||||
}
|
||||
logic.LogEvent(event)
|
||||
go mq.PublishPeerUpdate(false)
|
||||
go proLogic.RunPostureChecks()
|
||||
logic.ReturnSuccessResponseWithJson(w, r, pc, "updated posture check")
|
||||
}
|
||||
|
||||
// @Summary Delete Posture Check
|
||||
// @Router /api/v1/posture_check [delete]
|
||||
// @Tags Auth
|
||||
// @Accept json
|
||||
// @Param query id string
|
||||
// @Success 200 {object} models.SuccessResponse
|
||||
// @Failure 400 {object} models.ErrorResponse
|
||||
// @Failure 401 {object} models.ErrorResponse
|
||||
// @Failure 500 {object} models.ErrorResponse
|
||||
func deletePostureCheck(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
id := r.URL.Query().Get("id")
|
||||
if id == "" {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("id is required"), "badrequest"))
|
||||
return
|
||||
}
|
||||
pc := schema.PostureCheck{ID: id}
|
||||
err := pc.Get(db.WithContext(r.Context()))
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.BadReq))
|
||||
return
|
||||
}
|
||||
err = pc.Delete(db.WithContext(r.Context()))
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.Internal))
|
||||
return
|
||||
}
|
||||
logic.LogEvent(&models.Event{
|
||||
Action: models.Delete,
|
||||
Source: models.Subject{
|
||||
ID: r.Header.Get("user"),
|
||||
Name: r.Header.Get("user"),
|
||||
Type: models.UserSub,
|
||||
},
|
||||
TriggeredBy: r.Header.Get("user"),
|
||||
Target: models.Subject{
|
||||
ID: pc.ID,
|
||||
Name: pc.Name,
|
||||
Type: models.PostureCheckSub,
|
||||
},
|
||||
NetworkID: models.NetworkID(pc.NetworkID),
|
||||
Origin: models.Dashboard,
|
||||
Diff: models.Diff{
|
||||
Old: pc,
|
||||
New: nil,
|
||||
},
|
||||
})
|
||||
|
||||
go mq.PublishPeerUpdate(false)
|
||||
logic.ReturnSuccessResponseWithJson(w, r, pc, "deleted posture check")
|
||||
}
|
||||
|
||||
// @Summary List Posture Check violated Nodes
|
||||
// @Router /api/v1/posture_check/violations [get]
|
||||
// @Tags Auth
|
||||
// @Accept json
|
||||
// @Param query network string
|
||||
// @Success 200 {object} models.SuccessResponse
|
||||
// @Failure 400 {object} models.ErrorResponse
|
||||
// @Failure 401 {object} models.ErrorResponse
|
||||
// @Failure 500 {object} models.ErrorResponse
|
||||
func listPostureCheckViolatedNodes(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
network := r.URL.Query().Get("network")
|
||||
if network == "" {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("network is required"), logic.BadReq))
|
||||
return
|
||||
}
|
||||
listViolatedusers := r.URL.Query().Get("users") == "true"
|
||||
violatedNodes := []models.Node{}
|
||||
if listViolatedusers {
|
||||
extclients, err := logic.GetNetworkExtClients(network)
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.BadReq))
|
||||
return
|
||||
}
|
||||
for _, extclient := range extclients {
|
||||
if extclient.DeviceID != "" && extclient.Enabled {
|
||||
if len(extclient.PostureChecksViolations) > 0 {
|
||||
violatedNodes = append(violatedNodes, extclient.ConvertToStaticNode())
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
nodes, err := logic.GetNetworkNodes(network)
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.BadReq))
|
||||
return
|
||||
}
|
||||
|
||||
for _, node := range nodes {
|
||||
if len(node.PostureChecksViolations) > 0 {
|
||||
violatedNodes = append(violatedNodes, node)
|
||||
}
|
||||
}
|
||||
}
|
||||
apiNodes := logic.GetAllNodesAPI(violatedNodes)
|
||||
logic.SortApiNodes(apiNodes[:])
|
||||
logic.ReturnSuccessResponseWithJson(w, r, apiNodes, "fetched posture checks violated nodes")
|
||||
}
|
||||
@@ -241,7 +241,7 @@ func updateTag(w http.ResponseWriter, r *http.Request) {
|
||||
UsedByCnt: len(updateTag.TaggedNodes),
|
||||
TaggedNodes: updateTag.TaggedNodes,
|
||||
}
|
||||
|
||||
go proLogic.RunPostureChecks()
|
||||
logic.ReturnSuccessResponseWithJson(w, r, res, "updated tags")
|
||||
}
|
||||
|
||||
|
||||
+3
-1
@@ -37,6 +37,7 @@ func InitPro() {
|
||||
proControllers.TagHandlers,
|
||||
proControllers.NetworkHandlers,
|
||||
proControllers.AutoRelayHandlers,
|
||||
proControllers.PostureCheckHandlers,
|
||||
)
|
||||
controller.ListRoles = proControllers.ListRoles
|
||||
logic.EnterpriseCheckFuncs = append(logic.EnterpriseCheckFuncs, func() {
|
||||
@@ -99,8 +100,8 @@ func InitPro() {
|
||||
auth.ResetIDPSyncHook()
|
||||
email.Init()
|
||||
go proLogic.EventWatcher()
|
||||
|
||||
logic.GetMetricsMonitor().Start()
|
||||
proLogic.AddPostureCheckHook()
|
||||
})
|
||||
logic.ResetFailOver = proLogic.ResetFailOver
|
||||
logic.ResetFailedOverPeer = proLogic.ResetFailedOverPeer
|
||||
@@ -172,6 +173,7 @@ func InitPro() {
|
||||
logic.GetNameserversForNode = proLogic.GetNameserversForNode
|
||||
logic.ValidateNameserverReq = proLogic.ValidateNameserverReq
|
||||
logic.ValidateEgressReq = proLogic.ValidateEgressReq
|
||||
logic.CheckPostureViolations = proLogic.CheckPostureViolations
|
||||
|
||||
}
|
||||
|
||||
|
||||
@@ -240,7 +240,7 @@ func updateNodeMetrics(currentNode *models.Node, newMetrics *models.Metrics) {
|
||||
slog.Debug("[metrics] node metrics data", "node ID", currentNode.ID, "metrics", newMetrics)
|
||||
}
|
||||
|
||||
func GetHostLocInfo(ip, token string) string {
|
||||
func GetHostLocInfo(ip, token string) (loc, country string) {
|
||||
url := "https://ipinfo.io/"
|
||||
if ip != "" {
|
||||
url += ip
|
||||
@@ -253,15 +253,18 @@ func GetHostLocInfo(ip, token string) string {
|
||||
client := http.Client{Timeout: 3 * time.Second}
|
||||
resp, err := client.Get(url)
|
||||
if err != nil {
|
||||
return ""
|
||||
return "", ""
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var data struct {
|
||||
Loc string `json:"loc"` // Format: "lat,lon"
|
||||
Loc string `json:"loc"` // Format: "lat,lon"
|
||||
Country string `json:"country"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
|
||||
return ""
|
||||
return "", ""
|
||||
}
|
||||
return data.Loc
|
||||
loc = data.Loc
|
||||
country = data.Country
|
||||
return
|
||||
}
|
||||
|
||||
@@ -0,0 +1,464 @@
|
||||
package logic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/biter777/countries"
|
||||
"github.com/gravitl/netmaker/db"
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
"gorm.io/datatypes"
|
||||
)
|
||||
|
||||
var postureCheckMutex = &sync.Mutex{}
|
||||
|
||||
func AddPostureCheckHook() {
|
||||
settings := logic.GetServerSettings()
|
||||
interval := time.Hour
|
||||
i, err := strconv.Atoi(settings.PostureCheckInterval)
|
||||
if err == nil {
|
||||
interval = time.Minute * time.Duration(i)
|
||||
}
|
||||
logic.HookManagerCh <- models.HookDetails{
|
||||
Hook: logic.WrapHook(RunPostureChecks),
|
||||
Interval: interval,
|
||||
}
|
||||
}
|
||||
func RunPostureChecks() error {
|
||||
postureCheckMutex.Lock()
|
||||
defer postureCheckMutex.Unlock()
|
||||
nets, err := logic.GetNetworks()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
nodes, err := logic.GetAllNodes()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, netI := range nets {
|
||||
networkNodes := logic.GetNetworkNodesMemory(nodes, netI.NetID)
|
||||
if len(networkNodes) == 0 {
|
||||
continue
|
||||
}
|
||||
networkNodes = logic.AddStaticNodestoList(networkNodes)
|
||||
pcLi, err := (&schema.PostureCheck{NetworkID: netI.NetID}).ListByNetwork(db.WithContext(context.TODO()))
|
||||
if err != nil || len(pcLi) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, nodeI := range networkNodes {
|
||||
postureChecksViolations, postureCheckVolationSeverityLevel := GetPostureCheckViolations(pcLi, logic.GetPostureCheckDeviceInfoByNode(&nodeI))
|
||||
if nodeI.IsStatic {
|
||||
extclient, err := logic.GetExtClient(nodeI.StaticNode.ClientID, nodeI.StaticNode.Network)
|
||||
if err == nil {
|
||||
extclient.PostureChecksViolations = postureChecksViolations
|
||||
extclient.PostureCheckVolationSeverityLevel = postureCheckVolationSeverityLevel
|
||||
extclient.LastEvaluatedAt = time.Now().UTC()
|
||||
logic.SaveExtClient(&extclient)
|
||||
}
|
||||
} else {
|
||||
nodeI.PostureChecksViolations, nodeI.PostureCheckVolationSeverityLevel = postureChecksViolations,
|
||||
postureCheckVolationSeverityLevel
|
||||
nodeI.LastEvaluatedAt = time.Now().UTC()
|
||||
logic.UpsertNode(&nodeI)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func CheckPostureViolations(d models.PostureCheckDeviceInfo, network models.NetworkID) ([]models.Violation, models.Severity) {
|
||||
pcLi, err := (&schema.PostureCheck{NetworkID: network.String()}).ListByNetwork(db.WithContext(context.TODO()))
|
||||
if err != nil || len(pcLi) == 0 {
|
||||
return []models.Violation{}, models.SeverityUnknown
|
||||
}
|
||||
violations, level := GetPostureCheckViolations(pcLi, d)
|
||||
return violations, level
|
||||
}
|
||||
func GetPostureCheckViolations(checks []schema.PostureCheck, d models.PostureCheckDeviceInfo) ([]models.Violation, models.Severity) {
|
||||
var violations []models.Violation
|
||||
highest := models.SeverityUnknown
|
||||
|
||||
// Group checks by attribute
|
||||
checksByAttribute := make(map[schema.Attribute][]schema.PostureCheck)
|
||||
for _, c := range checks {
|
||||
// skip disabled checks
|
||||
if !c.Status {
|
||||
continue
|
||||
}
|
||||
// Check if tags match
|
||||
if _, ok := c.Tags["*"]; !ok {
|
||||
exists := false
|
||||
for tagID := range c.Tags {
|
||||
if _, ok := d.Tags[models.TagID(tagID)]; ok {
|
||||
exists = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !exists {
|
||||
continue
|
||||
}
|
||||
}
|
||||
checksByAttribute[c.Attribute] = append(checksByAttribute[c.Attribute], c)
|
||||
}
|
||||
|
||||
// Handle OS and OSFamily together with OR logic since they are related
|
||||
osChecks := checksByAttribute[schema.OS]
|
||||
osFamilyChecks := checksByAttribute[schema.OSFamily]
|
||||
if len(osChecks) > 0 || len(osFamilyChecks) > 0 {
|
||||
osAllowed := evaluateAttributeChecks(osChecks, schema.OS, d)
|
||||
osFamilyAllowed := evaluateAttributeChecks(osFamilyChecks, schema.OSFamily, d)
|
||||
|
||||
// OR condition: if either OS or OSFamily passes, both are considered passed
|
||||
if !osAllowed && !osFamilyAllowed {
|
||||
|
||||
// Both failed, add violations for both
|
||||
osDenied := getDeniedChecks(osChecks, schema.OS, d)
|
||||
osFamilyDenied := getDeniedChecks(osFamilyChecks, schema.OSFamily, d)
|
||||
|
||||
for _, denied := range osDenied {
|
||||
sev := denied.check.Severity
|
||||
if sev > highest {
|
||||
highest = sev
|
||||
}
|
||||
v := models.Violation{
|
||||
CheckID: denied.check.ID,
|
||||
Name: denied.check.Name,
|
||||
Attribute: string(denied.check.Attribute),
|
||||
Message: denied.reason,
|
||||
Severity: sev,
|
||||
}
|
||||
violations = append(violations, v)
|
||||
}
|
||||
for _, denied := range osFamilyDenied {
|
||||
sev := denied.check.Severity
|
||||
if sev > highest {
|
||||
highest = sev
|
||||
}
|
||||
v := models.Violation{
|
||||
CheckID: denied.check.ID,
|
||||
Name: denied.check.Name,
|
||||
Attribute: string(denied.check.Attribute),
|
||||
Message: denied.reason,
|
||||
Severity: sev,
|
||||
}
|
||||
violations = append(violations, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// For all other attributes, check if ANY check allows it
|
||||
for attr, attrChecks := range checksByAttribute {
|
||||
// Skip OS and OSFamily as they are handled above
|
||||
if attr == schema.OS || attr == schema.OSFamily {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if any check for this attribute allows the device
|
||||
allowed := false
|
||||
var deniedChecks []struct {
|
||||
check schema.PostureCheck
|
||||
reason string
|
||||
}
|
||||
|
||||
for _, c := range attrChecks {
|
||||
violated, reason := evaluatePostureCheck(&c, d)
|
||||
if !violated {
|
||||
// At least one check allows it
|
||||
allowed = true
|
||||
break
|
||||
}
|
||||
// Track denied checks with their reasons for violation reporting
|
||||
deniedChecks = append(deniedChecks, struct {
|
||||
check schema.PostureCheck
|
||||
reason string
|
||||
}{check: c, reason: reason})
|
||||
}
|
||||
|
||||
// If no check allows it, add violations for all denied checks
|
||||
if !allowed {
|
||||
for _, denied := range deniedChecks {
|
||||
sev := denied.check.Severity
|
||||
if sev > highest {
|
||||
highest = sev
|
||||
}
|
||||
|
||||
v := models.Violation{
|
||||
CheckID: denied.check.ID,
|
||||
Name: denied.check.Name,
|
||||
Attribute: string(denied.check.Attribute),
|
||||
Message: denied.reason,
|
||||
Severity: sev,
|
||||
}
|
||||
violations = append(violations, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return violations, highest
|
||||
}
|
||||
|
||||
// evaluateAttributeChecks evaluates checks for a specific attribute and returns true if any check allows the device
|
||||
func evaluateAttributeChecks(attrChecks []schema.PostureCheck, attr schema.Attribute, d models.PostureCheckDeviceInfo) bool {
|
||||
for _, c := range attrChecks {
|
||||
violated, _ := evaluatePostureCheck(&c, d)
|
||||
if !violated {
|
||||
// At least one check allows it
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// getDeniedChecks returns all checks that denied the device for a specific attribute
|
||||
func getDeniedChecks(attrChecks []schema.PostureCheck, attr schema.Attribute, d models.PostureCheckDeviceInfo) []struct {
|
||||
check schema.PostureCheck
|
||||
reason string
|
||||
} {
|
||||
var deniedChecks []struct {
|
||||
check schema.PostureCheck
|
||||
reason string
|
||||
}
|
||||
|
||||
for _, c := range attrChecks {
|
||||
violated, reason := evaluatePostureCheck(&c, d)
|
||||
if violated {
|
||||
deniedChecks = append(deniedChecks, struct {
|
||||
check schema.PostureCheck
|
||||
reason string
|
||||
}{check: c, reason: reason})
|
||||
}
|
||||
}
|
||||
return deniedChecks
|
||||
}
|
||||
|
||||
func evaluatePostureCheck(check *schema.PostureCheck, d models.PostureCheckDeviceInfo) (violated bool, reason string) {
|
||||
switch check.Attribute {
|
||||
|
||||
// ------------------------
|
||||
// 1. Geographic check
|
||||
// ------------------------
|
||||
case schema.ClientLocation:
|
||||
if !slices.Contains(check.Values, strings.ToUpper(d.ClientLocation)) {
|
||||
return true, fmt.Sprintf("client location '%s' not allowed", CountryNameFromISO(d.ClientLocation))
|
||||
}
|
||||
|
||||
// ------------------------
|
||||
// 2. Client version check
|
||||
// Supports: exact match OR allowed list OR semver rules
|
||||
// ------------------------
|
||||
case schema.ClientVersion:
|
||||
for _, rule := range check.Values {
|
||||
ok, err := matchVersionRule(d.ClientVersion, rule)
|
||||
if err != nil || !ok {
|
||||
return true, fmt.Sprintf("client version '%s' violation", d.ClientVersion)
|
||||
}
|
||||
}
|
||||
|
||||
// ------------------------
|
||||
// 3. OS check
|
||||
// ("windows", "mac", "linux", etc.)
|
||||
// ------------------------
|
||||
case schema.OS:
|
||||
if !slices.Contains(check.Values, d.OS) {
|
||||
return true, fmt.Sprintf("client os '%s' not allowed", d.OS)
|
||||
}
|
||||
case schema.OSFamily:
|
||||
if !slices.Contains(check.Values, d.OSFamily) {
|
||||
return true, fmt.Sprintf("os family '%s' not allowed", d.OSFamily)
|
||||
}
|
||||
// ------------------------
|
||||
// 4. OS version check
|
||||
// Supports operators: > >= < <= =
|
||||
// ------------------------
|
||||
case schema.OSVersion:
|
||||
for _, rule := range check.Values {
|
||||
ok, err := matchVersionRule(d.OSVersion, rule)
|
||||
if err != nil || !ok {
|
||||
return true, fmt.Sprintf("os version '%s' violation", d.OSVersion)
|
||||
}
|
||||
}
|
||||
case schema.KernelVersion:
|
||||
for _, rule := range check.Values {
|
||||
ok, err := matchVersionRule(d.KernelVersion, rule)
|
||||
if err != nil || !ok {
|
||||
return true, fmt.Sprintf("kernel version '%s' violation", d.KernelVersion)
|
||||
}
|
||||
}
|
||||
// ------------------------
|
||||
// 5. Auto-update check
|
||||
// Values: ["true"] or ["false"]
|
||||
// ------------------------
|
||||
case schema.AutoUpdate:
|
||||
required := len(check.Values) > 0 && strings.ToLower(check.Values[0]) == "true"
|
||||
if required && !d.AutoUpdate {
|
||||
return true, "auto update must be enabled"
|
||||
}
|
||||
if !required && d.AutoUpdate {
|
||||
return true, "auto update must be disabled"
|
||||
}
|
||||
}
|
||||
|
||||
return false, ""
|
||||
}
|
||||
func cleanVersion(v string) string {
|
||||
v = strings.TrimSpace(v)
|
||||
v = strings.TrimPrefix(v, "v")
|
||||
v = strings.TrimPrefix(v, "V")
|
||||
v = strings.TrimSuffix(v, ",")
|
||||
v = strings.TrimSpace(v)
|
||||
return v
|
||||
}
|
||||
|
||||
func matchVersionRule(actual, rule string) (bool, error) {
|
||||
actual = cleanVersion(actual)
|
||||
rule = strings.TrimSpace(rule)
|
||||
|
||||
op := "="
|
||||
|
||||
switch {
|
||||
case strings.HasPrefix(rule, ">="):
|
||||
op = ">="
|
||||
rule = strings.TrimPrefix(rule, ">=")
|
||||
case strings.HasPrefix(rule, "<="):
|
||||
op = "<="
|
||||
rule = strings.TrimPrefix(rule, "<=")
|
||||
case strings.HasPrefix(rule, ">"):
|
||||
op = ">"
|
||||
rule = strings.TrimPrefix(rule, ">")
|
||||
case strings.HasPrefix(rule, "<"):
|
||||
op = "<"
|
||||
rule = strings.TrimPrefix(rule, "<")
|
||||
case strings.HasPrefix(rule, "="):
|
||||
op = "="
|
||||
rule = strings.TrimPrefix(rule, "=")
|
||||
}
|
||||
|
||||
rule = cleanVersion(rule)
|
||||
|
||||
cmp := compareVersions(actual, rule)
|
||||
|
||||
switch op {
|
||||
case "=":
|
||||
return cmp == 0, nil
|
||||
case ">":
|
||||
return cmp == 1, nil
|
||||
case "<":
|
||||
return cmp == -1, nil
|
||||
case ">=":
|
||||
return cmp == 1 || cmp == 0, nil
|
||||
case "<=":
|
||||
return cmp == -1 || cmp == 0, nil
|
||||
}
|
||||
|
||||
return false, fmt.Errorf("invalid rule: %s", rule)
|
||||
}
|
||||
|
||||
func compareVersions(a, b string) int {
|
||||
pa := strings.Split(a, ".")
|
||||
pb := strings.Split(b, ".")
|
||||
|
||||
n := len(pa)
|
||||
if len(pb) > n {
|
||||
n = len(pb)
|
||||
}
|
||||
|
||||
for i := 0; i < n; i++ {
|
||||
ai, bi := 0, 0
|
||||
|
||||
if i < len(pa) {
|
||||
ai, _ = strconv.Atoi(pa[i])
|
||||
}
|
||||
if i < len(pb) {
|
||||
bi, _ = strconv.Atoi(pb[i])
|
||||
}
|
||||
|
||||
if ai > bi {
|
||||
return 1
|
||||
}
|
||||
if ai < bi {
|
||||
return -1
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func ValidatePostureCheck(pc *schema.PostureCheck) error {
|
||||
if pc.Name == "" {
|
||||
return errors.New("name cannot be empty")
|
||||
}
|
||||
_, err := logic.GetNetwork(pc.NetworkID)
|
||||
if err != nil {
|
||||
return errors.New("invalid network")
|
||||
}
|
||||
allowedAttrvaluesMap, ok := schema.PostureCheckAttrValuesMap[pc.Attribute]
|
||||
if !ok {
|
||||
return errors.New("unkown attribute")
|
||||
}
|
||||
if len(pc.Values) == 0 {
|
||||
return errors.New("attribute value cannot be empty")
|
||||
}
|
||||
for i, valueI := range pc.Values {
|
||||
pc.Values[i] = strings.ToLower(valueI)
|
||||
}
|
||||
if pc.Attribute == schema.ClientLocation {
|
||||
for i, loc := range pc.Values {
|
||||
if countries.ByName(loc) == countries.Unknown {
|
||||
return errors.New("invalid country code")
|
||||
}
|
||||
pc.Values[i] = strings.ToUpper(loc)
|
||||
}
|
||||
}
|
||||
if pc.Attribute == schema.AutoUpdate || pc.Attribute == schema.OS ||
|
||||
pc.Attribute == schema.OSFamily {
|
||||
for _, valueI := range pc.Values {
|
||||
if _, ok := allowedAttrvaluesMap[valueI]; !ok {
|
||||
return errors.New("invalid attribute value")
|
||||
}
|
||||
}
|
||||
}
|
||||
if pc.Attribute == schema.ClientVersion || pc.Attribute == schema.OSVersion ||
|
||||
pc.Attribute == schema.KernelVersion {
|
||||
for i, valueI := range pc.Values {
|
||||
if !logic.IsValidVersion(valueI) {
|
||||
return errors.New("invalid attribute version value")
|
||||
}
|
||||
pc.Values[i] = logic.CleanVersion(valueI)
|
||||
}
|
||||
}
|
||||
if len(pc.Tags) > 0 {
|
||||
for tagID := range pc.Tags {
|
||||
if tagID == "*" {
|
||||
continue
|
||||
}
|
||||
_, err := GetTag(models.TagID(tagID))
|
||||
if err != nil {
|
||||
return errors.New("unknown tag")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
pc.Tags = make(datatypes.JSONMap)
|
||||
pc.Tags["*"] = struct{}{}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func CountryNameFromISO(code string) string {
|
||||
c := countries.ByName(code) // works with ISO2, ISO3, full name
|
||||
if c == countries.Unknown {
|
||||
return ""
|
||||
}
|
||||
return c.Info().Name
|
||||
}
|
||||
@@ -9,5 +9,6 @@ func ListModels() []interface{} {
|
||||
&Event{},
|
||||
&PendingHost{},
|
||||
&Nameserver{},
|
||||
&PostureCheck{},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,123 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/gravitl/netmaker/db"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"gorm.io/datatypes"
|
||||
)
|
||||
|
||||
type Attribute string
|
||||
type Values string
|
||||
|
||||
const (
|
||||
OS Attribute = "os"
|
||||
OSVersion Attribute = "os_version"
|
||||
OSFamily Attribute = "os_family"
|
||||
KernelVersion Attribute = "kernel_version"
|
||||
AutoUpdate Attribute = "auto_update"
|
||||
ClientVersion Attribute = "client_version"
|
||||
ClientLocation Attribute = "client_location"
|
||||
)
|
||||
|
||||
var PostureCheckAttrs = []Attribute{
|
||||
ClientLocation,
|
||||
ClientVersion,
|
||||
OS,
|
||||
OSVersion,
|
||||
OSFamily,
|
||||
KernelVersion,
|
||||
AutoUpdate,
|
||||
}
|
||||
|
||||
var PostureCheckAttrValuesMap = map[Attribute]map[string]struct{}{
|
||||
ClientLocation: {
|
||||
"any_valid_iso_country_codes": {},
|
||||
},
|
||||
ClientVersion: {
|
||||
"any_valid_semantic_version": {},
|
||||
},
|
||||
OS: {
|
||||
"linux": {},
|
||||
"darwin": {},
|
||||
"windows": {},
|
||||
"ios": {},
|
||||
"android": {},
|
||||
},
|
||||
OSVersion: {
|
||||
"any_valid_semantic_version": {},
|
||||
},
|
||||
OSFamily: {
|
||||
"linux-debian": {},
|
||||
"linux-redhat": {},
|
||||
"linux-suse": {},
|
||||
"linux-arch": {},
|
||||
"linux-gentoo": {},
|
||||
"linux-other": {},
|
||||
"darwin": {},
|
||||
"windows": {},
|
||||
"ios": {},
|
||||
"android": {},
|
||||
},
|
||||
KernelVersion: {
|
||||
"any_valid_semantic_version": {},
|
||||
},
|
||||
AutoUpdate: {
|
||||
"true": {},
|
||||
"false": {},
|
||||
},
|
||||
}
|
||||
|
||||
var PostureCheckAttrValues = map[Attribute][]string{
|
||||
ClientLocation: {"any_valid_iso_country_codes"},
|
||||
ClientVersion: {"any_valid_semantic_version"},
|
||||
OS: {"linux", "darwin", "windows", "ios", "android"},
|
||||
OSVersion: {"any_valid_semantic_version"},
|
||||
OSFamily: {"linux-debian", "linux-redhat", "linux-suse", "linux-arch", "linux-gentoo", "linux-other", "darwin", "windows", "ios", "android"},
|
||||
KernelVersion: {"any_valid_semantic_version"},
|
||||
AutoUpdate: {"true", "false"},
|
||||
}
|
||||
|
||||
type PostureCheck struct {
|
||||
ID string `gorm:"primaryKey" json:"id"`
|
||||
Name string `gorm:"name" json:"name"`
|
||||
NetworkID string `gorm:"network_id" json:"network_id"`
|
||||
Description string `gorm:"description" json:"description"`
|
||||
Attribute Attribute `gorm:"attribute" json:"attribute"`
|
||||
Values datatypes.JSONSlice[string] `gorm:"values" json:"values"`
|
||||
Severity models.Severity `gorm:"severity" json:"severity"`
|
||||
Tags datatypes.JSONMap `gorm:"tags" json:"tags"`
|
||||
Status bool `gorm:"status" json:"status"`
|
||||
CreatedBy string `gorm:"created_by" json:"created_by"`
|
||||
CreatedAt time.Time `gorm:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"updated_at" json:"updated_at"`
|
||||
}
|
||||
|
||||
func (p *PostureCheck) Get(ctx context.Context) error {
|
||||
return db.FromContext(ctx).Model(&PostureCheck{}).First(&p).Where("id = ?", p.ID).Error
|
||||
}
|
||||
|
||||
func (p *PostureCheck) Update(ctx context.Context) error {
|
||||
return db.FromContext(ctx).Model(&PostureCheck{}).Where("id = ?", p.ID).Updates(&p).Error
|
||||
}
|
||||
|
||||
func (p *PostureCheck) Create(ctx context.Context) error {
|
||||
return db.FromContext(ctx).Model(&PostureCheck{}).Create(&p).Error
|
||||
}
|
||||
|
||||
func (p *PostureCheck) ListByNetwork(ctx context.Context) (pcli []PostureCheck, err error) {
|
||||
err = db.FromContext(ctx).Model(&PostureCheck{}).Where("network_id = ?", p.NetworkID).Find(&pcli).Error
|
||||
return
|
||||
}
|
||||
|
||||
func (p *PostureCheck) Delete(ctx context.Context) error {
|
||||
return db.FromContext(ctx).Model(&PostureCheck{}).Where("id = ?", p.ID).Delete(&p).Error
|
||||
}
|
||||
|
||||
func (p *PostureCheck) UpdateStatus(ctx context.Context) error {
|
||||
return db.FromContext(ctx).Model(&PostureCheck{}).Where("id = ?", p.ID).Updates(map[string]any{
|
||||
"status": p.Status,
|
||||
}).Error
|
||||
}
|
||||
Reference in New Issue
Block a user