mirror of
https://github.com/gravitl/netmaker.git
synced 2026-04-22 16:07:11 +08:00
NM-163: Users, Groups, Roles, Networks and Hosts Table Migration (#3910)
* feat(go): add user schema; * feat(go): migrate to user schema; * feat(go): add audit fields; * feat(go): remove unused fields from the network model; * feat(go): add network schema; * feat(go): migrate to network schema; * refactor(go): add comment to clarify migration logic; * fix(go): test failures; * fix(go): test failures; * feat(go): change membership table to store memberships at all scopes; * feat(go): add schema for access grants; * feat(go): remove nameservers from new networks table; ensure db passed for schema functions; * feat(go): set max conns for sqlite to 1; * fix(go): issues updating user account status; * refactor(go): remove converters and access grants; * refactor(go): add json tags in schema models; * refactor(go): rename file to migrate_v1_6_0.go; * refactor(go): add user groups and user roles tables; use schema tables; * refactor(go): inline get and list from schema package; * refactor(go): inline get network and list users from schema package; * fix(go): staticcheck issues; * fix(go): remove test not in use; fix test case; * fix(go): validate network; * fix(go): resolve static checks; * fix(go): new models errors; * fix(go): test errors; * fix(go): handle no records; * fix(go): add validations for user object; * fix(go): set correct extclient status; * fix(go): test error; * feat(go): make schema the base package; * feat(go): add host schema; * feat(go): use schema host everywhere; * feat(go): inline get host, list hosts and delete host; * feat(go): use non-ptr value; * feat(go): use save to upsert all fields; * feat(go): use save to upsert all fields; * feat(go): save turn endpoint as string; * feat(go): check for gorm error record not found; * fix(go): test failures; * fix(go): update all network fields; * fix(go): update all network fields; * feat(go): add paginated list networks api; * feat(go): add paginated list users api; * feat(go): add paginated list hosts api; * feat(go): add pagination to list groups api; * fix(go): comment; * fix(go): implement marshal and unmarshal text for custom types; * fix(go): implement marshal and unmarshal json for custom types; * fix(go): just use the old model for unmarshalling; * fix(go): implement marshal and unmarshal json for custom types; * feat(go): remove paginated list networks api; * feat(go): use custom paginated response object; * fix(go): ensure default values for page and per_page are used when not passed; * fix(go): rename v1.6.0 to v1.5.1; * fix(go): check for gorm.ErrRecordNotFound instead of database.IsEmptyRecord; * fix(go): use host id, not pending host id; * feat(go): add filters to paginated apis; * feat(go): add filters to paginated apis; * feat(go): remove check for max username length; * feat(go): add filters to count as well; * feat(go): use library to check email address validity; * feat(go): ignore pagination if params not passed; * fix(go): pagination issues; * fix(go): check exists before using; * fix(go): remove debug log; * fix(go): use gorm err record not found; * fix(go): use gorm err record not found; * fix(go): use user principal name when creating pending user; * fix(go): use schema package for consts; * fix(go): prevent disabling superadmin user; Co-authored-by: tenki-reviewer[bot] <262613592+tenki-reviewer[bot]@users.noreply.github.com> * fix(go): swap is admin and is superadmin; Co-authored-by: tenki-reviewer[bot] <262613592+tenki-reviewer[bot]@users.noreply.github.com> * fix(go): remove dead code block; https://github.com/gravitl/netmaker/pull/3910#discussion_r2928837937 * fix(go): incorrect message when trying to disable self; https://github.com/gravitl/netmaker/pull/3910#discussion_r2928837934 * fix(go): use correct header; Co-authored-by: tenki-reviewer[bot] <262613592+tenki-reviewer[bot]@users.noreply.github.com> * fix(go): return after error response; Co-authored-by: tenki-reviewer[bot] <262613592+tenki-reviewer[bot]@users.noreply.github.com> * fix(go): use correct order of params; https://github.com/gravitl/netmaker/pull/3910#discussion_r2929593036 * fix(go): set default values for page and page size; use v2 instead of /list; * Update logic/auth.go Co-authored-by: tenki-reviewer[bot] <262613592+tenki-reviewer[bot]@users.noreply.github.com> * Update schema/user_roles.go Co-authored-by: tenki-reviewer[bot] <262613592+tenki-reviewer[bot]@users.noreply.github.com> * fix(go): syntax error; * fix(go): set default values when page and per_page are not passed or 0; * fix(go): use uuid.parse instead of uuid.must parse; * fix(go): review errors; * fix(go): review errors; * Update controllers/user.go Co-authored-by: tenki-reviewer[bot] <262613592+tenki-reviewer[bot]@users.noreply.github.com> * Update controllers/user.go Co-authored-by: tenki-reviewer[bot] <262613592+tenki-reviewer[bot]@users.noreply.github.com> * NM-163: fix errors: * Update db/types/options.go Co-authored-by: tenki-reviewer[bot] <262613592+tenki-reviewer[bot]@users.noreply.github.com> * fix(go): persist return user in event; * Update db/types/options.go Co-authored-by: tenki-reviewer[bot] <262613592+tenki-reviewer[bot]@users.noreply.github.com> * NM-163: duplicate lines of code * NM-163: fix(go): fix missing return and filter parsing in user controller - Add missing return after error response in updateUserAccountStatus to prevent double-response and spurious ext-client side-effects - Use switch statements in listUsers to skip unrecognized account_status and mfa_status filter values * fix(go): check for both min and max page size; * fix(go): enclose transfer superadmin in transaction; * fix(go): review errors; * fix(go): remove free tier checks; * fix(go): review fixes; --------- Co-authored-by: VishalDalwadi <dalwadivishal26@gmail.com> Co-authored-by: Vishal Dalwadi <51291657+VishalDalwadi@users.noreply.github.com> Co-authored-by: tenki-reviewer[bot] <262613592+tenki-reviewer[bot]@users.noreply.github.com>
This commit is contained in:
+8
-5
@@ -1,8 +1,10 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"context"
|
||||
|
||||
"github.com/gravitl/netmaker/db"
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
)
|
||||
|
||||
// == consts ==
|
||||
@@ -10,11 +12,12 @@ const (
|
||||
node_signin_length = 64
|
||||
)
|
||||
|
||||
func isUserIsAllowed(username, network string) (*models.User, error) {
|
||||
func isUserIsAllowed(username, network string) (*schema.User, error) {
|
||||
|
||||
user, err := logic.GetUser(username)
|
||||
user := &schema.User{Username: username}
|
||||
err := user.Get(db.WithContext(context.TODO()))
|
||||
if err != nil { // user must not exist, so try to make one
|
||||
return &models.User{}, err
|
||||
return &schema.User{}, err
|
||||
}
|
||||
|
||||
return user, nil
|
||||
|
||||
+25
-21
@@ -180,23 +180,26 @@ func SessionHandler(conn *websocket.Conn) {
|
||||
handleHostRegErr(conn, err)
|
||||
return
|
||||
}
|
||||
currHost, err := logic.GetHost(result.Host.ID.String())
|
||||
currHost := &schema.Host{
|
||||
ID: result.Host.ID,
|
||||
}
|
||||
err = currHost.Get(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
handleHostRegErr(conn, err)
|
||||
return
|
||||
}
|
||||
var currentNetworks = []string{}
|
||||
var currentNetworks []string
|
||||
if result.ALL {
|
||||
currentNets, err := logic.GetNetworks()
|
||||
if err == nil && len(currentNets) > 0 {
|
||||
for i := range currentNets {
|
||||
currentNetworks = append(currentNetworks, currentNets[i].NetID)
|
||||
_networks, err := (&schema.Network{}).ListAll(db.WithContext(context.TODO()))
|
||||
if err == nil && len(_networks) > 0 {
|
||||
for i := range _networks {
|
||||
currentNetworks = append(currentNetworks, _networks[i].Name)
|
||||
}
|
||||
}
|
||||
} else if len(result.Network) > 0 {
|
||||
currentNetworks = append(currentNetworks, result.Network)
|
||||
}
|
||||
var netsToAdd = []string{} // track the networks not currently owned by host
|
||||
var netsToAdd []string // track the networks not currently owned by host
|
||||
hostNets := logic.GetHostNetworks(currHost.ID.String())
|
||||
for _, newNet := range currentNetworks {
|
||||
if !logic.StringSliceContains(hostNets, newNet) {
|
||||
@@ -240,13 +243,14 @@ func SessionHandler(conn *websocket.Conn) {
|
||||
}
|
||||
|
||||
// CheckNetRegAndHostUpdate - run through networks and send a host update
|
||||
func CheckNetRegAndHostUpdate(key models.EnrollmentKey, h *models.Host, username string) {
|
||||
func CheckNetRegAndHostUpdate(key models.EnrollmentKey, h *schema.Host, username string) {
|
||||
// publish host update through MQ
|
||||
featureFlags := logic.GetFeatureFlags()
|
||||
for _, netID := range key.Networks {
|
||||
if network, err := logic.GetNetwork(netID); err == nil {
|
||||
if featureFlags.EnableDeviceApproval && network.AutoJoin == "false" {
|
||||
if logic.DoesHostExistinTheNetworkAlready(h, models.NetworkID(netID)) {
|
||||
network := &schema.Network{Name: netID}
|
||||
if err := network.Get(db.WithContext(context.TODO())); err == nil {
|
||||
if featureFlags.EnableDeviceApproval && !network.AutoJoin {
|
||||
if logic.DoesHostExistinTheNetworkAlready(h, schema.NetworkID(netID)) {
|
||||
continue
|
||||
}
|
||||
if err := (&schema.PendingHost{
|
||||
@@ -275,37 +279,37 @@ func CheckNetRegAndHostUpdate(key models.EnrollmentKey, h *models.Host, username
|
||||
|
||||
if len(username) > 0 {
|
||||
logic.LogEvent(&models.Event{
|
||||
Action: models.JoinHostToNet,
|
||||
Action: schema.JoinHostToNet,
|
||||
Source: models.Subject{
|
||||
ID: username,
|
||||
Name: username,
|
||||
Type: models.UserSub,
|
||||
Type: schema.UserSub,
|
||||
},
|
||||
TriggeredBy: username,
|
||||
Target: models.Subject{
|
||||
ID: h.ID.String(),
|
||||
Name: h.Name,
|
||||
Type: models.DeviceSub,
|
||||
Type: schema.DeviceSub,
|
||||
},
|
||||
NetworkID: models.NetworkID(netID),
|
||||
Origin: models.Dashboard,
|
||||
NetworkID: schema.NetworkID(netID),
|
||||
Origin: schema.Dashboard,
|
||||
})
|
||||
} else {
|
||||
logic.LogEvent(&models.Event{
|
||||
Action: models.JoinHostToNet,
|
||||
Action: schema.JoinHostToNet,
|
||||
Source: models.Subject{
|
||||
ID: key.Value,
|
||||
Name: key.Tags[0],
|
||||
Type: models.EnrollmentKeySub,
|
||||
Type: schema.EnrollmentKeySub,
|
||||
},
|
||||
TriggeredBy: username,
|
||||
Target: models.Subject{
|
||||
ID: h.ID.String(),
|
||||
Name: h.Name,
|
||||
Type: models.DeviceSub,
|
||||
Type: schema.DeviceSub,
|
||||
},
|
||||
NetworkID: models.NetworkID(netID),
|
||||
Origin: models.Dashboard,
|
||||
NetworkID: schema.NetworkID(netID),
|
||||
Origin: schema.Dashboard,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"os"
|
||||
|
||||
"github.com/gravitl/netmaker/cli/functions"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
@@ -16,7 +16,7 @@ var networkCreateCmd = &cobra.Command{
|
||||
Long: `Create a Network`,
|
||||
Args: cobra.NoArgs,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
network := &models.Network{}
|
||||
network := &schema.Network{}
|
||||
if networkDefinitionFilePath != "" {
|
||||
content, err := os.ReadFile(networkDefinitionFilePath)
|
||||
if err != nil {
|
||||
@@ -26,28 +26,15 @@ var networkCreateCmd = &cobra.Command{
|
||||
log.Fatal(err)
|
||||
}
|
||||
} else {
|
||||
network.NetID = netID
|
||||
network.Name = name
|
||||
network.AddressRange = address
|
||||
if address6 != "" {
|
||||
network.AddressRange6 = address6
|
||||
network.IsIPv6 = "yes"
|
||||
}
|
||||
if address == "" {
|
||||
network.IsIPv4 = "no"
|
||||
}
|
||||
if udpHolePunch {
|
||||
network.DefaultUDPHolePunch = "yes"
|
||||
}
|
||||
if defaultACL {
|
||||
network.DefaultACL = "yes"
|
||||
}
|
||||
network.DefaultInterface = defaultInterface
|
||||
network.DefaultListenPort = int32(defaultListenPort)
|
||||
network.NodeLimit = int32(nodeLimit)
|
||||
network.DefaultKeepalive = int32(defaultKeepalive)
|
||||
if allowManualSignUp {
|
||||
network.AllowManualSignUp = "yes"
|
||||
}
|
||||
network.DefaultKeepAlive = defaultKeepalive
|
||||
network.DefaultMTU = int32(defaultMTU)
|
||||
}
|
||||
functions.PrettyPrint(functions.CreateNetwork(network))
|
||||
@@ -56,17 +43,12 @@ var networkCreateCmd = &cobra.Command{
|
||||
|
||||
func init() {
|
||||
networkCreateCmd.Flags().StringVar(&networkDefinitionFilePath, "file", "", "Path to network_definition.json")
|
||||
networkCreateCmd.Flags().StringVar(&netID, "name", "", "Name of the network")
|
||||
networkCreateCmd.Flags().StringVar(&name, "name", "", "Name of the network")
|
||||
networkCreateCmd.MarkFlagsMutuallyExclusive("file", "name")
|
||||
networkCreateCmd.Flags().StringVar(&address, "ipv4_addr", "", "IPv4 address of the network")
|
||||
networkCreateCmd.Flags().StringVar(&address6, "ipv6_addr", "", "IPv6 address of the network")
|
||||
networkCreateCmd.Flags().BoolVar(&udpHolePunch, "udp_hole_punch", false, "Enable UDP Hole Punching ?")
|
||||
networkCreateCmd.Flags().BoolVar(&defaultACL, "default_acl", false, "Enable default Access Control List ?")
|
||||
networkCreateCmd.Flags().StringVar(&defaultInterface, "interface", "", "Name of the network interface")
|
||||
networkCreateCmd.Flags().IntVar(&defaultListenPort, "listen_port", 51821, "Default wireguard port each node will attempt to use")
|
||||
networkCreateCmd.Flags().IntVar(&nodeLimit, "node_limit", 999999999, "Maximum number of nodes that can be associated with this network")
|
||||
networkCreateCmd.Flags().IntVar(&defaultKeepalive, "keep_alive", 20, "Keep Alive in seconds")
|
||||
networkCreateCmd.Flags().IntVar(&defaultMTU, "mtu", 1280, "MTU size")
|
||||
networkCreateCmd.Flags().BoolVar(&allowManualSignUp, "manual_signup", false, "Allow manual signup ?")
|
||||
rootCmd.AddCommand(networkCreateCmd)
|
||||
}
|
||||
|
||||
@@ -2,15 +2,10 @@ package network
|
||||
|
||||
var (
|
||||
networkDefinitionFilePath string
|
||||
netID string
|
||||
name string
|
||||
address string
|
||||
address6 string
|
||||
udpHolePunch bool
|
||||
defaultACL bool
|
||||
defaultInterface string
|
||||
defaultListenPort int
|
||||
nodeLimit int
|
||||
defaultKeepalive int
|
||||
allowManualSignUp bool
|
||||
defaultMTU int
|
||||
)
|
||||
|
||||
@@ -24,9 +24,9 @@ var networkListCmd = &cobra.Command{
|
||||
table := tablewriter.NewWriter(os.Stdout)
|
||||
table.SetHeader([]string{"NetId", "Address Range (IPv4)", "Address Range (IPv6)", "Network Last Modified", "Nodes Last Modified"})
|
||||
for _, n := range *networks {
|
||||
networkLastModified := time.Unix(n.NetworkLastModified, 0).Format(time.RFC3339)
|
||||
nodesLastModified := time.Unix(n.NodesLastModified, 0).Format(time.RFC3339)
|
||||
table.Append([]string{n.NetID, n.AddressRange, n.AddressRange6, networkLastModified, nodesLastModified})
|
||||
networkLastModified := n.UpdatedAt.Format(time.RFC3339)
|
||||
nodesLastModified := n.NodesUpdatedAt.Format(time.RFC3339)
|
||||
table.Append([]string{n.Name, n.AddressRange, n.AddressRange6, networkLastModified, nodesLastModified})
|
||||
}
|
||||
table.Render()
|
||||
}
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
package network
|
||||
|
||||
import (
|
||||
"log"
|
||||
"strconv"
|
||||
|
||||
"github.com/gravitl/netmaker/cli/functions"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
var networkNodeLimitCmd = &cobra.Command{
|
||||
Use: "node_limit [NETWORK NAME] [NEW LIMIT]",
|
||||
Short: "Update network nodel limit",
|
||||
Long: `Update network nodel limit`,
|
||||
Args: cobra.ExactArgs(2),
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
nodelimit, err := strconv.ParseInt(args[1], 10, 32)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
functions.PrettyPrint(functions.UpdateNetworkNodeLimit(args[0], int32(nodelimit)))
|
||||
},
|
||||
}
|
||||
|
||||
func init() {
|
||||
rootCmd.AddCommand(networkNodeLimitCmd)
|
||||
}
|
||||
+7
-19
@@ -1,11 +1,10 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/gravitl/netmaker/cli/functions"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
"github.com/spf13/cobra"
|
||||
"gorm.io/datatypes"
|
||||
)
|
||||
|
||||
var userCreateCmd = &cobra.Command{
|
||||
@@ -14,24 +13,13 @@ var userCreateCmd = &cobra.Command{
|
||||
Short: "Create a new user",
|
||||
Long: `Create a new user`,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
user := &models.User{UserName: username, Password: password, PlatformRoleID: models.UserRoleID(platformID)}
|
||||
if len(networkRoles) > 0 {
|
||||
netRolesMap := make(map[models.NetworkID]map[models.UserRoleID]struct{})
|
||||
for netID, netRoles := range networkRoles {
|
||||
roleMap := make(map[models.UserRoleID]struct{})
|
||||
for _, roleID := range strings.Split(netRoles, " ") {
|
||||
roleMap[models.UserRoleID(roleID)] = struct{}{}
|
||||
}
|
||||
netRolesMap[models.NetworkID(netID)] = roleMap
|
||||
}
|
||||
user.NetworkRoles = netRolesMap
|
||||
}
|
||||
user := &schema.User{Username: username, Password: password, PlatformRoleID: schema.UserRoleID(platformID)}
|
||||
if len(groups) > 0 {
|
||||
grMap := make(map[models.UserGroupID]struct{})
|
||||
grMap := make(map[schema.UserGroupID]struct{})
|
||||
for _, groupID := range groups {
|
||||
grMap[models.UserGroupID(groupID)] = struct{}{}
|
||||
grMap[schema.UserGroupID(groupID)] = struct{}{}
|
||||
}
|
||||
user.UserGroups = grMap
|
||||
user.UserGroups = datatypes.NewJSONType(grMap)
|
||||
}
|
||||
|
||||
functions.PrettyPrint(functions.CreateUser(user))
|
||||
@@ -42,7 +30,7 @@ func init() {
|
||||
|
||||
userCreateCmd.Flags().StringVar(&username, "name", "", "Name of the user")
|
||||
userCreateCmd.Flags().StringVar(&password, "password", "", "Password of the user")
|
||||
userCreateCmd.Flags().StringVarP(&platformID, "platform-role", "r", models.ServiceUser.String(),
|
||||
userCreateCmd.Flags().StringVarP(&platformID, "platform-role", "r", schema.ServiceUser.String(),
|
||||
"Platform Role of the user; run `nmctl roles list` to see available user roles")
|
||||
userCreateCmd.MarkFlagRequired("name")
|
||||
userCreateCmd.MarkFlagRequired("password")
|
||||
|
||||
@@ -35,7 +35,7 @@ var userGroupListCmd = &cobra.Command{
|
||||
for _, d := range data {
|
||||
|
||||
roleInfoStr := ""
|
||||
for netID, netRoleMap := range d.NetworkRoles {
|
||||
for netID, netRoleMap := range d.NetworkRoles.Data() {
|
||||
roleList := []string{}
|
||||
for roleID := range netRoleMap {
|
||||
roleList = append(roleList, roleID.String())
|
||||
@@ -88,7 +88,7 @@ var userGroupGetCmd = &cobra.Command{
|
||||
h := []string{"ID", "MetaData", "Network Roles"}
|
||||
table.SetHeader(h)
|
||||
roleInfoStr := ""
|
||||
for netID, netRoleMap := range data.NetworkRoles {
|
||||
for netID, netRoleMap := range data.NetworkRoles.Data() {
|
||||
roleList := []string{}
|
||||
for roleID := range netRoleMap {
|
||||
roleList = append(roleList, roleID.String())
|
||||
|
||||
+7
-19
@@ -1,11 +1,10 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/gravitl/netmaker/cli/functions"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
"github.com/spf13/cobra"
|
||||
"gorm.io/datatypes"
|
||||
)
|
||||
|
||||
var userUpdateCmd = &cobra.Command{
|
||||
@@ -14,27 +13,16 @@ var userUpdateCmd = &cobra.Command{
|
||||
Short: "Update a user",
|
||||
Long: `Update a user`,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
user := &models.User{UserName: args[0]}
|
||||
user := &schema.User{Username: args[0]}
|
||||
if platformID != "" {
|
||||
user.PlatformRoleID = models.UserRoleID(platformID)
|
||||
}
|
||||
if len(networkRoles) > 0 {
|
||||
netRolesMap := make(map[models.NetworkID]map[models.UserRoleID]struct{})
|
||||
for netID, netRoles := range networkRoles {
|
||||
roleMap := make(map[models.UserRoleID]struct{})
|
||||
for _, roleID := range strings.Split(netRoles, ",") {
|
||||
roleMap[models.UserRoleID(roleID)] = struct{}{}
|
||||
}
|
||||
netRolesMap[models.NetworkID(netID)] = roleMap
|
||||
}
|
||||
user.NetworkRoles = netRolesMap
|
||||
user.PlatformRoleID = schema.UserRoleID(platformID)
|
||||
}
|
||||
if len(groups) > 0 {
|
||||
grMap := make(map[models.UserGroupID]struct{})
|
||||
grMap := make(map[schema.UserGroupID]struct{})
|
||||
for _, groupID := range groups {
|
||||
grMap[models.UserGroupID(groupID)] = struct{}{}
|
||||
grMap[schema.UserGroupID(groupID)] = struct{}{}
|
||||
}
|
||||
user.UserGroups = grMap
|
||||
user.UserGroups = datatypes.NewJSONType(grMap)
|
||||
}
|
||||
functions.PrettyPrint(functions.UpdateUser(user))
|
||||
},
|
||||
|
||||
@@ -1,37 +1,24 @@
|
||||
package functions
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
)
|
||||
|
||||
// CreateNetwork - creates a network
|
||||
func CreateNetwork(payload *models.Network) *models.Network {
|
||||
return request[models.Network](http.MethodPost, "/api/networks", payload)
|
||||
}
|
||||
|
||||
// UpdateNetwork - updates a network
|
||||
func UpdateNetwork(name string, payload *models.Network) *models.Network {
|
||||
return request[models.Network](http.MethodPut, "/api/networks/"+name, payload)
|
||||
}
|
||||
|
||||
// UpdateNetworkNodeLimit - updates a network
|
||||
func UpdateNetworkNodeLimit(name string, nodeLimit int32) *models.Network {
|
||||
return request[models.Network](http.MethodPut, fmt.Sprintf("/api/networks/%s/nodelimit", name), &models.Network{
|
||||
NodeLimit: nodeLimit,
|
||||
})
|
||||
func CreateNetwork(payload *schema.Network) *schema.Network {
|
||||
return request[schema.Network](http.MethodPost, "/api/networks", payload)
|
||||
}
|
||||
|
||||
// GetNetworks - fetch all networks
|
||||
func GetNetworks() *[]models.Network {
|
||||
return request[[]models.Network](http.MethodGet, "/api/networks", nil)
|
||||
func GetNetworks() *[]schema.Network {
|
||||
return request[[]schema.Network](http.MethodGet, "/api/networks", nil)
|
||||
}
|
||||
|
||||
// GetNetwork - fetch a single network
|
||||
func GetNetwork(name string) *models.Network {
|
||||
return request[models.Network](http.MethodGet, "/api/networks/"+name, nil)
|
||||
func GetNetwork(name string) *schema.Network {
|
||||
return request[schema.Network](http.MethodGet, "/api/networks/"+name, nil)
|
||||
}
|
||||
|
||||
// DeleteNetwork - delete a network
|
||||
|
||||
+11
-10
@@ -6,6 +6,7 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
)
|
||||
|
||||
// HasAdmin - check if server has an admin user
|
||||
@@ -14,13 +15,13 @@ func HasAdmin() *bool {
|
||||
}
|
||||
|
||||
// CreateUser - create a user
|
||||
func CreateUser(payload *models.User) *models.User {
|
||||
return request[models.User](http.MethodPost, "/api/users/"+payload.UserName, payload)
|
||||
func CreateUser(payload *schema.User) *schema.User {
|
||||
return request[schema.User](http.MethodPost, "/api/users/"+payload.Username, payload)
|
||||
}
|
||||
|
||||
// UpdateUser - update a user
|
||||
func UpdateUser(payload *models.User) *models.User {
|
||||
return request[models.User](http.MethodPut, "/api/users/"+payload.UserName, payload)
|
||||
func UpdateUser(payload *schema.User) *schema.User {
|
||||
return request[schema.User](http.MethodPut, "/api/users/"+payload.Username, payload)
|
||||
}
|
||||
|
||||
// DeleteUser - delete a user
|
||||
@@ -29,8 +30,8 @@ func DeleteUser(username string) *string {
|
||||
}
|
||||
|
||||
// GetUser - fetch a single user
|
||||
func GetUser(username string) *models.User {
|
||||
return request[models.User](http.MethodGet, "/api/users/"+username, nil)
|
||||
func GetUser(username string) *schema.User {
|
||||
return request[schema.User](http.MethodGet, "/api/users/"+username, nil)
|
||||
}
|
||||
|
||||
// ListUsers - fetch all users
|
||||
@@ -38,7 +39,7 @@ func ListUsers() *[]models.ReturnUser {
|
||||
return request[[]models.ReturnUser](http.MethodGet, "/api/users", nil)
|
||||
}
|
||||
|
||||
func ListUserRoles() (roles []models.UserRolePermissionTemplate) {
|
||||
func ListUserRoles() (roles []schema.UserRole) {
|
||||
resp := request[models.SuccessResponse](http.MethodGet, "/api/v1/users/roles", nil)
|
||||
d, _ := json.Marshal(resp.Response)
|
||||
json.Unmarshal(d, &roles)
|
||||
@@ -48,14 +49,14 @@ func ListUserRoles() (roles []models.UserRolePermissionTemplate) {
|
||||
func DeleteUserRole(roleID string) *models.SuccessResponse {
|
||||
return request[models.SuccessResponse](http.MethodDelete, fmt.Sprintf("/api/v1/users/role?role_id=%s", roleID), nil)
|
||||
}
|
||||
func GetUserRole(roleID string) (role models.UserRolePermissionTemplate) {
|
||||
func GetUserRole(roleID string) (role schema.UserRole) {
|
||||
resp := request[models.SuccessResponse](http.MethodGet, fmt.Sprintf("/api/v1/users/role?role_id=%s", roleID), nil)
|
||||
d, _ := json.Marshal(resp.Response)
|
||||
json.Unmarshal(d, &role)
|
||||
return
|
||||
}
|
||||
|
||||
func ListUserGrps() (groups []models.UserGroup) {
|
||||
func ListUserGrps() (groups []schema.UserGroup) {
|
||||
resp := request[models.SuccessResponse](http.MethodGet, "/api/v1/users/groups", nil)
|
||||
d, _ := json.Marshal(resp.Response)
|
||||
json.Unmarshal(d, &groups)
|
||||
@@ -66,7 +67,7 @@ func DeleteUserGrp(grpID string) *models.SuccessResponse {
|
||||
return request[models.SuccessResponse](http.MethodDelete, fmt.Sprintf("/api/v1/users/group?group_id=%s", grpID), nil)
|
||||
}
|
||||
|
||||
func GetUserGrp(grpID string) (group models.UserGroup) {
|
||||
func GetUserGrp(grpID string) (group schema.UserGroup) {
|
||||
resp := request[models.SuccessResponse](http.MethodGet, fmt.Sprintf("/api/v1/users/group?group_id=%s", grpID), nil)
|
||||
d, _ := json.Marshal(resp.Response)
|
||||
json.Unmarshal(d, &group)
|
||||
|
||||
+17
-16
@@ -211,12 +211,12 @@ func getAcls(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
// check if network exists
|
||||
_, err := logic.GetNetwork(netID)
|
||||
err := (&schema.Network{Name: netID}).Get(r.Context())
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
|
||||
return
|
||||
}
|
||||
acls, err := logic.ListAclsByNetwork(models.NetworkID(netID))
|
||||
acls, err := logic.ListAclsByNetwork(schema.NetworkID(netID))
|
||||
if err != nil {
|
||||
logger.Log(0, r.Header.Get("user"), "failed to get all network acl entries: ", err.Error())
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
|
||||
@@ -276,7 +276,8 @@ func createAcl(w http.ResponseWriter, r *http.Request) {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
|
||||
return
|
||||
}
|
||||
user, err := logic.GetUser(r.Header.Get("user"))
|
||||
user := &schema.User{Username: r.Header.Get("user")}
|
||||
err = user.Get(r.Context())
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
|
||||
return
|
||||
@@ -289,7 +290,7 @@ func createAcl(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
acl := req
|
||||
acl.ID = uuid.New().String()
|
||||
acl.CreatedBy = user.UserName
|
||||
acl.CreatedBy = user.Username
|
||||
acl.CreatedAt = time.Now().UTC()
|
||||
acl.Default = false
|
||||
if acl.ServiceType == models.Any {
|
||||
@@ -312,20 +313,20 @@ func createAcl(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
logic.LogEvent(&models.Event{
|
||||
Action: models.Create,
|
||||
Action: schema.Create,
|
||||
Source: models.Subject{
|
||||
ID: r.Header.Get("user"),
|
||||
Name: r.Header.Get("user"),
|
||||
Type: models.UserSub,
|
||||
Type: schema.UserSub,
|
||||
},
|
||||
TriggeredBy: r.Header.Get("user"),
|
||||
Target: models.Subject{
|
||||
ID: acl.ID,
|
||||
Name: acl.Name,
|
||||
Type: models.AclSub,
|
||||
Type: schema.AclSub,
|
||||
},
|
||||
NetworkID: acl.NetworkID,
|
||||
Origin: models.Dashboard,
|
||||
Origin: schema.Dashboard,
|
||||
})
|
||||
go mq.PublishPeerUpdate(true)
|
||||
logic.ReturnSuccessResponseWithJson(w, r, acl, "created acl successfully")
|
||||
@@ -374,24 +375,24 @@ func updateAcl(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
logic.LogEvent(&models.Event{
|
||||
Action: models.Update,
|
||||
Action: schema.Update,
|
||||
Source: models.Subject{
|
||||
ID: r.Header.Get("user"),
|
||||
Name: r.Header.Get("user"),
|
||||
Type: models.UserSub,
|
||||
Type: schema.UserSub,
|
||||
},
|
||||
TriggeredBy: r.Header.Get("user"),
|
||||
Target: models.Subject{
|
||||
ID: acl.ID,
|
||||
Name: acl.Name,
|
||||
Type: models.AclSub,
|
||||
Type: schema.AclSub,
|
||||
},
|
||||
Diff: models.Diff{
|
||||
Old: acl,
|
||||
New: updateAcl.Acl,
|
||||
},
|
||||
NetworkID: acl.NetworkID,
|
||||
Origin: models.Dashboard,
|
||||
Origin: schema.Dashboard,
|
||||
})
|
||||
go mq.PublishPeerUpdate(true)
|
||||
logic.ReturnSuccessResponse(w, r, "updated acl "+acl.Name)
|
||||
@@ -428,20 +429,20 @@ func deleteAcl(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
logic.LogEvent(&models.Event{
|
||||
Action: models.Delete,
|
||||
Action: schema.Delete,
|
||||
Source: models.Subject{
|
||||
ID: r.Header.Get("user"),
|
||||
Name: r.Header.Get("user"),
|
||||
Type: models.UserSub,
|
||||
Type: schema.UserSub,
|
||||
},
|
||||
TriggeredBy: r.Header.Get("user"),
|
||||
Target: models.Subject{
|
||||
ID: acl.ID,
|
||||
Name: acl.Name,
|
||||
Type: models.AclSub,
|
||||
Type: schema.AclSub,
|
||||
},
|
||||
NetworkID: acl.NetworkID,
|
||||
Origin: models.Dashboard,
|
||||
Origin: schema.Dashboard,
|
||||
Diff: models.Diff{
|
||||
Old: acl,
|
||||
New: nil,
|
||||
|
||||
+15
-15
@@ -140,20 +140,20 @@ func createNs(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
logic.LogEvent(&models.Event{
|
||||
Action: models.Create,
|
||||
Action: schema.Create,
|
||||
Source: models.Subject{
|
||||
ID: r.Header.Get("user"),
|
||||
Name: r.Header.Get("user"),
|
||||
Type: models.UserSub,
|
||||
Type: schema.UserSub,
|
||||
},
|
||||
TriggeredBy: r.Header.Get("user"),
|
||||
Target: models.Subject{
|
||||
ID: ns.ID,
|
||||
Name: ns.Name,
|
||||
Type: models.NameserverSub,
|
||||
Type: schema.NameserverSub,
|
||||
},
|
||||
NetworkID: models.NetworkID(ns.NetworkID),
|
||||
Origin: models.Dashboard,
|
||||
NetworkID: schema.NetworkID(ns.NetworkID),
|
||||
Origin: schema.Dashboard,
|
||||
})
|
||||
|
||||
go mq.PublishPeerUpdate(false)
|
||||
@@ -252,24 +252,24 @@ func updateNs(w http.ResponseWriter, r *http.Request) {
|
||||
updateFallback = true
|
||||
}
|
||||
event := &models.Event{
|
||||
Action: models.Update,
|
||||
Action: schema.Update,
|
||||
Source: models.Subject{
|
||||
ID: r.Header.Get("user"),
|
||||
Name: r.Header.Get("user"),
|
||||
Type: models.UserSub,
|
||||
Type: schema.UserSub,
|
||||
},
|
||||
TriggeredBy: r.Header.Get("user"),
|
||||
Target: models.Subject{
|
||||
ID: ns.ID,
|
||||
Name: updateNs.Name,
|
||||
Type: models.NameserverSub,
|
||||
Type: schema.NameserverSub,
|
||||
},
|
||||
Diff: models.Diff{
|
||||
Old: ns,
|
||||
New: updateNs,
|
||||
},
|
||||
NetworkID: models.NetworkID(ns.NetworkID),
|
||||
Origin: models.Dashboard,
|
||||
NetworkID: schema.NetworkID(ns.NetworkID),
|
||||
Origin: schema.Dashboard,
|
||||
}
|
||||
|
||||
if !ns.Default {
|
||||
@@ -352,20 +352,20 @@ func deleteNs(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
logic.LogEvent(&models.Event{
|
||||
Action: models.Delete,
|
||||
Action: schema.Delete,
|
||||
Source: models.Subject{
|
||||
ID: r.Header.Get("user"),
|
||||
Name: r.Header.Get("user"),
|
||||
Type: models.UserSub,
|
||||
Type: schema.UserSub,
|
||||
},
|
||||
TriggeredBy: r.Header.Get("user"),
|
||||
Target: models.Subject{
|
||||
ID: ns.ID,
|
||||
Name: ns.Name,
|
||||
Type: models.NameserverSub,
|
||||
Type: schema.NameserverSub,
|
||||
},
|
||||
NetworkID: models.NetworkID(ns.NetworkID),
|
||||
Origin: models.Dashboard,
|
||||
NetworkID: schema.NetworkID(ns.NetworkID),
|
||||
Origin: schema.Dashboard,
|
||||
Diff: models.Diff{
|
||||
Old: ns,
|
||||
New: nil,
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/txn2/txeh"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
@@ -15,7 +17,7 @@ import (
|
||||
"github.com/gravitl/netmaker/models"
|
||||
)
|
||||
|
||||
var dnsHost models.Host
|
||||
var dnsHost schema.Host
|
||||
|
||||
func TestGetAllDNS(t *testing.T) {
|
||||
deleteAllDNS(t)
|
||||
@@ -425,14 +427,17 @@ func TestValidateDNSCreate(t *testing.T) {
|
||||
|
||||
func createHost() {
|
||||
k, _ := wgtypes.ParseKey("DM5qhLAE20PG9BbfBCger+Ac9D2NDOwCtY1rbYDLf34=")
|
||||
dnsHost = models.Host{
|
||||
dnsHost = schema.Host{
|
||||
ID: uuid.New(),
|
||||
PublicKey: k.PublicKey(),
|
||||
PublicKey: schema.WgKey{Key: k.PublicKey()},
|
||||
HostPass: "password",
|
||||
OS: "linux",
|
||||
Name: "dnshost",
|
||||
}
|
||||
_ = logic.CreateHost(&dnsHost)
|
||||
err := logic.CreateHost(&dnsHost)
|
||||
if err != nil {
|
||||
fmt.Println("ERROR CREATING HOST", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func deleteAllDNS(t *testing.T) {
|
||||
|
||||
+34
-26
@@ -70,7 +70,8 @@ func createEgress(w http.ResponseWriter, r *http.Request) {
|
||||
egressRange = "*"
|
||||
req.Domain = ""
|
||||
}
|
||||
network, err := logic.GetNetwork(req.Network)
|
||||
network := &schema.Network{Name: req.Network}
|
||||
err = network.Get(r.Context())
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
|
||||
return
|
||||
@@ -91,7 +92,7 @@ func createEgress(w http.ResponseWriter, r *http.Request) {
|
||||
CreatedBy: r.Header.Get("user"),
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
if err := logic.AssignVirtualRangeToEgress(&network, &e); err != nil {
|
||||
if err := logic.AssignVirtualRangeToEgress(network, &e); err != nil {
|
||||
logger.Log(0, "error assigning virtual range to egress: ", err.Error())
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
|
||||
return
|
||||
@@ -121,20 +122,20 @@ func createEgress(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
logic.LogEvent(&models.Event{
|
||||
Action: models.Create,
|
||||
Action: schema.Create,
|
||||
Source: models.Subject{
|
||||
ID: r.Header.Get("user"),
|
||||
Name: r.Header.Get("user"),
|
||||
Type: models.UserSub,
|
||||
Type: schema.UserSub,
|
||||
},
|
||||
TriggeredBy: r.Header.Get("user"),
|
||||
Target: models.Subject{
|
||||
ID: e.ID,
|
||||
Name: e.Name,
|
||||
Type: models.EgressSub,
|
||||
Type: schema.EgressSub,
|
||||
},
|
||||
NetworkID: models.NetworkID(e.Network),
|
||||
Origin: models.Dashboard,
|
||||
NetworkID: schema.NetworkID(e.Network),
|
||||
Origin: schema.Dashboard,
|
||||
})
|
||||
// for nodeID := range e.Nodes {
|
||||
// node, err := logic.GetNodeByID(nodeID)
|
||||
@@ -151,8 +152,11 @@ func createEgress(w http.ResponseWriter, r *http.Request) {
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
host, _ := logic.GetHost(node.HostID.String())
|
||||
if host == nil {
|
||||
host := &schema.Host{
|
||||
ID: node.HostID,
|
||||
}
|
||||
err = host.Get(r.Context())
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
mq.HostUpdate(&models.HostUpdate{
|
||||
@@ -227,7 +231,8 @@ func updateEgress(w http.ResponseWriter, r *http.Request) {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
|
||||
return
|
||||
}
|
||||
network, err := logic.GetNetwork(req.Network)
|
||||
network := &schema.Network{Name: req.Network}
|
||||
err = network.Get(r.Context())
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
|
||||
return
|
||||
@@ -271,8 +276,8 @@ func updateEgress(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// Update mode and NAT before calling AssignVirtualRangeToEgress
|
||||
// This ensures the function sees the new values
|
||||
if req.Mode != models.VirtualNAT || !req.Nat {
|
||||
e.Mode = models.DirectNAT
|
||||
if req.Mode != schema.VirtualNAT || !req.Nat {
|
||||
e.Mode = schema.DirectNAT
|
||||
if !req.Nat {
|
||||
e.Mode = ""
|
||||
}
|
||||
@@ -284,8 +289,8 @@ func updateEgress(w http.ResponseWriter, r *http.Request) {
|
||||
e.Nat = req.Nat
|
||||
// Assign virtual range if switching to virtual NAT mode from a different mode,
|
||||
// or if already in virtual NAT mode but virtual range is empty
|
||||
if (oldMode != models.VirtualNAT) || (e.VirtualRange == "") {
|
||||
if err := logic.AssignVirtualRangeToEgress(&network, &e); err != nil {
|
||||
if (oldMode != schema.VirtualNAT) || (e.VirtualRange == "") {
|
||||
if err := logic.AssignVirtualRangeToEgress(network, &e); err != nil {
|
||||
logger.Log(0, "error assigning virtual range to egress: ", err.Error())
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
|
||||
return
|
||||
@@ -293,23 +298,23 @@ func updateEgress(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
event := &models.Event{
|
||||
Action: models.Update,
|
||||
Action: schema.Update,
|
||||
Source: models.Subject{
|
||||
ID: r.Header.Get("user"),
|
||||
Name: r.Header.Get("user"),
|
||||
Type: models.UserSub,
|
||||
Type: schema.UserSub,
|
||||
},
|
||||
TriggeredBy: r.Header.Get("user"),
|
||||
Target: models.Subject{
|
||||
ID: e.ID,
|
||||
Name: e.Name,
|
||||
Type: models.EgressSub,
|
||||
Type: schema.EgressSub,
|
||||
},
|
||||
Diff: models.Diff{
|
||||
Old: e,
|
||||
},
|
||||
NetworkID: models.NetworkID(e.Network),
|
||||
Origin: models.Dashboard,
|
||||
NetworkID: schema.NetworkID(e.Network),
|
||||
Origin: schema.Dashboard,
|
||||
}
|
||||
e.Nodes = make(datatypes.JSONMap)
|
||||
e.Tags = make(datatypes.JSONMap)
|
||||
@@ -374,8 +379,11 @@ func updateEgress(w http.ResponseWriter, r *http.Request) {
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
host, _ := logic.GetHost(node.HostID.String())
|
||||
if host == nil {
|
||||
host := &schema.Host{
|
||||
ID: node.HostID,
|
||||
}
|
||||
err = host.Get(r.Context())
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
mq.HostUpdate(&models.HostUpdate{
|
||||
@@ -426,20 +434,20 @@ func deleteEgress(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
logic.LogEvent(&models.Event{
|
||||
Action: models.Delete,
|
||||
Action: schema.Delete,
|
||||
Source: models.Subject{
|
||||
ID: r.Header.Get("user"),
|
||||
Name: r.Header.Get("user"),
|
||||
Type: models.UserSub,
|
||||
Type: schema.UserSub,
|
||||
},
|
||||
TriggeredBy: r.Header.Get("user"),
|
||||
Target: models.Subject{
|
||||
ID: e.ID,
|
||||
Name: e.Name,
|
||||
Type: models.EgressSub,
|
||||
Type: schema.EgressSub,
|
||||
},
|
||||
NetworkID: models.NetworkID(e.Network),
|
||||
Origin: models.Dashboard,
|
||||
NetworkID: schema.NetworkID(e.Network),
|
||||
Origin: schema.Dashboard,
|
||||
Diff: models.Diff{
|
||||
Old: e,
|
||||
New: nil,
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/go-playground/validator/v10"
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
|
||||
"github.com/gravitl/netmaker/auth"
|
||||
"github.com/gravitl/netmaker/logger"
|
||||
@@ -87,19 +88,19 @@ func deleteEnrollmentKey(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
logic.LogEvent(&models.Event{
|
||||
Action: models.Delete,
|
||||
Action: schema.Delete,
|
||||
Source: models.Subject{
|
||||
ID: r.Header.Get("user"),
|
||||
Name: r.Header.Get("user"),
|
||||
Type: models.UserSub,
|
||||
Type: schema.UserSub,
|
||||
},
|
||||
TriggeredBy: r.Header.Get("user"),
|
||||
Target: models.Subject{
|
||||
ID: keyID,
|
||||
Name: key.Tags[0],
|
||||
Type: models.EnrollmentKeySub,
|
||||
Type: schema.EnrollmentKeySub,
|
||||
},
|
||||
Origin: models.Dashboard,
|
||||
Origin: schema.Dashboard,
|
||||
Diff: models.Diff{
|
||||
Old: key,
|
||||
New: nil,
|
||||
@@ -204,19 +205,19 @@ func createEnrollmentKey(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
logic.LogEvent(&models.Event{
|
||||
Action: models.Create,
|
||||
Action: schema.Create,
|
||||
Source: models.Subject{
|
||||
ID: r.Header.Get("user"),
|
||||
Name: r.Header.Get("user"),
|
||||
Type: models.UserSub,
|
||||
Type: schema.UserSub,
|
||||
},
|
||||
TriggeredBy: r.Header.Get("user"),
|
||||
Target: models.Subject{
|
||||
ID: newEnrollmentKey.Value,
|
||||
Name: newEnrollmentKey.Tags[0],
|
||||
Type: models.EnrollmentKeySub,
|
||||
Type: schema.EnrollmentKeySub,
|
||||
},
|
||||
Origin: models.Dashboard,
|
||||
Origin: schema.Dashboard,
|
||||
})
|
||||
logger.Log(2, r.Header.Get("user"), "created enrollment key")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
@@ -269,23 +270,23 @@ func updateEnrollmentKey(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
logic.LogEvent(&models.Event{
|
||||
Action: models.Update,
|
||||
Action: schema.Update,
|
||||
Source: models.Subject{
|
||||
ID: r.Header.Get("user"),
|
||||
Name: r.Header.Get("user"),
|
||||
Type: models.UserSub,
|
||||
Type: schema.UserSub,
|
||||
},
|
||||
TriggeredBy: r.Header.Get("user"),
|
||||
Target: models.Subject{
|
||||
ID: newEnrollmentKey.Value,
|
||||
Name: newEnrollmentKey.Tags[0],
|
||||
Type: models.EnrollmentKeySub,
|
||||
Type: schema.EnrollmentKeySub,
|
||||
},
|
||||
Diff: models.Diff{
|
||||
Old: currKey,
|
||||
New: newEnrollmentKey,
|
||||
},
|
||||
Origin: models.Dashboard,
|
||||
Origin: schema.Dashboard,
|
||||
})
|
||||
slog.Info("updated enrollment key", "id", keyId)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
@@ -298,7 +299,7 @@ func updateEnrollmentKey(w http.ResponseWriter, r *http.Request) {
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param token path string true "Enrollment Key Token"
|
||||
// @Param body body models.Host true "Host registration parameters"
|
||||
// @Param body body schema.Host true "Host registration parameters"
|
||||
// @Success 200 {object} models.RegisterResponse
|
||||
// @Failure 400 {object} models.ErrorResponse
|
||||
// @Failure 500 {object} models.ErrorResponse
|
||||
@@ -314,7 +315,7 @@ func handleHostRegister(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
// get the host
|
||||
var newHost models.Host
|
||||
var newHost schema.Host
|
||||
if err = json.NewDecoder(r.Body).Decode(&newHost); err != nil {
|
||||
logger.Log(0, r.Header.Get("user"), "error decoding request body: ",
|
||||
err.Error())
|
||||
@@ -383,7 +384,7 @@ func handleHostRegister(w http.ResponseWriter, r *http.Request) {
|
||||
KernelVersion: newHost.KernelVersion,
|
||||
AutoUpdate: newHost.AutoUpdate,
|
||||
Tags: keyTags,
|
||||
}, models.NetworkID(netI))
|
||||
}, schema.NetworkID(netI))
|
||||
pcviolations = append(pcviolations, violations...)
|
||||
if len(violations) > 0 {
|
||||
skipViolatedNetworks = append(skipViolatedNetworks, netI)
|
||||
@@ -433,7 +434,10 @@ func handleHostRegister(w http.ResponseWriter, r *http.Request) {
|
||||
// }
|
||||
// }
|
||||
// enrollmentKey.Networks = networksToAdd
|
||||
currHost, err := logic.GetHost(newHost.ID.String())
|
||||
currHost := &schema.Host{
|
||||
ID: newHost.ID,
|
||||
}
|
||||
err := currHost.Get(r.Context())
|
||||
if err != nil {
|
||||
slog.Error("failed registration", "hostID", newHost.ID.String(), "hostName", newHost.Name, "error", err.Error())
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
|
||||
@@ -447,7 +451,10 @@ func handleHostRegister(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
}
|
||||
host, err := logic.GetHost(newHost.ID.String())
|
||||
host := &schema.Host{
|
||||
ID: newHost.ID,
|
||||
}
|
||||
err = host.Get(r.Context())
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
|
||||
return
|
||||
|
||||
+48
-28
@@ -46,7 +46,7 @@ func extClientHandlers(r *mux.Router) {
|
||||
Methods(http.MethodPut)
|
||||
r.HandleFunc("/api/extclients/{network}/{clientid}", logic.SecurityCheck(false, http.HandlerFunc(deleteExtClient))).
|
||||
Methods(http.MethodDelete)
|
||||
r.HandleFunc("/api/extclients/{network}/{nodeid}", logic.SecurityCheck(false, checkFreeTierLimits(limitChoiceMachines, http.HandlerFunc(createExtClient)))).
|
||||
r.HandleFunc("/api/extclients/{network}/{nodeid}", logic.SecurityCheck(false, http.HandlerFunc(createExtClient))).
|
||||
Methods(http.MethodPost)
|
||||
// unused API
|
||||
//r.HandleFunc("/api/v1/client_conf/{network}", logic.SecurityCheck(false, http.HandlerFunc(getExtClientHAConf))).Methods(http.MethodGet)
|
||||
@@ -85,9 +85,15 @@ func getNetworkExtClients(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
username := r.Header.Get("user")
|
||||
if r.Header.Get("ismaster") != "yes" {
|
||||
user, err := logic.GetUser(username)
|
||||
user := &schema.User{
|
||||
Username: username,
|
||||
}
|
||||
err := user.Get(r.Context())
|
||||
if err == nil {
|
||||
userRole, err := logic.GetRole(user.PlatformRoleID)
|
||||
userRole := &schema.UserRole{
|
||||
ID: user.PlatformRoleID,
|
||||
}
|
||||
err := userRole.Get(r.Context())
|
||||
if err != nil || !userRole.FullAccess {
|
||||
filtered := []models.ExtClient{}
|
||||
for _, ec := range extclients {
|
||||
@@ -218,9 +224,12 @@ func getExtClientConf(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
eli, _ := (&schema.Egress{Network: gwnode.Network}).ListByNetwork(db.WithContext(context.TODO()))
|
||||
acls, _ := logic.ListAclsByNetwork(models.NetworkID(client.Network))
|
||||
acls, _ := logic.ListAclsByNetwork(schema.NetworkID(client.Network))
|
||||
logic.GetNodeEgressInfo(&gwnode, eli, acls)
|
||||
host, err := logic.GetHost(gwnode.HostID.String())
|
||||
host := &schema.Host{
|
||||
ID: gwnode.HostID,
|
||||
}
|
||||
err = host.Get(r.Context())
|
||||
if err != nil {
|
||||
logger.Log(
|
||||
0,
|
||||
@@ -235,7 +244,8 @@ func getExtClientConf(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
network, err := logic.GetParentNetwork(client.Network)
|
||||
network := &schema.Network{Name: client.Network}
|
||||
err = network.Get(r.Context())
|
||||
if err != nil {
|
||||
logger.Log(
|
||||
1,
|
||||
@@ -288,8 +298,8 @@ func getExtClientConf(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
keepalive := ""
|
||||
if network.DefaultKeepalive != 0 {
|
||||
keepalive = "PersistentKeepalive = " + strconv.Itoa(int(network.DefaultKeepalive))
|
||||
if network.DefaultKeepAlive != 0 {
|
||||
keepalive = "PersistentKeepalive = " + strconv.Itoa(int(network.DefaultKeepAlive))
|
||||
}
|
||||
if gwnode.IngressPersistentKeepalive != 0 {
|
||||
keepalive = "PersistentKeepalive = " + strconv.Itoa(int(gwnode.IngressPersistentKeepalive))
|
||||
@@ -428,7 +438,8 @@ func GetExtClientHAConf(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
var params = mux.Vars(r)
|
||||
networkid := params["network"]
|
||||
network, err := logic.GetParentNetwork(networkid)
|
||||
network := &schema.Network{Name: networkid}
|
||||
err := network.Get(r.Context())
|
||||
if err != nil {
|
||||
logger.Log(
|
||||
1,
|
||||
@@ -441,7 +452,7 @@ func GetExtClientHAConf(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
// fetch client based on availability
|
||||
nodes, _ := logic.GetNetworkNodes(networkid)
|
||||
defaultPolicy, _ := logic.GetDefaultPolicy(models.NetworkID(networkid), models.DevicePolicy)
|
||||
defaultPolicy, _ := logic.GetDefaultPolicy(schema.NetworkID(networkid), models.DevicePolicy)
|
||||
var targetGwID string
|
||||
var connectionCnt int = -1
|
||||
for _, nodeI := range nodes {
|
||||
@@ -475,7 +486,10 @@ func GetExtClientHAConf(w http.ResponseWriter, r *http.Request) {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
|
||||
return
|
||||
}
|
||||
host, err := logic.GetHost(gwnode.HostID.String())
|
||||
host := &schema.Host{
|
||||
ID: gwnode.HostID,
|
||||
}
|
||||
err = host.Get(r.Context())
|
||||
if err != nil {
|
||||
logger.Log(0, r.Header.Get("user"),
|
||||
fmt.Sprintf("failed to get ingress gateway host for node [%s] info: %v", gwnode.ID, err))
|
||||
@@ -487,12 +501,13 @@ func GetExtClientHAConf(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Header.Get("ismaster") == "yes" {
|
||||
userName = logic.MasterUser
|
||||
} else {
|
||||
caller, err := logic.GetUser(r.Header.Get("user"))
|
||||
caller := &schema.User{Username: r.Header.Get("user")}
|
||||
err = caller.Get(r.Context())
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
|
||||
return
|
||||
}
|
||||
userName = caller.UserName
|
||||
userName = caller.Username
|
||||
}
|
||||
// create client
|
||||
var extclient models.ExtClient
|
||||
@@ -540,8 +555,8 @@ func GetExtClientHAConf(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
keepalive := ""
|
||||
if network.DefaultKeepalive != 0 {
|
||||
keepalive = "PersistentKeepalive = " + strconv.Itoa(int(network.DefaultKeepalive))
|
||||
if network.DefaultKeepAlive != 0 {
|
||||
keepalive = "PersistentKeepalive = " + strconv.Itoa(int(network.DefaultKeepAlive))
|
||||
}
|
||||
if gwnode.IngressPersistentKeepalive != 0 {
|
||||
keepalive = "PersistentKeepalive = " + strconv.Itoa(int(gwnode.IngressPersistentKeepalive))
|
||||
@@ -711,12 +726,13 @@ func createExtClient(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Header.Get("ismaster") == "yes" {
|
||||
userName = logic.MasterUser
|
||||
} else {
|
||||
caller, err := logic.GetUser(r.Header.Get("user"))
|
||||
caller := &schema.User{Username: r.Header.Get("user")}
|
||||
err = caller.Get(r.Context())
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
|
||||
return
|
||||
}
|
||||
userName = caller.UserName
|
||||
userName = caller.Username
|
||||
// check if user has a config already for remote access client
|
||||
extclients, err := logic.GetNetworkExtClients(node.Network)
|
||||
if err != nil {
|
||||
@@ -742,7 +758,7 @@ func createExtClient(w http.ResponseWriter, r *http.Request) {
|
||||
// let's first confirm that none of the user's extclients for this gw have device id.
|
||||
for _, extclient := range extclients {
|
||||
if extclient.DeviceID == customExtClient.DeviceID &&
|
||||
extclient.OwnerID == caller.UserName && nodeid == extclient.IngressGatewayID {
|
||||
extclient.OwnerID == caller.Username && nodeid == extclient.IngressGatewayID {
|
||||
if jitGrant != nil {
|
||||
extclient.JITExpiresAt = &jitGrant.ExpiresAt
|
||||
_ = logic.SaveExtClient(&extclient)
|
||||
@@ -758,7 +774,7 @@ func createExtClient(w http.ResponseWriter, r *http.Request) {
|
||||
for _, extclient := range extclients {
|
||||
if extclient.RemoteAccessClientID != "" &&
|
||||
extclient.RemoteAccessClientID == customExtClient.RemoteAccessClientID &&
|
||||
extclient.OwnerID == caller.UserName && nodeid == extclient.IngressGatewayID {
|
||||
extclient.OwnerID == caller.Username && nodeid == extclient.IngressGatewayID {
|
||||
if customExtClient.DeviceID != "" && extclient.DeviceID == "" {
|
||||
// This extclient doesn’t include a device ID (and neither do the others).
|
||||
// We patch it by assigning the device ID from the incoming request.
|
||||
@@ -793,7 +809,10 @@ func createExtClient(w http.ResponseWriter, r *http.Request) {
|
||||
dns := gwDNS
|
||||
extclient.DNS = dns
|
||||
}
|
||||
host, err := logic.GetHost(node.HostID.String())
|
||||
host := &schema.Host{
|
||||
ID: node.HostID,
|
||||
}
|
||||
err = host.Get(r.Context())
|
||||
if err != nil {
|
||||
logger.Log(0, r.Header.Get("user"),
|
||||
fmt.Sprintf("failed to get ingress gateway host for node [%s] info: %v", nodeid, err))
|
||||
@@ -803,7 +822,8 @@ func createExtClient(w http.ResponseWriter, r *http.Request) {
|
||||
listenPort := logic.GetPeerListenPort(host)
|
||||
extclient.IngressGatewayEndpoint = fmt.Sprintf("%s:%d", host.EndpointIP.String(), listenPort)
|
||||
extclient.Enabled = true
|
||||
parentNetwork, err := logic.GetNetwork(node.Network)
|
||||
parentNetwork := &schema.Network{Name: node.Network}
|
||||
err = parentNetwork.Get(r.Context())
|
||||
if err == nil { // check if parent network default ACL is enabled (yes) or not (no)
|
||||
extclient.Enabled = parentNetwork.DefaultACL == "yes"
|
||||
}
|
||||
@@ -835,7 +855,7 @@ func createExtClient(w http.ResponseWriter, r *http.Request) {
|
||||
if extclient.DeviceID != "" {
|
||||
// check for violations connecting from desktop app
|
||||
staticNode := extclient.ConvertToStaticNode()
|
||||
violations, _ := logic.CheckPostureViolations(logic.GetPostureCheckDeviceInfoByNode(&staticNode), models.NetworkID(extclient.Network))
|
||||
violations, _ := logic.CheckPostureViolations(logic.GetPostureCheckDeviceInfoByNode(&staticNode), schema.NetworkID(extclient.Network))
|
||||
if len(violations) > 0 {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("posture check violations"), logic.Forbidden))
|
||||
return
|
||||
@@ -885,21 +905,21 @@ func createExtClient(w http.ResponseWriter, r *http.Request) {
|
||||
if extclient.RemoteAccessClientID != "" {
|
||||
// if created by user from client app, log event
|
||||
logic.LogEvent(&models.Event{
|
||||
Action: models.Connect,
|
||||
Action: schema.Connect,
|
||||
Source: models.Subject{
|
||||
ID: userName,
|
||||
Name: userName,
|
||||
Type: models.UserSub,
|
||||
Type: schema.UserSub,
|
||||
},
|
||||
TriggeredBy: userName,
|
||||
Target: models.Subject{
|
||||
ID: extclient.Network,
|
||||
Name: extclient.Network,
|
||||
Type: models.NetworkSub,
|
||||
Type: schema.NetworkSub,
|
||||
Info: extclient,
|
||||
},
|
||||
NetworkID: models.NetworkID(extclient.Network),
|
||||
Origin: models.ClientApp,
|
||||
NetworkID: schema.NetworkID(extclient.Network),
|
||||
Origin: schema.ClientApp,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1008,7 +1028,7 @@ func updateExtClient(w http.ResponseWriter, r *http.Request) {
|
||||
if newclient.DeviceID != "" && newclient.Enabled {
|
||||
// check for violations connecting from desktop app
|
||||
staticNode := newclient.ConvertToStaticNode()
|
||||
violations, _ := logic.CheckPostureViolations(logic.GetPostureCheckDeviceInfoByNode(&staticNode), models.NetworkID(newclient.Network))
|
||||
violations, _ := logic.CheckPostureViolations(logic.GetPostureCheckDeviceInfoByNode(&staticNode), schema.NetworkID(newclient.Network))
|
||||
if len(violations) > 0 {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("posture check violations"), logic.Forbidden))
|
||||
return
|
||||
|
||||
+44
-23
@@ -1,6 +1,7 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -9,16 +10,18 @@ import (
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/gravitl/netmaker/db"
|
||||
"github.com/gravitl/netmaker/logger"
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/gravitl/netmaker/mq"
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
"github.com/gravitl/netmaker/servercfg"
|
||||
"golang.org/x/exp/slog"
|
||||
)
|
||||
|
||||
func gwHandlers(r *mux.Router) {
|
||||
r.HandleFunc("/api/nodes/{network}/{nodeid}/gateway", logic.SecurityCheck(true, checkFreeTierLimits(limitChoiceIngress, http.HandlerFunc(createGateway)))).Methods(http.MethodPost)
|
||||
r.HandleFunc("/api/nodes/{network}/{nodeid}/gateway", logic.SecurityCheck(true, http.HandlerFunc(createGateway))).Methods(http.MethodPost)
|
||||
r.HandleFunc("/api/nodes/{network}/{nodeid}/gateway", logic.SecurityCheck(true, http.HandlerFunc(deleteGateway))).Methods(http.MethodDelete)
|
||||
r.HandleFunc("/api/nodes/{network}/{nodeid}/gateway/assign", logic.SecurityCheck(true, http.HandlerFunc(assignGw))).Methods(http.MethodPost)
|
||||
r.HandleFunc("/api/nodes/{network}/{nodeid}/gateway/unassign", logic.SecurityCheck(true, http.HandlerFunc(unassignGw))).Methods(http.MethodPost)
|
||||
@@ -48,7 +51,10 @@ func createGateway(w http.ResponseWriter, r *http.Request) {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
|
||||
return
|
||||
}
|
||||
host, err := logic.GetHost(node.HostID.String())
|
||||
host := &schema.Host{
|
||||
ID: node.HostID,
|
||||
}
|
||||
err = host.Get(r.Context())
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
|
||||
return
|
||||
@@ -143,19 +149,19 @@ func createGateway(w http.ResponseWriter, r *http.Request) {
|
||||
logic.GetNodeStatus(&relayNode, false)
|
||||
apiNode := relayNode.ConvertToAPINode()
|
||||
logic.LogEvent(&models.Event{
|
||||
Action: models.Create,
|
||||
Action: schema.Create,
|
||||
Source: models.Subject{
|
||||
ID: r.Header.Get("user"),
|
||||
Name: r.Header.Get("user"),
|
||||
Type: models.UserSub,
|
||||
Type: schema.UserSub,
|
||||
},
|
||||
TriggeredBy: r.Header.Get("user"),
|
||||
Target: models.Subject{
|
||||
ID: node.ID.String(),
|
||||
Name: host.Name,
|
||||
Type: models.GatewaySub,
|
||||
Type: schema.GatewaySub,
|
||||
},
|
||||
Origin: models.Dashboard,
|
||||
Origin: schema.Dashboard,
|
||||
})
|
||||
host.IsStaticPort = true
|
||||
logic.UpsertHost(host)
|
||||
@@ -210,7 +216,10 @@ func deleteGateway(w http.ResponseWriter, r *http.Request) {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
|
||||
return
|
||||
}
|
||||
host, err := logic.GetHost(node.HostID.String())
|
||||
host := &schema.Host{
|
||||
ID: node.HostID,
|
||||
}
|
||||
err = host.Get(r.Context())
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
|
||||
return
|
||||
@@ -225,7 +234,10 @@ func deleteGateway(w http.ResponseWriter, r *http.Request) {
|
||||
logger.Log(1, r.Header.Get("user"), "deleted gw", nodeid, "on network", netid)
|
||||
|
||||
go func() {
|
||||
host, err := logic.GetHost(node.HostID.String())
|
||||
host := &schema.Host{
|
||||
ID: node.HostID,
|
||||
}
|
||||
err = host.Get(db.WithContext(context.TODO()))
|
||||
if err == nil {
|
||||
allNodes, err := logic.GetAllNodes()
|
||||
if err != nil {
|
||||
@@ -246,7 +258,10 @@ func deleteGateway(w http.ResponseWriter, r *http.Request) {
|
||||
)
|
||||
|
||||
}
|
||||
h, err := logic.GetHost(relayedNode.HostID.String())
|
||||
h := &schema.Host{
|
||||
ID: relayedNode.HostID,
|
||||
}
|
||||
err = h.Get(db.WithContext(context.TODO()))
|
||||
if err == nil {
|
||||
if h.OS == models.OS_Types.IoT {
|
||||
nodes, err := logic.GetAllNodes()
|
||||
@@ -283,19 +298,19 @@ func deleteGateway(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
}()
|
||||
logic.LogEvent(&models.Event{
|
||||
Action: models.Delete,
|
||||
Action: schema.Delete,
|
||||
Source: models.Subject{
|
||||
ID: r.Header.Get("user"),
|
||||
Name: r.Header.Get("user"),
|
||||
Type: models.UserSub,
|
||||
Type: schema.UserSub,
|
||||
},
|
||||
TriggeredBy: r.Header.Get("user"),
|
||||
Target: models.Subject{
|
||||
ID: node.ID.String(),
|
||||
Name: host.Name,
|
||||
Type: models.GatewaySub,
|
||||
Type: schema.GatewaySub,
|
||||
},
|
||||
Origin: models.Dashboard,
|
||||
Origin: schema.Dashboard,
|
||||
Diff: models.Diff{
|
||||
Old: node,
|
||||
New: node,
|
||||
@@ -405,7 +420,10 @@ func assignGw(w http.ResponseWriter, r *http.Request) {
|
||||
newNodes = logic.UniqueStrings(newNodes)
|
||||
logic.UpdateRelayNodes(gatewayNode.ID.String(), gatewayNode.RelayedNodes, newNodes)
|
||||
|
||||
host, err := logic.GetHost(node.HostID.String())
|
||||
host := &schema.Host{
|
||||
ID: node.HostID,
|
||||
}
|
||||
err = host.Get(r.Context())
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
|
||||
return
|
||||
@@ -416,19 +434,19 @@ func assignGw(w http.ResponseWriter, r *http.Request) {
|
||||
nodeid, netid))
|
||||
|
||||
logic.LogEvent(&models.Event{
|
||||
Action: models.GatewayAssign,
|
||||
Action: schema.GatewayAssign,
|
||||
Source: models.Subject{
|
||||
ID: r.Header.Get("user"),
|
||||
Name: r.Header.Get("user"),
|
||||
Type: models.UserSub,
|
||||
Type: schema.UserSub,
|
||||
},
|
||||
TriggeredBy: r.Header.Get("user"),
|
||||
Target: models.Subject{
|
||||
ID: node.ID.String(),
|
||||
Name: host.Name,
|
||||
Type: models.GatewaySub,
|
||||
Type: schema.GatewaySub,
|
||||
},
|
||||
Origin: models.Dashboard,
|
||||
Origin: schema.Dashboard,
|
||||
})
|
||||
|
||||
logic.GetNodeStatus(&node, false)
|
||||
@@ -465,7 +483,10 @@ func unassignGw(w http.ResponseWriter, r *http.Request) {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
|
||||
return
|
||||
}
|
||||
host, err := logic.GetHost(node.HostID.String())
|
||||
host := &schema.Host{
|
||||
ID: node.HostID,
|
||||
}
|
||||
err = host.Get(r.Context())
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
|
||||
return
|
||||
@@ -536,19 +557,19 @@ func unassignGw(w http.ResponseWriter, r *http.Request) {
|
||||
nodeid, netid))
|
||||
|
||||
logic.LogEvent(&models.Event{
|
||||
Action: models.GatewayUnAssign,
|
||||
Action: schema.GatewayUnAssign,
|
||||
Source: models.Subject{
|
||||
ID: r.Header.Get("user"),
|
||||
Name: r.Header.Get("user"),
|
||||
Type: models.UserSub,
|
||||
Type: schema.UserSub,
|
||||
},
|
||||
TriggeredBy: r.Header.Get("user"),
|
||||
Target: models.Subject{
|
||||
ID: node.ID.String(),
|
||||
Name: host.Name,
|
||||
Type: models.GatewaySub,
|
||||
Type: schema.GatewaySub,
|
||||
},
|
||||
Origin: models.Dashboard,
|
||||
Origin: schema.Dashboard,
|
||||
})
|
||||
|
||||
logic.GetNodeStatus(&node, false)
|
||||
|
||||
+304
-104
@@ -1,16 +1,19 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/gravitl/netmaker/database"
|
||||
"github.com/gravitl/netmaker/db"
|
||||
dbtypes "github.com/gravitl/netmaker/db/types"
|
||||
"github.com/gravitl/netmaker/logger"
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
@@ -24,6 +27,8 @@ import (
|
||||
func hostHandlers(r *mux.Router) {
|
||||
r.HandleFunc("/api/hosts", logic.SecurityCheck(true, http.HandlerFunc(getHosts))).
|
||||
Methods(http.MethodGet)
|
||||
r.HandleFunc("/api/v1/hosts", logic.SecurityCheck(true, http.HandlerFunc(listHosts))).
|
||||
Methods(http.MethodGet)
|
||||
r.HandleFunc("/api/hosts/keys", logic.SecurityCheck(true, http.HandlerFunc(updateAllKeys))).
|
||||
Methods(http.MethodPut)
|
||||
r.HandleFunc("/api/hosts/sync", logic.SecurityCheck(true, http.HandlerFunc(syncHosts))).
|
||||
@@ -36,10 +41,10 @@ func hostHandlers(r *mux.Router) {
|
||||
Methods(http.MethodPost)
|
||||
r.HandleFunc("/api/hosts/{hostid}", logic.SecurityCheck(true, http.HandlerFunc(updateHost))).
|
||||
Methods(http.MethodPut)
|
||||
// used by netclient
|
||||
// used by netclient
|
||||
r.HandleFunc("/api/hosts/{hostid}", AuthorizeHost(http.HandlerFunc(deleteHost))).
|
||||
Methods(http.MethodDelete)
|
||||
// used by UI
|
||||
// used by UI
|
||||
r.HandleFunc("/api/v1/ui/hosts/{hostid}", logic.SecurityCheck(true, http.HandlerFunc(deleteHost))).
|
||||
Methods(http.MethodDelete)
|
||||
r.HandleFunc("/api/hosts/{hostid}/upgrade", logic.SecurityCheck(true, http.HandlerFunc(upgradeHost))).
|
||||
@@ -88,14 +93,14 @@ func upgradeHosts(w http.ResponseWriter, r *http.Request) {
|
||||
go func() {
|
||||
slog.Info("requesting all hosts to upgrade", "user", user)
|
||||
|
||||
hosts, err := logic.GetAllHosts()
|
||||
hosts, err := (&schema.Host{}).ListAll(r.Context())
|
||||
if err != nil {
|
||||
slog.Error("failed to retrieve all hosts", "user", user, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, host := range hosts {
|
||||
go func(host models.Host) {
|
||||
go func(host schema.Host) {
|
||||
hostUpdate := models.HostUpdate{
|
||||
Action: action,
|
||||
Host: host,
|
||||
@@ -109,19 +114,19 @@ func upgradeHosts(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}()
|
||||
logic.LogEvent(&models.Event{
|
||||
Action: models.UpgradeAll,
|
||||
Action: schema.UpgradeAll,
|
||||
Source: models.Subject{
|
||||
ID: r.Header.Get("user"),
|
||||
Name: r.Header.Get("user"),
|
||||
Type: models.UserSub,
|
||||
Type: schema.UserSub,
|
||||
},
|
||||
TriggeredBy: r.Header.Get("user"),
|
||||
Target: models.Subject{
|
||||
ID: "All Hosts",
|
||||
Name: "All Hosts",
|
||||
Type: models.DeviceSub,
|
||||
Type: schema.DeviceSub,
|
||||
},
|
||||
Origin: models.Dashboard,
|
||||
Origin: schema.Dashboard,
|
||||
})
|
||||
slog.Info("upgrade all hosts request received", "user", user)
|
||||
logic.ReturnSuccessResponse(w, r, "upgrade all hosts request received")
|
||||
@@ -137,7 +142,18 @@ func upgradeHosts(w http.ResponseWriter, r *http.Request) {
|
||||
// @Failure 500 {object} models.ErrorResponse
|
||||
// upgrade host is a handler to send upgrade message to a host
|
||||
func upgradeHost(w http.ResponseWriter, r *http.Request) {
|
||||
host, err := logic.GetHost(mux.Vars(r)["hostid"])
|
||||
hostIDStr := mux.Vars(r)["hostid"]
|
||||
hostID, err := uuid.Parse(hostIDStr)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to parse host id: %w", err)
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.BadReq))
|
||||
return
|
||||
}
|
||||
|
||||
host := &schema.Host{
|
||||
ID: hostID,
|
||||
}
|
||||
err = host.Get(r.Context())
|
||||
if err != nil {
|
||||
slog.Error("failed to find host", "error", err)
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "notfound"))
|
||||
@@ -168,7 +184,7 @@ func upgradeHost(w http.ResponseWriter, r *http.Request) {
|
||||
func getHosts(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
currentHosts, err := logic.GetAllHosts()
|
||||
currentHosts, err := (&schema.Host{}).ListAll(r.Context())
|
||||
if err != nil {
|
||||
logger.Log(0, r.Header.Get("user"), "failed to fetch hosts: ", err.Error())
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
|
||||
@@ -182,6 +198,73 @@ func getHosts(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode(apiHosts)
|
||||
}
|
||||
|
||||
// @Summary List all hosts
|
||||
// @Router /api/v1/hosts [get]
|
||||
// @Tags Hosts
|
||||
// @Security oauth
|
||||
// @Produce json
|
||||
// @Param os query []string false "Filter by OS" Enums(windows, linux, darwin)
|
||||
// @Param page query int false "Page number"
|
||||
// @Param per_page query int false "Items per page"
|
||||
// @Success 200 {array} models.ApiHost
|
||||
// @Failure 500 {object} models.ErrorResponse
|
||||
func listHosts(w http.ResponseWriter, r *http.Request) {
|
||||
var osFilters []interface{}
|
||||
for _, filter := range r.URL.Query()["os"] {
|
||||
osFilters = append(osFilters, filter)
|
||||
}
|
||||
|
||||
var page, pageSize int
|
||||
page, _ = strconv.Atoi(r.URL.Query().Get("page"))
|
||||
if page == 0 {
|
||||
page = 1
|
||||
}
|
||||
|
||||
pageSize, _ = strconv.Atoi(r.URL.Query().Get("per_page"))
|
||||
if pageSize < 1 || pageSize > 100 {
|
||||
pageSize = 10
|
||||
}
|
||||
|
||||
currentHosts, err := (&schema.Host{}).ListAll(
|
||||
r.Context(),
|
||||
dbtypes.WithFilter("os", osFilters...),
|
||||
dbtypes.InAscOrder("name"),
|
||||
dbtypes.WithPagination(page, pageSize),
|
||||
)
|
||||
if err != nil {
|
||||
logger.Log(0, r.Header.Get("user"), "failed to fetch hosts: ", err.Error())
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
|
||||
return
|
||||
}
|
||||
|
||||
apiHosts := logic.GetAllHostsAPI(currentHosts[:])
|
||||
logger.Log(2, r.Header.Get("user"), "fetched all hosts")
|
||||
|
||||
total, err := (&schema.Host{}).Count(
|
||||
r.Context(),
|
||||
dbtypes.WithFilter("os", osFilters...),
|
||||
)
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.Internal))
|
||||
return
|
||||
}
|
||||
|
||||
totalPages := (total + pageSize - 1) / pageSize
|
||||
if totalPages == 0 {
|
||||
totalPages = 1
|
||||
}
|
||||
|
||||
response := models.PaginatedResponse{
|
||||
Data: apiHosts,
|
||||
Page: page,
|
||||
PerPage: pageSize,
|
||||
Total: total,
|
||||
TotalPages: totalPages,
|
||||
}
|
||||
|
||||
logic.ReturnSuccessResponseWithJson(w, r, response, "fetched hosts")
|
||||
}
|
||||
|
||||
// @Summary Used by clients for "pull" command
|
||||
// @Router /api/v1/host [get]
|
||||
// @Tags Hosts
|
||||
@@ -190,9 +273,8 @@ func getHosts(w http.ResponseWriter, r *http.Request) {
|
||||
// @Success 200 {object} models.HostPull
|
||||
// @Failure 500 {object} models.ErrorResponse
|
||||
func pull(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
hostID := r.Header.Get(hostIDHeader) // return JSON/API formatted keys
|
||||
if len(hostID) == 0 {
|
||||
hostIDStr := r.Header.Get(hostIDHeader) // return JSON/API formatted keys
|
||||
if len(hostIDStr) == 0 {
|
||||
logger.Log(0, "no host authorized to pull")
|
||||
logic.ReturnErrorResponse(
|
||||
w,
|
||||
@@ -201,9 +283,20 @@ func pull(w http.ResponseWriter, r *http.Request) {
|
||||
)
|
||||
return
|
||||
}
|
||||
host, err := logic.GetHost(hostID)
|
||||
|
||||
hostID, err := uuid.Parse(hostIDStr)
|
||||
if err != nil {
|
||||
logger.Log(0, "no host found during pull", hostID)
|
||||
err = fmt.Errorf("failed to parse host id: %w", err)
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.BadReq))
|
||||
return
|
||||
}
|
||||
|
||||
host := &schema.Host{
|
||||
ID: hostID,
|
||||
}
|
||||
err = host.Get(r.Context())
|
||||
if err != nil {
|
||||
logger.Log(0, "no host found during pull", hostIDStr)
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
|
||||
return
|
||||
}
|
||||
@@ -228,13 +321,13 @@ func pull(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
allNodes, err := logic.GetAllNodes()
|
||||
if err != nil {
|
||||
logger.Log(0, "failed to get nodes: ", hostID)
|
||||
logger.Log(0, "failed to get nodes: ", hostIDStr)
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
|
||||
return
|
||||
}
|
||||
hPU, err := logic.GetPeerUpdateForHost("", host, allNodes, nil, nil)
|
||||
if err != nil {
|
||||
logger.Log(0, "could not pull peers for host", hostID, err.Error())
|
||||
logger.Log(0, "could not pull peers for host", hostIDStr, err.Error())
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
|
||||
return
|
||||
}
|
||||
@@ -269,7 +362,7 @@ func pull(w http.ResponseWriter, r *http.Request) {
|
||||
AddressIdentityMap: hPU.AddressIdentityMap,
|
||||
}
|
||||
|
||||
logger.Log(1, hostID, host.Name, "completed a pull")
|
||||
logger.Log(1, hostIDStr, host.Name, "completed a pull")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(&response)
|
||||
}
|
||||
@@ -293,8 +386,18 @@ func updateHost(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
hostID, err := uuid.Parse(newHostData.ID)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to parse host id: %w", err)
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.BadReq))
|
||||
return
|
||||
}
|
||||
|
||||
// confirm host exists
|
||||
currHost, err := logic.GetHost(newHostData.ID)
|
||||
currHost := &schema.Host{
|
||||
ID: hostID,
|
||||
}
|
||||
err = currHost.Get(r.Context())
|
||||
if err != nil {
|
||||
logger.Log(0, r.Header.Get("user"), "failed to update a host:", err.Error())
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
|
||||
@@ -350,25 +453,25 @@ func updateHost(w http.ResponseWriter, r *http.Request) {
|
||||
}()
|
||||
|
||||
logic.LogEvent(&models.Event{
|
||||
Action: models.Update,
|
||||
Action: schema.Update,
|
||||
Source: models.Subject{
|
||||
ID: r.Header.Get("user"),
|
||||
Name: r.Header.Get("user"),
|
||||
Type: models.UserSub,
|
||||
Type: schema.UserSub,
|
||||
},
|
||||
TriggeredBy: r.Header.Get("user"),
|
||||
Target: models.Subject{
|
||||
ID: currHost.ID.String(),
|
||||
Name: newHost.Name,
|
||||
Type: models.DeviceSub,
|
||||
Type: schema.DeviceSub,
|
||||
},
|
||||
Diff: models.Diff{
|
||||
Old: currHost,
|
||||
New: newHost,
|
||||
},
|
||||
Origin: models.Dashboard,
|
||||
Origin: schema.Dashboard,
|
||||
})
|
||||
apiHostData := newHost.ConvertNMHostToAPI()
|
||||
apiHostData := models.NewApiHostFromSchemaHost(newHost)
|
||||
logger.Log(2, r.Header.Get("user"), "updated host", newHost.ID.String())
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(apiHostData)
|
||||
@@ -384,10 +487,19 @@ func updateHost(w http.ResponseWriter, r *http.Request) {
|
||||
// @Failure 500 {object} models.ErrorResponse
|
||||
func hostUpdateFallback(w http.ResponseWriter, r *http.Request) {
|
||||
var params = mux.Vars(r)
|
||||
hostid := params["hostid"]
|
||||
currentHost, err := logic.GetHost(hostid)
|
||||
hostIDStr := params["hostid"]
|
||||
hostID, err := uuid.Parse(hostIDStr)
|
||||
if err != nil {
|
||||
slog.Error("error getting host", "id", hostid, "error", err)
|
||||
err = fmt.Errorf("failed to parse host id: %w", err)
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.BadReq))
|
||||
return
|
||||
}
|
||||
currentHost := &schema.Host{
|
||||
ID: hostID,
|
||||
}
|
||||
err = currentHost.Get(r.Context())
|
||||
if err != nil {
|
||||
slog.Error("error getting host", "id", hostIDStr, "error", err)
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
|
||||
return
|
||||
}
|
||||
@@ -466,11 +578,19 @@ func hostUpdateFallback(w http.ResponseWriter, r *http.Request) {
|
||||
// @Failure 500 {object} models.ErrorResponse
|
||||
func deleteHost(w http.ResponseWriter, r *http.Request) {
|
||||
var params = mux.Vars(r)
|
||||
hostid := params["hostid"]
|
||||
hostIDStr := params["hostid"]
|
||||
forceDelete := r.URL.Query().Get("force") == "true"
|
||||
|
||||
hostID, err := uuid.Parse(hostIDStr)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to parse host id: %w", err)
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.BadReq))
|
||||
return
|
||||
}
|
||||
// confirm host exists
|
||||
currHost, err := logic.GetHost(hostid)
|
||||
currHost := &schema.Host{
|
||||
ID: hostID,
|
||||
}
|
||||
err = currHost.Get(r.Context())
|
||||
if err != nil {
|
||||
logger.Log(0, r.Header.Get("user"), "failed to delete a host:", err.Error())
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
|
||||
@@ -519,25 +639,25 @@ func deleteHost(w http.ResponseWriter, r *http.Request) {
|
||||
HostID: currHost.ID.String(),
|
||||
}).DeleteAllPendingHosts(db.WithContext(r.Context()))
|
||||
logic.LogEvent(&models.Event{
|
||||
Action: models.Delete,
|
||||
Action: schema.Delete,
|
||||
Source: models.Subject{
|
||||
ID: r.Header.Get("user"),
|
||||
Name: r.Header.Get("user"),
|
||||
Type: models.UserSub,
|
||||
Type: schema.UserSub,
|
||||
},
|
||||
TriggeredBy: r.Header.Get("user"),
|
||||
Target: models.Subject{
|
||||
ID: currHost.ID.String(),
|
||||
Name: currHost.Name,
|
||||
Type: models.DeviceSub,
|
||||
Type: schema.DeviceSub,
|
||||
},
|
||||
Origin: models.Dashboard,
|
||||
Origin: schema.Dashboard,
|
||||
Diff: models.Diff{
|
||||
Old: currHost,
|
||||
New: nil,
|
||||
},
|
||||
})
|
||||
apiHostData := currHost.ConvertNMHostToAPI()
|
||||
apiHostData := models.NewApiHostFromSchemaHost(currHost)
|
||||
logger.Log(2, r.Header.Get("user"), "removed host", currHost.Name)
|
||||
logic.ReturnSuccessResponseWithJson(w, r, apiHostData, "deleted host "+currHost.Name)
|
||||
}
|
||||
@@ -553,9 +673,9 @@ func deleteHost(w http.ResponseWriter, r *http.Request) {
|
||||
func addHostToNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
var params = mux.Vars(r)
|
||||
hostid := params["hostid"]
|
||||
hostIDStr := params["hostid"]
|
||||
network := params["network"]
|
||||
if hostid == "" || network == "" {
|
||||
if hostIDStr == "" || network == "" {
|
||||
logic.ReturnErrorResponse(
|
||||
w,
|
||||
r,
|
||||
@@ -563,10 +683,20 @@ func addHostToNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
)
|
||||
return
|
||||
}
|
||||
// confirm host exists
|
||||
currHost, err := logic.GetHost(hostid)
|
||||
hostID, err := uuid.Parse(hostIDStr)
|
||||
if err != nil {
|
||||
logger.Log(0, r.Header.Get("user"), "failed to find host:", hostid, err.Error())
|
||||
err = fmt.Errorf("failed to parse host id: %w", err)
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.BadReq))
|
||||
return
|
||||
}
|
||||
|
||||
// confirm host exists
|
||||
currHost := &schema.Host{
|
||||
ID: hostID,
|
||||
}
|
||||
err = currHost.Get(r.Context())
|
||||
if err != nil {
|
||||
logger.Log(0, r.Header.Get("user"), "failed to find host:", hostIDStr, err.Error())
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.Internal))
|
||||
return
|
||||
}
|
||||
@@ -579,7 +709,7 @@ func addHostToNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
OSVersion: currHost.OSVersion,
|
||||
KernelVersion: currHost.KernelVersion,
|
||||
AutoUpdate: currHost.AutoUpdate,
|
||||
}, models.NetworkID(network))
|
||||
}, schema.NetworkID(network))
|
||||
if len(violations) > 0 {
|
||||
logic.ReturnErrorResponseWithJson(w, r, violations, logic.FormatError(errors.New("posture check violations"), logic.BadReq))
|
||||
return
|
||||
@@ -590,7 +720,7 @@ func addHostToNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
0,
|
||||
r.Header.Get("user"),
|
||||
"failed to add host to network:",
|
||||
hostid,
|
||||
hostIDStr,
|
||||
network,
|
||||
err.Error(),
|
||||
)
|
||||
@@ -623,20 +753,20 @@ func addHostToNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Sprintf("added host %s to network %s", currHost.Name, network),
|
||||
)
|
||||
logic.LogEvent(&models.Event{
|
||||
Action: models.JoinHostToNet,
|
||||
Action: schema.JoinHostToNet,
|
||||
Source: models.Subject{
|
||||
ID: r.Header.Get("user"),
|
||||
Name: r.Header.Get("user"),
|
||||
Type: models.UserSub,
|
||||
Type: schema.UserSub,
|
||||
},
|
||||
TriggeredBy: r.Header.Get("user"),
|
||||
Target: models.Subject{
|
||||
ID: currHost.ID.String(),
|
||||
Name: currHost.Name,
|
||||
Type: models.DeviceSub,
|
||||
Type: schema.DeviceSub,
|
||||
},
|
||||
NetworkID: models.NetworkID(network),
|
||||
Origin: models.Dashboard,
|
||||
NetworkID: schema.NetworkID(network),
|
||||
Origin: schema.Dashboard,
|
||||
})
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
@@ -653,10 +783,10 @@ func addHostToNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
func deleteHostFromNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
var params = mux.Vars(r)
|
||||
hostid := params["hostid"]
|
||||
hostIDStr := params["hostid"]
|
||||
network := params["network"]
|
||||
forceDelete := r.URL.Query().Get("force") == "true"
|
||||
if hostid == "" || network == "" {
|
||||
if hostIDStr == "" || network == "" {
|
||||
logic.ReturnErrorResponse(
|
||||
w,
|
||||
r,
|
||||
@@ -664,17 +794,26 @@ func deleteHostFromNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
)
|
||||
return
|
||||
}
|
||||
hostID, err := uuid.Parse(hostIDStr)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to parse host id: %w", err)
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.BadReq))
|
||||
return
|
||||
}
|
||||
// confirm host exists
|
||||
currHost, err := logic.GetHost(hostid)
|
||||
currHost := &schema.Host{
|
||||
ID: hostID,
|
||||
}
|
||||
err = currHost.Get(r.Context())
|
||||
if err != nil {
|
||||
if database.IsEmptyRecord(err) {
|
||||
// check if there is any daemon nodes that needs to be deleted
|
||||
node, err := logic.GetNodeByHostRef(hostid, network)
|
||||
node, err := logic.GetNodeByHostRef(hostIDStr, network)
|
||||
if err != nil {
|
||||
slog.Error(
|
||||
"couldn't get node for host",
|
||||
"hostid",
|
||||
hostid,
|
||||
hostIDStr,
|
||||
"network",
|
||||
network,
|
||||
"error",
|
||||
@@ -685,7 +824,7 @@ func deleteHostFromNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
if err = logic.DeleteNodeByID(&node); err != nil {
|
||||
slog.Error("failed to force delete daemon node",
|
||||
"nodeid", node.ID.String(), "hostid", hostid, "network", network, "error", err)
|
||||
"nodeid", node.ID.String(), "hostid", hostIDStr, "network", network, "error", err)
|
||||
logic.ReturnErrorResponse(
|
||||
w,
|
||||
r,
|
||||
@@ -709,12 +848,12 @@ func deleteHostFromNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
if err != nil {
|
||||
if node == nil && forceDelete {
|
||||
// force cleanup the node
|
||||
node, err := logic.GetNodeByHostRef(hostid, network)
|
||||
node, err := logic.GetNodeByHostRef(hostIDStr, network)
|
||||
if err != nil {
|
||||
slog.Error(
|
||||
"couldn't get node for host",
|
||||
"hostid",
|
||||
hostid,
|
||||
hostIDStr,
|
||||
"network",
|
||||
network,
|
||||
"error",
|
||||
@@ -725,7 +864,7 @@ func deleteHostFromNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
if err = logic.DeleteNodeByID(&node); err != nil {
|
||||
slog.Error("failed to force delete daemon node",
|
||||
"nodeid", node.ID.String(), "hostid", hostid, "network", network, "error", err)
|
||||
"nodeid", node.ID.String(), "hostid", hostIDStr, "network", network, "error", err)
|
||||
logic.ReturnErrorResponse(
|
||||
w,
|
||||
r,
|
||||
@@ -743,7 +882,7 @@ func deleteHostFromNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
0,
|
||||
r.Header.Get("user"),
|
||||
"failed to remove host from network:",
|
||||
hostid,
|
||||
hostIDStr,
|
||||
network,
|
||||
err.Error(),
|
||||
)
|
||||
@@ -766,20 +905,20 @@ func deleteHostFromNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}()
|
||||
logic.LogEvent(&models.Event{
|
||||
Action: models.RemoveHostFromNet,
|
||||
Action: schema.RemoveHostFromNet,
|
||||
Source: models.Subject{
|
||||
ID: r.Header.Get("user"),
|
||||
Name: r.Header.Get("user"),
|
||||
Type: models.UserSub,
|
||||
Type: schema.UserSub,
|
||||
},
|
||||
TriggeredBy: r.Header.Get("user"),
|
||||
Target: models.Subject{
|
||||
ID: currHost.ID.String(),
|
||||
Name: currHost.Name,
|
||||
Type: models.DeviceSub,
|
||||
Type: schema.DeviceSub,
|
||||
},
|
||||
NetworkID: models.NetworkID(network),
|
||||
Origin: models.Dashboard,
|
||||
NetworkID: schema.NetworkID(network),
|
||||
Origin: schema.Dashboard,
|
||||
})
|
||||
logger.Log(
|
||||
2,
|
||||
@@ -829,7 +968,17 @@ func authenticateHost(response http.ResponseWriter, request *http.Request) {
|
||||
logic.ReturnErrorResponse(response, request, errorResponse)
|
||||
return
|
||||
}
|
||||
host, err := logic.GetHost(authRequest.ID)
|
||||
hostID, err := uuid.Parse(authRequest.ID)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to parse host id: %w", err)
|
||||
logic.ReturnErrorResponse(response, request, logic.FormatError(err, logic.BadReq))
|
||||
return
|
||||
}
|
||||
|
||||
host := &schema.Host{
|
||||
ID: hostID,
|
||||
}
|
||||
err = host.Get(request.Context())
|
||||
if err != nil {
|
||||
errorResponse.Code = http.StatusBadRequest
|
||||
errorResponse.Message = err.Error()
|
||||
@@ -900,9 +1049,17 @@ func authenticateHost(response http.ResponseWriter, request *http.Request) {
|
||||
// @Failure 400 {object} models.ErrorResponse
|
||||
func signalPeer(w http.ResponseWriter, r *http.Request) {
|
||||
var params = mux.Vars(r)
|
||||
hostid := params["hostid"]
|
||||
hostIDStr := params["hostid"]
|
||||
hostID, err := uuid.Parse(hostIDStr)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to parse host id: %w", err)
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.BadReq))
|
||||
return
|
||||
}
|
||||
// confirm host exists
|
||||
_, err := logic.GetHost(hostid)
|
||||
err = (&schema.Host{
|
||||
ID: hostID,
|
||||
}).Get(r.Context())
|
||||
if err != nil {
|
||||
logger.Log(0, r.Header.Get("user"), "failed to get host:", err.Error())
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
|
||||
@@ -923,7 +1080,16 @@ func signalPeer(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
signal.IsPro = servercfg.IsPro
|
||||
peerHost, err := logic.GetHost(signal.ToHostID)
|
||||
hostID, err = uuid.Parse(signal.ToHostID)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to parse host id: %w", err)
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.BadReq))
|
||||
return
|
||||
}
|
||||
peerHost := &schema.Host{
|
||||
ID: hostID,
|
||||
}
|
||||
err = peerHost.Get(r.Context())
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(
|
||||
w,
|
||||
@@ -962,7 +1128,7 @@ func signalPeer(w http.ResponseWriter, r *http.Request) {
|
||||
func updateAllKeys(w http.ResponseWriter, r *http.Request) {
|
||||
var errorResponse = models.ErrorResponse{}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
hosts, err := logic.GetAllHosts()
|
||||
hosts, err := (&schema.Host{}).ListAll(r.Context())
|
||||
if err != nil {
|
||||
errorResponse.Code = http.StatusBadRequest
|
||||
errorResponse.Message = err.Error()
|
||||
@@ -988,19 +1154,19 @@ func updateAllKeys(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}()
|
||||
logic.LogEvent(&models.Event{
|
||||
Action: models.RefreshAllKeys,
|
||||
Action: schema.RefreshAllKeys,
|
||||
Source: models.Subject{
|
||||
ID: r.Header.Get("user"),
|
||||
Name: r.Header.Get("user"),
|
||||
Type: models.UserSub,
|
||||
Type: schema.UserSub,
|
||||
},
|
||||
TriggeredBy: r.Header.Get("user"),
|
||||
Target: models.Subject{
|
||||
ID: "All Devices",
|
||||
Name: "All Devices",
|
||||
Type: models.DeviceSub,
|
||||
Type: schema.DeviceSub,
|
||||
},
|
||||
Origin: models.Dashboard,
|
||||
Origin: schema.Dashboard,
|
||||
})
|
||||
logger.Log(2, r.Header.Get("user"), "updated keys for all hosts")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
@@ -1017,10 +1183,19 @@ func updateKeys(w http.ResponseWriter, r *http.Request) {
|
||||
var errorResponse = models.ErrorResponse{}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
var params = mux.Vars(r)
|
||||
hostid := params["hostid"]
|
||||
host, err := logic.GetHost(hostid)
|
||||
hostIDStr := params["hostid"]
|
||||
hostID, err := uuid.Parse(hostIDStr)
|
||||
if err != nil {
|
||||
logger.Log(0, "failed to retrieve host", hostid, err.Error())
|
||||
err = fmt.Errorf("failed to parse host id: %w", err)
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.BadReq))
|
||||
return
|
||||
}
|
||||
host := &schema.Host{
|
||||
ID: hostID,
|
||||
}
|
||||
err = host.Get(r.Context())
|
||||
if err != nil {
|
||||
logger.Log(0, "failed to retrieve host", hostIDStr, err.Error())
|
||||
errorResponse.Code = http.StatusBadRequest
|
||||
errorResponse.Message = err.Error()
|
||||
logger.Log(0, r.Header.Get("user"),
|
||||
@@ -1038,19 +1213,19 @@ func updateKeys(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}()
|
||||
logic.LogEvent(&models.Event{
|
||||
Action: models.RefreshKey,
|
||||
Action: schema.RefreshKey,
|
||||
Source: models.Subject{
|
||||
ID: r.Header.Get("user"),
|
||||
Name: r.Header.Get("user"),
|
||||
Type: models.UserSub,
|
||||
Type: schema.UserSub,
|
||||
},
|
||||
TriggeredBy: r.Header.Get("user"),
|
||||
Target: models.Subject{
|
||||
ID: host.ID.String(),
|
||||
Name: host.Name,
|
||||
Type: models.DeviceSub,
|
||||
Type: schema.DeviceSub,
|
||||
},
|
||||
Origin: models.Dashboard,
|
||||
Origin: schema.Dashboard,
|
||||
})
|
||||
logger.Log(2, r.Header.Get("user"), "updated key on host", host.Name)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
@@ -1069,14 +1244,14 @@ func syncHosts(w http.ResponseWriter, r *http.Request) {
|
||||
go func() {
|
||||
slog.Info("requesting all hosts to sync", "user", user)
|
||||
|
||||
hosts, err := logic.GetAllHosts()
|
||||
hosts, err := (&schema.Host{}).ListAll(r.Context())
|
||||
if err != nil {
|
||||
slog.Error("failed to retrieve all hosts", "user", user, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, host := range hosts {
|
||||
go func(host models.Host) {
|
||||
go func(host schema.Host) {
|
||||
hostUpdate := models.HostUpdate{
|
||||
Action: models.RequestPull,
|
||||
Host: host,
|
||||
@@ -1091,19 +1266,19 @@ func syncHosts(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}()
|
||||
logic.LogEvent(&models.Event{
|
||||
Action: models.SyncAll,
|
||||
Action: schema.SyncAll,
|
||||
Source: models.Subject{
|
||||
ID: r.Header.Get("user"),
|
||||
Name: r.Header.Get("user"),
|
||||
Type: models.UserSub,
|
||||
Type: schema.UserSub,
|
||||
},
|
||||
TriggeredBy: r.Header.Get("user"),
|
||||
Target: models.Subject{
|
||||
ID: "All Devices",
|
||||
Name: "All Devices",
|
||||
Type: models.DeviceSub,
|
||||
Type: schema.DeviceSub,
|
||||
},
|
||||
Origin: models.Dashboard,
|
||||
Origin: schema.Dashboard,
|
||||
})
|
||||
slog.Info("sync all hosts request received", "user", user)
|
||||
logic.ReturnSuccessResponse(w, r, "sync all hosts request received")
|
||||
@@ -1117,12 +1292,20 @@ func syncHosts(w http.ResponseWriter, r *http.Request) {
|
||||
// @Success 200 {string} string "OK"
|
||||
// @Failure 400 {object} models.ErrorResponse
|
||||
func syncHost(w http.ResponseWriter, r *http.Request) {
|
||||
hostId := mux.Vars(r)["hostid"]
|
||||
hostIDStr := mux.Vars(r)["hostid"]
|
||||
|
||||
var errorResponse = models.ErrorResponse{}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
host, err := logic.GetHost(hostId)
|
||||
hostID, err := uuid.Parse(hostIDStr)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to parse host id: %w", err)
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.BadReq))
|
||||
return
|
||||
}
|
||||
host := &schema.Host{
|
||||
ID: hostID,
|
||||
}
|
||||
err = host.Get(r.Context())
|
||||
if err != nil {
|
||||
slog.Error("failed to retrieve host", "user", r.Header.Get("user"), "error", err)
|
||||
errorResponse.Code = http.StatusBadRequest
|
||||
@@ -1141,26 +1324,26 @@ func syncHost(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}()
|
||||
logic.LogEvent(&models.Event{
|
||||
Action: models.Sync,
|
||||
Action: schema.Sync,
|
||||
Source: models.Subject{
|
||||
ID: r.Header.Get("user"),
|
||||
Name: r.Header.Get("user"),
|
||||
Type: models.UserSub,
|
||||
Type: schema.UserSub,
|
||||
},
|
||||
TriggeredBy: r.Header.Get("user"),
|
||||
Target: models.Subject{
|
||||
ID: host.ID.String(),
|
||||
Name: host.Name,
|
||||
Type: models.DeviceSub,
|
||||
Type: schema.DeviceSub,
|
||||
},
|
||||
Origin: models.Dashboard,
|
||||
Origin: schema.Dashboard,
|
||||
})
|
||||
slog.Info("requested host pull", "user", r.Header.Get("user"), "host", host.ID.String())
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
func delEmqxHosts(w http.ResponseWriter, r *http.Request) {
|
||||
currentHosts, err := logic.GetAllHosts()
|
||||
currentHosts, err := (&schema.Host{}).ListAll(r.Context())
|
||||
if err != nil {
|
||||
logger.Log(0, r.Header.Get("user"), "failed to fetch hosts: ", err.Error())
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
|
||||
@@ -1194,10 +1377,18 @@ func delEmqxHosts(w http.ResponseWriter, r *http.Request) {
|
||||
// @Success 200 {object} models.HostPeerInfo
|
||||
// @Failure 500 {object} models.ErrorResponse
|
||||
func getHostPeerInfo(w http.ResponseWriter, r *http.Request) {
|
||||
hostId := mux.Vars(r)["hostid"]
|
||||
hostIDStr := mux.Vars(r)["hostid"]
|
||||
var errorResponse = models.ErrorResponse{}
|
||||
|
||||
host, err := logic.GetHost(hostId)
|
||||
hostID, err := uuid.Parse(hostIDStr)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to parse host id: %w", err)
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.BadReq))
|
||||
return
|
||||
}
|
||||
host := &schema.Host{
|
||||
ID: hostID,
|
||||
}
|
||||
err = host.Get(r.Context())
|
||||
if err != nil {
|
||||
slog.Error("failed to retrieve host", "error", err)
|
||||
errorResponse.Code = http.StatusBadRequest
|
||||
@@ -1263,7 +1454,16 @@ func approvePendingHost(w http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
return
|
||||
}
|
||||
h, err := logic.GetHost(p.HostID)
|
||||
hostID, err := uuid.Parse(p.HostID)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to parse host id: %w", err)
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.BadReq))
|
||||
return
|
||||
}
|
||||
h := &schema.Host{
|
||||
ID: hostID,
|
||||
}
|
||||
err = h.Get(r.Context())
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(w, r, models.ErrorResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
@@ -1364,22 +1564,22 @@ func rejectPendingHost(w http.ResponseWriter, r *http.Request) {
|
||||
// addDefaultHostToNetworks enrolls a newly-made-default host into every
|
||||
// existing network it is not already part of, applying the standard default
|
||||
// host operations for each network.
|
||||
func addDefaultHostToNetworks(host *models.Host) {
|
||||
networks, err := logic.GetNetworks()
|
||||
func addDefaultHostToNetworks(host *schema.Host) {
|
||||
networks, err := (&schema.Network{}).ListAll(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
logger.Log(0, "failed to get networks for default host ops:", err.Error())
|
||||
return
|
||||
}
|
||||
for _, network := range networks {
|
||||
if network.AutoJoin != "true" {
|
||||
if !network.AutoJoin {
|
||||
continue
|
||||
}
|
||||
newNode, err := logic.UpdateHostNetwork(host, network.NetID, true)
|
||||
newNode, err := logic.UpdateHostNetwork(host, network.Name, true)
|
||||
if err != nil {
|
||||
logger.Log(2, "skipping network", network.NetID, "for default host", host.Name, ":", err.Error())
|
||||
logger.Log(2, "skipping network", network.Name, "for default host", host.Name, ":", err.Error())
|
||||
continue
|
||||
}
|
||||
logger.Log(1, "added default host", host.Name, "to network", network.NetID)
|
||||
logger.Log(1, "added default host", host.Name, "to network", network.Name)
|
||||
if len(host.Nodes) == 1 {
|
||||
mq.HostUpdate(&models.HostUpdate{
|
||||
Action: models.RequestPull,
|
||||
@@ -1393,10 +1593,10 @@ func addDefaultHostToNetworks(host *models.Host) {
|
||||
Node: *newNode,
|
||||
})
|
||||
}
|
||||
logic.CreateIngressGateway(network.NetID, newNode.ID.String(), models.IngressRequest{})
|
||||
logic.CreateIngressGateway(network.Name, newNode.ID.String(), models.IngressRequest{})
|
||||
logic.CreateRelay(models.RelayRequest{
|
||||
NodeID: newNode.ID.String(),
|
||||
NetID: network.NetID,
|
||||
NetID: network.Name,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/gravitl/netmaker/mq"
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
"github.com/gravitl/netmaker/servercfg"
|
||||
)
|
||||
|
||||
@@ -33,7 +34,10 @@ func createInternetGw(w http.ResponseWriter, r *http.Request) {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
|
||||
return
|
||||
}
|
||||
host, err := logic.GetHost(node.HostID.String())
|
||||
host := &schema.Host{
|
||||
ID: node.HostID,
|
||||
}
|
||||
err = host.Get(r.Context())
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
|
||||
return
|
||||
|
||||
@@ -1,75 +0,0 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gravitl/netmaker/database"
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
)
|
||||
|
||||
// limit consts
|
||||
const (
|
||||
limitChoiceNetworks = iota
|
||||
limitChoiceUsers
|
||||
limitChoiceMachines
|
||||
limitChoiceIngress
|
||||
limitChoiceEgress
|
||||
)
|
||||
|
||||
func checkFreeTierLimits(limitChoice int, next http.Handler) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
var errorResponse = models.ErrorResponse{
|
||||
Code: http.StatusForbidden, Message: "free tier limits exceeded on ",
|
||||
}
|
||||
|
||||
if logic.FreeTier { // check that free tier limits not exceeded
|
||||
switch limitChoice {
|
||||
case limitChoiceNetworks:
|
||||
currentNetworks, err := logic.GetNetworks()
|
||||
if (err != nil && !database.IsEmptyRecord(err)) ||
|
||||
len(currentNetworks) >= logic.NetworksLimit {
|
||||
errorResponse.Message += "networks"
|
||||
logic.ReturnErrorResponse(w, r, errorResponse)
|
||||
return
|
||||
}
|
||||
case limitChoiceUsers:
|
||||
users, err := logic.GetUsers()
|
||||
if (err != nil && !database.IsEmptyRecord(err)) ||
|
||||
len(users) >= logic.UsersLimit {
|
||||
errorResponse.Message += "users"
|
||||
logic.ReturnErrorResponse(w, r, errorResponse)
|
||||
return
|
||||
}
|
||||
case limitChoiceMachines:
|
||||
hosts, hErr := logic.GetAllHosts()
|
||||
clients, cErr := logic.GetAllExtClients()
|
||||
if (hErr != nil && !database.IsEmptyRecord(hErr)) ||
|
||||
(cErr != nil && !database.IsEmptyRecord(cErr)) ||
|
||||
len(hosts)+len(clients) >= logic.MachinesLimit {
|
||||
errorResponse.Message += "machines"
|
||||
logic.ReturnErrorResponse(w, r, errorResponse)
|
||||
return
|
||||
}
|
||||
case limitChoiceIngress:
|
||||
ingresses, err := logic.GetAllIngresses()
|
||||
if (err != nil && !database.IsEmptyRecord(err)) ||
|
||||
len(ingresses) >= logic.IngressesLimit {
|
||||
errorResponse.Message += "ingresses"
|
||||
logic.ReturnErrorResponse(w, r, errorResponse)
|
||||
return
|
||||
}
|
||||
case limitChoiceEgress:
|
||||
egresses, err := logic.GetAllEgresses()
|
||||
if (err != nil && !database.IsEmptyRecord(err)) ||
|
||||
len(egresses) >= logic.EgressesLimit {
|
||||
errorResponse.Message += "egresses"
|
||||
logic.ReturnErrorResponse(w, r, errorResponse)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
}
|
||||
}
|
||||
+27
-27
@@ -6,7 +6,7 @@ import (
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
)
|
||||
|
||||
func userMiddleWare(handler http.Handler) http.Handler {
|
||||
@@ -33,79 +33,79 @@ func userMiddleWare(handler http.Handler) http.Handler {
|
||||
r.Header.Set("NET_ID", r.URL.Query().Get("network"))
|
||||
}
|
||||
if strings.Contains(route, "hosts") || strings.Contains(route, "nodes") {
|
||||
r.Header.Set("TARGET_RSRC", models.HostRsrc.String())
|
||||
r.Header.Set("TARGET_RSRC", schema.HostRsrc.String())
|
||||
}
|
||||
if strings.Contains(route, "dns") {
|
||||
r.Header.Set("TARGET_RSRC", models.DnsRsrc.String())
|
||||
r.Header.Set("TARGET_RSRC", schema.DnsRsrc.String())
|
||||
}
|
||||
if strings.Contains(route, "rac") {
|
||||
r.Header.Set("RAC", "true")
|
||||
}
|
||||
if strings.Contains(route, "users") {
|
||||
r.Header.Set("TARGET_RSRC", models.UserRsrc.String())
|
||||
r.Header.Set("TARGET_RSRC", schema.UserRsrc.String())
|
||||
}
|
||||
if strings.Contains(route, "ingress") {
|
||||
r.Header.Set("TARGET_RSRC", models.RemoteAccessGwRsrc.String())
|
||||
r.Header.Set("TARGET_RSRC", schema.RemoteAccessGwRsrc.String())
|
||||
}
|
||||
if strings.Contains(route, "createrelay") || strings.Contains(route, "deleterelay") {
|
||||
r.Header.Set("TARGET_RSRC", models.RelayRsrc.String())
|
||||
r.Header.Set("TARGET_RSRC", schema.RelayRsrc.String())
|
||||
}
|
||||
if strings.Contains(route, "gateway") {
|
||||
r.Header.Set("TARGET_RSRC", models.GatewayRsrc.String())
|
||||
r.Header.Set("TARGET_RSRC", schema.GatewayRsrc.String())
|
||||
}
|
||||
|
||||
if strings.Contains(route, "egress") {
|
||||
r.Header.Set("TARGET_RSRC", models.EgressGwRsrc.String())
|
||||
r.Header.Set("TARGET_RSRC", schema.EgressGwRsrc.String())
|
||||
}
|
||||
if strings.Contains(route, "networks") {
|
||||
r.Header.Set("TARGET_RSRC", models.NetworkRsrc.String())
|
||||
r.Header.Set("TARGET_RSRC", schema.NetworkRsrc.String())
|
||||
}
|
||||
// check 'graph' after 'networks', otherwise the
|
||||
// header will be overwritten.
|
||||
if strings.Contains(route, "graph") {
|
||||
r.Header.Set("TARGET_RSRC", models.HostRsrc.String())
|
||||
r.Header.Set("TARGET_RSRC", schema.HostRsrc.String())
|
||||
}
|
||||
if strings.Contains(route, "acls") {
|
||||
r.Header.Set("TARGET_RSRC", models.AclRsrc.String())
|
||||
r.Header.Set("TARGET_RSRC", schema.AclRsrc.String())
|
||||
}
|
||||
if strings.Contains(route, "tags") {
|
||||
r.Header.Set("TARGET_RSRC", models.TagRsrc.String())
|
||||
r.Header.Set("TARGET_RSRC", schema.TagRsrc.String())
|
||||
}
|
||||
if strings.Contains(route, "extclients") || strings.Contains(route, "client_conf") {
|
||||
r.Header.Set("TARGET_RSRC", models.ExtClientsRsrc.String())
|
||||
r.Header.Set("TARGET_RSRC", schema.ExtClientsRsrc.String())
|
||||
}
|
||||
if strings.Contains(route, "enrollment-keys") {
|
||||
r.Header.Set("TARGET_RSRC", models.EnrollmentKeysRsrc.String())
|
||||
r.Header.Set("TARGET_RSRC", schema.EnrollmentKeysRsrc.String())
|
||||
}
|
||||
if strings.Contains(route, "posture_check") {
|
||||
r.Header.Set("TARGET_RSRC", models.PostureCheckRsrc.String())
|
||||
r.Header.Set("TARGET_RSRC", schema.PostureCheckRsrc.String())
|
||||
}
|
||||
if strings.Contains(route, "activity") {
|
||||
r.Header.Set("TARGET_RSRC", models.UserActivityRsrc.String())
|
||||
r.Header.Set("TARGET_RSRC", schema.UserActivityRsrc.String())
|
||||
}
|
||||
if strings.Contains(route, "nameserver") {
|
||||
r.Header.Set("TARGET_RSRC", models.NameserverRsrc.String())
|
||||
r.Header.Set("TARGET_RSRC", schema.NameserverRsrc.String())
|
||||
}
|
||||
if strings.Contains(route, "jit") {
|
||||
r.Header.Set("TARGET_RSRC", models.JitAdminRsrc.String())
|
||||
r.Header.Set("TARGET_RSRC", schema.JitAdminRsrc.String())
|
||||
}
|
||||
if strings.Contains(route, "jit_user") {
|
||||
r.Header.Set("TARGET_RSRC", models.JitUserRsrc.String())
|
||||
r.Header.Set("TARGET_RSRC", schema.JitUserRsrc.String())
|
||||
}
|
||||
if strings.Contains(route, "metrics") {
|
||||
r.Header.Set("TARGET_RSRC", models.MetricRsrc.String())
|
||||
r.Header.Set("TARGET_RSRC", schema.MetricRsrc.String())
|
||||
}
|
||||
if strings.Contains(route, "flows") {
|
||||
r.Header.Set("TARGET_RSRC", models.TrafficFlow.String())
|
||||
r.Header.Set("TARGET_RSRC", schema.TrafficFlow.String())
|
||||
}
|
||||
if keyID, ok := params["keyID"]; ok {
|
||||
r.Header.Set("TARGET_RSRC_ID", keyID)
|
||||
}
|
||||
if nodeID, ok := params["nodeid"]; ok && r.Header.Get("TARGET_RSRC") != models.ExtClientsRsrc.String() {
|
||||
if nodeID, ok := params["nodeid"]; ok && r.Header.Get("TARGET_RSRC") != schema.ExtClientsRsrc.String() {
|
||||
r.Header.Set("TARGET_RSRC_ID", nodeID)
|
||||
}
|
||||
if strings.Contains(route, "failover") {
|
||||
r.Header.Set("TARGET_RSRC", models.FailOverRsrc.String())
|
||||
r.Header.Set("TARGET_RSRC", schema.FailOverRsrc.String())
|
||||
nodeID := r.Header.Get("TARGET_RSRC_ID")
|
||||
node, _ := logic.GetNodeByID(nodeID)
|
||||
r.Header.Set("NET_ID", node.Network)
|
||||
@@ -133,10 +133,10 @@ func userMiddleWare(handler http.Handler) http.Handler {
|
||||
}
|
||||
}
|
||||
if r.Header.Get("NET_ID") == "" && (r.Header.Get("TARGET_RSRC_ID") == "" ||
|
||||
r.Header.Get("TARGET_RSRC") == models.EnrollmentKeysRsrc.String() ||
|
||||
r.Header.Get("TARGET_RSRC") == models.UserRsrc.String()) ||
|
||||
(r.Header.Get("TARGET_RSRC") == models.UserActivityRsrc.String() && route != "/api/v1/network/activity") ||
|
||||
r.Header.Get("TARGET_RSRC") == models.TrafficFlow.String() {
|
||||
r.Header.Get("TARGET_RSRC") == schema.EnrollmentKeysRsrc.String() ||
|
||||
r.Header.Get("TARGET_RSRC") == schema.UserRsrc.String()) ||
|
||||
(r.Header.Get("TARGET_RSRC") == schema.UserActivityRsrc.String() && route != "/api/v1/network/activity") ||
|
||||
r.Header.Get("TARGET_RSRC") == schema.TrafficFlow.String() {
|
||||
r.Header.Set("IS_GLOBAL_ACCESS", "yes")
|
||||
}
|
||||
r.Header.Set("RSRC_TYPE", r.Header.Get("TARGET_RSRC"))
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/gravitl/netmaker/mq"
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
"github.com/gravitl/netmaker/servercfg"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"golang.org/x/exp/slog"
|
||||
@@ -30,9 +31,9 @@ import (
|
||||
// @Failure 400 {object} models.ErrorResponse
|
||||
func migrate(w http.ResponseWriter, r *http.Request) {
|
||||
data := models.MigrationData{}
|
||||
host := models.Host{}
|
||||
host := schema.Host{}
|
||||
node := models.Node{}
|
||||
nodes := []models.Node{}
|
||||
var nodes []models.Node
|
||||
server := models.ServerConfig{}
|
||||
err := json.NewDecoder(r.Body).Decode(&data)
|
||||
if err != nil {
|
||||
@@ -127,9 +128,9 @@ func migrate(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
func convertLegacyHostNode(legacy models.LegacyNode) (models.Host, models.Node) {
|
||||
func convertLegacyHostNode(legacy models.LegacyNode) (schema.Host, models.Node) {
|
||||
//convert host
|
||||
host := models.Host{}
|
||||
host := schema.Host{}
|
||||
host.ID = uuid.New()
|
||||
host.IPForwarding = models.ParseBool(legacy.IPForwarding)
|
||||
host.AutoUpdate = logic.AutoUpdateEnabled()
|
||||
@@ -139,7 +140,8 @@ func convertLegacyHostNode(legacy models.LegacyNode) (models.Host, models.Node)
|
||||
host.ListenPort = 51821
|
||||
}
|
||||
host.MTU = int(legacy.MTU)
|
||||
host.PublicKey, _ = wgtypes.ParseKey(legacy.PublicKey)
|
||||
pubKey, _ := wgtypes.ParseKey(legacy.PublicKey)
|
||||
host.PublicKey = schema.WgKey{Key: pubKey}
|
||||
host.MacAddress = net.HardwareAddr(legacy.MacAddress)
|
||||
host.TrafficKeyPublic = legacy.TrafficKeys.Mine
|
||||
host.Nodes = append([]string{}, legacy.ID)
|
||||
|
||||
+69
-60
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
"golang.org/x/exp/slog"
|
||||
|
||||
"github.com/gravitl/netmaker/database"
|
||||
@@ -27,7 +28,7 @@ func networkHandlers(r *mux.Router) {
|
||||
Methods(http.MethodGet)
|
||||
r.HandleFunc("/api/v1/networks/stats", logic.SecurityCheck(true, http.HandlerFunc(getNetworksStats))).
|
||||
Methods(http.MethodGet)
|
||||
r.HandleFunc("/api/networks", logic.SecurityCheck(true, checkFreeTierLimits(limitChoiceNetworks, http.HandlerFunc(createNetwork)))).
|
||||
r.HandleFunc("/api/networks", logic.SecurityCheck(true, http.HandlerFunc(createNetwork))).
|
||||
Methods(http.MethodPost)
|
||||
r.HandleFunc("/api/networks/{networkname}", logic.SecurityCheck(true, http.HandlerFunc(getNetwork))).
|
||||
Methods(http.MethodGet)
|
||||
@@ -52,26 +53,26 @@ func networkHandlers(r *mux.Router) {
|
||||
// @Tags Networks
|
||||
// @Security oauth
|
||||
// @Produce json
|
||||
// @Success 200 {object} models.Network
|
||||
// @Success 200 {array} schema.Network
|
||||
// @Failure 500 {object} models.ErrorResponse
|
||||
func getNetworks(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
var err error
|
||||
|
||||
allnetworks, err := logic.GetNetworks()
|
||||
if err != nil && !database.IsEmptyRecord(err) {
|
||||
allnetworks, err := (&schema.Network{}).ListAll(r.Context())
|
||||
if err != nil {
|
||||
slog.Error("failed to fetch networks", "error", err.Error())
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
|
||||
return
|
||||
}
|
||||
if r.Header.Get("ismaster") != "yes" {
|
||||
username := r.Header.Get("user")
|
||||
user, err := logic.GetUser(username)
|
||||
user := &schema.User{Username: username}
|
||||
err = user.Get(r.Context())
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
|
||||
return
|
||||
}
|
||||
allnetworks = logic.FilterNetworksByRole(allnetworks, *user)
|
||||
allnetworks = logic.FilterNetworksByRole(allnetworks, user)
|
||||
}
|
||||
|
||||
logger.Log(2, r.Header.Get("user"), "fetched networks.")
|
||||
@@ -90,32 +91,38 @@ func getNetworks(w http.ResponseWriter, r *http.Request) {
|
||||
func getNetworksStats(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
var err error
|
||||
allnetworks, err := logic.GetNetworks()
|
||||
if err != nil && !database.IsEmptyRecord(err) {
|
||||
allnetworks, err := (&schema.Network{}).ListAll(r.Context())
|
||||
if err != nil {
|
||||
slog.Error("failed to fetch networks", "error", err.Error())
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
|
||||
return
|
||||
}
|
||||
if r.Header.Get("ismaster") != "yes" {
|
||||
username := r.Header.Get("user")
|
||||
user, err := logic.GetUser(username)
|
||||
user := &schema.User{Username: username}
|
||||
err = user.Get(r.Context())
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
|
||||
return
|
||||
}
|
||||
allnetworks = logic.FilterNetworksByRole(allnetworks, *user)
|
||||
allnetworks = logic.FilterNetworksByRole(allnetworks, user)
|
||||
}
|
||||
allNodes, err := logic.GetAllNodes()
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
|
||||
return
|
||||
}
|
||||
netstats := []models.NetworkStatResp{}
|
||||
type networkStatResp struct {
|
||||
Network *schema.Network
|
||||
Hosts int
|
||||
}
|
||||
|
||||
var netstats []networkStatResp
|
||||
logic.SortNetworks(allnetworks[:])
|
||||
for _, network := range allnetworks {
|
||||
netstats = append(netstats, models.NetworkStatResp{
|
||||
Network: network,
|
||||
Hosts: len(logic.GetNetworkNodesMemory(allNodes, network.NetID)),
|
||||
netstats = append(netstats, networkStatResp{
|
||||
Network: &network,
|
||||
Hosts: len(logic.GetNetworkNodesMemory(allNodes, network.Name)),
|
||||
})
|
||||
}
|
||||
logger.Log(2, r.Header.Get("user"), "fetched networks.")
|
||||
@@ -128,7 +135,7 @@ func getNetworksStats(w http.ResponseWriter, r *http.Request) {
|
||||
// @Security oauth
|
||||
// @Param networkname path string true "Network name"
|
||||
// @Produce json
|
||||
// @Success 200 {object} models.Network
|
||||
// @Success 200 {object} schema.Network
|
||||
// @Failure 404 {object} models.ErrorResponse
|
||||
// @Failure 500 {object} models.ErrorResponse
|
||||
func getNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -136,7 +143,8 @@ func getNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
var params = mux.Vars(r)
|
||||
netname := params["networkname"]
|
||||
network, err := logic.GetNetwork(netname)
|
||||
network := &schema.Network{Name: netname}
|
||||
err := network.Get(r.Context())
|
||||
if err != nil {
|
||||
logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to fetch network [%s] info: %v",
|
||||
netname, err))
|
||||
@@ -383,7 +391,7 @@ func updateNetworkACLv2(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// update ingress gateways of associated clients
|
||||
hosts, err := logic.GetAllHosts()
|
||||
hosts, err := (&schema.Host{}).ListAll(r.Context())
|
||||
if err != nil {
|
||||
slog.Error(
|
||||
"failed to fetch hosts after network ACL update. skipping publish extclients ACL",
|
||||
@@ -392,7 +400,7 @@ func updateNetworkACLv2(w http.ResponseWriter, r *http.Request) {
|
||||
)
|
||||
return
|
||||
}
|
||||
hostsMap := make(map[uuid.UUID]models.Host)
|
||||
hostsMap := make(map[uuid.UUID]schema.Host)
|
||||
for _, host := range hosts {
|
||||
hostsMap[host.ID] = host
|
||||
}
|
||||
@@ -486,14 +494,14 @@ func getNetworkEgressRoutes(w http.ResponseWriter, r *http.Request) {
|
||||
var params = mux.Vars(r)
|
||||
netname := params["networkname"]
|
||||
// check if network exists
|
||||
_, err := logic.GetNetwork(netname)
|
||||
err := (&schema.Network{Name: netname}).Get(r.Context())
|
||||
if err != nil {
|
||||
logger.Log(0, r.Header.Get("user"),
|
||||
fmt.Sprintf("failed to fetch ACLs for network [%s]: %v", netname, err))
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
|
||||
return
|
||||
}
|
||||
nodeEgressRoutes, _, err := logic.GetEgressRanges(models.NetworkID(netname))
|
||||
nodeEgressRoutes, _, err := logic.GetEgressRanges(schema.NetworkID(netname))
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
|
||||
return
|
||||
@@ -538,8 +546,8 @@ func deleteNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
go logic.UnlinkNetworkAndTagsFromEnrollmentKeys(network, true)
|
||||
go logic.DeleteNetworkRoles(network)
|
||||
go logic.DeleteAllNetworkTags(models.NetworkID(network))
|
||||
go logic.DeleteNetworkPolicies(models.NetworkID(network))
|
||||
go logic.DeleteAllNetworkTags(schema.NetworkID(network))
|
||||
go logic.DeleteNetworkPolicies(schema.NetworkID(network))
|
||||
//delete network from allocated ip map
|
||||
go logic.RemoveNetworkFromAllocatedIpMap(network)
|
||||
go func() {
|
||||
@@ -561,19 +569,19 @@ func deleteNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}()
|
||||
logic.LogEvent(&models.Event{
|
||||
Action: models.Delete,
|
||||
Action: schema.Delete,
|
||||
Source: models.Subject{
|
||||
ID: r.Header.Get("user"),
|
||||
Name: r.Header.Get("user"),
|
||||
Type: models.UserSub,
|
||||
Type: schema.UserSub,
|
||||
},
|
||||
TriggeredBy: r.Header.Get("user"),
|
||||
Target: models.Subject{
|
||||
ID: network,
|
||||
Name: network,
|
||||
Type: models.NetworkSub,
|
||||
Type: schema.NetworkSub,
|
||||
},
|
||||
Origin: models.Dashboard,
|
||||
Origin: schema.Dashboard,
|
||||
Diff: models.Diff{
|
||||
Old: network,
|
||||
New: nil,
|
||||
@@ -588,15 +596,15 @@ func deleteNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
// @Router /api/networks [post]
|
||||
// @Tags Networks
|
||||
// @Security oauth
|
||||
// @Param body body models.Network true "Network details"
|
||||
// @Param body body schema.Network true "Network details"
|
||||
// @Produce json
|
||||
// @Success 200 {object} models.Network
|
||||
// @Success 200 {object} schema.Network
|
||||
// @Failure 400 {object} models.ErrorResponse
|
||||
func createNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
var network models.Network
|
||||
var network schema.Network
|
||||
|
||||
// we decode our body request params
|
||||
err := json.NewDecoder(r.Body).Decode(&network)
|
||||
@@ -608,9 +616,9 @@ func createNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
featureFlags := logic.GetFeatureFlags()
|
||||
if !featureFlags.EnableDeviceApproval {
|
||||
network.AutoJoin = "true"
|
||||
network.AutoJoin = true
|
||||
}
|
||||
if len(network.NetID) > 32 {
|
||||
if len(network.Name) > 32 {
|
||||
err := errors.New("network name shouldn't exceed 32 characters")
|
||||
logger.Log(0, r.Header.Get("user"), "failed to create network: ",
|
||||
err.Error())
|
||||
@@ -664,7 +672,7 @@ func createNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
}
|
||||
if network.AutoRemove == "true" {
|
||||
if network.AutoRemove {
|
||||
if network.AutoRemoveThreshold == 0 {
|
||||
network.AutoRemoveThreshold = 60
|
||||
}
|
||||
@@ -672,23 +680,23 @@ func createNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
if network.AutoRemoveTags == nil {
|
||||
network.AutoRemoveTags = []string{}
|
||||
}
|
||||
network, err = logic.CreateNetwork(network)
|
||||
err = logic.CreateNetwork(&network)
|
||||
if err != nil {
|
||||
logger.Log(0, r.Header.Get("user"), "failed to create network: ",
|
||||
err.Error())
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
|
||||
return
|
||||
}
|
||||
logic.CreateDefaultNetworkRolesAndGroups(models.NetworkID(network.NetID))
|
||||
logic.CreateDefaultAclNetworkPolicies(models.NetworkID(network.NetID))
|
||||
logic.CreateDefaultTags(models.NetworkID(network.NetID))
|
||||
logic.AddNetworkToAllocatedIpMap(network.NetID)
|
||||
logic.CreateFallbackNameserver(network.NetID)
|
||||
logic.CreateDefaultNetworkRolesAndGroups(schema.NetworkID(network.Name))
|
||||
logic.CreateDefaultAclNetworkPolicies(schema.NetworkID(network.Name))
|
||||
logic.CreateDefaultTags(schema.NetworkID(network.Name))
|
||||
logic.AddNetworkToAllocatedIpMap(network.Name)
|
||||
logic.CreateFallbackNameserver(network.Name)
|
||||
if featureFlags.EnableOverlappingEgressRanges {
|
||||
// assign virtual NAT pool fields
|
||||
network.AssignVirtualNATDefaults(network.AddressRange, network.NetID)
|
||||
logic.AssignVirtualNATDefaults(&network, network.AddressRange)
|
||||
// Update network with virtual NAT settings
|
||||
if err := logic.UpsertNetwork(network); err != nil {
|
||||
if err := logic.UpsertNetwork(&network); err != nil {
|
||||
logger.Log(0, r.Header.Get("user"), "failed to update network with virtual NAT settings:", err.Error())
|
||||
}
|
||||
}
|
||||
@@ -696,14 +704,14 @@ func createNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
defaultHosts := logic.GetDefaultHosts()
|
||||
for i := range defaultHosts {
|
||||
currHost := &defaultHosts[i]
|
||||
newNode, err := logic.UpdateHostNetwork(currHost, network.NetID, true)
|
||||
newNode, err := logic.UpdateHostNetwork(currHost, network.Name, true)
|
||||
if err != nil {
|
||||
logger.Log(
|
||||
0,
|
||||
r.Header.Get("user"),
|
||||
"failed to add host to network:",
|
||||
currHost.ID.String(),
|
||||
network.NetID,
|
||||
network.Name,
|
||||
err.Error(),
|
||||
)
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
|
||||
@@ -721,7 +729,7 @@ func createNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
r.Header.Get("user"),
|
||||
"failed to add host to network:",
|
||||
currHost.ID.String(),
|
||||
network.NetID,
|
||||
network.Name,
|
||||
err.Error(),
|
||||
)
|
||||
}
|
||||
@@ -736,7 +744,7 @@ func createNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
r.Header.Get("user"),
|
||||
"failed to add host to network:",
|
||||
currHost.ID.String(),
|
||||
network.NetID,
|
||||
network.Name,
|
||||
err.Error(),
|
||||
)
|
||||
}
|
||||
@@ -745,10 +753,10 @@ func createNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
// make host failover
|
||||
logic.CreateFailOver(*newNode)
|
||||
// make host remote access gateway
|
||||
logic.CreateIngressGateway(network.NetID, newNode.ID.String(), models.IngressRequest{})
|
||||
logic.CreateIngressGateway(network.Name, newNode.ID.String(), models.IngressRequest{})
|
||||
logic.CreateRelay(models.RelayRequest{
|
||||
NodeID: newNode.ID.String(),
|
||||
NetID: network.NetID,
|
||||
NetID: network.Name,
|
||||
})
|
||||
}
|
||||
// send peer updates
|
||||
@@ -757,22 +765,22 @@ func createNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}()
|
||||
logic.LogEvent(&models.Event{
|
||||
Action: models.Create,
|
||||
Action: schema.Create,
|
||||
Source: models.Subject{
|
||||
ID: r.Header.Get("user"),
|
||||
Name: r.Header.Get("user"),
|
||||
Type: models.UserSub,
|
||||
Type: schema.UserSub,
|
||||
},
|
||||
TriggeredBy: r.Header.Get("user"),
|
||||
Target: models.Subject{
|
||||
ID: network.NetID,
|
||||
Name: network.NetID,
|
||||
Type: models.NetworkSub,
|
||||
ID: network.Name,
|
||||
Name: network.Name,
|
||||
Type: schema.NetworkSub,
|
||||
Info: network,
|
||||
},
|
||||
Origin: models.Dashboard,
|
||||
Origin: schema.Dashboard,
|
||||
})
|
||||
logger.Log(1, r.Header.Get("user"), "created network", network.NetID)
|
||||
logger.Log(1, r.Header.Get("user"), "created network", network.Name)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(network)
|
||||
}
|
||||
@@ -782,15 +790,15 @@ func createNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
// @Tags Networks
|
||||
// @Security oauth
|
||||
// @Param networkname path string true "Network name"
|
||||
// @Param body body models.Network true "Network details"
|
||||
// @Param body body schema.Network true "Network details"
|
||||
// @Produce json
|
||||
// @Success 200 {object} models.Network
|
||||
// @Success 200 {object} schema.Network
|
||||
// @Failure 400 {object} models.ErrorResponse
|
||||
func updateNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
var payload models.Network
|
||||
var payload schema.Network
|
||||
|
||||
// we decode our body request params
|
||||
err := json.NewDecoder(r.Body).Decode(&payload)
|
||||
@@ -800,20 +808,21 @@ func updateNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
netOld, err := logic.GetNetwork(payload.NetID)
|
||||
netOld := &schema.Network{Name: payload.Name}
|
||||
err = netOld.Get(r.Context())
|
||||
if err != nil {
|
||||
slog.Info("error fetching network", "user", r.Header.Get("user"), "err", err)
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
|
||||
return
|
||||
}
|
||||
err = logic.UpdateNetwork(&netOld, &payload)
|
||||
err = logic.UpdateNetwork(netOld, &payload)
|
||||
if err != nil {
|
||||
slog.Info("failed to update network", "user", r.Header.Get("user"), "err", err)
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
|
||||
return
|
||||
}
|
||||
go mq.PublishPeerUpdate(false)
|
||||
slog.Info("updated network", "network", payload.NetID, "user", r.Header.Get("user"))
|
||||
slog.Info("updated network", "network", payload.Name, "user", r.Header.Get("user"))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(payload)
|
||||
}
|
||||
|
||||
+54
-65
@@ -7,6 +7,7 @@ import (
|
||||
|
||||
"github.com/gravitl/netmaker/db"
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gravitl/netmaker/database"
|
||||
@@ -19,11 +20,11 @@ import (
|
||||
|
||||
type NetworkValidationTestCase struct {
|
||||
testname string
|
||||
network models.Network
|
||||
network schema.Network
|
||||
errMessage string
|
||||
}
|
||||
|
||||
var netHost models.Host
|
||||
var netHost schema.Host
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
db.InitializeDB(schema.ListModels()...)
|
||||
@@ -31,10 +32,10 @@ func TestMain(m *testing.M) {
|
||||
|
||||
database.InitializeDatabase()
|
||||
defer database.CloseDB()
|
||||
logic.CreateSuperAdmin(&models.User{
|
||||
UserName: "admin",
|
||||
logic.CreateSuperAdmin(&schema.User{
|
||||
Username: "admin",
|
||||
Password: "password",
|
||||
PlatformRoleID: models.SuperAdminRole,
|
||||
PlatformRoleID: schema.SuperAdminRole,
|
||||
})
|
||||
peerUpdate := make(chan *models.Node)
|
||||
go logic.ManageZombies(context.Background())
|
||||
@@ -51,27 +52,29 @@ func TestMain(m *testing.M) {
|
||||
func TestCreateNetwork(t *testing.T) {
|
||||
deleteAllNetworks()
|
||||
|
||||
var network models.Network
|
||||
network.NetID = "skynet1"
|
||||
var network schema.Network
|
||||
network.Name = "skynet1"
|
||||
network.AddressRange = "10.10.0.1/24"
|
||||
// if tests break - check here (removed displayname)
|
||||
//network.DisplayName = "mynetwork"
|
||||
|
||||
_, err := logic.CreateNetwork(network)
|
||||
err := logic.CreateNetwork(&network)
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
func TestGetNetwork(t *testing.T) {
|
||||
createNet()
|
||||
|
||||
t.Run("GetExistingNetwork", func(t *testing.T) {
|
||||
network, err := logic.GetNetwork("skynet")
|
||||
network := &schema.Network{Name: "skynet"}
|
||||
err := network.Get(db.WithContext(context.TODO()))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "skynet", network.NetID)
|
||||
assert.Equal(t, "skynet", network.Name)
|
||||
})
|
||||
t.Run("GetNonExistantNetwork", func(t *testing.T) {
|
||||
network, err := logic.GetNetwork("doesnotexist")
|
||||
assert.EqualError(t, err, "no result found")
|
||||
assert.Equal(t, "", network.NetID)
|
||||
network := &schema.Network{Name: "doesnotexist"}
|
||||
err := network.Get(db.WithContext(context.TODO()))
|
||||
assert.EqualError(t, err, gorm.ErrRecordNotFound.Error())
|
||||
assert.Equal(t, "", network.ID)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -128,65 +131,49 @@ func TestValidateNetwork(t *testing.T) {
|
||||
cases := []NetworkValidationTestCase{
|
||||
{
|
||||
testname: "InvalidAddress",
|
||||
network: models.Network{
|
||||
NetID: "skynet",
|
||||
network: schema.Network{
|
||||
Name: "skynet",
|
||||
AddressRange: "10.0.0.256",
|
||||
},
|
||||
errMessage: "Field validation for 'AddressRange' failed on the 'cidrv4' tag",
|
||||
errMessage: "invalid CIDR address: 10.0.0.256",
|
||||
},
|
||||
{
|
||||
testname: "InvalidAddress6",
|
||||
network: models.Network{
|
||||
NetID: "skynet1",
|
||||
network: schema.Network{
|
||||
Name: "skynet1",
|
||||
AddressRange6: "2607::ffff/130",
|
||||
},
|
||||
errMessage: "Field validation for 'AddressRange6' failed on the 'cidrv6' tag",
|
||||
errMessage: "invalid CIDR address: 2607::ffff/130",
|
||||
},
|
||||
{
|
||||
testname: "InvalidNetID",
|
||||
network: models.Network{
|
||||
NetID: "with spaces",
|
||||
network: schema.Network{
|
||||
Name: "with spaces",
|
||||
},
|
||||
errMessage: "Field validation for 'NetID' failed on the 'netid_valid' tag",
|
||||
errMessage: "invalid character(s) in network name",
|
||||
},
|
||||
{
|
||||
testname: "NetIDTooLong",
|
||||
network: models.Network{
|
||||
NetID: "LongNetIDNameForMaxCharactersTest",
|
||||
network: schema.Network{
|
||||
Name: "LongNetIDNameForMaxCharactersTest",
|
||||
},
|
||||
errMessage: "Field validation for 'NetID' failed on the 'max' tag",
|
||||
},
|
||||
{
|
||||
testname: "ListenPortTooLow",
|
||||
network: models.Network{
|
||||
NetID: "skynet",
|
||||
DefaultListenPort: 1023,
|
||||
},
|
||||
errMessage: "Field validation for 'DefaultListenPort' failed on the 'min' tag",
|
||||
},
|
||||
{
|
||||
testname: "ListenPortTooHigh",
|
||||
network: models.Network{
|
||||
NetID: "skynet",
|
||||
DefaultListenPort: 65536,
|
||||
},
|
||||
errMessage: "Field validation for 'DefaultListenPort' failed on the 'max' tag",
|
||||
errMessage: "network name cannot be longer than 32 characters",
|
||||
},
|
||||
{
|
||||
testname: "KeepAliveTooBig",
|
||||
network: models.Network{
|
||||
NetID: "skynet",
|
||||
DefaultKeepalive: 1010,
|
||||
network: schema.Network{
|
||||
Name: "skynet",
|
||||
DefaultKeepAlive: 1010,
|
||||
},
|
||||
errMessage: "Field validation for 'DefaultKeepalive' failed on the 'max' tag",
|
||||
errMessage: "default keep alive must be less than 1000",
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.testname, func(t *testing.T) {
|
||||
t.Log(tc.testname)
|
||||
network := models.Network(tc.network)
|
||||
network.SetDefaults()
|
||||
network := tc.network
|
||||
err := logic.ValidateNetwork(&network, false)
|
||||
|
||||
assert.NotNil(t, err)
|
||||
assert.Contains(t, err.Error(), tc.errMessage) // test passes if err.Error() contains the expected errMessage.
|
||||
})
|
||||
@@ -200,7 +187,8 @@ func TestIpv6Network(t *testing.T) {
|
||||
deleteAllNetworks()
|
||||
createNet()
|
||||
createNetDualStack()
|
||||
network, err := logic.GetNetwork("skynet6")
|
||||
network := &schema.Network{Name: "skynet6"}
|
||||
err := network.Get(db.WithContext(context.TODO()))
|
||||
t.Run("Test Network Create IPv6", func(t *testing.T) {
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, network.AddressRange6, "fde6:be04:fa5e:d076::/64")
|
||||
@@ -216,46 +204,47 @@ func TestIpv6Network(t *testing.T) {
|
||||
|
||||
func deleteAllNetworks() {
|
||||
deleteAllNodes()
|
||||
database.DeleteAllRecords(database.NETWORKS_TABLE_NAME)
|
||||
_networks, _ := (&schema.Network{}).ListAll(db.WithContext(context.TODO()))
|
||||
for _, _network := range _networks {
|
||||
_ = _network.Delete(db.WithContext(context.TODO()))
|
||||
}
|
||||
}
|
||||
|
||||
func createNet() {
|
||||
var network models.Network
|
||||
network.NetID = "skynet"
|
||||
var network schema.Network
|
||||
network.Name = "skynet"
|
||||
network.AddressRange = "10.0.0.1/24"
|
||||
_, err := logic.GetNetwork("skynet")
|
||||
err := (&schema.Network{Name: "skynet"}).Get(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
logic.CreateNetwork(network)
|
||||
logic.CreateNetwork(&network)
|
||||
}
|
||||
}
|
||||
func createNetv1(netId string) {
|
||||
var network models.Network
|
||||
network.NetID = netId
|
||||
var network schema.Network
|
||||
network.Name = netId
|
||||
network.AddressRange = "100.0.0.1/24"
|
||||
_, err := logic.GetNetwork(netId)
|
||||
err := (&schema.Network{Name: netId}).Get(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
logic.CreateNetwork(network)
|
||||
logic.CreateNetwork(&network)
|
||||
}
|
||||
}
|
||||
|
||||
func createNetDualStack() {
|
||||
var network models.Network
|
||||
network.NetID = "skynet6"
|
||||
var network schema.Network
|
||||
network.Name = "skynet6"
|
||||
network.AddressRange = "10.1.2.0/24"
|
||||
network.AddressRange6 = "fde6:be04:fa5e:d076::/64"
|
||||
network.IsIPv4 = "yes"
|
||||
network.IsIPv6 = "yes"
|
||||
_, err := logic.GetNetwork("skynet6")
|
||||
err := (&schema.Network{Name: "skynet6"}).Get(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
logic.CreateNetwork(network)
|
||||
logic.CreateNetwork(&network)
|
||||
}
|
||||
}
|
||||
|
||||
func createNetHost() {
|
||||
k, _ := wgtypes.ParseKey("DM5qhLAE20PG9BbfBCger+Ac9D2NDOwCtY1rbYDLf34=")
|
||||
netHost = models.Host{
|
||||
netHost = schema.Host{
|
||||
ID: uuid.New(),
|
||||
PublicKey: k.PublicKey(),
|
||||
PublicKey: schema.WgKey{Key: k.PublicKey()},
|
||||
HostPass: "password",
|
||||
OS: "linux",
|
||||
Name: "nethost",
|
||||
|
||||
+26
-14
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/gravitl/netmaker/mq"
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
"github.com/gravitl/netmaker/servercfg"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"golang.org/x/exp/slog"
|
||||
@@ -27,9 +28,9 @@ func nodeHandlers(r *mux.Router) {
|
||||
r.HandleFunc("/api/nodes/{network}/{nodeid}", AuthorizeHost(http.HandlerFunc(getNode))).Methods(http.MethodGet)
|
||||
r.HandleFunc("/api/nodes/{network}/{nodeid}", logic.SecurityCheck(true, http.HandlerFunc(updateNode))).Methods(http.MethodPut)
|
||||
r.HandleFunc("/api/nodes/{network}/{nodeid}", AuthorizeHost(http.HandlerFunc(deleteNode))).Methods(http.MethodDelete)
|
||||
r.HandleFunc("/api/nodes/{network}/{nodeid}/creategateway", logic.SecurityCheck(true, checkFreeTierLimits(limitChoiceEgress, http.HandlerFunc(createEgressGateway)))).Methods(http.MethodPost)
|
||||
r.HandleFunc("/api/nodes/{network}/{nodeid}/creategateway", logic.SecurityCheck(true, http.HandlerFunc(createEgressGateway))).Methods(http.MethodPost)
|
||||
r.HandleFunc("/api/nodes/{network}/{nodeid}/deletegateway", logic.SecurityCheck(true, http.HandlerFunc(deleteEgressGateway))).Methods(http.MethodDelete)
|
||||
r.HandleFunc("/api/nodes/{network}/{nodeid}/createingress", logic.SecurityCheck(true, checkFreeTierLimits(limitChoiceIngress, http.HandlerFunc(createGateway)))).Methods(http.MethodPost)
|
||||
r.HandleFunc("/api/nodes/{network}/{nodeid}/createingress", logic.SecurityCheck(true, http.HandlerFunc(createGateway))).Methods(http.MethodPost)
|
||||
r.HandleFunc("/api/nodes/{network}/{nodeid}/deleteingress", logic.SecurityCheck(true, http.HandlerFunc(deleteGateway))).Methods(http.MethodDelete)
|
||||
r.HandleFunc("/api/nodes/adm/{network}/authenticate", authenticate).Methods(http.MethodPost)
|
||||
r.HandleFunc("/api/v1/nodes/{network}/status", logic.SecurityCheck(true, http.HandlerFunc(getNetworkNodeStatus))).Methods(http.MethodGet)
|
||||
@@ -81,7 +82,10 @@ func authenticate(response http.ResponseWriter, request *http.Request) {
|
||||
return
|
||||
}
|
||||
}
|
||||
host, err := logic.GetHost(result.HostID.String())
|
||||
host := &schema.Host{
|
||||
ID: result.HostID,
|
||||
}
|
||||
err = host.Get(request.Context())
|
||||
if err != nil {
|
||||
errorResponse.Code = http.StatusBadRequest
|
||||
errorResponse.Message = err.Error()
|
||||
@@ -234,16 +238,18 @@ func getAllNodes(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
username := r.Header.Get("user")
|
||||
if r.Header.Get("ismaster") == "no" {
|
||||
user, err := logic.GetUser(username)
|
||||
user := &schema.User{Username: username}
|
||||
err = user.Get(r.Context())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
userPlatformRole, err := logic.GetRole(user.PlatformRoleID)
|
||||
userPlatformRole := &schema.UserRole{ID: user.PlatformRoleID}
|
||||
err = userPlatformRole.Get(r.Context())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if !userPlatformRole.FullAccess {
|
||||
nodes = logic.GetFilteredNodesByUserAccess(*user, nodes)
|
||||
nodes = logic.GetFilteredNodesByUserAccess(user, nodes)
|
||||
}
|
||||
|
||||
}
|
||||
@@ -272,7 +278,7 @@ func getNetworkNodeStatus(w http.ResponseWriter, r *http.Request) {
|
||||
var params = mux.Vars(r)
|
||||
netID := params["network"]
|
||||
// validate network
|
||||
_, err := logic.GetNetwork(netID)
|
||||
err := (&schema.Network{Name: netID}).Get(r.Context())
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("failed to get network %v", err), "badrequest"))
|
||||
return
|
||||
@@ -313,7 +319,10 @@ func getNode(w http.ResponseWriter, r *http.Request) {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
|
||||
return
|
||||
}
|
||||
host, err := logic.GetHost(node.HostID.String())
|
||||
host := &schema.Host{
|
||||
ID: node.HostID,
|
||||
}
|
||||
err = host.Get(r.Context())
|
||||
if err != nil {
|
||||
logger.Log(0, r.Header.Get("user"),
|
||||
fmt.Sprintf("error fetching host for node [ %s ] info: %v", nodeid, err))
|
||||
@@ -549,7 +558,10 @@ func updateNode(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
}
|
||||
host, err := logic.GetHost(newNode.HostID.String())
|
||||
host := &schema.Host{
|
||||
ID: newNode.HostID,
|
||||
}
|
||||
err = host.Get(r.Context())
|
||||
if err != nil {
|
||||
logger.Log(0, r.Header.Get("user"),
|
||||
fmt.Sprintf("failed to get host for node [ %s ] info: %v", nodeid, err))
|
||||
@@ -621,7 +633,7 @@ func updateNode(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
newNode.PostureChecksViolations,
|
||||
newNode.PostureCheckVolationSeverityLevel = logic.CheckPostureViolations(logic.GetPostureCheckDeviceInfoByNode(newNode),
|
||||
models.NetworkID(newNode.Network))
|
||||
schema.NetworkID(newNode.Network))
|
||||
newNode.LastEvaluatedAt = time.Now().UTC()
|
||||
logic.UpsertNode(newNode)
|
||||
logic.GetNodeStatus(newNode, false)
|
||||
@@ -636,23 +648,23 @@ func updateNode(w http.ResponseWriter, r *http.Request) {
|
||||
currentNode.Network,
|
||||
)
|
||||
logic.LogEvent(&models.Event{
|
||||
Action: models.Update,
|
||||
Action: schema.Update,
|
||||
Source: models.Subject{
|
||||
ID: r.Header.Get("user"),
|
||||
Name: r.Header.Get("user"),
|
||||
Type: models.UserSub,
|
||||
Type: schema.UserSub,
|
||||
},
|
||||
TriggeredBy: r.Header.Get("user"),
|
||||
Target: models.Subject{
|
||||
ID: newNode.ID.String(),
|
||||
Name: host.Name,
|
||||
Type: models.NodeSub,
|
||||
Type: schema.NodeSub,
|
||||
},
|
||||
Diff: models.Diff{
|
||||
Old: currentNode,
|
||||
New: newNode,
|
||||
},
|
||||
Origin: models.Dashboard,
|
||||
Origin: schema.Dashboard,
|
||||
})
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(apiNode)
|
||||
|
||||
@@ -4,19 +4,23 @@ import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"context"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gravitl/netmaker/database"
|
||||
"github.com/gravitl/netmaker/db"
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/logic/acls"
|
||||
"github.com/gravitl/netmaker/logic/acls/nodeacls"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
"github.com/gravitl/netmaker/servercfg"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
var nonLinuxHost models.Host
|
||||
var linuxHost models.Host
|
||||
var nonLinuxHost schema.Host
|
||||
var linuxHost schema.Host
|
||||
|
||||
func TestGetNetworkNodes(t *testing.T) {
|
||||
deleteAllNetworks()
|
||||
@@ -95,10 +99,11 @@ func TestNodeACLs(t *testing.T) {
|
||||
t.Run("node acls correct after add new node not allowed", func(t *testing.T) {
|
||||
node3 := createNodeWithParams("", "10.0.0.100/32")
|
||||
createNodeHosts()
|
||||
n, e := logic.GetNetwork(node3.Network)
|
||||
n := &schema.Network{Name: node3.Network}
|
||||
e := n.Get(db.WithContext(context.TODO()))
|
||||
assert.Nil(t, e)
|
||||
n.DefaultACL = "no"
|
||||
e = logic.SaveNetwork(&n)
|
||||
e = logic.SaveNetwork(n)
|
||||
assert.Nil(t, e)
|
||||
err := logic.AssociateNodeToHost(node3, &linuxHost)
|
||||
assert.Nil(t, err)
|
||||
@@ -159,18 +164,18 @@ func createNodeWithParams(network, address string) *models.Node {
|
||||
|
||||
func createNodeHosts() {
|
||||
k, _ := wgtypes.ParseKey("DM5qhLAE20PG9BbfBCger+Ac9D2NDOwCtY1rbYDLf34=")
|
||||
linuxHost = models.Host{
|
||||
linuxHost = schema.Host{
|
||||
ID: uuid.New(),
|
||||
PublicKey: k.PublicKey(),
|
||||
PublicKey: schema.WgKey{Key: k.PublicKey()},
|
||||
HostPass: "password",
|
||||
OS: "linux",
|
||||
Name: "linuxhost",
|
||||
}
|
||||
_ = logic.CreateHost(&linuxHost)
|
||||
nonLinuxHost = models.Host{
|
||||
nonLinuxHost = schema.Host{
|
||||
ID: uuid.New(),
|
||||
OS: "windows",
|
||||
PublicKey: k.PublicKey(),
|
||||
PublicKey: schema.WgKey{Key: k.PublicKey()},
|
||||
Name: "windowshost",
|
||||
HostPass: "password",
|
||||
}
|
||||
|
||||
+30
-24
@@ -1,6 +1,7 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -13,6 +14,8 @@ import (
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/gorilla/mux"
|
||||
ch "github.com/gravitl/netmaker/clickhouse"
|
||||
"github.com/gravitl/netmaker/db"
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
"golang.org/x/exp/slog"
|
||||
|
||||
"github.com/gravitl/netmaker/database"
|
||||
@@ -38,8 +41,11 @@ func serverHandlers(r *mux.Router) {
|
||||
"/api/server/shutdown", logic.SecurityCheck(true,
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Header.Get("ismaster") != "yes" {
|
||||
caller, err := logic.GetUser(r.Header.Get("user"))
|
||||
if err != nil || caller.PlatformRoleID != models.SuperAdminRole {
|
||||
caller := &schema.User{
|
||||
Username: r.Header.Get("user"),
|
||||
}
|
||||
err := caller.Get(r.Context())
|
||||
if err != nil || caller.PlatformRoleID != schema.SuperAdminRole {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("only a super-admin can shut down the server"), "forbidden"))
|
||||
return
|
||||
}
|
||||
@@ -255,7 +261,7 @@ func updateSettings(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if superAdmin.AuthType == models.OAuth {
|
||||
if superAdmin.AuthType == schema.OAuth {
|
||||
err := fmt.Errorf(
|
||||
"cannot remove IdP integration because an OAuth user has the super-admin role; transfer the super-admin role to another user first",
|
||||
)
|
||||
@@ -293,19 +299,19 @@ func updateSettings(w http.ResponseWriter, r *http.Request) {
|
||||
Source: models.Subject{
|
||||
ID: r.Header.Get("user"),
|
||||
Name: r.Header.Get("user"),
|
||||
Type: models.UserSub,
|
||||
Type: schema.UserSub,
|
||||
},
|
||||
TriggeredBy: r.Header.Get("user"),
|
||||
Target: models.Subject{
|
||||
ID: models.SettingSub.String(),
|
||||
Name: models.SettingSub.String(),
|
||||
Type: models.SettingSub,
|
||||
ID: schema.SettingSub.String(),
|
||||
Name: schema.SettingSub.String(),
|
||||
Type: schema.SettingSub,
|
||||
},
|
||||
Diff: models.Diff{
|
||||
Old: currSettings,
|
||||
New: req,
|
||||
},
|
||||
Origin: models.Dashboard,
|
||||
Origin: schema.Dashboard,
|
||||
})
|
||||
go reInit(currSettings, req, force == "true")
|
||||
logic.ReturnSuccessResponseWithJson(w, r, req, "updated server settings successfully")
|
||||
@@ -333,7 +339,7 @@ func reInit(curr, new models.ServerSettings, force bool) {
|
||||
if force || !new.EnableFlowLogs {
|
||||
if curr.NetclientAutoUpdate != new.NetclientAutoUpdate ||
|
||||
curr.EnableFlowLogs != new.EnableFlowLogs {
|
||||
hosts, _ := logic.GetAllHosts()
|
||||
hosts, _ := (&schema.Host{}).ListAll(db.WithContext(context.TODO()))
|
||||
for _, host := range hosts {
|
||||
if curr.NetclientAutoUpdate != new.NetclientAutoUpdate {
|
||||
host.AutoUpdate = new.NetclientAutoUpdate
|
||||
@@ -355,7 +361,7 @@ func reInit(curr, new models.ServerSettings, force bool) {
|
||||
go mq.PublishPeerUpdate(false)
|
||||
}
|
||||
|
||||
func identifySettingsUpdateAction(old, new models.ServerSettings) models.Action {
|
||||
func identifySettingsUpdateAction(old, new models.ServerSettings) schema.Action {
|
||||
// TODO: here we are relying on the dashboard to only
|
||||
// make singular updates, but it's possible that the
|
||||
// API can be called to make multiple changes to the
|
||||
@@ -363,33 +369,33 @@ func identifySettingsUpdateAction(old, new models.ServerSettings) models.Action
|
||||
// events or create singular update APIs.
|
||||
if old.MFAEnforced != new.MFAEnforced {
|
||||
if new.MFAEnforced {
|
||||
return models.EnforceMFA
|
||||
return schema.EnforceMFA
|
||||
} else {
|
||||
return models.UnenforceMFA
|
||||
return schema.UnenforceMFA
|
||||
}
|
||||
}
|
||||
|
||||
if old.BasicAuth != new.BasicAuth {
|
||||
if new.BasicAuth {
|
||||
return models.EnableBasicAuth
|
||||
return schema.EnableBasicAuth
|
||||
} else {
|
||||
return models.DisableBasicAuth
|
||||
return schema.DisableBasicAuth
|
||||
}
|
||||
}
|
||||
|
||||
if old.Telemetry != new.Telemetry {
|
||||
if new.Telemetry == "off" {
|
||||
return models.DisableTelemetry
|
||||
return schema.DisableTelemetry
|
||||
} else {
|
||||
return models.EnableTelemetry
|
||||
return schema.EnableTelemetry
|
||||
}
|
||||
}
|
||||
|
||||
if old.EnableFlowLogs != new.EnableFlowLogs {
|
||||
if new.EnableFlowLogs {
|
||||
return models.EnableFlowLogs
|
||||
return schema.EnableFlowLogs
|
||||
} else {
|
||||
return models.DisableFlowLogs
|
||||
return schema.DisableFlowLogs
|
||||
}
|
||||
}
|
||||
|
||||
@@ -398,19 +404,19 @@ func identifySettingsUpdateAction(old, new models.ServerSettings) models.Action
|
||||
old.ManageDNS != new.ManageDNS ||
|
||||
old.DefaultDomain != new.DefaultDomain ||
|
||||
old.EndpointDetection != new.EndpointDetection {
|
||||
return models.UpdateClientSettings
|
||||
return schema.UpdateClientSettings
|
||||
}
|
||||
|
||||
if old.AllowedEmailDomains != new.AllowedEmailDomains ||
|
||||
old.JwtValidityDuration != new.JwtValidityDuration {
|
||||
return models.UpdateAuthenticationSecuritySettings
|
||||
return schema.UpdateAuthenticationSecuritySettings
|
||||
}
|
||||
|
||||
if old.Verbosity != new.Verbosity ||
|
||||
old.MetricsPort != new.MetricsPort ||
|
||||
old.MetricInterval != new.MetricInterval ||
|
||||
old.AuditLogsRetentionPeriodInDays != new.AuditLogsRetentionPeriodInDays {
|
||||
return models.UpdateMonitoringAndDebuggingSettings
|
||||
return schema.UpdateMonitoringAndDebuggingSettings
|
||||
}
|
||||
|
||||
if old.EmailSenderAddr != new.EmailSenderAddr ||
|
||||
@@ -418,7 +424,7 @@ func identifySettingsUpdateAction(old, new models.ServerSettings) models.Action
|
||||
old.EmailSenderPassword != new.EmailSenderPassword ||
|
||||
old.SmtpHost != new.SmtpHost ||
|
||||
old.SmtpPort != new.SmtpPort {
|
||||
return models.UpdateSMTPSettings
|
||||
return schema.UpdateSMTPSettings
|
||||
}
|
||||
|
||||
if old.AuthProvider != new.AuthProvider ||
|
||||
@@ -432,10 +438,10 @@ func identifySettingsUpdateAction(old, new models.ServerSettings) models.Action
|
||||
old.AzureTenant != new.AzureTenant ||
|
||||
!cmp.Equal(old.GroupFilters, new.GroupFilters) ||
|
||||
cmp.Equal(old.UserFilters, new.UserFilters) {
|
||||
return models.UpdateIDPSettings
|
||||
return schema.UpdateIDPSettings
|
||||
}
|
||||
|
||||
return models.Update
|
||||
return schema.Update
|
||||
}
|
||||
|
||||
// @Summary Get feature flags for this server
|
||||
|
||||
+426
-350
File diff suppressed because it is too large
Load Diff
@@ -101,11 +101,9 @@ const (
|
||||
var dbMutex sync.RWMutex
|
||||
|
||||
var Tables = []string{
|
||||
NETWORKS_TABLE_NAME,
|
||||
NODES_TABLE_NAME,
|
||||
CERTS_TABLE_NAME,
|
||||
DELETED_NODES_TABLE_NAME,
|
||||
USERS_TABLE_NAME,
|
||||
DNS_TABLE_NAME,
|
||||
EXT_CLIENT_TABLE_NAME,
|
||||
PEERS_TABLE_NAME,
|
||||
@@ -116,18 +114,22 @@ var Tables = []string{
|
||||
SSO_STATE_CACHE,
|
||||
METRICS_TABLE_NAME,
|
||||
NETWORK_USER_TABLE_NAME,
|
||||
USER_GROUPS_TABLE_NAME,
|
||||
CACHE_TABLE_NAME,
|
||||
HOSTS_TABLE_NAME,
|
||||
ENROLLMENT_KEYS_TABLE_NAME,
|
||||
HOST_ACTIONS_TABLE_NAME,
|
||||
PENDING_USERS_TABLE_NAME,
|
||||
USER_PERMISSIONS_TABLE_NAME,
|
||||
USER_INVITES_TABLE_NAME,
|
||||
TAG_TABLE_NAME,
|
||||
ACLS_TABLE_NAME,
|
||||
PEER_ACK_TABLE,
|
||||
SERVER_SETTINGS,
|
||||
// The following tables are to be migrated, but we still need them so that the migration function
|
||||
// doesn't fail with table does not exist.
|
||||
USERS_TABLE_NAME,
|
||||
USER_GROUPS_TABLE_NAME,
|
||||
USER_PERMISSIONS_TABLE_NAME,
|
||||
NETWORKS_TABLE_NAME,
|
||||
HOSTS_TABLE_NAME,
|
||||
}
|
||||
|
||||
func getCurrentDB() map[string]interface{} {
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
|
||||
"github.com/gravitl/netmaker/db"
|
||||
_ "github.com/mattn/go-sqlite3" // need to blank import this package
|
||||
)
|
||||
|
||||
+7
-2
@@ -1,11 +1,16 @@
|
||||
package database
|
||||
|
||||
import "strings"
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// IsEmptyRecord - checks for if it's an empty record error or not
|
||||
func IsEmptyRecord(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(err.Error(), NO_RECORD) || strings.Contains(err.Error(), NO_RECORDS)
|
||||
return strings.Contains(err.Error(), NO_RECORD) || strings.Contains(err.Error(), NO_RECORDS) || errors.Is(err, gorm.ErrRecordNotFound)
|
||||
}
|
||||
|
||||
+13
-1
@@ -49,7 +49,19 @@ func (s *sqliteConnector) connect() (*gorm.DB, error) {
|
||||
}
|
||||
}
|
||||
|
||||
return gorm.Open(sqlite.Open(dbFilePath), &gorm.Config{
|
||||
db, err := gorm.Open(sqlite.Open(dbFilePath), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sqlDB.SetMaxIdleConns(1)
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type Option func(db *gorm.DB) *gorm.DB
|
||||
|
||||
func WithPagination(page, pageSize int) Option {
|
||||
return func(db *gorm.DB) *gorm.DB {
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
|
||||
if pageSize < 1 || pageSize > 100 {
|
||||
pageSize = 10
|
||||
}
|
||||
|
||||
offset := (page - 1) * pageSize
|
||||
return db.Offset(offset).Limit(pageSize)
|
||||
}
|
||||
}
|
||||
|
||||
// WithFilter applies a WHERE clause for the given column.
|
||||
// IMPORTANT: `field` MUST be a trusted, hardcoded column name.
|
||||
// NEVER pass user-supplied strings as `field`.
|
||||
func WithFilter(field string, value ...interface{}) Option {
|
||||
return func(db *gorm.DB) *gorm.DB {
|
||||
if len(value) == 0 {
|
||||
return db
|
||||
}
|
||||
|
||||
if len(value) == 1 {
|
||||
return db.Where(fmt.Sprintf("%s = ?", field), value[0])
|
||||
}
|
||||
|
||||
return db.Where(fmt.Sprintf("%s IN ?", field), value)
|
||||
}
|
||||
}
|
||||
|
||||
func InAscOrder(fields ...string) Option {
|
||||
return func(db *gorm.DB) *gorm.DB {
|
||||
for _, field := range fields {
|
||||
db = db.Order(fmt.Sprintf("%s ASC", field))
|
||||
}
|
||||
|
||||
return db
|
||||
}
|
||||
}
|
||||
|
||||
func InDescOrder(fields ...string) Option {
|
||||
return func(db *gorm.DB) *gorm.DB {
|
||||
for _, field := range fields {
|
||||
db = db.Order(fmt.Sprintf("%s DESC", field))
|
||||
}
|
||||
|
||||
return db
|
||||
}
|
||||
}
|
||||
+11
-11
@@ -17,8 +17,8 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
testNetwork = &models.Network{
|
||||
NetID: "not-a-network",
|
||||
_testNetwork = &schema.Network{
|
||||
Name: "not-a-network",
|
||||
}
|
||||
testExternalClient = &models.ExtClient{
|
||||
ClientID: "testExtClient",
|
||||
@@ -31,10 +31,10 @@ func TestMain(m *testing.M) {
|
||||
|
||||
database.InitializeDatabase()
|
||||
defer database.CloseDB()
|
||||
logic.CreateSuperAdmin(&models.User{
|
||||
UserName: "superadmin",
|
||||
logic.CreateSuperAdmin(&schema.User{
|
||||
Username: "superadmin",
|
||||
Password: "password",
|
||||
PlatformRoleID: models.SuperAdminRole,
|
||||
PlatformRoleID: schema.SuperAdminRole,
|
||||
})
|
||||
peerUpdate := make(chan *models.Node)
|
||||
go logic.ManageZombies(context.Background())
|
||||
@@ -48,18 +48,18 @@ func TestMain(m *testing.M) {
|
||||
}
|
||||
|
||||
func TestNetworkExists(t *testing.T) {
|
||||
database.DeleteRecord(database.NETWORKS_TABLE_NAME, testNetwork.NetID)
|
||||
exists, err := logic.NetworkExists(testNetwork.NetID)
|
||||
assert.NotNil(t, err)
|
||||
_ = _testNetwork.Delete(db.WithContext(context.TODO()))
|
||||
exists, err := logic.NetworkExists(_testNetwork.Name)
|
||||
assert.Nil(t, err)
|
||||
assert.False(t, exists)
|
||||
|
||||
err = logic.SaveNetwork(testNetwork)
|
||||
err = logic.SaveNetwork(_testNetwork)
|
||||
assert.Nil(t, err)
|
||||
exists, err = logic.NetworkExists(testNetwork.NetID)
|
||||
exists, err = logic.NetworkExists(_testNetwork.Name)
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
err = database.DeleteRecord(database.NETWORKS_TABLE_NAME, testNetwork.NetID)
|
||||
err = _testNetwork.Delete(db.WithContext(context.TODO()))
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
|
||||
+31
-31
@@ -43,9 +43,9 @@ func GetFwRulesOnIngressGateway(node models.Node) (rules []models.FwRule) {
|
||||
return string(rules[i].DstIP.IP.To16()) < string(rules[j].DstIP.IP.To16())
|
||||
})
|
||||
}()
|
||||
defaultDevicePolicy, _ := GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy)
|
||||
defaultDevicePolicy, _ := GetDefaultPolicy(schema.NetworkID(node.Network), models.DevicePolicy)
|
||||
nodes, _ := GetNetworkNodes(node.Network)
|
||||
nodes = append(nodes, GetStaticNodesByNetwork(models.NetworkID(node.Network), true)...)
|
||||
nodes = append(nodes, GetStaticNodesByNetwork(schema.NetworkID(node.Network), true)...)
|
||||
rules = GetFwRulesForUserNodesOnGw(node, nodes)
|
||||
if defaultDevicePolicy.Enabled {
|
||||
return
|
||||
@@ -417,10 +417,10 @@ func GetStaticNodeIps(node models.Node) (ips []net.IP) {
|
||||
defer func() {
|
||||
sortIPs(ips)
|
||||
}()
|
||||
defaultUserPolicy, _ := GetDefaultPolicy(models.NetworkID(node.Network), models.UserPolicy)
|
||||
defaultDevicePolicy, _ := GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy)
|
||||
defaultUserPolicy, _ := GetDefaultPolicy(schema.NetworkID(node.Network), models.UserPolicy)
|
||||
defaultDevicePolicy, _ := GetDefaultPolicy(schema.NetworkID(node.Network), models.DevicePolicy)
|
||||
|
||||
extclients := GetStaticNodesByNetwork(models.NetworkID(node.Network), false)
|
||||
extclients := GetStaticNodesByNetwork(schema.NetworkID(node.Network), false)
|
||||
for _, extclient := range extclients {
|
||||
if extclient.IsUserNode && defaultUserPolicy.Enabled {
|
||||
continue
|
||||
@@ -516,11 +516,11 @@ func GetAclRulesForNode(targetnodeI *models.Node) (rules map[string]models.AclRu
|
||||
}
|
||||
var taggedNodes map[models.TagID][]models.Node
|
||||
if targetnode.IsIngressGateway {
|
||||
taggedNodes = GetTagMapWithNodesByNetwork(models.NetworkID(targetnode.Network), false)
|
||||
taggedNodes = GetTagMapWithNodesByNetwork(schema.NetworkID(targetnode.Network), false)
|
||||
} else {
|
||||
taggedNodes = GetTagMapWithNodesByNetwork(models.NetworkID(targetnode.Network), true)
|
||||
taggedNodes = GetTagMapWithNodesByNetwork(schema.NetworkID(targetnode.Network), true)
|
||||
}
|
||||
acls := ListDevicePolicies(models.NetworkID(targetnode.Network))
|
||||
acls := ListDevicePolicies(schema.NetworkID(targetnode.Network))
|
||||
var targetNodeTags = make(map[models.TagID]struct{})
|
||||
if targetnode.Mutex != nil {
|
||||
targetnode.Mutex.Lock()
|
||||
@@ -796,9 +796,9 @@ func GetEgressRulesForNode(targetnode models.Node) (rules map[string]models.AclR
|
||||
defer func() {
|
||||
rules = GetEgressUserRulesForNode(&targetnode, rules)
|
||||
}()
|
||||
taggedNodes := GetTagMapWithNodesByNetwork(models.NetworkID(targetnode.Network), true)
|
||||
taggedNodes := GetTagMapWithNodesByNetwork(schema.NetworkID(targetnode.Network), true)
|
||||
|
||||
acls := ListDevicePolicies(models.NetworkID(targetnode.Network))
|
||||
acls := ListDevicePolicies(schema.NetworkID(targetnode.Network))
|
||||
var targetNodeTags = make(map[models.TagID]struct{})
|
||||
targetNodeTags[models.TagID(targetnode.ID.String())] = struct{}{}
|
||||
targetNodeTags["*"] = struct{}{}
|
||||
@@ -1070,7 +1070,7 @@ var IsPeerAllowed = func(node, peer models.Node, checkDefaultPolicy bool) bool {
|
||||
}
|
||||
if checkDefaultPolicy {
|
||||
// check default policy if all allowed return true
|
||||
defaultPolicy, err := GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy)
|
||||
defaultPolicy, err := GetDefaultPolicy(schema.NetworkID(node.Network), models.DevicePolicy)
|
||||
if err == nil {
|
||||
if defaultPolicy.Enabled {
|
||||
return true
|
||||
@@ -1079,7 +1079,7 @@ var IsPeerAllowed = func(node, peer models.Node, checkDefaultPolicy bool) bool {
|
||||
|
||||
}
|
||||
// list device policies
|
||||
policies := ListDevicePolicies(models.NetworkID(peer.Network))
|
||||
policies := ListDevicePolicies(schema.NetworkID(peer.Network))
|
||||
srcMap := make(map[string]struct{})
|
||||
dstMap := make(map[string]struct{})
|
||||
defer func() {
|
||||
@@ -1176,9 +1176,9 @@ func CheckTagGroupPolicy(srcMap, dstMap map[string]struct{}, node, peer models.N
|
||||
}
|
||||
|
||||
var (
|
||||
CreateDefaultTags = func(netID models.NetworkID) {}
|
||||
CreateDefaultTags = func(netID schema.NetworkID) {}
|
||||
|
||||
DeleteAllNetworkTags = func(networkID models.NetworkID) {}
|
||||
DeleteAllNetworkTags = func(networkID schema.NetworkID) {}
|
||||
|
||||
IsUserAllowedToCommunicate = func(userName string, peer models.Node) (bool, []models.Acl) {
|
||||
return false, []models.Acl{}
|
||||
@@ -1213,7 +1213,7 @@ func MigrateAclPolicies() {
|
||||
|
||||
func IsNodeAllowedToCommunicateWithAllRsrcs(node models.Node) bool {
|
||||
// check default policy if all allowed return true
|
||||
defaultPolicy, err := GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy)
|
||||
defaultPolicy, err := GetDefaultPolicy(schema.NetworkID(node.Network), models.DevicePolicy)
|
||||
if err == nil {
|
||||
if defaultPolicy.Enabled {
|
||||
return true
|
||||
@@ -1244,7 +1244,7 @@ func IsNodeAllowedToCommunicateWithAllRsrcs(node models.Node) bool {
|
||||
node.Tags[models.TagID(fmt.Sprintf("%s.%s", node.Network, models.GwTagName))] = struct{}{}
|
||||
}
|
||||
// list device policies
|
||||
policies := ListDevicePolicies(models.NetworkID(node.Network))
|
||||
policies := ListDevicePolicies(schema.NetworkID(node.Network))
|
||||
srcMap := make(map[string]struct{})
|
||||
dstMap := make(map[string]struct{})
|
||||
defer func() {
|
||||
@@ -1332,7 +1332,7 @@ func IsNodeAllowedToCommunicate(node, peer models.Node, checkDefaultPolicy bool)
|
||||
peerTags[models.TagID(peerId)] = struct{}{}
|
||||
if checkDefaultPolicy {
|
||||
// check default policy if all allowed return true
|
||||
defaultPolicy, err := GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy)
|
||||
defaultPolicy, err := GetDefaultPolicy(schema.NetworkID(node.Network), models.DevicePolicy)
|
||||
if err == nil {
|
||||
if defaultPolicy.Enabled {
|
||||
return true, []models.Acl{defaultPolicy}
|
||||
@@ -1344,7 +1344,7 @@ func IsNodeAllowedToCommunicate(node, peer models.Node, checkDefaultPolicy bool)
|
||||
allowedPolicies = UniquePolicies(allowedPolicies)
|
||||
}()
|
||||
// list device policies
|
||||
policies := ListDevicePolicies(models.NetworkID(peer.Network))
|
||||
policies := ListDevicePolicies(schema.NetworkID(peer.Network))
|
||||
srcMap := make(map[string]struct{})
|
||||
dstMap := make(map[string]struct{})
|
||||
defer func() {
|
||||
@@ -1462,7 +1462,7 @@ func IsNodeAllowedToCommunicate(node, peer models.Node, checkDefaultPolicy bool)
|
||||
}
|
||||
|
||||
// GetDefaultPolicy - fetches default policy in the network by ruleType
|
||||
func GetDefaultPolicy(netID models.NetworkID, ruleType models.AclPolicyType) (models.Acl, error) {
|
||||
func GetDefaultPolicy(netID schema.NetworkID, ruleType models.AclPolicyType) (models.Acl, error) {
|
||||
aclID := "all-users"
|
||||
if ruleType == models.DevicePolicy {
|
||||
aclID = "all-nodes"
|
||||
@@ -1505,7 +1505,7 @@ func GetDefaultPolicy(netID models.NetworkID, ruleType models.AclPolicyType) (mo
|
||||
}
|
||||
|
||||
// ListAcls - lists all acl policies
|
||||
func ListAclsByNetwork(netID models.NetworkID) ([]models.Acl, error) {
|
||||
func ListAclsByNetwork(netID schema.NetworkID) ([]models.Acl, error) {
|
||||
|
||||
allAcls := ListAcls()
|
||||
netAcls := []models.Acl{}
|
||||
@@ -1538,7 +1538,7 @@ func ListEgressAcls(eID string) ([]models.Acl, error) {
|
||||
}
|
||||
|
||||
// ListDevicePolicies - lists all device policies in a network
|
||||
func ListDevicePolicies(netID models.NetworkID) []models.Acl {
|
||||
func ListDevicePolicies(netID schema.NetworkID) []models.Acl {
|
||||
allAcls := ListAcls()
|
||||
deviceAcls := []models.Acl{}
|
||||
for _, acl := range allAcls {
|
||||
@@ -1550,7 +1550,7 @@ func ListDevicePolicies(netID models.NetworkID) []models.Acl {
|
||||
}
|
||||
|
||||
// ListUserPolicies - lists all user policies in a network
|
||||
func ListUserPolicies(netID models.NetworkID) []models.Acl {
|
||||
func ListUserPolicies(netID schema.NetworkID) []models.Acl {
|
||||
allAcls := ListAcls()
|
||||
userAcls := []models.Acl{}
|
||||
for _, acl := range allAcls {
|
||||
@@ -1697,7 +1697,7 @@ func UniquePolicies(items []models.Acl) []models.Acl {
|
||||
}
|
||||
|
||||
// DeleteNetworkPolicies - deletes all default network acl policies
|
||||
func DeleteNetworkPolicies(netId models.NetworkID) {
|
||||
func DeleteNetworkPolicies(netId schema.NetworkID) {
|
||||
acls, _ := ListAclsByNetwork(netId)
|
||||
for _, acl := range acls {
|
||||
if acl.NetworkID == netId {
|
||||
@@ -1716,7 +1716,7 @@ func SortAclEntrys(acls []models.Acl) {
|
||||
// ValidateCreateAclReq - validates create req for acl
|
||||
func ValidateCreateAclReq(req models.Acl) error {
|
||||
// check if acl network exists
|
||||
_, err := GetNetwork(req.NetworkID.String())
|
||||
err := (&schema.Network{Name: req.NetworkID.String()}).Get(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
return errors.New("failed to get network details for " + req.NetworkID.String())
|
||||
}
|
||||
@@ -1726,17 +1726,17 @@ func ValidateCreateAclReq(req models.Acl) error {
|
||||
// }
|
||||
for _, src := range req.Src {
|
||||
if src.ID == models.UserGroupAclID {
|
||||
userGroup, err := GetUserGroup(models.UserGroupID(src.Value))
|
||||
userGroup, err := GetUserGroup(schema.UserGroupID(src.Value))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, ok := userGroup.NetworkRoles[models.AllNetworks]
|
||||
_, ok := userGroup.NetworkRoles.Data()[schema.AllNetworks]
|
||||
if ok {
|
||||
continue
|
||||
}
|
||||
|
||||
_, ok = userGroup.NetworkRoles[req.NetworkID]
|
||||
_, ok = userGroup.NetworkRoles.Data()[req.NetworkID]
|
||||
if !ok {
|
||||
return fmt.Errorf("user group %s does not have access to network %s", src.Value, req.NetworkID)
|
||||
}
|
||||
@@ -1824,7 +1824,7 @@ func RemoveNodeFromAclPolicy(node models.Node) {
|
||||
} else {
|
||||
nodeID = node.ID.String()
|
||||
}
|
||||
acls, _ := ListAclsByNetwork(models.NetworkID(node.Network))
|
||||
acls, _ := ListAclsByNetwork(schema.NetworkID(node.Network))
|
||||
for _, acl := range acls {
|
||||
delete := false
|
||||
update := false
|
||||
@@ -1891,7 +1891,7 @@ func RemoveNodeFromAclPolicy(node models.Node) {
|
||||
}
|
||||
|
||||
// CreateDefaultAclNetworkPolicies - create default acl network policies
|
||||
func CreateDefaultAclNetworkPolicies(netID models.NetworkID) {
|
||||
func CreateDefaultAclNetworkPolicies(netID schema.NetworkID) {
|
||||
if netID.String() == "" {
|
||||
return
|
||||
}
|
||||
@@ -1957,7 +1957,7 @@ func CreateDefaultAclNetworkPolicies(netID models.NetworkID) {
|
||||
CreateDefaultUserPolicies(netID)
|
||||
}
|
||||
|
||||
func getTagMapWithNodesByNetwork(netID models.NetworkID, withStaticNodes bool) (tagNodesMap map[models.TagID][]models.Node) {
|
||||
func getTagMapWithNodesByNetwork(netID schema.NetworkID, withStaticNodes bool) (tagNodesMap map[models.TagID][]models.Node) {
|
||||
tagNodesMap = make(map[models.TagID][]models.Node)
|
||||
nodes, _ := GetNetworkNodes(netID.String())
|
||||
netGwTag := models.TagID(fmt.Sprintf("%s.%s", netID.String(), models.GwTagName))
|
||||
@@ -1974,7 +1974,7 @@ func getTagMapWithNodesByNetwork(netID models.NetworkID, withStaticNodes bool) (
|
||||
return addTagMapWithStaticNodes(netID, tagNodesMap)
|
||||
}
|
||||
|
||||
func addTagMapWithStaticNodes(netID models.NetworkID,
|
||||
func addTagMapWithStaticNodes(netID schema.NetworkID,
|
||||
tagNodesMap map[models.TagID][]models.Node) map[models.TagID][]models.Node {
|
||||
extclients, err := GetNetworkExtClients(netID.String())
|
||||
if err != nil {
|
||||
|
||||
+221
-217
@@ -6,19 +6,21 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/mail"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gravitl/netmaker/db"
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
"gorm.io/datatypes"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"golang.org/x/exp/slog"
|
||||
|
||||
"github.com/gravitl/netmaker/database"
|
||||
"github.com/gravitl/netmaker/logger"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/gravitl/netmaker/servercfg"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -37,67 +39,25 @@ var ResetIDPSyncHook = func() {}
|
||||
|
||||
// HasSuperAdmin - checks if server has an superadmin/owner
|
||||
func HasSuperAdmin() (bool, error) {
|
||||
users, err := GetUsersDB()
|
||||
if err != nil {
|
||||
if database.IsEmptyRecord(err) {
|
||||
return false, nil
|
||||
}
|
||||
return true, err
|
||||
}
|
||||
for _, user := range users {
|
||||
if user.PlatformRoleID == models.SuperAdminRole {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// GetUsersDB - gets users
|
||||
func GetUsersDB() ([]models.User, error) {
|
||||
if servercfg.CacheEnabled() {
|
||||
users := getUsersFromCache()
|
||||
if len(users) != 0 {
|
||||
return users, nil
|
||||
}
|
||||
}
|
||||
|
||||
var users []models.User
|
||||
collection, err := database.FetchRecords(database.USERS_TABLE_NAME)
|
||||
if err != nil {
|
||||
return users, err
|
||||
}
|
||||
|
||||
cacheMap := make(map[string]models.User, len(collection))
|
||||
for _, value := range collection {
|
||||
var user models.User
|
||||
err = json.Unmarshal([]byte(value), &user)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
users = append(users, user)
|
||||
cacheMap[user.UserName] = user
|
||||
}
|
||||
if servercfg.CacheEnabled() {
|
||||
loadUsersIntoCache(cacheMap)
|
||||
}
|
||||
return users, nil
|
||||
return (&schema.User{}).SuperAdminExists(db.WithContext(context.TODO()))
|
||||
}
|
||||
|
||||
// GetUsers - gets users
|
||||
func GetUsers() ([]models.ReturnUser, error) {
|
||||
dbUsers, err := GetUsersDB()
|
||||
_users, err := (&schema.User{}).ListAll(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
users := make([]models.ReturnUser, 0, len(dbUsers))
|
||||
for _, u := range dbUsers {
|
||||
users = append(users, ToReturnUser(u))
|
||||
|
||||
users := make([]models.ReturnUser, len(_users))
|
||||
for i, _user := range _users {
|
||||
users[i] = ToReturnUser(&_user)
|
||||
}
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// IsOauthUser - returns
|
||||
func IsOauthUser(user *models.User) error {
|
||||
func IsOauthUser(user *schema.User) error {
|
||||
var currentValue, err = FetchPassValue("")
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -130,63 +90,63 @@ func FetchPassValue(newValue string) (string, error) {
|
||||
}
|
||||
|
||||
// CreateUser - creates a user
|
||||
func CreateUser(user *models.User) error {
|
||||
func CreateUser(_user *schema.User) error {
|
||||
// check if user exists
|
||||
if _, err := GetUser(user.UserName); err == nil {
|
||||
userCheck := &schema.User{Username: _user.Username}
|
||||
if err := userCheck.Get(db.WithContext(context.TODO())); err == nil {
|
||||
return errors.New("user exists")
|
||||
}
|
||||
SetUserDefaults(user)
|
||||
if err := IsGroupsValid(user.UserGroups); err != nil {
|
||||
SetUserDefaults(_user)
|
||||
if err := IsGroupsValid(_user.UserGroups.Data()); err != nil {
|
||||
return errors.New("invalid groups: " + err.Error())
|
||||
}
|
||||
if err := IsNetworkRolesValid(user.NetworkRoles); err != nil {
|
||||
return errors.New("invalid network roles: " + err.Error())
|
||||
}
|
||||
|
||||
var err = ValidateUser(user)
|
||||
var err = ValidateUser(_user)
|
||||
if err != nil {
|
||||
logger.Log(0, "failed to validate user", err.Error())
|
||||
return err
|
||||
}
|
||||
// encrypt that password so we never see it again
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(user.Password), 5)
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(_user.Password), 5)
|
||||
if err != nil {
|
||||
logger.Log(0, "error encrypting pass", err.Error())
|
||||
return err
|
||||
}
|
||||
// set password to encrypted password
|
||||
user.Password = string(hash)
|
||||
user.AuthType = models.BasicAuth
|
||||
if IsOauthUser(user) == nil {
|
||||
user.AuthType = models.OAuth
|
||||
_user.Password = string(hash)
|
||||
_user.AuthType = schema.BasicAuth
|
||||
if IsOauthUser(_user) == nil {
|
||||
_user.AuthType = schema.OAuth
|
||||
}
|
||||
AddGlobalNetRolesToAdmins(user)
|
||||
AddGlobalNetRolesToAdmins(_user)
|
||||
// create user will always be called either from API or Dashboard.
|
||||
_, err = CreateUserJWT(user.UserName, user.PlatformRoleID, DashboardApp)
|
||||
_, err = CreateUserJWT(_user.Username, _user.PlatformRoleID, DashboardApp)
|
||||
if err != nil {
|
||||
logger.Log(0, "failed to generate token", err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
// connect db
|
||||
data, err := json.Marshal(user)
|
||||
dbctx := db.BeginTx(context.TODO())
|
||||
commit := false
|
||||
defer func() {
|
||||
if commit {
|
||||
db.FromContext(dbctx).Commit()
|
||||
} else {
|
||||
db.FromContext(dbctx).Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
err = _user.Create(dbctx)
|
||||
if err != nil {
|
||||
logger.Log(0, "failed to marshal", err.Error())
|
||||
return err
|
||||
}
|
||||
err = database.Insert(user.UserName, string(data), database.USERS_TABLE_NAME)
|
||||
if err != nil {
|
||||
logger.Log(0, "failed to insert user", err.Error())
|
||||
return err
|
||||
}
|
||||
if servercfg.CacheEnabled() {
|
||||
storeUserInCache(*user)
|
||||
return fmt.Errorf("failed to create user %s: %v", _user.Username, err)
|
||||
}
|
||||
|
||||
commit = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateSuperAdmin - creates an super admin user
|
||||
func CreateSuperAdmin(u *models.User) error {
|
||||
func CreateSuperAdmin(u *schema.User) error {
|
||||
hassuperadmin, err := HasSuperAdmin()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -194,37 +154,34 @@ func CreateSuperAdmin(u *models.User) error {
|
||||
if hassuperadmin {
|
||||
return errors.New("superadmin user already exists")
|
||||
}
|
||||
u.IsSuperAdmin = true
|
||||
u.IsAdmin = true
|
||||
u.PlatformRoleID = models.SuperAdminRole
|
||||
u.PlatformRoleID = schema.SuperAdminRole
|
||||
return CreateUser(u)
|
||||
}
|
||||
|
||||
// VerifyAuthRequest - verifies an auth request
|
||||
func VerifyAuthRequest(authRequest models.UserAuthParams, appName string) (string, error) {
|
||||
var result models.User
|
||||
if authRequest.UserName == "" {
|
||||
return "", errors.New("username can't be empty")
|
||||
} else if authRequest.Password == "" {
|
||||
return "", errors.New("password can't be empty")
|
||||
}
|
||||
// Search DB for node with Mac Address. Ignore pending nodes (they should not be able to authenticate with API until approved).
|
||||
record, err := database.FetchRecord(database.USERS_TABLE_NAME, authRequest.UserName)
|
||||
_user := &schema.User{
|
||||
Username: authRequest.UserName,
|
||||
}
|
||||
err := _user.Get(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
return "", errors.New("incorrect credentials")
|
||||
}
|
||||
if err = json.Unmarshal([]byte(record), &result); err != nil {
|
||||
return "", errors.New("error unmarshalling user json: " + err.Error())
|
||||
}
|
||||
|
||||
// compare password from request to stored password in database
|
||||
// might be able to have a common hash (certificates?) and compare those so that a password isn't passed in in plain text...
|
||||
// TODO: Consider a way of hashing the password client side before sending, or using certificates
|
||||
if err = bcrypt.CompareHashAndPassword([]byte(result.Password), []byte(authRequest.Password)); err != nil {
|
||||
if err = bcrypt.CompareHashAndPassword([]byte(_user.Password), []byte(authRequest.Password)); err != nil {
|
||||
return "", errors.New("incorrect credentials")
|
||||
}
|
||||
|
||||
if result.IsMFAEnabled {
|
||||
if _user.IsMFAEnabled {
|
||||
tokenString, err := CreatePreAuthToken(authRequest.UserName)
|
||||
if err != nil {
|
||||
slog.Error("error creating jwt", "error", err)
|
||||
@@ -234,15 +191,15 @@ func VerifyAuthRequest(authRequest models.UserAuthParams, appName string) (strin
|
||||
return tokenString, nil
|
||||
} else {
|
||||
// Create a new JWT for the node
|
||||
tokenString, err := CreateUserJWT(authRequest.UserName, result.PlatformRoleID, appName)
|
||||
tokenString, err := CreateUserJWT(authRequest.UserName, schema.UserRoleID(_user.PlatformRoleID), appName)
|
||||
if err != nil {
|
||||
slog.Error("error creating jwt", "error", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
// update last login time
|
||||
result.LastLoginTime = time.Now().UTC()
|
||||
err = UpsertUser(result)
|
||||
_user.LastLoginAt = time.Now().UTC()
|
||||
err = _user.Update(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
slog.Error("error upserting user", "error", err)
|
||||
return "", err
|
||||
@@ -253,44 +210,42 @@ func VerifyAuthRequest(authRequest models.UserAuthParams, appName string) (strin
|
||||
}
|
||||
|
||||
// UpsertUser - updates user in the db
|
||||
func UpsertUser(user models.User) error {
|
||||
data, err := json.Marshal(&user)
|
||||
if err != nil {
|
||||
slog.Error("error marshalling user", "user", user.UserName, "error", err.Error())
|
||||
return err
|
||||
func UpsertUser(_user schema.User) error {
|
||||
_existingUser := schema.User{Username: _user.Username}
|
||||
// Check if user exists to preserve ID
|
||||
err := _existingUser.Get(db.WithContext(context.TODO()))
|
||||
if err == nil {
|
||||
_user.ID = _existingUser.ID
|
||||
return _user.Update(db.WithContext(context.TODO()))
|
||||
}
|
||||
if err = database.Insert(user.UserName, string(data), database.USERS_TABLE_NAME); err != nil {
|
||||
slog.Error("error inserting user", "user", user.UserName, "error", err.Error())
|
||||
return err
|
||||
}
|
||||
if servercfg.CacheEnabled() {
|
||||
storeUserInCache(user)
|
||||
}
|
||||
return nil
|
||||
|
||||
return _user.Create(db.WithContext(context.TODO()))
|
||||
}
|
||||
|
||||
// UpdateUser - updates a given user
|
||||
func UpdateUser(userchange, user *models.User) (*models.User, error) {
|
||||
func UpdateUser(userchange, _user *schema.User) (*schema.User, error) {
|
||||
// check if user exists
|
||||
if _, err := GetUser(user.UserName); err != nil {
|
||||
return &models.User{}, err
|
||||
userCheck := &schema.User{Username: _user.Username}
|
||||
if err := userCheck.Get(db.WithContext(context.TODO())); err != nil {
|
||||
return &schema.User{}, err
|
||||
}
|
||||
|
||||
queryUser := user.UserName
|
||||
if userchange.UserName != "" && user.UserName != userchange.UserName {
|
||||
queryUser := _user.Username
|
||||
if userchange.Username != "" && _user.Username != userchange.Username {
|
||||
// check if username is available
|
||||
if _, err := GetUser(userchange.UserName); err == nil {
|
||||
return &models.User{}, errors.New("username exists already")
|
||||
userCheck := &schema.User{Username: userchange.Username}
|
||||
if err := userCheck.Get(db.WithContext(context.TODO())); err == nil {
|
||||
return &schema.User{}, errors.New("username exists already")
|
||||
}
|
||||
if userchange.UserName == MasterUser {
|
||||
return &models.User{}, errors.New("username not allowed")
|
||||
if userchange.Username == MasterUser {
|
||||
return &schema.User{}, errors.New("username not allowed")
|
||||
}
|
||||
|
||||
user.UserName = userchange.UserName
|
||||
_user.Username = userchange.Username
|
||||
}
|
||||
if userchange.Password != "" {
|
||||
if len(userchange.Password) < 5 {
|
||||
return &models.User{}, errors.New("password requires min 5 characters")
|
||||
return &schema.User{}, errors.New("password requires min 5 characters")
|
||||
}
|
||||
// encrypt that password so we never see it again
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(userchange.Password), 5)
|
||||
@@ -301,152 +256,201 @@ func UpdateUser(userchange, user *models.User) (*models.User, error) {
|
||||
// set password to encrypted password
|
||||
userchange.Password = string(hash)
|
||||
|
||||
user.Password = userchange.Password
|
||||
_user.Password = userchange.Password
|
||||
}
|
||||
|
||||
validUserGroups := make(map[models.UserGroupID]struct{})
|
||||
for userGroupID := range userchange.UserGroups {
|
||||
validUserGroups := make(map[schema.UserGroupID]struct{})
|
||||
for userGroupID := range userchange.UserGroups.Data() {
|
||||
_, err := GetUserGroup(userGroupID)
|
||||
if err == nil {
|
||||
validUserGroups[userGroupID] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
userchange.UserGroups = validUserGroups
|
||||
|
||||
if err := IsNetworkRolesValid(userchange.NetworkRoles); err != nil {
|
||||
return userchange, errors.New("invalid network roles: " + err.Error())
|
||||
}
|
||||
userchange.UserGroups = datatypes.NewJSONType(validUserGroups)
|
||||
|
||||
if userchange.DisplayName != "" {
|
||||
if user.ExternalIdentityProviderID != "" &&
|
||||
user.DisplayName != userchange.DisplayName {
|
||||
if _user.ExternalIdentityProviderID != "" &&
|
||||
_user.DisplayName != userchange.DisplayName {
|
||||
return userchange, errors.New("display name cannot be updated for external user")
|
||||
}
|
||||
|
||||
user.DisplayName = userchange.DisplayName
|
||||
_user.DisplayName = userchange.DisplayName
|
||||
}
|
||||
|
||||
if user.ExternalIdentityProviderID != "" &&
|
||||
userchange.AccountDisabled != user.AccountDisabled {
|
||||
if _user.ExternalIdentityProviderID != "" &&
|
||||
userchange.AccountDisabled != _user.AccountDisabled {
|
||||
return userchange, errors.New("account status cannot be updated for external user")
|
||||
}
|
||||
|
||||
// Reset Gw Access for service users
|
||||
go UpdateUserGwAccess(*user, *userchange)
|
||||
go UpdateUserGwAccess(_user, userchange)
|
||||
if userchange.PlatformRoleID != "" {
|
||||
user.PlatformRoleID = userchange.PlatformRoleID
|
||||
// TODO: remove once NMUI stops using these fields.
|
||||
if user.PlatformRoleID == models.SuperAdminRole {
|
||||
user.IsSuperAdmin = true
|
||||
user.IsAdmin = true
|
||||
} else if user.PlatformRoleID == models.AdminRole {
|
||||
user.IsSuperAdmin = false
|
||||
user.IsAdmin = true
|
||||
_user.PlatformRoleID = userchange.PlatformRoleID
|
||||
}
|
||||
|
||||
for groupID := range userchange.UserGroups.Data() {
|
||||
_, ok := _user.UserGroups.Data()[groupID]
|
||||
if !ok {
|
||||
group, err := GetUserGroup(groupID)
|
||||
if err != nil {
|
||||
return userchange, err
|
||||
}
|
||||
|
||||
if group.ExternalIdentityProviderID != "" {
|
||||
return userchange, errors.New("cannot modify membership of external groups")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for groupID := range _user.UserGroups.Data() {
|
||||
_, ok := userchange.UserGroups.Data()[groupID]
|
||||
if !ok {
|
||||
group, err := GetUserGroup(groupID)
|
||||
if err != nil {
|
||||
return userchange, err
|
||||
}
|
||||
|
||||
if group.ExternalIdentityProviderID != "" {
|
||||
return userchange, errors.New("cannot modify membership of external groups")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var updateMFA bool
|
||||
if _user.IsMFAEnabled != userchange.IsMFAEnabled {
|
||||
updateMFA = true
|
||||
}
|
||||
|
||||
_user.IsMFAEnabled = userchange.IsMFAEnabled
|
||||
|
||||
var updateAccountStatus bool
|
||||
if _user.AccountDisabled != userchange.AccountDisabled {
|
||||
updateAccountStatus = true
|
||||
}
|
||||
|
||||
_user.IsMFAEnabled = userchange.IsMFAEnabled
|
||||
if !_user.IsMFAEnabled {
|
||||
_user.TOTPSecret = ""
|
||||
}
|
||||
|
||||
_user.UserGroups = userchange.UserGroups
|
||||
AddGlobalNetRolesToAdmins(_user)
|
||||
err := ValidateUser(_user)
|
||||
if err != nil {
|
||||
return &schema.User{}, err
|
||||
}
|
||||
|
||||
dbctx := db.BeginTx(context.TODO())
|
||||
commit := false
|
||||
defer func() {
|
||||
if commit {
|
||||
db.FromContext(dbctx).Commit()
|
||||
logger.Log(1, "updated user", queryUser)
|
||||
} else {
|
||||
user.IsSuperAdmin = false
|
||||
user.IsAdmin = false
|
||||
db.FromContext(dbctx).Rollback()
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for groupID := range userchange.UserGroups {
|
||||
_, ok := user.UserGroups[groupID]
|
||||
if !ok {
|
||||
group, err := GetUserGroup(groupID)
|
||||
if err != nil {
|
||||
return userchange, err
|
||||
}
|
||||
|
||||
if group.ExternalIdentityProviderID != "" {
|
||||
return userchange, errors.New("cannot modify membership of external groups")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for groupID := range user.UserGroups {
|
||||
_, ok := userchange.UserGroups[groupID]
|
||||
if !ok {
|
||||
group, err := GetUserGroup(groupID)
|
||||
if err != nil {
|
||||
return userchange, err
|
||||
}
|
||||
|
||||
if group.ExternalIdentityProviderID != "" {
|
||||
return userchange, errors.New("cannot modify membership of external groups")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
user.IsMFAEnabled = userchange.IsMFAEnabled
|
||||
if !user.IsMFAEnabled {
|
||||
user.TOTPSecret = ""
|
||||
}
|
||||
|
||||
user.UserGroups = userchange.UserGroups
|
||||
user.NetworkRoles = userchange.NetworkRoles
|
||||
AddGlobalNetRolesToAdmins(user)
|
||||
err := ValidateUser(user)
|
||||
// Fetch existing user to get ID
|
||||
_schemaUser := schema.User{Username: queryUser}
|
||||
err = _schemaUser.Get(dbctx)
|
||||
if err != nil {
|
||||
return &models.User{}, err
|
||||
return &schema.User{}, err
|
||||
}
|
||||
if err = database.DeleteRecord(database.USERS_TABLE_NAME, queryUser); err != nil {
|
||||
return &models.User{}, err
|
||||
}
|
||||
data, err := json.Marshal(&user)
|
||||
|
||||
_user.ID = _schemaUser.ID
|
||||
|
||||
err = _user.Update(dbctx)
|
||||
if err != nil {
|
||||
return &models.User{}, err
|
||||
return &schema.User{}, err
|
||||
}
|
||||
if err = database.Insert(user.UserName, string(data), database.USERS_TABLE_NAME); err != nil {
|
||||
return &models.User{}, err
|
||||
}
|
||||
if servercfg.CacheEnabled() {
|
||||
if queryUser != user.UserName {
|
||||
deleteUserFromCache(queryUser)
|
||||
|
||||
if updateAccountStatus {
|
||||
err = _user.UpdateAccountStatus(dbctx)
|
||||
if err != nil {
|
||||
return &schema.User{}, err
|
||||
}
|
||||
storeUserInCache(*user)
|
||||
}
|
||||
logger.Log(1, "updated user", queryUser)
|
||||
return user, nil
|
||||
|
||||
if updateMFA {
|
||||
err = _user.UpdateMFA(dbctx)
|
||||
if err != nil {
|
||||
return &schema.User{}, err
|
||||
}
|
||||
}
|
||||
|
||||
commit = true
|
||||
return _user, nil
|
||||
}
|
||||
|
||||
func validateUserName(user *schema.User) error {
|
||||
var validationErr error
|
||||
|
||||
if len(user.Username) == 0 {
|
||||
validationErr = errors.Join(validationErr, errors.New("username cannot be empty"))
|
||||
} else if len(user.Username) <= 3 {
|
||||
validationErr = errors.Join(validationErr, errors.New("username must have more than 3 characters"))
|
||||
}
|
||||
|
||||
var isValidEmail bool
|
||||
_, err := mail.ParseAddress(user.Username)
|
||||
if err == nil {
|
||||
isValidEmail = true
|
||||
}
|
||||
|
||||
if !isValidEmail {
|
||||
charset := "abcdefghijklmnopqrstuvwxyz1234567890-."
|
||||
for _, char := range user.Username {
|
||||
if !strings.Contains(charset, strings.ToLower(string(char))) {
|
||||
validationErr = errors.Join(validationErr, errors.New("invalid character(s) in username"))
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return validationErr
|
||||
}
|
||||
|
||||
// ValidateUser - validates a user model
|
||||
func ValidateUser(user *models.User) error {
|
||||
|
||||
func ValidateUser(user *schema.User) error {
|
||||
var validationErr error
|
||||
// check if role is valid
|
||||
_, err := GetRole(user.PlatformRoleID)
|
||||
roleCheck := &schema.UserRole{ID: user.PlatformRoleID}
|
||||
err := roleCheck.Get(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
return errors.New("failed to fetch platform role " + user.PlatformRoleID.String())
|
||||
}
|
||||
v := validator.New()
|
||||
_ = v.RegisterValidation("in_charset", func(fl validator.FieldLevel) bool {
|
||||
isgood := user.NameInCharSet()
|
||||
return isgood
|
||||
})
|
||||
err = v.Struct(user)
|
||||
|
||||
if err != nil {
|
||||
for _, e := range err.(validator.ValidationErrors) {
|
||||
logger.Log(2, e.Error())
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return err
|
||||
}
|
||||
|
||||
validationErr = errors.Join(validationErr, fmt.Errorf("invalid user role %s", user.PlatformRoleID))
|
||||
}
|
||||
|
||||
return err
|
||||
err = validateUserName(user)
|
||||
if err != nil {
|
||||
validationErr = errors.Join(validationErr, err)
|
||||
}
|
||||
|
||||
if len(user.Password) < 5 {
|
||||
validationErr = errors.Join(validationErr, errors.New("password must have a minimum of 5 characters"))
|
||||
}
|
||||
|
||||
return validationErr
|
||||
}
|
||||
|
||||
// DeleteUser - deletes a given user
|
||||
func DeleteUser(user string) error {
|
||||
|
||||
if userRecord, err := database.FetchRecord(database.USERS_TABLE_NAME, user); err != nil || len(userRecord) == 0 {
|
||||
return errors.New("user does not exist")
|
||||
_user := schema.User{
|
||||
Username: user,
|
||||
}
|
||||
|
||||
err := database.DeleteRecord(database.USERS_TABLE_NAME, user)
|
||||
err := _user.Delete(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errors.New("user does not exist")
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
if servercfg.CacheEnabled() {
|
||||
deleteUserFromCache(user)
|
||||
}
|
||||
|
||||
go RemoveUserFromAclPolicy(user)
|
||||
return (&schema.UserAccessToken{UserName: user}).DeleteAllUserTokens(db.WithContext(context.TODO()))
|
||||
}
|
||||
|
||||
+19
-13
@@ -116,14 +116,14 @@ func SetDNS() error {
|
||||
return err
|
||||
}
|
||||
var corefilestring string
|
||||
networks, err := GetNetworks()
|
||||
if err != nil && !database.IsEmptyRecord(err) {
|
||||
networks, err := (&schema.Network{}).ListAll(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, net := range networks {
|
||||
corefilestring = corefilestring + net.NetID + " "
|
||||
dns, err := GetDNS(net.NetID)
|
||||
corefilestring = corefilestring + net.Name + " "
|
||||
dns, err := GetDNS(net.Name)
|
||||
if err != nil && !database.IsEmptyRecord(err) {
|
||||
return err
|
||||
}
|
||||
@@ -229,7 +229,10 @@ func GetNodeDNS(network string) ([]models.DNSEntry, error) {
|
||||
if node.Network != network {
|
||||
continue
|
||||
}
|
||||
host, err := GetHost(node.HostID.String())
|
||||
host := &schema.Host{
|
||||
ID: node.HostID,
|
||||
}
|
||||
err = host.Get(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
@@ -256,7 +259,10 @@ func GetGwDNS(node *models.Node) string {
|
||||
if !servercfg.GetManageDNS() {
|
||||
return ""
|
||||
}
|
||||
h, err := GetHost(node.HostID.String())
|
||||
h := &schema.Host{
|
||||
ID: node.HostID,
|
||||
}
|
||||
err := h.Get(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
@@ -340,12 +346,12 @@ func SetCorefile(domains string) error {
|
||||
// GetAllDNS - gets all dns entries
|
||||
func GetAllDNS() ([]models.DNSEntry, error) {
|
||||
var dns []models.DNSEntry
|
||||
networks, err := GetNetworks()
|
||||
if err != nil && !database.IsEmptyRecord(err) {
|
||||
networks, err := (&schema.Network{}).ListAll(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
return []models.DNSEntry{}, err
|
||||
}
|
||||
for _, net := range networks {
|
||||
netdns, err := GetDNS(net.NetID)
|
||||
netdns, err := GetDNS(net.Name)
|
||||
if err != nil {
|
||||
return []models.DNSEntry{}, nil
|
||||
}
|
||||
@@ -405,7 +411,7 @@ func ValidateDNSCreate(entry models.DNSEntry) error {
|
||||
})
|
||||
|
||||
_ = v.RegisterValidation("network_exists", func(fl validator.FieldLevel) bool {
|
||||
_, err := GetParentNetwork(entry.Network)
|
||||
err := (&schema.Network{Name: entry.Network}).Get(db.WithContext(context.TODO()))
|
||||
return err == nil
|
||||
})
|
||||
|
||||
@@ -437,7 +443,7 @@ func ValidateDNSUpdate(change models.DNSEntry, entry models.DNSEntry) error {
|
||||
return err == nil && num == 0
|
||||
})
|
||||
_ = v.RegisterValidation("network_exists", func(fl validator.FieldLevel) bool {
|
||||
_, err := GetParentNetwork(change.Network)
|
||||
err := (&schema.Network{Name: change.Network}).Get(db.WithContext(context.TODO()))
|
||||
return err == nil
|
||||
})
|
||||
|
||||
@@ -488,7 +494,7 @@ func validateNameserverReq(ns *schema.Nameserver) error {
|
||||
if len(ns.Servers) == 0 {
|
||||
return errors.New("atleast one nameserver should be specified")
|
||||
}
|
||||
_, err := GetNetwork(ns.NetworkID)
|
||||
err := (&schema.Network{Name: ns.NetworkID}).Get(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
return errors.New("invalid network id")
|
||||
}
|
||||
@@ -595,7 +601,7 @@ func getNameserversForNode(node *models.Node) (returnNsLi []models.Nameserver) {
|
||||
return
|
||||
}
|
||||
|
||||
func getNameserversForHost(h *models.Host) (returnNsLi []models.Nameserver) {
|
||||
func getNameserversForHost(h *schema.Host) (returnNsLi []models.Nameserver) {
|
||||
if h.DNS != "yes" {
|
||||
return
|
||||
}
|
||||
|
||||
+14
-11
@@ -16,7 +16,7 @@ import (
|
||||
|
||||
var ValidateEgressReq = validateEgressReq
|
||||
|
||||
var AssignVirtualRangeToEgress = func(nw *models.Network, eg *schema.Egress) error {
|
||||
var AssignVirtualRangeToEgress = func(nw *schema.Network, eg *schema.Egress) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -25,12 +25,12 @@ func validateEgressReq(e *schema.Egress) error {
|
||||
return errors.New("network id is empty")
|
||||
}
|
||||
if e.Nat {
|
||||
e.Mode = models.DirectNAT
|
||||
e.Mode = schema.DirectNAT
|
||||
} else {
|
||||
e.Mode = ""
|
||||
e.VirtualRange = ""
|
||||
}
|
||||
_, err := GetNetwork(e.Network)
|
||||
err := (&schema.Network{Name: e.Network}).Get(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
return errors.New("failed to get network " + err.Error())
|
||||
}
|
||||
@@ -50,7 +50,7 @@ func validateEgressReq(e *schema.Egress) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func DoesUserHaveAccessToEgress(user *models.User, e *schema.Egress, acls []models.Acl) bool {
|
||||
func DoesUserHaveAccessToEgress(user *schema.User, e *schema.Egress, acls []models.Acl) bool {
|
||||
if !e.Status {
|
||||
return false
|
||||
}
|
||||
@@ -64,11 +64,11 @@ func DoesUserHaveAccessToEgress(user *models.User, e *schema.Egress, acls []mode
|
||||
if _, ok := dstTags[e.ID]; ok || all {
|
||||
// get all src tags
|
||||
for _, srcAcl := range acl.Src {
|
||||
if srcAcl.ID == models.UserAclID && srcAcl.Value == user.UserName {
|
||||
if srcAcl.ID == models.UserAclID && srcAcl.Value == user.Username {
|
||||
return true
|
||||
} else if srcAcl.ID == models.UserGroupAclID {
|
||||
// fetch all users in the group
|
||||
if _, ok := user.UserGroups[models.UserGroupID(srcAcl.Value)]; ok {
|
||||
if _, ok := user.UserGroups.Data()[schema.UserGroupID(srcAcl.Value)]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@@ -253,7 +253,7 @@ func AddEgressInfoToPeerByAccess(node, targetNode *models.Node, eli []schema.Egr
|
||||
}
|
||||
}
|
||||
|
||||
func GetEgressDomainsByAccessForUser(user *models.User, network models.NetworkID) (domains []string) {
|
||||
func GetEgressDomainsByAccessForUser(user *schema.User, network schema.NetworkID) (domains []string) {
|
||||
acls := ListUserPolicies(network)
|
||||
eli, _ := (&schema.Egress{Network: network.String()}).ListByNetwork(db.WithContext(context.TODO()))
|
||||
defaultDevicePolicy, _ := GetDefaultPolicy(network, models.UserPolicy)
|
||||
@@ -276,9 +276,9 @@ func GetEgressDomainsByAccessForUser(user *models.User, network models.NetworkID
|
||||
}
|
||||
|
||||
func GetEgressDomainNSForNode(node *models.Node) (returnNsLi []models.Nameserver) {
|
||||
acls := ListDevicePolicies(models.NetworkID(node.Network))
|
||||
acls := ListDevicePolicies(schema.NetworkID(node.Network))
|
||||
eli, _ := (&schema.Egress{Network: node.Network}).ListByNetwork(db.WithContext(context.TODO()))
|
||||
defaultDevicePolicy, _ := GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy)
|
||||
defaultDevicePolicy, _ := GetDefaultPolicy(schema.NetworkID(node.Network), models.DevicePolicy)
|
||||
isDefaultPolicyActive := defaultDevicePolicy.Enabled
|
||||
for _, e := range eli {
|
||||
if !e.Status || e.Network != node.Network {
|
||||
@@ -458,7 +458,7 @@ func RemoveNodeFromEgress(node models.Node) {
|
||||
}
|
||||
}
|
||||
|
||||
func GetEgressRanges(netID models.NetworkID) (map[string][]string, map[string]struct{}, error) {
|
||||
func GetEgressRanges(netID schema.NetworkID) (map[string][]string, map[string]struct{}, error) {
|
||||
|
||||
resultMap := make(map[string]struct{})
|
||||
nodeEgressMap := make(map[string][]string)
|
||||
@@ -496,7 +496,10 @@ func ListAllByRoutingNodeWithDomain(egs []schema.Egress, nodeID string) (egWithD
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
host, err := GetHost(node.HostID.String())
|
||||
host := &schema.Host{
|
||||
ID: node.HostID,
|
||||
}
|
||||
err = host.Get(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
+22
-16
@@ -72,8 +72,8 @@ func GetEgressRangesOnNetwork(client *models.ExtClient) ([]string, error) {
|
||||
var result []string
|
||||
eli, _ := (&schema.Egress{Network: client.Network}).ListByNetwork(db.WithContext(context.TODO()))
|
||||
staticNode := client.ConvertToStaticNode()
|
||||
userPolicies := ListUserPolicies(models.NetworkID(client.Network))
|
||||
defaultUserPolicy, _ := GetDefaultPolicy(models.NetworkID(client.Network), models.UserPolicy)
|
||||
userPolicies := ListUserPolicies(schema.NetworkID(client.Network))
|
||||
defaultUserPolicy, _ := GetDefaultPolicy(schema.NetworkID(client.Network), models.UserPolicy)
|
||||
|
||||
for _, eI := range eli {
|
||||
if !eI.Status {
|
||||
@@ -100,7 +100,8 @@ func GetEgressRangesOnNetwork(client *models.ExtClient) ([]string, error) {
|
||||
result = append(result, rangesToBeAdded...)
|
||||
} else {
|
||||
if staticNode.IsUserNode && staticNode.StaticNode.OwnerID != "" {
|
||||
user, err := GetUser(staticNode.StaticNode.OwnerID)
|
||||
user := &schema.User{Username: staticNode.StaticNode.OwnerID}
|
||||
err := user.Get(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
return []string{}, errors.New("user not found")
|
||||
}
|
||||
@@ -173,21 +174,21 @@ func DeleteExtClient(network string, clientid string, isUpdate bool) error {
|
||||
}
|
||||
if !isUpdate && extClient.RemoteAccessClientID != "" {
|
||||
LogEvent(&models.Event{
|
||||
Action: models.Disconnect,
|
||||
Action: schema.Disconnect,
|
||||
Source: models.Subject{
|
||||
ID: extClient.OwnerID,
|
||||
Name: extClient.OwnerID,
|
||||
Type: models.UserSub,
|
||||
Type: schema.UserSub,
|
||||
},
|
||||
TriggeredBy: extClient.OwnerID,
|
||||
Target: models.Subject{
|
||||
ID: extClient.Network,
|
||||
Name: extClient.Network,
|
||||
Type: models.NetworkSub,
|
||||
Type: schema.NetworkSub,
|
||||
Info: extClient,
|
||||
},
|
||||
NetworkID: models.NetworkID(extClient.Network),
|
||||
Origin: models.ClientApp,
|
||||
NetworkID: schema.NetworkID(extClient.Network),
|
||||
Origin: schema.ClientApp,
|
||||
})
|
||||
}
|
||||
go RemoveNodeFromAclPolicy(extClient.ConvertToStaticNode())
|
||||
@@ -343,12 +344,13 @@ func CreateExtClient(extclient *models.ExtClient) error {
|
||||
extclient.ExtraAllowedIPs = []string{}
|
||||
}
|
||||
|
||||
parentNetwork, err := GetNetwork(extclient.Network)
|
||||
parentNetwork := &schema.Network{Name: extclient.Network}
|
||||
err := parentNetwork.Get(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if extclient.Address == "" {
|
||||
if parentNetwork.IsIPv4 == "yes" {
|
||||
if parentNetwork.AddressRange != "" {
|
||||
newAddress, err := UniqueAddress(extclient.Network, true)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -358,7 +360,7 @@ func CreateExtClient(extclient *models.ExtClient) error {
|
||||
}
|
||||
|
||||
if extclient.Address6 == "" {
|
||||
if parentNetwork.IsIPv6 == "yes" {
|
||||
if parentNetwork.AddressRange6 != "" {
|
||||
addr6, err := UniqueAddress6(extclient.Network, true)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -497,7 +499,7 @@ func GetExtClientsByID(nodeid, network string) ([]models.ExtClient, error) {
|
||||
// GetAllExtClients - gets all ext clients from DB
|
||||
func GetAllExtClients() ([]models.ExtClient, error) {
|
||||
var clients = []models.ExtClient{}
|
||||
currentNetworks, err := GetNetworks()
|
||||
currentNetworks, err := (&schema.Network{}).ListAll(db.WithContext(context.TODO()))
|
||||
if err != nil && database.IsEmptyRecord(err) {
|
||||
return clients, nil
|
||||
} else if err != nil {
|
||||
@@ -505,7 +507,7 @@ func GetAllExtClients() ([]models.ExtClient, error) {
|
||||
}
|
||||
|
||||
for i := range currentNetworks {
|
||||
netName := currentNetworks[i].NetID
|
||||
netName := currentNetworks[i].Name
|
||||
netClients, err := GetNetworkExtClients(netName)
|
||||
if err != nil {
|
||||
continue
|
||||
@@ -575,7 +577,10 @@ func GetExtPeers(node, peer *models.Node, addressIdentityMap map[string]models.P
|
||||
if err != nil {
|
||||
return peers, idsAndAddr, egressRoutes, err
|
||||
}
|
||||
host, err := GetHost(node.HostID.String())
|
||||
host := &schema.Host{
|
||||
ID: node.HostID,
|
||||
}
|
||||
err = host.Get(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
return peers, idsAndAddr, egressRoutes, err
|
||||
}
|
||||
@@ -764,7 +769,8 @@ func GetExtclientAllowedIPs(client models.ExtClient) (allowedIPs []string) {
|
||||
return
|
||||
}
|
||||
|
||||
network, err := GetParentNetwork(client.Network)
|
||||
network := &schema.Network{Name: client.Network}
|
||||
err = network.Get(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
logger.Log(1, "Could not retrieve Ingress Gateway Network", client.Network)
|
||||
return
|
||||
@@ -788,7 +794,7 @@ func GetExtclientAllowedIPs(client models.ExtClient) (allowedIPs []string) {
|
||||
return
|
||||
}
|
||||
|
||||
func GetStaticNodesByNetwork(network models.NetworkID, onlyWg bool) (staticNode []models.Node) {
|
||||
func GetStaticNodesByNetwork(network schema.NetworkID, onlyWg bool) (staticNode []models.Node) {
|
||||
extClients, err := GetAllExtClients()
|
||||
if err != nil {
|
||||
return
|
||||
|
||||
+36
-12
@@ -8,10 +8,14 @@ import (
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"context"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gravitl/netmaker/database"
|
||||
"github.com/gravitl/netmaker/db"
|
||||
"github.com/gravitl/netmaker/logger"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
"github.com/gravitl/netmaker/servercfg"
|
||||
"golang.org/x/exp/slog"
|
||||
)
|
||||
@@ -77,7 +81,10 @@ func CreateEgressGateway(gateway models.EgressGatewayRequest) (models.Node, erro
|
||||
if err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
host, err := GetHost(node.HostID.String())
|
||||
host := &schema.Host{
|
||||
ID: node.HostID,
|
||||
}
|
||||
err = host.Get(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
@@ -184,7 +191,10 @@ func CreateIngressGateway(netid string, nodeid string, ingress models.IngressReq
|
||||
if node.IsRelayed {
|
||||
return models.Node{}, errors.New("gateway cannot be created on a relayed node")
|
||||
}
|
||||
host, err := GetHost(node.HostID.String())
|
||||
host := &schema.Host{
|
||||
ID: node.HostID,
|
||||
}
|
||||
err = host.Get(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
@@ -192,7 +202,8 @@ func CreateIngressGateway(netid string, nodeid string, ingress models.IngressReq
|
||||
return models.Node{}, errors.New("gateway can only be created on linux based node")
|
||||
}
|
||||
|
||||
network, err := GetParentNetwork(netid)
|
||||
network := &schema.Network{Name: netid}
|
||||
err = network.Get(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
return models.Node{}, err
|
||||
}
|
||||
@@ -231,7 +242,7 @@ func CreateIngressGateway(netid string, nodeid string, ingress models.IngressReq
|
||||
node.Tags = make(map[models.TagID]struct{})
|
||||
}
|
||||
node.Tags[models.TagID(fmt.Sprintf("%s.%s", netid, models.GwTagName))] = struct{}{}
|
||||
node.PostureChecksViolations, node.PostureCheckVolationSeverityLevel = CheckPostureViolations(GetPostureCheckDeviceInfoByNode(&node), models.NetworkID(node.Network))
|
||||
node.PostureChecksViolations, node.PostureCheckVolationSeverityLevel = CheckPostureViolations(GetPostureCheckDeviceInfoByNode(&node), schema.NetworkID(node.Network))
|
||||
node.LastEvaluatedAt = time.Now().UTC()
|
||||
err = UpsertNode(&node)
|
||||
if err != nil {
|
||||
@@ -253,7 +264,7 @@ func GetIngressGwUsers(node models.Node) (models.IngressGwUsers, error) {
|
||||
return gwUsers, err
|
||||
}
|
||||
for _, user := range users {
|
||||
if user.PlatformRoleID != models.SuperAdminRole && user.PlatformRoleID != models.AdminRole {
|
||||
if user.PlatformRoleID != schema.SuperAdminRole && user.PlatformRoleID != schema.AdminRole {
|
||||
gwUsers.Users = append(gwUsers.Users, user)
|
||||
}
|
||||
}
|
||||
@@ -284,7 +295,7 @@ func DeleteIngressGateway(nodeid string) (models.Node, []models.ExtClient, error
|
||||
delete(node.Tags, models.TagID(fmt.Sprintf("%s.%s", node.Network, models.GwTagName)))
|
||||
node.IngressGatewayRange = ""
|
||||
node.Metadata = ""
|
||||
node.PostureChecksViolations, node.PostureCheckVolationSeverityLevel = CheckPostureViolations(GetPostureCheckDeviceInfoByNode(&node), models.NetworkID(node.Network))
|
||||
node.PostureChecksViolations, node.PostureCheckVolationSeverityLevel = CheckPostureViolations(GetPostureCheckDeviceInfoByNode(&node), schema.NetworkID(node.Network))
|
||||
node.LastEvaluatedAt = time.Now().UTC()
|
||||
err = UpsertNode(&node)
|
||||
if err != nil {
|
||||
@@ -319,18 +330,22 @@ func IsUserAllowedAccessToExtClient(username string, client models.ExtClient) bo
|
||||
if username == MasterUser {
|
||||
return true
|
||||
}
|
||||
user, err := GetUser(username)
|
||||
user := &schema.User{Username: username}
|
||||
err := user.Get(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if user.UserName != client.OwnerID {
|
||||
if user.Username != client.OwnerID {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func ValidateInetGwReq(inetNode models.Node, req models.InetNodeReq, update bool) error {
|
||||
inetHost, err := GetHost(inetNode.HostID.String())
|
||||
inetHost := &schema.Host{
|
||||
ID: inetNode.HostID,
|
||||
}
|
||||
err := inetHost.Get(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -352,7 +367,10 @@ func ValidateInetGwReq(inetNode models.Node, req models.InetNodeReq, update bool
|
||||
if clientNode.IsFailOver || clientNode.IsAutoRelay {
|
||||
return errors.New("failover node cannot be set to use internet gateway")
|
||||
}
|
||||
clientHost, err := GetHost(clientNode.HostID.String())
|
||||
clientHost := &schema.Host{
|
||||
ID: clientNode.HostID,
|
||||
}
|
||||
err = clientHost.Get(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -431,7 +449,10 @@ func UnsetInternetGw(node *models.Node) {
|
||||
|
||||
func SetDefaultGwForRelayedUpdate(relayed, relay models.Node, peerUpdate models.HostPeerUpdate) models.HostPeerUpdate {
|
||||
if relay.InternetGwID != "" {
|
||||
relayedHost, err := GetHost(relayed.HostID.String())
|
||||
relayedHost := &schema.Host{
|
||||
ID: relayed.HostID,
|
||||
}
|
||||
err := relayedHost.Get(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
return peerUpdate
|
||||
}
|
||||
@@ -452,7 +473,10 @@ func SetDefaultGw(node models.Node, peerUpdate models.HostPeerUpdate) models.Hos
|
||||
if err != nil {
|
||||
return peerUpdate
|
||||
}
|
||||
host, err := GetHost(node.HostID.String())
|
||||
host := &schema.Host{
|
||||
ID: node.HostID,
|
||||
}
|
||||
err = host.Get(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
return peerUpdate
|
||||
}
|
||||
|
||||
+2
-2
@@ -35,12 +35,12 @@ func TestMain(m *testing.M) {
|
||||
}
|
||||
|
||||
func TestCheckPorts(t *testing.T) {
|
||||
h := models.Host{
|
||||
h := schema.Host{
|
||||
ID: uuid.New(),
|
||||
EndpointIP: net.ParseIP("192.168.1.1"),
|
||||
ListenPort: 51821,
|
||||
}
|
||||
testHost := models.Host{
|
||||
testHost := schema.Host{
|
||||
ID: uuid.New(),
|
||||
EndpointIP: net.ParseIP("192.168.1.1"),
|
||||
ListenPort: 51830,
|
||||
|
||||
+55
-218
@@ -1,8 +1,8 @@
|
||||
package logic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
@@ -11,10 +11,12 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gravitl/netmaker/db"
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"golang.org/x/exp/slog"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/gravitl/netmaker/database"
|
||||
"github.com/gravitl/netmaker/logger"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/gravitl/netmaker/servercfg"
|
||||
@@ -22,9 +24,7 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
hostCacheMutex = &sync.RWMutex{}
|
||||
hostsCacheMap = make(map[string]models.Host)
|
||||
hostPortMutex = &sync.Mutex{}
|
||||
hostPortMutex = &sync.Mutex{}
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -34,99 +34,28 @@ var (
|
||||
ErrInvalidHostID error = errors.New("invalid host id")
|
||||
)
|
||||
|
||||
var CheckPostureViolations = func(d models.PostureCheckDeviceInfo, network models.NetworkID) (v []models.Violation, level models.Severity) {
|
||||
return []models.Violation{}, models.SeverityUnknown
|
||||
var CheckPostureViolations = func(d models.PostureCheckDeviceInfo, network schema.NetworkID) (v []models.Violation, level schema.Severity) {
|
||||
return []models.Violation{}, schema.SeverityUnknown
|
||||
}
|
||||
|
||||
var GetPostureCheckDeviceInfoByNode = func(node *models.Node) (d models.PostureCheckDeviceInfo) {
|
||||
return
|
||||
}
|
||||
|
||||
func getHostsFromCache() (hosts []models.Host) {
|
||||
hostCacheMutex.RLock()
|
||||
for _, host := range hostsCacheMap {
|
||||
hosts = append(hosts, host)
|
||||
}
|
||||
hostCacheMutex.RUnlock()
|
||||
return
|
||||
}
|
||||
|
||||
func getHostsMapFromCache() (hostsMap map[string]models.Host) {
|
||||
hostCacheMutex.RLock()
|
||||
hostsMap = hostsCacheMap
|
||||
hostCacheMutex.RUnlock()
|
||||
return
|
||||
}
|
||||
|
||||
func getHostFromCache(hostID string) (host models.Host, ok bool) {
|
||||
hostCacheMutex.RLock()
|
||||
host, ok = hostsCacheMap[hostID]
|
||||
hostCacheMutex.RUnlock()
|
||||
return
|
||||
}
|
||||
|
||||
func storeHostInCache(h models.Host) {
|
||||
hostCacheMutex.Lock()
|
||||
hostsCacheMap[h.ID.String()] = h
|
||||
hostCacheMutex.Unlock()
|
||||
}
|
||||
|
||||
func deleteHostFromCache(hostID string) {
|
||||
hostCacheMutex.Lock()
|
||||
delete(hostsCacheMap, hostID)
|
||||
hostCacheMutex.Unlock()
|
||||
}
|
||||
|
||||
func loadHostsIntoCache(hMap map[string]models.Host) {
|
||||
hostCacheMutex.Lock()
|
||||
hostsCacheMap = hMap
|
||||
hostCacheMutex.Unlock()
|
||||
}
|
||||
|
||||
const (
|
||||
maxPort = 1<<16 - 1
|
||||
minPort = 1025
|
||||
)
|
||||
|
||||
// GetAllHosts - returns all hosts in flat list or error
|
||||
func GetAllHosts() ([]models.Host, error) {
|
||||
var currHosts []models.Host
|
||||
if servercfg.CacheEnabled() {
|
||||
currHosts := getHostsFromCache()
|
||||
if len(currHosts) != 0 {
|
||||
return currHosts, nil
|
||||
}
|
||||
}
|
||||
records, err := database.FetchRecords(database.HOSTS_TABLE_NAME)
|
||||
if err != nil && !database.IsEmptyRecord(err) {
|
||||
return nil, err
|
||||
}
|
||||
currHostsMap := make(map[string]models.Host)
|
||||
if servercfg.CacheEnabled() {
|
||||
defer loadHostsIntoCache(currHostsMap)
|
||||
}
|
||||
for k := range records {
|
||||
var h models.Host
|
||||
err = json.Unmarshal([]byte(records[k]), &h)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
currHosts = append(currHosts, h)
|
||||
currHostsMap[h.ID.String()] = h
|
||||
}
|
||||
|
||||
return currHosts, nil
|
||||
}
|
||||
|
||||
// GetAllHostsWithStatus - returns all hosts with at least one
|
||||
// node with given status.
|
||||
func GetAllHostsWithStatus(status models.NodeStatus) ([]models.Host, error) {
|
||||
hosts, err := GetAllHosts()
|
||||
func GetAllHostsWithStatus(status models.NodeStatus) ([]schema.Host, error) {
|
||||
hosts, err := (&schema.Host{}).ListAll(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var validHosts []models.Host
|
||||
var validHosts []schema.Host
|
||||
for _, host := range hosts {
|
||||
if len(host.Nodes) == 0 {
|
||||
continue
|
||||
@@ -146,44 +75,16 @@ func GetAllHostsWithStatus(status models.NodeStatus) ([]models.Host, error) {
|
||||
}
|
||||
|
||||
// GetAllHostsAPI - get's all the hosts in an API usable format
|
||||
func GetAllHostsAPI(hosts []models.Host) []models.ApiHost {
|
||||
func GetAllHostsAPI(hosts []schema.Host) []models.ApiHost {
|
||||
apiHosts := []models.ApiHost{}
|
||||
for i := range hosts {
|
||||
newApiHost := hosts[i].ConvertNMHostToAPI()
|
||||
newApiHost := models.NewApiHostFromSchemaHost(&hosts[i])
|
||||
apiHosts = append(apiHosts, *newApiHost)
|
||||
}
|
||||
return apiHosts[:]
|
||||
}
|
||||
|
||||
// GetHostsMap - gets all the current hosts on machine in a map
|
||||
func GetHostsMap() (map[string]models.Host, error) {
|
||||
if servercfg.CacheEnabled() {
|
||||
hostsMap := getHostsMapFromCache()
|
||||
if len(hostsMap) != 0 {
|
||||
return hostsMap, nil
|
||||
}
|
||||
}
|
||||
records, err := database.FetchRecords(database.HOSTS_TABLE_NAME)
|
||||
if err != nil && !database.IsEmptyRecord(err) {
|
||||
return nil, err
|
||||
}
|
||||
currHostMap := make(map[string]models.Host)
|
||||
if servercfg.CacheEnabled() {
|
||||
defer loadHostsIntoCache(currHostMap)
|
||||
}
|
||||
for k := range records {
|
||||
var h models.Host
|
||||
err = json.Unmarshal([]byte(records[k]), &h)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
currHostMap[h.ID.String()] = h
|
||||
}
|
||||
|
||||
return currHostMap, nil
|
||||
}
|
||||
|
||||
func DoesHostExistinTheNetworkAlready(h *models.Host, network models.NetworkID) bool {
|
||||
func DoesHostExistinTheNetworkAlready(h *schema.Host, network schema.NetworkID) bool {
|
||||
if len(h.Nodes) > 0 {
|
||||
for _, nodeID := range h.Nodes {
|
||||
node, err := GetNodeByID(nodeID)
|
||||
@@ -195,54 +96,11 @@ func DoesHostExistinTheNetworkAlready(h *models.Host, network models.NetworkID)
|
||||
return false
|
||||
}
|
||||
|
||||
// GetHost - gets a host from db given id
|
||||
func GetHost(hostid string) (*models.Host, error) {
|
||||
if servercfg.CacheEnabled() {
|
||||
if host, ok := getHostFromCache(hostid); ok {
|
||||
return &host, nil
|
||||
}
|
||||
}
|
||||
record, err := database.FetchRecord(database.HOSTS_TABLE_NAME, hostid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var h models.Host
|
||||
if err = json.Unmarshal([]byte(record), &h); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if servercfg.CacheEnabled() {
|
||||
storeHostInCache(h)
|
||||
}
|
||||
|
||||
return &h, nil
|
||||
}
|
||||
|
||||
// GetHostByPubKey - gets a host from db given pubkey
|
||||
func GetHostByPubKey(hostPubKey string) (*models.Host, error) {
|
||||
hosts, err := GetAllHosts()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, host := range hosts {
|
||||
if host.PublicKey.String() == hostPubKey {
|
||||
return &host, nil
|
||||
}
|
||||
}
|
||||
return nil, errors.New("host not found")
|
||||
}
|
||||
|
||||
// CreateHost - creates a host if not exist
|
||||
func CreateHost(h *models.Host) error {
|
||||
hosts, hErr := GetAllHosts()
|
||||
clients, cErr := GetAllExtClients()
|
||||
if (hErr != nil && !database.IsEmptyRecord(hErr)) ||
|
||||
(cErr != nil && !database.IsEmptyRecord(cErr)) ||
|
||||
len(hosts)+len(clients) >= MachinesLimit {
|
||||
return errors.New("free tier limits exceeded on machines")
|
||||
}
|
||||
_, err := GetHost(h.ID.String())
|
||||
if (err != nil && !database.IsEmptyRecord(err)) || (err == nil) {
|
||||
func CreateHost(h *schema.Host) error {
|
||||
_host := &schema.Host{ID: h.ID}
|
||||
err := _host.Get(db.WithContext(context.TODO()))
|
||||
if (err != nil && !errors.Is(err, gorm.ErrRecordNotFound)) || (err == nil) {
|
||||
return ErrHostExists
|
||||
}
|
||||
|
||||
@@ -269,7 +127,7 @@ func CreateHost(h *models.Host) error {
|
||||
}
|
||||
|
||||
// UpdateHost - updates host data by field
|
||||
func UpdateHost(newHost, currentHost *models.Host) {
|
||||
func UpdateHost(newHost, currentHost *schema.Host) {
|
||||
// unchangeable fields via API here
|
||||
newHost.DaemonInstalled = currentHost.DaemonInstalled
|
||||
newHost.OS = currentHost.OS
|
||||
@@ -311,7 +169,7 @@ func UpdateHost(newHost, currentHost *models.Host) {
|
||||
}
|
||||
|
||||
// UpdateHostFromClient - used for updating host on server with update recieved from client
|
||||
func UpdateHostFromClient(newHost, currHost *models.Host) (sendPeerUpdate bool) {
|
||||
func UpdateHostFromClient(newHost, currHost *schema.Host) (sendPeerUpdate bool) {
|
||||
if newHost.PublicKey != currHost.PublicKey {
|
||||
currHost.PublicKey = newHost.PublicKey
|
||||
sendPeerUpdate = true
|
||||
@@ -401,24 +259,12 @@ func UpdateHostFromClient(newHost, currHost *models.Host) (sendPeerUpdate bool)
|
||||
}
|
||||
|
||||
// UpsertHost - upserts into DB a given host model, does not check for existence*
|
||||
func UpsertHost(h *models.Host) error {
|
||||
data, err := json.Marshal(h)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = database.Insert(h.ID.String(), string(data), database.HOSTS_TABLE_NAME)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if servercfg.CacheEnabled() {
|
||||
storeHostInCache(*h)
|
||||
}
|
||||
|
||||
return nil
|
||||
func UpsertHost(h *schema.Host) error {
|
||||
return h.Upsert(db.WithContext(context.TODO()))
|
||||
}
|
||||
|
||||
// UpdateHostNode - handles updates from client nodes
|
||||
func UpdateHostNode(h *models.Host, newNode *models.Node) (publishDeletedNodeUpdate, publishPeerUpdate bool) {
|
||||
func UpdateHostNode(h *schema.Host, newNode *models.Node) (publishDeletedNodeUpdate, publishPeerUpdate bool) {
|
||||
currentNode, err := GetNodeByID(newNode.ID.String())
|
||||
if err != nil {
|
||||
return
|
||||
@@ -442,7 +288,7 @@ func UpdateHostNode(h *models.Host, newNode *models.Node) (publishDeletedNodeUpd
|
||||
}
|
||||
|
||||
// RemoveHost - removes a given host from server
|
||||
func RemoveHost(h *models.Host, forceDelete bool) error {
|
||||
func RemoveHost(h *schema.Host, forceDelete bool) error {
|
||||
if !forceDelete && len(h.Nodes) > 0 {
|
||||
return fmt.Errorf("host still has associated nodes")
|
||||
}
|
||||
@@ -453,13 +299,10 @@ func RemoveHost(h *models.Host, forceDelete bool) error {
|
||||
}
|
||||
}
|
||||
|
||||
err := database.DeleteRecord(database.HOSTS_TABLE_NAME, h.ID.String())
|
||||
err := h.Delete(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if servercfg.CacheEnabled() {
|
||||
deleteHostFromCache(h.ID.String())
|
||||
}
|
||||
go func() {
|
||||
if servercfg.IsDNSMode() {
|
||||
SetDNS()
|
||||
@@ -469,21 +312,8 @@ func RemoveHost(h *models.Host, forceDelete bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveHostByID - removes a given host by id from server
|
||||
func RemoveHostByID(hostID string) error {
|
||||
|
||||
err := database.DeleteRecord(database.HOSTS_TABLE_NAME, hostID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if servercfg.CacheEnabled() {
|
||||
deleteHostFromCache(hostID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateHostNetwork - adds/deletes host from a network
|
||||
func UpdateHostNetwork(h *models.Host, network string, add bool) (*models.Node, error) {
|
||||
func UpdateHostNetwork(h *schema.Host, network string, add bool) (*models.Node, error) {
|
||||
for _, nodeID := range h.Nodes {
|
||||
node, err := GetNodeByID(nodeID)
|
||||
if err != nil || node.PendingDelete {
|
||||
@@ -513,7 +343,7 @@ func UpdateHostNetwork(h *models.Host, network string, add bool) (*models.Node,
|
||||
|
||||
// AssociateNodeToHost - associates and creates a node with a given host
|
||||
// should be the only way nodes get created as of 0.18
|
||||
func AssociateNodeToHost(n *models.Node, h *models.Host) error {
|
||||
func AssociateNodeToHost(n *models.Node, h *schema.Host) error {
|
||||
if len(h.ID.String()) == 0 || h.ID == uuid.Nil {
|
||||
return ErrInvalidHostID
|
||||
}
|
||||
@@ -522,8 +352,8 @@ func AssociateNodeToHost(n *models.Node, h *models.Host) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
currentHost, err := GetHost(h.ID.String())
|
||||
if err != nil {
|
||||
currentHost := &schema.Host{ID: h.ID}
|
||||
if err = currentHost.Get(db.WithContext(context.TODO())); err != nil {
|
||||
return err
|
||||
}
|
||||
h.HostPass = currentHost.HostPass
|
||||
@@ -533,7 +363,7 @@ func AssociateNodeToHost(n *models.Node, h *models.Host) error {
|
||||
|
||||
// DissasociateNodeFromHost - deletes a node and removes from host nodes
|
||||
// should be the only way nodes are deleted as of 0.18
|
||||
func DissasociateNodeFromHost(n *models.Node, h *models.Host) error {
|
||||
func DissasociateNodeFromHost(n *models.Node, h *schema.Host) error {
|
||||
if len(h.ID.String()) == 0 || h.ID == uuid.Nil {
|
||||
return ErrInvalidHostID
|
||||
}
|
||||
@@ -566,11 +396,16 @@ func DissasociateNodeFromHost(n *models.Node, h *models.Host) error {
|
||||
}
|
||||
|
||||
// DisassociateAllNodesFromHost - deletes all nodes of the host
|
||||
func DisassociateAllNodesFromHost(hostID string) error {
|
||||
host, err := GetHost(hostID)
|
||||
func DisassociateAllNodesFromHost(hostIDStr string) error {
|
||||
hostID, err := uuid.Parse(hostIDStr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
host := &schema.Host{ID: hostID}
|
||||
if err := host.Get(db.WithContext(context.TODO())); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, nodeID := range host.Nodes {
|
||||
node, err := GetNodeByID(nodeID)
|
||||
if err != nil {
|
||||
@@ -588,9 +423,9 @@ func DisassociateAllNodesFromHost(hostID string) error {
|
||||
}
|
||||
|
||||
// GetDefaultHosts - retrieve all hosts marked as default from DB
|
||||
func GetDefaultHosts() []models.Host {
|
||||
defaultHostList := []models.Host{}
|
||||
hosts, err := GetAllHosts()
|
||||
func GetDefaultHosts() []schema.Host {
|
||||
defaultHostList := []schema.Host{}
|
||||
hosts, err := (&schema.Host{}).ListAll(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
return defaultHostList
|
||||
}
|
||||
@@ -604,8 +439,8 @@ func GetDefaultHosts() []models.Host {
|
||||
|
||||
// GetHostNetworks - fetches all the networks
|
||||
func GetHostNetworks(hostID string) []string {
|
||||
currHost, err := GetHost(hostID)
|
||||
if err != nil {
|
||||
currHost := &schema.Host{ID: uuid.MustParse(hostID)}
|
||||
if err := currHost.Get(db.WithContext(context.TODO())); err != nil {
|
||||
return nil
|
||||
}
|
||||
nets := []string{}
|
||||
@@ -620,14 +455,14 @@ func GetHostNetworks(hostID string) []string {
|
||||
}
|
||||
|
||||
// GetRelatedHosts - fetches related hosts of a given host
|
||||
func GetRelatedHosts(hostID string) []models.Host {
|
||||
relatedHosts := []models.Host{}
|
||||
func GetRelatedHosts(hostID string) []schema.Host {
|
||||
relatedHosts := []schema.Host{}
|
||||
networks := GetHostNetworks(hostID)
|
||||
networkMap := make(map[string]struct{})
|
||||
for _, network := range networks {
|
||||
networkMap[network] = struct{}{}
|
||||
}
|
||||
hosts, err := GetAllHosts()
|
||||
hosts, err := (&schema.Host{}).ListAll(db.WithContext(context.TODO()))
|
||||
if err == nil {
|
||||
for _, host := range hosts {
|
||||
if host.ID.String() == hostID {
|
||||
@@ -648,7 +483,7 @@ func GetRelatedHosts(hostID string) []models.Host {
|
||||
// CheckHostPort checks host endpoints to ensures that hosts on the same server
|
||||
// with the same endpoint have different listen ports
|
||||
// in the case of 64535 hosts or more with same endpoint, ports will not be changed
|
||||
func CheckHostPorts(h *models.Host) (changed bool) {
|
||||
func CheckHostPorts(h *schema.Host) (changed bool) {
|
||||
if h.IsStaticPort {
|
||||
return false
|
||||
}
|
||||
@@ -658,7 +493,8 @@ func CheckHostPorts(h *models.Host) (changed bool) {
|
||||
|
||||
// Get the current host from database to check if it already has a valid port assigned
|
||||
// This check happens before the mutex to avoid unnecessary locking
|
||||
currentHost, err := GetHost(h.ID.String())
|
||||
currentHost := &schema.Host{ID: h.ID}
|
||||
err := currentHost.Get(db.WithContext(context.TODO()))
|
||||
if err == nil && currentHost.ListenPort > 0 {
|
||||
// If the host already has a port in the database, use that instead of the incoming port
|
||||
// This prevents the host from being reassigned when the client sends the old port
|
||||
@@ -679,7 +515,7 @@ func CheckHostPorts(h *models.Host) (changed bool) {
|
||||
}
|
||||
}()
|
||||
|
||||
hosts, err := GetAllHosts()
|
||||
hosts, err := (&schema.Host{}).ListAll(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -735,7 +571,7 @@ func CheckHostPorts(h *models.Host) (changed bool) {
|
||||
|
||||
// Re-read hosts to get the latest state (in case another host just changed its port)
|
||||
// This is important to avoid conflicts when multiple hosts are being processed
|
||||
latestHosts, err := GetAllHosts()
|
||||
latestHosts, err := (&schema.Host{}).ListAll(db.WithContext(context.TODO()))
|
||||
if err == nil {
|
||||
// Update portsInUse with latest state
|
||||
for _, host := range latestHosts {
|
||||
@@ -770,14 +606,15 @@ func CheckHostPorts(h *models.Host) (changed bool) {
|
||||
}
|
||||
|
||||
// HostExists - checks if given host already exists
|
||||
func HostExists(h *models.Host) bool {
|
||||
_, err := GetHost(h.ID.String())
|
||||
return (err != nil && !database.IsEmptyRecord(err)) || (err == nil)
|
||||
func HostExists(h *schema.Host) bool {
|
||||
_host := &schema.Host{ID: h.ID}
|
||||
err := _host.Get(db.WithContext(context.TODO()))
|
||||
return (err != nil && !errors.Is(err, gorm.ErrRecordNotFound)) || (err == nil)
|
||||
}
|
||||
|
||||
// GetHostByNodeID - returns a host if found to have a node's ID, else nil
|
||||
func GetHostByNodeID(id string) *models.Host {
|
||||
hosts, err := GetAllHosts()
|
||||
func GetHostByNodeID(id string) *schema.Host {
|
||||
hosts, err := (&schema.Host{}).ListAll(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
+11
-11
@@ -59,7 +59,7 @@ func CreateJWT(uuid string, macAddress string, network string) (response string,
|
||||
}
|
||||
|
||||
// CreateUserJWT - creates a user jwt token
|
||||
func CreateUserAccessJwtToken(username string, role models.UserRoleID, d time.Time, tokenID string) (response string, err error) {
|
||||
func CreateUserAccessJwtToken(username string, role schema.UserRoleID, d time.Time, tokenID string) (response string, err error) {
|
||||
claims := &models.UserClaims{
|
||||
UserName: username,
|
||||
Role: role,
|
||||
@@ -83,7 +83,7 @@ func CreateUserAccessJwtToken(username string, role models.UserRoleID, d time.Ti
|
||||
}
|
||||
|
||||
// CreateUserJWT - creates a user jwt token
|
||||
func CreateUserJWT(username string, role models.UserRoleID, appName string) (response string, err error) {
|
||||
func CreateUserJWT(username string, role schema.UserRoleID, appName string) (response string, err error) {
|
||||
duration := GetJwtValidityDuration()
|
||||
if appName == NetclientApp || appName == NetmakerDesktopApp {
|
||||
duration = GetJwtValidityDurationForClients()
|
||||
@@ -187,14 +187,14 @@ func GetUserNameFromToken(authtoken string) (username string, err error) {
|
||||
}
|
||||
|
||||
if token != nil && token.Valid {
|
||||
var user *models.User
|
||||
// check that user exists
|
||||
user, err = GetUser(claims.UserName)
|
||||
user := &schema.User{Username: claims.UserName}
|
||||
err = user.Get(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if user.UserName != "" {
|
||||
return user.UserName, nil
|
||||
if user.Username != "" {
|
||||
return user.Username, nil
|
||||
}
|
||||
if user.PlatformRoleID != claims.Role {
|
||||
return "", Unauthorized_Err
|
||||
@@ -232,15 +232,15 @@ func VerifyUserToken(tokenString string) (username string, issuperadmin, isadmin
|
||||
}
|
||||
}
|
||||
if token != nil && token.Valid {
|
||||
var user *models.User
|
||||
// check that user exists
|
||||
user, err = GetUser(claims.UserName)
|
||||
user := &schema.User{Username: claims.UserName}
|
||||
err = user.Get(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
return "", false, false, err
|
||||
}
|
||||
if user.UserName != "" {
|
||||
return user.UserName, user.PlatformRoleID == models.SuperAdminRole,
|
||||
user.PlatformRoleID == models.AdminRole, nil
|
||||
if user.Username != "" {
|
||||
return user.Username, user.PlatformRoleID == schema.SuperAdminRole,
|
||||
user.PlatformRoleID == schema.AdminRole, nil
|
||||
}
|
||||
err = errors.New("user does not exist")
|
||||
}
|
||||
|
||||
+193
-278
@@ -1,7 +1,7 @@
|
||||
package logic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
@@ -11,20 +11,20 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/c-robinson/iplib"
|
||||
validator "github.com/go-playground/validator/v10"
|
||||
"github.com/google/uuid"
|
||||
"github.com/gravitl/netmaker/database"
|
||||
"github.com/gravitl/netmaker/db"
|
||||
"github.com/gravitl/netmaker/logger"
|
||||
"github.com/gravitl/netmaker/logic/acls/nodeacls"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
"github.com/gravitl/netmaker/servercfg"
|
||||
"github.com/gravitl/netmaker/validation"
|
||||
"golang.org/x/exp/slog"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var (
|
||||
networkCacheMutex = &sync.RWMutex{}
|
||||
networkCacheMap = make(map[string]models.Network)
|
||||
allocatedIpMap = make(map[string]map[string]net.IP)
|
||||
)
|
||||
|
||||
@@ -38,14 +38,14 @@ func SetAllocatedIpMap() error {
|
||||
allocatedIpMap = map[string]map[string]net.IP{}
|
||||
}
|
||||
|
||||
currentNetworks, err := GetNetworks()
|
||||
currentNetworks, err := (&schema.Network{}).ListAll(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, v := range currentNetworks {
|
||||
pMap := map[string]net.IP{}
|
||||
netName := v.NetID
|
||||
netName := v.Name
|
||||
|
||||
//nodes
|
||||
nodes, err := GetNetworkNodes(netName)
|
||||
@@ -132,77 +132,16 @@ func RemoveNetworkFromAllocatedIpMap(networkName string) {
|
||||
networkCacheMutex.Unlock()
|
||||
}
|
||||
|
||||
func getNetworksFromCache() (networks []models.Network) {
|
||||
networkCacheMutex.RLock()
|
||||
for _, network := range networkCacheMap {
|
||||
networks = append(networks, network)
|
||||
}
|
||||
networkCacheMutex.RUnlock()
|
||||
return
|
||||
}
|
||||
|
||||
func deleteNetworkFromCache(key string) {
|
||||
networkCacheMutex.Lock()
|
||||
delete(networkCacheMap, key)
|
||||
networkCacheMutex.Unlock()
|
||||
}
|
||||
|
||||
func getNetworkFromCache(key string) (network models.Network, ok bool) {
|
||||
networkCacheMutex.RLock()
|
||||
network, ok = networkCacheMap[key]
|
||||
networkCacheMutex.RUnlock()
|
||||
return
|
||||
}
|
||||
|
||||
func storeNetworkInCache(key string, network models.Network) {
|
||||
networkCacheMutex.Lock()
|
||||
networkCacheMap[key] = network
|
||||
networkCacheMutex.Unlock()
|
||||
}
|
||||
|
||||
// GetNetworks - returns all networks from database
|
||||
func GetNetworks() ([]models.Network, error) {
|
||||
var networks []models.Network
|
||||
if servercfg.CacheEnabled() {
|
||||
networks := getNetworksFromCache()
|
||||
if len(networks) != 0 {
|
||||
return networks, nil
|
||||
}
|
||||
}
|
||||
collection, err := database.FetchRecords(database.NETWORKS_TABLE_NAME)
|
||||
if err != nil {
|
||||
return networks, err
|
||||
}
|
||||
|
||||
for _, value := range collection {
|
||||
var network models.Network
|
||||
if err := json.Unmarshal([]byte(value), &network); err != nil {
|
||||
return networks, err
|
||||
}
|
||||
// add network our array
|
||||
networks = append(networks, network)
|
||||
if servercfg.CacheEnabled() {
|
||||
storeNetworkInCache(network.NetID, network)
|
||||
}
|
||||
}
|
||||
|
||||
return networks, err
|
||||
}
|
||||
|
||||
// DeleteNetwork - deletes a network
|
||||
func DeleteNetwork(network string, force bool, done chan struct{}) error {
|
||||
|
||||
nodeCount, err := GetNetworkNonServerNodeCount(network)
|
||||
if nodeCount == 0 || database.IsEmptyRecord(err) {
|
||||
_network := &schema.Network{
|
||||
Name: network,
|
||||
}
|
||||
// delete server nodes first then db records
|
||||
err = database.DeleteRecord(database.NETWORKS_TABLE_NAME, network)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if servercfg.CacheEnabled() {
|
||||
deleteNetworkFromCache(network)
|
||||
}
|
||||
return nil
|
||||
return _network.Delete(db.WithContext(context.TODO()))
|
||||
}
|
||||
|
||||
// Remove All Nodes
|
||||
@@ -211,8 +150,8 @@ func DeleteNetwork(network string, force bool, done chan struct{}) error {
|
||||
if err == nil {
|
||||
for _, node := range nodes {
|
||||
node := node
|
||||
host, err := GetHost(node.HostID.String())
|
||||
if err != nil {
|
||||
host := &schema.Host{ID: node.HostID}
|
||||
if err := host.Get(db.WithContext(context.TODO())); err != nil {
|
||||
continue
|
||||
}
|
||||
if node.IsGw {
|
||||
@@ -228,13 +167,13 @@ func DeleteNetwork(network string, force bool, done chan struct{}) error {
|
||||
logger.Log(1, "failed to remove the node acls during network delete for network,", network)
|
||||
}
|
||||
// delete server nodes first then db records
|
||||
err = database.DeleteRecord(database.NETWORKS_TABLE_NAME, network)
|
||||
_network := &schema.Network{
|
||||
Name: network,
|
||||
}
|
||||
err = _network.Delete(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if servercfg.CacheEnabled() {
|
||||
deleteNetworkFromCache(network)
|
||||
}
|
||||
done <- struct{}{}
|
||||
close(done)
|
||||
}()
|
||||
@@ -254,54 +193,89 @@ func DeleteNetwork(network string, force bool, done chan struct{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// AssignVirtualNATDefaults determines safe defaults based on VPN CIDR
|
||||
func AssignVirtualNATDefaults(network *schema.Network, vpnCIDR string) {
|
||||
const (
|
||||
cgnatCIDR = "100.64.0.0/10"
|
||||
fallbackIPv4Pool = "198.18.0.0/15"
|
||||
|
||||
defaultIPv4SitePrefix = 24
|
||||
)
|
||||
|
||||
// Parse CGNAT CIDR (should always succeed, but check for safety)
|
||||
_, cgnatNet, err := net.ParseCIDR(cgnatCIDR)
|
||||
if err != nil {
|
||||
// Fallback to default pool if CGNAT parsing fails (shouldn't happen)
|
||||
network.VirtualNATPoolIPv4 = fallbackIPv4Pool
|
||||
network.VirtualNATSitePrefixLenIPv4 = defaultIPv4SitePrefix
|
||||
return
|
||||
}
|
||||
|
||||
var virtualIPv4Pool string
|
||||
// Parse VPN CIDR - if it fails or is empty, use fallback
|
||||
if vpnCIDR == "" {
|
||||
virtualIPv4Pool = fallbackIPv4Pool
|
||||
} else {
|
||||
_, vpnNet, err := net.ParseCIDR(vpnCIDR)
|
||||
if err != nil || vpnNet == nil {
|
||||
// Invalid VPN CIDR, use fallback
|
||||
virtualIPv4Pool = fallbackIPv4Pool
|
||||
} else if !cidrOverlaps(vpnNet, cgnatNet) {
|
||||
// Safe to reuse VPN CIDR for Virtual NAT
|
||||
virtualIPv4Pool = vpnCIDR
|
||||
} else {
|
||||
// VPN is CGNAT — must not reuse
|
||||
virtualIPv4Pool = fallbackIPv4Pool
|
||||
}
|
||||
}
|
||||
|
||||
network.VirtualNATPoolIPv4 = virtualIPv4Pool
|
||||
network.VirtualNATSitePrefixLenIPv4 = defaultIPv4SitePrefix
|
||||
}
|
||||
|
||||
// cidrOverlaps checks if two CIDR blocks overlap
|
||||
func cidrOverlaps(a, b *net.IPNet) bool {
|
||||
return a.Contains(b.IP) || b.Contains(a.IP)
|
||||
}
|
||||
|
||||
// CreateNetwork - creates a network in database
|
||||
func CreateNetwork(network models.Network) (models.Network, error) {
|
||||
|
||||
if network.AddressRange != "" {
|
||||
normalizedRange, err := NormalizeCIDR(network.AddressRange)
|
||||
func CreateNetwork(_network *schema.Network) error {
|
||||
if _network.AddressRange != "" {
|
||||
normalizedRange, err := NormalizeCIDR(_network.AddressRange)
|
||||
if err != nil {
|
||||
return models.Network{}, err
|
||||
return err
|
||||
}
|
||||
network.AddressRange = normalizedRange
|
||||
_network.AddressRange = normalizedRange
|
||||
}
|
||||
if network.AddressRange6 != "" {
|
||||
normalizedRange, err := NormalizeCIDR(network.AddressRange6)
|
||||
if _network.AddressRange6 != "" {
|
||||
normalizedRange, err := NormalizeCIDR(_network.AddressRange6)
|
||||
if err != nil {
|
||||
return models.Network{}, err
|
||||
return err
|
||||
}
|
||||
network.AddressRange6 = normalizedRange
|
||||
_network.AddressRange6 = normalizedRange
|
||||
}
|
||||
if !IsNetworkCIDRUnique(network.GetNetworkNetworkCIDR4(), network.GetNetworkNetworkCIDR6()) {
|
||||
return models.Network{}, errors.New("network cidr already in use")
|
||||
if !IsNetworkCIDRUnique(GetNetworkNetworkCIDR4(_network), GetNetworkNetworkCIDR6(_network)) {
|
||||
return errors.New("network cidr already in use")
|
||||
}
|
||||
|
||||
network.SetDefaults()
|
||||
network.SetNodesLastModified()
|
||||
network.SetNetworkLastModified()
|
||||
_network.NodesUpdatedAt = time.Now().UTC()
|
||||
|
||||
err := ValidateNetwork(&network, false)
|
||||
err := ValidateNetwork(_network, false)
|
||||
if err != nil {
|
||||
//logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
|
||||
return models.Network{}, err
|
||||
return err
|
||||
}
|
||||
|
||||
data, err := json.Marshal(&network)
|
||||
err = _network.Create(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
return models.Network{}, err
|
||||
}
|
||||
|
||||
if err = database.Insert(network.NetID, string(data), database.NETWORKS_TABLE_NAME); err != nil {
|
||||
return models.Network{}, err
|
||||
}
|
||||
if servercfg.CacheEnabled() {
|
||||
storeNetworkInCache(network.NetID, network)
|
||||
return err
|
||||
}
|
||||
|
||||
_, _ = CreateEnrollmentKey(
|
||||
0,
|
||||
time.Time{},
|
||||
[]string{network.NetID},
|
||||
[]string{network.NetID},
|
||||
[]string{_network.Name},
|
||||
[]string{_network.Name},
|
||||
[]models.TagID{},
|
||||
true,
|
||||
uuid.Nil,
|
||||
@@ -310,7 +284,22 @@ func CreateNetwork(network models.Network) (models.Network, error) {
|
||||
false,
|
||||
)
|
||||
|
||||
return network, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetNetworkNetworkCIDR4(network *schema.Network) *net.IPNet {
|
||||
if network.AddressRange == "" {
|
||||
return nil
|
||||
}
|
||||
_, netCidr, _ := net.ParseCIDR(network.AddressRange)
|
||||
return netCidr
|
||||
}
|
||||
func GetNetworkNetworkCIDR6(network *schema.Network) *net.IPNet {
|
||||
if network.AddressRange6 == "" {
|
||||
return nil
|
||||
}
|
||||
_, netCidr, _ := net.ParseCIDR(network.AddressRange6)
|
||||
return netCidr
|
||||
}
|
||||
|
||||
// GetNetworkNonServerNodeCount - get number of network non server nodes
|
||||
@@ -320,13 +309,13 @@ func GetNetworkNonServerNodeCount(networkName string) (int, error) {
|
||||
}
|
||||
|
||||
func IsNetworkCIDRUnique(cidr4 *net.IPNet, cidr6 *net.IPNet) bool {
|
||||
networks, err := GetNetworks()
|
||||
networks, err := (&schema.Network{}).ListAll(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
return database.IsEmptyRecord(err)
|
||||
return errors.Is(err, gorm.ErrRecordNotFound)
|
||||
}
|
||||
for _, network := range networks {
|
||||
if intersect(network.GetNetworkNetworkCIDR4(), cidr4) ||
|
||||
intersect(network.GetNetworkNetworkCIDR6(), cidr6) {
|
||||
if intersect(GetNetworkNetworkCIDR4(&network), cidr4) ||
|
||||
intersect(GetNetworkNetworkCIDR6(&network), cidr6) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
@@ -340,55 +329,17 @@ func intersect(n1, n2 *net.IPNet) bool {
|
||||
return n2.Contains(n1.IP) || n1.Contains(n2.IP)
|
||||
}
|
||||
|
||||
// GetParentNetwork - get parent network
|
||||
func GetParentNetwork(networkname string) (models.Network, error) {
|
||||
|
||||
var network models.Network
|
||||
if servercfg.CacheEnabled() {
|
||||
if network, ok := getNetworkFromCache(networkname); ok {
|
||||
return network, nil
|
||||
}
|
||||
}
|
||||
networkData, err := database.FetchRecord(database.NETWORKS_TABLE_NAME, networkname)
|
||||
if err != nil {
|
||||
return network, err
|
||||
}
|
||||
if err = json.Unmarshal([]byte(networkData), &network); err != nil {
|
||||
return models.Network{}, err
|
||||
}
|
||||
return network, nil
|
||||
}
|
||||
|
||||
// GetNetworkSettings - get parent network
|
||||
func GetNetworkSettings(networkname string) (models.Network, error) {
|
||||
|
||||
var network models.Network
|
||||
if servercfg.CacheEnabled() {
|
||||
if network, ok := getNetworkFromCache(networkname); ok {
|
||||
return network, nil
|
||||
}
|
||||
}
|
||||
networkData, err := database.FetchRecord(database.NETWORKS_TABLE_NAME, networkname)
|
||||
if err != nil {
|
||||
return network, err
|
||||
}
|
||||
if err = json.Unmarshal([]byte(networkData), &network); err != nil {
|
||||
return models.Network{}, err
|
||||
}
|
||||
return network, nil
|
||||
}
|
||||
|
||||
// UniqueAddress - get a unique ipv4 address
|
||||
func UniqueAddressCache(networkName string, reverse bool) (net.IP, error) {
|
||||
add := net.IP{}
|
||||
var network models.Network
|
||||
network, err := GetParentNetwork(networkName)
|
||||
network := &schema.Network{Name: networkName}
|
||||
err := network.Get(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
logger.Log(0, "UniqueAddressServer encountered an error")
|
||||
return add, err
|
||||
}
|
||||
|
||||
if network.IsIPv4 == "no" {
|
||||
if network.AddressRange == "" {
|
||||
return add, fmt.Errorf("IPv4 not active on network %s", networkName)
|
||||
}
|
||||
//ensure AddressRange is valid
|
||||
@@ -424,14 +375,14 @@ func UniqueAddressCache(networkName string, reverse bool) (net.IP, error) {
|
||||
// UniqueAddress - get a unique ipv4 address
|
||||
func UniqueAddressDB(networkName string, reverse bool) (net.IP, error) {
|
||||
add := net.IP{}
|
||||
var network models.Network
|
||||
network, err := GetParentNetwork(networkName)
|
||||
network := &schema.Network{Name: networkName}
|
||||
err := network.Get(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
logger.Log(0, "UniqueAddressServer encountered an error")
|
||||
return add, err
|
||||
}
|
||||
|
||||
if network.IsIPv4 == "no" {
|
||||
if network.AddressRange == "" {
|
||||
return add, fmt.Errorf("IPv4 not active on network %s", networkName)
|
||||
}
|
||||
//ensure AddressRange is valid
|
||||
@@ -524,12 +475,12 @@ func UniqueAddress6(networkName string, reverse bool) (net.IP, error) {
|
||||
// UniqueAddress6DB - see if ipv6 address is unique
|
||||
func UniqueAddress6DB(networkName string, reverse bool) (net.IP, error) {
|
||||
add := net.IP{}
|
||||
var network models.Network
|
||||
network, err := GetParentNetwork(networkName)
|
||||
network := &schema.Network{Name: networkName}
|
||||
err := network.Get(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
return add, err
|
||||
}
|
||||
if network.IsIPv6 == "no" {
|
||||
if network.AddressRange6 == "" {
|
||||
return add, fmt.Errorf("IPv6 not active on network %s", networkName)
|
||||
}
|
||||
|
||||
@@ -568,12 +519,12 @@ func UniqueAddress6DB(networkName string, reverse bool) (net.IP, error) {
|
||||
// UniqueAddress6Cache - see if ipv6 address is unique using cache
|
||||
func UniqueAddress6Cache(networkName string, reverse bool) (net.IP, error) {
|
||||
add := net.IP{}
|
||||
var network models.Network
|
||||
network, err := GetParentNetwork(networkName)
|
||||
network := &schema.Network{Name: networkName}
|
||||
err := network.Get(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
return add, err
|
||||
}
|
||||
if network.IsIPv6 == "no" {
|
||||
if network.AddressRange6 == "" {
|
||||
return add, fmt.Errorf("IPv6 not active on network %s", networkName)
|
||||
}
|
||||
|
||||
@@ -610,60 +561,44 @@ func UniqueAddress6Cache(networkName string, reverse bool) (net.IP, error) {
|
||||
}
|
||||
|
||||
// IsNetworkNameUnique - checks to see if any other networks have the same name (id)
|
||||
func IsNetworkNameUnique(network *models.Network) (bool, error) {
|
||||
func IsNetworkNameUnique(network *schema.Network) (bool, error) {
|
||||
_network := &schema.Network{
|
||||
Name: network.Name,
|
||||
}
|
||||
err := _network.Get(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
isunique := true
|
||||
|
||||
dbs, err := GetNetworks()
|
||||
|
||||
if err != nil && !database.IsEmptyRecord(err) {
|
||||
return false, err
|
||||
}
|
||||
|
||||
for i := 0; i < len(dbs); i++ {
|
||||
|
||||
if network.NetID == dbs[i].NetID {
|
||||
isunique = false
|
||||
}
|
||||
}
|
||||
|
||||
return isunique, nil
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func UpsertNetwork(network models.Network) error {
|
||||
netData, err := json.Marshal(network)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = database.Insert(network.NetID, string(netData), database.NETWORKS_TABLE_NAME)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if servercfg.CacheEnabled() {
|
||||
storeNetworkInCache(network.NetID, network)
|
||||
}
|
||||
return nil
|
||||
func UpsertNetwork(_network *schema.Network) error {
|
||||
return _network.Update(db.WithContext(context.TODO()))
|
||||
}
|
||||
|
||||
// UpdateNetwork - updates a network with another network's fields
|
||||
func UpdateNetwork(currentNetwork *models.Network, newNetwork *models.Network) error {
|
||||
func UpdateNetwork(currentNetwork, newNetwork *schema.Network) error {
|
||||
if err := ValidateNetwork(newNetwork, true); err != nil {
|
||||
return err
|
||||
}
|
||||
if newNetwork.NetID != currentNetwork.NetID {
|
||||
return errors.New("failed to update network " + newNetwork.NetID + ", cannot change netid.")
|
||||
if newNetwork.Name != currentNetwork.Name {
|
||||
return errors.New("failed to update network " + newNetwork.Name + ", cannot change netid.")
|
||||
}
|
||||
featureFlags := GetFeatureFlags()
|
||||
if featureFlags.EnableDeviceApproval {
|
||||
currentNetwork.AutoJoin = newNetwork.AutoJoin
|
||||
} else {
|
||||
currentNetwork.AutoJoin = "true"
|
||||
currentNetwork.AutoJoin = true
|
||||
}
|
||||
currentNetwork.AutoRemove = newNetwork.AutoRemove
|
||||
currentNetwork.AutoRemoveThreshold = newNetwork.AutoRemoveThreshold
|
||||
currentNetwork.AutoRemoveTags = newNetwork.AutoRemoveTags
|
||||
currentNetwork.DefaultACL = newNetwork.DefaultACL
|
||||
currentNetwork.NameServers = newNetwork.NameServers
|
||||
|
||||
// Validate and update Virtual NAT IPv4 settings
|
||||
if newNetwork.VirtualNATPoolIPv4 != "" {
|
||||
@@ -705,124 +640,104 @@ func UpdateNetwork(currentNetwork *models.Network, newNetwork *models.Network) e
|
||||
currentNetwork.VirtualNATPoolIPv4 = newNetwork.VirtualNATPoolIPv4
|
||||
currentNetwork.VirtualNATSitePrefixLenIPv4 = newNetwork.VirtualNATSitePrefixLenIPv4
|
||||
}
|
||||
data, err := json.Marshal(currentNetwork)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newNetwork.SetNetworkLastModified()
|
||||
err = database.Insert(currentNetwork.NetID, string(data), database.NETWORKS_TABLE_NAME)
|
||||
if err == nil {
|
||||
if servercfg.CacheEnabled() {
|
||||
storeNetworkInCache(newNetwork.NetID, *currentNetwork)
|
||||
}
|
||||
}
|
||||
return err
|
||||
return currentNetwork.Update(db.WithContext(context.TODO()))
|
||||
}
|
||||
|
||||
// GetNetwork - gets a network from database
|
||||
func GetNetwork(networkname string) (models.Network, error) {
|
||||
// validateNetName - checks if a netid of a network uses valid characters
|
||||
func validateNetName(network *schema.Network) error {
|
||||
var validationErr error
|
||||
|
||||
var network models.Network
|
||||
if servercfg.CacheEnabled() {
|
||||
if network, ok := getNetworkFromCache(networkname); ok {
|
||||
return network, nil
|
||||
}
|
||||
if len(network.Name) == 0 {
|
||||
validationErr = errors.Join(validationErr, errors.New("network name cannot be empty"))
|
||||
}
|
||||
networkData, err := database.FetchRecord(database.NETWORKS_TABLE_NAME, networkname)
|
||||
if err != nil {
|
||||
return network, err
|
||||
}
|
||||
if err = json.Unmarshal([]byte(networkData), &network); err != nil {
|
||||
return models.Network{}, err
|
||||
}
|
||||
return network, nil
|
||||
}
|
||||
|
||||
// NetIDInNetworkCharSet - checks if a netid of a network uses valid characters
|
||||
func NetIDInNetworkCharSet(network *models.Network) bool {
|
||||
if len(network.Name) > 32 {
|
||||
validationErr = errors.Join(validationErr, errors.New("network name cannot be longer than 32 characters"))
|
||||
}
|
||||
|
||||
charset := "abcdefghijklmnopqrstuvwxyz1234567890-_"
|
||||
|
||||
for _, char := range network.NetID {
|
||||
for _, char := range network.Name {
|
||||
if !strings.Contains(charset, string(char)) {
|
||||
return false
|
||||
validationErr = errors.Join(validationErr, errors.New("invalid character(s) in network name"))
|
||||
break
|
||||
}
|
||||
}
|
||||
return true
|
||||
|
||||
return validationErr
|
||||
}
|
||||
|
||||
// Validate - validates fields of an network struct
|
||||
func ValidateNetwork(network *models.Network, isUpdate bool) error {
|
||||
v := validator.New()
|
||||
_ = v.RegisterValidation("netid_valid", func(fl validator.FieldLevel) bool {
|
||||
inCharSet := NetIDInNetworkCharSet(network)
|
||||
if isUpdate {
|
||||
return inCharSet
|
||||
}
|
||||
isFieldUnique, _ := IsNetworkNameUnique(network)
|
||||
return isFieldUnique && inCharSet
|
||||
})
|
||||
//
|
||||
_ = v.RegisterValidation("checkyesorno", func(fl validator.FieldLevel) bool {
|
||||
return validation.CheckYesOrNo(fl)
|
||||
})
|
||||
err := v.Struct(network)
|
||||
func ValidateNetwork(network *schema.Network, isUpdate bool) error {
|
||||
var validationErr error
|
||||
err := validateNetName(network)
|
||||
if err != nil {
|
||||
for _, e := range err.(validator.ValidationErrors) {
|
||||
fmt.Println(e)
|
||||
validationErr = errors.Join(validationErr, err)
|
||||
}
|
||||
|
||||
if !isUpdate {
|
||||
nameUnique, _ := IsNetworkNameUnique(network)
|
||||
if !nameUnique {
|
||||
validationErr = errors.Join(validationErr, errors.New("invalid network name"))
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
if network.AddressRange != "" {
|
||||
_, _, err = net.ParseCIDR(network.AddressRange)
|
||||
if err != nil {
|
||||
validationErr = errors.Join(validationErr, err)
|
||||
}
|
||||
}
|
||||
|
||||
// ParseNetwork - parses a network into a model
|
||||
func ParseNetwork(value string) (models.Network, error) {
|
||||
var network models.Network
|
||||
err := json.Unmarshal([]byte(value), &network)
|
||||
return network, err
|
||||
if network.AddressRange6 != "" {
|
||||
_, _, err = net.ParseCIDR(network.AddressRange6)
|
||||
if err != nil {
|
||||
validationErr = errors.Join(validationErr, err)
|
||||
}
|
||||
}
|
||||
|
||||
if network.DefaultKeepAlive > 1000 {
|
||||
validationErr = errors.Join(validationErr, errors.New("default keep alive must be less than 1000"))
|
||||
}
|
||||
|
||||
return validationErr
|
||||
}
|
||||
|
||||
// SaveNetwork - save network struct to database
|
||||
func SaveNetwork(network *models.Network) error {
|
||||
data, err := json.Marshal(network)
|
||||
if err != nil {
|
||||
return err
|
||||
func SaveNetwork(_network *schema.Network) error {
|
||||
_existingNetwork := schema.Network{Name: _network.Name}
|
||||
// Check if network exists to preserve ID
|
||||
err := _existingNetwork.Get(db.WithContext(context.TODO()))
|
||||
if err == nil {
|
||||
_network.ID = _existingNetwork.ID
|
||||
return _network.Update(db.WithContext(context.TODO()))
|
||||
}
|
||||
if err := database.Insert(network.NetID, string(data), database.NETWORKS_TABLE_NAME); err != nil {
|
||||
return err
|
||||
}
|
||||
if servercfg.CacheEnabled() {
|
||||
storeNetworkInCache(network.NetID, *network)
|
||||
}
|
||||
return nil
|
||||
|
||||
return _network.Create(db.WithContext(context.TODO()))
|
||||
}
|
||||
|
||||
// NetworkExists - check if network exists
|
||||
func NetworkExists(name string) (bool, error) {
|
||||
|
||||
var network string
|
||||
var err error
|
||||
if servercfg.CacheEnabled() {
|
||||
if _, ok := getNetworkFromCache(name); ok {
|
||||
return ok, nil
|
||||
err := (&schema.Network{Name: name}).Get(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
if network, err = database.FetchRecord(database.NETWORKS_TABLE_NAME, name); err != nil {
|
||||
|
||||
return false, err
|
||||
}
|
||||
return len(network) > 0, nil
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// SortNetworks - Sorts slice of Networks by their NetID alphabetically with numbers first
|
||||
func SortNetworks(unsortedNetworks []models.Network) {
|
||||
func SortNetworks(unsortedNetworks []schema.Network) {
|
||||
sort.Slice(unsortedNetworks, func(i, j int) bool {
|
||||
return unsortedNetworks[i].NetID < unsortedNetworks[j].NetID
|
||||
return unsortedNetworks[i].Name < unsortedNetworks[j].Name
|
||||
})
|
||||
}
|
||||
|
||||
var NetworkHook models.HookFunc = func(params ...interface{}) error {
|
||||
networks, err := GetNetworks()
|
||||
networks, err := (&schema.Network{}).ListAll(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -831,10 +746,10 @@ var NetworkHook models.HookFunc = func(params ...interface{}) error {
|
||||
return err
|
||||
}
|
||||
for _, network := range networks {
|
||||
if network.AutoRemove == "false" || network.AutoRemoveThreshold == 0 {
|
||||
if !network.AutoRemove || network.AutoRemoveThreshold == 0 {
|
||||
continue
|
||||
}
|
||||
nodes := GetNetworkNodesMemory(allNodes, network.NetID)
|
||||
nodes := GetNetworkNodesMemory(allNodes, network.Name)
|
||||
for _, node := range nodes {
|
||||
if !node.Connected {
|
||||
continue
|
||||
@@ -860,9 +775,9 @@ var NetworkHook models.HookFunc = func(params ...interface{}) error {
|
||||
node.PendingDelete = true
|
||||
node.Action = models.NODE_DELETE
|
||||
DeleteNodesCh <- &node
|
||||
host, err := GetHost(node.HostID.String())
|
||||
if err == nil && len(host.Nodes) == 0 {
|
||||
RemoveHostByID(host.ID.String())
|
||||
host := &schema.Host{ID: node.HostID}
|
||||
if err := host.Get(db.WithContext(context.TODO())); err == nil && len(host.Nodes) == 0 {
|
||||
(&schema.Host{ID: host.ID}).Delete(db.WithContext(context.TODO()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+30
-29
@@ -131,7 +131,7 @@ func GetNetworkNodes(network string) ([]models.Node, error) {
|
||||
}
|
||||
|
||||
// GetHostNodes - fetches all nodes part of the host
|
||||
func GetHostNodes(host *models.Host) []models.Node {
|
||||
func GetHostNodes(host *schema.Host) []models.Node {
|
||||
nodes := []models.Node{}
|
||||
for _, nodeID := range host.Nodes {
|
||||
node, err := GetNodeByID(nodeID)
|
||||
@@ -201,7 +201,8 @@ func UpsertNode(newNode *models.Node) error {
|
||||
// UpdateNode - takes a node and updates another node with it's values
|
||||
func UpdateNode(currentNode *models.Node, newNode *models.Node) error {
|
||||
if newNode.Address.IP.String() != currentNode.Address.IP.String() {
|
||||
if network, err := GetParentNetwork(newNode.Network); err == nil {
|
||||
network := &schema.Network{Name: newNode.Network}
|
||||
if err := network.Get(db.WithContext(context.TODO())); err == nil {
|
||||
if !IsAddressInCIDR(newNode.Address.IP, network.AddressRange) {
|
||||
return fmt.Errorf("invalid address provided; out of network range for node %s", newNode.ID)
|
||||
}
|
||||
@@ -319,7 +320,10 @@ func DeleteNode(node *models.Node, purge bool) error {
|
||||
if alreadyDeleted {
|
||||
logger.Log(1, "forcibly deleting node", node.ID.String())
|
||||
}
|
||||
host, err := GetHost(node.HostID.String())
|
||||
host := &schema.Host{
|
||||
ID: node.HostID,
|
||||
}
|
||||
err := host.Get(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
logger.Log(1, "no host found for node", node.ID.String(), "deleting..")
|
||||
if delErr := DeleteNodeByID(node); delErr != nil {
|
||||
@@ -428,7 +432,7 @@ func ValidateNode(node *models.Node, isUpdate bool) error {
|
||||
return isFieldUnique
|
||||
})
|
||||
_ = v.RegisterValidation("network_exists", func(fl validator.FieldLevel) bool {
|
||||
_, err := GetNetworkByNode(node)
|
||||
err := (&schema.Network{Name: node.Network}).Get(db.WithContext(context.TODO()))
|
||||
return err == nil
|
||||
})
|
||||
_ = v.RegisterValidation("checkyesornoorunset", func(f1 validator.FieldLevel) bool {
|
||||
@@ -485,7 +489,7 @@ func AddStaticNodestoList(nodes []models.Node) []models.Node {
|
||||
continue
|
||||
}
|
||||
if node.IsIngressGateway {
|
||||
nodes = append(nodes, GetStaticNodesByNetwork(models.NetworkID(node.Network), false)...)
|
||||
nodes = append(nodes, GetStaticNodesByNetwork(schema.NetworkID(node.Network), false)...)
|
||||
netMap[node.Network] = struct{}{}
|
||||
}
|
||||
}
|
||||
@@ -497,7 +501,7 @@ func AddStatusToNodes(nodes []models.Node, statusCall bool) (nodesWithStatus []m
|
||||
for _, node := range nodes {
|
||||
if _, ok := aclDefaultPolicyStatusMap[node.Network]; !ok {
|
||||
// check default policy if all allowed return true
|
||||
defaultPolicy, _ := GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy)
|
||||
defaultPolicy, _ := GetDefaultPolicy(schema.NetworkID(node.Network), models.DevicePolicy)
|
||||
aclDefaultPolicyStatusMap[node.Network] = defaultPolicy.Enabled
|
||||
}
|
||||
if statusCall {
|
||||
@@ -511,24 +515,10 @@ func AddStatusToNodes(nodes []models.Node, statusCall bool) (nodesWithStatus []m
|
||||
return
|
||||
}
|
||||
|
||||
// GetNetworkByNode - gets the network model from a node
|
||||
func GetNetworkByNode(node *models.Node) (models.Network, error) {
|
||||
|
||||
var network = models.Network{}
|
||||
networkData, err := database.FetchRecord(database.NETWORKS_TABLE_NAME, node.Network)
|
||||
if err != nil {
|
||||
return network, err
|
||||
}
|
||||
if err = json.Unmarshal([]byte(networkData), &network); err != nil {
|
||||
return models.Network{}, err
|
||||
}
|
||||
return network, nil
|
||||
}
|
||||
|
||||
// SetNodeDefaults - sets the defaults of a node to avoid empty fields
|
||||
func SetNodeDefaults(node *models.Node, resetConnected bool) {
|
||||
|
||||
parentNetwork, _ := GetNetworkByNode(node)
|
||||
parentNetwork := &schema.Network{Name: node.Network}
|
||||
_ = parentNetwork.Get(db.WithContext(context.TODO()))
|
||||
_, cidr, err := net.ParseCIDR(parentNetwork.AddressRange)
|
||||
if err == nil {
|
||||
node.NetworkRange = *cidr
|
||||
@@ -622,7 +612,10 @@ func GetAllNodesAPI(nodes []models.Node) []models.ApiNode {
|
||||
for i := range nodes {
|
||||
node := nodes[i]
|
||||
if !node.IsStatic {
|
||||
h, err := GetHost(node.HostID.String())
|
||||
h := &schema.Host{
|
||||
ID: node.HostID,
|
||||
}
|
||||
err := h.Get(db.WithContext(context.TODO()))
|
||||
if err == nil {
|
||||
node.Location = h.Location
|
||||
node.CountryCode = h.CountryCode
|
||||
@@ -643,7 +636,10 @@ func GetAllNodesAPIWithLocation(nodes []models.Node) []models.ApiNode {
|
||||
if node.IsStatic {
|
||||
newApiNode.Location = node.StaticNode.Location
|
||||
} else {
|
||||
host, _ := GetHost(node.HostID.String())
|
||||
host := &schema.Host{
|
||||
ID: node.HostID,
|
||||
}
|
||||
_ = host.Get(db.WithContext(context.TODO()))
|
||||
newApiNode.Location = host.Location
|
||||
}
|
||||
|
||||
@@ -694,7 +690,10 @@ func createNode(node *models.Node) error {
|
||||
addressLock.Lock()
|
||||
defer addressLock.Unlock()
|
||||
|
||||
host, err := GetHost(node.HostID.String())
|
||||
host := &schema.Host{
|
||||
ID: node.HostID,
|
||||
}
|
||||
err := host.Get(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -702,7 +701,8 @@ func createNode(node *models.Node) error {
|
||||
SetNodeDefaults(node, true)
|
||||
|
||||
defaultACLVal := acls.Allowed
|
||||
parentNetwork, err := GetNetwork(node.Network)
|
||||
parentNetwork := &schema.Network{Name: node.Network}
|
||||
err = parentNetwork.Get(db.WithContext(context.TODO()))
|
||||
if err == nil {
|
||||
if parentNetwork.DefaultACL != "yes" {
|
||||
defaultACLVal = acls.NotAllowed
|
||||
@@ -714,7 +714,7 @@ func createNode(node *models.Node) error {
|
||||
}
|
||||
|
||||
if node.Address.IP == nil {
|
||||
if parentNetwork.IsIPv4 == "yes" {
|
||||
if parentNetwork.AddressRange != "" {
|
||||
if node.Address.IP, err = UniqueAddress(node.Network, false); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -728,7 +728,7 @@ func createNode(node *models.Node) error {
|
||||
return fmt.Errorf("invalid address: ipv4 %s is not unique", node.Address.String())
|
||||
}
|
||||
if node.Address6.IP == nil {
|
||||
if parentNetwork.IsIPv6 == "yes" {
|
||||
if parentNetwork.AddressRange6 != "" {
|
||||
if node.Address6.IP, err = UniqueAddress6(node.Network, false); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -836,7 +836,8 @@ func ValidateNodeIp(currentNode *models.Node, newNode *models.ApiNode) error {
|
||||
}
|
||||
|
||||
func ValidateEgressRange(netID string, ranges []string) error {
|
||||
network, err := GetNetworkSettings(netID)
|
||||
network := &schema.Network{Name: netID}
|
||||
err := network.Get(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
slog.Error("error getting network with netid", "error", netID, err.Error)
|
||||
return errors.New("error getting network with netid: " + netID + " " + err.Error())
|
||||
|
||||
+44
-29
@@ -64,9 +64,9 @@ var (
|
||||
)
|
||||
|
||||
// GetHostPeerInfo - fetches required peer info per network
|
||||
func GetHostPeerInfo(host *models.Host) (models.HostPeerInfo, error) {
|
||||
func GetHostPeerInfo(host *schema.Host) (models.HostPeerInfo, error) {
|
||||
peerInfo := models.HostPeerInfo{
|
||||
NetworkPeerIDs: make(map[models.NetworkID]models.PeerMap),
|
||||
NetworkPeerIDs: make(map[schema.NetworkID]models.PeerMap),
|
||||
}
|
||||
allNodes, err := GetAllNodes()
|
||||
if err != nil {
|
||||
@@ -84,7 +84,7 @@ func GetHostPeerInfo(host *models.Host) (models.HostPeerInfo, error) {
|
||||
continue
|
||||
}
|
||||
networkPeersInfo := make(models.PeerMap)
|
||||
defaultDevicePolicy, _ := GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy)
|
||||
defaultDevicePolicy, _ := GetDefaultPolicy(schema.NetworkID(node.Network), models.DevicePolicy)
|
||||
|
||||
currentPeers := GetNetworkNodesMemory(allNodes, node.Network)
|
||||
for _, peer := range currentPeers {
|
||||
@@ -94,7 +94,10 @@ func GetHostPeerInfo(host *models.Host) (models.HostPeerInfo, error) {
|
||||
continue
|
||||
}
|
||||
|
||||
peerHost, err := GetHost(peer.HostID.String())
|
||||
peerHost := &schema.Host{
|
||||
ID: peer.HostID,
|
||||
}
|
||||
err := peerHost.Get(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
logger.Log(4, "no peer host", peer.HostID.String(), err.Error())
|
||||
continue
|
||||
@@ -134,13 +137,13 @@ func GetHostPeerInfo(host *models.Host) (models.HostPeerInfo, error) {
|
||||
}
|
||||
}
|
||||
|
||||
peerInfo.NetworkPeerIDs[models.NetworkID(node.Network)] = networkPeersInfo
|
||||
peerInfo.NetworkPeerIDs[schema.NetworkID(node.Network)] = networkPeersInfo
|
||||
}
|
||||
return peerInfo, nil
|
||||
}
|
||||
|
||||
// GetPeerUpdateForHost - gets the consolidated peer update for the host from all networks
|
||||
func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.Node,
|
||||
func GetPeerUpdateForHost(network string, host *schema.Host, allNodes []models.Node,
|
||||
deletedNode *models.Node, deletedClients []models.ExtClient) (models.HostPeerUpdate, error) {
|
||||
if host == nil {
|
||||
return models.HostPeerUpdate{}, errors.New("host is nil")
|
||||
@@ -165,8 +168,8 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N
|
||||
HostNetworkInfo: models.HostInfoMap{},
|
||||
ServerConfig: GetServerInfo(),
|
||||
DnsNameservers: GetNameserversForHost(host),
|
||||
AutoRelayNodes: make(map[models.NetworkID][]models.Node),
|
||||
GwNodes: make(map[models.NetworkID][]models.Node),
|
||||
AutoRelayNodes: make(map[schema.NetworkID][]models.Node),
|
||||
GwNodes: make(map[schema.NetworkID][]models.Node),
|
||||
AddressIdentityMap: make(map[string]models.PeerIdentity),
|
||||
}
|
||||
if host.DNS == "no" {
|
||||
@@ -231,7 +234,7 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N
|
||||
}
|
||||
|
||||
hostPeerUpdate.Nodes = append(hostPeerUpdate.Nodes, node)
|
||||
acls, _ := ListAclsByNetwork(models.NetworkID(node.Network))
|
||||
acls, _ := ListAclsByNetwork(schema.NetworkID(node.Network))
|
||||
eli, _ := (&schema.Egress{Network: node.Network}).ListByNetwork(db.WithContext(context.TODO()))
|
||||
GetNodeEgressInfo(&node, eli, acls)
|
||||
egsWithDomain := ListAllByRoutingNodeWithDomain(eli, node.ID.String())
|
||||
@@ -243,8 +246,8 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N
|
||||
hostPeerUpdate.IsInternetGw = IsInternetGw(node)
|
||||
}
|
||||
hostPeerUpdate.DnsNameservers = append(hostPeerUpdate.DnsNameservers, GetEgressDomainNSForNode(&node)...)
|
||||
defaultUserPolicy, _ := GetDefaultPolicy(models.NetworkID(node.Network), models.UserPolicy)
|
||||
defaultDevicePolicy, _ := GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy)
|
||||
defaultUserPolicy, _ := GetDefaultPolicy(schema.NetworkID(node.Network), models.UserPolicy)
|
||||
defaultDevicePolicy, _ := GetDefaultPolicy(schema.NetworkID(node.Network), models.DevicePolicy)
|
||||
if (defaultDevicePolicy.Enabled && defaultUserPolicy.Enabled) ||
|
||||
(!CheckIfAnyPolicyisUniDirectional(node, acls) &&
|
||||
!(node.EgressDetails.IsEgressGateway && len(node.EgressDetails.EgressGatewayRanges) > 0)) {
|
||||
@@ -273,11 +276,6 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N
|
||||
}
|
||||
}
|
||||
}
|
||||
networkSettings, err := GetNetwork(node.Network)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
hostPeerUpdate.NameServers = append(hostPeerUpdate.NameServers, networkSettings.NameServers...)
|
||||
currentPeers := GetNetworkNodesMemory(allNodes, node.Network)
|
||||
for _, peer := range currentPeers {
|
||||
if peer.ID.String() == node.ID.String() {
|
||||
@@ -285,13 +283,16 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N
|
||||
continue
|
||||
}
|
||||
|
||||
peerHost, err := GetHost(peer.HostID.String())
|
||||
peerHost := &schema.Host{
|
||||
ID: peer.HostID,
|
||||
}
|
||||
err := peerHost.Get(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
logger.Log(4, "no peer host", peer.HostID.String(), err.Error())
|
||||
continue
|
||||
}
|
||||
peerConfig := wgtypes.PeerConfig{
|
||||
PublicKey: peerHost.PublicKey,
|
||||
PublicKey: peerHost.PublicKey.Key,
|
||||
PersistentKeepaliveInterval: &peerHost.PersistentKeepalive,
|
||||
ReplaceAllowedIPs: true,
|
||||
}
|
||||
@@ -314,7 +315,10 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N
|
||||
// get relay host
|
||||
failOverNode, err := GetNodeByID(peer.FailedOverBy.String())
|
||||
if err == nil {
|
||||
relayHost, err := GetHost(failOverNode.HostID.String())
|
||||
relayHost := &schema.Host{
|
||||
ID: failOverNode.HostID,
|
||||
}
|
||||
err := relayHost.Get(db.WithContext(context.TODO()))
|
||||
if err == nil {
|
||||
peerKey = relayHost.PublicKey.String()
|
||||
}
|
||||
@@ -324,7 +328,10 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N
|
||||
// get relay host
|
||||
autoRelayNode, err := GetNodeByID(peerAutoRelayID)
|
||||
if err == nil {
|
||||
relayHost, err := GetHost(autoRelayNode.HostID.String())
|
||||
relayHost := &schema.Host{
|
||||
ID: autoRelayNode.HostID,
|
||||
}
|
||||
err = relayHost.Get(db.WithContext(context.TODO()))
|
||||
if err == nil {
|
||||
peerKey = relayHost.PublicKey.String()
|
||||
}
|
||||
@@ -334,7 +341,10 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N
|
||||
// get relay host
|
||||
relayNode, err := GetNodeByID(peer.RelayedBy)
|
||||
if err == nil {
|
||||
relayHost, err := GetHost(relayNode.HostID.String())
|
||||
relayHost := &schema.Host{
|
||||
ID: relayNode.HostID,
|
||||
}
|
||||
err := relayHost.Get(db.WithContext(context.TODO()))
|
||||
if err == nil {
|
||||
peerKey = relayHost.PublicKey.String()
|
||||
}
|
||||
@@ -363,11 +373,11 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N
|
||||
}
|
||||
if allowedToComm {
|
||||
if peer.IsAutoRelay {
|
||||
hostPeerUpdate.AutoRelayNodes[models.NetworkID(peer.Network)] = append(hostPeerUpdate.AutoRelayNodes[models.NetworkID(peer.Network)],
|
||||
hostPeerUpdate.AutoRelayNodes[schema.NetworkID(peer.Network)] = append(hostPeerUpdate.AutoRelayNodes[schema.NetworkID(peer.Network)],
|
||||
peer)
|
||||
}
|
||||
if node.AutoAssignGateway && peer.IsGw {
|
||||
hostPeerUpdate.GwNodes[models.NetworkID(peer.Network)] = append(hostPeerUpdate.GwNodes[models.NetworkID(peer.Network)],
|
||||
hostPeerUpdate.GwNodes[schema.NetworkID(peer.Network)] = append(hostPeerUpdate.GwNodes[schema.NetworkID(peer.Network)],
|
||||
peer)
|
||||
}
|
||||
}
|
||||
@@ -587,7 +597,7 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N
|
||||
Network: rangeI,
|
||||
RouteMetric: 256,
|
||||
Nat: true,
|
||||
Mode: models.DirectNAT,
|
||||
Mode: schema.DirectNAT,
|
||||
})
|
||||
}
|
||||
inetEgressInfo := models.EgressInfo{
|
||||
@@ -626,11 +636,14 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N
|
||||
hostPeerUpdate.Peers[i] = peer
|
||||
}
|
||||
if deletedNode != nil && host.OS != models.OS_Types.IoT {
|
||||
peerHost, err := GetHost(deletedNode.HostID.String())
|
||||
peerHost := &schema.Host{
|
||||
ID: deletedNode.HostID,
|
||||
}
|
||||
err := peerHost.Get(db.WithContext(context.TODO()))
|
||||
if err == nil && host.ID != peerHost.ID {
|
||||
if _, ok := peerIndexMap[peerHost.PublicKey.String()]; !ok {
|
||||
hostPeerUpdate.Peers = append(hostPeerUpdate.Peers, wgtypes.PeerConfig{
|
||||
PublicKey: peerHost.PublicKey,
|
||||
PublicKey: peerHost.PublicKey.Key,
|
||||
Remove: true,
|
||||
})
|
||||
}
|
||||
@@ -662,7 +675,7 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N
|
||||
}
|
||||
|
||||
// GetPeerListenPort - given a host, retrieve it's appropriate listening port
|
||||
func GetPeerListenPort(host *models.Host) int {
|
||||
func GetPeerListenPort(host *schema.Host) int {
|
||||
peerPort := host.ListenPort
|
||||
if !host.IsStaticPort && host.WgPublicListenPort != 0 {
|
||||
peerPort = host.WgPublicListenPort
|
||||
@@ -742,8 +755,10 @@ func GetAllowedIPs(node, peer *models.Node, metrics *models.Metrics) []net.IPNet
|
||||
}
|
||||
|
||||
func GetEgressIPs(peer *models.Node) []net.IPNet {
|
||||
|
||||
peerHost, err := GetHost(peer.HostID.String())
|
||||
peerHost := &schema.Host{
|
||||
ID: peer.HostID,
|
||||
}
|
||||
err := peerHost.Get(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
logger.Log(0, "error retrieving host for peer", peer.ID.String(), "host id", peer.HostID.String(), err.Error())
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gravitl/netmaker/database"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -17,7 +17,7 @@ const (
|
||||
type CValue struct {
|
||||
Network string `json:"network,omitempty"`
|
||||
Value string `json:"value"`
|
||||
Host models.Host `json:"host"`
|
||||
Host schema.Host `json:"host"`
|
||||
Pass string `json:"pass,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
ALL bool `json:"all,omitempty"`
|
||||
|
||||
+8
-5
@@ -37,7 +37,10 @@ func CreateRelay(relay models.RelayRequest) ([]models.Node, models.Node, error)
|
||||
if err != nil {
|
||||
return returnnodes, models.Node{}, err
|
||||
}
|
||||
host, err := GetHost(node.HostID.String())
|
||||
host := &schema.Host{
|
||||
ID: node.HostID,
|
||||
}
|
||||
err = host.Get(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
return returnnodes, models.Node{}, err
|
||||
}
|
||||
@@ -120,7 +123,7 @@ func ValidateRelay(relay models.RelayRequest, update bool) error {
|
||||
return errors.New("node is already acting as a relay")
|
||||
}
|
||||
eli, _ := (&schema.Egress{Network: node.Network}).ListByNetwork(db.WithContext(context.TODO()))
|
||||
acls, _ := ListAclsByNetwork(models.NetworkID(node.Network))
|
||||
acls, _ := ListAclsByNetwork(schema.NetworkID(node.Network))
|
||||
for _, relayedNodeID := range relay.RelayedNodes {
|
||||
relayedNode, err := GetNodeByID(relayedNodeID)
|
||||
if err != nil {
|
||||
@@ -203,7 +206,7 @@ func DeleteRelay(network, nodeid string) ([]models.Node, models.Node, error) {
|
||||
func RelayedAllowedIPs(peer, node *models.Node) []net.IPNet {
|
||||
var allowedIPs = []net.IPNet{}
|
||||
eli, _ := (&schema.Egress{Network: node.Network}).ListByNetwork(db.WithContext(context.TODO()))
|
||||
acls, _ := ListAclsByNetwork(models.NetworkID(node.Network))
|
||||
acls, _ := ListAclsByNetwork(schema.NetworkID(node.Network))
|
||||
for _, relayedNodeID := range peer.RelayedNodes {
|
||||
if node.ID.String() == relayedNodeID {
|
||||
continue
|
||||
@@ -237,9 +240,9 @@ func GetAllowedIpsForRelayed(relayed, relay *models.Node) (allowedIPs []net.IPNe
|
||||
return
|
||||
}
|
||||
serverSettings := GetServerSettings()
|
||||
acls, _ := ListAclsByNetwork(models.NetworkID(relay.Network))
|
||||
acls, _ := ListAclsByNetwork(schema.NetworkID(relay.Network))
|
||||
eli, _ := (&schema.Egress{Network: relay.Network}).ListByNetwork(db.WithContext(context.TODO()))
|
||||
defaultPolicy, _ := GetDefaultPolicy(models.NetworkID(relay.Network), models.DevicePolicy)
|
||||
defaultPolicy, _ := GetDefaultPolicy(schema.NetworkID(relay.Network), models.DevicePolicy)
|
||||
for _, peer := range peers {
|
||||
if peer.ID == relayed.ID || peer.ID == relay.ID {
|
||||
continue
|
||||
|
||||
+3
-1
@@ -9,6 +9,7 @@ import (
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
"github.com/gravitl/netmaker/servercfg"
|
||||
)
|
||||
|
||||
@@ -35,7 +36,8 @@ func SecurityCheck(reqAdmin bool, next http.Handler) http.HandlerFunc {
|
||||
return
|
||||
}
|
||||
if username != MasterUser {
|
||||
user, err := GetUser(username)
|
||||
user := &schema.User{Username: username}
|
||||
err = user.Get(r.Context())
|
||||
if err != nil {
|
||||
ReturnErrorResponse(w, r, FormatError(err, "unauthorized"))
|
||||
return
|
||||
|
||||
@@ -5,21 +5,9 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gravitl/netmaker/database"
|
||||
"github.com/gravitl/netmaker/servercfg"
|
||||
)
|
||||
|
||||
var (
|
||||
// NetworksLimit - dummy var for community
|
||||
NetworksLimit = 1000000000
|
||||
// UsersLimit - dummy var for community
|
||||
UsersLimit = 1000000000
|
||||
// MachinesLimit - dummy var for community
|
||||
MachinesLimit = 1000000000
|
||||
// IngressesLimit - dummy var for community
|
||||
IngressesLimit = 1000000000
|
||||
// EgressesLimit - dummy var for community
|
||||
EgressesLimit = 1000000000
|
||||
// FreeTier - specifies if free tier
|
||||
FreeTier = false
|
||||
// DefaultTrialEndDate - is a placeholder date for not applicable trial end dates
|
||||
DefaultTrialEndDate, _ = time.Parse("2006-Jan-02", "2021-Apr-01")
|
||||
@@ -61,13 +49,3 @@ func StoreJWTSecret(privateKey string) error {
|
||||
}
|
||||
return database.Insert("nm-jwt-secret", string(data), database.SERVERCONF_TABLE_NAME)
|
||||
}
|
||||
|
||||
// SetFreeTierLimits - sets limits for free tier
|
||||
func SetFreeTierLimits() {
|
||||
FreeTier = true
|
||||
UsersLimit = servercfg.GetUserLimit()
|
||||
NetworksLimit = servercfg.GetNetworkLimit()
|
||||
MachinesLimit = servercfg.GetMachinesLimit()
|
||||
IngressesLimit = servercfg.GetIngressLimit()
|
||||
EgressesLimit = servercfg.GetEgressLimit()
|
||||
}
|
||||
|
||||
+8
-4
@@ -11,11 +11,15 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"context"
|
||||
|
||||
"github.com/gravitl/netmaker/config"
|
||||
"github.com/gravitl/netmaker/database"
|
||||
"github.com/gravitl/netmaker/db"
|
||||
"github.com/gravitl/netmaker/logic/acls"
|
||||
"github.com/gravitl/netmaker/logic/acls/nodeacls"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
"github.com/gravitl/netmaker/servercfg"
|
||||
)
|
||||
|
||||
@@ -87,13 +91,13 @@ func UpsertServerSettings(s models.ServerSettings) error {
|
||||
}
|
||||
|
||||
func setDefaultsforOldAclCfg() {
|
||||
nets, _ := GetNetworks()
|
||||
nets, _ := (&schema.Network{}).ListAll(db.WithContext(context.TODO()))
|
||||
for _, netI := range nets {
|
||||
if netI.DefaultACL != "yes" {
|
||||
netI.DefaultACL = "yes"
|
||||
UpsertNetwork(netI)
|
||||
UpsertNetwork(&netI)
|
||||
}
|
||||
networkACL, err := nodeacls.FetchAllACLs(nodeacls.NetworkID(netI.NetID))
|
||||
networkACL, err := nodeacls.FetchAllACLs(nodeacls.NetworkID(netI.Name))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
@@ -105,7 +109,7 @@ func setDefaultsforOldAclCfg() {
|
||||
}
|
||||
networkACL.UpdateACL(id, aclNode)
|
||||
}
|
||||
networkACL.Save(acls.ContainerID(netI.NetID))
|
||||
networkACL.Save(acls.ContainerID(netI.Name))
|
||||
}
|
||||
nodes, _ := GetAllNodes()
|
||||
for _, node := range nodes {
|
||||
|
||||
+10
-4
@@ -1,12 +1,15 @@
|
||||
package logic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/gravitl/netmaker/database"
|
||||
"github.com/gravitl/netmaker/db"
|
||||
"github.com/gravitl/netmaker/logger"
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/gravitl/netmaker/servercfg"
|
||||
@@ -92,9 +95,9 @@ func FetchTelemetryData() telemetryData {
|
||||
|
||||
data.IsPro = servercfg.IsPro
|
||||
data.ExtClients = getDBLength(database.EXT_CLIENT_TABLE_NAME)
|
||||
data.Users = getDBLength(database.USERS_TABLE_NAME)
|
||||
data.Networks = getDBLength(database.NETWORKS_TABLE_NAME)
|
||||
data.Hosts = getDBLength(database.HOSTS_TABLE_NAME)
|
||||
data.Users, _ = (&schema.User{}).Count(db.WithContext(context.TODO()))
|
||||
data.Networks, _ = (&schema.Network{}).Count(db.WithContext(context.TODO()))
|
||||
data.Hosts, _ = (&schema.Host{}).Count(db.WithContext(context.TODO()))
|
||||
data.Version = servercfg.GetVersion()
|
||||
data.Servers = getServerCount()
|
||||
nodes, _ := GetAllNodes()
|
||||
@@ -143,7 +146,10 @@ func setTelemetryTimestamp(telRecord *models.Telemetry) error {
|
||||
func getClientCount(nodes []models.Node) clientCount {
|
||||
var count clientCount
|
||||
for _, node := range nodes {
|
||||
host, err := GetHost(node.HostID.String())
|
||||
host := &schema.Host{
|
||||
ID: node.HostID,
|
||||
}
|
||||
err := host.Get(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
+4
-9
@@ -18,14 +18,8 @@ func GetCurrentServerUsage() (limits models.Usage) {
|
||||
if cErr == nil {
|
||||
limits.Clients = len(clients)
|
||||
}
|
||||
users, err := GetUsers()
|
||||
if err == nil {
|
||||
limits.Users = len(users)
|
||||
}
|
||||
networks, err := GetNetworks()
|
||||
if err == nil {
|
||||
limits.Networks = len(networks)
|
||||
}
|
||||
limits.Users, _ = (&schema.User{}).Count(db.WithContext(context.TODO()))
|
||||
limits.Networks, _ = (&schema.Network{}).Count(db.WithContext(context.TODO()))
|
||||
limits.Egresses, _ = (&schema.Egress{}).Count(db.WithContext(context.TODO()))
|
||||
|
||||
nodes, _ := GetAllNodes()
|
||||
@@ -35,8 +29,9 @@ func GetCurrentServerUsage() (limits models.Usage) {
|
||||
}
|
||||
|
||||
limits.NetworkUsage = make(map[string]models.NetworkUsage)
|
||||
networks, _ := (&schema.Network{}).ListAll(db.WithContext(context.TODO()))
|
||||
for _, network := range networks {
|
||||
limits.NetworkUsage[network.NetID] = models.NetworkUsage{}
|
||||
limits.NetworkUsage[network.Name] = models.NetworkUsage{}
|
||||
}
|
||||
|
||||
for _, node := range nodes {
|
||||
|
||||
+57
-115
@@ -1,68 +1,51 @@
|
||||
package logic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/gravitl/netmaker/database"
|
||||
"github.com/gravitl/netmaker/db"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
)
|
||||
|
||||
// Pre-Define Permission Templates for default Roles
|
||||
var SuperAdminPermissionTemplate = models.UserRolePermissionTemplate{
|
||||
ID: models.SuperAdminRole,
|
||||
var SuperAdminPermissionTemplate = schema.UserRole{
|
||||
ID: schema.SuperAdminRole,
|
||||
Default: true,
|
||||
FullAccess: true,
|
||||
}
|
||||
|
||||
var AdminPermissionTemplate = models.UserRolePermissionTemplate{
|
||||
ID: models.AdminRole,
|
||||
var AdminPermissionTemplate = schema.UserRole{
|
||||
ID: schema.AdminRole,
|
||||
Default: true,
|
||||
FullAccess: true,
|
||||
}
|
||||
|
||||
var GetFilteredNodesByUserAccess = func(user models.User, nodes []models.Node) (filteredNodes []models.Node) {
|
||||
var GetFilteredNodesByUserAccess = func(user *schema.User, nodes []models.Node) (filteredNodes []models.Node) {
|
||||
return
|
||||
}
|
||||
|
||||
var CreateRole = func(r models.UserRolePermissionTemplate) error {
|
||||
var DeleteRole = func(r schema.UserRoleID, force bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var DeleteRole = func(r models.UserRoleID, force bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var FilterNetworksByRole = func(allnetworks []models.Network, user models.User) []models.Network {
|
||||
var FilterNetworksByRole = func(allnetworks []schema.Network, user *schema.User) []schema.Network {
|
||||
return allnetworks
|
||||
}
|
||||
|
||||
var IsGroupsValid = func(groups map[models.UserGroupID]struct{}) error {
|
||||
return nil
|
||||
}
|
||||
var IsGroupValid = func(groupID models.UserGroupID) error {
|
||||
return nil
|
||||
}
|
||||
var IsNetworkRolesValid = func(networkRoles map[models.NetworkID]map[models.UserRoleID]struct{}) error {
|
||||
var IsGroupsValid = func(groups map[schema.UserGroupID]struct{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var MigrateUserRoleAndGroups = func(u models.User) models.User {
|
||||
return u
|
||||
}
|
||||
|
||||
var MigrateToUUIDs = func() {}
|
||||
|
||||
var UpdateUserGwAccess = func(currentUser, changeUser models.User) {}
|
||||
|
||||
var UpdateRole = func(r models.UserRolePermissionTemplate) error { return nil }
|
||||
var UpdateUserGwAccess = func(currentUser, changeUser *schema.User) {}
|
||||
|
||||
var InitialiseRoles = userRolesInit
|
||||
var IntialiseGroups = func() {}
|
||||
var DeleteNetworkRoles = func(netID string) {}
|
||||
var CreateDefaultNetworkRolesAndGroups = func(netID models.NetworkID) {}
|
||||
var CreateDefaultUserPolicies = func(netID models.NetworkID) {
|
||||
var CreateDefaultNetworkRolesAndGroups = func(netID schema.NetworkID) {}
|
||||
var CreateDefaultUserPolicies = func(netID schema.NetworkID) {
|
||||
if netID.String() == "" {
|
||||
return
|
||||
}
|
||||
@@ -95,96 +78,55 @@ var CreateDefaultUserPolicies = func(netID models.NetworkID) {
|
||||
InsertAcl(defaultUserAcl)
|
||||
}
|
||||
}
|
||||
var ListUserGroups = func() ([]models.UserGroup, error) { return nil, nil }
|
||||
var GetUserGroupsInNetwork = func(netID models.NetworkID) (networkGrps map[models.UserGroupID]models.UserGroup) { return }
|
||||
var GetUserGroup = func(groupId models.UserGroupID) (userGrps models.UserGroup, err error) { return }
|
||||
var AddGlobalNetRolesToAdmins = func(u *models.User) {}
|
||||
var GetUserGroup = func(groupId schema.UserGroupID) (userGrps schema.UserGroup, err error) { return }
|
||||
var AddGlobalNetRolesToAdmins = func(u *schema.User) {}
|
||||
var EmailInit = func() {}
|
||||
|
||||
// GetRole - fetches role template by id
|
||||
func GetRole(roleID models.UserRoleID) (models.UserRolePermissionTemplate, error) {
|
||||
// check if role already exists
|
||||
data, err := database.FetchRecord(database.USER_PERMISSIONS_TABLE_NAME, roleID.String())
|
||||
if err != nil {
|
||||
return models.UserRolePermissionTemplate{}, err
|
||||
}
|
||||
ur := models.UserRolePermissionTemplate{}
|
||||
err = json.Unmarshal([]byte(data), &ur)
|
||||
if err != nil {
|
||||
return ur, err
|
||||
}
|
||||
return ur, nil
|
||||
}
|
||||
|
||||
// ListPlatformRoles - lists user platform roles permission templates
|
||||
func ListPlatformRoles() ([]models.UserRolePermissionTemplate, error) {
|
||||
data, err := database.FetchRecords(database.USER_PERMISSIONS_TABLE_NAME)
|
||||
if err != nil && !database.IsEmptyRecord(err) {
|
||||
return []models.UserRolePermissionTemplate{}, err
|
||||
}
|
||||
userRoles := []models.UserRolePermissionTemplate{}
|
||||
for _, dataI := range data {
|
||||
userRole := models.UserRolePermissionTemplate{}
|
||||
err := json.Unmarshal([]byte(dataI), &userRole)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if userRole.NetworkID != "" {
|
||||
continue
|
||||
}
|
||||
userRoles = append(userRoles, userRole)
|
||||
}
|
||||
return userRoles, nil
|
||||
}
|
||||
|
||||
func GetAllRsrcIDForRsrc(rsrc models.RsrcType) models.RsrcID {
|
||||
func GetAllRsrcIDForRsrc(rsrc schema.RsrcType) schema.RsrcID {
|
||||
switch rsrc {
|
||||
case models.HostRsrc:
|
||||
return models.AllHostRsrcID
|
||||
case models.RelayRsrc:
|
||||
return models.AllRelayRsrcID
|
||||
case models.RemoteAccessGwRsrc:
|
||||
return models.AllRemoteAccessGwRsrcID
|
||||
case models.ExtClientsRsrc:
|
||||
return models.AllExtClientsRsrcID
|
||||
case models.InetGwRsrc:
|
||||
return models.AllInetGwRsrcID
|
||||
case models.EgressGwRsrc:
|
||||
return models.AllEgressGwRsrcID
|
||||
case models.NetworkRsrc:
|
||||
return models.AllNetworkRsrcID
|
||||
case models.EnrollmentKeysRsrc:
|
||||
return models.AllEnrollmentKeysRsrcID
|
||||
case models.UserRsrc:
|
||||
return models.AllUserRsrcID
|
||||
case models.DnsRsrc:
|
||||
return models.AllDnsRsrcID
|
||||
case models.FailOverRsrc:
|
||||
return models.AllFailOverRsrcID
|
||||
case models.AclRsrc:
|
||||
return models.AllAclsRsrcID
|
||||
case models.TagRsrc:
|
||||
return models.AllTagsRsrcID
|
||||
case models.PostureCheckRsrc:
|
||||
return models.AllPostureCheckRsrcID
|
||||
case models.NameserverRsrc:
|
||||
return models.AllNameserverRsrcID
|
||||
case models.JitAdminRsrc:
|
||||
return models.AllJitAdminRsrcID
|
||||
case models.JitUserRsrc:
|
||||
return models.AllJitUserRsrcID
|
||||
case models.UserActivityRsrc:
|
||||
return models.AllUserActivityRsrcID
|
||||
case models.TrafficFlow:
|
||||
return models.AllTrafficFlowRsrcID
|
||||
case schema.HostRsrc:
|
||||
return schema.AllHostRsrcID
|
||||
case schema.RelayRsrc:
|
||||
return schema.AllRelayRsrcID
|
||||
case schema.RemoteAccessGwRsrc:
|
||||
return schema.AllRemoteAccessGwRsrcID
|
||||
case schema.ExtClientsRsrc:
|
||||
return schema.AllExtClientsRsrcID
|
||||
case schema.InetGwRsrc:
|
||||
return schema.AllInetGwRsrcID
|
||||
case schema.EgressGwRsrc:
|
||||
return schema.AllEgressGwRsrcID
|
||||
case schema.NetworkRsrc:
|
||||
return schema.AllNetworkRsrcID
|
||||
case schema.EnrollmentKeysRsrc:
|
||||
return schema.AllEnrollmentKeysRsrcID
|
||||
case schema.UserRsrc:
|
||||
return schema.AllUserRsrcID
|
||||
case schema.DnsRsrc:
|
||||
return schema.AllDnsRsrcID
|
||||
case schema.FailOverRsrc:
|
||||
return schema.AllFailOverRsrcID
|
||||
case schema.AclRsrc:
|
||||
return schema.AllAclsRsrcID
|
||||
case schema.TagRsrc:
|
||||
return schema.AllTagsRsrcID
|
||||
case schema.PostureCheckRsrc:
|
||||
return schema.AllPostureCheckRsrcID
|
||||
case schema.NameserverRsrc:
|
||||
return schema.AllNameserverRsrcID
|
||||
case schema.JitAdminRsrc:
|
||||
return schema.AllJitAdminRsrcID
|
||||
case schema.JitUserRsrc:
|
||||
return schema.AllJitUserRsrcID
|
||||
case schema.UserActivityRsrc:
|
||||
return schema.AllUserActivityRsrcID
|
||||
case schema.TrafficFlow:
|
||||
return schema.AllTrafficFlowRsrcID
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func userRolesInit() {
|
||||
d, _ := json.Marshal(SuperAdminPermissionTemplate)
|
||||
database.Insert(SuperAdminPermissionTemplate.ID.String(), string(d), database.USER_PERMISSIONS_TABLE_NAME)
|
||||
d, _ = json.Marshal(AdminPermissionTemplate)
|
||||
database.Insert(AdminPermissionTemplate.ID.String(), string(d), database.USER_PERMISSIONS_TABLE_NAME)
|
||||
|
||||
_ = SuperAdminPermissionTemplate.Upsert(db.WithContext(context.TODO()))
|
||||
_ = AdminPermissionTemplate.Upsert(db.WithContext(context.TODO()))
|
||||
}
|
||||
|
||||
+42
-106
@@ -1,120 +1,59 @@
|
||||
package logic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"sort"
|
||||
"sync"
|
||||
|
||||
"github.com/gravitl/netmaker/database"
|
||||
"github.com/gravitl/netmaker/db"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/gravitl/netmaker/servercfg"
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
"gorm.io/datatypes"
|
||||
)
|
||||
|
||||
var (
|
||||
userCacheMutex = &sync.RWMutex{}
|
||||
usersCacheMap = make(map[string]models.User)
|
||||
)
|
||||
|
||||
func getUserFromCache(username string) (models.User, bool) {
|
||||
userCacheMutex.RLock()
|
||||
user, ok := usersCacheMap[username]
|
||||
userCacheMutex.RUnlock()
|
||||
return user, ok
|
||||
}
|
||||
|
||||
func getUsersFromCache() []models.User {
|
||||
userCacheMutex.RLock()
|
||||
users := make([]models.User, 0, len(usersCacheMap))
|
||||
for _, user := range usersCacheMap {
|
||||
users = append(users, user)
|
||||
}
|
||||
userCacheMutex.RUnlock()
|
||||
return users
|
||||
}
|
||||
|
||||
func storeUserInCache(user models.User) {
|
||||
userCacheMutex.Lock()
|
||||
usersCacheMap[user.UserName] = user
|
||||
userCacheMutex.Unlock()
|
||||
}
|
||||
|
||||
func deleteUserFromCache(username string) {
|
||||
userCacheMutex.Lock()
|
||||
delete(usersCacheMap, username)
|
||||
userCacheMutex.Unlock()
|
||||
}
|
||||
|
||||
func loadUsersIntoCache(users map[string]models.User) {
|
||||
userCacheMutex.Lock()
|
||||
usersCacheMap = users
|
||||
userCacheMutex.Unlock()
|
||||
}
|
||||
|
||||
// GetUser - gets a user
|
||||
func GetUser(username string) (*models.User, error) {
|
||||
if servercfg.CacheEnabled() {
|
||||
if user, ok := getUserFromCache(username); ok {
|
||||
user.IsSuperAdmin = user.PlatformRoleID == models.SuperAdminRole
|
||||
user.IsAdmin = user.PlatformRoleID == models.SuperAdminRole || user.PlatformRoleID == models.AdminRole
|
||||
return &user, nil
|
||||
}
|
||||
}
|
||||
var user models.User
|
||||
record, err := database.FetchRecord(database.USERS_TABLE_NAME, username)
|
||||
if err != nil {
|
||||
return &user, err
|
||||
}
|
||||
if err = json.Unmarshal([]byte(record), &user); err != nil {
|
||||
return &models.User{}, err
|
||||
}
|
||||
|
||||
user.IsSuperAdmin = user.PlatformRoleID == models.SuperAdminRole
|
||||
user.IsAdmin = user.PlatformRoleID == models.SuperAdminRole || user.PlatformRoleID == models.AdminRole
|
||||
if servercfg.CacheEnabled() {
|
||||
storeUserInCache(user)
|
||||
}
|
||||
return &user, err
|
||||
}
|
||||
|
||||
// GetReturnUser - gets a user
|
||||
func GetReturnUser(username string) (models.ReturnUser, error) {
|
||||
u, err := GetUser(username)
|
||||
_user := &schema.User{
|
||||
Username: username,
|
||||
}
|
||||
err := _user.Get(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
return models.ReturnUser{}, err
|
||||
}
|
||||
return ToReturnUser(*u), nil
|
||||
|
||||
return ToReturnUser(_user), nil
|
||||
}
|
||||
|
||||
// ToReturnUser - gets a user as a return user
|
||||
func ToReturnUser(user models.User) models.ReturnUser {
|
||||
func ToReturnUser(user *schema.User) models.ReturnUser {
|
||||
return models.ReturnUser{
|
||||
UserName: user.UserName,
|
||||
UserName: user.Username,
|
||||
ExternalIdentityProviderID: user.ExternalIdentityProviderID,
|
||||
IsMFAEnabled: user.IsMFAEnabled,
|
||||
DisplayName: user.DisplayName,
|
||||
AccountDisabled: user.AccountDisabled,
|
||||
IsAdmin: user.PlatformRoleID == schema.SuperAdminRole || user.PlatformRoleID == schema.AdminRole,
|
||||
IsSuperAdmin: user.PlatformRoleID == schema.SuperAdminRole,
|
||||
AuthType: user.AuthType,
|
||||
RemoteGwIDs: user.RemoteGwIDs,
|
||||
UserGroups: user.UserGroups,
|
||||
PlatformRoleID: user.PlatformRoleID,
|
||||
IsSuperAdmin: user.PlatformRoleID == models.SuperAdminRole,
|
||||
IsAdmin: user.PlatformRoleID == models.SuperAdminRole || user.PlatformRoleID == models.AdminRole,
|
||||
NetworkRoles: user.NetworkRoles,
|
||||
LastLoginTime: user.LastLoginTime,
|
||||
// no need to set. field not in use.
|
||||
RemoteGwIDs: nil,
|
||||
UserGroups: user.UserGroups.Data(),
|
||||
PlatformRoleID: user.PlatformRoleID,
|
||||
// no need to set. field not in use.
|
||||
NetworkRoles: nil,
|
||||
LastLoginTime: user.LastLoginAt,
|
||||
CreatedBy: user.CreatedBy,
|
||||
CreatedAt: user.CreatedAt,
|
||||
UpdatedAt: user.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
// SetUserDefaults - sets the defaults of a user to avoid empty fields
|
||||
func SetUserDefaults(user *models.User) {
|
||||
if user.RemoteGwIDs == nil {
|
||||
user.RemoteGwIDs = make(map[string]struct{})
|
||||
}
|
||||
if len(user.NetworkRoles) == 0 {
|
||||
user.NetworkRoles = make(map[models.NetworkID]map[models.UserRoleID]struct{})
|
||||
}
|
||||
if len(user.UserGroups) == 0 {
|
||||
user.UserGroups = make(map[models.UserGroupID]struct{})
|
||||
func SetUserDefaults(user *schema.User) {
|
||||
if len(user.UserGroups.Data()) == 0 {
|
||||
user.UserGroups = datatypes.NewJSONType(make(map[schema.UserGroupID]struct{}))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -127,16 +66,13 @@ func SortUsers(unsortedUsers []models.ReturnUser) {
|
||||
|
||||
// GetSuperAdmin - fetches superadmin user
|
||||
func GetSuperAdmin() (models.ReturnUser, error) {
|
||||
users, err := GetUsers()
|
||||
_user := &schema.User{}
|
||||
err := _user.GetSuperAdmin(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
return models.ReturnUser{}, err
|
||||
}
|
||||
for _, user := range users {
|
||||
if user.PlatformRoleID == models.SuperAdminRole {
|
||||
return user, nil
|
||||
}
|
||||
}
|
||||
return models.ReturnUser{}, errors.New("superadmin not found")
|
||||
|
||||
return ToReturnUser(_user), nil
|
||||
}
|
||||
|
||||
func InsertPendingUser(u *models.User) error {
|
||||
@@ -177,8 +113,8 @@ func ListPendingReturnUsers() ([]models.ReturnUser, error) {
|
||||
user := models.ReturnUser{}
|
||||
err = json.Unmarshal([]byte(record), &user)
|
||||
if err == nil {
|
||||
user.IsSuperAdmin = user.PlatformRoleID == models.SuperAdminRole
|
||||
user.IsAdmin = user.PlatformRoleID == models.SuperAdminRole || user.PlatformRoleID == models.AdminRole
|
||||
user.IsSuperAdmin = user.PlatformRoleID == schema.SuperAdminRole
|
||||
user.IsAdmin = user.PlatformRoleID == schema.SuperAdminRole || user.PlatformRoleID == schema.AdminRole
|
||||
pendingUsers = append(pendingUsers, user)
|
||||
}
|
||||
}
|
||||
@@ -195,25 +131,25 @@ func ListPendingUsers() ([]models.User, error) {
|
||||
var user models.User
|
||||
err = json.Unmarshal([]byte(record), &user)
|
||||
if err == nil {
|
||||
user.IsSuperAdmin = user.PlatformRoleID == models.SuperAdminRole
|
||||
user.IsAdmin = user.PlatformRoleID == models.SuperAdminRole || user.PlatformRoleID == models.AdminRole
|
||||
user.IsSuperAdmin = user.PlatformRoleID == schema.SuperAdminRole
|
||||
user.IsAdmin = user.PlatformRoleID == schema.SuperAdminRole || user.PlatformRoleID == schema.AdminRole
|
||||
pendingUsers = append(pendingUsers, user)
|
||||
}
|
||||
}
|
||||
return pendingUsers, nil
|
||||
}
|
||||
|
||||
func GetUserMap() (map[string]models.User, error) {
|
||||
users, err := GetUsersDB()
|
||||
if err != nil && !database.IsEmptyRecord(err) {
|
||||
func GetUserMap() (map[string]schema.User, error) {
|
||||
users, err := (&schema.User{}).ListAll(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userMap := make(map[string]models.User, len(users))
|
||||
|
||||
userMap := make(map[string]schema.User, len(users))
|
||||
for _, user := range users {
|
||||
user.IsSuperAdmin = user.PlatformRoleID == models.SuperAdminRole
|
||||
user.IsAdmin = user.PlatformRoleID == models.SuperAdminRole || user.PlatformRoleID == models.AdminRole
|
||||
userMap[user.UserName] = user
|
||||
userMap[user.Username] = user
|
||||
}
|
||||
|
||||
return userMap, nil
|
||||
}
|
||||
|
||||
|
||||
+9
-21
@@ -2,10 +2,10 @@
|
||||
package logic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base32"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
@@ -19,9 +19,9 @@ import (
|
||||
|
||||
"github.com/blang/semver"
|
||||
"github.com/c-robinson/iplib"
|
||||
"github.com/gravitl/netmaker/database"
|
||||
"github.com/gravitl/netmaker/db"
|
||||
"github.com/gravitl/netmaker/logger"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
)
|
||||
|
||||
// IsBase64 - checks if a string is in base64 format
|
||||
@@ -57,23 +57,11 @@ func IsAddressInCIDR(address net.IP, cidr string) bool {
|
||||
|
||||
// SetNetworkNodesLastModified - sets the network nodes last modified
|
||||
func SetNetworkNodesLastModified(networkName string) error {
|
||||
|
||||
timestamp := time.Now().Unix()
|
||||
|
||||
network, err := GetParentNetwork(networkName)
|
||||
if err != nil {
|
||||
return err
|
||||
_network := &schema.Network{
|
||||
Name: networkName,
|
||||
NodesUpdatedAt: time.Now(),
|
||||
}
|
||||
network.NodesLastModified = timestamp
|
||||
data, err := json.Marshal(&network)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = database.Insert(networkName, string(data), database.NETWORKS_TABLE_NAME)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
return _network.UpdateNodesUpdatedAt(db.WithContext(context.TODO()))
|
||||
}
|
||||
|
||||
// RandomString - returns a random string in a charset
|
||||
@@ -270,7 +258,7 @@ func GetClientIP(r *http.Request) string {
|
||||
}
|
||||
|
||||
// CompareIfaceSlices compares two slices of Iface for deep equality (order-sensitive)
|
||||
func CompareIfaceSlices(a, b []models.Iface) bool {
|
||||
func CompareIfaceSlices(a, b []schema.Iface) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
@@ -281,7 +269,7 @@ func CompareIfaceSlices(a, b []models.Iface) bool {
|
||||
}
|
||||
return true
|
||||
}
|
||||
func compareIface(a, b models.Iface) bool {
|
||||
func compareIface(a, b schema.Iface) bool {
|
||||
return a.Name == b.Name &&
|
||||
a.Address.IP.Equal(b.Address.IP) &&
|
||||
a.Address.Mask.String() == b.Address.Mask.String() &&
|
||||
|
||||
+7
-6
@@ -5,8 +5,10 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gravitl/netmaker/db"
|
||||
"github.com/gravitl/netmaker/logger"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -46,8 +48,8 @@ func CheckZombies(newnode *models.Node) {
|
||||
|
||||
// checkForZombieHosts - checks if new host has the same macAddress as an existing host
|
||||
// if true, existing host is added to host zombie collection
|
||||
func checkForZombieHosts(h *models.Host) {
|
||||
hosts, err := GetAllHosts()
|
||||
func checkForZombieHosts(h *schema.Host) {
|
||||
hosts, err := (&schema.Host{}).ListAll(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
logger.Log(3, "error retrieving all hosts", err.Error())
|
||||
}
|
||||
@@ -118,12 +120,11 @@ func ManageZombies(ctx context.Context) {
|
||||
if len(hostZombies) > 0 {
|
||||
logger.Log(3, "checking host zombies")
|
||||
for i := len(hostZombies) - 1; i >= 0; i-- {
|
||||
host, err := GetHost(hostZombies[i].String())
|
||||
host := &schema.Host{ID: hostZombies[i]}
|
||||
err := host.Get(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
logger.Log(1, "error retrieving zombie host", err.Error())
|
||||
if host != nil {
|
||||
logger.Log(1, "deleting ", host.ID.String(), " from zombie list")
|
||||
}
|
||||
logger.Log(1, "deleting ", host.ID.String(), " from zombie list")
|
||||
hostZombies = append(hostZombies[:i], hostZombies[i+1:]...)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -53,6 +53,7 @@ var version = "v1.5.0"
|
||||
func main() {
|
||||
absoluteConfigPath := flag.String("c", "", "absolute path to configuration file")
|
||||
flag.Parse()
|
||||
setVerbosity()
|
||||
setupConfig(*absoluteConfigPath)
|
||||
servercfg.SetVersion(version)
|
||||
fmt.Println(models.RetrieveLogo()) // print the logo
|
||||
@@ -60,10 +61,6 @@ func main() {
|
||||
logic.SetAllocatedIpMap()
|
||||
defer logic.ClearAllocatedIpMap()
|
||||
setGarbageCollection()
|
||||
setVerbosity()
|
||||
if servercfg.DeployedByOperator() && !servercfg.IsPro {
|
||||
logic.SetFreeTierLimits()
|
||||
}
|
||||
defer db.CloseDB()
|
||||
defer database.CloseDB()
|
||||
|
||||
@@ -122,15 +119,19 @@ func initialize() { // Client Mode Prereq Check
|
||||
logger.FatalLog("error initializing database: ", err.Error())
|
||||
}
|
||||
|
||||
err = migrate.ToSQLSchema()
|
||||
if err != nil {
|
||||
// we shouldn't allow user to use the product until the migration is successfully done.
|
||||
panic(err)
|
||||
}
|
||||
|
||||
initializeUUID()
|
||||
|
||||
//initialize cache
|
||||
_, _ = logic.GetNetworks()
|
||||
_, _ = logic.GetAllNodes()
|
||||
_, _ = logic.GetAllHosts()
|
||||
_, _ = logic.GetAllExtClients()
|
||||
_ = logic.ListAcls()
|
||||
_, _ = logic.GetAllEnrollmentKeys()
|
||||
_, _ = logic.GetUsersDB()
|
||||
_ = logic.CleanExpiredSSOStates()
|
||||
|
||||
migrate.Run()
|
||||
|
||||
+62
-303
@@ -33,39 +33,24 @@ func Run() {
|
||||
updateEnrollmentKeys()
|
||||
assignSuperAdmin()
|
||||
createDefaultTagsAndPolicies()
|
||||
removeOldUserGrps()
|
||||
migrateToUUIDs()
|
||||
syncUsers()
|
||||
updateHosts()
|
||||
updateNodes()
|
||||
updateAcls()
|
||||
updateNewAcls()
|
||||
logic.MigrateToGws()
|
||||
migrateToEgressV1()
|
||||
updateNetworks()
|
||||
migrateNameservers()
|
||||
resync()
|
||||
deleteOldExtclients()
|
||||
checkAndDeprecateOldAcls()
|
||||
migrateJITEnabled()
|
||||
}
|
||||
|
||||
func migrateJITEnabled() {
|
||||
nets, _ := logic.GetNetworks()
|
||||
for _, netI := range nets {
|
||||
if netI.JITEnabled == "" {
|
||||
netI.JITEnabled = "no"
|
||||
logic.UpsertNetwork(netI)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func checkAndDeprecateOldAcls() {
|
||||
// check if everything is allowed on old acl and disable old acls
|
||||
nets, _ := logic.GetNetworks()
|
||||
nets, _ := (&schema.Network{}).ListAll(db.WithContext(context.TODO()))
|
||||
disableOldAcls := true
|
||||
for _, netI := range nets {
|
||||
networkACL, err := nodeacls.FetchAllACLs(nodeacls.NetworkID(netI.NetID))
|
||||
networkACL, err := nodeacls.FetchAllACLs(nodeacls.NetworkID(netI.Name))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
@@ -79,7 +64,7 @@ func checkAndDeprecateOldAcls() {
|
||||
}
|
||||
if disableOldAcls {
|
||||
netI.DefaultACL = "yes"
|
||||
logic.UpsertNetwork(netI)
|
||||
logic.UpsertNetwork(&netI)
|
||||
}
|
||||
}
|
||||
if disableOldAcls {
|
||||
@@ -91,17 +76,6 @@ func checkAndDeprecateOldAcls() {
|
||||
}
|
||||
|
||||
func updateNetworks() {
|
||||
nets, _ := logic.GetNetworks()
|
||||
for _, netI := range nets {
|
||||
if netI.AutoJoin == "" {
|
||||
netI.AutoJoin = "true"
|
||||
logic.UpsertNetwork(netI)
|
||||
}
|
||||
if netI.AutoRemove == "" {
|
||||
netI.AutoRemove = "false"
|
||||
logic.UpsertNetwork(netI)
|
||||
}
|
||||
}
|
||||
initializeVirtualNATSettings()
|
||||
}
|
||||
|
||||
@@ -115,7 +89,7 @@ func initializeVirtualNATSettings() {
|
||||
logger.Log(1, "Initializing Virtual NAT settings for existing networks")
|
||||
defer logger.Log(1, "Completed initializing Virtual NAT settings for existing networks")
|
||||
|
||||
networks, err := logic.GetNetworks()
|
||||
networks, err := (&schema.Network{}).ListAll(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
logger.Log(0, "failed to get networks for Virtual NAT migration:", err.Error())
|
||||
return
|
||||
@@ -174,7 +148,7 @@ func initializeVirtualNATSettings() {
|
||||
|
||||
// If this network needs a unique pool, allocate one
|
||||
if needsUniquePool {
|
||||
uniquePool := allocateUniquePoolFromFallback(fallbackNet, poolPrefixLen, allocatedPools, network.NetID)
|
||||
uniquePool := allocateUniquePoolFromFallback(fallbackNet, poolPrefixLen, allocatedPools, network.Name)
|
||||
if uniquePool != "" {
|
||||
vpnCIDR = uniquePool
|
||||
allocatedPools[uniquePool] = struct{}{}
|
||||
@@ -182,14 +156,14 @@ func initializeVirtualNATSettings() {
|
||||
}
|
||||
|
||||
// Initialize virtual NAT defaults
|
||||
network.AssignVirtualNATDefaults(vpnCIDR, network.NetID)
|
||||
logic.AssignVirtualNATDefaults(&network, vpnCIDR)
|
||||
|
||||
// Save the updated network
|
||||
if err := logic.UpsertNetwork(network); err != nil {
|
||||
logger.Log(0, "failed to update network", network.NetID, "with Virtual NAT settings:", err.Error())
|
||||
if err := logic.UpsertNetwork(&network); err != nil {
|
||||
logger.Log(0, "failed to update network", network.Name, "with Virtual NAT settings:", err.Error())
|
||||
continue
|
||||
}
|
||||
logger.Log(1, "initialized Virtual NAT settings for network", network.NetID, "pool:", network.VirtualNATPoolIPv4)
|
||||
logger.Log(1, "initialized Virtual NAT settings for network", network.Name, "pool:", network.VirtualNATPoolIPv4)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -282,116 +256,6 @@ func cidrOverlaps(a, b *net.IPNet) bool {
|
||||
return a.Contains(b.IP) || b.Contains(a.IP)
|
||||
}
|
||||
|
||||
func migrateNameservers() {
|
||||
nets, _ := logic.GetNetworks()
|
||||
user, err := logic.GetSuperAdmin()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, netI := range nets {
|
||||
_ = logic.CreateFallbackNameserver(netI.NetID)
|
||||
|
||||
_, cidr, err := net.ParseCIDR(netI.AddressRange)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
ns := &schema.Nameserver{
|
||||
NetworkID: netI.NetID,
|
||||
}
|
||||
nameservers, _ := ns.ListByNetwork(db.WithContext(context.TODO()))
|
||||
for _, nsI := range nameservers {
|
||||
if len(nsI.Domains) != 0 {
|
||||
for _, matchDomain := range nsI.MatchDomains {
|
||||
nsI.Domains = append(nsI.Domains, schema.NameserverDomain{
|
||||
Domain: matchDomain,
|
||||
})
|
||||
}
|
||||
|
||||
nsI.MatchDomains = []string{}
|
||||
|
||||
_ = nsI.Update(db.WithContext(context.TODO()))
|
||||
}
|
||||
}
|
||||
|
||||
if len(netI.NameServers) > 0 {
|
||||
ns := schema.Nameserver{
|
||||
ID: uuid.NewString(),
|
||||
Name: "upstream nameservers",
|
||||
NetworkID: netI.NetID,
|
||||
Servers: []string{},
|
||||
MatchAll: true,
|
||||
Domains: []schema.NameserverDomain{
|
||||
{
|
||||
Domain: ".",
|
||||
},
|
||||
},
|
||||
Tags: datatypes.JSONMap{
|
||||
"*": struct{}{},
|
||||
},
|
||||
Nodes: make(datatypes.JSONMap),
|
||||
Status: true,
|
||||
CreatedBy: user.UserName,
|
||||
}
|
||||
|
||||
for _, nsIP := range netI.NameServers {
|
||||
if net.ParseIP(nsIP) == nil {
|
||||
continue
|
||||
}
|
||||
if !cidr.Contains(net.ParseIP(nsIP)) {
|
||||
ns.Servers = append(ns.Servers, nsIP)
|
||||
}
|
||||
}
|
||||
ns.Create(db.WithContext(context.TODO()))
|
||||
netI.NameServers = []string{}
|
||||
logic.SaveNetwork(&netI)
|
||||
}
|
||||
}
|
||||
nodes, _ := logic.GetAllNodes()
|
||||
for _, node := range nodes {
|
||||
if !node.IsGw {
|
||||
continue
|
||||
}
|
||||
if node.IngressDNS != "" {
|
||||
if (node.Address.IP != nil && node.Address.IP.String() == node.IngressDNS) ||
|
||||
(node.Address6.IP != nil && node.Address6.IP.String() == node.IngressDNS) {
|
||||
continue
|
||||
}
|
||||
if node.IngressDNS == "8.8.8.8" || node.IngressDNS == "1.1.1.1" || node.IngressDNS == "9.9.9.9" {
|
||||
continue
|
||||
}
|
||||
h, err := logic.GetHost(node.HostID.String())
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
ns := schema.Nameserver{
|
||||
ID: uuid.NewString(),
|
||||
Name: fmt.Sprintf("%s gw nameservers", h.Name),
|
||||
NetworkID: node.Network,
|
||||
Servers: []string{node.IngressDNS},
|
||||
MatchAll: true,
|
||||
Domains: []schema.NameserverDomain{
|
||||
{
|
||||
Domain: ".",
|
||||
},
|
||||
},
|
||||
Nodes: datatypes.JSONMap{
|
||||
node.ID.String(): struct{}{},
|
||||
},
|
||||
Tags: make(datatypes.JSONMap),
|
||||
Status: true,
|
||||
CreatedBy: user.UserName,
|
||||
}
|
||||
ns.Create(db.WithContext(context.TODO()))
|
||||
node.IngressDNS = ""
|
||||
logic.UpsertNode(&node)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// removes if any stale configurations from previous run.
|
||||
func resync() {
|
||||
|
||||
@@ -447,17 +311,18 @@ func assignSuperAdmin() {
|
||||
createdSuperAdmin := false
|
||||
owner := servercfg.GetOwnerEmail()
|
||||
if owner != "" {
|
||||
user, err := logic.GetUser(owner)
|
||||
user := &schema.User{Username: owner}
|
||||
err = user.Get(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
log.Fatal("error getting user", "user", owner, "error", err.Error())
|
||||
}
|
||||
user.PlatformRoleID = models.SuperAdminRole
|
||||
user.PlatformRoleID = schema.SuperAdminRole
|
||||
err = logic.UpsertUser(*user)
|
||||
if err != nil {
|
||||
log.Fatal(
|
||||
"error updating user to superadmin",
|
||||
"user",
|
||||
user.UserName,
|
||||
user.Username,
|
||||
"error",
|
||||
err.Error(),
|
||||
)
|
||||
@@ -466,7 +331,7 @@ func assignSuperAdmin() {
|
||||
}
|
||||
for _, u := range users {
|
||||
var isAdmin bool
|
||||
if u.PlatformRoleID == models.AdminRole {
|
||||
if u.PlatformRoleID == schema.AdminRole {
|
||||
isAdmin = true
|
||||
}
|
||||
if u.PlatformRoleID == "" && u.IsAdmin {
|
||||
@@ -474,19 +339,19 @@ func assignSuperAdmin() {
|
||||
}
|
||||
|
||||
if isAdmin {
|
||||
user, err := logic.GetUser(u.UserName)
|
||||
user := &schema.User{Username: u.UserName}
|
||||
err = user.Get(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
slog.Error("error getting user", "user", u.UserName, "error", err.Error())
|
||||
continue
|
||||
}
|
||||
user.PlatformRoleID = models.SuperAdminRole
|
||||
user.IsSuperAdmin = true
|
||||
user.PlatformRoleID = schema.SuperAdminRole
|
||||
err = logic.UpsertUser(*user)
|
||||
if err != nil {
|
||||
slog.Error(
|
||||
"error updating user to superadmin",
|
||||
"user",
|
||||
user.UserName,
|
||||
user.Username,
|
||||
"error",
|
||||
err.Error(),
|
||||
)
|
||||
@@ -549,16 +414,16 @@ func updateEnrollmentKeys() {
|
||||
existingTags[t] = struct{}{}
|
||||
}
|
||||
}
|
||||
networks, _ := logic.GetNetworks()
|
||||
networks, _ := (&schema.Network{}).ListAll(db.WithContext(context.TODO()))
|
||||
for _, network := range networks {
|
||||
if _, ok := existingTags[network.NetID]; ok {
|
||||
if _, ok := existingTags[network.Name]; ok {
|
||||
continue
|
||||
}
|
||||
_, _ = logic.CreateEnrollmentKey(
|
||||
0,
|
||||
time.Time{},
|
||||
[]string{network.NetID},
|
||||
[]string{network.NetID},
|
||||
[]string{network.Name},
|
||||
[]string{network.Name},
|
||||
[]models.TagID{},
|
||||
true,
|
||||
uuid.Nil,
|
||||
@@ -566,57 +431,6 @@ func updateEnrollmentKeys() {
|
||||
false,
|
||||
false,
|
||||
)
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func removeOldUserGrps() {
|
||||
rows, err := database.FetchRecords(database.USER_GROUPS_TABLE_NAME)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for key, row := range rows {
|
||||
userG := models.UserGroup{}
|
||||
_ = json.Unmarshal([]byte(row), &userG)
|
||||
if userG.ID == "" {
|
||||
database.DeleteRecord(database.USER_GROUPS_TABLE_NAME, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func updateHosts() {
|
||||
rows, err := database.FetchRecords(database.HOSTS_TABLE_NAME)
|
||||
if err != nil {
|
||||
logger.Log(0, "failed to fetch database records for hosts")
|
||||
}
|
||||
for _, row := range rows {
|
||||
var host models.Host
|
||||
if err := json.Unmarshal([]byte(row), &host); err != nil {
|
||||
logger.Log(0, "failed to unmarshal database row to host", "row", row)
|
||||
continue
|
||||
}
|
||||
if host.PersistentKeepalive == 0 {
|
||||
host.PersistentKeepalive = models.DefaultPersistentKeepAlive
|
||||
if err := logic.UpsertHost(&host); err != nil {
|
||||
logger.Log(0, "failed to upsert host", host.ID.String())
|
||||
continue
|
||||
}
|
||||
}
|
||||
if host.DNS == "" || (host.DNS != "yes" && host.DNS != "no") {
|
||||
if logic.GetServerSettings().ManageDNS {
|
||||
host.DNS = "yes"
|
||||
} else {
|
||||
host.DNS = "no"
|
||||
}
|
||||
if host.IsDefault {
|
||||
host.DNS = "yes"
|
||||
}
|
||||
logic.UpsertHost(&host)
|
||||
}
|
||||
if host.IsDefault && !host.AutoUpdate {
|
||||
host.AutoUpdate = true
|
||||
logic.UpsertHost(&host)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -633,7 +447,10 @@ func updateNodes() {
|
||||
logic.UpsertNode(&node)
|
||||
}
|
||||
if node.IsIngressGateway {
|
||||
host, err := logic.GetHost(node.HostID.String())
|
||||
host := &schema.Host{
|
||||
ID: node.HostID,
|
||||
}
|
||||
err = host.Get(db.WithContext(context.TODO()))
|
||||
if err == nil {
|
||||
go logic.DeleteRole(models.GetRAGRoleID(node.Network, host.ID.String()), true)
|
||||
}
|
||||
@@ -682,8 +499,8 @@ func updateAcls() {
|
||||
if !logic.GetServerSettings().OldAClsSupport {
|
||||
return
|
||||
}
|
||||
networks, err := logic.GetNetworks()
|
||||
if err != nil && !database.IsEmptyRecord(err) {
|
||||
networks, err := (&schema.Network{}).ListAll(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
slog.Error("acls migration failed. error getting networks", "error", err)
|
||||
return
|
||||
}
|
||||
@@ -691,19 +508,19 @@ func updateAcls() {
|
||||
// get current acls per network
|
||||
for _, network := range networks {
|
||||
var networkAcl acls.ACLContainer
|
||||
networkAcl, err := networkAcl.Get(acls.ContainerID(network.NetID))
|
||||
networkAcl, err := networkAcl.Get(acls.ContainerID(network.Name))
|
||||
if err != nil {
|
||||
if database.IsEmptyRecord(err) {
|
||||
continue
|
||||
}
|
||||
slog.Error(fmt.Sprintf("error during acls migration. error getting acls for network: %s", network.NetID), "error", err)
|
||||
slog.Error(fmt.Sprintf("error during acls migration. error getting acls for network: %s", network.Name), "error", err)
|
||||
continue
|
||||
}
|
||||
// convert old acls to new acls with clients
|
||||
// TODO: optimise O(n^2) operation
|
||||
clients, err := logic.GetNetworkExtClients(network.NetID)
|
||||
clients, err := logic.GetNetworkExtClients(network.Name)
|
||||
if err != nil {
|
||||
slog.Error(fmt.Sprintf("error during acls migration. error getting clients for network: %s", network.NetID), "error", err)
|
||||
slog.Error(fmt.Sprintf("error during acls migration. error getting clients for network: %s", network.Name), "error", err)
|
||||
continue
|
||||
}
|
||||
clientsIdMap := make(map[string]struct{})
|
||||
@@ -749,7 +566,7 @@ func updateAcls() {
|
||||
continue
|
||||
}
|
||||
if nodeAcl == nil {
|
||||
slog.Warn("acls migration bad data: nil node acl", "node", id, "network", network.NetID)
|
||||
slog.Warn("acls migration bad data: nil node acl", "node", id, "network", network.Name)
|
||||
continue
|
||||
}
|
||||
nodeAcl[acls.AclID(client.ClientID)] = acls.Allowed
|
||||
@@ -793,19 +610,19 @@ func updateAcls() {
|
||||
}
|
||||
|
||||
// save new acls
|
||||
slog.Debug(fmt.Sprintf("(migration) saving new acls for network: %s", network.NetID), "networkAcl", networkAcl)
|
||||
if _, err := networkAcl.Save(acls.ContainerID(network.NetID)); err != nil {
|
||||
slog.Error(fmt.Sprintf("error during acls migration. error saving new acls for network: %s", network.NetID), "error", err)
|
||||
slog.Debug(fmt.Sprintf("(migration) saving new acls for network: %s", network.Name), "networkAcl", networkAcl)
|
||||
if _, err := networkAcl.Save(acls.ContainerID(network.Name)); err != nil {
|
||||
slog.Error(fmt.Sprintf("error during acls migration. error saving new acls for network: %s", network.Name), "error", err)
|
||||
continue
|
||||
}
|
||||
slog.Info(fmt.Sprintf("(migration) successfully saved new acls for network: %s", network.NetID))
|
||||
slog.Info(fmt.Sprintf("(migration) successfully saved new acls for network: %s", network.Name))
|
||||
}
|
||||
}
|
||||
|
||||
func updateNewAcls() {
|
||||
if servercfg.IsPro {
|
||||
userGroups, _ := logic.ListUserGroups()
|
||||
userGroupMap := make(map[models.UserGroupID]models.UserGroup)
|
||||
userGroups, _ := (&schema.UserGroup{}).ListAll(db.WithContext(context.TODO()))
|
||||
userGroupMap := make(map[schema.UserGroupID]schema.UserGroup)
|
||||
for _, userGroup := range userGroups {
|
||||
userGroupMap[userGroup.ID] = userGroup
|
||||
}
|
||||
@@ -815,14 +632,14 @@ func updateNewAcls() {
|
||||
aclSrc := make([]models.AclPolicyTag, 0)
|
||||
for _, src := range acl.Src {
|
||||
if src.ID == models.UserGroupAclID {
|
||||
userGroup, ok := userGroupMap[models.UserGroupID(src.Value)]
|
||||
userGroup, ok := userGroupMap[schema.UserGroupID(src.Value)]
|
||||
if !ok {
|
||||
// if the group doesn't exist, don't add it to the acl's src.
|
||||
continue
|
||||
} else {
|
||||
_, allNetworkAccess := userGroup.NetworkRoles[models.AllNetworks]
|
||||
_, allNetworkAccess := userGroup.NetworkRoles.Data()[schema.AllNetworks]
|
||||
if !allNetworkAccess {
|
||||
_, ok := userGroup.NetworkRoles[acl.NetworkID]
|
||||
_, ok := userGroup.NetworkRoles.Data()[acl.NetworkID]
|
||||
if !ok {
|
||||
// if the group doesn't have permissions for the acl's
|
||||
// network, don't add it to the acl's src.
|
||||
@@ -863,107 +680,46 @@ func MigrateEmqx() {
|
||||
|
||||
}
|
||||
|
||||
func migrateToUUIDs() {
|
||||
logic.MigrateToUUIDs()
|
||||
}
|
||||
|
||||
func syncUsers() {
|
||||
logger.Log(1, "Migrating Users (SyncUsers)")
|
||||
defer logger.Log(1, "Completed migrating Users (SyncUsers)")
|
||||
// create default network user roles for existing networks
|
||||
if servercfg.IsPro {
|
||||
networks, _ := logic.GetNetworks()
|
||||
networks, _ := (&schema.Network{}).ListAll(db.WithContext(context.TODO()))
|
||||
for _, netI := range networks {
|
||||
logic.CreateDefaultNetworkRolesAndGroups(models.NetworkID(netI.NetID))
|
||||
logic.CreateDefaultNetworkRolesAndGroups(schema.NetworkID(netI.Name))
|
||||
}
|
||||
}
|
||||
|
||||
users, err := logic.GetUsersDB()
|
||||
users, err := (&schema.User{}).ListAll(db.WithContext(context.TODO()))
|
||||
if err == nil {
|
||||
for _, user := range users {
|
||||
user := user
|
||||
needsUpdate := false
|
||||
|
||||
// Update admin flags based on platform role
|
||||
if user.PlatformRoleID == models.AdminRole && !user.IsAdmin {
|
||||
user.IsAdmin = true
|
||||
user.IsSuperAdmin = false
|
||||
needsUpdate = true
|
||||
user.AuthType = schema.BasicAuth
|
||||
if logic.IsOauthUser(&user) == nil {
|
||||
user.AuthType = schema.OAuth
|
||||
}
|
||||
if user.PlatformRoleID == models.SuperAdminRole && !user.IsSuperAdmin {
|
||||
user.IsSuperAdmin = true
|
||||
user.IsAdmin = true
|
||||
needsUpdate = true
|
||||
}
|
||||
if user.PlatformRoleID == models.PlatformUser || user.PlatformRoleID == models.ServiceUser {
|
||||
if user.IsSuperAdmin || user.IsAdmin {
|
||||
user.IsSuperAdmin = false
|
||||
user.IsAdmin = false
|
||||
needsUpdate = true
|
||||
}
|
||||
if len(user.UserGroups.Data()) == 0 {
|
||||
user.UserGroups = datatypes.NewJSONType(make(map[schema.UserGroupID]struct{}))
|
||||
}
|
||||
|
||||
if user.PlatformRoleID.String() != "" {
|
||||
// Initialize maps if nil
|
||||
if user.NetworkRoles == nil {
|
||||
user.NetworkRoles = make(map[models.NetworkID]map[models.UserRoleID]struct{})
|
||||
needsUpdate = true
|
||||
}
|
||||
if user.UserGroups == nil {
|
||||
user.UserGroups = make(map[models.UserGroupID]struct{})
|
||||
needsUpdate = true
|
||||
}
|
||||
// Migrate user roles and groups, then add global net roles
|
||||
user = logic.MigrateUserRoleAndGroups(user)
|
||||
logic.AddGlobalNetRolesToAdmins(&user)
|
||||
needsUpdate = true
|
||||
} else {
|
||||
// Set auth type
|
||||
user.AuthType = models.BasicAuth
|
||||
if logic.IsOauthUser(&user) == nil {
|
||||
user.AuthType = models.OAuth
|
||||
}
|
||||
if len(user.NetworkRoles) == 0 {
|
||||
user.NetworkRoles = make(map[models.NetworkID]map[models.UserRoleID]struct{})
|
||||
}
|
||||
if len(user.UserGroups) == 0 {
|
||||
user.UserGroups = make(map[models.UserGroupID]struct{})
|
||||
}
|
||||
|
||||
// We reach here only if the platform role id has not been set.
|
||||
//
|
||||
// Thus, we use the boolean fields to assign the role.
|
||||
if user.IsSuperAdmin {
|
||||
user.PlatformRoleID = models.SuperAdminRole
|
||||
} else if user.IsAdmin {
|
||||
user.PlatformRoleID = models.AdminRole
|
||||
} else {
|
||||
user.PlatformRoleID = models.ServiceUser
|
||||
}
|
||||
logic.AddGlobalNetRolesToAdmins(&user)
|
||||
user = logic.MigrateUserRoleAndGroups(user)
|
||||
needsUpdate = true
|
||||
}
|
||||
|
||||
// Only update user once after all changes are collected
|
||||
if needsUpdate {
|
||||
logic.UpsertUser(user)
|
||||
}
|
||||
logic.AddGlobalNetRolesToAdmins(&user)
|
||||
logic.UpsertUser(user)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func createDefaultTagsAndPolicies() {
|
||||
networks, err := logic.GetNetworks()
|
||||
networks, err := (&schema.Network{}).ListAll(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for _, network := range networks {
|
||||
logic.CreateDefaultTags(models.NetworkID(network.NetID))
|
||||
logic.CreateDefaultAclNetworkPolicies(models.NetworkID(network.NetID))
|
||||
logic.CreateDefaultTags(schema.NetworkID(network.Name))
|
||||
logic.CreateDefaultAclNetworkPolicies(schema.NetworkID(network.Name))
|
||||
// delete old remote access gws policy
|
||||
logic.DeleteAcl(models.Acl{ID: fmt.Sprintf("%s.%s", network.NetID, "all-remote-access-gws")})
|
||||
logic.DeleteAcl(models.Acl{ID: fmt.Sprintf("%s.%s", network.Name, "all-remote-access-gws")})
|
||||
}
|
||||
logic.MigrateAclPolicies()
|
||||
if !servercfg.IsPro {
|
||||
@@ -986,7 +742,10 @@ func migrateToEgressV1() {
|
||||
}
|
||||
for _, node := range nodes {
|
||||
if node.IsEgressGateway {
|
||||
_, err := logic.GetHost(node.HostID.String())
|
||||
host := &schema.Host{
|
||||
ID: node.HostID,
|
||||
}
|
||||
err := host.Get(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
@@ -1020,7 +779,7 @@ func migrateToEgressV1() {
|
||||
MetaData: "",
|
||||
Default: false,
|
||||
ServiceType: models.Any,
|
||||
NetworkID: models.NetworkID(node.Network),
|
||||
NetworkID: schema.NetworkID(node.Network),
|
||||
Proto: models.ALL,
|
||||
RuleType: models.DevicePolicy,
|
||||
Src: []models.AclPolicyTag{
|
||||
@@ -1049,7 +808,7 @@ func migrateToEgressV1() {
|
||||
MetaData: "",
|
||||
Default: false,
|
||||
ServiceType: models.Any,
|
||||
NetworkID: models.NetworkID(node.Network),
|
||||
NetworkID: schema.NetworkID(node.Network),
|
||||
Proto: models.ALL,
|
||||
RuleType: models.UserPolicy,
|
||||
Src: []models.AclPolicyTag{
|
||||
|
||||
@@ -1,183 +0,0 @@
|
||||
package migrate
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gravitl/netmaker/database"
|
||||
"github.com/gravitl/netmaker/db"
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
"github.com/gravitl/netmaker/servercfg"
|
||||
"gorm.io/gorm"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
// ToSQLSchema migrates the data from key-value
|
||||
// db to sql db.
|
||||
//
|
||||
// This function archives the old data and does not
|
||||
// delete it.
|
||||
//
|
||||
// Based on the db server, the archival is done in the
|
||||
// following way:
|
||||
//
|
||||
// 1. Sqlite: Moves the old data to a
|
||||
// netmaker_archive.db file.
|
||||
//
|
||||
// 2. Postgres: Moves the data to a netmaker_archive
|
||||
// schema within the same database.
|
||||
func ToSQLSchema() error {
|
||||
// initialize sql schema db.
|
||||
err := db.InitializeDB(schema.ListModels()...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// migrate, if not done already.
|
||||
err = migrate()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// archive key-value schema db, if not done already.
|
||||
// ignore errors.
|
||||
_ = archive()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func migrate() error {
|
||||
// begin a new transaction.
|
||||
dbctx := db.BeginTx(context.TODO())
|
||||
commit := false
|
||||
defer func() {
|
||||
if commit {
|
||||
db.FromContext(dbctx).Commit()
|
||||
} else {
|
||||
db.FromContext(dbctx).Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
// check if migrated already.
|
||||
migrationJob := &schema.Job{
|
||||
ID: "migration-v1.0.0",
|
||||
}
|
||||
err := migrationJob.Get(dbctx)
|
||||
if err != nil {
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return err
|
||||
}
|
||||
|
||||
// initialize key-value schema db.
|
||||
err := database.InitializeDatabase()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer database.CloseDB()
|
||||
|
||||
// migrate.
|
||||
// TODO: add migration code.
|
||||
|
||||
// mark migration job completed.
|
||||
err = migrationJob.Create(dbctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
commit = true
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func archive() error {
|
||||
dbServer := servercfg.GetDB()
|
||||
if dbServer != "sqlite" && dbServer != "postgres" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// begin a new transaction.
|
||||
dbctx := db.BeginTx(context.TODO())
|
||||
commit := false
|
||||
defer func() {
|
||||
if commit {
|
||||
db.FromContext(dbctx).Commit()
|
||||
} else {
|
||||
db.FromContext(dbctx).Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
// check if key-value schema db archived already.
|
||||
archivalJob := &schema.Job{
|
||||
ID: "archival-v1.0.0",
|
||||
}
|
||||
err := archivalJob.Get(dbctx)
|
||||
if err != nil {
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return err
|
||||
}
|
||||
|
||||
// archive.
|
||||
switch dbServer {
|
||||
case "sqlite":
|
||||
err = sqliteArchiveOldData()
|
||||
default:
|
||||
err = pgArchiveOldData()
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// mark archival job completed.
|
||||
err = archivalJob.Create(dbctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
commit = true
|
||||
} else {
|
||||
// remove the residual
|
||||
if dbServer == "sqlite" {
|
||||
_ = os.Remove(filepath.Join("data", "netmaker.db"))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func sqliteArchiveOldData() error {
|
||||
oldDBFilePath := filepath.Join("data", "netmaker.db")
|
||||
archiveDBFilePath := filepath.Join("data", "netmaker_archive.db")
|
||||
|
||||
// check if netmaker_archive.db exist.
|
||||
_, err := os.Stat(archiveDBFilePath)
|
||||
if err == nil {
|
||||
return nil
|
||||
} else if !os.IsNotExist(err) {
|
||||
return err
|
||||
}
|
||||
|
||||
// rename old db file to netmaker_archive.db.
|
||||
return os.Rename(oldDBFilePath, archiveDBFilePath)
|
||||
}
|
||||
|
||||
func pgArchiveOldData() error {
|
||||
_, err := database.PGDB.Exec("CREATE SCHEMA IF NOT EXISTS netmaker_archive")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, table := range database.Tables {
|
||||
_, err := database.PGDB.Exec(
|
||||
fmt.Sprintf(
|
||||
"ALTER TABLE public.%s SET SCHEMA netmaker_archive",
|
||||
table,
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
+23
-173
@@ -16,6 +16,7 @@
|
||||
package migrate
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -23,10 +24,10 @@ import (
|
||||
"github.com/gravitl/netmaker/database"
|
||||
"github.com/gravitl/netmaker/db"
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/datatypes"
|
||||
)
|
||||
|
||||
// TestSyncUsersLargeScale tests syncUsers() with a large number of users
|
||||
@@ -52,33 +53,25 @@ func TestSyncUsersLargeScale(t *testing.T) {
|
||||
startCreate := time.Now()
|
||||
|
||||
for i := 0; i < numUsers; i++ {
|
||||
user := models.User{
|
||||
UserName: "testuser" + uuid.New().String()[:8],
|
||||
user := schema.User{
|
||||
Username: "testuser" + uuid.New().String()[:8],
|
||||
Password: "testpassword123",
|
||||
DisplayName: "Test User " + uuid.New().String()[:8],
|
||||
IsAdmin: i%10 == 0, // 10% are admins
|
||||
IsSuperAdmin: i%100 == 0, // 1% are super admins
|
||||
PlatformRoleID: models.ServiceUser, // Most are service users
|
||||
PlatformRoleID: schema.ServiceUser, // Most are service users
|
||||
}
|
||||
|
||||
// Assign different platform roles
|
||||
if i%100 == 0 {
|
||||
user.PlatformRoleID = models.SuperAdminRole
|
||||
user.PlatformRoleID = schema.SuperAdminRole
|
||||
} else if i%10 == 0 {
|
||||
user.PlatformRoleID = models.AdminRole
|
||||
user.PlatformRoleID = schema.AdminRole
|
||||
} else if i%5 == 0 {
|
||||
user.PlatformRoleID = models.PlatformUser
|
||||
}
|
||||
|
||||
// Some users have network roles
|
||||
if i%3 == 0 {
|
||||
user.NetworkRoles = make(map[models.NetworkID]map[models.UserRoleID]struct{})
|
||||
user.NetworkRoles[models.NetworkID("test-network")] = make(map[models.UserRoleID]struct{})
|
||||
user.PlatformRoleID = schema.PlatformUser
|
||||
}
|
||||
|
||||
// Some users have user groups
|
||||
if i%4 == 0 {
|
||||
user.UserGroups = make(map[models.UserGroupID]struct{})
|
||||
user.UserGroups = datatypes.NewJSONType(make(map[schema.UserGroupID]struct{}))
|
||||
}
|
||||
|
||||
err := logic.UpsertUser(user)
|
||||
@@ -89,7 +82,7 @@ func TestSyncUsersLargeScale(t *testing.T) {
|
||||
t.Logf("Created %d users in %v (avg: %v per user)", numUsers, createDuration, createDuration/time.Duration(numUsers))
|
||||
|
||||
// Verify users were created
|
||||
users, err := logic.GetUsersDB()
|
||||
users, err := (&schema.User{}).ListAll(db.WithContext(context.TODO()))
|
||||
require.NoError(t, err)
|
||||
assert.GreaterOrEqual(t, len(users), numUsers, "Expected at least %d users", numUsers)
|
||||
|
||||
@@ -103,7 +96,7 @@ func TestSyncUsersLargeScale(t *testing.T) {
|
||||
syncDuration, len(users), syncDuration/time.Duration(len(users)))
|
||||
|
||||
// Verify users were migrated correctly
|
||||
usersAfter, err := logic.GetUsersDB()
|
||||
usersAfter, err := (&schema.User{}).ListAll(db.WithContext(context.TODO()))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, len(users), len(usersAfter), "User count should remain the same")
|
||||
|
||||
@@ -117,19 +110,7 @@ func TestSyncUsersLargeScale(t *testing.T) {
|
||||
user := usersAfter[i]
|
||||
|
||||
// Verify platform role is set
|
||||
assert.NotEmpty(t, user.PlatformRoleID.String(), "User %s should have PlatformRoleID", user.UserName)
|
||||
|
||||
// Verify admin flags match platform role
|
||||
if user.PlatformRoleID == models.SuperAdminRole {
|
||||
assert.True(t, user.IsSuperAdmin, "SuperAdmin user should have IsSuperAdmin=true")
|
||||
assert.True(t, user.IsAdmin, "SuperAdmin user should have IsAdmin=true")
|
||||
} else if user.PlatformRoleID == models.AdminRole {
|
||||
assert.True(t, user.IsAdmin, "Admin user should have IsAdmin=true")
|
||||
assert.False(t, user.IsSuperAdmin, "Admin user should not have IsSuperAdmin=true")
|
||||
} else {
|
||||
assert.False(t, user.IsSuperAdmin, "Non-admin user should not have IsSuperAdmin=true")
|
||||
assert.False(t, user.IsAdmin, "Non-admin user should not have IsAdmin=true")
|
||||
}
|
||||
assert.NotEmpty(t, user.PlatformRoleID.String(), "User %s should have PlatformRoleID", user.Username)
|
||||
|
||||
// Verify user groups are initialized
|
||||
assert.NotNil(t, user.UserGroups, "User should have UserGroups map")
|
||||
@@ -166,20 +147,20 @@ func TestMigrateToUUIDsLargeScale(t *testing.T) {
|
||||
startCreate := time.Now()
|
||||
|
||||
for i := 0; i < numUsers; i++ {
|
||||
user := models.User{
|
||||
UserName: "testuser" + uuid.New().String()[:8],
|
||||
user := schema.User{
|
||||
Username: "testuser" + uuid.New().String()[:8],
|
||||
Password: "testpassword123",
|
||||
DisplayName: "Test User " + uuid.New().String()[:8],
|
||||
PlatformRoleID: models.ServiceUser,
|
||||
UserGroups: make(map[models.UserGroupID]struct{}),
|
||||
PlatformRoleID: schema.ServiceUser,
|
||||
UserGroups: datatypes.NewJSONType(make(map[schema.UserGroupID]struct{})),
|
||||
}
|
||||
|
||||
// Add some user groups with non-UUID IDs (to trigger migration)
|
||||
if i%2 == 0 {
|
||||
user.UserGroups[models.UserGroupID("old-group-1")] = struct{}{}
|
||||
user.UserGroups.Data()[("old-group-1")] = struct{}{}
|
||||
}
|
||||
if i%3 == 0 {
|
||||
user.UserGroups[models.UserGroupID("old-group-2")] = struct{}{}
|
||||
user.UserGroups.Data()[("old-group-2")] = struct{}{}
|
||||
}
|
||||
|
||||
err := logic.UpsertUser(user)
|
||||
@@ -190,140 +171,9 @@ func TestMigrateToUUIDsLargeScale(t *testing.T) {
|
||||
t.Logf("Created %d users in %v", numUsers, createDuration)
|
||||
|
||||
// Verify users were created
|
||||
users, err := logic.GetUsersDB()
|
||||
users, err := (&schema.User{}).ListAll(db.WithContext(context.TODO()))
|
||||
require.NoError(t, err)
|
||||
assert.GreaterOrEqual(t, len(users), numUsers, "Expected at least %d users", numUsers)
|
||||
|
||||
// Test MigrateToUUIDs() performance
|
||||
t.Log("Running MigrateToUUIDs() migration...")
|
||||
startMigrate := time.Now()
|
||||
migrateToUUIDs()
|
||||
migrateDuration := time.Since(startMigrate)
|
||||
|
||||
t.Logf("MigrateToUUIDs() completed in %v for %d users (avg: %v per user)",
|
||||
migrateDuration, len(users), migrateDuration/time.Duration(len(users)))
|
||||
|
||||
// Verify users still exist after migration
|
||||
usersAfter, err := logic.GetUsersDB()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, len(users), len(usersAfter), "User count should remain the same")
|
||||
|
||||
// Performance assertion - MigrateToUUIDs should complete in reasonable time
|
||||
maxDuration := 30 * time.Second
|
||||
assert.Less(t, migrateDuration, maxDuration,
|
||||
"MigrateToUUIDs() took too long: %v (expected < %v)", migrateDuration, maxDuration)
|
||||
|
||||
t.Logf("✓ MigrateToUUIDs() performance test passed: %v for %d users", migrateDuration, len(users))
|
||||
}
|
||||
|
||||
// TestSyncUsersCorrectness tests that syncUsers() correctly migrates user data
|
||||
func TestSyncUsersCorrectness(t *testing.T) {
|
||||
// Initialize test database
|
||||
err := db.InitializeDB(schema.ListModels()...)
|
||||
require.NoError(t, err)
|
||||
defer db.CloseDB()
|
||||
|
||||
err = database.InitializeDatabase()
|
||||
require.NoError(t, err)
|
||||
defer database.CloseDB()
|
||||
|
||||
// Create test users with various states
|
||||
testCases := []struct {
|
||||
name string
|
||||
user models.User
|
||||
expectedRole models.UserRoleID
|
||||
expectedAdmin bool
|
||||
expectedSuper bool
|
||||
}{
|
||||
{
|
||||
name: "user with AdminRole but IsAdmin=false",
|
||||
user: models.User{
|
||||
UserName: "admin1",
|
||||
Password: "password",
|
||||
PlatformRoleID: models.AdminRole,
|
||||
IsAdmin: false,
|
||||
IsSuperAdmin: false,
|
||||
},
|
||||
expectedRole: models.AdminRole,
|
||||
expectedAdmin: true,
|
||||
expectedSuper: false,
|
||||
},
|
||||
{
|
||||
name: "user with SuperAdminRole but IsSuperAdmin=false",
|
||||
user: models.User{
|
||||
UserName: "superadmin1",
|
||||
Password: "password",
|
||||
PlatformRoleID: models.SuperAdminRole,
|
||||
IsAdmin: false,
|
||||
IsSuperAdmin: false,
|
||||
},
|
||||
expectedRole: models.SuperAdminRole,
|
||||
expectedAdmin: true,
|
||||
expectedSuper: true,
|
||||
},
|
||||
{
|
||||
name: "user with IsSuperAdmin=true but no PlatformRoleID",
|
||||
user: models.User{
|
||||
UserName: "superadmin2",
|
||||
Password: "password",
|
||||
PlatformRoleID: "",
|
||||
IsAdmin: true,
|
||||
IsSuperAdmin: true,
|
||||
},
|
||||
expectedRole: models.SuperAdminRole,
|
||||
expectedAdmin: true,
|
||||
expectedSuper: true,
|
||||
},
|
||||
{
|
||||
name: "user with IsAdmin=true but no PlatformRoleID",
|
||||
user: models.User{
|
||||
UserName: "admin2",
|
||||
Password: "password",
|
||||
PlatformRoleID: "",
|
||||
IsAdmin: true,
|
||||
IsSuperAdmin: false,
|
||||
},
|
||||
expectedRole: models.AdminRole,
|
||||
expectedAdmin: true,
|
||||
expectedSuper: false,
|
||||
},
|
||||
{
|
||||
name: "regular user with no role",
|
||||
user: models.User{
|
||||
UserName: "user1",
|
||||
Password: "password",
|
||||
PlatformRoleID: "",
|
||||
IsAdmin: false,
|
||||
IsSuperAdmin: false,
|
||||
},
|
||||
expectedRole: models.ServiceUser,
|
||||
expectedAdmin: false,
|
||||
expectedSuper: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Create user
|
||||
err := logic.UpsertUser(tc.user)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Run migration
|
||||
syncUsers()
|
||||
|
||||
// Verify migration
|
||||
user, err := logic.GetUser(tc.user.UserName)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tc.expectedRole, user.PlatformRoleID, "PlatformRoleID should match")
|
||||
assert.Equal(t, tc.expectedAdmin, user.IsAdmin, "IsAdmin should match")
|
||||
assert.Equal(t, tc.expectedSuper, user.IsSuperAdmin, "IsSuperAdmin should match")
|
||||
assert.NotNil(t, user.UserGroups, "UserGroups should be initialized")
|
||||
assert.NotNil(t, user.NetworkRoles, "NetworkRoles should be initialized")
|
||||
|
||||
// Cleanup
|
||||
_ = database.DeleteRecord(database.USERS_TABLE_NAME, tc.user.UserName)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkSyncUsers benchmarks syncUsers() performance
|
||||
@@ -344,13 +194,13 @@ func BenchmarkSyncUsers(b *testing.B) {
|
||||
// Create test users
|
||||
numUsers := 1000
|
||||
for i := 0; i < numUsers; i++ {
|
||||
user := models.User{
|
||||
UserName: "benchuser" + uuid.New().String()[:8],
|
||||
user := schema.User{
|
||||
Username: "benchuser" + uuid.New().String()[:8],
|
||||
Password: "password",
|
||||
PlatformRoleID: models.ServiceUser,
|
||||
PlatformRoleID: schema.ServiceUser,
|
||||
}
|
||||
if i%10 == 0 {
|
||||
user.PlatformRoleID = models.AdminRole
|
||||
user.PlatformRoleID = schema.AdminRole
|
||||
}
|
||||
_ = logic.UpsertUser(user)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,354 @@
|
||||
package migrate
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/gravitl/netmaker/database"
|
||||
"github.com/gravitl/netmaker/db"
|
||||
"github.com/gravitl/netmaker/logger"
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
"gorm.io/datatypes"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// ToSQLSchema migrates the data from key-value
|
||||
// db to sql db.
|
||||
func ToSQLSchema() error {
|
||||
// begin a new transaction.
|
||||
dbctx := db.BeginTx(context.TODO())
|
||||
commit := false
|
||||
defer func() {
|
||||
if commit {
|
||||
db.FromContext(dbctx).Commit()
|
||||
} else {
|
||||
db.FromContext(dbctx).Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
// v1.5.1 migration includes migrating the users, groups, roles, networks and hosts tables.
|
||||
// future table migrations should be made below this block,
|
||||
// with a different version number and a similar check for whether the
|
||||
// migration was already done.
|
||||
migrationJob := &schema.Job{
|
||||
ID: "migration-v1.5.1",
|
||||
}
|
||||
err := migrationJob.Get(dbctx)
|
||||
if err != nil {
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return err
|
||||
}
|
||||
|
||||
logger.Log(1, fmt.Sprintf("running migration job %s", migrationJob.ID))
|
||||
// migrate.
|
||||
err = migrateV1_5_1(dbctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// mark migration job completed.
|
||||
err = migrationJob.Create(dbctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
logger.Log(1, fmt.Sprintf("migration job %s completed", migrationJob.ID))
|
||||
commit = true
|
||||
} else {
|
||||
logger.Log(1, fmt.Sprintf("migration job %s already completed, skipping", migrationJob.ID))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func migrateV1_5_1(ctx context.Context) error {
|
||||
err := migrateUsers(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = migrateNetworks(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = migrateUserRoles(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = migrateUserGroups(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return migrateHosts(ctx)
|
||||
}
|
||||
|
||||
func migrateUsers(ctx context.Context) error {
|
||||
records, err := database.FetchRecords(database.USERS_TABLE_NAME)
|
||||
if err != nil && !database.IsEmptyRecord(err) {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, record := range records {
|
||||
var user models.User
|
||||
err = json.Unmarshal([]byte(record), &user)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
platformRoleID := user.PlatformRoleID
|
||||
if user.PlatformRoleID == "" {
|
||||
if user.IsSuperAdmin {
|
||||
platformRoleID = schema.SuperAdminRole
|
||||
} else if user.IsAdmin {
|
||||
platformRoleID = schema.AdminRole
|
||||
} else {
|
||||
platformRoleID = schema.ServiceUser
|
||||
}
|
||||
}
|
||||
|
||||
_user := schema.User{
|
||||
ID: "",
|
||||
Username: user.UserName,
|
||||
DisplayName: user.DisplayName,
|
||||
PlatformRoleID: platformRoleID,
|
||||
ExternalIdentityProviderID: user.ExternalIdentityProviderID,
|
||||
AccountDisabled: user.AccountDisabled,
|
||||
AuthType: user.AuthType,
|
||||
Password: user.Password,
|
||||
IsMFAEnabled: user.IsMFAEnabled,
|
||||
TOTPSecret: user.TOTPSecret,
|
||||
LastLoginAt: user.LastLoginTime,
|
||||
UserGroups: datatypes.NewJSONType(user.UserGroups),
|
||||
CreatedBy: user.CreatedBy,
|
||||
CreatedAt: user.CreatedAt,
|
||||
UpdatedAt: user.UpdatedAt,
|
||||
}
|
||||
|
||||
logger.Log(4, fmt.Sprintf("migrating user %s", _user.Username))
|
||||
|
||||
err = _user.Create(ctx)
|
||||
if err != nil {
|
||||
logger.Log(4, fmt.Sprintf("migrating user %s failed: %v", _user.Username, err))
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func migrateNetworks(ctx context.Context) error {
|
||||
records, err := database.FetchRecords(database.NETWORKS_TABLE_NAME)
|
||||
if err != nil && !database.IsEmptyRecord(err) {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, record := range records {
|
||||
var network models.Network
|
||||
err = json.Unmarshal([]byte(record), &network)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var autoJoin, autoRemove, jitEnabled bool
|
||||
|
||||
if network.AutoJoin == "false" {
|
||||
autoJoin = false
|
||||
} else {
|
||||
autoJoin = true
|
||||
}
|
||||
|
||||
if network.AutoRemove == "true" {
|
||||
autoRemove = true
|
||||
} else {
|
||||
autoRemove = false
|
||||
}
|
||||
|
||||
if network.JITEnabled == "yes" {
|
||||
jitEnabled = true
|
||||
} else {
|
||||
jitEnabled = false
|
||||
}
|
||||
|
||||
_network := &schema.Network{
|
||||
ID: "",
|
||||
Name: network.NetID,
|
||||
AddressRange: network.AddressRange,
|
||||
AddressRange6: network.AddressRange6,
|
||||
DefaultKeepAlive: int(network.DefaultKeepalive),
|
||||
DefaultACL: network.DefaultACL,
|
||||
DefaultMTU: network.DefaultMTU,
|
||||
AutoJoin: autoJoin,
|
||||
AutoRemove: autoRemove,
|
||||
AutoRemoveTags: network.AutoRemoveTags,
|
||||
AutoRemoveThreshold: network.AutoRemoveThreshold,
|
||||
JITEnabled: jitEnabled,
|
||||
VirtualNATPoolIPv4: network.VirtualNATPoolIPv4,
|
||||
VirtualNATSitePrefixLenIPv4: network.VirtualNATSitePrefixLenIPv4,
|
||||
NodesUpdatedAt: time.Unix(network.NodesLastModified, 0),
|
||||
CreatedBy: network.CreatedBy,
|
||||
CreatedAt: network.CreatedAt,
|
||||
UpdatedAt: time.Unix(network.NetworkLastModified, 0),
|
||||
}
|
||||
|
||||
logger.Log(4, fmt.Sprintf("migrating network %s", _network.Name))
|
||||
|
||||
err = _network.Create(ctx)
|
||||
if err != nil {
|
||||
logger.Log(4, fmt.Sprintf("migrating network %s failed: %v", _network.Name, err))
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func migrateUserRoles(ctx context.Context) error {
|
||||
records, err := database.FetchRecords(database.USER_PERMISSIONS_TABLE_NAME)
|
||||
if err != nil && !database.IsEmptyRecord(err) {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, record := range records {
|
||||
var _userRole schema.UserRole
|
||||
err = json.Unmarshal([]byte(record), &_userRole)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
logger.Log(4, fmt.Sprintf("migrating user role %s", _userRole.ID))
|
||||
|
||||
err = _userRole.Create(ctx)
|
||||
if err != nil {
|
||||
logger.Log(4, fmt.Sprintf("migrating user role %s failed: %v", _userRole.ID, err))
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func migrateUserGroups(ctx context.Context) error {
|
||||
records, err := database.FetchRecords(database.USER_GROUPS_TABLE_NAME)
|
||||
if err != nil && !database.IsEmptyRecord(err) {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, record := range records {
|
||||
var _userGroup schema.UserGroup
|
||||
err = json.Unmarshal([]byte(record), &_userGroup)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
logger.Log(4, fmt.Sprintf("migrating user group %s", _userGroup.ID))
|
||||
|
||||
err = _userGroup.Create(ctx)
|
||||
if err != nil {
|
||||
logger.Log(4, fmt.Sprintf("migrating user group %s failed: %v", _userGroup.ID, err))
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func migrateHosts(ctx context.Context) error {
|
||||
records, err := database.FetchRecords(database.HOSTS_TABLE_NAME)
|
||||
if err != nil && !database.IsEmptyRecord(err) {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, record := range records {
|
||||
var host models.Host
|
||||
err = json.Unmarshal([]byte(record), &host)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_host := &schema.Host{
|
||||
ID: host.ID,
|
||||
Verbosity: host.Verbosity,
|
||||
FirewallInUse: host.FirewallInUse,
|
||||
Version: host.Version,
|
||||
IPForwarding: host.IPForwarding,
|
||||
DaemonInstalled: host.DaemonInstalled,
|
||||
AutoUpdate: host.AutoUpdate,
|
||||
HostPass: host.HostPass,
|
||||
Name: host.Name,
|
||||
OS: host.OS,
|
||||
OSFamily: host.OSFamily,
|
||||
OSVersion: host.OSVersion,
|
||||
KernelVersion: host.KernelVersion,
|
||||
Interface: host.Interface,
|
||||
Debug: host.Debug,
|
||||
ListenPort: host.ListenPort,
|
||||
WgPublicListenPort: host.WgPublicListenPort,
|
||||
MTU: host.MTU,
|
||||
PublicKey: schema.WgKey{
|
||||
Key: host.PublicKey,
|
||||
},
|
||||
MacAddress: host.MacAddress,
|
||||
TrafficKeyPublic: host.TrafficKeyPublic,
|
||||
Nodes: host.Nodes,
|
||||
Interfaces: host.Interfaces,
|
||||
DefaultInterface: host.DefaultInterface,
|
||||
EndpointIP: host.EndpointIP,
|
||||
EndpointIPv6: host.EndpointIPv6,
|
||||
IsDocker: host.IsDocker,
|
||||
IsK8S: host.IsK8S,
|
||||
IsStaticPort: host.IsStaticPort,
|
||||
IsStatic: host.IsStatic,
|
||||
IsDefault: host.IsDefault,
|
||||
DNS: host.DNS,
|
||||
NatType: host.NatType,
|
||||
TurnEndpoint: nil,
|
||||
PersistentKeepalive: host.PersistentKeepalive,
|
||||
Location: host.Location,
|
||||
CountryCode: host.CountryCode,
|
||||
EnableFlowLogs: host.EnableFlowLogs,
|
||||
}
|
||||
|
||||
if host.TurnEndpoint != nil {
|
||||
_host.TurnEndpoint = &schema.AddrPort{
|
||||
AddrPort: *host.TurnEndpoint,
|
||||
}
|
||||
}
|
||||
|
||||
if _host.PersistentKeepalive == 0 {
|
||||
_host.PersistentKeepalive = models.DefaultPersistentKeepAlive
|
||||
}
|
||||
|
||||
if _host.DNS == "" || (_host.DNS != "yes" && _host.DNS != "no") {
|
||||
if logic.GetServerSettings().ManageDNS {
|
||||
_host.DNS = "yes"
|
||||
} else {
|
||||
_host.DNS = "no"
|
||||
}
|
||||
if _host.IsDefault {
|
||||
_host.DNS = "yes"
|
||||
}
|
||||
}
|
||||
|
||||
if _host.IsDefault && !_host.AutoUpdate {
|
||||
_host.AutoUpdate = true
|
||||
}
|
||||
|
||||
logger.Log(4, fmt.Sprintf("migrating host %s", _host.ID))
|
||||
|
||||
err = _host.Create(ctx)
|
||||
if err != nil {
|
||||
logger.Log(4, fmt.Sprintf("migrating host %s failed: %v", _host.ID, err))
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
+3
-1
@@ -3,6 +3,8 @@ package models
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
)
|
||||
|
||||
// AllowedTrafficDirection - allowed direction of traffic
|
||||
@@ -84,7 +86,7 @@ type Acl struct {
|
||||
Default bool `json:"default"`
|
||||
MetaData string `json:"meta_data"`
|
||||
Name string `json:"name"`
|
||||
NetworkID NetworkID `json:"network_id"`
|
||||
NetworkID schema.NetworkID `json:"network_id"`
|
||||
RuleType AclPolicyType `json:"policy_type"`
|
||||
Src []AclPolicyTag `json:"src_type"`
|
||||
Dst []AclPolicyTag `json:"dst_type"`
|
||||
|
||||
+6
-4
@@ -4,6 +4,8 @@ import (
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
)
|
||||
|
||||
// ApiHost - the host struct for API usage
|
||||
@@ -47,8 +49,8 @@ type ApiIface struct {
|
||||
AddressString string `json:"addressString"`
|
||||
}
|
||||
|
||||
// Host.ConvertNMHostToAPI - converts a Netmaker host to an API editable host
|
||||
func (h *Host) ConvertNMHostToAPI() *ApiHost {
|
||||
// NewApiHostFromSchemaHost - converts a Netmaker host to an API editable host
|
||||
func NewApiHostFromSchemaHost(h *schema.Host) *ApiHost {
|
||||
a := ApiHost{}
|
||||
a.Debug = h.Debug
|
||||
a.EndpointIP = h.EndpointIP.String()
|
||||
@@ -96,8 +98,8 @@ func (h *Host) ConvertNMHostToAPI() *ApiHost {
|
||||
|
||||
// APIHost.ConvertAPIHostToNMHost - convert's a given apihost struct to
|
||||
// a Host struct
|
||||
func (a *ApiHost) ConvertAPIHostToNMHost(currentHost *Host) *Host {
|
||||
h := Host{}
|
||||
func (a *ApiHost) ConvertAPIHostToNMHost(currentHost *schema.Host) *schema.Host {
|
||||
h := schema.Host{}
|
||||
h.ID = currentHost.ID
|
||||
h.HostPass = currentHost.HostPass
|
||||
h.DaemonInstalled = currentHost.DaemonInstalled
|
||||
|
||||
+2
-1
@@ -5,6 +5,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
"golang.org/x/exp/slog"
|
||||
)
|
||||
|
||||
@@ -69,7 +70,7 @@ type ApiNode struct {
|
||||
Location string `json:"location"`
|
||||
Country string `json:"country"`
|
||||
PostureChecksViolations []Violation `json:"posture_check_violations"`
|
||||
PostureCheckVolationSeverityLevel Severity `json:"posture_check_violation_severity_level"`
|
||||
PostureCheckVolationSeverityLevel schema.Severity `json:"posture_check_violation_severity_level"`
|
||||
LastEvaluatedAt time.Time `json:"last_evaluated_at"`
|
||||
}
|
||||
|
||||
|
||||
+13
-18
@@ -1,23 +1,18 @@
|
||||
package models
|
||||
|
||||
type EgressNATMode string
|
||||
|
||||
const (
|
||||
VirtualNAT EgressNATMode = "virtual_nat"
|
||||
DirectNAT EgressNATMode = "direct_nat"
|
||||
)
|
||||
import "github.com/gravitl/netmaker/schema"
|
||||
|
||||
type EgressReq struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Network string `json:"network"`
|
||||
Description string `json:"description"`
|
||||
Nodes map[string]int `json:"nodes"`
|
||||
Tags map[string]int `json:"tags"`
|
||||
Range string `json:"range"`
|
||||
Domain string `json:"domain"`
|
||||
Nat bool `json:"nat"`
|
||||
Mode EgressNATMode `json:"mode"`
|
||||
Status bool `json:"status"`
|
||||
IsInetGw bool `json:"is_internet_gateway"`
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Network string `json:"network"`
|
||||
Description string `json:"description"`
|
||||
Nodes map[string]int `json:"nodes"`
|
||||
Tags map[string]int `json:"tags"`
|
||||
Range string `json:"range"`
|
||||
Domain string `json:"domain"`
|
||||
Nat bool `json:"nat"`
|
||||
Mode schema.EgressNATMode `json:"mode"`
|
||||
Status bool `json:"status"`
|
||||
IsInetGw bool `json:"is_internet_gateway"`
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -75,7 +76,7 @@ type APIEnrollmentKey struct {
|
||||
// RegisterResponse - the response to a successful enrollment register
|
||||
type RegisterResponse struct {
|
||||
ServerConf ServerConfig `json:"server_config"`
|
||||
RequestedHost Host `json:"requested_host"`
|
||||
RequestedHost schema.Host `json:"requested_host"`
|
||||
}
|
||||
|
||||
// EnrollmentKey.IsValid - checks if the key is still valid to use
|
||||
|
||||
+8
-80
@@ -1,84 +1,12 @@
|
||||
package models
|
||||
|
||||
type Action string
|
||||
|
||||
const (
|
||||
Create Action = "CREATE"
|
||||
Update Action = "UPDATE"
|
||||
Delete Action = "DELETE"
|
||||
DeleteAll Action = "DELETE_ALL"
|
||||
Login Action = "LOGIN"
|
||||
LogOut Action = "LOGOUT"
|
||||
Connect Action = "CONNECT"
|
||||
Sync Action = "SYNC"
|
||||
RefreshKey Action = "REFRESH_KEY"
|
||||
RefreshAllKeys Action = "REFRESH_ALL_KEYS"
|
||||
SyncAll Action = "SYNC_ALL"
|
||||
UpgradeAll Action = "UPGRADE_ALL"
|
||||
Disconnect Action = "DISCONNECT"
|
||||
JoinHostToNet Action = "JOIN_HOST_TO_NETWORK"
|
||||
RemoveHostFromNet Action = "REMOVE_HOST_FROM_NETWORK"
|
||||
EnableMFA Action = "ENABLE_MFA"
|
||||
DisableMFA Action = "DISABLE_MFA"
|
||||
EnforceMFA Action = "ENFORCE_MFA"
|
||||
UnenforceMFA Action = "UNENFORCE_MFA"
|
||||
EnableBasicAuth Action = "ENABLE_BASIC_AUTH"
|
||||
DisableBasicAuth Action = "DISABLE_BASIC_AUTH"
|
||||
EnableTelemetry Action = "ENABLE_TELEMETRY"
|
||||
DisableTelemetry Action = "DISABLE_TELEMETRY"
|
||||
UpdateClientSettings Action = "UPDATE_CLIENT_SETTINGS"
|
||||
UpdateAuthenticationSecuritySettings Action = "UPDATE_AUTHENTICATION_SECURITY_SETTINGS"
|
||||
UpdateMonitoringAndDebuggingSettings Action = "UPDATE_MONITORING_AND_DEBUGGING_SETTINGS"
|
||||
UpdateSMTPSettings Action = "UPDATE_EMAIL_SETTINGS"
|
||||
UpdateIDPSettings Action = "UPDATE_IDP_SETTINGS"
|
||||
EnableFlowLogs Action = "ENABLE_FLOW_LOGS"
|
||||
DisableFlowLogs Action = "DISABLE_FLOW_LOGS"
|
||||
GatewayAssign Action = "GATEWAY_ASSIGN"
|
||||
GatewayUnAssign Action = "GATEWAY_UNASSIGN"
|
||||
)
|
||||
|
||||
type SubjectType string
|
||||
|
||||
const (
|
||||
UserSub SubjectType = "USER"
|
||||
UserAccessTokenSub SubjectType = "USER_ACCESS_TOKEN"
|
||||
DeviceSub SubjectType = "DEVICE"
|
||||
NodeSub SubjectType = "NODE"
|
||||
GatewaySub SubjectType = "GATEWAY"
|
||||
SettingSub SubjectType = "SETTING"
|
||||
AclSub SubjectType = "ACL"
|
||||
TagSub SubjectType = "TAG"
|
||||
UserRoleSub SubjectType = "USER_ROLE"
|
||||
UserGroupSub SubjectType = "USER_GROUP"
|
||||
UserInviteSub SubjectType = "USER_INVITE"
|
||||
PendingUserSub SubjectType = "PENDING_USER"
|
||||
EgressSub SubjectType = "EGRESS"
|
||||
NetworkSub SubjectType = "NETWORK"
|
||||
DashboardSub SubjectType = "DASHBOARD"
|
||||
EnrollmentKeySub SubjectType = "ENROLLMENT_KEY"
|
||||
ClientAppSub SubjectType = "CLIENT-APP"
|
||||
NameserverSub SubjectType = "NAMESERVER"
|
||||
PostureCheckSub SubjectType = "POSTURE_CHECK"
|
||||
)
|
||||
|
||||
func (sub SubjectType) String() string {
|
||||
return string(sub)
|
||||
}
|
||||
|
||||
type Origin string
|
||||
|
||||
const (
|
||||
Dashboard Origin = "DASHBOARD"
|
||||
Api Origin = "API"
|
||||
NMCTL Origin = "NMCTL"
|
||||
ClientApp Origin = "CLIENT-APP"
|
||||
)
|
||||
import "github.com/gravitl/netmaker/schema"
|
||||
|
||||
type Subject struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Type SubjectType `json:"subject_type"`
|
||||
Info interface{} `json:"info"`
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Type schema.SubjectType `json:"subject_type"`
|
||||
Info interface{} `json:"info"`
|
||||
}
|
||||
|
||||
type Diff struct {
|
||||
@@ -87,11 +15,11 @@ type Diff struct {
|
||||
}
|
||||
|
||||
type Event struct {
|
||||
Action Action
|
||||
Action schema.Action
|
||||
Source Subject
|
||||
Origin Origin
|
||||
Origin schema.Origin
|
||||
Target Subject
|
||||
TriggeredBy string
|
||||
NetworkID NetworkID
|
||||
NetworkID schema.NetworkID
|
||||
Diff Diff
|
||||
}
|
||||
|
||||
+3
-1
@@ -3,6 +3,8 @@ package models
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
)
|
||||
|
||||
// ExtClient - struct for external clients
|
||||
@@ -37,7 +39,7 @@ type ExtClient struct {
|
||||
Country string `json:"country"`
|
||||
Location string `json:"location"` //format: lat,long
|
||||
PostureChecksViolations []Violation `json:"posture_check_violations"`
|
||||
PostureCheckVolationSeverityLevel Severity `json:"posture_check_violation_severity_level"`
|
||||
PostureCheckVolationSeverityLevel schema.Severity `json:"posture_check_violation_severity_level"`
|
||||
LastEvaluatedAt time.Time `json:"last_evaluated_at"`
|
||||
JITExpiresAt *time.Time `json:"jit_expires_at,omitempty" bson:"jit_expires_at,omitempty"` // JIT grant expiry time (nil if JIT not enabled or user is admin)
|
||||
Mutex *sync.Mutex `json:"-"`
|
||||
|
||||
+9
-8
@@ -6,6 +6,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
@@ -63,7 +64,7 @@ type Host struct {
|
||||
MacAddress net.HardwareAddr `json:"macaddress" yaml:"macaddress"`
|
||||
TrafficKeyPublic []byte `json:"traffickeypublic" yaml:"traffickeypublic"`
|
||||
Nodes []string `json:"nodes" yaml:"nodes"`
|
||||
Interfaces []Iface `json:"interfaces" yaml:"interfaces"`
|
||||
Interfaces []schema.Iface `json:"interfaces" yaml:"interfaces"`
|
||||
DefaultInterface string `json:"defaultinterface" yaml:"defaultinterface"`
|
||||
EndpointIP net.IP `json:"endpointip" yaml:"endpointip"`
|
||||
EndpointIPv6 net.IP `json:"endpointipv6" yaml:"endpointipv6"`
|
||||
@@ -150,7 +151,7 @@ const (
|
||||
// HostUpdate - struct for host update
|
||||
type HostUpdate struct {
|
||||
Action HostMqAction
|
||||
Host Host
|
||||
Host schema.Host
|
||||
Node Node
|
||||
Signal Signal
|
||||
EgressDomain EgressDomain
|
||||
@@ -182,10 +183,10 @@ type Signal struct {
|
||||
|
||||
// RegisterMsg - login message struct for hosts to join via SSO login
|
||||
type RegisterMsg struct {
|
||||
RegisterHost Host `json:"host"`
|
||||
Network string `json:"network,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
Password string `json:"password,omitempty"`
|
||||
JoinAll bool `json:"join_all,omitempty"`
|
||||
Relay string `json:"relay,omitempty"`
|
||||
RegisterHost schema.Host `json:"host"`
|
||||
Network string `json:"network,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
Password string `json:"password,omitempty"`
|
||||
JoinAll bool `json:"join_all,omitempty"`
|
||||
Relay string `json:"relay,omitempty"`
|
||||
}
|
||||
|
||||
+7
-5
@@ -2,6 +2,8 @@ package models
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
)
|
||||
|
||||
// Metrics - metrics struct
|
||||
@@ -48,11 +50,11 @@ type HostInfoMap map[string]HostNetworkInfo
|
||||
|
||||
// HostNetworkInfo - holds info related to host networking (used for client side peer calculations)
|
||||
type HostNetworkInfo struct {
|
||||
Interfaces []Iface `json:"interfaces" yaml:"interfaces"`
|
||||
ListenPort int `json:"listen_port" yaml:"listen_port"`
|
||||
IsStaticPort bool `json:"is_static_port"`
|
||||
IsStatic bool `json:"is_static"`
|
||||
Version string `json:"version"`
|
||||
Interfaces []schema.Iface `json:"interfaces" yaml:"interfaces"`
|
||||
ListenPort int `json:"listen_port" yaml:"listen_port"`
|
||||
IsStaticPort bool `json:"is_static_port"`
|
||||
IsStatic bool `json:"is_static"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
// PeerMap - peer map for ids and addresses in metrics
|
||||
|
||||
+28
-27
@@ -3,11 +3,12 @@ package models
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
type HostPeerInfo struct {
|
||||
NetworkPeerIDs map[NetworkID]PeerMap `json:"network_peers"`
|
||||
NetworkPeerIDs map[schema.NetworkID]PeerMap `json:"network_peers"`
|
||||
}
|
||||
|
||||
type PeerType int
|
||||
@@ -27,37 +28,37 @@ type PeerIdentity struct {
|
||||
|
||||
// HostPeerUpdate - struct for host peer updates
|
||||
type HostPeerUpdate struct {
|
||||
Host Host `json:"host"`
|
||||
Nodes []Node `json:"nodes"`
|
||||
ChangeDefaultGw bool `json:"change_default_gw"`
|
||||
DefaultGwIp net.IP `json:"default_gw_ip"`
|
||||
IsInternetGw bool `json:"is_inet_gw"`
|
||||
NodeAddrs []net.IPNet `json:"nodes_addrs"`
|
||||
Server string `json:"server"`
|
||||
ServerVersion string `json:"serverversion"`
|
||||
ServerAddrs []ServerAddr `json:"serveraddrs"`
|
||||
NodePeers []wgtypes.PeerConfig `json:"node_peers"`
|
||||
Peers []wgtypes.PeerConfig `json:"host_peers"`
|
||||
PeerIDs PeerMap `json:"peerids"`
|
||||
HostNetworkInfo HostInfoMap `json:"host_network_info,omitempty"`
|
||||
EgressRoutes []EgressNetworkRoutes `json:"egress_network_routes"`
|
||||
FwUpdate FwUpdate `json:"fw_update"`
|
||||
ReplacePeers bool `json:"replace_peers"`
|
||||
NameServers []string `json:"name_servers"`
|
||||
DnsNameservers []Nameserver `json:"dns_nameservers"`
|
||||
EgressWithDomains []EgressDomain `json:"egress_with_domains"`
|
||||
AutoRelayNodes map[NetworkID][]Node `json:"auto_relay_nodes"`
|
||||
GwNodes map[NetworkID][]Node `json:"gw_nodes"`
|
||||
AddressIdentityMap map[string]PeerIdentity `json:"address_identity_map"`
|
||||
Host schema.Host `json:"host"`
|
||||
Nodes []Node `json:"nodes"`
|
||||
ChangeDefaultGw bool `json:"change_default_gw"`
|
||||
DefaultGwIp net.IP `json:"default_gw_ip"`
|
||||
IsInternetGw bool `json:"is_inet_gw"`
|
||||
NodeAddrs []net.IPNet `json:"nodes_addrs"`
|
||||
Server string `json:"server"`
|
||||
ServerVersion string `json:"serverversion"`
|
||||
ServerAddrs []ServerAddr `json:"serveraddrs"`
|
||||
NodePeers []wgtypes.PeerConfig `json:"node_peers"`
|
||||
Peers []wgtypes.PeerConfig `json:"host_peers"`
|
||||
PeerIDs PeerMap `json:"peerids"`
|
||||
HostNetworkInfo HostInfoMap `json:"host_network_info,omitempty"`
|
||||
EgressRoutes []EgressNetworkRoutes `json:"egress_network_routes"`
|
||||
FwUpdate FwUpdate `json:"fw_update"`
|
||||
ReplacePeers bool `json:"replace_peers"`
|
||||
NameServers []string `json:"name_servers"`
|
||||
DnsNameservers []Nameserver `json:"dns_nameservers"`
|
||||
EgressWithDomains []EgressDomain `json:"egress_with_domains"`
|
||||
AutoRelayNodes map[schema.NetworkID][]Node `json:"auto_relay_nodes"`
|
||||
GwNodes map[schema.NetworkID][]Node `json:"gw_nodes"`
|
||||
AddressIdentityMap map[string]PeerIdentity `json:"address_identity_map"`
|
||||
ServerConfig
|
||||
OldPeerUpdateFields
|
||||
}
|
||||
|
||||
type EgressDomain struct {
|
||||
ID string `json:"id"`
|
||||
Node Node `json:"node"`
|
||||
Host Host `json:"host"`
|
||||
Domain string `json:"domain"`
|
||||
ID string `json:"id"`
|
||||
Node Node `json:"node"`
|
||||
Host schema.Host `json:"host"`
|
||||
Domain string `json:"domain"`
|
||||
}
|
||||
type Nameserver struct {
|
||||
IPs []string `json:"ips"`
|
||||
|
||||
+3
-143
@@ -1,7 +1,6 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -13,15 +12,9 @@ type Network struct {
|
||||
NetID string `json:"netid" bson:"netid" validate:"required,min=1,max=32,netid_valid"`
|
||||
NodesLastModified int64 `json:"nodeslastmodified" bson:"nodeslastmodified" swaggertype:"primitive,integer" format:"int64"`
|
||||
NetworkLastModified int64 `json:"networklastmodified" bson:"networklastmodified" swaggertype:"primitive,integer" format:"int64"`
|
||||
DefaultInterface string `json:"defaultinterface" bson:"defaultinterface" validate:"min=1,max=35"`
|
||||
DefaultListenPort int32 `json:"defaultlistenport,omitempty" bson:"defaultlistenport,omitempty" validate:"omitempty,min=1024,max=65535"`
|
||||
NodeLimit int32 `json:"nodelimit" bson:"nodelimit"`
|
||||
DefaultPostDown string `json:"defaultpostdown" bson:"defaultpostdown"`
|
||||
DefaultKeepalive int32 `json:"defaultkeepalive" bson:"defaultkeepalive" validate:"omitempty,max=1000"`
|
||||
AllowManualSignUp string `json:"allowmanualsignup" bson:"allowmanualsignup" validate:"checkyesorno"`
|
||||
IsIPv4 string `json:"isipv4" bson:"isipv4" validate:"checkyesorno"`
|
||||
IsIPv6 string `json:"isipv6" bson:"isipv6" validate:"checkyesorno"`
|
||||
DefaultUDPHolePunch string `json:"defaultudpholepunch" bson:"defaultudpholepunch" validate:"checkyesorno"`
|
||||
DefaultMTU int32 `json:"defaultmtu" bson:"defaultmtu"`
|
||||
DefaultACL string `json:"defaultacl" bson:"defaultacl" yaml:"defaultacl" validate:"checkyesorno"`
|
||||
NameServers []string `json:"dns_nameservers"`
|
||||
@@ -33,145 +26,12 @@ type Network struct {
|
||||
// VirtualNATPoolIPv4 is the IPv4 CIDR pool from which virtual NAT ranges are allocated for egress gateways
|
||||
VirtualNATPoolIPv4 string `json:"virtual_nat_pool_ipv4"`
|
||||
// VirtualNATSitePrefixLenIPv4 is the prefix length (e.g., 24) for individual site allocations from the IPv4 virtual NAT pool
|
||||
VirtualNATSitePrefixLenIPv4 int `json:"virtual_nat_site_prefixlen_ipv4"`
|
||||
VirtualNATSitePrefixLenIPv4 int `json:"virtual_nat_site_prefixlen_ipv4"`
|
||||
CreatedBy string `json:"created_by"`
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// SaveData - sensitive fields of a network that should be kept the same
|
||||
type SaveData struct { // put sensitive fields here
|
||||
NetID string `json:"netid" bson:"netid" validate:"required,min=1,max=32,netid_valid"`
|
||||
}
|
||||
|
||||
// Network.SetNodesLastModified - sets nodes last modified on network, depricated
|
||||
func (network *Network) SetNodesLastModified() {
|
||||
network.NodesLastModified = time.Now().Unix()
|
||||
}
|
||||
|
||||
// Network.SetNetworkLastModified - sets network last modified time
|
||||
func (network *Network) SetNetworkLastModified() {
|
||||
network.NetworkLastModified = time.Now().Unix()
|
||||
}
|
||||
|
||||
// Network.SetDefaults - sets default values for a network struct
|
||||
func (network *Network) SetDefaults() (upsert bool) {
|
||||
if network.DefaultUDPHolePunch == "" {
|
||||
network.DefaultUDPHolePunch = "no"
|
||||
upsert = true
|
||||
}
|
||||
if network.DefaultInterface == "" {
|
||||
if len(network.NetID) < 33 {
|
||||
network.DefaultInterface = "nm-" + network.NetID
|
||||
} else {
|
||||
network.DefaultInterface = network.NetID
|
||||
}
|
||||
upsert = true
|
||||
}
|
||||
if network.DefaultListenPort == 0 {
|
||||
network.DefaultListenPort = 51821
|
||||
upsert = true
|
||||
}
|
||||
if network.NodeLimit == 0 {
|
||||
network.NodeLimit = 999999999
|
||||
upsert = true
|
||||
}
|
||||
if network.DefaultKeepalive == 0 {
|
||||
network.DefaultKeepalive = 20
|
||||
upsert = true
|
||||
}
|
||||
if network.AllowManualSignUp == "" {
|
||||
network.AllowManualSignUp = "no"
|
||||
upsert = true
|
||||
}
|
||||
|
||||
if network.IsIPv4 == "" {
|
||||
network.IsIPv4 = "yes"
|
||||
upsert = true
|
||||
}
|
||||
|
||||
if network.IsIPv6 == "" {
|
||||
network.IsIPv6 = "no"
|
||||
upsert = true
|
||||
}
|
||||
|
||||
if network.DefaultMTU == 0 {
|
||||
network.DefaultMTU = 1280
|
||||
upsert = true
|
||||
}
|
||||
|
||||
if network.DefaultACL == "" {
|
||||
network.DefaultACL = "yes"
|
||||
upsert = true
|
||||
}
|
||||
|
||||
if network.JITEnabled == "" {
|
||||
network.JITEnabled = "no"
|
||||
upsert = true
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// AssignVirtualNATDefaults determines safe defaults based on VPN CIDR
|
||||
func (network *Network) AssignVirtualNATDefaults(vpnCIDR string, networkID string) {
|
||||
const (
|
||||
cgnatCIDR = "100.64.0.0/10"
|
||||
fallbackIPv4Pool = "198.18.0.0/15"
|
||||
|
||||
defaultIPv4SitePrefix = 24
|
||||
)
|
||||
|
||||
// Parse CGNAT CIDR (should always succeed, but check for safety)
|
||||
_, cgnatNet, err := net.ParseCIDR(cgnatCIDR)
|
||||
if err != nil {
|
||||
// Fallback to default pool if CGNAT parsing fails (shouldn't happen)
|
||||
network.VirtualNATPoolIPv4 = fallbackIPv4Pool
|
||||
network.VirtualNATSitePrefixLenIPv4 = defaultIPv4SitePrefix
|
||||
return
|
||||
}
|
||||
|
||||
var virtualIPv4Pool string
|
||||
// Parse VPN CIDR - if it fails or is empty, use fallback
|
||||
if vpnCIDR == "" {
|
||||
virtualIPv4Pool = fallbackIPv4Pool
|
||||
} else {
|
||||
_, vpnNet, err := net.ParseCIDR(vpnCIDR)
|
||||
if err != nil || vpnNet == nil {
|
||||
// Invalid VPN CIDR, use fallback
|
||||
virtualIPv4Pool = fallbackIPv4Pool
|
||||
} else if !cidrOverlaps(vpnNet, cgnatNet) {
|
||||
// Safe to reuse VPN CIDR for Virtual NAT
|
||||
virtualIPv4Pool = vpnCIDR
|
||||
} else {
|
||||
// VPN is CGNAT — must not reuse
|
||||
virtualIPv4Pool = fallbackIPv4Pool
|
||||
}
|
||||
}
|
||||
|
||||
network.VirtualNATPoolIPv4 = virtualIPv4Pool
|
||||
network.VirtualNATSitePrefixLenIPv4 = defaultIPv4SitePrefix
|
||||
|
||||
}
|
||||
func cidrOverlaps(a, b *net.IPNet) bool {
|
||||
if a == nil || b == nil {
|
||||
return false
|
||||
}
|
||||
return a.Contains(b.IP) || b.Contains(a.IP)
|
||||
}
|
||||
|
||||
func (network *Network) GetNetworkNetworkCIDR4() *net.IPNet {
|
||||
if network.AddressRange == "" {
|
||||
return nil
|
||||
}
|
||||
_, netCidr, _ := net.ParseCIDR(network.AddressRange)
|
||||
return netCidr
|
||||
}
|
||||
func (network *Network) GetNetworkNetworkCIDR6() *net.IPNet {
|
||||
if network.AddressRange6 == "" {
|
||||
return nil
|
||||
}
|
||||
_, netCidr, _ := net.ParseCIDR(network.AddressRange6)
|
||||
return netCidr
|
||||
}
|
||||
|
||||
type NetworkStatResp struct {
|
||||
Network
|
||||
Hosts int `json:"hosts"`
|
||||
}
|
||||
|
||||
+8
-13
@@ -9,6 +9,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
@@ -55,14 +56,7 @@ var seededRand *rand.Rand = rand.New(
|
||||
type NodeCheckin struct {
|
||||
Version string
|
||||
Connected bool
|
||||
Ifaces []Iface
|
||||
}
|
||||
|
||||
// Iface struct for local interfaces of a node
|
||||
type Iface struct {
|
||||
Name string `json:"name"`
|
||||
Address net.IPNet `json:"address"`
|
||||
AddressString string `json:"addressString"`
|
||||
Ifaces []schema.Iface
|
||||
}
|
||||
|
||||
// CommonNode - represents a commonn node data elements shared by netmaker and netclient
|
||||
@@ -127,7 +121,7 @@ type Node struct {
|
||||
Mutex *sync.Mutex `json:"-"`
|
||||
EgressDetails EgressDetails `json:"-"`
|
||||
PostureChecksViolations []Violation `json:"posture_check_violations"`
|
||||
PostureCheckVolationSeverityLevel Severity `json:"posture_check_violation_severity_level"`
|
||||
PostureCheckVolationSeverityLevel schema.Severity `json:"posture_check_violation_severity_level"`
|
||||
LastEvaluatedAt time.Time `json:"last_evaluated_at"`
|
||||
Location string `json:"location"` // Format: "lat,lon"
|
||||
CountryCode string `json:"country_code"`
|
||||
@@ -148,7 +142,7 @@ type LegacyNode struct {
|
||||
Address string `json:"address" bson:"address" yaml:"address" validate:"omitempty,ipv4"`
|
||||
Address6 string `json:"address6" bson:"address6" yaml:"address6" validate:"omitempty,ipv6"`
|
||||
LocalAddress string `json:"localaddress" bson:"localaddress" yaml:"localaddress" validate:"omitempty"`
|
||||
Interfaces []Iface `json:"interfaces" yaml:"interfaces"`
|
||||
Interfaces []schema.Iface `json:"interfaces" yaml:"interfaces"`
|
||||
Name string `json:"name" bson:"name" yaml:"name" validate:"omitempty,max=62,in_charset"`
|
||||
NetworkSettings Network `json:"networksettings" bson:"networksettings" yaml:"networksettings" validate:"-"`
|
||||
ListenPort int32 `json:"listenport" bson:"listenport" yaml:"listenport" validate:"omitempty,numeric,min=1024,max=65535"`
|
||||
@@ -541,10 +535,10 @@ func (node *Node) DoesACLDeny() bool {
|
||||
return node.DefaultACL == "no"
|
||||
}
|
||||
|
||||
func (ln *LegacyNode) ConvertToNewNode() (*Host, *Node) {
|
||||
func (ln *LegacyNode) ConvertToNewNode() (*schema.Host, *Node) {
|
||||
var node Node
|
||||
//host:= logic.GetHost(node.HostID)
|
||||
var host Host
|
||||
var host schema.Host
|
||||
if host.ID.String() == "" {
|
||||
host.ID = uuid.New()
|
||||
host.FirewallInUse = ln.FirewallInUse
|
||||
@@ -554,7 +548,8 @@ func (ln *LegacyNode) ConvertToNewNode() (*Host, *Node) {
|
||||
host.Name = ln.Name
|
||||
host.ListenPort = int(ln.ListenPort)
|
||||
host.MTU = int(ln.MTU)
|
||||
host.PublicKey, _ = wgtypes.ParseKey(ln.PublicKey)
|
||||
pubkey, _ := wgtypes.ParseKey(ln.PublicKey)
|
||||
host.PublicKey = schema.WgKey{Key: pubkey}
|
||||
host.MacAddress, _ = net.ParseMAC(ln.MacAddress)
|
||||
host.TrafficKeyPublic = ln.TrafficKeys.Mine
|
||||
id, _ := uuid.Parse(ln.ID)
|
||||
|
||||
+45
-58
@@ -2,10 +2,10 @@ package models
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
jwt "github.com/golang-jwt/jwt/v4"
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
@@ -141,6 +141,14 @@ type SuccessResponse struct {
|
||||
Response interface{}
|
||||
}
|
||||
|
||||
type PaginatedResponse struct {
|
||||
Data interface{} `json:"data"`
|
||||
Page int `json:"page"`
|
||||
PerPage int `json:"per_page"`
|
||||
Total int `json:"total"`
|
||||
TotalPages int `json:"total_pages"`
|
||||
}
|
||||
|
||||
// DisplayKey - what is displayed for key
|
||||
type DisplayKey struct {
|
||||
Name string `json:"name" bson:"name"`
|
||||
@@ -194,12 +202,12 @@ type EgressRangeMetric struct {
|
||||
// from. Might not be always set.
|
||||
EgressID string `json:"-"`
|
||||
// EgressName is the name of the egress gateway identified by EgressID. Might not be always set.
|
||||
EgressName string `json:"-"`
|
||||
Network string `json:"network"`
|
||||
VirtualNetwork string `json:"virtual_network"`
|
||||
RouteMetric uint32 `json:"route_metric"` // preffered range 1-999
|
||||
Nat bool `json:"nat"`
|
||||
Mode EgressNATMode `json:"nat_mode"`
|
||||
EgressName string `json:"-"`
|
||||
Network string `json:"network"`
|
||||
VirtualNetwork string `json:"virtual_network"`
|
||||
RouteMetric uint32 `json:"route_metric"` // preffered range 1-999
|
||||
Nat bool `json:"nat"`
|
||||
Mode schema.EgressNATMode `json:"nat_mode"`
|
||||
}
|
||||
|
||||
// EgressGatewayRequest - egress gateway request
|
||||
@@ -268,31 +276,31 @@ type TrafficKeys struct {
|
||||
|
||||
// HostPull - response of a host's pull
|
||||
type HostPull struct {
|
||||
Host Host `json:"host" yaml:"host"`
|
||||
Nodes []Node `json:"nodes" yaml:"nodes"`
|
||||
Peers []wgtypes.PeerConfig `json:"peers" yaml:"peers"`
|
||||
ServerConfig ServerConfig `json:"server_config" yaml:"server_config"`
|
||||
PeerIDs PeerMap `json:"peer_ids,omitempty" yaml:"peer_ids,omitempty"`
|
||||
HostNetworkInfo HostInfoMap `json:"host_network_info,omitempty" yaml:"host_network_info,omitempty"`
|
||||
EgressRoutes []EgressNetworkRoutes `json:"egress_network_routes"`
|
||||
FwUpdate FwUpdate `json:"fw_update"`
|
||||
ChangeDefaultGw bool `json:"change_default_gw"`
|
||||
DefaultGwIp net.IP `json:"default_gw_ip"`
|
||||
IsInternetGw bool `json:"is_inet_gw"`
|
||||
EndpointDetection bool `json:"endpoint_detection"`
|
||||
NameServers []string `json:"name_servers"`
|
||||
EgressWithDomains []EgressDomain `json:"egress_with_domains"`
|
||||
DnsNameservers []Nameserver `json:"dns_nameservers"`
|
||||
AutoRelayNodes map[NetworkID][]Node `json:"auto_relay_nodes"`
|
||||
GwNodes map[NetworkID][]Node `json:"gw_nodes"`
|
||||
ReplacePeers bool `json:"replace_peers"`
|
||||
AddressIdentityMap map[string]PeerIdentity `json:"address_identity_map"`
|
||||
Host schema.Host `json:"host" yaml:"host"`
|
||||
Nodes []Node `json:"nodes" yaml:"nodes"`
|
||||
Peers []wgtypes.PeerConfig `json:"peers" yaml:"peers"`
|
||||
ServerConfig ServerConfig `json:"server_config" yaml:"server_config"`
|
||||
PeerIDs PeerMap `json:"peer_ids,omitempty" yaml:"peer_ids,omitempty"`
|
||||
HostNetworkInfo HostInfoMap `json:"host_network_info,omitempty" yaml:"host_network_info,omitempty"`
|
||||
EgressRoutes []EgressNetworkRoutes `json:"egress_network_routes"`
|
||||
FwUpdate FwUpdate `json:"fw_update"`
|
||||
ChangeDefaultGw bool `json:"change_default_gw"`
|
||||
DefaultGwIp net.IP `json:"default_gw_ip"`
|
||||
IsInternetGw bool `json:"is_inet_gw"`
|
||||
EndpointDetection bool `json:"endpoint_detection"`
|
||||
NameServers []string `json:"name_servers"`
|
||||
EgressWithDomains []EgressDomain `json:"egress_with_domains"`
|
||||
DnsNameservers []Nameserver `json:"dns_nameservers"`
|
||||
AutoRelayNodes map[schema.NetworkID][]Node `json:"auto_relay_nodes"`
|
||||
GwNodes map[schema.NetworkID][]Node `json:"gw_nodes"`
|
||||
ReplacePeers bool `json:"replace_peers"`
|
||||
AddressIdentityMap map[string]PeerIdentity `json:"address_identity_map"`
|
||||
}
|
||||
|
||||
// NodeGet - struct for a single node get response
|
||||
type NodeGet struct {
|
||||
Node Node `json:"node" bson:"node" yaml:"node"`
|
||||
Host Host `json:"host" yaml:"host"`
|
||||
Host schema.Host `json:"host" yaml:"host"`
|
||||
Peers []wgtypes.PeerConfig `json:"peers" bson:"peers" yaml:"peers"`
|
||||
HostPeers []wgtypes.PeerConfig `json:"host_peers" bson:"host_peers" yaml:"host_peers"`
|
||||
ServerConfig ServerConfig `json:"serverconfig" bson:"serverconfig" yaml:"serverconfig"`
|
||||
@@ -302,7 +310,7 @@ type NodeGet struct {
|
||||
// NodeJoinResponse data returned to node in response to join
|
||||
type NodeJoinResponse struct {
|
||||
Node Node `json:"node" bson:"node" yaml:"node"`
|
||||
Host Host `json:"host" yaml:"host"`
|
||||
Host schema.Host `json:"host" yaml:"host"`
|
||||
ServerConfig ServerConfig `json:"serverconfig" bson:"serverconfig" yaml:"serverconfig"`
|
||||
Peers []wgtypes.PeerConfig `json:"peers" bson:"peers" yaml:"peers"`
|
||||
}
|
||||
@@ -336,17 +344,6 @@ type ServerConfig struct {
|
||||
OldAClsSupport bool `json:"-"`
|
||||
}
|
||||
|
||||
// User.NameInCharset - returns if name is in charset below or not
|
||||
func (user *User) NameInCharSet() bool {
|
||||
charset := "abcdefghijklmnopqrstuvwxyz1234567890-."
|
||||
for _, char := range user.UserName {
|
||||
if !strings.Contains(charset, strings.ToLower(string(char))) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// ServerIDs - struct to hold server ids.
|
||||
type ServerIDs struct {
|
||||
ServerIDs []string `json:"server_ids"`
|
||||
@@ -354,9 +351,9 @@ type ServerIDs struct {
|
||||
|
||||
// JoinData - struct to hold data required for node to join a network on server
|
||||
type JoinData struct {
|
||||
Host Host `json:"host" yaml:"host"`
|
||||
Node Node `json:"node" yaml:"node"`
|
||||
Key string `json:"key" yaml:"key"`
|
||||
Host schema.Host `json:"host" yaml:"host"`
|
||||
Node Node `json:"node" yaml:"node"`
|
||||
Key string `json:"key" yaml:"key"`
|
||||
}
|
||||
|
||||
// HookFunc - function type for hooks that can accept optional parameters
|
||||
@@ -484,23 +481,13 @@ type PostureCheckDeviceInfo struct {
|
||||
AutoUpdate bool
|
||||
Tags map[TagID]struct{}
|
||||
IsUser bool
|
||||
UserGroups map[UserGroupID]struct{}
|
||||
UserGroups map[schema.UserGroupID]struct{}
|
||||
}
|
||||
|
||||
type Violation struct {
|
||||
CheckID string `json:"check_id"`
|
||||
Name string `json:"name"`
|
||||
Attribute string `json:"attribute"`
|
||||
Message string `json:"message"`
|
||||
Severity Severity `json:"severity"`
|
||||
CheckID string `json:"check_id"`
|
||||
Name string `json:"name"`
|
||||
Attribute string `json:"attribute"`
|
||||
Message string `json:"message"`
|
||||
Severity schema.Severity `json:"severity"`
|
||||
}
|
||||
|
||||
type Severity int
|
||||
|
||||
const (
|
||||
SeverityUnknown Severity = iota
|
||||
SeverityLow
|
||||
SeverityMedium
|
||||
SeverityHigh
|
||||
SeverityCritical
|
||||
)
|
||||
|
||||
+12
-10
@@ -3,6 +3,8 @@ package models
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
)
|
||||
|
||||
type TagID string
|
||||
@@ -21,19 +23,19 @@ func (t Tag) GetIDFromName() string {
|
||||
}
|
||||
|
||||
type Tag struct {
|
||||
ID TagID `json:"id"`
|
||||
TagName string `json:"tag_name"`
|
||||
Network NetworkID `json:"network"`
|
||||
ColorCode string `json:"color_code"`
|
||||
CreatedBy string `json:"created_by"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
ID TagID `json:"id"`
|
||||
TagName string `json:"tag_name"`
|
||||
Network schema.NetworkID `json:"network"`
|
||||
ColorCode string `json:"color_code"`
|
||||
CreatedBy string `json:"created_by"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
type CreateTagReq struct {
|
||||
TagName string `json:"tag_name"`
|
||||
Network NetworkID `json:"network"`
|
||||
ColorCode string `json:"color_code"`
|
||||
TaggedNodes []ApiNode `json:"tagged_nodes"`
|
||||
TagName string `json:"tag_name"`
|
||||
Network schema.NetworkID `json:"network"`
|
||||
ColorCode string `json:"color_code"`
|
||||
TaggedNodes []ApiNode `json:"tagged_nodes"`
|
||||
}
|
||||
|
||||
type TagListResp struct {
|
||||
|
||||
+49
-187
@@ -5,35 +5,17 @@ import (
|
||||
"time"
|
||||
|
||||
jwt "github.com/golang-jwt/jwt/v4"
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
)
|
||||
|
||||
type NetworkID string
|
||||
type RsrcType string
|
||||
type RsrcID string
|
||||
type UserRoleID string
|
||||
type UserGroupID string
|
||||
type AuthType string
|
||||
type TokenType string
|
||||
|
||||
var (
|
||||
BasicAuth AuthType = "basic_auth"
|
||||
OAuth AuthType = "oauth"
|
||||
)
|
||||
|
||||
func (r RsrcType) String() string {
|
||||
return string(r)
|
||||
}
|
||||
|
||||
func (rid RsrcID) String() string {
|
||||
return string(rid)
|
||||
}
|
||||
|
||||
func GetRAGRoleName(netID, hostName string) string {
|
||||
return fmt.Sprintf("netID-%s-rag-%s", netID, hostName)
|
||||
}
|
||||
|
||||
func GetRAGRoleID(netID, hostID string) UserRoleID {
|
||||
return UserRoleID(fmt.Sprintf("netID-%s-rag-%s", netID, hostID))
|
||||
func GetRAGRoleID(netID, hostID string) schema.UserRoleID {
|
||||
return schema.UserRoleID(fmt.Sprintf("netID-%s-rag-%s", netID, hostID))
|
||||
}
|
||||
|
||||
func (t TokenType) String() string {
|
||||
@@ -45,169 +27,49 @@ var (
|
||||
AccessTokenType TokenType = "access_token"
|
||||
)
|
||||
|
||||
var RsrcTypeMap = map[RsrcType]struct{}{
|
||||
HostRsrc: {},
|
||||
RelayRsrc: {},
|
||||
RemoteAccessGwRsrc: {},
|
||||
ExtClientsRsrc: {},
|
||||
InetGwRsrc: {},
|
||||
EgressGwRsrc: {},
|
||||
NetworkRsrc: {},
|
||||
EnrollmentKeysRsrc: {},
|
||||
UserRsrc: {},
|
||||
AclRsrc: {},
|
||||
DnsRsrc: {},
|
||||
FailOverRsrc: {},
|
||||
}
|
||||
|
||||
const AllNetworks NetworkID = "all_networks"
|
||||
const (
|
||||
HostRsrc RsrcType = "host"
|
||||
RelayRsrc RsrcType = "relay"
|
||||
RemoteAccessGwRsrc RsrcType = "remote_access_gw"
|
||||
GatewayRsrc RsrcType = "gateway"
|
||||
ExtClientsRsrc RsrcType = "extclient"
|
||||
InetGwRsrc RsrcType = "inet_gw"
|
||||
EgressGwRsrc RsrcType = "egress"
|
||||
NetworkRsrc RsrcType = "network"
|
||||
EnrollmentKeysRsrc RsrcType = "enrollment_key"
|
||||
UserRsrc RsrcType = "user"
|
||||
AclRsrc RsrcType = "acl"
|
||||
TagRsrc RsrcType = "tag"
|
||||
DnsRsrc RsrcType = "dns"
|
||||
NameserverRsrc RsrcType = "nameserver"
|
||||
FailOverRsrc RsrcType = "fail_over"
|
||||
MetricRsrc RsrcType = "metric"
|
||||
PostureCheckRsrc RsrcType = "posturecheck"
|
||||
JitAdminRsrc RsrcType = "jit_admin"
|
||||
JitUserRsrc RsrcType = "jit_user"
|
||||
UserActivityRsrc RsrcType = "user_activity"
|
||||
TrafficFlow RsrcType = "traffic_flow"
|
||||
)
|
||||
|
||||
const (
|
||||
AllHostRsrcID RsrcID = "all_host"
|
||||
AllRelayRsrcID RsrcID = "all_relay"
|
||||
AllRemoteAccessGwRsrcID RsrcID = "all_remote_access_gw"
|
||||
AllExtClientsRsrcID RsrcID = "all_extclients"
|
||||
AllInetGwRsrcID RsrcID = "all_inet_gw"
|
||||
AllEgressGwRsrcID RsrcID = "all_egress"
|
||||
AllNetworkRsrcID RsrcID = "all_network"
|
||||
AllEnrollmentKeysRsrcID RsrcID = "all_enrollment_key"
|
||||
AllUserRsrcID RsrcID = "all_user"
|
||||
AllDnsRsrcID RsrcID = "all_dns"
|
||||
AllFailOverRsrcID RsrcID = "all_fail_over"
|
||||
AllAclsRsrcID RsrcID = "all_acl"
|
||||
AllTagsRsrcID RsrcID = "all_tag"
|
||||
AllPostureCheckRsrcID RsrcID = "all_posturecheck"
|
||||
AllNameserverRsrcID RsrcID = "all_nameserver"
|
||||
AllJitAdminRsrcID RsrcID = "all_jit_admin"
|
||||
AllJitUserRsrcID RsrcID = "all_jit_user"
|
||||
AllUserActivityRsrcID RsrcID = "all_user_activity"
|
||||
AllTrafficFlowRsrcID RsrcID = "all_traffic_flow"
|
||||
)
|
||||
|
||||
// Pre-Defined User Roles
|
||||
|
||||
const (
|
||||
SuperAdminRole UserRoleID = "super-admin"
|
||||
AdminRole UserRoleID = "admin"
|
||||
ServiceUser UserRoleID = "service-user"
|
||||
PlatformUser UserRoleID = "platform-user"
|
||||
Auditor UserRoleID = "auditor"
|
||||
NetworkAdmin UserRoleID = "network-admin"
|
||||
NetworkUser UserRoleID = "network-user"
|
||||
)
|
||||
|
||||
func (r UserRoleID) String() string {
|
||||
return string(r)
|
||||
}
|
||||
|
||||
func (g UserGroupID) String() string {
|
||||
return string(g)
|
||||
}
|
||||
|
||||
func (n NetworkID) String() string {
|
||||
return string(n)
|
||||
}
|
||||
|
||||
type RsrcPermissionScope struct {
|
||||
Create bool `json:"create"`
|
||||
Read bool `json:"read"`
|
||||
Update bool `json:"update"`
|
||||
Delete bool `json:"delete"`
|
||||
VPNaccess bool `json:"vpn_access"`
|
||||
SelfOnly bool `json:"self_only"`
|
||||
}
|
||||
|
||||
type UserRolePermissionTemplate struct {
|
||||
ID UserRoleID `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Default bool `json:"default"`
|
||||
MetaData string `json:"meta_data"`
|
||||
DenyDashboardAccess bool `json:"deny_dashboard_access"`
|
||||
FullAccess bool `json:"full_access"`
|
||||
NetworkID NetworkID `json:"network_id"`
|
||||
NetworkLevelAccess map[RsrcType]map[RsrcID]RsrcPermissionScope `json:"network_level_access"`
|
||||
GlobalLevelAccess map[RsrcType]map[RsrcID]RsrcPermissionScope `json:"global_level_access"`
|
||||
}
|
||||
|
||||
type CreateGroupReq struct {
|
||||
Group UserGroup `json:"user_group"`
|
||||
Members []string `json:"members"`
|
||||
}
|
||||
|
||||
type UserGroup struct {
|
||||
ID UserGroupID `json:"id"`
|
||||
ExternalIdentityProviderID string `json:"external_identity_provider_id"`
|
||||
Default bool `json:"default"`
|
||||
Name string `json:"name"`
|
||||
NetworkRoles map[NetworkID]map[UserRoleID]struct{} `json:"network_roles"`
|
||||
ColorCode string `json:"color_code"`
|
||||
MetaData string `json:"meta_data"`
|
||||
}
|
||||
|
||||
// User struct - struct for Users
|
||||
type User struct {
|
||||
UserName string `json:"username" bson:"username" validate:"min=3,in_charset|email"`
|
||||
ExternalIdentityProviderID string `json:"external_identity_provider_id"`
|
||||
IsMFAEnabled bool `json:"is_mfa_enabled"`
|
||||
TOTPSecret string `json:"totp_secret"`
|
||||
DisplayName string `json:"display_name"`
|
||||
AccountDisabled bool `json:"account_disabled"`
|
||||
Password string `json:"password" bson:"password" validate:"required,min=5"`
|
||||
IsAdmin bool `json:"isadmin" bson:"isadmin"` // deprecated
|
||||
IsSuperAdmin bool `json:"issuperadmin"` // deprecated
|
||||
RemoteGwIDs map[string]struct{} `json:"remote_gw_ids"` // deprecated
|
||||
AuthType AuthType `json:"auth_type"`
|
||||
UserGroups map[UserGroupID]struct{} `json:"user_group_ids"`
|
||||
PlatformRoleID UserRoleID `json:"platform_role_id"`
|
||||
NetworkRoles map[NetworkID]map[UserRoleID]struct{} `json:"network_roles"`
|
||||
LastLoginTime time.Time `json:"last_login_time"`
|
||||
}
|
||||
|
||||
type ReturnUserWithRolesAndGroups struct {
|
||||
ReturnUser
|
||||
PlatformRole UserRolePermissionTemplate `json:"platform_role"`
|
||||
UserGroups map[UserGroupID]UserGroup `json:"user_group_ids"`
|
||||
UserName string `json:"username" bson:"username" validate:"min=3,in_charset|email"`
|
||||
ExternalIdentityProviderID string `json:"external_identity_provider_id"`
|
||||
IsMFAEnabled bool `json:"is_mfa_enabled"`
|
||||
TOTPSecret string `json:"totp_secret"`
|
||||
DisplayName string `json:"display_name"`
|
||||
AccountDisabled bool `json:"account_disabled"`
|
||||
Password string `json:"password" bson:"password" validate:"required,min=5"`
|
||||
IsAdmin bool `json:"isadmin" bson:"isadmin"` // deprecated
|
||||
IsSuperAdmin bool `json:"issuperadmin"` // deprecated
|
||||
RemoteGwIDs map[string]struct{} `json:"remote_gw_ids"` // deprecated
|
||||
AuthType schema.AuthType `json:"auth_type"`
|
||||
UserGroups map[schema.UserGroupID]struct{} `json:"user_group_ids"`
|
||||
PlatformRoleID schema.UserRoleID `json:"platform_role_id"`
|
||||
NetworkRoles map[schema.NetworkID]map[schema.UserRoleID]struct{} `json:"network_roles"`
|
||||
LastLoginTime time.Time `json:"last_login_time"`
|
||||
CreatedBy string `json:"created_by"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// ReturnUser - return user struct
|
||||
type ReturnUser struct {
|
||||
UserName string `json:"username"`
|
||||
ExternalIdentityProviderID string `json:"external_identity_provider_id"`
|
||||
IsMFAEnabled bool `json:"is_mfa_enabled"`
|
||||
DisplayName string `json:"display_name"`
|
||||
AccountDisabled bool `json:"account_disabled"`
|
||||
IsAdmin bool `json:"isadmin"`
|
||||
IsSuperAdmin bool `json:"issuperadmin"`
|
||||
AuthType AuthType `json:"auth_type"`
|
||||
RemoteGwIDs map[string]struct{} `json:"remote_gw_ids"` // deprecated
|
||||
UserGroups map[UserGroupID]struct{} `json:"user_group_ids"`
|
||||
PlatformRoleID UserRoleID `json:"platform_role_id"`
|
||||
NetworkRoles map[NetworkID]map[UserRoleID]struct{} `json:"network_roles"`
|
||||
LastLoginTime time.Time `json:"last_login_time"`
|
||||
NumAccessTokens int `json:"num_access_tokens"`
|
||||
UserName string `json:"username"`
|
||||
ExternalIdentityProviderID string `json:"external_identity_provider_id"`
|
||||
IsMFAEnabled bool `json:"is_mfa_enabled"`
|
||||
DisplayName string `json:"display_name"`
|
||||
AccountDisabled bool `json:"account_disabled"`
|
||||
IsAdmin bool `json:"isadmin"`
|
||||
IsSuperAdmin bool `json:"issuperadmin"`
|
||||
AuthType schema.AuthType `json:"auth_type"`
|
||||
RemoteGwIDs map[string]struct{} `json:"remote_gw_ids"` // deprecated
|
||||
UserGroups map[schema.UserGroupID]struct{} `json:"user_group_ids"`
|
||||
PlatformRoleID schema.UserRoleID `json:"platform_role_id"`
|
||||
NetworkRoles map[schema.NetworkID]map[schema.UserRoleID]struct{} `json:"network_roles"`
|
||||
LastLoginTime time.Time `json:"last_login_time"`
|
||||
NumAccessTokens int `json:"num_access_tokens"`
|
||||
CreatedBy string `json:"created_by"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// UserAuthParams - user auth params struct
|
||||
@@ -234,7 +96,7 @@ type UserTOTPVerificationParams struct {
|
||||
|
||||
// UserClaims - user claims struct
|
||||
type UserClaims struct {
|
||||
Role UserRoleID
|
||||
Role schema.UserRoleID
|
||||
UserName string
|
||||
Api string
|
||||
TokenType TokenType
|
||||
@@ -243,20 +105,20 @@ type UserClaims struct {
|
||||
}
|
||||
|
||||
type InviteUsersReq struct {
|
||||
UserEmails []string `json:"user_emails"`
|
||||
PlatformRoleID string `json:"platform_role_id"`
|
||||
UserGroups map[UserGroupID]struct{} `json:"user_group_ids"`
|
||||
NetworkRoles map[NetworkID]map[UserRoleID]struct{} `json:"network_roles"`
|
||||
UserEmails []string `json:"user_emails"`
|
||||
PlatformRoleID string `json:"platform_role_id"`
|
||||
UserGroups map[schema.UserGroupID]struct{} `json:"user_group_ids"`
|
||||
NetworkRoles map[schema.NetworkID]map[schema.UserRoleID]struct{} `json:"network_roles"`
|
||||
}
|
||||
|
||||
// UserInvite - model for user invite
|
||||
type UserInvite struct {
|
||||
Email string `json:"email"`
|
||||
PlatformRoleID string `json:"platform_role_id"`
|
||||
UserGroups map[UserGroupID]struct{} `json:"user_group_ids"`
|
||||
NetworkRoles map[NetworkID]map[UserRoleID]struct{} `json:"network_roles"`
|
||||
InviteCode string `json:"invite_code"`
|
||||
InviteURL string `json:"invite_url"`
|
||||
Email string `json:"email"`
|
||||
PlatformRoleID string `json:"platform_role_id"`
|
||||
UserGroups map[schema.UserGroupID]struct{} `json:"user_group_ids"`
|
||||
NetworkRoles map[schema.NetworkID]map[schema.UserRoleID]struct{} `json:"network_roles"`
|
||||
InviteCode string `json:"invite_code"`
|
||||
InviteURL string `json:"invite_url"`
|
||||
}
|
||||
|
||||
// UserMapping - user ip map with groups
|
||||
|
||||
+12
-9
@@ -1,17 +1,20 @@
|
||||
package mq
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net"
|
||||
|
||||
mqtt "github.com/eclipse/paho.mqtt.golang"
|
||||
"github.com/google/uuid"
|
||||
"github.com/gravitl/netmaker/database"
|
||||
"github.com/gravitl/netmaker/db"
|
||||
"github.com/gravitl/netmaker/logger"
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/logic/hostactions"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/gravitl/netmaker/netclient/ncutils"
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
"github.com/gravitl/netmaker/servercfg"
|
||||
"github.com/gravitl/netmaker/utils"
|
||||
"golang.org/x/exp/slog"
|
||||
@@ -60,8 +63,8 @@ func UpdateNode(client mqtt.Client, msg mqtt.Message) {
|
||||
if ifaceDelta { // reduce number of unneeded updates, by only sending on iface changes
|
||||
if !newNode.Connected {
|
||||
err = PublishDeletedNodePeerUpdate(&newNode)
|
||||
host, err := logic.GetHost(newNode.HostID.String())
|
||||
if err != nil {
|
||||
host := &schema.Host{ID: newNode.HostID}
|
||||
if err := host.Get(db.WithContext(context.TODO())); err != nil {
|
||||
slog.Error("failed to get host for the node", "nodeid", newNode.ID.String(), "error", err)
|
||||
return
|
||||
}
|
||||
@@ -87,8 +90,8 @@ func UpdateHost(client mqtt.Client, msg mqtt.Message) {
|
||||
slog.Error("error getting host.ID sent on ", "topic", msg.Topic(), "error", err)
|
||||
return
|
||||
}
|
||||
currentHost, err := logic.GetHost(id)
|
||||
if err != nil {
|
||||
currentHost := &schema.Host{ID: uuid.MustParse(id)}
|
||||
if err := currentHost.Get(db.WithContext(context.TODO())); err != nil {
|
||||
slog.Error("error getting host", "id", id, "error", err)
|
||||
return
|
||||
}
|
||||
@@ -159,7 +162,7 @@ func UpdateHost(client mqtt.Client, msg mqtt.Message) {
|
||||
}
|
||||
}
|
||||
|
||||
func DeleteAndCleanupHost(h *models.Host) {
|
||||
func DeleteAndCleanupHost(h *schema.Host) {
|
||||
if servercfg.GetBrokerType() == servercfg.EmqxBrokerType {
|
||||
// delete EMQX credentials for host
|
||||
if err := emqx.DeleteEmqxUser(h.ID.String()); err != nil {
|
||||
@@ -180,7 +183,7 @@ func DeleteAndCleanupHost(h *models.Host) {
|
||||
slog.Error("failed to delete all nodes of host", "id", h.ID, "error", err)
|
||||
return
|
||||
}
|
||||
if err := logic.RemoveHostByID(h.ID.String()); err != nil {
|
||||
if err := (&schema.Host{ID: h.ID}).Delete(db.WithContext(context.TODO())); err != nil {
|
||||
slog.Error("failed to delete host", "id", h.ID, "error", err)
|
||||
return
|
||||
}
|
||||
@@ -209,8 +212,8 @@ func SignalPeer(signal models.Signal) {
|
||||
}
|
||||
signal.NetworkID = node.Network
|
||||
signal.IsPro = servercfg.IsPro
|
||||
peerHost, err := logic.GetHost(signal.ToHostID)
|
||||
if err != nil {
|
||||
peerHost := &schema.Host{ID: uuid.MustParse(signal.ToHostID)}
|
||||
if err := peerHost.Get(db.WithContext(context.TODO())); err != nil {
|
||||
slog.Error("failed to signal, peer host not found", "error", err)
|
||||
return
|
||||
}
|
||||
@@ -255,7 +258,7 @@ func ClientPeerUpdate(client mqtt.Client, msg mqtt.Message) {
|
||||
slog.Info("sent peer updates after signal received from", "id", id)
|
||||
}
|
||||
|
||||
func HandleHostCheckin(h, currentHost *models.Host) bool {
|
||||
func HandleHostCheckin(h, currentHost *schema.Host) bool {
|
||||
if h == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
+5
-2
@@ -2,6 +2,7 @@ package mq
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -11,9 +12,11 @@ import (
|
||||
"time"
|
||||
|
||||
mqtt "github.com/eclipse/paho.mqtt.golang"
|
||||
"github.com/gravitl/netmaker/db"
|
||||
"github.com/gravitl/netmaker/logger"
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
"github.com/gravitl/netmaker/servercfg"
|
||||
"golang.org/x/exp/slog"
|
||||
)
|
||||
@@ -77,7 +80,7 @@ func SendPullSYN() error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
hosts, err := logic.GetAllHosts()
|
||||
hosts, err := (&schema.Host{}).ListAll(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -125,7 +128,7 @@ func KickOutClients() error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
hosts, err := logic.GetAllHosts()
|
||||
hosts, err := (&schema.Host{}).ListAll(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
slog.Error("failed to migrate emqx: ", "error", err)
|
||||
return err
|
||||
|
||||
+14
-12
@@ -1,6 +1,7 @@
|
||||
package mq
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -8,9 +9,11 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gravitl/netmaker/db"
|
||||
"github.com/gravitl/netmaker/logger"
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
"github.com/gravitl/netmaker/servercfg"
|
||||
"golang.org/x/exp/slog"
|
||||
)
|
||||
@@ -25,7 +28,7 @@ func PublishPeerUpdate(replacePeers bool) error {
|
||||
sendDNSSync()
|
||||
}
|
||||
|
||||
hosts, err := logic.GetAllHosts()
|
||||
hosts, err := (&schema.Host{}).ListAll(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
logger.Log(1, "err getting all hosts", err.Error())
|
||||
return err
|
||||
@@ -38,7 +41,7 @@ func PublishPeerUpdate(replacePeers bool) error {
|
||||
for _, host := range hosts {
|
||||
host := host
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
go func(host models.Host) {
|
||||
go func(host schema.Host) {
|
||||
if err = PublishSingleHostPeerUpdate(&host, allNodes, nil, nil, replacePeers, nil); err != nil {
|
||||
id := host.Name
|
||||
if host.ID != uuid.Nil {
|
||||
@@ -59,7 +62,7 @@ func PublishDeletedNodePeerUpdate(delNode *models.Node) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
hosts, err := logic.GetAllHosts()
|
||||
hosts, err := (&schema.Host{}).ListAll(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
logger.Log(1, "err getting all hosts", err.Error())
|
||||
return err
|
||||
@@ -84,7 +87,7 @@ func PublishDeletedClientPeerUpdate(delClient *models.ExtClient) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
hosts, err := logic.GetAllHosts()
|
||||
hosts, err := (&schema.Host{}).ListAll(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
logger.Log(1, "err getting all hosts", err.Error())
|
||||
return err
|
||||
@@ -105,7 +108,7 @@ func PublishDeletedClientPeerUpdate(delClient *models.ExtClient) error {
|
||||
}
|
||||
|
||||
// PublishSingleHostPeerUpdate --- determines and publishes a peer update to one host
|
||||
func PublishSingleHostPeerUpdate(host *models.Host, allNodes []models.Node, deletedNode *models.Node, deletedClients []models.ExtClient, replacePeers bool, wg *sync.WaitGroup) error {
|
||||
func PublishSingleHostPeerUpdate(host *schema.Host, allNodes []models.Node, deletedNode *models.Node, deletedClients []models.ExtClient, replacePeers bool, wg *sync.WaitGroup) error {
|
||||
if wg != nil {
|
||||
defer wg.Done()
|
||||
}
|
||||
@@ -136,8 +139,8 @@ func PublishSingleHostPeerUpdate(host *models.Host, allNodes []models.Node, dele
|
||||
|
||||
// NodeUpdate -- publishes a node update
|
||||
func NodeUpdate(node *models.Node) error {
|
||||
host, err := logic.GetHost(node.HostID.String())
|
||||
if err != nil {
|
||||
host := &schema.Host{ID: node.HostID}
|
||||
if err := host.Get(db.WithContext(context.TODO())); err != nil {
|
||||
return nil
|
||||
}
|
||||
if !servercfg.IsMessageQueueBackend() {
|
||||
@@ -265,16 +268,15 @@ func SendDNSSyncByNetwork(network string) error {
|
||||
}
|
||||
|
||||
func sendDNSSync() error {
|
||||
|
||||
networks, err := logic.GetNetworks()
|
||||
networks, err := (&schema.Network{}).ListAll(db.WithContext(context.TODO()))
|
||||
if err == nil && len(networks) > 0 {
|
||||
for _, v := range networks {
|
||||
k, err := logic.GetDNS(v.NetID)
|
||||
k = append(k, logic.EgressDNs(v.NetID)...)
|
||||
k, err := logic.GetDNS(v.Name)
|
||||
k = append(k, logic.EgressDNs(v.Name)...)
|
||||
if err == nil && len(k) > 0 {
|
||||
err = PushSyncDNS(k)
|
||||
if err != nil {
|
||||
slog.Warn("error publishing dns entry data for network ", v.NetID, err.Error())
|
||||
slog.Warn("error publishing dns entry data for network ", v.Name, err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+8
-5
@@ -3,6 +3,7 @@ package mq
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
@@ -13,13 +14,15 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gravitl/netmaker/db"
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/gravitl/netmaker/netclient/ncutils"
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
"golang.org/x/exp/slog"
|
||||
)
|
||||
|
||||
func decryptMsgWithHost(host *models.Host, msg []byte) ([]byte, error) {
|
||||
func decryptMsgWithHost(host *schema.Host, msg []byte) ([]byte, error) {
|
||||
if host.OS == models.OS_Types.IoT { // just pass along IoT messages
|
||||
return msg, nil
|
||||
}
|
||||
@@ -44,8 +47,8 @@ func DecryptMsg(node *models.Node, msg []byte) ([]byte, error) {
|
||||
if len(msg) <= 24 { // make sure message is of appropriate length
|
||||
return nil, fmt.Errorf("received invalid message from broker %v", msg)
|
||||
}
|
||||
host, err := logic.GetHost(node.HostID.String())
|
||||
if err != nil {
|
||||
host := &schema.Host{ID: node.HostID}
|
||||
if err := host.Get(db.WithContext(context.TODO())); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -105,7 +108,7 @@ func encryptAESGCM(key, plaintext []byte) ([]byte, error) {
|
||||
return ciphertext, nil
|
||||
}
|
||||
|
||||
func encryptMsg(host *models.Host, msg []byte) ([]byte, error) {
|
||||
func encryptMsg(host *schema.Host, msg []byte) ([]byte, error) {
|
||||
if host.OS == models.OS_Types.IoT {
|
||||
return msg, nil
|
||||
}
|
||||
@@ -133,7 +136,7 @@ func encryptMsg(host *models.Host, msg []byte) ([]byte, error) {
|
||||
return ncutils.Chunk(msg, nodePubKey, serverPrivKey)
|
||||
}
|
||||
|
||||
func publish(host *models.Host, dest string, msg []byte) error {
|
||||
func publish(host *schema.Host, dest string, msg []byte) error {
|
||||
|
||||
var encrypted []byte
|
||||
var encryptErr error
|
||||
|
||||
+25
-21
@@ -9,14 +9,15 @@ import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gravitl/netmaker/database"
|
||||
"github.com/gravitl/netmaker/logger"
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
proLogic "github.com/gravitl/netmaker/pro/logic"
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
"github.com/gravitl/netmaker/servercfg"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/microsoft"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var azure_ad_functions = map[string]interface{}{
|
||||
@@ -90,9 +91,10 @@ func handleAzureCallback(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
user, err := logic.GetUser(content.UserPrincipalName)
|
||||
user := &schema.User{Username: content.UserPrincipalName}
|
||||
err = user.Get(r.Context())
|
||||
if err != nil {
|
||||
if database.IsEmptyRecord(err) { // user must not exist, so try to make one
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) { // user must not exist, so try to make one
|
||||
if inviteExists {
|
||||
// create user
|
||||
user, err := proLogic.PrepareOauthUserFromInvite(in)
|
||||
@@ -100,7 +102,7 @@ func handleAzureCallback(w http.ResponseWriter, r *http.Request) {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
|
||||
return
|
||||
}
|
||||
user.UserName = content.UserPrincipalName
|
||||
user.Username = content.UserPrincipalName
|
||||
user.ExternalIdentityProviderID = string(content.ID)
|
||||
if err = logic.CreateUser(&user); err != nil {
|
||||
handleSomethingWentWrong(w)
|
||||
@@ -114,9 +116,9 @@ func handleAzureCallback(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
err = logic.InsertPendingUser(&models.User{
|
||||
UserName: content.Email,
|
||||
UserName: content.UserPrincipalName,
|
||||
ExternalIdentityProviderID: string(content.ID),
|
||||
AuthType: models.OAuth,
|
||||
AuthType: schema.OAuth,
|
||||
})
|
||||
if err != nil {
|
||||
handleSomethingWentWrong(w)
|
||||
@@ -132,14 +134,15 @@ func handleAzureCallback(w http.ResponseWriter, r *http.Request) {
|
||||
} else {
|
||||
// if user exists, then ensure user's auth type is
|
||||
// oauth before proceeding.
|
||||
if user.AuthType == models.BasicAuth {
|
||||
if user.AuthType == schema.BasicAuth {
|
||||
logger.Log(0, "invalid auth type: basic_auth")
|
||||
handleAuthTypeMismatch(w)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
user, err = logic.GetUser(content.UserPrincipalName)
|
||||
user = &schema.User{Username: content.UserPrincipalName}
|
||||
err = user.Get(r.Context())
|
||||
if err != nil {
|
||||
handleOauthUserNotFound(w)
|
||||
return
|
||||
@@ -150,7 +153,8 @@ func handleAzureCallback(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
userRole, err := logic.GetRole(user.PlatformRoleID)
|
||||
userRole := &schema.UserRole{ID: user.PlatformRoleID}
|
||||
err = userRole.Get(r.Context())
|
||||
if err != nil {
|
||||
handleSomethingWentWrong(w)
|
||||
return
|
||||
@@ -175,23 +179,23 @@ func handleAzureCallback(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
logic.LogEvent(&models.Event{
|
||||
Action: models.Login,
|
||||
Action: schema.Login,
|
||||
Source: models.Subject{
|
||||
ID: user.UserName,
|
||||
Name: user.UserName,
|
||||
Type: models.UserSub,
|
||||
ID: user.Username,
|
||||
Name: user.Username,
|
||||
Type: schema.UserSub,
|
||||
},
|
||||
TriggeredBy: user.UserName,
|
||||
TriggeredBy: user.Username,
|
||||
Target: models.Subject{
|
||||
ID: models.DashboardSub.String(),
|
||||
Name: models.DashboardSub.String(),
|
||||
Type: models.DashboardSub,
|
||||
Info: user,
|
||||
ID: schema.DashboardSub.String(),
|
||||
Name: schema.DashboardSub.String(),
|
||||
Type: schema.DashboardSub,
|
||||
Info: logic.ToReturnUser(user),
|
||||
},
|
||||
Origin: models.Dashboard,
|
||||
Origin: schema.Dashboard,
|
||||
})
|
||||
logger.Log(1, "completed azure OAuth sigin in for", user.UserName)
|
||||
http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?login="+jwt+"&user="+user.UserName, http.StatusPermanentRedirect)
|
||||
logger.Log(1, "completed azure OAuth sigin in for", user.Username)
|
||||
http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?login="+jwt+"&user="+user.Username, http.StatusPermanentRedirect)
|
||||
}
|
||||
|
||||
func getAzureUserInfo(state string, code string) (*OAuthUser, error) {
|
||||
|
||||
+28
-23
@@ -9,14 +9,16 @@ import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gravitl/netmaker/database"
|
||||
"github.com/gravitl/netmaker/db"
|
||||
"github.com/gravitl/netmaker/logger"
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
proLogic "github.com/gravitl/netmaker/pro/logic"
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
"github.com/gravitl/netmaker/servercfg"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/github"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var github_functions = map[string]interface{}{
|
||||
@@ -91,29 +93,30 @@ func handleGithubCallback(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
// if user exists with provider ID, convert them into email ID
|
||||
user, err := logic.GetUser(content.Login)
|
||||
user := &schema.User{Username: content.Login}
|
||||
err = user.Get(r.Context())
|
||||
if err == nil {
|
||||
// if user exists, then ensure user's auth type is
|
||||
// oauth before proceeding.
|
||||
if user.AuthType == models.BasicAuth {
|
||||
if user.AuthType == schema.BasicAuth {
|
||||
logger.Log(0, "invalid auth type: basic_auth")
|
||||
handleAuthTypeMismatch(w)
|
||||
return
|
||||
}
|
||||
|
||||
// checks if user exists with email
|
||||
_, err := logic.GetUser(content.Email)
|
||||
emailCheck := &schema.User{Username: content.Email}
|
||||
err = emailCheck.Get(r.Context())
|
||||
if err != nil {
|
||||
user.UserName = content.Email
|
||||
user.Username = content.Email
|
||||
user.ExternalIdentityProviderID = content.Login
|
||||
_ = logic.DeleteUser(content.Login)
|
||||
_ = logic.UpsertUser(*user)
|
||||
_ = user.Update(db.WithContext(context.TODO()))
|
||||
}
|
||||
|
||||
}
|
||||
_, err = logic.GetUser(content.Email)
|
||||
emailCheck := &schema.User{Username: content.Email}
|
||||
err = emailCheck.Get(r.Context())
|
||||
if err != nil {
|
||||
if database.IsEmptyRecord(err) { // user must not exist, so try to make one
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) { // user must not exist, so try to make one
|
||||
if inviteExists {
|
||||
// create user
|
||||
user, err := proLogic.PrepareOauthUserFromInvite(in)
|
||||
@@ -136,7 +139,7 @@ func handleGithubCallback(w http.ResponseWriter, r *http.Request) {
|
||||
err = logic.InsertPendingUser(&models.User{
|
||||
UserName: content.Email,
|
||||
ExternalIdentityProviderID: string(content.ID),
|
||||
AuthType: models.OAuth,
|
||||
AuthType: schema.OAuth,
|
||||
})
|
||||
if err != nil {
|
||||
handleSomethingWentWrong(w)
|
||||
@@ -150,7 +153,8 @@ func handleGithubCallback(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
}
|
||||
user, err = logic.GetUser(content.Email)
|
||||
user = &schema.User{Username: content.Email}
|
||||
err = user.Get(r.Context())
|
||||
if err != nil {
|
||||
handleOauthUserNotFound(w)
|
||||
return
|
||||
@@ -161,7 +165,8 @@ func handleGithubCallback(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
userRole, err := logic.GetRole(user.PlatformRoleID)
|
||||
userRole := &schema.UserRole{ID: user.PlatformRoleID}
|
||||
err = userRole.Get(r.Context())
|
||||
if err != nil {
|
||||
handleSomethingWentWrong(w)
|
||||
return
|
||||
@@ -186,20 +191,20 @@ func handleGithubCallback(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
logic.LogEvent(&models.Event{
|
||||
Action: models.Login,
|
||||
Action: schema.Login,
|
||||
Source: models.Subject{
|
||||
ID: user.UserName,
|
||||
Name: user.UserName,
|
||||
Type: models.UserSub,
|
||||
ID: user.Username,
|
||||
Name: user.Username,
|
||||
Type: schema.UserSub,
|
||||
},
|
||||
TriggeredBy: user.UserName,
|
||||
TriggeredBy: user.Username,
|
||||
Target: models.Subject{
|
||||
ID: models.DashboardSub.String(),
|
||||
Name: models.DashboardSub.String(),
|
||||
Type: models.DashboardSub,
|
||||
Info: user,
|
||||
ID: schema.DashboardSub.String(),
|
||||
Name: schema.DashboardSub.String(),
|
||||
Type: schema.DashboardSub,
|
||||
Info: logic.ToReturnUser(user),
|
||||
},
|
||||
Origin: models.Dashboard,
|
||||
Origin: schema.Dashboard,
|
||||
})
|
||||
logger.Log(1, "completed github OAuth sigin in for", content.Email)
|
||||
http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?login="+jwt+"&user="+content.Email, http.StatusPermanentRedirect)
|
||||
|
||||
+23
-18
@@ -3,20 +3,22 @@ package auth
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gravitl/netmaker/database"
|
||||
"github.com/gravitl/netmaker/logger"
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
proLogic "github.com/gravitl/netmaker/pro/logic"
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
"github.com/gravitl/netmaker/servercfg"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/google"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var google_functions = map[string]interface{}{
|
||||
@@ -93,9 +95,10 @@ func handleGoogleCallback(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
user, err := logic.GetUser(content.Email)
|
||||
user := &schema.User{Username: content.Email}
|
||||
err = user.Get(r.Context())
|
||||
if err != nil {
|
||||
if database.IsEmptyRecord(err) { // user must not exist, so try to make one
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) { // user must not exist, so try to make one
|
||||
if inviteExists {
|
||||
// create user
|
||||
user, err := proLogic.PrepareOauthUserFromInvite(in)
|
||||
@@ -108,7 +111,7 @@ func handleGoogleCallback(w http.ResponseWriter, r *http.Request) {
|
||||
handleSomethingWentWrong(w)
|
||||
return
|
||||
}
|
||||
logic.DeleteUserInvite(user.UserName)
|
||||
logic.DeleteUserInvite(user.Username)
|
||||
logic.DeletePendingUser(content.Email)
|
||||
} else {
|
||||
if !isEmailAllowed(content.Email) {
|
||||
@@ -118,7 +121,7 @@ func handleGoogleCallback(w http.ResponseWriter, r *http.Request) {
|
||||
err = logic.InsertPendingUser(&models.User{
|
||||
UserName: content.Email,
|
||||
ExternalIdentityProviderID: string(content.ID),
|
||||
AuthType: models.OAuth,
|
||||
AuthType: schema.OAuth,
|
||||
})
|
||||
if err != nil {
|
||||
handleSomethingWentWrong(w)
|
||||
@@ -135,14 +138,15 @@ func handleGoogleCallback(w http.ResponseWriter, r *http.Request) {
|
||||
} else {
|
||||
// if user exists, then ensure user's auth type is
|
||||
// oauth before proceeding.
|
||||
if user.AuthType == models.BasicAuth {
|
||||
if user.AuthType == schema.BasicAuth {
|
||||
logger.Log(0, "invalid auth type: basic_auth")
|
||||
handleAuthTypeMismatch(w)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
user, err = logic.GetUser(content.Email)
|
||||
user = &schema.User{Username: content.Email}
|
||||
err = user.Get(r.Context())
|
||||
if err != nil {
|
||||
logger.Log(0, "error fetching user: ", err.Error())
|
||||
handleOauthUserNotFound(w)
|
||||
@@ -154,7 +158,8 @@ func handleGoogleCallback(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
userRole, err := logic.GetRole(user.PlatformRoleID)
|
||||
userRole := &schema.UserRole{ID: user.PlatformRoleID}
|
||||
err = userRole.Get(r.Context())
|
||||
if err != nil {
|
||||
handleSomethingWentWrong(w)
|
||||
return
|
||||
@@ -180,20 +185,20 @@ func handleGoogleCallback(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
logic.LogEvent(&models.Event{
|
||||
Action: models.Login,
|
||||
Action: schema.Login,
|
||||
Source: models.Subject{
|
||||
ID: user.UserName,
|
||||
Name: user.UserName,
|
||||
Type: models.UserSub,
|
||||
ID: user.Username,
|
||||
Name: user.Username,
|
||||
Type: schema.UserSub,
|
||||
},
|
||||
TriggeredBy: user.UserName,
|
||||
TriggeredBy: user.Username,
|
||||
Target: models.Subject{
|
||||
ID: models.DashboardSub.String(),
|
||||
Name: models.DashboardSub.String(),
|
||||
Type: models.DashboardSub,
|
||||
Info: user,
|
||||
ID: schema.DashboardSub.String(),
|
||||
Name: schema.DashboardSub.String(),
|
||||
Type: schema.DashboardSub,
|
||||
Info: logic.ToReturnUser(user),
|
||||
},
|
||||
Origin: models.Dashboard,
|
||||
Origin: schema.Dashboard,
|
||||
})
|
||||
|
||||
logger.Log(1, "completed google OAuth sigin in for", content.Email)
|
||||
|
||||
@@ -2,14 +2,16 @@ package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/gravitl/netmaker/database"
|
||||
"github.com/gravitl/netmaker/logger"
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/logic/pro/netcache"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// HandleHeadlessSSOCallback - handle OAuth callback for headless logins such as Netmaker CLI
|
||||
@@ -60,13 +62,14 @@ func HandleHeadlessSSOCallback(w http.ResponseWriter, r *http.Request) {
|
||||
handleOauthUserSignUpApprovalPending(w)
|
||||
return
|
||||
}
|
||||
user, err := logic.GetUser(userClaims.getUserName())
|
||||
user := &schema.User{Username: userClaims.getUserName()}
|
||||
err = user.Get(r.Context())
|
||||
if err != nil {
|
||||
if database.IsEmptyRecord(err) { // user must not exist, so try to make one
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) { // user must not exist, so try to make one
|
||||
err = logic.InsertPendingUser(&models.User{
|
||||
UserName: userClaims.getUserName(),
|
||||
ExternalIdentityProviderID: string(userClaims.ID),
|
||||
AuthType: models.OAuth,
|
||||
AuthType: schema.OAuth,
|
||||
})
|
||||
if err != nil {
|
||||
handleSomethingWentWrong(w)
|
||||
@@ -88,7 +91,7 @@ func HandleHeadlessSSOCallback(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
jwt, jwtErr := logic.VerifyAuthRequest(models.UserAuthParams{
|
||||
UserName: user.UserName,
|
||||
UserName: user.Username,
|
||||
Password: newPass,
|
||||
}, logic.NetclientApp)
|
||||
if jwtErr != nil {
|
||||
|
||||
+23
-18
@@ -2,19 +2,21 @@ package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/gravitl/netmaker/database"
|
||||
"github.com/gravitl/netmaker/logger"
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
proLogic "github.com/gravitl/netmaker/pro/logic"
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
"github.com/gravitl/netmaker/servercfg"
|
||||
"golang.org/x/oauth2"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const OIDC_TIMEOUT = 10 * time.Second
|
||||
@@ -102,9 +104,10 @@ func handleOIDCCallback(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
user, err := logic.GetUser(content.Email)
|
||||
user := &schema.User{Username: content.Email}
|
||||
err = user.Get(r.Context())
|
||||
if err != nil {
|
||||
if database.IsEmptyRecord(err) { // user must not exist, so try to make one
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) { // user must not exist, so try to make one
|
||||
if inviteExists {
|
||||
// create user
|
||||
user, err := proLogic.PrepareOauthUserFromInvite(in)
|
||||
@@ -117,7 +120,7 @@ func handleOIDCCallback(w http.ResponseWriter, r *http.Request) {
|
||||
handleSomethingWentWrong(w)
|
||||
return
|
||||
}
|
||||
logic.DeleteUserInvite(user.UserName)
|
||||
logic.DeleteUserInvite(user.Username)
|
||||
logic.DeletePendingUser(content.Email)
|
||||
} else {
|
||||
if !isEmailAllowed(content.Email) {
|
||||
@@ -127,7 +130,7 @@ func handleOIDCCallback(w http.ResponseWriter, r *http.Request) {
|
||||
err = logic.InsertPendingUser(&models.User{
|
||||
UserName: content.Email,
|
||||
ExternalIdentityProviderID: string(content.ID),
|
||||
AuthType: models.OAuth,
|
||||
AuthType: schema.OAuth,
|
||||
})
|
||||
if err != nil {
|
||||
handleSomethingWentWrong(w)
|
||||
@@ -143,14 +146,15 @@ func handleOIDCCallback(w http.ResponseWriter, r *http.Request) {
|
||||
} else {
|
||||
// if user exists, then ensure user's auth type is
|
||||
// oauth before proceeding.
|
||||
if user.AuthType == models.BasicAuth {
|
||||
if user.AuthType == schema.BasicAuth {
|
||||
logger.Log(0, "invalid auth type: basic_auth")
|
||||
handleAuthTypeMismatch(w)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
user, err = logic.GetUser(content.Email)
|
||||
user = &schema.User{Username: content.Email}
|
||||
err = user.Get(r.Context())
|
||||
if err != nil {
|
||||
handleOauthUserNotFound(w)
|
||||
return
|
||||
@@ -161,7 +165,8 @@ func handleOIDCCallback(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
userRole, err := logic.GetRole(user.PlatformRoleID)
|
||||
userRole := &schema.UserRole{ID: user.PlatformRoleID}
|
||||
err = userRole.Get(r.Context())
|
||||
if err != nil {
|
||||
handleSomethingWentWrong(w)
|
||||
return
|
||||
@@ -186,20 +191,20 @@ func handleOIDCCallback(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
logic.LogEvent(&models.Event{
|
||||
Action: models.Login,
|
||||
Action: schema.Login,
|
||||
Source: models.Subject{
|
||||
ID: user.UserName,
|
||||
Name: user.UserName,
|
||||
Type: models.UserSub,
|
||||
ID: user.Username,
|
||||
Name: user.Username,
|
||||
Type: schema.UserSub,
|
||||
},
|
||||
TriggeredBy: user.UserName,
|
||||
TriggeredBy: user.Username,
|
||||
Target: models.Subject{
|
||||
ID: models.DashboardSub.String(),
|
||||
Name: models.DashboardSub.String(),
|
||||
Type: models.DashboardSub,
|
||||
Info: user,
|
||||
ID: schema.DashboardSub.String(),
|
||||
Name: schema.DashboardSub.String(),
|
||||
Type: schema.DashboardSub,
|
||||
Info: logic.ToReturnUser(user),
|
||||
},
|
||||
Origin: models.Dashboard,
|
||||
Origin: schema.Dashboard,
|
||||
})
|
||||
logger.Log(1, "completed OIDC OAuth signin in for", content.Email)
|
||||
http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?login="+jwt+"&user="+content.Email, http.StatusPermanentRedirect)
|
||||
|
||||
@@ -8,9 +8,8 @@ import (
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/gravitl/netmaker/logger"
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/logic/pro/netcache"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -69,12 +68,13 @@ func HandleHostSSOCallback(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
// check if user exists
|
||||
user, err := logic.GetUser(userClaims.getUserName())
|
||||
user := &schema.User{Username: userClaims.getUserName()}
|
||||
err = user.Get(r.Context())
|
||||
if err != nil {
|
||||
handleOauthUserNotFound(w)
|
||||
return
|
||||
}
|
||||
if user.PlatformRoleID != models.AdminRole && user.PlatformRoleID != models.SuperAdminRole {
|
||||
if user.PlatformRoleID != schema.AdminRole && user.PlatformRoleID != schema.SuperAdminRole {
|
||||
response := returnErrTemplate(userClaims.getUserName(), "only admin users can register using SSO", state, reqKeyIf)
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
w.Write(response)
|
||||
|
||||
+42
-35
@@ -7,7 +7,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gravitl/netmaker/database"
|
||||
"github.com/gravitl/netmaker/db"
|
||||
"github.com/gravitl/netmaker/logger"
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
@@ -17,7 +17,9 @@ import (
|
||||
"github.com/gravitl/netmaker/pro/idp/google"
|
||||
"github.com/gravitl/netmaker/pro/idp/okta"
|
||||
proLogic "github.com/gravitl/netmaker/pro/logic"
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
"github.com/gravitl/netmaker/servercfg"
|
||||
"gorm.io/datatypes"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -126,8 +128,8 @@ func SyncFromIDP() error {
|
||||
}
|
||||
|
||||
func syncUsers(idpUsers []idp.User) error {
|
||||
dbUsers, err := logic.GetUsersDB()
|
||||
if err != nil && !database.IsEmptyRecord(err) {
|
||||
dbUsers, err := (&schema.User{}).ListAll(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -141,9 +143,9 @@ func syncUsers(idpUsers []idp.User) error {
|
||||
idpUsersMap[user.Username] = struct{}{}
|
||||
}
|
||||
|
||||
dbUsersMap := make(map[string]models.User)
|
||||
dbUsersMap := make(map[string]*schema.User)
|
||||
for _, user := range dbUsers {
|
||||
dbUsersMap[user.UserName] = user
|
||||
dbUsersMap[user.Username] = &user
|
||||
}
|
||||
|
||||
filters := logic.GetServerSettings().UserFilters
|
||||
@@ -151,8 +153,10 @@ func syncUsers(idpUsers []idp.User) error {
|
||||
for _, user := range idpUsers {
|
||||
if user.AccountArchived {
|
||||
// delete the user if it has been archived.
|
||||
user := dbUsersMap[user.Username]
|
||||
_ = deleteAndCleanUpUser(&user)
|
||||
user, ok := dbUsersMap[user.Username]
|
||||
if ok {
|
||||
_ = deleteAndCleanUpUser(user)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -172,14 +176,14 @@ func syncUsers(idpUsers []idp.User) error {
|
||||
dbUser, ok := dbUsersMap[user.Username]
|
||||
if !ok {
|
||||
// create the user only if it doesn't exist.
|
||||
err = logic.CreateUser(&models.User{
|
||||
UserName: user.Username,
|
||||
err = logic.CreateUser(&schema.User{
|
||||
Username: user.Username,
|
||||
ExternalIdentityProviderID: user.ID,
|
||||
DisplayName: user.DisplayName,
|
||||
AccountDisabled: user.AccountDisabled,
|
||||
Password: password,
|
||||
AuthType: models.OAuth,
|
||||
PlatformRoleID: models.ServiceUser,
|
||||
AuthType: schema.OAuth,
|
||||
PlatformRoleID: schema.ServiceUser,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -191,7 +195,7 @@ func syncUsers(idpUsers []idp.User) error {
|
||||
// created. Now, since the user is created, the pending user
|
||||
// can be deleted.
|
||||
_ = logic.DeletePendingUser(user.Username)
|
||||
} else if dbUser.AuthType == models.OAuth {
|
||||
} else if dbUser.AuthType == schema.OAuth {
|
||||
if dbUser.AccountDisabled != user.AccountDisabled ||
|
||||
dbUser.DisplayName != user.DisplayName ||
|
||||
dbUser.ExternalIdentityProviderID != user.ID {
|
||||
@@ -200,7 +204,7 @@ func syncUsers(idpUsers []idp.User) error {
|
||||
dbUser.DisplayName = user.DisplayName
|
||||
dbUser.ExternalIdentityProviderID = user.ID
|
||||
|
||||
err = logic.UpsertUser(dbUser)
|
||||
err = logic.UpsertUser(*dbUser)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -213,10 +217,10 @@ func syncUsers(idpUsers []idp.User) error {
|
||||
|
||||
for _, user := range dbUsersMap {
|
||||
if user.ExternalIdentityProviderID != "" {
|
||||
if _, ok := idpUsersMap[user.UserName]; !ok {
|
||||
if _, ok := idpUsersMap[user.Username]; !ok {
|
||||
// delete the user if it has been deleted on idp
|
||||
// or is filtered out.
|
||||
err = deleteAndCleanUpUser(&user)
|
||||
err = deleteAndCleanUpUser(user)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -228,13 +232,13 @@ func syncUsers(idpUsers []idp.User) error {
|
||||
}
|
||||
|
||||
func syncGroups(idpGroups []idp.Group) error {
|
||||
dbGroups, err := proLogic.ListUserGroups()
|
||||
if err != nil && !database.IsEmptyRecord(err) {
|
||||
dbGroups, err := (&schema.UserGroup{}).ListAll(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dbUsers, err := logic.GetUsersDB()
|
||||
if err != nil && !database.IsEmptyRecord(err) {
|
||||
dbUsers, err := (&schema.User{}).ListAll(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -243,17 +247,17 @@ func syncGroups(idpGroups []idp.Group) error {
|
||||
idpGroupsMap[group.ID] = struct{}{}
|
||||
}
|
||||
|
||||
dbGroupsMap := make(map[string]models.UserGroup)
|
||||
dbGroupsMap := make(map[string]schema.UserGroup)
|
||||
for _, group := range dbGroups {
|
||||
if group.ExternalIdentityProviderID != "" {
|
||||
dbGroupsMap[group.ExternalIdentityProviderID] = group
|
||||
}
|
||||
}
|
||||
|
||||
dbUsersMap := make(map[string]models.User)
|
||||
dbUsersMap := make(map[string]*schema.User)
|
||||
for _, user := range dbUsers {
|
||||
if user.ExternalIdentityProviderID != "" {
|
||||
dbUsersMap[user.ExternalIdentityProviderID] = user
|
||||
dbUsersMap[user.ExternalIdentityProviderID] = &user
|
||||
}
|
||||
}
|
||||
|
||||
@@ -261,7 +265,7 @@ func syncGroups(idpGroups []idp.Group) error {
|
||||
|
||||
filters := logic.GetServerSettings().GroupFilters
|
||||
|
||||
networks, err := logic.GetNetworks()
|
||||
networks, err := (&schema.Network{}).ListAll(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -269,7 +273,7 @@ func syncGroups(idpGroups []idp.Group) error {
|
||||
var aclsUpdated bool
|
||||
var acls []models.Acl
|
||||
for _, network := range networks {
|
||||
aclID := fmt.Sprintf("%s.%s-grp", network.NetID, models.NetworkUser)
|
||||
aclID := fmt.Sprintf("%s.%s-grp", network.Name, schema.NetworkUser)
|
||||
acl, err := logic.GetAcl(aclID)
|
||||
if err == nil {
|
||||
acls = append(acls, acl)
|
||||
@@ -295,7 +299,7 @@ func syncGroups(idpGroups []idp.Group) error {
|
||||
dbGroup.ExternalIdentityProviderID = group.ID
|
||||
dbGroup.Name = group.Name
|
||||
dbGroup.Default = false
|
||||
dbGroup.NetworkRoles = map[models.NetworkID]map[models.UserRoleID]struct{}{}
|
||||
dbGroup.NetworkRoles = datatypes.NewJSONType(schema.NetworkRoles{})
|
||||
err := proLogic.CreateUserGroup(&dbGroup)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -323,27 +327,30 @@ func syncGroups(idpGroups []idp.Group) error {
|
||||
|
||||
for _, user := range dbUsers {
|
||||
// use dbGroup.Name because the group name may have been changed on idp.
|
||||
_, inNetmakerGroup := user.UserGroups[dbGroup.ID]
|
||||
_, inNetmakerGroup := user.UserGroups.Data()[dbGroup.ID]
|
||||
_, inIDPGroup := groupMembersMap[user.ExternalIdentityProviderID]
|
||||
|
||||
if inNetmakerGroup && !inIDPGroup {
|
||||
// use dbGroup.Name because the group name may have been changed on idp.
|
||||
delete(dbUsersMap[user.ExternalIdentityProviderID].UserGroups, dbGroup.ID)
|
||||
delete(dbUsersMap[user.ExternalIdentityProviderID].UserGroups.Data(), dbGroup.ID)
|
||||
modifiedUsers[user.ExternalIdentityProviderID] = struct{}{}
|
||||
}
|
||||
|
||||
if !inNetmakerGroup && inIDPGroup {
|
||||
// use dbGroup.Name because the group name may have been changed on idp.
|
||||
dbUsersMap[user.ExternalIdentityProviderID].UserGroups[dbGroup.ID] = struct{}{}
|
||||
dbUsersMap[user.ExternalIdentityProviderID].UserGroups.Data()[dbGroup.ID] = struct{}{}
|
||||
modifiedUsers[user.ExternalIdentityProviderID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for userID := range modifiedUsers {
|
||||
err = logic.UpsertUser(dbUsersMap[userID])
|
||||
if err != nil {
|
||||
return err
|
||||
user, ok := dbUsersMap[userID]
|
||||
if ok {
|
||||
err = logic.UpsertUser(*user)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -456,8 +463,8 @@ func filterGroupsByMembers(idpGroups []idp.Group, idpUsers []idp.User) []idp.Gro
|
||||
// TODO: deduplicate
|
||||
// The cyclic import between the package logic and mq requires this
|
||||
// function to be duplicated in multiple places.
|
||||
func deleteAndCleanUpUser(user *models.User) error {
|
||||
err := logic.DeleteUser(user.UserName)
|
||||
func deleteAndCleanUpUser(user *schema.User) error {
|
||||
err := logic.DeleteUser(user.Username)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -469,7 +476,7 @@ func deleteAndCleanUpUser(user *models.User) error {
|
||||
return
|
||||
}
|
||||
for _, extclient := range extclients {
|
||||
if extclient.OwnerID == user.UserName {
|
||||
if extclient.OwnerID == user.Username {
|
||||
err = logic.DeleteExtClientAndCleanup(extclient)
|
||||
if err == nil {
|
||||
_ = mq.PublishDeletedClientPeerUpdate(&extclient)
|
||||
@@ -477,7 +484,7 @@ func deleteAndCleanUpUser(user *models.User) error {
|
||||
}
|
||||
}
|
||||
|
||||
go logic.DeleteUserInvite(user.UserName)
|
||||
go logic.DeleteUserInvite(user.Username)
|
||||
go mq.PublishPeerUpdate(false)
|
||||
if servercfg.IsDNSMode() {
|
||||
go logic.SetDNS()
|
||||
|
||||
@@ -65,7 +65,7 @@ func getAutoRelayGws(w http.ResponseWriter, r *http.Request) {
|
||||
)
|
||||
return
|
||||
}
|
||||
defaultPolicy, err := logic.GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy)
|
||||
defaultPolicy, err := logic.GetDefaultPolicy(schema.NetworkID(node.Network), models.DevicePolicy)
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
|
||||
return
|
||||
@@ -205,7 +205,10 @@ func autoRelayME(w http.ResponseWriter, r *http.Request) {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
|
||||
return
|
||||
}
|
||||
host, err := logic.GetHost(node.HostID.String())
|
||||
host := &schema.Host{
|
||||
ID: node.HostID,
|
||||
}
|
||||
err = host.Get(r.Context())
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
|
||||
return
|
||||
@@ -243,7 +246,7 @@ func autoRelayME(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
eli, _ := (&schema.Egress{Network: node.Network}).ListByNetwork(db.WithContext(context.TODO()))
|
||||
acls, _ := logic.ListAclsByNetwork(models.NetworkID(node.Network))
|
||||
acls, _ := logic.ListAclsByNetwork(schema.NetworkID(node.Network))
|
||||
logic.GetNodeEgressInfo(&node, eli, acls)
|
||||
logic.GetNodeEgressInfo(&peerNode, eli, acls)
|
||||
logic.GetNodeEgressInfo(&autoRelayNode, eli, acls)
|
||||
@@ -370,8 +373,10 @@ func autoRelayMEUpdate(w http.ResponseWriter, r *http.Request) {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
|
||||
return
|
||||
}
|
||||
|
||||
host, err := logic.GetHost(node.HostID.String())
|
||||
host := &schema.Host{
|
||||
ID: node.HostID,
|
||||
}
|
||||
err = host.Get(r.Context())
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
|
||||
return
|
||||
@@ -532,7 +537,10 @@ func checkautoRelayCtx(w http.ResponseWriter, r *http.Request) {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
|
||||
return
|
||||
}
|
||||
host, err := logic.GetHost(node.HostID.String())
|
||||
host := &schema.Host{
|
||||
ID: node.HostID,
|
||||
}
|
||||
err = host.Get(r.Context())
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
|
||||
return
|
||||
@@ -568,7 +576,7 @@ func checkautoRelayCtx(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
eli, _ := (&schema.Egress{Network: node.Network}).ListByNetwork(db.WithContext(context.TODO()))
|
||||
acls, _ := logic.ListAclsByNetwork(models.NetworkID(node.Network))
|
||||
acls, _ := logic.ListAclsByNetwork(schema.NetworkID(node.Network))
|
||||
logic.GetNodeEgressInfo(&node, eli, acls)
|
||||
logic.GetNodeEgressInfo(&peerNode, eli, acls)
|
||||
logic.GetNodeEgressInfo(&autoRelayNode, eli, acls)
|
||||
|
||||
@@ -66,7 +66,7 @@ func listNetworkActivity(w http.ResponseWriter, r *http.Request) {
|
||||
page, _ := strconv.Atoi(r.URL.Query().Get("page"))
|
||||
pageSize, _ := strconv.Atoi(r.URL.Query().Get("per_page"))
|
||||
ctx := db.WithContext(r.Context())
|
||||
netActivity, err := (&schema.Event{NetworkID: models.NetworkID(netID)}).ListByNetwork(db.SetPagination(ctx, page, pageSize), fromDate, toDate)
|
||||
netActivity, err := (&schema.Event{NetworkID: schema.NetworkID(netID)}).ListByNetwork(db.SetPagination(ctx, page, pageSize), fromDate, toDate)
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(w, r, models.ErrorResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
@@ -100,12 +100,13 @@ func listUserActivity(w http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
return
|
||||
}
|
||||
caller, err := logic.GetUser(r.Header.Get("user"))
|
||||
caller := &schema.User{Username: r.Header.Get("user")}
|
||||
err := caller.Get(r.Context())
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
|
||||
return
|
||||
}
|
||||
if caller.UserName != username && caller.PlatformRoleID != models.SuperAdminRole && caller.PlatformRoleID != models.AdminRole {
|
||||
if caller.Username != username && caller.PlatformRoleID != schema.SuperAdminRole && caller.PlatformRoleID != schema.AdminRole {
|
||||
logic.ReturnErrorResponse(w, r, models.ErrorResponse{
|
||||
Code: http.StatusForbidden,
|
||||
Message: "you are not authorized to view this user's activity",
|
||||
@@ -190,7 +191,7 @@ func listActivity(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
var events []schema.Event
|
||||
e := &schema.Event{TriggeredBy: username, NetworkID: models.NetworkID(network)}
|
||||
e := &schema.Event{TriggeredBy: username, NetworkID: schema.NetworkID(network)}
|
||||
if username != "" && network != "" {
|
||||
events, err = e.ListByUserAndNetwork(db.SetPagination(ctx, page, pageSize), fromDate, toDate)
|
||||
} else if username != "" && network == "" {
|
||||
|
||||
@@ -142,7 +142,10 @@ func failOverME(w http.ResponseWriter, r *http.Request) {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
|
||||
return
|
||||
}
|
||||
host, err := logic.GetHost(node.HostID.String())
|
||||
host := &schema.Host{
|
||||
ID: node.HostID,
|
||||
}
|
||||
err = host.Get(r.Context())
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
|
||||
return
|
||||
@@ -179,7 +182,7 @@ func failOverME(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
eli, _ := (&schema.Egress{Network: node.Network}).ListByNetwork(db.WithContext(context.TODO()))
|
||||
acls, _ := logic.ListAclsByNetwork(models.NetworkID(node.Network))
|
||||
acls, _ := logic.ListAclsByNetwork(schema.NetworkID(node.Network))
|
||||
logic.GetNodeEgressInfo(&node, eli, acls)
|
||||
logic.GetNodeEgressInfo(&peerNode, eli, acls)
|
||||
logic.GetNodeEgressInfo(&failOverNode, eli, acls)
|
||||
@@ -295,7 +298,10 @@ func checkfailOverCtx(w http.ResponseWriter, r *http.Request) {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
|
||||
return
|
||||
}
|
||||
host, err := logic.GetHost(node.HostID.String())
|
||||
host := &schema.Host{
|
||||
ID: node.HostID,
|
||||
}
|
||||
err = host.Get(r.Context())
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
|
||||
return
|
||||
@@ -331,7 +337,7 @@ func checkfailOverCtx(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
eli, _ := (&schema.Egress{Network: node.Network}).ListByNetwork(db.WithContext(context.TODO()))
|
||||
acls, _ := logic.ListAclsByNetwork(models.NetworkID(node.Network))
|
||||
acls, _ := logic.ListAclsByNetwork(schema.NetworkID(node.Network))
|
||||
logic.GetNodeEgressInfo(&node, eli, acls)
|
||||
logic.GetNodeEgressInfo(&peerNode, eli, acls)
|
||||
logic.GetNodeEgressInfo(&failOverNode, eli, acls)
|
||||
|
||||
+87
-78
@@ -1,6 +1,7 @@
|
||||
package controllers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
@@ -76,7 +77,8 @@ func handleJIT(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
user, err := logic.GetUser(username)
|
||||
user := &schema.User{Username: username}
|
||||
err := user.Get(r.Context())
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "unauthorized"))
|
||||
return
|
||||
@@ -93,7 +95,7 @@ func handleJIT(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// handleJITGet - handles GET requests for JIT status/requests
|
||||
func handleJITGet(w http.ResponseWriter, r *http.Request, networkID string, user *models.User) {
|
||||
func handleJITGet(w http.ResponseWriter, r *http.Request, networkID string, user *schema.User) {
|
||||
statusFilter := r.URL.Query().Get("status") // "pending", "approved", "denied", "expired", or empty for all
|
||||
|
||||
// Parse pagination parameters (default to 0, db.SetPagination will apply defaults)
|
||||
@@ -121,19 +123,19 @@ func handleJITGet(w http.ResponseWriter, r *http.Request, networkID string, user
|
||||
totalPages = 1
|
||||
}
|
||||
|
||||
response := map[string]interface{}{
|
||||
"data": requests,
|
||||
"page": page,
|
||||
"per_page": pageSize,
|
||||
"total": total,
|
||||
"total_pages": totalPages,
|
||||
response := models.PaginatedResponse{
|
||||
Data: requests,
|
||||
Page: page,
|
||||
PerPage: pageSize,
|
||||
Total: int(total),
|
||||
TotalPages: totalPages,
|
||||
}
|
||||
|
||||
logic.ReturnSuccessResponseWithJson(w, r, response, "fetched JIT requests")
|
||||
}
|
||||
|
||||
// handleJITPost - handles POST requests for JIT operations
|
||||
func handleJITPost(w http.ResponseWriter, r *http.Request, networkID string, user *models.User) {
|
||||
func handleJITPost(w http.ResponseWriter, r *http.Request, networkID string, user *schema.User) {
|
||||
var req models.JITOperationRequest
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
@@ -157,7 +159,7 @@ func handleJITPost(w http.ResponseWriter, r *http.Request, networkID string, use
|
||||
}
|
||||
|
||||
// handleEnableJIT - enables JIT on a network
|
||||
func handleEnableJIT(w http.ResponseWriter, r *http.Request, networkID string, user *models.User) {
|
||||
func handleEnableJIT(w http.ResponseWriter, r *http.Request, networkID string, user *schema.User) {
|
||||
// Check if user is admin
|
||||
if !proLogic.IsNetworkAdmin(user, networkID) {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("only network admins can enable JIT"), "forbidden"))
|
||||
@@ -170,27 +172,27 @@ func handleEnableJIT(w http.ResponseWriter, r *http.Request, networkID string, u
|
||||
}
|
||||
|
||||
logic.LogEvent(&models.Event{
|
||||
Action: models.Update,
|
||||
Action: schema.Update,
|
||||
Source: models.Subject{
|
||||
ID: user.UserName,
|
||||
Name: user.UserName,
|
||||
Type: models.UserSub,
|
||||
ID: user.Username,
|
||||
Name: user.Username,
|
||||
Type: schema.UserSub,
|
||||
},
|
||||
TriggeredBy: user.UserName,
|
||||
TriggeredBy: user.Username,
|
||||
Target: models.Subject{
|
||||
ID: networkID,
|
||||
Name: networkID,
|
||||
Type: models.NetworkSub,
|
||||
Type: schema.NetworkSub,
|
||||
},
|
||||
NetworkID: models.NetworkID(networkID),
|
||||
Origin: models.Dashboard,
|
||||
NetworkID: schema.NetworkID(networkID),
|
||||
Origin: schema.Dashboard,
|
||||
})
|
||||
|
||||
logic.ReturnSuccessResponse(w, r, "JIT enabled on network")
|
||||
}
|
||||
|
||||
// handleDisableJIT - disables JIT on a network
|
||||
func handleDisableJIT(w http.ResponseWriter, r *http.Request, networkID string, user *models.User) {
|
||||
func handleDisableJIT(w http.ResponseWriter, r *http.Request, networkID string, user *schema.User) {
|
||||
// Check if user is admin
|
||||
if !proLogic.IsNetworkAdmin(user, networkID) {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("only network admins can disable JIT"), "forbidden"))
|
||||
@@ -203,27 +205,27 @@ func handleDisableJIT(w http.ResponseWriter, r *http.Request, networkID string,
|
||||
}
|
||||
|
||||
logic.LogEvent(&models.Event{
|
||||
Action: models.Update,
|
||||
Action: schema.Update,
|
||||
Source: models.Subject{
|
||||
ID: user.UserName,
|
||||
Name: user.UserName,
|
||||
Type: models.UserSub,
|
||||
ID: user.Username,
|
||||
Name: user.Username,
|
||||
Type: schema.UserSub,
|
||||
},
|
||||
TriggeredBy: user.UserName,
|
||||
TriggeredBy: user.Username,
|
||||
Target: models.Subject{
|
||||
ID: networkID,
|
||||
Name: networkID,
|
||||
Type: models.NetworkSub,
|
||||
Type: schema.NetworkSub,
|
||||
},
|
||||
NetworkID: models.NetworkID(networkID),
|
||||
Origin: models.Dashboard,
|
||||
NetworkID: schema.NetworkID(networkID),
|
||||
Origin: schema.Dashboard,
|
||||
})
|
||||
|
||||
logic.ReturnSuccessResponse(w, r, "JIT disabled on network")
|
||||
}
|
||||
|
||||
// handleApproveRequest - approves a JIT request
|
||||
func handleApproveRequest(w http.ResponseWriter, r *http.Request, networkID string, user *models.User, requestID string, expiresAtEpoch int64) {
|
||||
func handleApproveRequest(w http.ResponseWriter, r *http.Request, networkID string, user *schema.User, requestID string, expiresAtEpoch int64) {
|
||||
// Check if user is admin
|
||||
if !proLogic.IsNetworkAdmin(user, networkID) {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("only network admins can approve requests"), "forbidden"))
|
||||
@@ -250,40 +252,41 @@ func handleApproveRequest(w http.ResponseWriter, r *http.Request, networkID stri
|
||||
return
|
||||
}
|
||||
|
||||
grant, req, err := proLogic.ApproveJITRequest(requestID, expiresAt, user.UserName)
|
||||
grant, req, err := proLogic.ApproveJITRequest(requestID, expiresAt, user.Username)
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
|
||||
return
|
||||
}
|
||||
// Send approval email to user
|
||||
go func() {
|
||||
network, _ := logic.GetNetwork(networkID)
|
||||
network := &schema.Network{Name: networkID}
|
||||
_ = network.Get(r.Context())
|
||||
if err := email.SendJITApprovalEmail(grant, req, network); err != nil {
|
||||
slog.Error("failed to send approval notification", "error", err)
|
||||
}
|
||||
}()
|
||||
logic.LogEvent(&models.Event{
|
||||
Action: models.Update,
|
||||
Action: schema.Update,
|
||||
Source: models.Subject{
|
||||
ID: user.UserName,
|
||||
Name: user.UserName,
|
||||
Type: models.UserSub,
|
||||
ID: user.Username,
|
||||
Name: user.Username,
|
||||
Type: schema.UserSub,
|
||||
},
|
||||
TriggeredBy: user.UserName,
|
||||
TriggeredBy: user.Username,
|
||||
Target: models.Subject{
|
||||
ID: requestID,
|
||||
Name: networkID,
|
||||
Type: models.NetworkSub,
|
||||
Type: schema.NetworkSub,
|
||||
},
|
||||
NetworkID: models.NetworkID(networkID),
|
||||
Origin: models.Dashboard,
|
||||
NetworkID: schema.NetworkID(networkID),
|
||||
Origin: schema.Dashboard,
|
||||
})
|
||||
|
||||
logic.ReturnSuccessResponseWithJson(w, r, grant, "JIT request approved")
|
||||
}
|
||||
|
||||
// handleDenyRequest - denies a JIT request
|
||||
func handleDenyRequest(w http.ResponseWriter, r *http.Request, networkID string, user *models.User, requestID string) {
|
||||
func handleDenyRequest(w http.ResponseWriter, r *http.Request, networkID string, user *schema.User, requestID string) {
|
||||
// Check if user is admin
|
||||
if !proLogic.IsNetworkAdmin(user, networkID) {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("only network admins can deny requests"), "forbidden"))
|
||||
@@ -295,7 +298,7 @@ func handleDenyRequest(w http.ResponseWriter, r *http.Request, networkID string,
|
||||
return
|
||||
}
|
||||
|
||||
request, err := proLogic.DenyJITRequest(requestID, user.UserName)
|
||||
request, err := proLogic.DenyJITRequest(requestID, user.Username)
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
|
||||
return
|
||||
@@ -303,27 +306,28 @@ func handleDenyRequest(w http.ResponseWriter, r *http.Request, networkID string,
|
||||
|
||||
// Send denial email to requester
|
||||
go func() {
|
||||
network, _ := logic.GetNetwork(networkID)
|
||||
network := &schema.Network{Name: networkID}
|
||||
_ = network.Get(db.WithContext(context.TODO()))
|
||||
if err := email.SendJITDeniedEmail(request, network); err != nil {
|
||||
slog.Error("failed to send JIT denied notification", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
logic.LogEvent(&models.Event{
|
||||
Action: models.Update,
|
||||
Action: schema.Update,
|
||||
Source: models.Subject{
|
||||
ID: user.UserName,
|
||||
Name: user.UserName,
|
||||
Type: models.UserSub,
|
||||
ID: user.Username,
|
||||
Name: user.Username,
|
||||
Type: schema.UserSub,
|
||||
},
|
||||
TriggeredBy: user.UserName,
|
||||
TriggeredBy: user.Username,
|
||||
Target: models.Subject{
|
||||
ID: requestID,
|
||||
Name: networkID,
|
||||
Type: models.NetworkSub,
|
||||
Type: schema.NetworkSub,
|
||||
},
|
||||
NetworkID: models.NetworkID(networkID),
|
||||
Origin: models.Dashboard,
|
||||
NetworkID: schema.NetworkID(networkID),
|
||||
Origin: schema.Dashboard,
|
||||
})
|
||||
|
||||
logic.ReturnSuccessResponse(w, r, "JIT request denied")
|
||||
@@ -361,7 +365,8 @@ func deleteJITGrant(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
user, err := logic.GetUser(username)
|
||||
user := &schema.User{Username: username}
|
||||
err := user.Get(r.Context())
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "unauthorized"))
|
||||
return
|
||||
@@ -418,9 +423,10 @@ func deleteJITGrant(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// Send email notification to user
|
||||
if revokedRequest != nil {
|
||||
network, err := logic.GetNetwork(networkID)
|
||||
network := &schema.Network{Name: networkID}
|
||||
err := network.Get(r.Context())
|
||||
if err == nil {
|
||||
if err := email.SendJITExpirationEmail(&grant, revokedRequest, network, true, user.UserName); err != nil {
|
||||
if err := email.SendJITExpirationEmail(&grant, revokedRequest, network, true, user.Username); err != nil {
|
||||
slog.Warn("failed to send revocation email", "grant_id", grantID, "user", revokedRequest.UserName, "error", err)
|
||||
}
|
||||
}
|
||||
@@ -432,20 +438,20 @@ func deleteJITGrant(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
logic.LogEvent(&models.Event{
|
||||
Action: models.Delete,
|
||||
Action: schema.Delete,
|
||||
Source: models.Subject{
|
||||
ID: user.UserName,
|
||||
Name: user.UserName,
|
||||
Type: models.UserSub,
|
||||
ID: user.Username,
|
||||
Name: user.Username,
|
||||
Type: schema.UserSub,
|
||||
},
|
||||
TriggeredBy: user.UserName,
|
||||
TriggeredBy: user.Username,
|
||||
Target: models.Subject{
|
||||
ID: grantID,
|
||||
Name: networkID,
|
||||
Type: models.NetworkSub,
|
||||
Type: schema.NetworkSub,
|
||||
},
|
||||
NetworkID: models.NetworkID(networkID),
|
||||
Origin: models.Dashboard,
|
||||
NetworkID: schema.NetworkID(networkID),
|
||||
Origin: schema.Dashboard,
|
||||
})
|
||||
|
||||
logic.ReturnSuccessResponse(w, r, "JIT grant revoked")
|
||||
@@ -473,24 +479,25 @@ func getUserJITNetworks(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
user, err := logic.GetUser(username)
|
||||
user := &schema.User{Username: username}
|
||||
err := user.Get(r.Context())
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "unauthorized"))
|
||||
return
|
||||
}
|
||||
|
||||
// Get all networks user has access to
|
||||
allNetworks, err := logic.GetNetworks()
|
||||
allNetworks, err := (&schema.Network{}).ListAll(r.Context())
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
|
||||
return
|
||||
}
|
||||
|
||||
// Filter networks by user role
|
||||
userNetworks := logic.FilterNetworksByRole(allNetworks, *user)
|
||||
userNetworks := logic.FilterNetworksByRole(allNetworks, user)
|
||||
|
||||
// Build response with JIT status for each network
|
||||
networksWithJITStatus, err := proLogic.GetUserJITNetworksStatus(userNetworks, user.UserName)
|
||||
networksWithJITStatus, err := proLogic.GetUserJITNetworksStatus(userNetworks, user.Username)
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
|
||||
return
|
||||
@@ -525,7 +532,8 @@ func requestJITAccess(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
network := r.URL.Query().Get("network")
|
||||
|
||||
user, err := logic.GetUser(username)
|
||||
user := &schema.User{Username: username}
|
||||
err := user.Get(r.Context())
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "unauthorized"))
|
||||
return
|
||||
@@ -550,17 +558,17 @@ func requestJITAccess(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
// Check if user has access to the network by role
|
||||
allNetworks, err := logic.GetNetworks()
|
||||
allNetworks, err := (&schema.Network{}).ListAll(r.Context())
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
|
||||
return
|
||||
}
|
||||
|
||||
// Filter networks by user role
|
||||
userNetworks := logic.FilterNetworksByRole(allNetworks, *user)
|
||||
userNetworks := logic.FilterNetworksByRole(allNetworks, user)
|
||||
hasAccess := false
|
||||
for _, network := range userNetworks {
|
||||
if network.NetID == req.NetworkID {
|
||||
if network.Name == req.NetworkID {
|
||||
hasAccess = true
|
||||
break
|
||||
}
|
||||
@@ -572,7 +580,7 @@ func requestJITAccess(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// Create the JIT request
|
||||
request, err := proLogic.CreateJITRequest(req.NetworkID, user.UserName, req.Reason)
|
||||
request, err := proLogic.CreateJITRequest(req.NetworkID, user.Username, req.Reason)
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
|
||||
return
|
||||
@@ -580,27 +588,28 @@ func requestJITAccess(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// Send email notifications to network admins
|
||||
go func() {
|
||||
network, _ := logic.GetNetwork(req.NetworkID)
|
||||
network := &schema.Network{Name: req.NetworkID}
|
||||
_ = network.Get(r.Context())
|
||||
if err := email.SendJITRequestEmails(request, network); err != nil {
|
||||
slog.Error("failed to send JIT request notifications", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
logic.LogEvent(&models.Event{
|
||||
Action: models.Create,
|
||||
Action: schema.Create,
|
||||
Source: models.Subject{
|
||||
ID: user.UserName,
|
||||
Name: user.UserName,
|
||||
Type: models.UserSub,
|
||||
ID: user.Username,
|
||||
Name: user.Username,
|
||||
Type: schema.UserSub,
|
||||
},
|
||||
TriggeredBy: user.UserName,
|
||||
TriggeredBy: user.Username,
|
||||
Target: models.Subject{
|
||||
ID: request.ID,
|
||||
Name: req.NetworkID,
|
||||
Type: models.NetworkSub,
|
||||
Type: schema.NetworkSub,
|
||||
},
|
||||
NetworkID: models.NetworkID(req.NetworkID),
|
||||
Origin: models.ClientApp,
|
||||
NetworkID: schema.NetworkID(req.NetworkID),
|
||||
Origin: schema.ClientApp,
|
||||
})
|
||||
|
||||
logic.ReturnSuccessResponseWithJson(w, r, request, "JIT access request created")
|
||||
|
||||
@@ -92,20 +92,20 @@ func createPostureCheck(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
logic.LogEvent(&models.Event{
|
||||
Action: models.Create,
|
||||
Action: schema.Create,
|
||||
Source: models.Subject{
|
||||
ID: r.Header.Get("user"),
|
||||
Name: r.Header.Get("user"),
|
||||
Type: models.UserSub,
|
||||
Type: schema.UserSub,
|
||||
},
|
||||
TriggeredBy: r.Header.Get("user"),
|
||||
Target: models.Subject{
|
||||
ID: pc.ID,
|
||||
Name: pc.Name,
|
||||
Type: models.PostureCheckSub,
|
||||
Type: schema.PostureCheckSub,
|
||||
},
|
||||
NetworkID: models.NetworkID(pc.NetworkID),
|
||||
Origin: models.Dashboard,
|
||||
NetworkID: schema.NetworkID(pc.NetworkID),
|
||||
Origin: schema.Dashboard,
|
||||
})
|
||||
|
||||
go mq.PublishPeerUpdate(false)
|
||||
@@ -131,7 +131,7 @@ func listPostureChecks(w http.ResponseWriter, r *http.Request) {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("network is required"), logic.BadReq))
|
||||
return
|
||||
}
|
||||
_, err := logic.GetNetwork(network)
|
||||
err := (&schema.Network{Name: network}).Get(r.Context())
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("network not found"), logic.BadReq))
|
||||
return
|
||||
@@ -151,7 +151,7 @@ func listPostureChecks(w http.ResponseWriter, r *http.Request) {
|
||||
logic.ReturnSuccessResponseWithJson(w, r, pc, "fetched posture check")
|
||||
return
|
||||
}
|
||||
pc := schema.PostureCheck{NetworkID: models.NetworkID(network)}
|
||||
pc := schema.PostureCheck{NetworkID: schema.NetworkID(network)}
|
||||
list, err := pc.ListByNetwork(db.WithContext(r.Context()))
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(
|
||||
@@ -202,24 +202,24 @@ func updatePostureCheck(w http.ResponseWriter, r *http.Request) {
|
||||
updateStatus = true
|
||||
}
|
||||
event := &models.Event{
|
||||
Action: models.Update,
|
||||
Action: schema.Update,
|
||||
Source: models.Subject{
|
||||
ID: r.Header.Get("user"),
|
||||
Name: r.Header.Get("user"),
|
||||
Type: models.UserSub,
|
||||
Type: schema.UserSub,
|
||||
},
|
||||
TriggeredBy: r.Header.Get("user"),
|
||||
Target: models.Subject{
|
||||
ID: pc.ID,
|
||||
Name: updatePc.Name,
|
||||
Type: models.PostureCheckSub,
|
||||
Type: schema.PostureCheckSub,
|
||||
},
|
||||
Diff: models.Diff{
|
||||
Old: pc,
|
||||
New: updatePc,
|
||||
},
|
||||
NetworkID: models.NetworkID(pc.NetworkID),
|
||||
Origin: models.Dashboard,
|
||||
NetworkID: schema.NetworkID(pc.NetworkID),
|
||||
Origin: schema.Dashboard,
|
||||
}
|
||||
pc.Tags = updatePc.Tags
|
||||
pc.UserGroups = updatePc.UserGroups
|
||||
@@ -278,20 +278,20 @@ func deletePostureCheck(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
logic.LogEvent(&models.Event{
|
||||
Action: models.Delete,
|
||||
Action: schema.Delete,
|
||||
Source: models.Subject{
|
||||
ID: r.Header.Get("user"),
|
||||
Name: r.Header.Get("user"),
|
||||
Type: models.UserSub,
|
||||
Type: schema.UserSub,
|
||||
},
|
||||
TriggeredBy: r.Header.Get("user"),
|
||||
Target: models.Subject{
|
||||
ID: pc.ID,
|
||||
Name: pc.Name,
|
||||
Type: models.PostureCheckSub,
|
||||
Type: schema.PostureCheckSub,
|
||||
},
|
||||
NetworkID: models.NetworkID(pc.NetworkID),
|
||||
Origin: models.Dashboard,
|
||||
NetworkID: schema.NetworkID(pc.NetworkID),
|
||||
Origin: schema.Dashboard,
|
||||
Diff: models.Diff{
|
||||
Old: pc,
|
||||
New: nil,
|
||||
|
||||
+19
-17
@@ -15,6 +15,7 @@ import (
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/gravitl/netmaker/mq"
|
||||
proLogic "github.com/gravitl/netmaker/pro/logic"
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
)
|
||||
|
||||
func TagHandlers(r *mux.Router) {
|
||||
@@ -44,12 +45,12 @@ func getTags(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
// check if network exists
|
||||
_, err := logic.GetNetwork(netID)
|
||||
err := (&schema.Network{Name: netID}).Get(r.Context())
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
|
||||
return
|
||||
}
|
||||
tags, err := proLogic.ListTagsWithNodes(models.NetworkID(netID))
|
||||
tags, err := proLogic.ListTagsWithNodes(schema.NetworkID(netID))
|
||||
if err != nil {
|
||||
logger.Log(0, r.Header.Get("user"), "failed to get all network tag entries: ", err.Error())
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
|
||||
@@ -77,13 +78,14 @@ func createTag(w http.ResponseWriter, r *http.Request) {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
|
||||
return
|
||||
}
|
||||
user, err := logic.GetUser(r.Header.Get("user"))
|
||||
user := &schema.User{Username: r.Header.Get("user")}
|
||||
err = user.Get(r.Context())
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
|
||||
return
|
||||
}
|
||||
// check if tag network exists
|
||||
_, err = logic.GetNetwork(req.Network.String())
|
||||
err = (&schema.Network{Name: req.Network.String()}).Get(r.Context())
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("failed to get network details for "+req.Network.String()), "badrequest"))
|
||||
return
|
||||
@@ -93,7 +95,7 @@ func createTag(w http.ResponseWriter, r *http.Request) {
|
||||
ID: models.TagID(fmt.Sprintf("%s.%s", req.Network, req.TagName)),
|
||||
TagName: req.TagName,
|
||||
Network: req.Network,
|
||||
CreatedBy: user.UserName,
|
||||
CreatedBy: user.Username,
|
||||
ColorCode: req.ColorCode,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
@@ -138,20 +140,20 @@ func createTag(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}()
|
||||
logic.LogEvent(&models.Event{
|
||||
Action: models.Create,
|
||||
Action: schema.Create,
|
||||
Source: models.Subject{
|
||||
ID: r.Header.Get("user"),
|
||||
Name: r.Header.Get("user"),
|
||||
Type: models.UserSub,
|
||||
Type: schema.UserSub,
|
||||
},
|
||||
TriggeredBy: r.Header.Get("user"),
|
||||
Target: models.Subject{
|
||||
ID: tag.ID.String(),
|
||||
Name: tag.TagName,
|
||||
Type: models.TagSub,
|
||||
Type: schema.TagSub,
|
||||
},
|
||||
NetworkID: tag.Network,
|
||||
Origin: models.Dashboard,
|
||||
Origin: schema.Dashboard,
|
||||
})
|
||||
go mq.PublishPeerUpdate(false)
|
||||
|
||||
@@ -189,23 +191,23 @@ func updateTag(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
e := &models.Event{
|
||||
Action: models.Update,
|
||||
Action: schema.Update,
|
||||
Source: models.Subject{
|
||||
ID: r.Header.Get("user"),
|
||||
Name: r.Header.Get("user"),
|
||||
Type: models.UserSub,
|
||||
Type: schema.UserSub,
|
||||
},
|
||||
TriggeredBy: r.Header.Get("user"),
|
||||
Target: models.Subject{
|
||||
ID: tag.ID.String(),
|
||||
Name: tag.TagName,
|
||||
Type: models.TagSub,
|
||||
Type: schema.TagSub,
|
||||
},
|
||||
Diff: models.Diff{
|
||||
Old: tag,
|
||||
},
|
||||
NetworkID: tag.Network,
|
||||
Origin: models.Dashboard,
|
||||
Origin: schema.Dashboard,
|
||||
}
|
||||
updateTag.NewName = strings.TrimSpace(updateTag.NewName)
|
||||
var newID models.TagID
|
||||
@@ -290,20 +292,20 @@ func deleteTag(w http.ResponseWriter, r *http.Request) {
|
||||
mq.PublishPeerUpdate(false)
|
||||
}()
|
||||
logic.LogEvent(&models.Event{
|
||||
Action: models.Delete,
|
||||
Action: schema.Delete,
|
||||
Source: models.Subject{
|
||||
ID: r.Header.Get("user"),
|
||||
Name: r.Header.Get("user"),
|
||||
Type: models.UserSub,
|
||||
Type: schema.UserSub,
|
||||
},
|
||||
TriggeredBy: r.Header.Get("user"),
|
||||
Target: models.Subject{
|
||||
ID: tag.ID.String(),
|
||||
Name: tag.TagName,
|
||||
Type: models.TagSub,
|
||||
Type: schema.TagSub,
|
||||
},
|
||||
NetworkID: tag.Network,
|
||||
Origin: models.Dashboard,
|
||||
Origin: schema.Dashboard,
|
||||
Diff: models.Diff{
|
||||
Old: tag,
|
||||
New: nil,
|
||||
|
||||
+299
-200
File diff suppressed because it is too large
Load Diff
+2
-2
@@ -3,8 +3,8 @@ package email
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/gravitl/netmaker/models"
|
||||
proLogic "github.com/gravitl/netmaker/pro/logic"
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
"github.com/gravitl/netmaker/servercfg"
|
||||
)
|
||||
|
||||
@@ -38,7 +38,7 @@ func (invite UserInvitedMail) GetBody(info Notification) string {
|
||||
WithHtml("<br>").
|
||||
WithHtml(fmt.Sprintf("<li><a href=\"%s\">Download the Netmaker Desktop App</a>.</li>", downloadLink))
|
||||
|
||||
if invite.PlatformRoleID == models.AdminRole.String() || invite.PlatformRoleID == models.PlatformUser.String() {
|
||||
if invite.PlatformRoleID == schema.AdminRole.String() || invite.PlatformRoleID == schema.PlatformUser.String() {
|
||||
content = content.
|
||||
WithHtml("<br>").
|
||||
WithHtml(fmt.Sprintf("<li>Access the <a href=\"%s\">Netmaker Dashboard</a> - use it to manage your network settings and view network status.</li>", dashboardURL))
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
)
|
||||
|
||||
@@ -14,11 +13,11 @@ type JITApprovedMail struct {
|
||||
BodyBuilder EmailBodyBuilder
|
||||
Grant *schema.JITGrant
|
||||
Request *schema.JITRequest
|
||||
Network models.Network
|
||||
Network *schema.Network
|
||||
}
|
||||
|
||||
// SendJITApprovalEmail - sends email notification to user when JIT request is approved
|
||||
func SendJITApprovalEmail(grant *schema.JITGrant, request *schema.JITRequest, network models.Network) error {
|
||||
func SendJITApprovalEmail(grant *schema.JITGrant, request *schema.JITRequest, network *schema.Network) error {
|
||||
mail := JITApprovedMail{
|
||||
BodyBuilder: &EmailBodyBuilderWithH1HeadlineAndImage{},
|
||||
Grant: grant,
|
||||
@@ -40,17 +39,17 @@ func SendJITApprovalEmail(grant *schema.JITGrant, request *schema.JITRequest, ne
|
||||
|
||||
// GetSubject - gets the subject of the email
|
||||
func (mail JITApprovedMail) GetSubject(info Notification) string {
|
||||
return fmt.Sprintf("JIT Access Approved: %s", mail.Network.NetID)
|
||||
return fmt.Sprintf("JIT Access Approved: %s", mail.Network.Name)
|
||||
}
|
||||
|
||||
// GetBody - gets the body of the email
|
||||
func (mail JITApprovedMail) GetBody(info Notification) string {
|
||||
content := mail.BodyBuilder.
|
||||
WithHeadline("JIT Access Approved").
|
||||
WithParagraph(fmt.Sprintf("Your request for Just-In-Time access to network <strong>%s</strong> has been approved.", mail.Network.NetID)).
|
||||
WithParagraph(fmt.Sprintf("Your request for Just-In-Time access to network <strong>%s</strong> has been approved.", mail.Network.Name)).
|
||||
WithParagraph("Access Details:").
|
||||
WithHtml("<ul>").
|
||||
WithHtml(fmt.Sprintf("<li><strong>Network:</strong> %s</li>", mail.Network.NetID)).
|
||||
WithHtml(fmt.Sprintf("<li><strong>Network:</strong> %s</li>", mail.Network.Name)).
|
||||
WithHtml(fmt.Sprintf("<li><strong>Granted At:</strong> %s</li>", formatUTCTime(mail.Grant.GrantedAt))).
|
||||
WithHtml(fmt.Sprintf("<li><strong>Expires At:</strong> %s</li>", formatUTCTime(mail.Grant.ExpiresAt))).
|
||||
WithHtml(fmt.Sprintf("<li><strong>Approved By:</strong> %s</li>", mail.Request.ApprovedBy)).
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
)
|
||||
|
||||
@@ -13,11 +12,11 @@ import (
|
||||
type JITDeniedMail struct {
|
||||
BodyBuilder EmailBodyBuilder
|
||||
Request *schema.JITRequest
|
||||
Network models.Network
|
||||
Network *schema.Network
|
||||
}
|
||||
|
||||
// SendJITDeniedEmail - sends email notification to user when JIT request is denied
|
||||
func SendJITDeniedEmail(request *schema.JITRequest, network models.Network) error {
|
||||
func SendJITDeniedEmail(request *schema.JITRequest, network *schema.Network) error {
|
||||
mail := JITDeniedMail{
|
||||
BodyBuilder: &EmailBodyBuilderWithH1HeadlineAndImage{},
|
||||
Request: request,
|
||||
@@ -38,17 +37,17 @@ func SendJITDeniedEmail(request *schema.JITRequest, network models.Network) erro
|
||||
|
||||
// GetSubject - gets the subject of the email
|
||||
func (mail JITDeniedMail) GetSubject(info Notification) string {
|
||||
return fmt.Sprintf("JIT Access Request Denied: %s", mail.Network.NetID)
|
||||
return fmt.Sprintf("JIT Access Request Denied: %s", mail.Network.Name)
|
||||
}
|
||||
|
||||
// GetBody - gets the body of the email
|
||||
func (mail JITDeniedMail) GetBody(info Notification) string {
|
||||
content := mail.BodyBuilder.
|
||||
WithHeadline("JIT Access Request Denied").
|
||||
WithParagraph(fmt.Sprintf("Your request for Just-In-Time access to network <strong>%s</strong> has been denied.", mail.Network.NetID)).
|
||||
WithParagraph(fmt.Sprintf("Your request for Just-In-Time access to network <strong>%s</strong> has been denied.", mail.Network.Name)).
|
||||
WithParagraph("Request Details:").
|
||||
WithHtml("<ul>").
|
||||
WithHtml(fmt.Sprintf("<li><strong>Network:</strong> %s</li>", mail.Network.NetID)).
|
||||
WithHtml(fmt.Sprintf("<li><strong>Network:</strong> %s</li>", mail.Network.Name)).
|
||||
WithHtml(fmt.Sprintf("<li><strong>Requested At:</strong> %s</li>", formatUTCTime(mail.Request.RequestedAt))).
|
||||
WithHtml(fmt.Sprintf("<li><strong>Denied At:</strong> %s</li>", formatUTCTime(mail.Request.ApprovedAt))).
|
||||
WithHtml(fmt.Sprintf("<li><strong>Denied By:</strong> %s</li>", mail.Request.ApprovedBy)).
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
)
|
||||
|
||||
@@ -15,14 +14,14 @@ type JITExpiredMail struct {
|
||||
BodyBuilder EmailBodyBuilder
|
||||
Grant *schema.JITGrant
|
||||
Request *schema.JITRequest
|
||||
Network models.Network
|
||||
Network *schema.Network
|
||||
IsRevoked bool
|
||||
RevokedBy string // set when IsRevoked is true
|
||||
}
|
||||
|
||||
// SendJITExpirationEmail - sends email notification to user when JIT grant expires or is revoked
|
||||
// revokedBy is the username of the admin who revoked the grant; empty when grant expired naturally
|
||||
func SendJITExpirationEmail(grant *schema.JITGrant, request *schema.JITRequest, network models.Network, isRevoked bool, revokedBy string) error {
|
||||
func SendJITExpirationEmail(grant *schema.JITGrant, request *schema.JITRequest, network *schema.Network, isRevoked bool, revokedBy string) error {
|
||||
mail := JITExpiredMail{
|
||||
BodyBuilder: &EmailBodyBuilderWithH1HeadlineAndImage{},
|
||||
Grant: grant,
|
||||
@@ -47,9 +46,9 @@ func SendJITExpirationEmail(grant *schema.JITGrant, request *schema.JITRequest,
|
||||
// GetSubject - gets the subject of the email
|
||||
func (mail JITExpiredMail) GetSubject(info Notification) string {
|
||||
if mail.IsRevoked {
|
||||
return fmt.Sprintf("JIT Access Revoked: %s", mail.Network.NetID)
|
||||
return fmt.Sprintf("JIT Access Revoked: %s", mail.Network.Name)
|
||||
}
|
||||
return fmt.Sprintf("JIT Access Expired: %s", mail.Network.NetID)
|
||||
return fmt.Sprintf("JIT Access Expired: %s", mail.Network.Name)
|
||||
}
|
||||
|
||||
// GetBody - gets the body of the email
|
||||
@@ -57,10 +56,10 @@ func (mail JITExpiredMail) GetBody(info Notification) string {
|
||||
var headline, message string
|
||||
if mail.IsRevoked {
|
||||
headline = "JIT Access Revoked"
|
||||
message = fmt.Sprintf("Your Just-In-Time access to network <strong>%s</strong> has been revoked by an administrator.", mail.Network.NetID)
|
||||
message = fmt.Sprintf("Your Just-In-Time access to network <strong>%s</strong> has been revoked by an administrator.", mail.Network.Name)
|
||||
} else {
|
||||
headline = "JIT Access Expired"
|
||||
message = fmt.Sprintf("Your Just-In-Time access to network <strong>%s</strong> has expired.", mail.Network.NetID)
|
||||
message = fmt.Sprintf("Your Just-In-Time access to network <strong>%s</strong> has expired.", mail.Network.Name)
|
||||
}
|
||||
|
||||
builder := mail.BodyBuilder.
|
||||
@@ -68,7 +67,7 @@ func (mail JITExpiredMail) GetBody(info Notification) string {
|
||||
WithParagraph(message).
|
||||
WithParagraph("Access Details:").
|
||||
WithHtml("<ul>").
|
||||
WithHtml(fmt.Sprintf("<li><strong>Network:</strong> %s</li>", mail.Network.NetID)).
|
||||
WithHtml(fmt.Sprintf("<li><strong>Network:</strong> %s</li>", mail.Network.Name)).
|
||||
WithHtml(fmt.Sprintf("<li><strong>Granted At:</strong> %s</li>", formatUTCTime(mail.Grant.GrantedAt))).
|
||||
WithHtml(fmt.Sprintf("<li><strong>Expired At:</strong> %s</li>", formatUTCTime(mail.Grant.ExpiresAt)))
|
||||
if mail.IsRevoked && mail.RevokedBy != "" {
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user