NM-163: Users, Groups, Roles, Networks and Hosts Table Migration (#3910)

* feat(go): add user schema;

* feat(go): migrate to user schema;

* feat(go): add audit fields;

* feat(go): remove unused fields from the network model;

* feat(go): add network schema;

* feat(go): migrate to network schema;

* refactor(go): add comment to clarify migration logic;

* fix(go): test failures;

* fix(go): test failures;

* feat(go): change membership table to store memberships at all scopes;

* feat(go): add schema for access grants;

* feat(go): remove nameservers from new networks table; ensure db passed for schema functions;

* feat(go): set max conns for sqlite to 1;

* fix(go): issues updating user account status;

* refactor(go): remove converters and access grants;

* refactor(go): add json tags in schema models;

* refactor(go): rename file to migrate_v1_6_0.go;

* refactor(go): add user groups and user roles tables; use schema tables;

* refactor(go): inline get and list from schema package;

* refactor(go): inline get network and list users from schema package;

* fix(go): staticcheck issues;

* fix(go): remove test not in use; fix test case;

* fix(go): validate network;

* fix(go): resolve static checks;

* fix(go): new models errors;

* fix(go): test errors;

* fix(go): handle no records;

* fix(go): add validations for user object;

* fix(go): set correct extclient status;

* fix(go): test error;

* feat(go): make schema the base package;

* feat(go): add host schema;

* feat(go): use schema host everywhere;

* feat(go): inline get host, list hosts and delete host;

* feat(go): use non-ptr value;

* feat(go): use save to upsert all fields;

* feat(go): use save to upsert all fields;

* feat(go): save turn endpoint as string;

* feat(go): check for gorm error record not found;

* fix(go): test failures;

* fix(go): update all network fields;

* fix(go): update all network fields;

* feat(go): add paginated list networks api;

* feat(go): add paginated list users api;

* feat(go): add paginated list hosts api;

* feat(go): add pagination to list groups api;

* fix(go): comment;

* fix(go): implement marshal and unmarshal text for custom types;

* fix(go): implement marshal and unmarshal json for custom types;

* fix(go): just use the old model for unmarshalling;

* fix(go): implement marshal and unmarshal json for custom types;

* feat(go): remove paginated list networks api;

* feat(go): use custom paginated response object;

* fix(go): ensure default values for page and per_page are used when not passed;

* fix(go): rename v1.6.0 to v1.5.1;

* fix(go): check for gorm.ErrRecordNotFound instead of database.IsEmptyRecord;

* fix(go): use host id, not pending host id;

* feat(go): add filters to paginated apis;

* feat(go): add filters to paginated apis;

* feat(go): remove check for max username length;

* feat(go): add filters to count as well;

* feat(go): use library to check email address validity;

* feat(go): ignore pagination if params not passed;

* fix(go): pagination issues;

* fix(go): check exists before using;

* fix(go): remove debug log;

* fix(go): use gorm err record not found;

* fix(go): use gorm err record not found;

* fix(go): use user principal name when creating pending user;

* fix(go): use schema package for consts;

* fix(go): prevent disabling superadmin user;

Co-authored-by: tenki-reviewer[bot] <262613592+tenki-reviewer[bot]@users.noreply.github.com>

* fix(go): swap is admin and is superadmin;

Co-authored-by: tenki-reviewer[bot] <262613592+tenki-reviewer[bot]@users.noreply.github.com>

* fix(go): remove dead code block;

https://github.com/gravitl/netmaker/pull/3910#discussion_r2928837937

* fix(go): incorrect message when trying to disable self;

https://github.com/gravitl/netmaker/pull/3910#discussion_r2928837934

* fix(go): use correct header;

Co-authored-by: tenki-reviewer[bot] <262613592+tenki-reviewer[bot]@users.noreply.github.com>

* fix(go): return after error response;

Co-authored-by: tenki-reviewer[bot] <262613592+tenki-reviewer[bot]@users.noreply.github.com>

* fix(go): use correct order of params;

https://github.com/gravitl/netmaker/pull/3910#discussion_r2929593036

* fix(go): set default values for page and page size; use v2 instead of /list;

* Update logic/auth.go

Co-authored-by: tenki-reviewer[bot] <262613592+tenki-reviewer[bot]@users.noreply.github.com>

* Update schema/user_roles.go

Co-authored-by: tenki-reviewer[bot] <262613592+tenki-reviewer[bot]@users.noreply.github.com>

* fix(go): syntax error;

* fix(go): set default values when page and per_page are not passed or 0;

* fix(go): use uuid.parse instead of uuid.must parse;

* fix(go): review errors;

* fix(go): review errors;

* Update controllers/user.go

Co-authored-by: tenki-reviewer[bot] <262613592+tenki-reviewer[bot]@users.noreply.github.com>

* Update controllers/user.go

Co-authored-by: tenki-reviewer[bot] <262613592+tenki-reviewer[bot]@users.noreply.github.com>

* NM-163: fix errors:

* Update db/types/options.go

Co-authored-by: tenki-reviewer[bot] <262613592+tenki-reviewer[bot]@users.noreply.github.com>

* fix(go): persist return user in event;

* Update db/types/options.go

Co-authored-by: tenki-reviewer[bot] <262613592+tenki-reviewer[bot]@users.noreply.github.com>

* NM-163: duplicate lines of code

* NM-163: fix(go): fix missing return and filter parsing in user controller

- Add missing return after error response in updateUserAccountStatus
  to prevent double-response and spurious ext-client side-effects
- Use switch statements in listUsers to skip unrecognized
  account_status and mfa_status filter values

* fix(go): check for both min and max page size;

* fix(go): enclose transfer superadmin in transaction;

* fix(go): review errors;

* fix(go): remove free tier checks;

* fix(go): review fixes;

---------

Co-authored-by: VishalDalwadi <dalwadivishal26@gmail.com>
Co-authored-by: Vishal Dalwadi <51291657+VishalDalwadi@users.noreply.github.com>
Co-authored-by: tenki-reviewer[bot] <262613592+tenki-reviewer[bot]@users.noreply.github.com>
This commit is contained in:
Abhishek Kondur
2026-03-17 19:36:52 +05:30
committed by GitHub
parent 6a4b34b61c
commit edda2868fc
130 changed files with 4958 additions and 5168 deletions
+8 -5
View File
@@ -1,8 +1,10 @@
package auth
import (
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/models"
"context"
"github.com/gravitl/netmaker/db"
"github.com/gravitl/netmaker/schema"
)
// == consts ==
@@ -10,11 +12,12 @@ const (
node_signin_length = 64
)
func isUserIsAllowed(username, network string) (*models.User, error) {
func isUserIsAllowed(username, network string) (*schema.User, error) {
user, err := logic.GetUser(username)
user := &schema.User{Username: username}
err := user.Get(db.WithContext(context.TODO()))
if err != nil { // user must not exist, so try to make one
return &models.User{}, err
return &schema.User{}, err
}
return user, nil
+25 -21
View File
@@ -180,23 +180,26 @@ func SessionHandler(conn *websocket.Conn) {
handleHostRegErr(conn, err)
return
}
currHost, err := logic.GetHost(result.Host.ID.String())
currHost := &schema.Host{
ID: result.Host.ID,
}
err = currHost.Get(db.WithContext(context.TODO()))
if err != nil {
handleHostRegErr(conn, err)
return
}
var currentNetworks = []string{}
var currentNetworks []string
if result.ALL {
currentNets, err := logic.GetNetworks()
if err == nil && len(currentNets) > 0 {
for i := range currentNets {
currentNetworks = append(currentNetworks, currentNets[i].NetID)
_networks, err := (&schema.Network{}).ListAll(db.WithContext(context.TODO()))
if err == nil && len(_networks) > 0 {
for i := range _networks {
currentNetworks = append(currentNetworks, _networks[i].Name)
}
}
} else if len(result.Network) > 0 {
currentNetworks = append(currentNetworks, result.Network)
}
var netsToAdd = []string{} // track the networks not currently owned by host
var netsToAdd []string // track the networks not currently owned by host
hostNets := logic.GetHostNetworks(currHost.ID.String())
for _, newNet := range currentNetworks {
if !logic.StringSliceContains(hostNets, newNet) {
@@ -240,13 +243,14 @@ func SessionHandler(conn *websocket.Conn) {
}
// CheckNetRegAndHostUpdate - run through networks and send a host update
func CheckNetRegAndHostUpdate(key models.EnrollmentKey, h *models.Host, username string) {
func CheckNetRegAndHostUpdate(key models.EnrollmentKey, h *schema.Host, username string) {
// publish host update through MQ
featureFlags := logic.GetFeatureFlags()
for _, netID := range key.Networks {
if network, err := logic.GetNetwork(netID); err == nil {
if featureFlags.EnableDeviceApproval && network.AutoJoin == "false" {
if logic.DoesHostExistinTheNetworkAlready(h, models.NetworkID(netID)) {
network := &schema.Network{Name: netID}
if err := network.Get(db.WithContext(context.TODO())); err == nil {
if featureFlags.EnableDeviceApproval && !network.AutoJoin {
if logic.DoesHostExistinTheNetworkAlready(h, schema.NetworkID(netID)) {
continue
}
if err := (&schema.PendingHost{
@@ -275,37 +279,37 @@ func CheckNetRegAndHostUpdate(key models.EnrollmentKey, h *models.Host, username
if len(username) > 0 {
logic.LogEvent(&models.Event{
Action: models.JoinHostToNet,
Action: schema.JoinHostToNet,
Source: models.Subject{
ID: username,
Name: username,
Type: models.UserSub,
Type: schema.UserSub,
},
TriggeredBy: username,
Target: models.Subject{
ID: h.ID.String(),
Name: h.Name,
Type: models.DeviceSub,
Type: schema.DeviceSub,
},
NetworkID: models.NetworkID(netID),
Origin: models.Dashboard,
NetworkID: schema.NetworkID(netID),
Origin: schema.Dashboard,
})
} else {
logic.LogEvent(&models.Event{
Action: models.JoinHostToNet,
Action: schema.JoinHostToNet,
Source: models.Subject{
ID: key.Value,
Name: key.Tags[0],
Type: models.EnrollmentKeySub,
Type: schema.EnrollmentKeySub,
},
TriggeredBy: username,
Target: models.Subject{
ID: h.ID.String(),
Name: h.Name,
Type: models.DeviceSub,
Type: schema.DeviceSub,
},
NetworkID: models.NetworkID(netID),
Origin: models.Dashboard,
NetworkID: schema.NetworkID(netID),
Origin: schema.Dashboard,
})
}
+5 -23
View File
@@ -6,7 +6,7 @@ import (
"os"
"github.com/gravitl/netmaker/cli/functions"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/schema"
"github.com/spf13/cobra"
)
@@ -16,7 +16,7 @@ var networkCreateCmd = &cobra.Command{
Long: `Create a Network`,
Args: cobra.NoArgs,
Run: func(cmd *cobra.Command, args []string) {
network := &models.Network{}
network := &schema.Network{}
if networkDefinitionFilePath != "" {
content, err := os.ReadFile(networkDefinitionFilePath)
if err != nil {
@@ -26,28 +26,15 @@ var networkCreateCmd = &cobra.Command{
log.Fatal(err)
}
} else {
network.NetID = netID
network.Name = name
network.AddressRange = address
if address6 != "" {
network.AddressRange6 = address6
network.IsIPv6 = "yes"
}
if address == "" {
network.IsIPv4 = "no"
}
if udpHolePunch {
network.DefaultUDPHolePunch = "yes"
}
if defaultACL {
network.DefaultACL = "yes"
}
network.DefaultInterface = defaultInterface
network.DefaultListenPort = int32(defaultListenPort)
network.NodeLimit = int32(nodeLimit)
network.DefaultKeepalive = int32(defaultKeepalive)
if allowManualSignUp {
network.AllowManualSignUp = "yes"
}
network.DefaultKeepAlive = defaultKeepalive
network.DefaultMTU = int32(defaultMTU)
}
functions.PrettyPrint(functions.CreateNetwork(network))
@@ -56,17 +43,12 @@ var networkCreateCmd = &cobra.Command{
func init() {
networkCreateCmd.Flags().StringVar(&networkDefinitionFilePath, "file", "", "Path to network_definition.json")
networkCreateCmd.Flags().StringVar(&netID, "name", "", "Name of the network")
networkCreateCmd.Flags().StringVar(&name, "name", "", "Name of the network")
networkCreateCmd.MarkFlagsMutuallyExclusive("file", "name")
networkCreateCmd.Flags().StringVar(&address, "ipv4_addr", "", "IPv4 address of the network")
networkCreateCmd.Flags().StringVar(&address6, "ipv6_addr", "", "IPv6 address of the network")
networkCreateCmd.Flags().BoolVar(&udpHolePunch, "udp_hole_punch", false, "Enable UDP Hole Punching ?")
networkCreateCmd.Flags().BoolVar(&defaultACL, "default_acl", false, "Enable default Access Control List ?")
networkCreateCmd.Flags().StringVar(&defaultInterface, "interface", "", "Name of the network interface")
networkCreateCmd.Flags().IntVar(&defaultListenPort, "listen_port", 51821, "Default wireguard port each node will attempt to use")
networkCreateCmd.Flags().IntVar(&nodeLimit, "node_limit", 999999999, "Maximum number of nodes that can be associated with this network")
networkCreateCmd.Flags().IntVar(&defaultKeepalive, "keep_alive", 20, "Keep Alive in seconds")
networkCreateCmd.Flags().IntVar(&defaultMTU, "mtu", 1280, "MTU size")
networkCreateCmd.Flags().BoolVar(&allowManualSignUp, "manual_signup", false, "Allow manual signup ?")
rootCmd.AddCommand(networkCreateCmd)
}
+1 -6
View File
@@ -2,15 +2,10 @@ package network
var (
networkDefinitionFilePath string
netID string
name string
address string
address6 string
udpHolePunch bool
defaultACL bool
defaultInterface string
defaultListenPort int
nodeLimit int
defaultKeepalive int
allowManualSignUp bool
defaultMTU int
)
+3 -3
View File
@@ -24,9 +24,9 @@ var networkListCmd = &cobra.Command{
table := tablewriter.NewWriter(os.Stdout)
table.SetHeader([]string{"NetId", "Address Range (IPv4)", "Address Range (IPv6)", "Network Last Modified", "Nodes Last Modified"})
for _, n := range *networks {
networkLastModified := time.Unix(n.NetworkLastModified, 0).Format(time.RFC3339)
nodesLastModified := time.Unix(n.NodesLastModified, 0).Format(time.RFC3339)
table.Append([]string{n.NetID, n.AddressRange, n.AddressRange6, networkLastModified, nodesLastModified})
networkLastModified := n.UpdatedAt.Format(time.RFC3339)
nodesLastModified := n.NodesUpdatedAt.Format(time.RFC3339)
table.Append([]string{n.Name, n.AddressRange, n.AddressRange6, networkLastModified, nodesLastModified})
}
table.Render()
}
-27
View File
@@ -1,27 +0,0 @@
package network
import (
"log"
"strconv"
"github.com/gravitl/netmaker/cli/functions"
"github.com/spf13/cobra"
)
var networkNodeLimitCmd = &cobra.Command{
Use: "node_limit [NETWORK NAME] [NEW LIMIT]",
Short: "Update network nodel limit",
Long: `Update network nodel limit`,
Args: cobra.ExactArgs(2),
Run: func(cmd *cobra.Command, args []string) {
nodelimit, err := strconv.ParseInt(args[1], 10, 32)
if err != nil {
log.Fatal(err)
}
functions.PrettyPrint(functions.UpdateNetworkNodeLimit(args[0], int32(nodelimit)))
},
}
func init() {
rootCmd.AddCommand(networkNodeLimitCmd)
}
+7 -19
View File
@@ -1,11 +1,10 @@
package user
import (
"strings"
"github.com/gravitl/netmaker/cli/functions"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/schema"
"github.com/spf13/cobra"
"gorm.io/datatypes"
)
var userCreateCmd = &cobra.Command{
@@ -14,24 +13,13 @@ var userCreateCmd = &cobra.Command{
Short: "Create a new user",
Long: `Create a new user`,
Run: func(cmd *cobra.Command, args []string) {
user := &models.User{UserName: username, Password: password, PlatformRoleID: models.UserRoleID(platformID)}
if len(networkRoles) > 0 {
netRolesMap := make(map[models.NetworkID]map[models.UserRoleID]struct{})
for netID, netRoles := range networkRoles {
roleMap := make(map[models.UserRoleID]struct{})
for _, roleID := range strings.Split(netRoles, " ") {
roleMap[models.UserRoleID(roleID)] = struct{}{}
}
netRolesMap[models.NetworkID(netID)] = roleMap
}
user.NetworkRoles = netRolesMap
}
user := &schema.User{Username: username, Password: password, PlatformRoleID: schema.UserRoleID(platformID)}
if len(groups) > 0 {
grMap := make(map[models.UserGroupID]struct{})
grMap := make(map[schema.UserGroupID]struct{})
for _, groupID := range groups {
grMap[models.UserGroupID(groupID)] = struct{}{}
grMap[schema.UserGroupID(groupID)] = struct{}{}
}
user.UserGroups = grMap
user.UserGroups = datatypes.NewJSONType(grMap)
}
functions.PrettyPrint(functions.CreateUser(user))
@@ -42,7 +30,7 @@ func init() {
userCreateCmd.Flags().StringVar(&username, "name", "", "Name of the user")
userCreateCmd.Flags().StringVar(&password, "password", "", "Password of the user")
userCreateCmd.Flags().StringVarP(&platformID, "platform-role", "r", models.ServiceUser.String(),
userCreateCmd.Flags().StringVarP(&platformID, "platform-role", "r", schema.ServiceUser.String(),
"Platform Role of the user; run `nmctl roles list` to see available user roles")
userCreateCmd.MarkFlagRequired("name")
userCreateCmd.MarkFlagRequired("password")
+2 -2
View File
@@ -35,7 +35,7 @@ var userGroupListCmd = &cobra.Command{
for _, d := range data {
roleInfoStr := ""
for netID, netRoleMap := range d.NetworkRoles {
for netID, netRoleMap := range d.NetworkRoles.Data() {
roleList := []string{}
for roleID := range netRoleMap {
roleList = append(roleList, roleID.String())
@@ -88,7 +88,7 @@ var userGroupGetCmd = &cobra.Command{
h := []string{"ID", "MetaData", "Network Roles"}
table.SetHeader(h)
roleInfoStr := ""
for netID, netRoleMap := range data.NetworkRoles {
for netID, netRoleMap := range data.NetworkRoles.Data() {
roleList := []string{}
for roleID := range netRoleMap {
roleList = append(roleList, roleID.String())
+7 -19
View File
@@ -1,11 +1,10 @@
package user
import (
"strings"
"github.com/gravitl/netmaker/cli/functions"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/schema"
"github.com/spf13/cobra"
"gorm.io/datatypes"
)
var userUpdateCmd = &cobra.Command{
@@ -14,27 +13,16 @@ var userUpdateCmd = &cobra.Command{
Short: "Update a user",
Long: `Update a user`,
Run: func(cmd *cobra.Command, args []string) {
user := &models.User{UserName: args[0]}
user := &schema.User{Username: args[0]}
if platformID != "" {
user.PlatformRoleID = models.UserRoleID(platformID)
}
if len(networkRoles) > 0 {
netRolesMap := make(map[models.NetworkID]map[models.UserRoleID]struct{})
for netID, netRoles := range networkRoles {
roleMap := make(map[models.UserRoleID]struct{})
for _, roleID := range strings.Split(netRoles, ",") {
roleMap[models.UserRoleID(roleID)] = struct{}{}
}
netRolesMap[models.NetworkID(netID)] = roleMap
}
user.NetworkRoles = netRolesMap
user.PlatformRoleID = schema.UserRoleID(platformID)
}
if len(groups) > 0 {
grMap := make(map[models.UserGroupID]struct{})
grMap := make(map[schema.UserGroupID]struct{})
for _, groupID := range groups {
grMap[models.UserGroupID(groupID)] = struct{}{}
grMap[schema.UserGroupID(groupID)] = struct{}{}
}
user.UserGroups = grMap
user.UserGroups = datatypes.NewJSONType(grMap)
}
functions.PrettyPrint(functions.UpdateUser(user))
},
+7 -20
View File
@@ -1,37 +1,24 @@
package functions
import (
"fmt"
"net/http"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/schema"
)
// CreateNetwork - creates a network
func CreateNetwork(payload *models.Network) *models.Network {
return request[models.Network](http.MethodPost, "/api/networks", payload)
}
// UpdateNetwork - updates a network
func UpdateNetwork(name string, payload *models.Network) *models.Network {
return request[models.Network](http.MethodPut, "/api/networks/"+name, payload)
}
// UpdateNetworkNodeLimit - updates a network
func UpdateNetworkNodeLimit(name string, nodeLimit int32) *models.Network {
return request[models.Network](http.MethodPut, fmt.Sprintf("/api/networks/%s/nodelimit", name), &models.Network{
NodeLimit: nodeLimit,
})
func CreateNetwork(payload *schema.Network) *schema.Network {
return request[schema.Network](http.MethodPost, "/api/networks", payload)
}
// GetNetworks - fetch all networks
func GetNetworks() *[]models.Network {
return request[[]models.Network](http.MethodGet, "/api/networks", nil)
func GetNetworks() *[]schema.Network {
return request[[]schema.Network](http.MethodGet, "/api/networks", nil)
}
// GetNetwork - fetch a single network
func GetNetwork(name string) *models.Network {
return request[models.Network](http.MethodGet, "/api/networks/"+name, nil)
func GetNetwork(name string) *schema.Network {
return request[schema.Network](http.MethodGet, "/api/networks/"+name, nil)
}
// DeleteNetwork - delete a network
+11 -10
View File
@@ -6,6 +6,7 @@ import (
"net/http"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/schema"
)
// HasAdmin - check if server has an admin user
@@ -14,13 +15,13 @@ func HasAdmin() *bool {
}
// CreateUser - create a user
func CreateUser(payload *models.User) *models.User {
return request[models.User](http.MethodPost, "/api/users/"+payload.UserName, payload)
func CreateUser(payload *schema.User) *schema.User {
return request[schema.User](http.MethodPost, "/api/users/"+payload.Username, payload)
}
// UpdateUser - update a user
func UpdateUser(payload *models.User) *models.User {
return request[models.User](http.MethodPut, "/api/users/"+payload.UserName, payload)
func UpdateUser(payload *schema.User) *schema.User {
return request[schema.User](http.MethodPut, "/api/users/"+payload.Username, payload)
}
// DeleteUser - delete a user
@@ -29,8 +30,8 @@ func DeleteUser(username string) *string {
}
// GetUser - fetch a single user
func GetUser(username string) *models.User {
return request[models.User](http.MethodGet, "/api/users/"+username, nil)
func GetUser(username string) *schema.User {
return request[schema.User](http.MethodGet, "/api/users/"+username, nil)
}
// ListUsers - fetch all users
@@ -38,7 +39,7 @@ func ListUsers() *[]models.ReturnUser {
return request[[]models.ReturnUser](http.MethodGet, "/api/users", nil)
}
func ListUserRoles() (roles []models.UserRolePermissionTemplate) {
func ListUserRoles() (roles []schema.UserRole) {
resp := request[models.SuccessResponse](http.MethodGet, "/api/v1/users/roles", nil)
d, _ := json.Marshal(resp.Response)
json.Unmarshal(d, &roles)
@@ -48,14 +49,14 @@ func ListUserRoles() (roles []models.UserRolePermissionTemplate) {
func DeleteUserRole(roleID string) *models.SuccessResponse {
return request[models.SuccessResponse](http.MethodDelete, fmt.Sprintf("/api/v1/users/role?role_id=%s", roleID), nil)
}
func GetUserRole(roleID string) (role models.UserRolePermissionTemplate) {
func GetUserRole(roleID string) (role schema.UserRole) {
resp := request[models.SuccessResponse](http.MethodGet, fmt.Sprintf("/api/v1/users/role?role_id=%s", roleID), nil)
d, _ := json.Marshal(resp.Response)
json.Unmarshal(d, &role)
return
}
func ListUserGrps() (groups []models.UserGroup) {
func ListUserGrps() (groups []schema.UserGroup) {
resp := request[models.SuccessResponse](http.MethodGet, "/api/v1/users/groups", nil)
d, _ := json.Marshal(resp.Response)
json.Unmarshal(d, &groups)
@@ -66,7 +67,7 @@ func DeleteUserGrp(grpID string) *models.SuccessResponse {
return request[models.SuccessResponse](http.MethodDelete, fmt.Sprintf("/api/v1/users/group?group_id=%s", grpID), nil)
}
func GetUserGrp(grpID string) (group models.UserGroup) {
func GetUserGrp(grpID string) (group schema.UserGroup) {
resp := request[models.SuccessResponse](http.MethodGet, fmt.Sprintf("/api/v1/users/group?group_id=%s", grpID), nil)
d, _ := json.Marshal(resp.Response)
json.Unmarshal(d, &group)
+17 -16
View File
@@ -211,12 +211,12 @@ func getAcls(w http.ResponseWriter, r *http.Request) {
return
}
// check if network exists
_, err := logic.GetNetwork(netID)
err := (&schema.Network{Name: netID}).Get(r.Context())
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
}
acls, err := logic.ListAclsByNetwork(models.NetworkID(netID))
acls, err := logic.ListAclsByNetwork(schema.NetworkID(netID))
if err != nil {
logger.Log(0, r.Header.Get("user"), "failed to get all network acl entries: ", err.Error())
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
@@ -276,7 +276,8 @@ func createAcl(w http.ResponseWriter, r *http.Request) {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
}
user, err := logic.GetUser(r.Header.Get("user"))
user := &schema.User{Username: r.Header.Get("user")}
err = user.Get(r.Context())
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
@@ -289,7 +290,7 @@ func createAcl(w http.ResponseWriter, r *http.Request) {
acl := req
acl.ID = uuid.New().String()
acl.CreatedBy = user.UserName
acl.CreatedBy = user.Username
acl.CreatedAt = time.Now().UTC()
acl.Default = false
if acl.ServiceType == models.Any {
@@ -312,20 +313,20 @@ func createAcl(w http.ResponseWriter, r *http.Request) {
return
}
logic.LogEvent(&models.Event{
Action: models.Create,
Action: schema.Create,
Source: models.Subject{
ID: r.Header.Get("user"),
Name: r.Header.Get("user"),
Type: models.UserSub,
Type: schema.UserSub,
},
TriggeredBy: r.Header.Get("user"),
Target: models.Subject{
ID: acl.ID,
Name: acl.Name,
Type: models.AclSub,
Type: schema.AclSub,
},
NetworkID: acl.NetworkID,
Origin: models.Dashboard,
Origin: schema.Dashboard,
})
go mq.PublishPeerUpdate(true)
logic.ReturnSuccessResponseWithJson(w, r, acl, "created acl successfully")
@@ -374,24 +375,24 @@ func updateAcl(w http.ResponseWriter, r *http.Request) {
return
}
logic.LogEvent(&models.Event{
Action: models.Update,
Action: schema.Update,
Source: models.Subject{
ID: r.Header.Get("user"),
Name: r.Header.Get("user"),
Type: models.UserSub,
Type: schema.UserSub,
},
TriggeredBy: r.Header.Get("user"),
Target: models.Subject{
ID: acl.ID,
Name: acl.Name,
Type: models.AclSub,
Type: schema.AclSub,
},
Diff: models.Diff{
Old: acl,
New: updateAcl.Acl,
},
NetworkID: acl.NetworkID,
Origin: models.Dashboard,
Origin: schema.Dashboard,
})
go mq.PublishPeerUpdate(true)
logic.ReturnSuccessResponse(w, r, "updated acl "+acl.Name)
@@ -428,20 +429,20 @@ func deleteAcl(w http.ResponseWriter, r *http.Request) {
return
}
logic.LogEvent(&models.Event{
Action: models.Delete,
Action: schema.Delete,
Source: models.Subject{
ID: r.Header.Get("user"),
Name: r.Header.Get("user"),
Type: models.UserSub,
Type: schema.UserSub,
},
TriggeredBy: r.Header.Get("user"),
Target: models.Subject{
ID: acl.ID,
Name: acl.Name,
Type: models.AclSub,
Type: schema.AclSub,
},
NetworkID: acl.NetworkID,
Origin: models.Dashboard,
Origin: schema.Dashboard,
Diff: models.Diff{
Old: acl,
New: nil,
+15 -15
View File
@@ -140,20 +140,20 @@ func createNs(w http.ResponseWriter, r *http.Request) {
return
}
logic.LogEvent(&models.Event{
Action: models.Create,
Action: schema.Create,
Source: models.Subject{
ID: r.Header.Get("user"),
Name: r.Header.Get("user"),
Type: models.UserSub,
Type: schema.UserSub,
},
TriggeredBy: r.Header.Get("user"),
Target: models.Subject{
ID: ns.ID,
Name: ns.Name,
Type: models.NameserverSub,
Type: schema.NameserverSub,
},
NetworkID: models.NetworkID(ns.NetworkID),
Origin: models.Dashboard,
NetworkID: schema.NetworkID(ns.NetworkID),
Origin: schema.Dashboard,
})
go mq.PublishPeerUpdate(false)
@@ -252,24 +252,24 @@ func updateNs(w http.ResponseWriter, r *http.Request) {
updateFallback = true
}
event := &models.Event{
Action: models.Update,
Action: schema.Update,
Source: models.Subject{
ID: r.Header.Get("user"),
Name: r.Header.Get("user"),
Type: models.UserSub,
Type: schema.UserSub,
},
TriggeredBy: r.Header.Get("user"),
Target: models.Subject{
ID: ns.ID,
Name: updateNs.Name,
Type: models.NameserverSub,
Type: schema.NameserverSub,
},
Diff: models.Diff{
Old: ns,
New: updateNs,
},
NetworkID: models.NetworkID(ns.NetworkID),
Origin: models.Dashboard,
NetworkID: schema.NetworkID(ns.NetworkID),
Origin: schema.Dashboard,
}
if !ns.Default {
@@ -352,20 +352,20 @@ func deleteNs(w http.ResponseWriter, r *http.Request) {
return
}
logic.LogEvent(&models.Event{
Action: models.Delete,
Action: schema.Delete,
Source: models.Subject{
ID: r.Header.Get("user"),
Name: r.Header.Get("user"),
Type: models.UserSub,
Type: schema.UserSub,
},
TriggeredBy: r.Header.Get("user"),
Target: models.Subject{
ID: ns.ID,
Name: ns.Name,
Type: models.NameserverSub,
Type: schema.NameserverSub,
},
NetworkID: models.NetworkID(ns.NetworkID),
Origin: models.Dashboard,
NetworkID: schema.NetworkID(ns.NetworkID),
Origin: schema.Dashboard,
Diff: models.Diff{
Old: ns,
New: nil,
+9 -4
View File
@@ -1,11 +1,13 @@
package controller
import (
"fmt"
"net"
"os"
"testing"
"github.com/google/uuid"
"github.com/gravitl/netmaker/schema"
"github.com/stretchr/testify/assert"
"github.com/txn2/txeh"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
@@ -15,7 +17,7 @@ import (
"github.com/gravitl/netmaker/models"
)
var dnsHost models.Host
var dnsHost schema.Host
func TestGetAllDNS(t *testing.T) {
deleteAllDNS(t)
@@ -425,14 +427,17 @@ func TestValidateDNSCreate(t *testing.T) {
func createHost() {
k, _ := wgtypes.ParseKey("DM5qhLAE20PG9BbfBCger+Ac9D2NDOwCtY1rbYDLf34=")
dnsHost = models.Host{
dnsHost = schema.Host{
ID: uuid.New(),
PublicKey: k.PublicKey(),
PublicKey: schema.WgKey{Key: k.PublicKey()},
HostPass: "password",
OS: "linux",
Name: "dnshost",
}
_ = logic.CreateHost(&dnsHost)
err := logic.CreateHost(&dnsHost)
if err != nil {
fmt.Println("ERROR CREATING HOST", err.Error())
}
}
func deleteAllDNS(t *testing.T) {
+34 -26
View File
@@ -70,7 +70,8 @@ func createEgress(w http.ResponseWriter, r *http.Request) {
egressRange = "*"
req.Domain = ""
}
network, err := logic.GetNetwork(req.Network)
network := &schema.Network{Name: req.Network}
err = network.Get(r.Context())
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
@@ -91,7 +92,7 @@ func createEgress(w http.ResponseWriter, r *http.Request) {
CreatedBy: r.Header.Get("user"),
CreatedAt: time.Now().UTC(),
}
if err := logic.AssignVirtualRangeToEgress(&network, &e); err != nil {
if err := logic.AssignVirtualRangeToEgress(network, &e); err != nil {
logger.Log(0, "error assigning virtual range to egress: ", err.Error())
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
@@ -121,20 +122,20 @@ func createEgress(w http.ResponseWriter, r *http.Request) {
return
}
logic.LogEvent(&models.Event{
Action: models.Create,
Action: schema.Create,
Source: models.Subject{
ID: r.Header.Get("user"),
Name: r.Header.Get("user"),
Type: models.UserSub,
Type: schema.UserSub,
},
TriggeredBy: r.Header.Get("user"),
Target: models.Subject{
ID: e.ID,
Name: e.Name,
Type: models.EgressSub,
Type: schema.EgressSub,
},
NetworkID: models.NetworkID(e.Network),
Origin: models.Dashboard,
NetworkID: schema.NetworkID(e.Network),
Origin: schema.Dashboard,
})
// for nodeID := range e.Nodes {
// node, err := logic.GetNodeByID(nodeID)
@@ -151,8 +152,11 @@ func createEgress(w http.ResponseWriter, r *http.Request) {
if err != nil {
continue
}
host, _ := logic.GetHost(node.HostID.String())
if host == nil {
host := &schema.Host{
ID: node.HostID,
}
err = host.Get(r.Context())
if err != nil {
continue
}
mq.HostUpdate(&models.HostUpdate{
@@ -227,7 +231,8 @@ func updateEgress(w http.ResponseWriter, r *http.Request) {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
}
network, err := logic.GetNetwork(req.Network)
network := &schema.Network{Name: req.Network}
err = network.Get(r.Context())
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
@@ -271,8 +276,8 @@ func updateEgress(w http.ResponseWriter, r *http.Request) {
// Update mode and NAT before calling AssignVirtualRangeToEgress
// This ensures the function sees the new values
if req.Mode != models.VirtualNAT || !req.Nat {
e.Mode = models.DirectNAT
if req.Mode != schema.VirtualNAT || !req.Nat {
e.Mode = schema.DirectNAT
if !req.Nat {
e.Mode = ""
}
@@ -284,8 +289,8 @@ func updateEgress(w http.ResponseWriter, r *http.Request) {
e.Nat = req.Nat
// Assign virtual range if switching to virtual NAT mode from a different mode,
// or if already in virtual NAT mode but virtual range is empty
if (oldMode != models.VirtualNAT) || (e.VirtualRange == "") {
if err := logic.AssignVirtualRangeToEgress(&network, &e); err != nil {
if (oldMode != schema.VirtualNAT) || (e.VirtualRange == "") {
if err := logic.AssignVirtualRangeToEgress(network, &e); err != nil {
logger.Log(0, "error assigning virtual range to egress: ", err.Error())
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
@@ -293,23 +298,23 @@ func updateEgress(w http.ResponseWriter, r *http.Request) {
}
}
event := &models.Event{
Action: models.Update,
Action: schema.Update,
Source: models.Subject{
ID: r.Header.Get("user"),
Name: r.Header.Get("user"),
Type: models.UserSub,
Type: schema.UserSub,
},
TriggeredBy: r.Header.Get("user"),
Target: models.Subject{
ID: e.ID,
Name: e.Name,
Type: models.EgressSub,
Type: schema.EgressSub,
},
Diff: models.Diff{
Old: e,
},
NetworkID: models.NetworkID(e.Network),
Origin: models.Dashboard,
NetworkID: schema.NetworkID(e.Network),
Origin: schema.Dashboard,
}
e.Nodes = make(datatypes.JSONMap)
e.Tags = make(datatypes.JSONMap)
@@ -374,8 +379,11 @@ func updateEgress(w http.ResponseWriter, r *http.Request) {
if err != nil {
continue
}
host, _ := logic.GetHost(node.HostID.String())
if host == nil {
host := &schema.Host{
ID: node.HostID,
}
err = host.Get(r.Context())
if err != nil {
continue
}
mq.HostUpdate(&models.HostUpdate{
@@ -426,20 +434,20 @@ func deleteEgress(w http.ResponseWriter, r *http.Request) {
return
}
logic.LogEvent(&models.Event{
Action: models.Delete,
Action: schema.Delete,
Source: models.Subject{
ID: r.Header.Get("user"),
Name: r.Header.Get("user"),
Type: models.UserSub,
Type: schema.UserSub,
},
TriggeredBy: r.Header.Get("user"),
Target: models.Subject{
ID: e.ID,
Name: e.Name,
Type: models.EgressSub,
Type: schema.EgressSub,
},
NetworkID: models.NetworkID(e.Network),
Origin: models.Dashboard,
NetworkID: schema.NetworkID(e.Network),
Origin: schema.Dashboard,
Diff: models.Diff{
Old: e,
New: nil,
+24 -17
View File
@@ -11,6 +11,7 @@ import (
"github.com/go-playground/validator/v10"
"github.com/google/uuid"
"github.com/gorilla/mux"
"github.com/gravitl/netmaker/schema"
"github.com/gravitl/netmaker/auth"
"github.com/gravitl/netmaker/logger"
@@ -87,19 +88,19 @@ func deleteEnrollmentKey(w http.ResponseWriter, r *http.Request) {
return
}
logic.LogEvent(&models.Event{
Action: models.Delete,
Action: schema.Delete,
Source: models.Subject{
ID: r.Header.Get("user"),
Name: r.Header.Get("user"),
Type: models.UserSub,
Type: schema.UserSub,
},
TriggeredBy: r.Header.Get("user"),
Target: models.Subject{
ID: keyID,
Name: key.Tags[0],
Type: models.EnrollmentKeySub,
Type: schema.EnrollmentKeySub,
},
Origin: models.Dashboard,
Origin: schema.Dashboard,
Diff: models.Diff{
Old: key,
New: nil,
@@ -204,19 +205,19 @@ func createEnrollmentKey(w http.ResponseWriter, r *http.Request) {
return
}
logic.LogEvent(&models.Event{
Action: models.Create,
Action: schema.Create,
Source: models.Subject{
ID: r.Header.Get("user"),
Name: r.Header.Get("user"),
Type: models.UserSub,
Type: schema.UserSub,
},
TriggeredBy: r.Header.Get("user"),
Target: models.Subject{
ID: newEnrollmentKey.Value,
Name: newEnrollmentKey.Tags[0],
Type: models.EnrollmentKeySub,
Type: schema.EnrollmentKeySub,
},
Origin: models.Dashboard,
Origin: schema.Dashboard,
})
logger.Log(2, r.Header.Get("user"), "created enrollment key")
w.WriteHeader(http.StatusOK)
@@ -269,23 +270,23 @@ func updateEnrollmentKey(w http.ResponseWriter, r *http.Request) {
return
}
logic.LogEvent(&models.Event{
Action: models.Update,
Action: schema.Update,
Source: models.Subject{
ID: r.Header.Get("user"),
Name: r.Header.Get("user"),
Type: models.UserSub,
Type: schema.UserSub,
},
TriggeredBy: r.Header.Get("user"),
Target: models.Subject{
ID: newEnrollmentKey.Value,
Name: newEnrollmentKey.Tags[0],
Type: models.EnrollmentKeySub,
Type: schema.EnrollmentKeySub,
},
Diff: models.Diff{
Old: currKey,
New: newEnrollmentKey,
},
Origin: models.Dashboard,
Origin: schema.Dashboard,
})
slog.Info("updated enrollment key", "id", keyId)
w.WriteHeader(http.StatusOK)
@@ -298,7 +299,7 @@ func updateEnrollmentKey(w http.ResponseWriter, r *http.Request) {
// @Accept json
// @Produce json
// @Param token path string true "Enrollment Key Token"
// @Param body body models.Host true "Host registration parameters"
// @Param body body schema.Host true "Host registration parameters"
// @Success 200 {object} models.RegisterResponse
// @Failure 400 {object} models.ErrorResponse
// @Failure 500 {object} models.ErrorResponse
@@ -314,7 +315,7 @@ func handleHostRegister(w http.ResponseWriter, r *http.Request) {
return
}
// get the host
var newHost models.Host
var newHost schema.Host
if err = json.NewDecoder(r.Body).Decode(&newHost); err != nil {
logger.Log(0, r.Header.Get("user"), "error decoding request body: ",
err.Error())
@@ -383,7 +384,7 @@ func handleHostRegister(w http.ResponseWriter, r *http.Request) {
KernelVersion: newHost.KernelVersion,
AutoUpdate: newHost.AutoUpdate,
Tags: keyTags,
}, models.NetworkID(netI))
}, schema.NetworkID(netI))
pcviolations = append(pcviolations, violations...)
if len(violations) > 0 {
skipViolatedNetworks = append(skipViolatedNetworks, netI)
@@ -433,7 +434,10 @@ func handleHostRegister(w http.ResponseWriter, r *http.Request) {
// }
// }
// enrollmentKey.Networks = networksToAdd
currHost, err := logic.GetHost(newHost.ID.String())
currHost := &schema.Host{
ID: newHost.ID,
}
err := currHost.Get(r.Context())
if err != nil {
slog.Error("failed registration", "hostID", newHost.ID.String(), "hostName", newHost.Name, "error", err.Error())
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
@@ -447,7 +451,10 @@ func handleHostRegister(w http.ResponseWriter, r *http.Request) {
return
}
}
host, err := logic.GetHost(newHost.ID.String())
host := &schema.Host{
ID: newHost.ID,
}
err = host.Get(r.Context())
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
+48 -28
View File
@@ -46,7 +46,7 @@ func extClientHandlers(r *mux.Router) {
Methods(http.MethodPut)
r.HandleFunc("/api/extclients/{network}/{clientid}", logic.SecurityCheck(false, http.HandlerFunc(deleteExtClient))).
Methods(http.MethodDelete)
r.HandleFunc("/api/extclients/{network}/{nodeid}", logic.SecurityCheck(false, checkFreeTierLimits(limitChoiceMachines, http.HandlerFunc(createExtClient)))).
r.HandleFunc("/api/extclients/{network}/{nodeid}", logic.SecurityCheck(false, http.HandlerFunc(createExtClient))).
Methods(http.MethodPost)
// unused API
//r.HandleFunc("/api/v1/client_conf/{network}", logic.SecurityCheck(false, http.HandlerFunc(getExtClientHAConf))).Methods(http.MethodGet)
@@ -85,9 +85,15 @@ func getNetworkExtClients(w http.ResponseWriter, r *http.Request) {
username := r.Header.Get("user")
if r.Header.Get("ismaster") != "yes" {
user, err := logic.GetUser(username)
user := &schema.User{
Username: username,
}
err := user.Get(r.Context())
if err == nil {
userRole, err := logic.GetRole(user.PlatformRoleID)
userRole := &schema.UserRole{
ID: user.PlatformRoleID,
}
err := userRole.Get(r.Context())
if err != nil || !userRole.FullAccess {
filtered := []models.ExtClient{}
for _, ec := range extclients {
@@ -218,9 +224,12 @@ func getExtClientConf(w http.ResponseWriter, r *http.Request) {
return
}
eli, _ := (&schema.Egress{Network: gwnode.Network}).ListByNetwork(db.WithContext(context.TODO()))
acls, _ := logic.ListAclsByNetwork(models.NetworkID(client.Network))
acls, _ := logic.ListAclsByNetwork(schema.NetworkID(client.Network))
logic.GetNodeEgressInfo(&gwnode, eli, acls)
host, err := logic.GetHost(gwnode.HostID.String())
host := &schema.Host{
ID: gwnode.HostID,
}
err = host.Get(r.Context())
if err != nil {
logger.Log(
0,
@@ -235,7 +244,8 @@ func getExtClientConf(w http.ResponseWriter, r *http.Request) {
return
}
network, err := logic.GetParentNetwork(client.Network)
network := &schema.Network{Name: client.Network}
err = network.Get(r.Context())
if err != nil {
logger.Log(
1,
@@ -288,8 +298,8 @@ func getExtClientConf(w http.ResponseWriter, r *http.Request) {
}
keepalive := ""
if network.DefaultKeepalive != 0 {
keepalive = "PersistentKeepalive = " + strconv.Itoa(int(network.DefaultKeepalive))
if network.DefaultKeepAlive != 0 {
keepalive = "PersistentKeepalive = " + strconv.Itoa(int(network.DefaultKeepAlive))
}
if gwnode.IngressPersistentKeepalive != 0 {
keepalive = "PersistentKeepalive = " + strconv.Itoa(int(gwnode.IngressPersistentKeepalive))
@@ -428,7 +438,8 @@ func GetExtClientHAConf(w http.ResponseWriter, r *http.Request) {
var params = mux.Vars(r)
networkid := params["network"]
network, err := logic.GetParentNetwork(networkid)
network := &schema.Network{Name: networkid}
err := network.Get(r.Context())
if err != nil {
logger.Log(
1,
@@ -441,7 +452,7 @@ func GetExtClientHAConf(w http.ResponseWriter, r *http.Request) {
}
// fetch client based on availability
nodes, _ := logic.GetNetworkNodes(networkid)
defaultPolicy, _ := logic.GetDefaultPolicy(models.NetworkID(networkid), models.DevicePolicy)
defaultPolicy, _ := logic.GetDefaultPolicy(schema.NetworkID(networkid), models.DevicePolicy)
var targetGwID string
var connectionCnt int = -1
for _, nodeI := range nodes {
@@ -475,7 +486,10 @@ func GetExtClientHAConf(w http.ResponseWriter, r *http.Request) {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
}
host, err := logic.GetHost(gwnode.HostID.String())
host := &schema.Host{
ID: gwnode.HostID,
}
err = host.Get(r.Context())
if err != nil {
logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to get ingress gateway host for node [%s] info: %v", gwnode.ID, err))
@@ -487,12 +501,13 @@ func GetExtClientHAConf(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("ismaster") == "yes" {
userName = logic.MasterUser
} else {
caller, err := logic.GetUser(r.Header.Get("user"))
caller := &schema.User{Username: r.Header.Get("user")}
err = caller.Get(r.Context())
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
}
userName = caller.UserName
userName = caller.Username
}
// create client
var extclient models.ExtClient
@@ -540,8 +555,8 @@ func GetExtClientHAConf(w http.ResponseWriter, r *http.Request) {
}
keepalive := ""
if network.DefaultKeepalive != 0 {
keepalive = "PersistentKeepalive = " + strconv.Itoa(int(network.DefaultKeepalive))
if network.DefaultKeepAlive != 0 {
keepalive = "PersistentKeepalive = " + strconv.Itoa(int(network.DefaultKeepAlive))
}
if gwnode.IngressPersistentKeepalive != 0 {
keepalive = "PersistentKeepalive = " + strconv.Itoa(int(gwnode.IngressPersistentKeepalive))
@@ -711,12 +726,13 @@ func createExtClient(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("ismaster") == "yes" {
userName = logic.MasterUser
} else {
caller, err := logic.GetUser(r.Header.Get("user"))
caller := &schema.User{Username: r.Header.Get("user")}
err = caller.Get(r.Context())
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
}
userName = caller.UserName
userName = caller.Username
// check if user has a config already for remote access client
extclients, err := logic.GetNetworkExtClients(node.Network)
if err != nil {
@@ -742,7 +758,7 @@ func createExtClient(w http.ResponseWriter, r *http.Request) {
// let's first confirm that none of the user's extclients for this gw have device id.
for _, extclient := range extclients {
if extclient.DeviceID == customExtClient.DeviceID &&
extclient.OwnerID == caller.UserName && nodeid == extclient.IngressGatewayID {
extclient.OwnerID == caller.Username && nodeid == extclient.IngressGatewayID {
if jitGrant != nil {
extclient.JITExpiresAt = &jitGrant.ExpiresAt
_ = logic.SaveExtClient(&extclient)
@@ -758,7 +774,7 @@ func createExtClient(w http.ResponseWriter, r *http.Request) {
for _, extclient := range extclients {
if extclient.RemoteAccessClientID != "" &&
extclient.RemoteAccessClientID == customExtClient.RemoteAccessClientID &&
extclient.OwnerID == caller.UserName && nodeid == extclient.IngressGatewayID {
extclient.OwnerID == caller.Username && nodeid == extclient.IngressGatewayID {
if customExtClient.DeviceID != "" && extclient.DeviceID == "" {
// This extclient doesnt include a device ID (and neither do the others).
// We patch it by assigning the device ID from the incoming request.
@@ -793,7 +809,10 @@ func createExtClient(w http.ResponseWriter, r *http.Request) {
dns := gwDNS
extclient.DNS = dns
}
host, err := logic.GetHost(node.HostID.String())
host := &schema.Host{
ID: node.HostID,
}
err = host.Get(r.Context())
if err != nil {
logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to get ingress gateway host for node [%s] info: %v", nodeid, err))
@@ -803,7 +822,8 @@ func createExtClient(w http.ResponseWriter, r *http.Request) {
listenPort := logic.GetPeerListenPort(host)
extclient.IngressGatewayEndpoint = fmt.Sprintf("%s:%d", host.EndpointIP.String(), listenPort)
extclient.Enabled = true
parentNetwork, err := logic.GetNetwork(node.Network)
parentNetwork := &schema.Network{Name: node.Network}
err = parentNetwork.Get(r.Context())
if err == nil { // check if parent network default ACL is enabled (yes) or not (no)
extclient.Enabled = parentNetwork.DefaultACL == "yes"
}
@@ -835,7 +855,7 @@ func createExtClient(w http.ResponseWriter, r *http.Request) {
if extclient.DeviceID != "" {
// check for violations connecting from desktop app
staticNode := extclient.ConvertToStaticNode()
violations, _ := logic.CheckPostureViolations(logic.GetPostureCheckDeviceInfoByNode(&staticNode), models.NetworkID(extclient.Network))
violations, _ := logic.CheckPostureViolations(logic.GetPostureCheckDeviceInfoByNode(&staticNode), schema.NetworkID(extclient.Network))
if len(violations) > 0 {
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("posture check violations"), logic.Forbidden))
return
@@ -885,21 +905,21 @@ func createExtClient(w http.ResponseWriter, r *http.Request) {
if extclient.RemoteAccessClientID != "" {
// if created by user from client app, log event
logic.LogEvent(&models.Event{
Action: models.Connect,
Action: schema.Connect,
Source: models.Subject{
ID: userName,
Name: userName,
Type: models.UserSub,
Type: schema.UserSub,
},
TriggeredBy: userName,
Target: models.Subject{
ID: extclient.Network,
Name: extclient.Network,
Type: models.NetworkSub,
Type: schema.NetworkSub,
Info: extclient,
},
NetworkID: models.NetworkID(extclient.Network),
Origin: models.ClientApp,
NetworkID: schema.NetworkID(extclient.Network),
Origin: schema.ClientApp,
})
}
@@ -1008,7 +1028,7 @@ func updateExtClient(w http.ResponseWriter, r *http.Request) {
if newclient.DeviceID != "" && newclient.Enabled {
// check for violations connecting from desktop app
staticNode := newclient.ConvertToStaticNode()
violations, _ := logic.CheckPostureViolations(logic.GetPostureCheckDeviceInfoByNode(&staticNode), models.NetworkID(newclient.Network))
violations, _ := logic.CheckPostureViolations(logic.GetPostureCheckDeviceInfoByNode(&staticNode), schema.NetworkID(newclient.Network))
if len(violations) > 0 {
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("posture check violations"), logic.Forbidden))
return
+44 -23
View File
@@ -1,6 +1,7 @@
package controller
import (
"context"
"encoding/json"
"errors"
"fmt"
@@ -9,16 +10,18 @@ import (
"github.com/google/uuid"
"github.com/gorilla/mux"
"github.com/gravitl/netmaker/db"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/mq"
"github.com/gravitl/netmaker/schema"
"github.com/gravitl/netmaker/servercfg"
"golang.org/x/exp/slog"
)
func gwHandlers(r *mux.Router) {
r.HandleFunc("/api/nodes/{network}/{nodeid}/gateway", logic.SecurityCheck(true, checkFreeTierLimits(limitChoiceIngress, http.HandlerFunc(createGateway)))).Methods(http.MethodPost)
r.HandleFunc("/api/nodes/{network}/{nodeid}/gateway", logic.SecurityCheck(true, http.HandlerFunc(createGateway))).Methods(http.MethodPost)
r.HandleFunc("/api/nodes/{network}/{nodeid}/gateway", logic.SecurityCheck(true, http.HandlerFunc(deleteGateway))).Methods(http.MethodDelete)
r.HandleFunc("/api/nodes/{network}/{nodeid}/gateway/assign", logic.SecurityCheck(true, http.HandlerFunc(assignGw))).Methods(http.MethodPost)
r.HandleFunc("/api/nodes/{network}/{nodeid}/gateway/unassign", logic.SecurityCheck(true, http.HandlerFunc(unassignGw))).Methods(http.MethodPost)
@@ -48,7 +51,10 @@ func createGateway(w http.ResponseWriter, r *http.Request) {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
}
host, err := logic.GetHost(node.HostID.String())
host := &schema.Host{
ID: node.HostID,
}
err = host.Get(r.Context())
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
@@ -143,19 +149,19 @@ func createGateway(w http.ResponseWriter, r *http.Request) {
logic.GetNodeStatus(&relayNode, false)
apiNode := relayNode.ConvertToAPINode()
logic.LogEvent(&models.Event{
Action: models.Create,
Action: schema.Create,
Source: models.Subject{
ID: r.Header.Get("user"),
Name: r.Header.Get("user"),
Type: models.UserSub,
Type: schema.UserSub,
},
TriggeredBy: r.Header.Get("user"),
Target: models.Subject{
ID: node.ID.String(),
Name: host.Name,
Type: models.GatewaySub,
Type: schema.GatewaySub,
},
Origin: models.Dashboard,
Origin: schema.Dashboard,
})
host.IsStaticPort = true
logic.UpsertHost(host)
@@ -210,7 +216,10 @@ func deleteGateway(w http.ResponseWriter, r *http.Request) {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
}
host, err := logic.GetHost(node.HostID.String())
host := &schema.Host{
ID: node.HostID,
}
err = host.Get(r.Context())
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
@@ -225,7 +234,10 @@ func deleteGateway(w http.ResponseWriter, r *http.Request) {
logger.Log(1, r.Header.Get("user"), "deleted gw", nodeid, "on network", netid)
go func() {
host, err := logic.GetHost(node.HostID.String())
host := &schema.Host{
ID: node.HostID,
}
err = host.Get(db.WithContext(context.TODO()))
if err == nil {
allNodes, err := logic.GetAllNodes()
if err != nil {
@@ -246,7 +258,10 @@ func deleteGateway(w http.ResponseWriter, r *http.Request) {
)
}
h, err := logic.GetHost(relayedNode.HostID.String())
h := &schema.Host{
ID: relayedNode.HostID,
}
err = h.Get(db.WithContext(context.TODO()))
if err == nil {
if h.OS == models.OS_Types.IoT {
nodes, err := logic.GetAllNodes()
@@ -283,19 +298,19 @@ func deleteGateway(w http.ResponseWriter, r *http.Request) {
}()
logic.LogEvent(&models.Event{
Action: models.Delete,
Action: schema.Delete,
Source: models.Subject{
ID: r.Header.Get("user"),
Name: r.Header.Get("user"),
Type: models.UserSub,
Type: schema.UserSub,
},
TriggeredBy: r.Header.Get("user"),
Target: models.Subject{
ID: node.ID.String(),
Name: host.Name,
Type: models.GatewaySub,
Type: schema.GatewaySub,
},
Origin: models.Dashboard,
Origin: schema.Dashboard,
Diff: models.Diff{
Old: node,
New: node,
@@ -405,7 +420,10 @@ func assignGw(w http.ResponseWriter, r *http.Request) {
newNodes = logic.UniqueStrings(newNodes)
logic.UpdateRelayNodes(gatewayNode.ID.String(), gatewayNode.RelayedNodes, newNodes)
host, err := logic.GetHost(node.HostID.String())
host := &schema.Host{
ID: node.HostID,
}
err = host.Get(r.Context())
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
@@ -416,19 +434,19 @@ func assignGw(w http.ResponseWriter, r *http.Request) {
nodeid, netid))
logic.LogEvent(&models.Event{
Action: models.GatewayAssign,
Action: schema.GatewayAssign,
Source: models.Subject{
ID: r.Header.Get("user"),
Name: r.Header.Get("user"),
Type: models.UserSub,
Type: schema.UserSub,
},
TriggeredBy: r.Header.Get("user"),
Target: models.Subject{
ID: node.ID.String(),
Name: host.Name,
Type: models.GatewaySub,
Type: schema.GatewaySub,
},
Origin: models.Dashboard,
Origin: schema.Dashboard,
})
logic.GetNodeStatus(&node, false)
@@ -465,7 +483,10 @@ func unassignGw(w http.ResponseWriter, r *http.Request) {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
}
host, err := logic.GetHost(node.HostID.String())
host := &schema.Host{
ID: node.HostID,
}
err = host.Get(r.Context())
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
@@ -536,19 +557,19 @@ func unassignGw(w http.ResponseWriter, r *http.Request) {
nodeid, netid))
logic.LogEvent(&models.Event{
Action: models.GatewayUnAssign,
Action: schema.GatewayUnAssign,
Source: models.Subject{
ID: r.Header.Get("user"),
Name: r.Header.Get("user"),
Type: models.UserSub,
Type: schema.UserSub,
},
TriggeredBy: r.Header.Get("user"),
Target: models.Subject{
ID: node.ID.String(),
Name: host.Name,
Type: models.GatewaySub,
Type: schema.GatewaySub,
},
Origin: models.Dashboard,
Origin: schema.Dashboard,
})
logic.GetNodeStatus(&node, false)
+304 -104
View File
@@ -1,16 +1,19 @@
package controller
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"strconv"
"time"
"github.com/google/uuid"
"github.com/gorilla/mux"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/db"
dbtypes "github.com/gravitl/netmaker/db/types"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/models"
@@ -24,6 +27,8 @@ import (
func hostHandlers(r *mux.Router) {
r.HandleFunc("/api/hosts", logic.SecurityCheck(true, http.HandlerFunc(getHosts))).
Methods(http.MethodGet)
r.HandleFunc("/api/v1/hosts", logic.SecurityCheck(true, http.HandlerFunc(listHosts))).
Methods(http.MethodGet)
r.HandleFunc("/api/hosts/keys", logic.SecurityCheck(true, http.HandlerFunc(updateAllKeys))).
Methods(http.MethodPut)
r.HandleFunc("/api/hosts/sync", logic.SecurityCheck(true, http.HandlerFunc(syncHosts))).
@@ -36,10 +41,10 @@ func hostHandlers(r *mux.Router) {
Methods(http.MethodPost)
r.HandleFunc("/api/hosts/{hostid}", logic.SecurityCheck(true, http.HandlerFunc(updateHost))).
Methods(http.MethodPut)
// used by netclient
// used by netclient
r.HandleFunc("/api/hosts/{hostid}", AuthorizeHost(http.HandlerFunc(deleteHost))).
Methods(http.MethodDelete)
// used by UI
// used by UI
r.HandleFunc("/api/v1/ui/hosts/{hostid}", logic.SecurityCheck(true, http.HandlerFunc(deleteHost))).
Methods(http.MethodDelete)
r.HandleFunc("/api/hosts/{hostid}/upgrade", logic.SecurityCheck(true, http.HandlerFunc(upgradeHost))).
@@ -88,14 +93,14 @@ func upgradeHosts(w http.ResponseWriter, r *http.Request) {
go func() {
slog.Info("requesting all hosts to upgrade", "user", user)
hosts, err := logic.GetAllHosts()
hosts, err := (&schema.Host{}).ListAll(r.Context())
if err != nil {
slog.Error("failed to retrieve all hosts", "user", user, "error", err)
return
}
for _, host := range hosts {
go func(host models.Host) {
go func(host schema.Host) {
hostUpdate := models.HostUpdate{
Action: action,
Host: host,
@@ -109,19 +114,19 @@ func upgradeHosts(w http.ResponseWriter, r *http.Request) {
}
}()
logic.LogEvent(&models.Event{
Action: models.UpgradeAll,
Action: schema.UpgradeAll,
Source: models.Subject{
ID: r.Header.Get("user"),
Name: r.Header.Get("user"),
Type: models.UserSub,
Type: schema.UserSub,
},
TriggeredBy: r.Header.Get("user"),
Target: models.Subject{
ID: "All Hosts",
Name: "All Hosts",
Type: models.DeviceSub,
Type: schema.DeviceSub,
},
Origin: models.Dashboard,
Origin: schema.Dashboard,
})
slog.Info("upgrade all hosts request received", "user", user)
logic.ReturnSuccessResponse(w, r, "upgrade all hosts request received")
@@ -137,7 +142,18 @@ func upgradeHosts(w http.ResponseWriter, r *http.Request) {
// @Failure 500 {object} models.ErrorResponse
// upgrade host is a handler to send upgrade message to a host
func upgradeHost(w http.ResponseWriter, r *http.Request) {
host, err := logic.GetHost(mux.Vars(r)["hostid"])
hostIDStr := mux.Vars(r)["hostid"]
hostID, err := uuid.Parse(hostIDStr)
if err != nil {
err = fmt.Errorf("failed to parse host id: %w", err)
logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.BadReq))
return
}
host := &schema.Host{
ID: hostID,
}
err = host.Get(r.Context())
if err != nil {
slog.Error("failed to find host", "error", err)
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "notfound"))
@@ -168,7 +184,7 @@ func upgradeHost(w http.ResponseWriter, r *http.Request) {
func getHosts(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
currentHosts, err := logic.GetAllHosts()
currentHosts, err := (&schema.Host{}).ListAll(r.Context())
if err != nil {
logger.Log(0, r.Header.Get("user"), "failed to fetch hosts: ", err.Error())
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
@@ -182,6 +198,73 @@ func getHosts(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(apiHosts)
}
// @Summary List all hosts
// @Router /api/v1/hosts [get]
// @Tags Hosts
// @Security oauth
// @Produce json
// @Param os query []string false "Filter by OS" Enums(windows, linux, darwin)
// @Param page query int false "Page number"
// @Param per_page query int false "Items per page"
// @Success 200 {array} models.ApiHost
// @Failure 500 {object} models.ErrorResponse
func listHosts(w http.ResponseWriter, r *http.Request) {
var osFilters []interface{}
for _, filter := range r.URL.Query()["os"] {
osFilters = append(osFilters, filter)
}
var page, pageSize int
page, _ = strconv.Atoi(r.URL.Query().Get("page"))
if page == 0 {
page = 1
}
pageSize, _ = strconv.Atoi(r.URL.Query().Get("per_page"))
if pageSize < 1 || pageSize > 100 {
pageSize = 10
}
currentHosts, err := (&schema.Host{}).ListAll(
r.Context(),
dbtypes.WithFilter("os", osFilters...),
dbtypes.InAscOrder("name"),
dbtypes.WithPagination(page, pageSize),
)
if err != nil {
logger.Log(0, r.Header.Get("user"), "failed to fetch hosts: ", err.Error())
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
}
apiHosts := logic.GetAllHostsAPI(currentHosts[:])
logger.Log(2, r.Header.Get("user"), "fetched all hosts")
total, err := (&schema.Host{}).Count(
r.Context(),
dbtypes.WithFilter("os", osFilters...),
)
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.Internal))
return
}
totalPages := (total + pageSize - 1) / pageSize
if totalPages == 0 {
totalPages = 1
}
response := models.PaginatedResponse{
Data: apiHosts,
Page: page,
PerPage: pageSize,
Total: total,
TotalPages: totalPages,
}
logic.ReturnSuccessResponseWithJson(w, r, response, "fetched hosts")
}
// @Summary Used by clients for "pull" command
// @Router /api/v1/host [get]
// @Tags Hosts
@@ -190,9 +273,8 @@ func getHosts(w http.ResponseWriter, r *http.Request) {
// @Success 200 {object} models.HostPull
// @Failure 500 {object} models.ErrorResponse
func pull(w http.ResponseWriter, r *http.Request) {
hostID := r.Header.Get(hostIDHeader) // return JSON/API formatted keys
if len(hostID) == 0 {
hostIDStr := r.Header.Get(hostIDHeader) // return JSON/API formatted keys
if len(hostIDStr) == 0 {
logger.Log(0, "no host authorized to pull")
logic.ReturnErrorResponse(
w,
@@ -201,9 +283,20 @@ func pull(w http.ResponseWriter, r *http.Request) {
)
return
}
host, err := logic.GetHost(hostID)
hostID, err := uuid.Parse(hostIDStr)
if err != nil {
logger.Log(0, "no host found during pull", hostID)
err = fmt.Errorf("failed to parse host id: %w", err)
logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.BadReq))
return
}
host := &schema.Host{
ID: hostID,
}
err = host.Get(r.Context())
if err != nil {
logger.Log(0, "no host found during pull", hostIDStr)
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
}
@@ -228,13 +321,13 @@ func pull(w http.ResponseWriter, r *http.Request) {
}
allNodes, err := logic.GetAllNodes()
if err != nil {
logger.Log(0, "failed to get nodes: ", hostID)
logger.Log(0, "failed to get nodes: ", hostIDStr)
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
}
hPU, err := logic.GetPeerUpdateForHost("", host, allNodes, nil, nil)
if err != nil {
logger.Log(0, "could not pull peers for host", hostID, err.Error())
logger.Log(0, "could not pull peers for host", hostIDStr, err.Error())
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
}
@@ -269,7 +362,7 @@ func pull(w http.ResponseWriter, r *http.Request) {
AddressIdentityMap: hPU.AddressIdentityMap,
}
logger.Log(1, hostID, host.Name, "completed a pull")
logger.Log(1, hostIDStr, host.Name, "completed a pull")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(&response)
}
@@ -293,8 +386,18 @@ func updateHost(w http.ResponseWriter, r *http.Request) {
return
}
hostID, err := uuid.Parse(newHostData.ID)
if err != nil {
err = fmt.Errorf("failed to parse host id: %w", err)
logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.BadReq))
return
}
// confirm host exists
currHost, err := logic.GetHost(newHostData.ID)
currHost := &schema.Host{
ID: hostID,
}
err = currHost.Get(r.Context())
if err != nil {
logger.Log(0, r.Header.Get("user"), "failed to update a host:", err.Error())
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
@@ -350,25 +453,25 @@ func updateHost(w http.ResponseWriter, r *http.Request) {
}()
logic.LogEvent(&models.Event{
Action: models.Update,
Action: schema.Update,
Source: models.Subject{
ID: r.Header.Get("user"),
Name: r.Header.Get("user"),
Type: models.UserSub,
Type: schema.UserSub,
},
TriggeredBy: r.Header.Get("user"),
Target: models.Subject{
ID: currHost.ID.String(),
Name: newHost.Name,
Type: models.DeviceSub,
Type: schema.DeviceSub,
},
Diff: models.Diff{
Old: currHost,
New: newHost,
},
Origin: models.Dashboard,
Origin: schema.Dashboard,
})
apiHostData := newHost.ConvertNMHostToAPI()
apiHostData := models.NewApiHostFromSchemaHost(newHost)
logger.Log(2, r.Header.Get("user"), "updated host", newHost.ID.String())
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(apiHostData)
@@ -384,10 +487,19 @@ func updateHost(w http.ResponseWriter, r *http.Request) {
// @Failure 500 {object} models.ErrorResponse
func hostUpdateFallback(w http.ResponseWriter, r *http.Request) {
var params = mux.Vars(r)
hostid := params["hostid"]
currentHost, err := logic.GetHost(hostid)
hostIDStr := params["hostid"]
hostID, err := uuid.Parse(hostIDStr)
if err != nil {
slog.Error("error getting host", "id", hostid, "error", err)
err = fmt.Errorf("failed to parse host id: %w", err)
logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.BadReq))
return
}
currentHost := &schema.Host{
ID: hostID,
}
err = currentHost.Get(r.Context())
if err != nil {
slog.Error("error getting host", "id", hostIDStr, "error", err)
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
}
@@ -466,11 +578,19 @@ func hostUpdateFallback(w http.ResponseWriter, r *http.Request) {
// @Failure 500 {object} models.ErrorResponse
func deleteHost(w http.ResponseWriter, r *http.Request) {
var params = mux.Vars(r)
hostid := params["hostid"]
hostIDStr := params["hostid"]
forceDelete := r.URL.Query().Get("force") == "true"
hostID, err := uuid.Parse(hostIDStr)
if err != nil {
err = fmt.Errorf("failed to parse host id: %w", err)
logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.BadReq))
return
}
// confirm host exists
currHost, err := logic.GetHost(hostid)
currHost := &schema.Host{
ID: hostID,
}
err = currHost.Get(r.Context())
if err != nil {
logger.Log(0, r.Header.Get("user"), "failed to delete a host:", err.Error())
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
@@ -519,25 +639,25 @@ func deleteHost(w http.ResponseWriter, r *http.Request) {
HostID: currHost.ID.String(),
}).DeleteAllPendingHosts(db.WithContext(r.Context()))
logic.LogEvent(&models.Event{
Action: models.Delete,
Action: schema.Delete,
Source: models.Subject{
ID: r.Header.Get("user"),
Name: r.Header.Get("user"),
Type: models.UserSub,
Type: schema.UserSub,
},
TriggeredBy: r.Header.Get("user"),
Target: models.Subject{
ID: currHost.ID.String(),
Name: currHost.Name,
Type: models.DeviceSub,
Type: schema.DeviceSub,
},
Origin: models.Dashboard,
Origin: schema.Dashboard,
Diff: models.Diff{
Old: currHost,
New: nil,
},
})
apiHostData := currHost.ConvertNMHostToAPI()
apiHostData := models.NewApiHostFromSchemaHost(currHost)
logger.Log(2, r.Header.Get("user"), "removed host", currHost.Name)
logic.ReturnSuccessResponseWithJson(w, r, apiHostData, "deleted host "+currHost.Name)
}
@@ -553,9 +673,9 @@ func deleteHost(w http.ResponseWriter, r *http.Request) {
func addHostToNetwork(w http.ResponseWriter, r *http.Request) {
var params = mux.Vars(r)
hostid := params["hostid"]
hostIDStr := params["hostid"]
network := params["network"]
if hostid == "" || network == "" {
if hostIDStr == "" || network == "" {
logic.ReturnErrorResponse(
w,
r,
@@ -563,10 +683,20 @@ func addHostToNetwork(w http.ResponseWriter, r *http.Request) {
)
return
}
// confirm host exists
currHost, err := logic.GetHost(hostid)
hostID, err := uuid.Parse(hostIDStr)
if err != nil {
logger.Log(0, r.Header.Get("user"), "failed to find host:", hostid, err.Error())
err = fmt.Errorf("failed to parse host id: %w", err)
logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.BadReq))
return
}
// confirm host exists
currHost := &schema.Host{
ID: hostID,
}
err = currHost.Get(r.Context())
if err != nil {
logger.Log(0, r.Header.Get("user"), "failed to find host:", hostIDStr, err.Error())
logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.Internal))
return
}
@@ -579,7 +709,7 @@ func addHostToNetwork(w http.ResponseWriter, r *http.Request) {
OSVersion: currHost.OSVersion,
KernelVersion: currHost.KernelVersion,
AutoUpdate: currHost.AutoUpdate,
}, models.NetworkID(network))
}, schema.NetworkID(network))
if len(violations) > 0 {
logic.ReturnErrorResponseWithJson(w, r, violations, logic.FormatError(errors.New("posture check violations"), logic.BadReq))
return
@@ -590,7 +720,7 @@ func addHostToNetwork(w http.ResponseWriter, r *http.Request) {
0,
r.Header.Get("user"),
"failed to add host to network:",
hostid,
hostIDStr,
network,
err.Error(),
)
@@ -623,20 +753,20 @@ func addHostToNetwork(w http.ResponseWriter, r *http.Request) {
fmt.Sprintf("added host %s to network %s", currHost.Name, network),
)
logic.LogEvent(&models.Event{
Action: models.JoinHostToNet,
Action: schema.JoinHostToNet,
Source: models.Subject{
ID: r.Header.Get("user"),
Name: r.Header.Get("user"),
Type: models.UserSub,
Type: schema.UserSub,
},
TriggeredBy: r.Header.Get("user"),
Target: models.Subject{
ID: currHost.ID.String(),
Name: currHost.Name,
Type: models.DeviceSub,
Type: schema.DeviceSub,
},
NetworkID: models.NetworkID(network),
Origin: models.Dashboard,
NetworkID: schema.NetworkID(network),
Origin: schema.Dashboard,
})
w.WriteHeader(http.StatusOK)
}
@@ -653,10 +783,10 @@ func addHostToNetwork(w http.ResponseWriter, r *http.Request) {
func deleteHostFromNetwork(w http.ResponseWriter, r *http.Request) {
var params = mux.Vars(r)
hostid := params["hostid"]
hostIDStr := params["hostid"]
network := params["network"]
forceDelete := r.URL.Query().Get("force") == "true"
if hostid == "" || network == "" {
if hostIDStr == "" || network == "" {
logic.ReturnErrorResponse(
w,
r,
@@ -664,17 +794,26 @@ func deleteHostFromNetwork(w http.ResponseWriter, r *http.Request) {
)
return
}
hostID, err := uuid.Parse(hostIDStr)
if err != nil {
err = fmt.Errorf("failed to parse host id: %w", err)
logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.BadReq))
return
}
// confirm host exists
currHost, err := logic.GetHost(hostid)
currHost := &schema.Host{
ID: hostID,
}
err = currHost.Get(r.Context())
if err != nil {
if database.IsEmptyRecord(err) {
// check if there is any daemon nodes that needs to be deleted
node, err := logic.GetNodeByHostRef(hostid, network)
node, err := logic.GetNodeByHostRef(hostIDStr, network)
if err != nil {
slog.Error(
"couldn't get node for host",
"hostid",
hostid,
hostIDStr,
"network",
network,
"error",
@@ -685,7 +824,7 @@ func deleteHostFromNetwork(w http.ResponseWriter, r *http.Request) {
}
if err = logic.DeleteNodeByID(&node); err != nil {
slog.Error("failed to force delete daemon node",
"nodeid", node.ID.String(), "hostid", hostid, "network", network, "error", err)
"nodeid", node.ID.String(), "hostid", hostIDStr, "network", network, "error", err)
logic.ReturnErrorResponse(
w,
r,
@@ -709,12 +848,12 @@ func deleteHostFromNetwork(w http.ResponseWriter, r *http.Request) {
if err != nil {
if node == nil && forceDelete {
// force cleanup the node
node, err := logic.GetNodeByHostRef(hostid, network)
node, err := logic.GetNodeByHostRef(hostIDStr, network)
if err != nil {
slog.Error(
"couldn't get node for host",
"hostid",
hostid,
hostIDStr,
"network",
network,
"error",
@@ -725,7 +864,7 @@ func deleteHostFromNetwork(w http.ResponseWriter, r *http.Request) {
}
if err = logic.DeleteNodeByID(&node); err != nil {
slog.Error("failed to force delete daemon node",
"nodeid", node.ID.String(), "hostid", hostid, "network", network, "error", err)
"nodeid", node.ID.String(), "hostid", hostIDStr, "network", network, "error", err)
logic.ReturnErrorResponse(
w,
r,
@@ -743,7 +882,7 @@ func deleteHostFromNetwork(w http.ResponseWriter, r *http.Request) {
0,
r.Header.Get("user"),
"failed to remove host from network:",
hostid,
hostIDStr,
network,
err.Error(),
)
@@ -766,20 +905,20 @@ func deleteHostFromNetwork(w http.ResponseWriter, r *http.Request) {
}
}()
logic.LogEvent(&models.Event{
Action: models.RemoveHostFromNet,
Action: schema.RemoveHostFromNet,
Source: models.Subject{
ID: r.Header.Get("user"),
Name: r.Header.Get("user"),
Type: models.UserSub,
Type: schema.UserSub,
},
TriggeredBy: r.Header.Get("user"),
Target: models.Subject{
ID: currHost.ID.String(),
Name: currHost.Name,
Type: models.DeviceSub,
Type: schema.DeviceSub,
},
NetworkID: models.NetworkID(network),
Origin: models.Dashboard,
NetworkID: schema.NetworkID(network),
Origin: schema.Dashboard,
})
logger.Log(
2,
@@ -829,7 +968,17 @@ func authenticateHost(response http.ResponseWriter, request *http.Request) {
logic.ReturnErrorResponse(response, request, errorResponse)
return
}
host, err := logic.GetHost(authRequest.ID)
hostID, err := uuid.Parse(authRequest.ID)
if err != nil {
err = fmt.Errorf("failed to parse host id: %w", err)
logic.ReturnErrorResponse(response, request, logic.FormatError(err, logic.BadReq))
return
}
host := &schema.Host{
ID: hostID,
}
err = host.Get(request.Context())
if err != nil {
errorResponse.Code = http.StatusBadRequest
errorResponse.Message = err.Error()
@@ -900,9 +1049,17 @@ func authenticateHost(response http.ResponseWriter, request *http.Request) {
// @Failure 400 {object} models.ErrorResponse
func signalPeer(w http.ResponseWriter, r *http.Request) {
var params = mux.Vars(r)
hostid := params["hostid"]
hostIDStr := params["hostid"]
hostID, err := uuid.Parse(hostIDStr)
if err != nil {
err = fmt.Errorf("failed to parse host id: %w", err)
logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.BadReq))
return
}
// confirm host exists
_, err := logic.GetHost(hostid)
err = (&schema.Host{
ID: hostID,
}).Get(r.Context())
if err != nil {
logger.Log(0, r.Header.Get("user"), "failed to get host:", err.Error())
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
@@ -923,7 +1080,16 @@ func signalPeer(w http.ResponseWriter, r *http.Request) {
return
}
signal.IsPro = servercfg.IsPro
peerHost, err := logic.GetHost(signal.ToHostID)
hostID, err = uuid.Parse(signal.ToHostID)
if err != nil {
err = fmt.Errorf("failed to parse host id: %w", err)
logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.BadReq))
return
}
peerHost := &schema.Host{
ID: hostID,
}
err = peerHost.Get(r.Context())
if err != nil {
logic.ReturnErrorResponse(
w,
@@ -962,7 +1128,7 @@ func signalPeer(w http.ResponseWriter, r *http.Request) {
func updateAllKeys(w http.ResponseWriter, r *http.Request) {
var errorResponse = models.ErrorResponse{}
w.Header().Set("Content-Type", "application/json")
hosts, err := logic.GetAllHosts()
hosts, err := (&schema.Host{}).ListAll(r.Context())
if err != nil {
errorResponse.Code = http.StatusBadRequest
errorResponse.Message = err.Error()
@@ -988,19 +1154,19 @@ func updateAllKeys(w http.ResponseWriter, r *http.Request) {
}
}()
logic.LogEvent(&models.Event{
Action: models.RefreshAllKeys,
Action: schema.RefreshAllKeys,
Source: models.Subject{
ID: r.Header.Get("user"),
Name: r.Header.Get("user"),
Type: models.UserSub,
Type: schema.UserSub,
},
TriggeredBy: r.Header.Get("user"),
Target: models.Subject{
ID: "All Devices",
Name: "All Devices",
Type: models.DeviceSub,
Type: schema.DeviceSub,
},
Origin: models.Dashboard,
Origin: schema.Dashboard,
})
logger.Log(2, r.Header.Get("user"), "updated keys for all hosts")
w.WriteHeader(http.StatusOK)
@@ -1017,10 +1183,19 @@ func updateKeys(w http.ResponseWriter, r *http.Request) {
var errorResponse = models.ErrorResponse{}
w.Header().Set("Content-Type", "application/json")
var params = mux.Vars(r)
hostid := params["hostid"]
host, err := logic.GetHost(hostid)
hostIDStr := params["hostid"]
hostID, err := uuid.Parse(hostIDStr)
if err != nil {
logger.Log(0, "failed to retrieve host", hostid, err.Error())
err = fmt.Errorf("failed to parse host id: %w", err)
logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.BadReq))
return
}
host := &schema.Host{
ID: hostID,
}
err = host.Get(r.Context())
if err != nil {
logger.Log(0, "failed to retrieve host", hostIDStr, err.Error())
errorResponse.Code = http.StatusBadRequest
errorResponse.Message = err.Error()
logger.Log(0, r.Header.Get("user"),
@@ -1038,19 +1213,19 @@ func updateKeys(w http.ResponseWriter, r *http.Request) {
}
}()
logic.LogEvent(&models.Event{
Action: models.RefreshKey,
Action: schema.RefreshKey,
Source: models.Subject{
ID: r.Header.Get("user"),
Name: r.Header.Get("user"),
Type: models.UserSub,
Type: schema.UserSub,
},
TriggeredBy: r.Header.Get("user"),
Target: models.Subject{
ID: host.ID.String(),
Name: host.Name,
Type: models.DeviceSub,
Type: schema.DeviceSub,
},
Origin: models.Dashboard,
Origin: schema.Dashboard,
})
logger.Log(2, r.Header.Get("user"), "updated key on host", host.Name)
w.WriteHeader(http.StatusOK)
@@ -1069,14 +1244,14 @@ func syncHosts(w http.ResponseWriter, r *http.Request) {
go func() {
slog.Info("requesting all hosts to sync", "user", user)
hosts, err := logic.GetAllHosts()
hosts, err := (&schema.Host{}).ListAll(r.Context())
if err != nil {
slog.Error("failed to retrieve all hosts", "user", user, "error", err)
return
}
for _, host := range hosts {
go func(host models.Host) {
go func(host schema.Host) {
hostUpdate := models.HostUpdate{
Action: models.RequestPull,
Host: host,
@@ -1091,19 +1266,19 @@ func syncHosts(w http.ResponseWriter, r *http.Request) {
}
}()
logic.LogEvent(&models.Event{
Action: models.SyncAll,
Action: schema.SyncAll,
Source: models.Subject{
ID: r.Header.Get("user"),
Name: r.Header.Get("user"),
Type: models.UserSub,
Type: schema.UserSub,
},
TriggeredBy: r.Header.Get("user"),
Target: models.Subject{
ID: "All Devices",
Name: "All Devices",
Type: models.DeviceSub,
Type: schema.DeviceSub,
},
Origin: models.Dashboard,
Origin: schema.Dashboard,
})
slog.Info("sync all hosts request received", "user", user)
logic.ReturnSuccessResponse(w, r, "sync all hosts request received")
@@ -1117,12 +1292,20 @@ func syncHosts(w http.ResponseWriter, r *http.Request) {
// @Success 200 {string} string "OK"
// @Failure 400 {object} models.ErrorResponse
func syncHost(w http.ResponseWriter, r *http.Request) {
hostId := mux.Vars(r)["hostid"]
hostIDStr := mux.Vars(r)["hostid"]
var errorResponse = models.ErrorResponse{}
w.Header().Set("Content-Type", "application/json")
host, err := logic.GetHost(hostId)
hostID, err := uuid.Parse(hostIDStr)
if err != nil {
err = fmt.Errorf("failed to parse host id: %w", err)
logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.BadReq))
return
}
host := &schema.Host{
ID: hostID,
}
err = host.Get(r.Context())
if err != nil {
slog.Error("failed to retrieve host", "user", r.Header.Get("user"), "error", err)
errorResponse.Code = http.StatusBadRequest
@@ -1141,26 +1324,26 @@ func syncHost(w http.ResponseWriter, r *http.Request) {
}
}()
logic.LogEvent(&models.Event{
Action: models.Sync,
Action: schema.Sync,
Source: models.Subject{
ID: r.Header.Get("user"),
Name: r.Header.Get("user"),
Type: models.UserSub,
Type: schema.UserSub,
},
TriggeredBy: r.Header.Get("user"),
Target: models.Subject{
ID: host.ID.String(),
Name: host.Name,
Type: models.DeviceSub,
Type: schema.DeviceSub,
},
Origin: models.Dashboard,
Origin: schema.Dashboard,
})
slog.Info("requested host pull", "user", r.Header.Get("user"), "host", host.ID.String())
w.WriteHeader(http.StatusOK)
}
func delEmqxHosts(w http.ResponseWriter, r *http.Request) {
currentHosts, err := logic.GetAllHosts()
currentHosts, err := (&schema.Host{}).ListAll(r.Context())
if err != nil {
logger.Log(0, r.Header.Get("user"), "failed to fetch hosts: ", err.Error())
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
@@ -1194,10 +1377,18 @@ func delEmqxHosts(w http.ResponseWriter, r *http.Request) {
// @Success 200 {object} models.HostPeerInfo
// @Failure 500 {object} models.ErrorResponse
func getHostPeerInfo(w http.ResponseWriter, r *http.Request) {
hostId := mux.Vars(r)["hostid"]
hostIDStr := mux.Vars(r)["hostid"]
var errorResponse = models.ErrorResponse{}
host, err := logic.GetHost(hostId)
hostID, err := uuid.Parse(hostIDStr)
if err != nil {
err = fmt.Errorf("failed to parse host id: %w", err)
logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.BadReq))
return
}
host := &schema.Host{
ID: hostID,
}
err = host.Get(r.Context())
if err != nil {
slog.Error("failed to retrieve host", "error", err)
errorResponse.Code = http.StatusBadRequest
@@ -1263,7 +1454,16 @@ func approvePendingHost(w http.ResponseWriter, r *http.Request) {
})
return
}
h, err := logic.GetHost(p.HostID)
hostID, err := uuid.Parse(p.HostID)
if err != nil {
err = fmt.Errorf("failed to parse host id: %w", err)
logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.BadReq))
return
}
h := &schema.Host{
ID: hostID,
}
err = h.Get(r.Context())
if err != nil {
logic.ReturnErrorResponse(w, r, models.ErrorResponse{
Code: http.StatusBadRequest,
@@ -1364,22 +1564,22 @@ func rejectPendingHost(w http.ResponseWriter, r *http.Request) {
// addDefaultHostToNetworks enrolls a newly-made-default host into every
// existing network it is not already part of, applying the standard default
// host operations for each network.
func addDefaultHostToNetworks(host *models.Host) {
networks, err := logic.GetNetworks()
func addDefaultHostToNetworks(host *schema.Host) {
networks, err := (&schema.Network{}).ListAll(db.WithContext(context.TODO()))
if err != nil {
logger.Log(0, "failed to get networks for default host ops:", err.Error())
return
}
for _, network := range networks {
if network.AutoJoin != "true" {
if !network.AutoJoin {
continue
}
newNode, err := logic.UpdateHostNetwork(host, network.NetID, true)
newNode, err := logic.UpdateHostNetwork(host, network.Name, true)
if err != nil {
logger.Log(2, "skipping network", network.NetID, "for default host", host.Name, ":", err.Error())
logger.Log(2, "skipping network", network.Name, "for default host", host.Name, ":", err.Error())
continue
}
logger.Log(1, "added default host", host.Name, "to network", network.NetID)
logger.Log(1, "added default host", host.Name, "to network", network.Name)
if len(host.Nodes) == 1 {
mq.HostUpdate(&models.HostUpdate{
Action: models.RequestPull,
@@ -1393,10 +1593,10 @@ func addDefaultHostToNetworks(host *models.Host) {
Node: *newNode,
})
}
logic.CreateIngressGateway(network.NetID, newNode.ID.String(), models.IngressRequest{})
logic.CreateIngressGateway(network.Name, newNode.ID.String(), models.IngressRequest{})
logic.CreateRelay(models.RelayRequest{
NodeID: newNode.ID.String(),
NetID: network.NetID,
NetID: network.Name,
})
}
}
+5 -1
View File
@@ -10,6 +10,7 @@ import (
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/mq"
"github.com/gravitl/netmaker/schema"
"github.com/gravitl/netmaker/servercfg"
)
@@ -33,7 +34,10 @@ func createInternetGw(w http.ResponseWriter, r *http.Request) {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
}
host, err := logic.GetHost(node.HostID.String())
host := &schema.Host{
ID: node.HostID,
}
err = host.Get(r.Context())
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
-75
View File
@@ -1,75 +0,0 @@
package controller
import (
"net/http"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/models"
)
// limit consts
const (
limitChoiceNetworks = iota
limitChoiceUsers
limitChoiceMachines
limitChoiceIngress
limitChoiceEgress
)
func checkFreeTierLimits(limitChoice int, next http.Handler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var errorResponse = models.ErrorResponse{
Code: http.StatusForbidden, Message: "free tier limits exceeded on ",
}
if logic.FreeTier { // check that free tier limits not exceeded
switch limitChoice {
case limitChoiceNetworks:
currentNetworks, err := logic.GetNetworks()
if (err != nil && !database.IsEmptyRecord(err)) ||
len(currentNetworks) >= logic.NetworksLimit {
errorResponse.Message += "networks"
logic.ReturnErrorResponse(w, r, errorResponse)
return
}
case limitChoiceUsers:
users, err := logic.GetUsers()
if (err != nil && !database.IsEmptyRecord(err)) ||
len(users) >= logic.UsersLimit {
errorResponse.Message += "users"
logic.ReturnErrorResponse(w, r, errorResponse)
return
}
case limitChoiceMachines:
hosts, hErr := logic.GetAllHosts()
clients, cErr := logic.GetAllExtClients()
if (hErr != nil && !database.IsEmptyRecord(hErr)) ||
(cErr != nil && !database.IsEmptyRecord(cErr)) ||
len(hosts)+len(clients) >= logic.MachinesLimit {
errorResponse.Message += "machines"
logic.ReturnErrorResponse(w, r, errorResponse)
return
}
case limitChoiceIngress:
ingresses, err := logic.GetAllIngresses()
if (err != nil && !database.IsEmptyRecord(err)) ||
len(ingresses) >= logic.IngressesLimit {
errorResponse.Message += "ingresses"
logic.ReturnErrorResponse(w, r, errorResponse)
return
}
case limitChoiceEgress:
egresses, err := logic.GetAllEgresses()
if (err != nil && !database.IsEmptyRecord(err)) ||
len(egresses) >= logic.EgressesLimit {
errorResponse.Message += "egresses"
logic.ReturnErrorResponse(w, r, errorResponse)
return
}
}
}
next.ServeHTTP(w, r)
}
}
+27 -27
View File
@@ -6,7 +6,7 @@ import (
"github.com/gorilla/mux"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/schema"
)
func userMiddleWare(handler http.Handler) http.Handler {
@@ -33,79 +33,79 @@ func userMiddleWare(handler http.Handler) http.Handler {
r.Header.Set("NET_ID", r.URL.Query().Get("network"))
}
if strings.Contains(route, "hosts") || strings.Contains(route, "nodes") {
r.Header.Set("TARGET_RSRC", models.HostRsrc.String())
r.Header.Set("TARGET_RSRC", schema.HostRsrc.String())
}
if strings.Contains(route, "dns") {
r.Header.Set("TARGET_RSRC", models.DnsRsrc.String())
r.Header.Set("TARGET_RSRC", schema.DnsRsrc.String())
}
if strings.Contains(route, "rac") {
r.Header.Set("RAC", "true")
}
if strings.Contains(route, "users") {
r.Header.Set("TARGET_RSRC", models.UserRsrc.String())
r.Header.Set("TARGET_RSRC", schema.UserRsrc.String())
}
if strings.Contains(route, "ingress") {
r.Header.Set("TARGET_RSRC", models.RemoteAccessGwRsrc.String())
r.Header.Set("TARGET_RSRC", schema.RemoteAccessGwRsrc.String())
}
if strings.Contains(route, "createrelay") || strings.Contains(route, "deleterelay") {
r.Header.Set("TARGET_RSRC", models.RelayRsrc.String())
r.Header.Set("TARGET_RSRC", schema.RelayRsrc.String())
}
if strings.Contains(route, "gateway") {
r.Header.Set("TARGET_RSRC", models.GatewayRsrc.String())
r.Header.Set("TARGET_RSRC", schema.GatewayRsrc.String())
}
if strings.Contains(route, "egress") {
r.Header.Set("TARGET_RSRC", models.EgressGwRsrc.String())
r.Header.Set("TARGET_RSRC", schema.EgressGwRsrc.String())
}
if strings.Contains(route, "networks") {
r.Header.Set("TARGET_RSRC", models.NetworkRsrc.String())
r.Header.Set("TARGET_RSRC", schema.NetworkRsrc.String())
}
// check 'graph' after 'networks', otherwise the
// header will be overwritten.
if strings.Contains(route, "graph") {
r.Header.Set("TARGET_RSRC", models.HostRsrc.String())
r.Header.Set("TARGET_RSRC", schema.HostRsrc.String())
}
if strings.Contains(route, "acls") {
r.Header.Set("TARGET_RSRC", models.AclRsrc.String())
r.Header.Set("TARGET_RSRC", schema.AclRsrc.String())
}
if strings.Contains(route, "tags") {
r.Header.Set("TARGET_RSRC", models.TagRsrc.String())
r.Header.Set("TARGET_RSRC", schema.TagRsrc.String())
}
if strings.Contains(route, "extclients") || strings.Contains(route, "client_conf") {
r.Header.Set("TARGET_RSRC", models.ExtClientsRsrc.String())
r.Header.Set("TARGET_RSRC", schema.ExtClientsRsrc.String())
}
if strings.Contains(route, "enrollment-keys") {
r.Header.Set("TARGET_RSRC", models.EnrollmentKeysRsrc.String())
r.Header.Set("TARGET_RSRC", schema.EnrollmentKeysRsrc.String())
}
if strings.Contains(route, "posture_check") {
r.Header.Set("TARGET_RSRC", models.PostureCheckRsrc.String())
r.Header.Set("TARGET_RSRC", schema.PostureCheckRsrc.String())
}
if strings.Contains(route, "activity") {
r.Header.Set("TARGET_RSRC", models.UserActivityRsrc.String())
r.Header.Set("TARGET_RSRC", schema.UserActivityRsrc.String())
}
if strings.Contains(route, "nameserver") {
r.Header.Set("TARGET_RSRC", models.NameserverRsrc.String())
r.Header.Set("TARGET_RSRC", schema.NameserverRsrc.String())
}
if strings.Contains(route, "jit") {
r.Header.Set("TARGET_RSRC", models.JitAdminRsrc.String())
r.Header.Set("TARGET_RSRC", schema.JitAdminRsrc.String())
}
if strings.Contains(route, "jit_user") {
r.Header.Set("TARGET_RSRC", models.JitUserRsrc.String())
r.Header.Set("TARGET_RSRC", schema.JitUserRsrc.String())
}
if strings.Contains(route, "metrics") {
r.Header.Set("TARGET_RSRC", models.MetricRsrc.String())
r.Header.Set("TARGET_RSRC", schema.MetricRsrc.String())
}
if strings.Contains(route, "flows") {
r.Header.Set("TARGET_RSRC", models.TrafficFlow.String())
r.Header.Set("TARGET_RSRC", schema.TrafficFlow.String())
}
if keyID, ok := params["keyID"]; ok {
r.Header.Set("TARGET_RSRC_ID", keyID)
}
if nodeID, ok := params["nodeid"]; ok && r.Header.Get("TARGET_RSRC") != models.ExtClientsRsrc.String() {
if nodeID, ok := params["nodeid"]; ok && r.Header.Get("TARGET_RSRC") != schema.ExtClientsRsrc.String() {
r.Header.Set("TARGET_RSRC_ID", nodeID)
}
if strings.Contains(route, "failover") {
r.Header.Set("TARGET_RSRC", models.FailOverRsrc.String())
r.Header.Set("TARGET_RSRC", schema.FailOverRsrc.String())
nodeID := r.Header.Get("TARGET_RSRC_ID")
node, _ := logic.GetNodeByID(nodeID)
r.Header.Set("NET_ID", node.Network)
@@ -133,10 +133,10 @@ func userMiddleWare(handler http.Handler) http.Handler {
}
}
if r.Header.Get("NET_ID") == "" && (r.Header.Get("TARGET_RSRC_ID") == "" ||
r.Header.Get("TARGET_RSRC") == models.EnrollmentKeysRsrc.String() ||
r.Header.Get("TARGET_RSRC") == models.UserRsrc.String()) ||
(r.Header.Get("TARGET_RSRC") == models.UserActivityRsrc.String() && route != "/api/v1/network/activity") ||
r.Header.Get("TARGET_RSRC") == models.TrafficFlow.String() {
r.Header.Get("TARGET_RSRC") == schema.EnrollmentKeysRsrc.String() ||
r.Header.Get("TARGET_RSRC") == schema.UserRsrc.String()) ||
(r.Header.Get("TARGET_RSRC") == schema.UserActivityRsrc.String() && route != "/api/v1/network/activity") ||
r.Header.Get("TARGET_RSRC") == schema.TrafficFlow.String() {
r.Header.Set("IS_GLOBAL_ACCESS", "yes")
}
r.Header.Set("RSRC_TYPE", r.Header.Get("TARGET_RSRC"))
+7 -5
View File
@@ -13,6 +13,7 @@ import (
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/mq"
"github.com/gravitl/netmaker/schema"
"github.com/gravitl/netmaker/servercfg"
"golang.org/x/crypto/bcrypt"
"golang.org/x/exp/slog"
@@ -30,9 +31,9 @@ import (
// @Failure 400 {object} models.ErrorResponse
func migrate(w http.ResponseWriter, r *http.Request) {
data := models.MigrationData{}
host := models.Host{}
host := schema.Host{}
node := models.Node{}
nodes := []models.Node{}
var nodes []models.Node
server := models.ServerConfig{}
err := json.NewDecoder(r.Body).Decode(&data)
if err != nil {
@@ -127,9 +128,9 @@ func migrate(w http.ResponseWriter, r *http.Request) {
}
}
func convertLegacyHostNode(legacy models.LegacyNode) (models.Host, models.Node) {
func convertLegacyHostNode(legacy models.LegacyNode) (schema.Host, models.Node) {
//convert host
host := models.Host{}
host := schema.Host{}
host.ID = uuid.New()
host.IPForwarding = models.ParseBool(legacy.IPForwarding)
host.AutoUpdate = logic.AutoUpdateEnabled()
@@ -139,7 +140,8 @@ func convertLegacyHostNode(legacy models.LegacyNode) (models.Host, models.Node)
host.ListenPort = 51821
}
host.MTU = int(legacy.MTU)
host.PublicKey, _ = wgtypes.ParseKey(legacy.PublicKey)
pubKey, _ := wgtypes.ParseKey(legacy.PublicKey)
host.PublicKey = schema.WgKey{Key: pubKey}
host.MacAddress = net.HardwareAddr(legacy.MacAddress)
host.TrafficKeyPublic = legacy.TrafficKeys.Mine
host.Nodes = append([]string{}, legacy.ID)
+69 -60
View File
@@ -10,6 +10,7 @@ import (
"github.com/google/uuid"
"github.com/gorilla/mux"
"github.com/gravitl/netmaker/schema"
"golang.org/x/exp/slog"
"github.com/gravitl/netmaker/database"
@@ -27,7 +28,7 @@ func networkHandlers(r *mux.Router) {
Methods(http.MethodGet)
r.HandleFunc("/api/v1/networks/stats", logic.SecurityCheck(true, http.HandlerFunc(getNetworksStats))).
Methods(http.MethodGet)
r.HandleFunc("/api/networks", logic.SecurityCheck(true, checkFreeTierLimits(limitChoiceNetworks, http.HandlerFunc(createNetwork)))).
r.HandleFunc("/api/networks", logic.SecurityCheck(true, http.HandlerFunc(createNetwork))).
Methods(http.MethodPost)
r.HandleFunc("/api/networks/{networkname}", logic.SecurityCheck(true, http.HandlerFunc(getNetwork))).
Methods(http.MethodGet)
@@ -52,26 +53,26 @@ func networkHandlers(r *mux.Router) {
// @Tags Networks
// @Security oauth
// @Produce json
// @Success 200 {object} models.Network
// @Success 200 {array} schema.Network
// @Failure 500 {object} models.ErrorResponse
func getNetworks(w http.ResponseWriter, r *http.Request) {
var err error
allnetworks, err := logic.GetNetworks()
if err != nil && !database.IsEmptyRecord(err) {
allnetworks, err := (&schema.Network{}).ListAll(r.Context())
if err != nil {
slog.Error("failed to fetch networks", "error", err.Error())
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
}
if r.Header.Get("ismaster") != "yes" {
username := r.Header.Get("user")
user, err := logic.GetUser(username)
user := &schema.User{Username: username}
err = user.Get(r.Context())
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
}
allnetworks = logic.FilterNetworksByRole(allnetworks, *user)
allnetworks = logic.FilterNetworksByRole(allnetworks, user)
}
logger.Log(2, r.Header.Get("user"), "fetched networks.")
@@ -90,32 +91,38 @@ func getNetworks(w http.ResponseWriter, r *http.Request) {
func getNetworksStats(w http.ResponseWriter, r *http.Request) {
var err error
allnetworks, err := logic.GetNetworks()
if err != nil && !database.IsEmptyRecord(err) {
allnetworks, err := (&schema.Network{}).ListAll(r.Context())
if err != nil {
slog.Error("failed to fetch networks", "error", err.Error())
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
}
if r.Header.Get("ismaster") != "yes" {
username := r.Header.Get("user")
user, err := logic.GetUser(username)
user := &schema.User{Username: username}
err = user.Get(r.Context())
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
}
allnetworks = logic.FilterNetworksByRole(allnetworks, *user)
allnetworks = logic.FilterNetworksByRole(allnetworks, user)
}
allNodes, err := logic.GetAllNodes()
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
}
netstats := []models.NetworkStatResp{}
type networkStatResp struct {
Network *schema.Network
Hosts int
}
var netstats []networkStatResp
logic.SortNetworks(allnetworks[:])
for _, network := range allnetworks {
netstats = append(netstats, models.NetworkStatResp{
Network: network,
Hosts: len(logic.GetNetworkNodesMemory(allNodes, network.NetID)),
netstats = append(netstats, networkStatResp{
Network: &network,
Hosts: len(logic.GetNetworkNodesMemory(allNodes, network.Name)),
})
}
logger.Log(2, r.Header.Get("user"), "fetched networks.")
@@ -128,7 +135,7 @@ func getNetworksStats(w http.ResponseWriter, r *http.Request) {
// @Security oauth
// @Param networkname path string true "Network name"
// @Produce json
// @Success 200 {object} models.Network
// @Success 200 {object} schema.Network
// @Failure 404 {object} models.ErrorResponse
// @Failure 500 {object} models.ErrorResponse
func getNetwork(w http.ResponseWriter, r *http.Request) {
@@ -136,7 +143,8 @@ func getNetwork(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
var params = mux.Vars(r)
netname := params["networkname"]
network, err := logic.GetNetwork(netname)
network := &schema.Network{Name: netname}
err := network.Get(r.Context())
if err != nil {
logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to fetch network [%s] info: %v",
netname, err))
@@ -383,7 +391,7 @@ func updateNetworkACLv2(w http.ResponseWriter, r *http.Request) {
}
// update ingress gateways of associated clients
hosts, err := logic.GetAllHosts()
hosts, err := (&schema.Host{}).ListAll(r.Context())
if err != nil {
slog.Error(
"failed to fetch hosts after network ACL update. skipping publish extclients ACL",
@@ -392,7 +400,7 @@ func updateNetworkACLv2(w http.ResponseWriter, r *http.Request) {
)
return
}
hostsMap := make(map[uuid.UUID]models.Host)
hostsMap := make(map[uuid.UUID]schema.Host)
for _, host := range hosts {
hostsMap[host.ID] = host
}
@@ -486,14 +494,14 @@ func getNetworkEgressRoutes(w http.ResponseWriter, r *http.Request) {
var params = mux.Vars(r)
netname := params["networkname"]
// check if network exists
_, err := logic.GetNetwork(netname)
err := (&schema.Network{Name: netname}).Get(r.Context())
if err != nil {
logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to fetch ACLs for network [%s]: %v", netname, err))
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
}
nodeEgressRoutes, _, err := logic.GetEgressRanges(models.NetworkID(netname))
nodeEgressRoutes, _, err := logic.GetEgressRanges(schema.NetworkID(netname))
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
@@ -538,8 +546,8 @@ func deleteNetwork(w http.ResponseWriter, r *http.Request) {
}
go logic.UnlinkNetworkAndTagsFromEnrollmentKeys(network, true)
go logic.DeleteNetworkRoles(network)
go logic.DeleteAllNetworkTags(models.NetworkID(network))
go logic.DeleteNetworkPolicies(models.NetworkID(network))
go logic.DeleteAllNetworkTags(schema.NetworkID(network))
go logic.DeleteNetworkPolicies(schema.NetworkID(network))
//delete network from allocated ip map
go logic.RemoveNetworkFromAllocatedIpMap(network)
go func() {
@@ -561,19 +569,19 @@ func deleteNetwork(w http.ResponseWriter, r *http.Request) {
}
}()
logic.LogEvent(&models.Event{
Action: models.Delete,
Action: schema.Delete,
Source: models.Subject{
ID: r.Header.Get("user"),
Name: r.Header.Get("user"),
Type: models.UserSub,
Type: schema.UserSub,
},
TriggeredBy: r.Header.Get("user"),
Target: models.Subject{
ID: network,
Name: network,
Type: models.NetworkSub,
Type: schema.NetworkSub,
},
Origin: models.Dashboard,
Origin: schema.Dashboard,
Diff: models.Diff{
Old: network,
New: nil,
@@ -588,15 +596,15 @@ func deleteNetwork(w http.ResponseWriter, r *http.Request) {
// @Router /api/networks [post]
// @Tags Networks
// @Security oauth
// @Param body body models.Network true "Network details"
// @Param body body schema.Network true "Network details"
// @Produce json
// @Success 200 {object} models.Network
// @Success 200 {object} schema.Network
// @Failure 400 {object} models.ErrorResponse
func createNetwork(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
var network models.Network
var network schema.Network
// we decode our body request params
err := json.NewDecoder(r.Body).Decode(&network)
@@ -608,9 +616,9 @@ func createNetwork(w http.ResponseWriter, r *http.Request) {
}
featureFlags := logic.GetFeatureFlags()
if !featureFlags.EnableDeviceApproval {
network.AutoJoin = "true"
network.AutoJoin = true
}
if len(network.NetID) > 32 {
if len(network.Name) > 32 {
err := errors.New("network name shouldn't exceed 32 characters")
logger.Log(0, r.Header.Get("user"), "failed to create network: ",
err.Error())
@@ -664,7 +672,7 @@ func createNetwork(w http.ResponseWriter, r *http.Request) {
}
}
}
if network.AutoRemove == "true" {
if network.AutoRemove {
if network.AutoRemoveThreshold == 0 {
network.AutoRemoveThreshold = 60
}
@@ -672,23 +680,23 @@ func createNetwork(w http.ResponseWriter, r *http.Request) {
if network.AutoRemoveTags == nil {
network.AutoRemoveTags = []string{}
}
network, err = logic.CreateNetwork(network)
err = logic.CreateNetwork(&network)
if err != nil {
logger.Log(0, r.Header.Get("user"), "failed to create network: ",
err.Error())
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
}
logic.CreateDefaultNetworkRolesAndGroups(models.NetworkID(network.NetID))
logic.CreateDefaultAclNetworkPolicies(models.NetworkID(network.NetID))
logic.CreateDefaultTags(models.NetworkID(network.NetID))
logic.AddNetworkToAllocatedIpMap(network.NetID)
logic.CreateFallbackNameserver(network.NetID)
logic.CreateDefaultNetworkRolesAndGroups(schema.NetworkID(network.Name))
logic.CreateDefaultAclNetworkPolicies(schema.NetworkID(network.Name))
logic.CreateDefaultTags(schema.NetworkID(network.Name))
logic.AddNetworkToAllocatedIpMap(network.Name)
logic.CreateFallbackNameserver(network.Name)
if featureFlags.EnableOverlappingEgressRanges {
// assign virtual NAT pool fields
network.AssignVirtualNATDefaults(network.AddressRange, network.NetID)
logic.AssignVirtualNATDefaults(&network, network.AddressRange)
// Update network with virtual NAT settings
if err := logic.UpsertNetwork(network); err != nil {
if err := logic.UpsertNetwork(&network); err != nil {
logger.Log(0, r.Header.Get("user"), "failed to update network with virtual NAT settings:", err.Error())
}
}
@@ -696,14 +704,14 @@ func createNetwork(w http.ResponseWriter, r *http.Request) {
defaultHosts := logic.GetDefaultHosts()
for i := range defaultHosts {
currHost := &defaultHosts[i]
newNode, err := logic.UpdateHostNetwork(currHost, network.NetID, true)
newNode, err := logic.UpdateHostNetwork(currHost, network.Name, true)
if err != nil {
logger.Log(
0,
r.Header.Get("user"),
"failed to add host to network:",
currHost.ID.String(),
network.NetID,
network.Name,
err.Error(),
)
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
@@ -721,7 +729,7 @@ func createNetwork(w http.ResponseWriter, r *http.Request) {
r.Header.Get("user"),
"failed to add host to network:",
currHost.ID.String(),
network.NetID,
network.Name,
err.Error(),
)
}
@@ -736,7 +744,7 @@ func createNetwork(w http.ResponseWriter, r *http.Request) {
r.Header.Get("user"),
"failed to add host to network:",
currHost.ID.String(),
network.NetID,
network.Name,
err.Error(),
)
}
@@ -745,10 +753,10 @@ func createNetwork(w http.ResponseWriter, r *http.Request) {
// make host failover
logic.CreateFailOver(*newNode)
// make host remote access gateway
logic.CreateIngressGateway(network.NetID, newNode.ID.String(), models.IngressRequest{})
logic.CreateIngressGateway(network.Name, newNode.ID.String(), models.IngressRequest{})
logic.CreateRelay(models.RelayRequest{
NodeID: newNode.ID.String(),
NetID: network.NetID,
NetID: network.Name,
})
}
// send peer updates
@@ -757,22 +765,22 @@ func createNetwork(w http.ResponseWriter, r *http.Request) {
}
}()
logic.LogEvent(&models.Event{
Action: models.Create,
Action: schema.Create,
Source: models.Subject{
ID: r.Header.Get("user"),
Name: r.Header.Get("user"),
Type: models.UserSub,
Type: schema.UserSub,
},
TriggeredBy: r.Header.Get("user"),
Target: models.Subject{
ID: network.NetID,
Name: network.NetID,
Type: models.NetworkSub,
ID: network.Name,
Name: network.Name,
Type: schema.NetworkSub,
Info: network,
},
Origin: models.Dashboard,
Origin: schema.Dashboard,
})
logger.Log(1, r.Header.Get("user"), "created network", network.NetID)
logger.Log(1, r.Header.Get("user"), "created network", network.Name)
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(network)
}
@@ -782,15 +790,15 @@ func createNetwork(w http.ResponseWriter, r *http.Request) {
// @Tags Networks
// @Security oauth
// @Param networkname path string true "Network name"
// @Param body body models.Network true "Network details"
// @Param body body schema.Network true "Network details"
// @Produce json
// @Success 200 {object} models.Network
// @Success 200 {object} schema.Network
// @Failure 400 {object} models.ErrorResponse
func updateNetwork(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
var payload models.Network
var payload schema.Network
// we decode our body request params
err := json.NewDecoder(r.Body).Decode(&payload)
@@ -800,20 +808,21 @@ func updateNetwork(w http.ResponseWriter, r *http.Request) {
return
}
netOld, err := logic.GetNetwork(payload.NetID)
netOld := &schema.Network{Name: payload.Name}
err = netOld.Get(r.Context())
if err != nil {
slog.Info("error fetching network", "user", r.Header.Get("user"), "err", err)
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
}
err = logic.UpdateNetwork(&netOld, &payload)
err = logic.UpdateNetwork(netOld, &payload)
if err != nil {
slog.Info("failed to update network", "user", r.Header.Get("user"), "err", err)
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
}
go mq.PublishPeerUpdate(false)
slog.Info("updated network", "network", payload.NetID, "user", r.Header.Get("user"))
slog.Info("updated network", "network", payload.Name, "user", r.Header.Get("user"))
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(payload)
}
+54 -65
View File
@@ -7,6 +7,7 @@ import (
"github.com/gravitl/netmaker/db"
"github.com/gravitl/netmaker/schema"
"gorm.io/gorm"
"github.com/google/uuid"
"github.com/gravitl/netmaker/database"
@@ -19,11 +20,11 @@ import (
type NetworkValidationTestCase struct {
testname string
network models.Network
network schema.Network
errMessage string
}
var netHost models.Host
var netHost schema.Host
func TestMain(m *testing.M) {
db.InitializeDB(schema.ListModels()...)
@@ -31,10 +32,10 @@ func TestMain(m *testing.M) {
database.InitializeDatabase()
defer database.CloseDB()
logic.CreateSuperAdmin(&models.User{
UserName: "admin",
logic.CreateSuperAdmin(&schema.User{
Username: "admin",
Password: "password",
PlatformRoleID: models.SuperAdminRole,
PlatformRoleID: schema.SuperAdminRole,
})
peerUpdate := make(chan *models.Node)
go logic.ManageZombies(context.Background())
@@ -51,27 +52,29 @@ func TestMain(m *testing.M) {
func TestCreateNetwork(t *testing.T) {
deleteAllNetworks()
var network models.Network
network.NetID = "skynet1"
var network schema.Network
network.Name = "skynet1"
network.AddressRange = "10.10.0.1/24"
// if tests break - check here (removed displayname)
//network.DisplayName = "mynetwork"
_, err := logic.CreateNetwork(network)
err := logic.CreateNetwork(&network)
assert.Nil(t, err)
}
func TestGetNetwork(t *testing.T) {
createNet()
t.Run("GetExistingNetwork", func(t *testing.T) {
network, err := logic.GetNetwork("skynet")
network := &schema.Network{Name: "skynet"}
err := network.Get(db.WithContext(context.TODO()))
assert.Nil(t, err)
assert.Equal(t, "skynet", network.NetID)
assert.Equal(t, "skynet", network.Name)
})
t.Run("GetNonExistantNetwork", func(t *testing.T) {
network, err := logic.GetNetwork("doesnotexist")
assert.EqualError(t, err, "no result found")
assert.Equal(t, "", network.NetID)
network := &schema.Network{Name: "doesnotexist"}
err := network.Get(db.WithContext(context.TODO()))
assert.EqualError(t, err, gorm.ErrRecordNotFound.Error())
assert.Equal(t, "", network.ID)
})
}
@@ -128,65 +131,49 @@ func TestValidateNetwork(t *testing.T) {
cases := []NetworkValidationTestCase{
{
testname: "InvalidAddress",
network: models.Network{
NetID: "skynet",
network: schema.Network{
Name: "skynet",
AddressRange: "10.0.0.256",
},
errMessage: "Field validation for 'AddressRange' failed on the 'cidrv4' tag",
errMessage: "invalid CIDR address: 10.0.0.256",
},
{
testname: "InvalidAddress6",
network: models.Network{
NetID: "skynet1",
network: schema.Network{
Name: "skynet1",
AddressRange6: "2607::ffff/130",
},
errMessage: "Field validation for 'AddressRange6' failed on the 'cidrv6' tag",
errMessage: "invalid CIDR address: 2607::ffff/130",
},
{
testname: "InvalidNetID",
network: models.Network{
NetID: "with spaces",
network: schema.Network{
Name: "with spaces",
},
errMessage: "Field validation for 'NetID' failed on the 'netid_valid' tag",
errMessage: "invalid character(s) in network name",
},
{
testname: "NetIDTooLong",
network: models.Network{
NetID: "LongNetIDNameForMaxCharactersTest",
network: schema.Network{
Name: "LongNetIDNameForMaxCharactersTest",
},
errMessage: "Field validation for 'NetID' failed on the 'max' tag",
},
{
testname: "ListenPortTooLow",
network: models.Network{
NetID: "skynet",
DefaultListenPort: 1023,
},
errMessage: "Field validation for 'DefaultListenPort' failed on the 'min' tag",
},
{
testname: "ListenPortTooHigh",
network: models.Network{
NetID: "skynet",
DefaultListenPort: 65536,
},
errMessage: "Field validation for 'DefaultListenPort' failed on the 'max' tag",
errMessage: "network name cannot be longer than 32 characters",
},
{
testname: "KeepAliveTooBig",
network: models.Network{
NetID: "skynet",
DefaultKeepalive: 1010,
network: schema.Network{
Name: "skynet",
DefaultKeepAlive: 1010,
},
errMessage: "Field validation for 'DefaultKeepalive' failed on the 'max' tag",
errMessage: "default keep alive must be less than 1000",
},
}
for _, tc := range cases {
t.Run(tc.testname, func(t *testing.T) {
t.Log(tc.testname)
network := models.Network(tc.network)
network.SetDefaults()
network := tc.network
err := logic.ValidateNetwork(&network, false)
assert.NotNil(t, err)
assert.Contains(t, err.Error(), tc.errMessage) // test passes if err.Error() contains the expected errMessage.
})
@@ -200,7 +187,8 @@ func TestIpv6Network(t *testing.T) {
deleteAllNetworks()
createNet()
createNetDualStack()
network, err := logic.GetNetwork("skynet6")
network := &schema.Network{Name: "skynet6"}
err := network.Get(db.WithContext(context.TODO()))
t.Run("Test Network Create IPv6", func(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, network.AddressRange6, "fde6:be04:fa5e:d076::/64")
@@ -216,46 +204,47 @@ func TestIpv6Network(t *testing.T) {
func deleteAllNetworks() {
deleteAllNodes()
database.DeleteAllRecords(database.NETWORKS_TABLE_NAME)
_networks, _ := (&schema.Network{}).ListAll(db.WithContext(context.TODO()))
for _, _network := range _networks {
_ = _network.Delete(db.WithContext(context.TODO()))
}
}
func createNet() {
var network models.Network
network.NetID = "skynet"
var network schema.Network
network.Name = "skynet"
network.AddressRange = "10.0.0.1/24"
_, err := logic.GetNetwork("skynet")
err := (&schema.Network{Name: "skynet"}).Get(db.WithContext(context.TODO()))
if err != nil {
logic.CreateNetwork(network)
logic.CreateNetwork(&network)
}
}
func createNetv1(netId string) {
var network models.Network
network.NetID = netId
var network schema.Network
network.Name = netId
network.AddressRange = "100.0.0.1/24"
_, err := logic.GetNetwork(netId)
err := (&schema.Network{Name: netId}).Get(db.WithContext(context.TODO()))
if err != nil {
logic.CreateNetwork(network)
logic.CreateNetwork(&network)
}
}
func createNetDualStack() {
var network models.Network
network.NetID = "skynet6"
var network schema.Network
network.Name = "skynet6"
network.AddressRange = "10.1.2.0/24"
network.AddressRange6 = "fde6:be04:fa5e:d076::/64"
network.IsIPv4 = "yes"
network.IsIPv6 = "yes"
_, err := logic.GetNetwork("skynet6")
err := (&schema.Network{Name: "skynet6"}).Get(db.WithContext(context.TODO()))
if err != nil {
logic.CreateNetwork(network)
logic.CreateNetwork(&network)
}
}
func createNetHost() {
k, _ := wgtypes.ParseKey("DM5qhLAE20PG9BbfBCger+Ac9D2NDOwCtY1rbYDLf34=")
netHost = models.Host{
netHost = schema.Host{
ID: uuid.New(),
PublicKey: k.PublicKey(),
PublicKey: schema.WgKey{Key: k.PublicKey()},
HostPass: "password",
OS: "linux",
Name: "nethost",
+26 -14
View File
@@ -13,6 +13,7 @@ import (
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/mq"
"github.com/gravitl/netmaker/schema"
"github.com/gravitl/netmaker/servercfg"
"golang.org/x/crypto/bcrypt"
"golang.org/x/exp/slog"
@@ -27,9 +28,9 @@ func nodeHandlers(r *mux.Router) {
r.HandleFunc("/api/nodes/{network}/{nodeid}", AuthorizeHost(http.HandlerFunc(getNode))).Methods(http.MethodGet)
r.HandleFunc("/api/nodes/{network}/{nodeid}", logic.SecurityCheck(true, http.HandlerFunc(updateNode))).Methods(http.MethodPut)
r.HandleFunc("/api/nodes/{network}/{nodeid}", AuthorizeHost(http.HandlerFunc(deleteNode))).Methods(http.MethodDelete)
r.HandleFunc("/api/nodes/{network}/{nodeid}/creategateway", logic.SecurityCheck(true, checkFreeTierLimits(limitChoiceEgress, http.HandlerFunc(createEgressGateway)))).Methods(http.MethodPost)
r.HandleFunc("/api/nodes/{network}/{nodeid}/creategateway", logic.SecurityCheck(true, http.HandlerFunc(createEgressGateway))).Methods(http.MethodPost)
r.HandleFunc("/api/nodes/{network}/{nodeid}/deletegateway", logic.SecurityCheck(true, http.HandlerFunc(deleteEgressGateway))).Methods(http.MethodDelete)
r.HandleFunc("/api/nodes/{network}/{nodeid}/createingress", logic.SecurityCheck(true, checkFreeTierLimits(limitChoiceIngress, http.HandlerFunc(createGateway)))).Methods(http.MethodPost)
r.HandleFunc("/api/nodes/{network}/{nodeid}/createingress", logic.SecurityCheck(true, http.HandlerFunc(createGateway))).Methods(http.MethodPost)
r.HandleFunc("/api/nodes/{network}/{nodeid}/deleteingress", logic.SecurityCheck(true, http.HandlerFunc(deleteGateway))).Methods(http.MethodDelete)
r.HandleFunc("/api/nodes/adm/{network}/authenticate", authenticate).Methods(http.MethodPost)
r.HandleFunc("/api/v1/nodes/{network}/status", logic.SecurityCheck(true, http.HandlerFunc(getNetworkNodeStatus))).Methods(http.MethodGet)
@@ -81,7 +82,10 @@ func authenticate(response http.ResponseWriter, request *http.Request) {
return
}
}
host, err := logic.GetHost(result.HostID.String())
host := &schema.Host{
ID: result.HostID,
}
err = host.Get(request.Context())
if err != nil {
errorResponse.Code = http.StatusBadRequest
errorResponse.Message = err.Error()
@@ -234,16 +238,18 @@ func getAllNodes(w http.ResponseWriter, r *http.Request) {
}
username := r.Header.Get("user")
if r.Header.Get("ismaster") == "no" {
user, err := logic.GetUser(username)
user := &schema.User{Username: username}
err = user.Get(r.Context())
if err != nil {
return
}
userPlatformRole, err := logic.GetRole(user.PlatformRoleID)
userPlatformRole := &schema.UserRole{ID: user.PlatformRoleID}
err = userPlatformRole.Get(r.Context())
if err != nil {
return
}
if !userPlatformRole.FullAccess {
nodes = logic.GetFilteredNodesByUserAccess(*user, nodes)
nodes = logic.GetFilteredNodesByUserAccess(user, nodes)
}
}
@@ -272,7 +278,7 @@ func getNetworkNodeStatus(w http.ResponseWriter, r *http.Request) {
var params = mux.Vars(r)
netID := params["network"]
// validate network
_, err := logic.GetNetwork(netID)
err := (&schema.Network{Name: netID}).Get(r.Context())
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("failed to get network %v", err), "badrequest"))
return
@@ -313,7 +319,10 @@ func getNode(w http.ResponseWriter, r *http.Request) {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
}
host, err := logic.GetHost(node.HostID.String())
host := &schema.Host{
ID: node.HostID,
}
err = host.Get(r.Context())
if err != nil {
logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("error fetching host for node [ %s ] info: %v", nodeid, err))
@@ -549,7 +558,10 @@ func updateNode(w http.ResponseWriter, r *http.Request) {
return
}
}
host, err := logic.GetHost(newNode.HostID.String())
host := &schema.Host{
ID: newNode.HostID,
}
err = host.Get(r.Context())
if err != nil {
logger.Log(0, r.Header.Get("user"),
fmt.Sprintf("failed to get host for node [ %s ] info: %v", nodeid, err))
@@ -621,7 +633,7 @@ func updateNode(w http.ResponseWriter, r *http.Request) {
}
newNode.PostureChecksViolations,
newNode.PostureCheckVolationSeverityLevel = logic.CheckPostureViolations(logic.GetPostureCheckDeviceInfoByNode(newNode),
models.NetworkID(newNode.Network))
schema.NetworkID(newNode.Network))
newNode.LastEvaluatedAt = time.Now().UTC()
logic.UpsertNode(newNode)
logic.GetNodeStatus(newNode, false)
@@ -636,23 +648,23 @@ func updateNode(w http.ResponseWriter, r *http.Request) {
currentNode.Network,
)
logic.LogEvent(&models.Event{
Action: models.Update,
Action: schema.Update,
Source: models.Subject{
ID: r.Header.Get("user"),
Name: r.Header.Get("user"),
Type: models.UserSub,
Type: schema.UserSub,
},
TriggeredBy: r.Header.Get("user"),
Target: models.Subject{
ID: newNode.ID.String(),
Name: host.Name,
Type: models.NodeSub,
Type: schema.NodeSub,
},
Diff: models.Diff{
Old: currentNode,
New: newNode,
},
Origin: models.Dashboard,
Origin: schema.Dashboard,
})
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(apiNode)
+13 -8
View File
@@ -4,19 +4,23 @@ import (
"net"
"testing"
"context"
"github.com/google/uuid"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/db"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/logic/acls"
"github.com/gravitl/netmaker/logic/acls/nodeacls"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/schema"
"github.com/gravitl/netmaker/servercfg"
"github.com/stretchr/testify/assert"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
var nonLinuxHost models.Host
var linuxHost models.Host
var nonLinuxHost schema.Host
var linuxHost schema.Host
func TestGetNetworkNodes(t *testing.T) {
deleteAllNetworks()
@@ -95,10 +99,11 @@ func TestNodeACLs(t *testing.T) {
t.Run("node acls correct after add new node not allowed", func(t *testing.T) {
node3 := createNodeWithParams("", "10.0.0.100/32")
createNodeHosts()
n, e := logic.GetNetwork(node3.Network)
n := &schema.Network{Name: node3.Network}
e := n.Get(db.WithContext(context.TODO()))
assert.Nil(t, e)
n.DefaultACL = "no"
e = logic.SaveNetwork(&n)
e = logic.SaveNetwork(n)
assert.Nil(t, e)
err := logic.AssociateNodeToHost(node3, &linuxHost)
assert.Nil(t, err)
@@ -159,18 +164,18 @@ func createNodeWithParams(network, address string) *models.Node {
func createNodeHosts() {
k, _ := wgtypes.ParseKey("DM5qhLAE20PG9BbfBCger+Ac9D2NDOwCtY1rbYDLf34=")
linuxHost = models.Host{
linuxHost = schema.Host{
ID: uuid.New(),
PublicKey: k.PublicKey(),
PublicKey: schema.WgKey{Key: k.PublicKey()},
HostPass: "password",
OS: "linux",
Name: "linuxhost",
}
_ = logic.CreateHost(&linuxHost)
nonLinuxHost = models.Host{
nonLinuxHost = schema.Host{
ID: uuid.New(),
OS: "windows",
PublicKey: k.PublicKey(),
PublicKey: schema.WgKey{Key: k.PublicKey()},
Name: "windowshost",
HostPass: "password",
}
+30 -24
View File
@@ -1,6 +1,7 @@
package controller
import (
"context"
"encoding/json"
"errors"
"fmt"
@@ -13,6 +14,8 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/gorilla/mux"
ch "github.com/gravitl/netmaker/clickhouse"
"github.com/gravitl/netmaker/db"
"github.com/gravitl/netmaker/schema"
"golang.org/x/exp/slog"
"github.com/gravitl/netmaker/database"
@@ -38,8 +41,11 @@ func serverHandlers(r *mux.Router) {
"/api/server/shutdown", logic.SecurityCheck(true,
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("ismaster") != "yes" {
caller, err := logic.GetUser(r.Header.Get("user"))
if err != nil || caller.PlatformRoleID != models.SuperAdminRole {
caller := &schema.User{
Username: r.Header.Get("user"),
}
err := caller.Get(r.Context())
if err != nil || caller.PlatformRoleID != schema.SuperAdminRole {
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("only a super-admin can shut down the server"), "forbidden"))
return
}
@@ -255,7 +261,7 @@ func updateSettings(w http.ResponseWriter, r *http.Request) {
return
}
if superAdmin.AuthType == models.OAuth {
if superAdmin.AuthType == schema.OAuth {
err := fmt.Errorf(
"cannot remove IdP integration because an OAuth user has the super-admin role; transfer the super-admin role to another user first",
)
@@ -293,19 +299,19 @@ func updateSettings(w http.ResponseWriter, r *http.Request) {
Source: models.Subject{
ID: r.Header.Get("user"),
Name: r.Header.Get("user"),
Type: models.UserSub,
Type: schema.UserSub,
},
TriggeredBy: r.Header.Get("user"),
Target: models.Subject{
ID: models.SettingSub.String(),
Name: models.SettingSub.String(),
Type: models.SettingSub,
ID: schema.SettingSub.String(),
Name: schema.SettingSub.String(),
Type: schema.SettingSub,
},
Diff: models.Diff{
Old: currSettings,
New: req,
},
Origin: models.Dashboard,
Origin: schema.Dashboard,
})
go reInit(currSettings, req, force == "true")
logic.ReturnSuccessResponseWithJson(w, r, req, "updated server settings successfully")
@@ -333,7 +339,7 @@ func reInit(curr, new models.ServerSettings, force bool) {
if force || !new.EnableFlowLogs {
if curr.NetclientAutoUpdate != new.NetclientAutoUpdate ||
curr.EnableFlowLogs != new.EnableFlowLogs {
hosts, _ := logic.GetAllHosts()
hosts, _ := (&schema.Host{}).ListAll(db.WithContext(context.TODO()))
for _, host := range hosts {
if curr.NetclientAutoUpdate != new.NetclientAutoUpdate {
host.AutoUpdate = new.NetclientAutoUpdate
@@ -355,7 +361,7 @@ func reInit(curr, new models.ServerSettings, force bool) {
go mq.PublishPeerUpdate(false)
}
func identifySettingsUpdateAction(old, new models.ServerSettings) models.Action {
func identifySettingsUpdateAction(old, new models.ServerSettings) schema.Action {
// TODO: here we are relying on the dashboard to only
// make singular updates, but it's possible that the
// API can be called to make multiple changes to the
@@ -363,33 +369,33 @@ func identifySettingsUpdateAction(old, new models.ServerSettings) models.Action
// events or create singular update APIs.
if old.MFAEnforced != new.MFAEnforced {
if new.MFAEnforced {
return models.EnforceMFA
return schema.EnforceMFA
} else {
return models.UnenforceMFA
return schema.UnenforceMFA
}
}
if old.BasicAuth != new.BasicAuth {
if new.BasicAuth {
return models.EnableBasicAuth
return schema.EnableBasicAuth
} else {
return models.DisableBasicAuth
return schema.DisableBasicAuth
}
}
if old.Telemetry != new.Telemetry {
if new.Telemetry == "off" {
return models.DisableTelemetry
return schema.DisableTelemetry
} else {
return models.EnableTelemetry
return schema.EnableTelemetry
}
}
if old.EnableFlowLogs != new.EnableFlowLogs {
if new.EnableFlowLogs {
return models.EnableFlowLogs
return schema.EnableFlowLogs
} else {
return models.DisableFlowLogs
return schema.DisableFlowLogs
}
}
@@ -398,19 +404,19 @@ func identifySettingsUpdateAction(old, new models.ServerSettings) models.Action
old.ManageDNS != new.ManageDNS ||
old.DefaultDomain != new.DefaultDomain ||
old.EndpointDetection != new.EndpointDetection {
return models.UpdateClientSettings
return schema.UpdateClientSettings
}
if old.AllowedEmailDomains != new.AllowedEmailDomains ||
old.JwtValidityDuration != new.JwtValidityDuration {
return models.UpdateAuthenticationSecuritySettings
return schema.UpdateAuthenticationSecuritySettings
}
if old.Verbosity != new.Verbosity ||
old.MetricsPort != new.MetricsPort ||
old.MetricInterval != new.MetricInterval ||
old.AuditLogsRetentionPeriodInDays != new.AuditLogsRetentionPeriodInDays {
return models.UpdateMonitoringAndDebuggingSettings
return schema.UpdateMonitoringAndDebuggingSettings
}
if old.EmailSenderAddr != new.EmailSenderAddr ||
@@ -418,7 +424,7 @@ func identifySettingsUpdateAction(old, new models.ServerSettings) models.Action
old.EmailSenderPassword != new.EmailSenderPassword ||
old.SmtpHost != new.SmtpHost ||
old.SmtpPort != new.SmtpPort {
return models.UpdateSMTPSettings
return schema.UpdateSMTPSettings
}
if old.AuthProvider != new.AuthProvider ||
@@ -432,10 +438,10 @@ func identifySettingsUpdateAction(old, new models.ServerSettings) models.Action
old.AzureTenant != new.AzureTenant ||
!cmp.Equal(old.GroupFilters, new.GroupFilters) ||
cmp.Equal(old.UserFilters, new.UserFilters) {
return models.UpdateIDPSettings
return schema.UpdateIDPSettings
}
return models.Update
return schema.Update
}
// @Summary Get feature flags for this server
+426 -350
View File
File diff suppressed because it is too large Load Diff
+7 -5
View File
@@ -101,11 +101,9 @@ const (
var dbMutex sync.RWMutex
var Tables = []string{
NETWORKS_TABLE_NAME,
NODES_TABLE_NAME,
CERTS_TABLE_NAME,
DELETED_NODES_TABLE_NAME,
USERS_TABLE_NAME,
DNS_TABLE_NAME,
EXT_CLIENT_TABLE_NAME,
PEERS_TABLE_NAME,
@@ -116,18 +114,22 @@ var Tables = []string{
SSO_STATE_CACHE,
METRICS_TABLE_NAME,
NETWORK_USER_TABLE_NAME,
USER_GROUPS_TABLE_NAME,
CACHE_TABLE_NAME,
HOSTS_TABLE_NAME,
ENROLLMENT_KEYS_TABLE_NAME,
HOST_ACTIONS_TABLE_NAME,
PENDING_USERS_TABLE_NAME,
USER_PERMISSIONS_TABLE_NAME,
USER_INVITES_TABLE_NAME,
TAG_TABLE_NAME,
ACLS_TABLE_NAME,
PEER_ACK_TABLE,
SERVER_SETTINGS,
// The following tables are to be migrated, but we still need them so that the migration function
// doesn't fail with table does not exist.
USERS_TABLE_NAME,
USER_GROUPS_TABLE_NAME,
USER_PERMISSIONS_TABLE_NAME,
NETWORKS_TABLE_NAME,
HOSTS_TABLE_NAME,
}
func getCurrentDB() map[string]interface{} {
+1
View File
@@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"errors"
"github.com/gravitl/netmaker/db"
_ "github.com/mattn/go-sqlite3" // need to blank import this package
)
+7 -2
View File
@@ -1,11 +1,16 @@
package database
import "strings"
import (
"errors"
"strings"
"gorm.io/gorm"
)
// IsEmptyRecord - checks for if it's an empty record error or not
func IsEmptyRecord(err error) bool {
if err == nil {
return false
}
return strings.Contains(err.Error(), NO_RECORD) || strings.Contains(err.Error(), NO_RECORDS)
return strings.Contains(err.Error(), NO_RECORD) || strings.Contains(err.Error(), NO_RECORDS) || errors.Is(err, gorm.ErrRecordNotFound)
}
+13 -1
View File
@@ -49,7 +49,19 @@ func (s *sqliteConnector) connect() (*gorm.DB, error) {
}
}
return gorm.Open(sqlite.Open(dbFilePath), &gorm.Config{
db, err := gorm.Open(sqlite.Open(dbFilePath), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
return nil, err
}
sqlDB, err := db.DB()
if err != nil {
return nil, err
}
sqlDB.SetMaxIdleConns(1)
return db, nil
}
+61
View File
@@ -0,0 +1,61 @@
package types
import (
"fmt"
"gorm.io/gorm"
)
type Option func(db *gorm.DB) *gorm.DB
func WithPagination(page, pageSize int) Option {
return func(db *gorm.DB) *gorm.DB {
if page < 1 {
page = 1
}
if pageSize < 1 || pageSize > 100 {
pageSize = 10
}
offset := (page - 1) * pageSize
return db.Offset(offset).Limit(pageSize)
}
}
// WithFilter applies a WHERE clause for the given column.
// IMPORTANT: `field` MUST be a trusted, hardcoded column name.
// NEVER pass user-supplied strings as `field`.
func WithFilter(field string, value ...interface{}) Option {
return func(db *gorm.DB) *gorm.DB {
if len(value) == 0 {
return db
}
if len(value) == 1 {
return db.Where(fmt.Sprintf("%s = ?", field), value[0])
}
return db.Where(fmt.Sprintf("%s IN ?", field), value)
}
}
func InAscOrder(fields ...string) Option {
return func(db *gorm.DB) *gorm.DB {
for _, field := range fields {
db = db.Order(fmt.Sprintf("%s ASC", field))
}
return db
}
}
func InDescOrder(fields ...string) Option {
return func(db *gorm.DB) *gorm.DB {
for _, field := range fields {
db = db.Order(fmt.Sprintf("%s DESC", field))
}
return db
}
}
+11 -11
View File
@@ -17,8 +17,8 @@ import (
)
var (
testNetwork = &models.Network{
NetID: "not-a-network",
_testNetwork = &schema.Network{
Name: "not-a-network",
}
testExternalClient = &models.ExtClient{
ClientID: "testExtClient",
@@ -31,10 +31,10 @@ func TestMain(m *testing.M) {
database.InitializeDatabase()
defer database.CloseDB()
logic.CreateSuperAdmin(&models.User{
UserName: "superadmin",
logic.CreateSuperAdmin(&schema.User{
Username: "superadmin",
Password: "password",
PlatformRoleID: models.SuperAdminRole,
PlatformRoleID: schema.SuperAdminRole,
})
peerUpdate := make(chan *models.Node)
go logic.ManageZombies(context.Background())
@@ -48,18 +48,18 @@ func TestMain(m *testing.M) {
}
func TestNetworkExists(t *testing.T) {
database.DeleteRecord(database.NETWORKS_TABLE_NAME, testNetwork.NetID)
exists, err := logic.NetworkExists(testNetwork.NetID)
assert.NotNil(t, err)
_ = _testNetwork.Delete(db.WithContext(context.TODO()))
exists, err := logic.NetworkExists(_testNetwork.Name)
assert.Nil(t, err)
assert.False(t, exists)
err = logic.SaveNetwork(testNetwork)
err = logic.SaveNetwork(_testNetwork)
assert.Nil(t, err)
exists, err = logic.NetworkExists(testNetwork.NetID)
exists, err = logic.NetworkExists(_testNetwork.Name)
assert.Nil(t, err)
assert.True(t, exists)
err = database.DeleteRecord(database.NETWORKS_TABLE_NAME, testNetwork.NetID)
err = _testNetwork.Delete(db.WithContext(context.TODO()))
assert.Nil(t, err)
}
+31 -31
View File
@@ -43,9 +43,9 @@ func GetFwRulesOnIngressGateway(node models.Node) (rules []models.FwRule) {
return string(rules[i].DstIP.IP.To16()) < string(rules[j].DstIP.IP.To16())
})
}()
defaultDevicePolicy, _ := GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy)
defaultDevicePolicy, _ := GetDefaultPolicy(schema.NetworkID(node.Network), models.DevicePolicy)
nodes, _ := GetNetworkNodes(node.Network)
nodes = append(nodes, GetStaticNodesByNetwork(models.NetworkID(node.Network), true)...)
nodes = append(nodes, GetStaticNodesByNetwork(schema.NetworkID(node.Network), true)...)
rules = GetFwRulesForUserNodesOnGw(node, nodes)
if defaultDevicePolicy.Enabled {
return
@@ -417,10 +417,10 @@ func GetStaticNodeIps(node models.Node) (ips []net.IP) {
defer func() {
sortIPs(ips)
}()
defaultUserPolicy, _ := GetDefaultPolicy(models.NetworkID(node.Network), models.UserPolicy)
defaultDevicePolicy, _ := GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy)
defaultUserPolicy, _ := GetDefaultPolicy(schema.NetworkID(node.Network), models.UserPolicy)
defaultDevicePolicy, _ := GetDefaultPolicy(schema.NetworkID(node.Network), models.DevicePolicy)
extclients := GetStaticNodesByNetwork(models.NetworkID(node.Network), false)
extclients := GetStaticNodesByNetwork(schema.NetworkID(node.Network), false)
for _, extclient := range extclients {
if extclient.IsUserNode && defaultUserPolicy.Enabled {
continue
@@ -516,11 +516,11 @@ func GetAclRulesForNode(targetnodeI *models.Node) (rules map[string]models.AclRu
}
var taggedNodes map[models.TagID][]models.Node
if targetnode.IsIngressGateway {
taggedNodes = GetTagMapWithNodesByNetwork(models.NetworkID(targetnode.Network), false)
taggedNodes = GetTagMapWithNodesByNetwork(schema.NetworkID(targetnode.Network), false)
} else {
taggedNodes = GetTagMapWithNodesByNetwork(models.NetworkID(targetnode.Network), true)
taggedNodes = GetTagMapWithNodesByNetwork(schema.NetworkID(targetnode.Network), true)
}
acls := ListDevicePolicies(models.NetworkID(targetnode.Network))
acls := ListDevicePolicies(schema.NetworkID(targetnode.Network))
var targetNodeTags = make(map[models.TagID]struct{})
if targetnode.Mutex != nil {
targetnode.Mutex.Lock()
@@ -796,9 +796,9 @@ func GetEgressRulesForNode(targetnode models.Node) (rules map[string]models.AclR
defer func() {
rules = GetEgressUserRulesForNode(&targetnode, rules)
}()
taggedNodes := GetTagMapWithNodesByNetwork(models.NetworkID(targetnode.Network), true)
taggedNodes := GetTagMapWithNodesByNetwork(schema.NetworkID(targetnode.Network), true)
acls := ListDevicePolicies(models.NetworkID(targetnode.Network))
acls := ListDevicePolicies(schema.NetworkID(targetnode.Network))
var targetNodeTags = make(map[models.TagID]struct{})
targetNodeTags[models.TagID(targetnode.ID.String())] = struct{}{}
targetNodeTags["*"] = struct{}{}
@@ -1070,7 +1070,7 @@ var IsPeerAllowed = func(node, peer models.Node, checkDefaultPolicy bool) bool {
}
if checkDefaultPolicy {
// check default policy if all allowed return true
defaultPolicy, err := GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy)
defaultPolicy, err := GetDefaultPolicy(schema.NetworkID(node.Network), models.DevicePolicy)
if err == nil {
if defaultPolicy.Enabled {
return true
@@ -1079,7 +1079,7 @@ var IsPeerAllowed = func(node, peer models.Node, checkDefaultPolicy bool) bool {
}
// list device policies
policies := ListDevicePolicies(models.NetworkID(peer.Network))
policies := ListDevicePolicies(schema.NetworkID(peer.Network))
srcMap := make(map[string]struct{})
dstMap := make(map[string]struct{})
defer func() {
@@ -1176,9 +1176,9 @@ func CheckTagGroupPolicy(srcMap, dstMap map[string]struct{}, node, peer models.N
}
var (
CreateDefaultTags = func(netID models.NetworkID) {}
CreateDefaultTags = func(netID schema.NetworkID) {}
DeleteAllNetworkTags = func(networkID models.NetworkID) {}
DeleteAllNetworkTags = func(networkID schema.NetworkID) {}
IsUserAllowedToCommunicate = func(userName string, peer models.Node) (bool, []models.Acl) {
return false, []models.Acl{}
@@ -1213,7 +1213,7 @@ func MigrateAclPolicies() {
func IsNodeAllowedToCommunicateWithAllRsrcs(node models.Node) bool {
// check default policy if all allowed return true
defaultPolicy, err := GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy)
defaultPolicy, err := GetDefaultPolicy(schema.NetworkID(node.Network), models.DevicePolicy)
if err == nil {
if defaultPolicy.Enabled {
return true
@@ -1244,7 +1244,7 @@ func IsNodeAllowedToCommunicateWithAllRsrcs(node models.Node) bool {
node.Tags[models.TagID(fmt.Sprintf("%s.%s", node.Network, models.GwTagName))] = struct{}{}
}
// list device policies
policies := ListDevicePolicies(models.NetworkID(node.Network))
policies := ListDevicePolicies(schema.NetworkID(node.Network))
srcMap := make(map[string]struct{})
dstMap := make(map[string]struct{})
defer func() {
@@ -1332,7 +1332,7 @@ func IsNodeAllowedToCommunicate(node, peer models.Node, checkDefaultPolicy bool)
peerTags[models.TagID(peerId)] = struct{}{}
if checkDefaultPolicy {
// check default policy if all allowed return true
defaultPolicy, err := GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy)
defaultPolicy, err := GetDefaultPolicy(schema.NetworkID(node.Network), models.DevicePolicy)
if err == nil {
if defaultPolicy.Enabled {
return true, []models.Acl{defaultPolicy}
@@ -1344,7 +1344,7 @@ func IsNodeAllowedToCommunicate(node, peer models.Node, checkDefaultPolicy bool)
allowedPolicies = UniquePolicies(allowedPolicies)
}()
// list device policies
policies := ListDevicePolicies(models.NetworkID(peer.Network))
policies := ListDevicePolicies(schema.NetworkID(peer.Network))
srcMap := make(map[string]struct{})
dstMap := make(map[string]struct{})
defer func() {
@@ -1462,7 +1462,7 @@ func IsNodeAllowedToCommunicate(node, peer models.Node, checkDefaultPolicy bool)
}
// GetDefaultPolicy - fetches default policy in the network by ruleType
func GetDefaultPolicy(netID models.NetworkID, ruleType models.AclPolicyType) (models.Acl, error) {
func GetDefaultPolicy(netID schema.NetworkID, ruleType models.AclPolicyType) (models.Acl, error) {
aclID := "all-users"
if ruleType == models.DevicePolicy {
aclID = "all-nodes"
@@ -1505,7 +1505,7 @@ func GetDefaultPolicy(netID models.NetworkID, ruleType models.AclPolicyType) (mo
}
// ListAcls - lists all acl policies
func ListAclsByNetwork(netID models.NetworkID) ([]models.Acl, error) {
func ListAclsByNetwork(netID schema.NetworkID) ([]models.Acl, error) {
allAcls := ListAcls()
netAcls := []models.Acl{}
@@ -1538,7 +1538,7 @@ func ListEgressAcls(eID string) ([]models.Acl, error) {
}
// ListDevicePolicies - lists all device policies in a network
func ListDevicePolicies(netID models.NetworkID) []models.Acl {
func ListDevicePolicies(netID schema.NetworkID) []models.Acl {
allAcls := ListAcls()
deviceAcls := []models.Acl{}
for _, acl := range allAcls {
@@ -1550,7 +1550,7 @@ func ListDevicePolicies(netID models.NetworkID) []models.Acl {
}
// ListUserPolicies - lists all user policies in a network
func ListUserPolicies(netID models.NetworkID) []models.Acl {
func ListUserPolicies(netID schema.NetworkID) []models.Acl {
allAcls := ListAcls()
userAcls := []models.Acl{}
for _, acl := range allAcls {
@@ -1697,7 +1697,7 @@ func UniquePolicies(items []models.Acl) []models.Acl {
}
// DeleteNetworkPolicies - deletes all default network acl policies
func DeleteNetworkPolicies(netId models.NetworkID) {
func DeleteNetworkPolicies(netId schema.NetworkID) {
acls, _ := ListAclsByNetwork(netId)
for _, acl := range acls {
if acl.NetworkID == netId {
@@ -1716,7 +1716,7 @@ func SortAclEntrys(acls []models.Acl) {
// ValidateCreateAclReq - validates create req for acl
func ValidateCreateAclReq(req models.Acl) error {
// check if acl network exists
_, err := GetNetwork(req.NetworkID.String())
err := (&schema.Network{Name: req.NetworkID.String()}).Get(db.WithContext(context.TODO()))
if err != nil {
return errors.New("failed to get network details for " + req.NetworkID.String())
}
@@ -1726,17 +1726,17 @@ func ValidateCreateAclReq(req models.Acl) error {
// }
for _, src := range req.Src {
if src.ID == models.UserGroupAclID {
userGroup, err := GetUserGroup(models.UserGroupID(src.Value))
userGroup, err := GetUserGroup(schema.UserGroupID(src.Value))
if err != nil {
return err
}
_, ok := userGroup.NetworkRoles[models.AllNetworks]
_, ok := userGroup.NetworkRoles.Data()[schema.AllNetworks]
if ok {
continue
}
_, ok = userGroup.NetworkRoles[req.NetworkID]
_, ok = userGroup.NetworkRoles.Data()[req.NetworkID]
if !ok {
return fmt.Errorf("user group %s does not have access to network %s", src.Value, req.NetworkID)
}
@@ -1824,7 +1824,7 @@ func RemoveNodeFromAclPolicy(node models.Node) {
} else {
nodeID = node.ID.String()
}
acls, _ := ListAclsByNetwork(models.NetworkID(node.Network))
acls, _ := ListAclsByNetwork(schema.NetworkID(node.Network))
for _, acl := range acls {
delete := false
update := false
@@ -1891,7 +1891,7 @@ func RemoveNodeFromAclPolicy(node models.Node) {
}
// CreateDefaultAclNetworkPolicies - create default acl network policies
func CreateDefaultAclNetworkPolicies(netID models.NetworkID) {
func CreateDefaultAclNetworkPolicies(netID schema.NetworkID) {
if netID.String() == "" {
return
}
@@ -1957,7 +1957,7 @@ func CreateDefaultAclNetworkPolicies(netID models.NetworkID) {
CreateDefaultUserPolicies(netID)
}
func getTagMapWithNodesByNetwork(netID models.NetworkID, withStaticNodes bool) (tagNodesMap map[models.TagID][]models.Node) {
func getTagMapWithNodesByNetwork(netID schema.NetworkID, withStaticNodes bool) (tagNodesMap map[models.TagID][]models.Node) {
tagNodesMap = make(map[models.TagID][]models.Node)
nodes, _ := GetNetworkNodes(netID.String())
netGwTag := models.TagID(fmt.Sprintf("%s.%s", netID.String(), models.GwTagName))
@@ -1974,7 +1974,7 @@ func getTagMapWithNodesByNetwork(netID models.NetworkID, withStaticNodes bool) (
return addTagMapWithStaticNodes(netID, tagNodesMap)
}
func addTagMapWithStaticNodes(netID models.NetworkID,
func addTagMapWithStaticNodes(netID schema.NetworkID,
tagNodesMap map[models.TagID][]models.Node) map[models.TagID][]models.Node {
extclients, err := GetNetworkExtClients(netID.String())
if err != nil {
+221 -217
View File
@@ -6,19 +6,21 @@ import (
"encoding/json"
"errors"
"fmt"
"net/mail"
"strings"
"time"
"github.com/gravitl/netmaker/db"
"github.com/gravitl/netmaker/schema"
"gorm.io/datatypes"
"gorm.io/gorm"
"github.com/go-playground/validator/v10"
"golang.org/x/crypto/bcrypt"
"golang.org/x/exp/slog"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/servercfg"
)
const (
@@ -37,67 +39,25 @@ var ResetIDPSyncHook = func() {}
// HasSuperAdmin - checks if server has an superadmin/owner
func HasSuperAdmin() (bool, error) {
users, err := GetUsersDB()
if err != nil {
if database.IsEmptyRecord(err) {
return false, nil
}
return true, err
}
for _, user := range users {
if user.PlatformRoleID == models.SuperAdminRole {
return true, nil
}
}
return false, nil
}
// GetUsersDB - gets users
func GetUsersDB() ([]models.User, error) {
if servercfg.CacheEnabled() {
users := getUsersFromCache()
if len(users) != 0 {
return users, nil
}
}
var users []models.User
collection, err := database.FetchRecords(database.USERS_TABLE_NAME)
if err != nil {
return users, err
}
cacheMap := make(map[string]models.User, len(collection))
for _, value := range collection {
var user models.User
err = json.Unmarshal([]byte(value), &user)
if err != nil {
continue
}
users = append(users, user)
cacheMap[user.UserName] = user
}
if servercfg.CacheEnabled() {
loadUsersIntoCache(cacheMap)
}
return users, nil
return (&schema.User{}).SuperAdminExists(db.WithContext(context.TODO()))
}
// GetUsers - gets users
func GetUsers() ([]models.ReturnUser, error) {
dbUsers, err := GetUsersDB()
_users, err := (&schema.User{}).ListAll(db.WithContext(context.TODO()))
if err != nil {
return nil, err
}
users := make([]models.ReturnUser, 0, len(dbUsers))
for _, u := range dbUsers {
users = append(users, ToReturnUser(u))
users := make([]models.ReturnUser, len(_users))
for i, _user := range _users {
users[i] = ToReturnUser(&_user)
}
return users, nil
}
// IsOauthUser - returns
func IsOauthUser(user *models.User) error {
func IsOauthUser(user *schema.User) error {
var currentValue, err = FetchPassValue("")
if err != nil {
return err
@@ -130,63 +90,63 @@ func FetchPassValue(newValue string) (string, error) {
}
// CreateUser - creates a user
func CreateUser(user *models.User) error {
func CreateUser(_user *schema.User) error {
// check if user exists
if _, err := GetUser(user.UserName); err == nil {
userCheck := &schema.User{Username: _user.Username}
if err := userCheck.Get(db.WithContext(context.TODO())); err == nil {
return errors.New("user exists")
}
SetUserDefaults(user)
if err := IsGroupsValid(user.UserGroups); err != nil {
SetUserDefaults(_user)
if err := IsGroupsValid(_user.UserGroups.Data()); err != nil {
return errors.New("invalid groups: " + err.Error())
}
if err := IsNetworkRolesValid(user.NetworkRoles); err != nil {
return errors.New("invalid network roles: " + err.Error())
}
var err = ValidateUser(user)
var err = ValidateUser(_user)
if err != nil {
logger.Log(0, "failed to validate user", err.Error())
return err
}
// encrypt that password so we never see it again
hash, err := bcrypt.GenerateFromPassword([]byte(user.Password), 5)
hash, err := bcrypt.GenerateFromPassword([]byte(_user.Password), 5)
if err != nil {
logger.Log(0, "error encrypting pass", err.Error())
return err
}
// set password to encrypted password
user.Password = string(hash)
user.AuthType = models.BasicAuth
if IsOauthUser(user) == nil {
user.AuthType = models.OAuth
_user.Password = string(hash)
_user.AuthType = schema.BasicAuth
if IsOauthUser(_user) == nil {
_user.AuthType = schema.OAuth
}
AddGlobalNetRolesToAdmins(user)
AddGlobalNetRolesToAdmins(_user)
// create user will always be called either from API or Dashboard.
_, err = CreateUserJWT(user.UserName, user.PlatformRoleID, DashboardApp)
_, err = CreateUserJWT(_user.Username, _user.PlatformRoleID, DashboardApp)
if err != nil {
logger.Log(0, "failed to generate token", err.Error())
return err
}
// connect db
data, err := json.Marshal(user)
dbctx := db.BeginTx(context.TODO())
commit := false
defer func() {
if commit {
db.FromContext(dbctx).Commit()
} else {
db.FromContext(dbctx).Rollback()
}
}()
err = _user.Create(dbctx)
if err != nil {
logger.Log(0, "failed to marshal", err.Error())
return err
}
err = database.Insert(user.UserName, string(data), database.USERS_TABLE_NAME)
if err != nil {
logger.Log(0, "failed to insert user", err.Error())
return err
}
if servercfg.CacheEnabled() {
storeUserInCache(*user)
return fmt.Errorf("failed to create user %s: %v", _user.Username, err)
}
commit = true
return nil
}
// CreateSuperAdmin - creates an super admin user
func CreateSuperAdmin(u *models.User) error {
func CreateSuperAdmin(u *schema.User) error {
hassuperadmin, err := HasSuperAdmin()
if err != nil {
return err
@@ -194,37 +154,34 @@ func CreateSuperAdmin(u *models.User) error {
if hassuperadmin {
return errors.New("superadmin user already exists")
}
u.IsSuperAdmin = true
u.IsAdmin = true
u.PlatformRoleID = models.SuperAdminRole
u.PlatformRoleID = schema.SuperAdminRole
return CreateUser(u)
}
// VerifyAuthRequest - verifies an auth request
func VerifyAuthRequest(authRequest models.UserAuthParams, appName string) (string, error) {
var result models.User
if authRequest.UserName == "" {
return "", errors.New("username can't be empty")
} else if authRequest.Password == "" {
return "", errors.New("password can't be empty")
}
// Search DB for node with Mac Address. Ignore pending nodes (they should not be able to authenticate with API until approved).
record, err := database.FetchRecord(database.USERS_TABLE_NAME, authRequest.UserName)
_user := &schema.User{
Username: authRequest.UserName,
}
err := _user.Get(db.WithContext(context.TODO()))
if err != nil {
return "", errors.New("incorrect credentials")
}
if err = json.Unmarshal([]byte(record), &result); err != nil {
return "", errors.New("error unmarshalling user json: " + err.Error())
}
// compare password from request to stored password in database
// might be able to have a common hash (certificates?) and compare those so that a password isn't passed in in plain text...
// TODO: Consider a way of hashing the password client side before sending, or using certificates
if err = bcrypt.CompareHashAndPassword([]byte(result.Password), []byte(authRequest.Password)); err != nil {
if err = bcrypt.CompareHashAndPassword([]byte(_user.Password), []byte(authRequest.Password)); err != nil {
return "", errors.New("incorrect credentials")
}
if result.IsMFAEnabled {
if _user.IsMFAEnabled {
tokenString, err := CreatePreAuthToken(authRequest.UserName)
if err != nil {
slog.Error("error creating jwt", "error", err)
@@ -234,15 +191,15 @@ func VerifyAuthRequest(authRequest models.UserAuthParams, appName string) (strin
return tokenString, nil
} else {
// Create a new JWT for the node
tokenString, err := CreateUserJWT(authRequest.UserName, result.PlatformRoleID, appName)
tokenString, err := CreateUserJWT(authRequest.UserName, schema.UserRoleID(_user.PlatformRoleID), appName)
if err != nil {
slog.Error("error creating jwt", "error", err)
return "", err
}
// update last login time
result.LastLoginTime = time.Now().UTC()
err = UpsertUser(result)
_user.LastLoginAt = time.Now().UTC()
err = _user.Update(db.WithContext(context.TODO()))
if err != nil {
slog.Error("error upserting user", "error", err)
return "", err
@@ -253,44 +210,42 @@ func VerifyAuthRequest(authRequest models.UserAuthParams, appName string) (strin
}
// UpsertUser - updates user in the db
func UpsertUser(user models.User) error {
data, err := json.Marshal(&user)
if err != nil {
slog.Error("error marshalling user", "user", user.UserName, "error", err.Error())
return err
func UpsertUser(_user schema.User) error {
_existingUser := schema.User{Username: _user.Username}
// Check if user exists to preserve ID
err := _existingUser.Get(db.WithContext(context.TODO()))
if err == nil {
_user.ID = _existingUser.ID
return _user.Update(db.WithContext(context.TODO()))
}
if err = database.Insert(user.UserName, string(data), database.USERS_TABLE_NAME); err != nil {
slog.Error("error inserting user", "user", user.UserName, "error", err.Error())
return err
}
if servercfg.CacheEnabled() {
storeUserInCache(user)
}
return nil
return _user.Create(db.WithContext(context.TODO()))
}
// UpdateUser - updates a given user
func UpdateUser(userchange, user *models.User) (*models.User, error) {
func UpdateUser(userchange, _user *schema.User) (*schema.User, error) {
// check if user exists
if _, err := GetUser(user.UserName); err != nil {
return &models.User{}, err
userCheck := &schema.User{Username: _user.Username}
if err := userCheck.Get(db.WithContext(context.TODO())); err != nil {
return &schema.User{}, err
}
queryUser := user.UserName
if userchange.UserName != "" && user.UserName != userchange.UserName {
queryUser := _user.Username
if userchange.Username != "" && _user.Username != userchange.Username {
// check if username is available
if _, err := GetUser(userchange.UserName); err == nil {
return &models.User{}, errors.New("username exists already")
userCheck := &schema.User{Username: userchange.Username}
if err := userCheck.Get(db.WithContext(context.TODO())); err == nil {
return &schema.User{}, errors.New("username exists already")
}
if userchange.UserName == MasterUser {
return &models.User{}, errors.New("username not allowed")
if userchange.Username == MasterUser {
return &schema.User{}, errors.New("username not allowed")
}
user.UserName = userchange.UserName
_user.Username = userchange.Username
}
if userchange.Password != "" {
if len(userchange.Password) < 5 {
return &models.User{}, errors.New("password requires min 5 characters")
return &schema.User{}, errors.New("password requires min 5 characters")
}
// encrypt that password so we never see it again
hash, err := bcrypt.GenerateFromPassword([]byte(userchange.Password), 5)
@@ -301,152 +256,201 @@ func UpdateUser(userchange, user *models.User) (*models.User, error) {
// set password to encrypted password
userchange.Password = string(hash)
user.Password = userchange.Password
_user.Password = userchange.Password
}
validUserGroups := make(map[models.UserGroupID]struct{})
for userGroupID := range userchange.UserGroups {
validUserGroups := make(map[schema.UserGroupID]struct{})
for userGroupID := range userchange.UserGroups.Data() {
_, err := GetUserGroup(userGroupID)
if err == nil {
validUserGroups[userGroupID] = struct{}{}
}
}
userchange.UserGroups = validUserGroups
if err := IsNetworkRolesValid(userchange.NetworkRoles); err != nil {
return userchange, errors.New("invalid network roles: " + err.Error())
}
userchange.UserGroups = datatypes.NewJSONType(validUserGroups)
if userchange.DisplayName != "" {
if user.ExternalIdentityProviderID != "" &&
user.DisplayName != userchange.DisplayName {
if _user.ExternalIdentityProviderID != "" &&
_user.DisplayName != userchange.DisplayName {
return userchange, errors.New("display name cannot be updated for external user")
}
user.DisplayName = userchange.DisplayName
_user.DisplayName = userchange.DisplayName
}
if user.ExternalIdentityProviderID != "" &&
userchange.AccountDisabled != user.AccountDisabled {
if _user.ExternalIdentityProviderID != "" &&
userchange.AccountDisabled != _user.AccountDisabled {
return userchange, errors.New("account status cannot be updated for external user")
}
// Reset Gw Access for service users
go UpdateUserGwAccess(*user, *userchange)
go UpdateUserGwAccess(_user, userchange)
if userchange.PlatformRoleID != "" {
user.PlatformRoleID = userchange.PlatformRoleID
// TODO: remove once NMUI stops using these fields.
if user.PlatformRoleID == models.SuperAdminRole {
user.IsSuperAdmin = true
user.IsAdmin = true
} else if user.PlatformRoleID == models.AdminRole {
user.IsSuperAdmin = false
user.IsAdmin = true
_user.PlatformRoleID = userchange.PlatformRoleID
}
for groupID := range userchange.UserGroups.Data() {
_, ok := _user.UserGroups.Data()[groupID]
if !ok {
group, err := GetUserGroup(groupID)
if err != nil {
return userchange, err
}
if group.ExternalIdentityProviderID != "" {
return userchange, errors.New("cannot modify membership of external groups")
}
}
}
for groupID := range _user.UserGroups.Data() {
_, ok := userchange.UserGroups.Data()[groupID]
if !ok {
group, err := GetUserGroup(groupID)
if err != nil {
return userchange, err
}
if group.ExternalIdentityProviderID != "" {
return userchange, errors.New("cannot modify membership of external groups")
}
}
}
var updateMFA bool
if _user.IsMFAEnabled != userchange.IsMFAEnabled {
updateMFA = true
}
_user.IsMFAEnabled = userchange.IsMFAEnabled
var updateAccountStatus bool
if _user.AccountDisabled != userchange.AccountDisabled {
updateAccountStatus = true
}
_user.IsMFAEnabled = userchange.IsMFAEnabled
if !_user.IsMFAEnabled {
_user.TOTPSecret = ""
}
_user.UserGroups = userchange.UserGroups
AddGlobalNetRolesToAdmins(_user)
err := ValidateUser(_user)
if err != nil {
return &schema.User{}, err
}
dbctx := db.BeginTx(context.TODO())
commit := false
defer func() {
if commit {
db.FromContext(dbctx).Commit()
logger.Log(1, "updated user", queryUser)
} else {
user.IsSuperAdmin = false
user.IsAdmin = false
db.FromContext(dbctx).Rollback()
}
}
}()
for groupID := range userchange.UserGroups {
_, ok := user.UserGroups[groupID]
if !ok {
group, err := GetUserGroup(groupID)
if err != nil {
return userchange, err
}
if group.ExternalIdentityProviderID != "" {
return userchange, errors.New("cannot modify membership of external groups")
}
}
}
for groupID := range user.UserGroups {
_, ok := userchange.UserGroups[groupID]
if !ok {
group, err := GetUserGroup(groupID)
if err != nil {
return userchange, err
}
if group.ExternalIdentityProviderID != "" {
return userchange, errors.New("cannot modify membership of external groups")
}
}
}
user.IsMFAEnabled = userchange.IsMFAEnabled
if !user.IsMFAEnabled {
user.TOTPSecret = ""
}
user.UserGroups = userchange.UserGroups
user.NetworkRoles = userchange.NetworkRoles
AddGlobalNetRolesToAdmins(user)
err := ValidateUser(user)
// Fetch existing user to get ID
_schemaUser := schema.User{Username: queryUser}
err = _schemaUser.Get(dbctx)
if err != nil {
return &models.User{}, err
return &schema.User{}, err
}
if err = database.DeleteRecord(database.USERS_TABLE_NAME, queryUser); err != nil {
return &models.User{}, err
}
data, err := json.Marshal(&user)
_user.ID = _schemaUser.ID
err = _user.Update(dbctx)
if err != nil {
return &models.User{}, err
return &schema.User{}, err
}
if err = database.Insert(user.UserName, string(data), database.USERS_TABLE_NAME); err != nil {
return &models.User{}, err
}
if servercfg.CacheEnabled() {
if queryUser != user.UserName {
deleteUserFromCache(queryUser)
if updateAccountStatus {
err = _user.UpdateAccountStatus(dbctx)
if err != nil {
return &schema.User{}, err
}
storeUserInCache(*user)
}
logger.Log(1, "updated user", queryUser)
return user, nil
if updateMFA {
err = _user.UpdateMFA(dbctx)
if err != nil {
return &schema.User{}, err
}
}
commit = true
return _user, nil
}
func validateUserName(user *schema.User) error {
var validationErr error
if len(user.Username) == 0 {
validationErr = errors.Join(validationErr, errors.New("username cannot be empty"))
} else if len(user.Username) <= 3 {
validationErr = errors.Join(validationErr, errors.New("username must have more than 3 characters"))
}
var isValidEmail bool
_, err := mail.ParseAddress(user.Username)
if err == nil {
isValidEmail = true
}
if !isValidEmail {
charset := "abcdefghijklmnopqrstuvwxyz1234567890-."
for _, char := range user.Username {
if !strings.Contains(charset, strings.ToLower(string(char))) {
validationErr = errors.Join(validationErr, errors.New("invalid character(s) in username"))
break
}
}
}
return validationErr
}
// ValidateUser - validates a user model
func ValidateUser(user *models.User) error {
func ValidateUser(user *schema.User) error {
var validationErr error
// check if role is valid
_, err := GetRole(user.PlatformRoleID)
roleCheck := &schema.UserRole{ID: user.PlatformRoleID}
err := roleCheck.Get(db.WithContext(context.TODO()))
if err != nil {
return errors.New("failed to fetch platform role " + user.PlatformRoleID.String())
}
v := validator.New()
_ = v.RegisterValidation("in_charset", func(fl validator.FieldLevel) bool {
isgood := user.NameInCharSet()
return isgood
})
err = v.Struct(user)
if err != nil {
for _, e := range err.(validator.ValidationErrors) {
logger.Log(2, e.Error())
if !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
validationErr = errors.Join(validationErr, fmt.Errorf("invalid user role %s", user.PlatformRoleID))
}
return err
err = validateUserName(user)
if err != nil {
validationErr = errors.Join(validationErr, err)
}
if len(user.Password) < 5 {
validationErr = errors.Join(validationErr, errors.New("password must have a minimum of 5 characters"))
}
return validationErr
}
// DeleteUser - deletes a given user
func DeleteUser(user string) error {
if userRecord, err := database.FetchRecord(database.USERS_TABLE_NAME, user); err != nil || len(userRecord) == 0 {
return errors.New("user does not exist")
_user := schema.User{
Username: user,
}
err := database.DeleteRecord(database.USERS_TABLE_NAME, user)
err := _user.Delete(db.WithContext(context.TODO()))
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("user does not exist")
}
return err
}
if servercfg.CacheEnabled() {
deleteUserFromCache(user)
}
go RemoveUserFromAclPolicy(user)
return (&schema.UserAccessToken{UserName: user}).DeleteAllUserTokens(db.WithContext(context.TODO()))
}
+19 -13
View File
@@ -116,14 +116,14 @@ func SetDNS() error {
return err
}
var corefilestring string
networks, err := GetNetworks()
if err != nil && !database.IsEmptyRecord(err) {
networks, err := (&schema.Network{}).ListAll(db.WithContext(context.TODO()))
if err != nil {
return err
}
for _, net := range networks {
corefilestring = corefilestring + net.NetID + " "
dns, err := GetDNS(net.NetID)
corefilestring = corefilestring + net.Name + " "
dns, err := GetDNS(net.Name)
if err != nil && !database.IsEmptyRecord(err) {
return err
}
@@ -229,7 +229,10 @@ func GetNodeDNS(network string) ([]models.DNSEntry, error) {
if node.Network != network {
continue
}
host, err := GetHost(node.HostID.String())
host := &schema.Host{
ID: node.HostID,
}
err = host.Get(db.WithContext(context.TODO()))
if err != nil {
continue
}
@@ -256,7 +259,10 @@ func GetGwDNS(node *models.Node) string {
if !servercfg.GetManageDNS() {
return ""
}
h, err := GetHost(node.HostID.String())
h := &schema.Host{
ID: node.HostID,
}
err := h.Get(db.WithContext(context.TODO()))
if err != nil {
return ""
}
@@ -340,12 +346,12 @@ func SetCorefile(domains string) error {
// GetAllDNS - gets all dns entries
func GetAllDNS() ([]models.DNSEntry, error) {
var dns []models.DNSEntry
networks, err := GetNetworks()
if err != nil && !database.IsEmptyRecord(err) {
networks, err := (&schema.Network{}).ListAll(db.WithContext(context.TODO()))
if err != nil {
return []models.DNSEntry{}, err
}
for _, net := range networks {
netdns, err := GetDNS(net.NetID)
netdns, err := GetDNS(net.Name)
if err != nil {
return []models.DNSEntry{}, nil
}
@@ -405,7 +411,7 @@ func ValidateDNSCreate(entry models.DNSEntry) error {
})
_ = v.RegisterValidation("network_exists", func(fl validator.FieldLevel) bool {
_, err := GetParentNetwork(entry.Network)
err := (&schema.Network{Name: entry.Network}).Get(db.WithContext(context.TODO()))
return err == nil
})
@@ -437,7 +443,7 @@ func ValidateDNSUpdate(change models.DNSEntry, entry models.DNSEntry) error {
return err == nil && num == 0
})
_ = v.RegisterValidation("network_exists", func(fl validator.FieldLevel) bool {
_, err := GetParentNetwork(change.Network)
err := (&schema.Network{Name: change.Network}).Get(db.WithContext(context.TODO()))
return err == nil
})
@@ -488,7 +494,7 @@ func validateNameserverReq(ns *schema.Nameserver) error {
if len(ns.Servers) == 0 {
return errors.New("atleast one nameserver should be specified")
}
_, err := GetNetwork(ns.NetworkID)
err := (&schema.Network{Name: ns.NetworkID}).Get(db.WithContext(context.TODO()))
if err != nil {
return errors.New("invalid network id")
}
@@ -595,7 +601,7 @@ func getNameserversForNode(node *models.Node) (returnNsLi []models.Nameserver) {
return
}
func getNameserversForHost(h *models.Host) (returnNsLi []models.Nameserver) {
func getNameserversForHost(h *schema.Host) (returnNsLi []models.Nameserver) {
if h.DNS != "yes" {
return
}
+14 -11
View File
@@ -16,7 +16,7 @@ import (
var ValidateEgressReq = validateEgressReq
var AssignVirtualRangeToEgress = func(nw *models.Network, eg *schema.Egress) error {
var AssignVirtualRangeToEgress = func(nw *schema.Network, eg *schema.Egress) error {
return nil
}
@@ -25,12 +25,12 @@ func validateEgressReq(e *schema.Egress) error {
return errors.New("network id is empty")
}
if e.Nat {
e.Mode = models.DirectNAT
e.Mode = schema.DirectNAT
} else {
e.Mode = ""
e.VirtualRange = ""
}
_, err := GetNetwork(e.Network)
err := (&schema.Network{Name: e.Network}).Get(db.WithContext(context.TODO()))
if err != nil {
return errors.New("failed to get network " + err.Error())
}
@@ -50,7 +50,7 @@ func validateEgressReq(e *schema.Egress) error {
return nil
}
func DoesUserHaveAccessToEgress(user *models.User, e *schema.Egress, acls []models.Acl) bool {
func DoesUserHaveAccessToEgress(user *schema.User, e *schema.Egress, acls []models.Acl) bool {
if !e.Status {
return false
}
@@ -64,11 +64,11 @@ func DoesUserHaveAccessToEgress(user *models.User, e *schema.Egress, acls []mode
if _, ok := dstTags[e.ID]; ok || all {
// get all src tags
for _, srcAcl := range acl.Src {
if srcAcl.ID == models.UserAclID && srcAcl.Value == user.UserName {
if srcAcl.ID == models.UserAclID && srcAcl.Value == user.Username {
return true
} else if srcAcl.ID == models.UserGroupAclID {
// fetch all users in the group
if _, ok := user.UserGroups[models.UserGroupID(srcAcl.Value)]; ok {
if _, ok := user.UserGroups.Data()[schema.UserGroupID(srcAcl.Value)]; ok {
return true
}
}
@@ -253,7 +253,7 @@ func AddEgressInfoToPeerByAccess(node, targetNode *models.Node, eli []schema.Egr
}
}
func GetEgressDomainsByAccessForUser(user *models.User, network models.NetworkID) (domains []string) {
func GetEgressDomainsByAccessForUser(user *schema.User, network schema.NetworkID) (domains []string) {
acls := ListUserPolicies(network)
eli, _ := (&schema.Egress{Network: network.String()}).ListByNetwork(db.WithContext(context.TODO()))
defaultDevicePolicy, _ := GetDefaultPolicy(network, models.UserPolicy)
@@ -276,9 +276,9 @@ func GetEgressDomainsByAccessForUser(user *models.User, network models.NetworkID
}
func GetEgressDomainNSForNode(node *models.Node) (returnNsLi []models.Nameserver) {
acls := ListDevicePolicies(models.NetworkID(node.Network))
acls := ListDevicePolicies(schema.NetworkID(node.Network))
eli, _ := (&schema.Egress{Network: node.Network}).ListByNetwork(db.WithContext(context.TODO()))
defaultDevicePolicy, _ := GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy)
defaultDevicePolicy, _ := GetDefaultPolicy(schema.NetworkID(node.Network), models.DevicePolicy)
isDefaultPolicyActive := defaultDevicePolicy.Enabled
for _, e := range eli {
if !e.Status || e.Network != node.Network {
@@ -458,7 +458,7 @@ func RemoveNodeFromEgress(node models.Node) {
}
}
func GetEgressRanges(netID models.NetworkID) (map[string][]string, map[string]struct{}, error) {
func GetEgressRanges(netID schema.NetworkID) (map[string][]string, map[string]struct{}, error) {
resultMap := make(map[string]struct{})
nodeEgressMap := make(map[string][]string)
@@ -496,7 +496,10 @@ func ListAllByRoutingNodeWithDomain(egs []schema.Egress, nodeID string) (egWithD
if err != nil {
return
}
host, err := GetHost(node.HostID.String())
host := &schema.Host{
ID: node.HostID,
}
err = host.Get(db.WithContext(context.TODO()))
if err != nil {
return
}
+22 -16
View File
@@ -72,8 +72,8 @@ func GetEgressRangesOnNetwork(client *models.ExtClient) ([]string, error) {
var result []string
eli, _ := (&schema.Egress{Network: client.Network}).ListByNetwork(db.WithContext(context.TODO()))
staticNode := client.ConvertToStaticNode()
userPolicies := ListUserPolicies(models.NetworkID(client.Network))
defaultUserPolicy, _ := GetDefaultPolicy(models.NetworkID(client.Network), models.UserPolicy)
userPolicies := ListUserPolicies(schema.NetworkID(client.Network))
defaultUserPolicy, _ := GetDefaultPolicy(schema.NetworkID(client.Network), models.UserPolicy)
for _, eI := range eli {
if !eI.Status {
@@ -100,7 +100,8 @@ func GetEgressRangesOnNetwork(client *models.ExtClient) ([]string, error) {
result = append(result, rangesToBeAdded...)
} else {
if staticNode.IsUserNode && staticNode.StaticNode.OwnerID != "" {
user, err := GetUser(staticNode.StaticNode.OwnerID)
user := &schema.User{Username: staticNode.StaticNode.OwnerID}
err := user.Get(db.WithContext(context.TODO()))
if err != nil {
return []string{}, errors.New("user not found")
}
@@ -173,21 +174,21 @@ func DeleteExtClient(network string, clientid string, isUpdate bool) error {
}
if !isUpdate && extClient.RemoteAccessClientID != "" {
LogEvent(&models.Event{
Action: models.Disconnect,
Action: schema.Disconnect,
Source: models.Subject{
ID: extClient.OwnerID,
Name: extClient.OwnerID,
Type: models.UserSub,
Type: schema.UserSub,
},
TriggeredBy: extClient.OwnerID,
Target: models.Subject{
ID: extClient.Network,
Name: extClient.Network,
Type: models.NetworkSub,
Type: schema.NetworkSub,
Info: extClient,
},
NetworkID: models.NetworkID(extClient.Network),
Origin: models.ClientApp,
NetworkID: schema.NetworkID(extClient.Network),
Origin: schema.ClientApp,
})
}
go RemoveNodeFromAclPolicy(extClient.ConvertToStaticNode())
@@ -343,12 +344,13 @@ func CreateExtClient(extclient *models.ExtClient) error {
extclient.ExtraAllowedIPs = []string{}
}
parentNetwork, err := GetNetwork(extclient.Network)
parentNetwork := &schema.Network{Name: extclient.Network}
err := parentNetwork.Get(db.WithContext(context.TODO()))
if err != nil {
return err
}
if extclient.Address == "" {
if parentNetwork.IsIPv4 == "yes" {
if parentNetwork.AddressRange != "" {
newAddress, err := UniqueAddress(extclient.Network, true)
if err != nil {
return err
@@ -358,7 +360,7 @@ func CreateExtClient(extclient *models.ExtClient) error {
}
if extclient.Address6 == "" {
if parentNetwork.IsIPv6 == "yes" {
if parentNetwork.AddressRange6 != "" {
addr6, err := UniqueAddress6(extclient.Network, true)
if err != nil {
return err
@@ -497,7 +499,7 @@ func GetExtClientsByID(nodeid, network string) ([]models.ExtClient, error) {
// GetAllExtClients - gets all ext clients from DB
func GetAllExtClients() ([]models.ExtClient, error) {
var clients = []models.ExtClient{}
currentNetworks, err := GetNetworks()
currentNetworks, err := (&schema.Network{}).ListAll(db.WithContext(context.TODO()))
if err != nil && database.IsEmptyRecord(err) {
return clients, nil
} else if err != nil {
@@ -505,7 +507,7 @@ func GetAllExtClients() ([]models.ExtClient, error) {
}
for i := range currentNetworks {
netName := currentNetworks[i].NetID
netName := currentNetworks[i].Name
netClients, err := GetNetworkExtClients(netName)
if err != nil {
continue
@@ -575,7 +577,10 @@ func GetExtPeers(node, peer *models.Node, addressIdentityMap map[string]models.P
if err != nil {
return peers, idsAndAddr, egressRoutes, err
}
host, err := GetHost(node.HostID.String())
host := &schema.Host{
ID: node.HostID,
}
err = host.Get(db.WithContext(context.TODO()))
if err != nil {
return peers, idsAndAddr, egressRoutes, err
}
@@ -764,7 +769,8 @@ func GetExtclientAllowedIPs(client models.ExtClient) (allowedIPs []string) {
return
}
network, err := GetParentNetwork(client.Network)
network := &schema.Network{Name: client.Network}
err = network.Get(db.WithContext(context.TODO()))
if err != nil {
logger.Log(1, "Could not retrieve Ingress Gateway Network", client.Network)
return
@@ -788,7 +794,7 @@ func GetExtclientAllowedIPs(client models.ExtClient) (allowedIPs []string) {
return
}
func GetStaticNodesByNetwork(network models.NetworkID, onlyWg bool) (staticNode []models.Node) {
func GetStaticNodesByNetwork(network schema.NetworkID, onlyWg bool) (staticNode []models.Node) {
extClients, err := GetAllExtClients()
if err != nil {
return
+36 -12
View File
@@ -8,10 +8,14 @@ import (
"sort"
"time"
"context"
"github.com/google/uuid"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/db"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/schema"
"github.com/gravitl/netmaker/servercfg"
"golang.org/x/exp/slog"
)
@@ -77,7 +81,10 @@ func CreateEgressGateway(gateway models.EgressGatewayRequest) (models.Node, erro
if err != nil {
return models.Node{}, err
}
host, err := GetHost(node.HostID.String())
host := &schema.Host{
ID: node.HostID,
}
err = host.Get(db.WithContext(context.TODO()))
if err != nil {
return models.Node{}, err
}
@@ -184,7 +191,10 @@ func CreateIngressGateway(netid string, nodeid string, ingress models.IngressReq
if node.IsRelayed {
return models.Node{}, errors.New("gateway cannot be created on a relayed node")
}
host, err := GetHost(node.HostID.String())
host := &schema.Host{
ID: node.HostID,
}
err = host.Get(db.WithContext(context.TODO()))
if err != nil {
return models.Node{}, err
}
@@ -192,7 +202,8 @@ func CreateIngressGateway(netid string, nodeid string, ingress models.IngressReq
return models.Node{}, errors.New("gateway can only be created on linux based node")
}
network, err := GetParentNetwork(netid)
network := &schema.Network{Name: netid}
err = network.Get(db.WithContext(context.TODO()))
if err != nil {
return models.Node{}, err
}
@@ -231,7 +242,7 @@ func CreateIngressGateway(netid string, nodeid string, ingress models.IngressReq
node.Tags = make(map[models.TagID]struct{})
}
node.Tags[models.TagID(fmt.Sprintf("%s.%s", netid, models.GwTagName))] = struct{}{}
node.PostureChecksViolations, node.PostureCheckVolationSeverityLevel = CheckPostureViolations(GetPostureCheckDeviceInfoByNode(&node), models.NetworkID(node.Network))
node.PostureChecksViolations, node.PostureCheckVolationSeverityLevel = CheckPostureViolations(GetPostureCheckDeviceInfoByNode(&node), schema.NetworkID(node.Network))
node.LastEvaluatedAt = time.Now().UTC()
err = UpsertNode(&node)
if err != nil {
@@ -253,7 +264,7 @@ func GetIngressGwUsers(node models.Node) (models.IngressGwUsers, error) {
return gwUsers, err
}
for _, user := range users {
if user.PlatformRoleID != models.SuperAdminRole && user.PlatformRoleID != models.AdminRole {
if user.PlatformRoleID != schema.SuperAdminRole && user.PlatformRoleID != schema.AdminRole {
gwUsers.Users = append(gwUsers.Users, user)
}
}
@@ -284,7 +295,7 @@ func DeleteIngressGateway(nodeid string) (models.Node, []models.ExtClient, error
delete(node.Tags, models.TagID(fmt.Sprintf("%s.%s", node.Network, models.GwTagName)))
node.IngressGatewayRange = ""
node.Metadata = ""
node.PostureChecksViolations, node.PostureCheckVolationSeverityLevel = CheckPostureViolations(GetPostureCheckDeviceInfoByNode(&node), models.NetworkID(node.Network))
node.PostureChecksViolations, node.PostureCheckVolationSeverityLevel = CheckPostureViolations(GetPostureCheckDeviceInfoByNode(&node), schema.NetworkID(node.Network))
node.LastEvaluatedAt = time.Now().UTC()
err = UpsertNode(&node)
if err != nil {
@@ -319,18 +330,22 @@ func IsUserAllowedAccessToExtClient(username string, client models.ExtClient) bo
if username == MasterUser {
return true
}
user, err := GetUser(username)
user := &schema.User{Username: username}
err := user.Get(db.WithContext(context.TODO()))
if err != nil {
return false
}
if user.UserName != client.OwnerID {
if user.Username != client.OwnerID {
return false
}
return true
}
func ValidateInetGwReq(inetNode models.Node, req models.InetNodeReq, update bool) error {
inetHost, err := GetHost(inetNode.HostID.String())
inetHost := &schema.Host{
ID: inetNode.HostID,
}
err := inetHost.Get(db.WithContext(context.TODO()))
if err != nil {
return err
}
@@ -352,7 +367,10 @@ func ValidateInetGwReq(inetNode models.Node, req models.InetNodeReq, update bool
if clientNode.IsFailOver || clientNode.IsAutoRelay {
return errors.New("failover node cannot be set to use internet gateway")
}
clientHost, err := GetHost(clientNode.HostID.String())
clientHost := &schema.Host{
ID: clientNode.HostID,
}
err = clientHost.Get(db.WithContext(context.TODO()))
if err != nil {
return err
}
@@ -431,7 +449,10 @@ func UnsetInternetGw(node *models.Node) {
func SetDefaultGwForRelayedUpdate(relayed, relay models.Node, peerUpdate models.HostPeerUpdate) models.HostPeerUpdate {
if relay.InternetGwID != "" {
relayedHost, err := GetHost(relayed.HostID.String())
relayedHost := &schema.Host{
ID: relayed.HostID,
}
err := relayedHost.Get(db.WithContext(context.TODO()))
if err != nil {
return peerUpdate
}
@@ -452,7 +473,10 @@ func SetDefaultGw(node models.Node, peerUpdate models.HostPeerUpdate) models.Hos
if err != nil {
return peerUpdate
}
host, err := GetHost(node.HostID.String())
host := &schema.Host{
ID: node.HostID,
}
err = host.Get(db.WithContext(context.TODO()))
if err != nil {
return peerUpdate
}
+2 -2
View File
@@ -35,12 +35,12 @@ func TestMain(m *testing.M) {
}
func TestCheckPorts(t *testing.T) {
h := models.Host{
h := schema.Host{
ID: uuid.New(),
EndpointIP: net.ParseIP("192.168.1.1"),
ListenPort: 51821,
}
testHost := models.Host{
testHost := schema.Host{
ID: uuid.New(),
EndpointIP: net.ParseIP("192.168.1.1"),
ListenPort: 51830,
+55 -218
View File
@@ -1,8 +1,8 @@
package logic
import (
"context"
"crypto/md5"
"encoding/json"
"errors"
"fmt"
"net"
@@ -11,10 +11,12 @@ import (
"sync"
"github.com/google/uuid"
"github.com/gravitl/netmaker/db"
"github.com/gravitl/netmaker/schema"
"golang.org/x/crypto/bcrypt"
"golang.org/x/exp/slog"
"gorm.io/gorm"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/servercfg"
@@ -22,9 +24,7 @@ import (
)
var (
hostCacheMutex = &sync.RWMutex{}
hostsCacheMap = make(map[string]models.Host)
hostPortMutex = &sync.Mutex{}
hostPortMutex = &sync.Mutex{}
)
var (
@@ -34,99 +34,28 @@ var (
ErrInvalidHostID error = errors.New("invalid host id")
)
var CheckPostureViolations = func(d models.PostureCheckDeviceInfo, network models.NetworkID) (v []models.Violation, level models.Severity) {
return []models.Violation{}, models.SeverityUnknown
var CheckPostureViolations = func(d models.PostureCheckDeviceInfo, network schema.NetworkID) (v []models.Violation, level schema.Severity) {
return []models.Violation{}, schema.SeverityUnknown
}
var GetPostureCheckDeviceInfoByNode = func(node *models.Node) (d models.PostureCheckDeviceInfo) {
return
}
func getHostsFromCache() (hosts []models.Host) {
hostCacheMutex.RLock()
for _, host := range hostsCacheMap {
hosts = append(hosts, host)
}
hostCacheMutex.RUnlock()
return
}
func getHostsMapFromCache() (hostsMap map[string]models.Host) {
hostCacheMutex.RLock()
hostsMap = hostsCacheMap
hostCacheMutex.RUnlock()
return
}
func getHostFromCache(hostID string) (host models.Host, ok bool) {
hostCacheMutex.RLock()
host, ok = hostsCacheMap[hostID]
hostCacheMutex.RUnlock()
return
}
func storeHostInCache(h models.Host) {
hostCacheMutex.Lock()
hostsCacheMap[h.ID.String()] = h
hostCacheMutex.Unlock()
}
func deleteHostFromCache(hostID string) {
hostCacheMutex.Lock()
delete(hostsCacheMap, hostID)
hostCacheMutex.Unlock()
}
func loadHostsIntoCache(hMap map[string]models.Host) {
hostCacheMutex.Lock()
hostsCacheMap = hMap
hostCacheMutex.Unlock()
}
const (
maxPort = 1<<16 - 1
minPort = 1025
)
// GetAllHosts - returns all hosts in flat list or error
func GetAllHosts() ([]models.Host, error) {
var currHosts []models.Host
if servercfg.CacheEnabled() {
currHosts := getHostsFromCache()
if len(currHosts) != 0 {
return currHosts, nil
}
}
records, err := database.FetchRecords(database.HOSTS_TABLE_NAME)
if err != nil && !database.IsEmptyRecord(err) {
return nil, err
}
currHostsMap := make(map[string]models.Host)
if servercfg.CacheEnabled() {
defer loadHostsIntoCache(currHostsMap)
}
for k := range records {
var h models.Host
err = json.Unmarshal([]byte(records[k]), &h)
if err != nil {
return nil, err
}
currHosts = append(currHosts, h)
currHostsMap[h.ID.String()] = h
}
return currHosts, nil
}
// GetAllHostsWithStatus - returns all hosts with at least one
// node with given status.
func GetAllHostsWithStatus(status models.NodeStatus) ([]models.Host, error) {
hosts, err := GetAllHosts()
func GetAllHostsWithStatus(status models.NodeStatus) ([]schema.Host, error) {
hosts, err := (&schema.Host{}).ListAll(db.WithContext(context.TODO()))
if err != nil {
return nil, err
}
var validHosts []models.Host
var validHosts []schema.Host
for _, host := range hosts {
if len(host.Nodes) == 0 {
continue
@@ -146,44 +75,16 @@ func GetAllHostsWithStatus(status models.NodeStatus) ([]models.Host, error) {
}
// GetAllHostsAPI - get's all the hosts in an API usable format
func GetAllHostsAPI(hosts []models.Host) []models.ApiHost {
func GetAllHostsAPI(hosts []schema.Host) []models.ApiHost {
apiHosts := []models.ApiHost{}
for i := range hosts {
newApiHost := hosts[i].ConvertNMHostToAPI()
newApiHost := models.NewApiHostFromSchemaHost(&hosts[i])
apiHosts = append(apiHosts, *newApiHost)
}
return apiHosts[:]
}
// GetHostsMap - gets all the current hosts on machine in a map
func GetHostsMap() (map[string]models.Host, error) {
if servercfg.CacheEnabled() {
hostsMap := getHostsMapFromCache()
if len(hostsMap) != 0 {
return hostsMap, nil
}
}
records, err := database.FetchRecords(database.HOSTS_TABLE_NAME)
if err != nil && !database.IsEmptyRecord(err) {
return nil, err
}
currHostMap := make(map[string]models.Host)
if servercfg.CacheEnabled() {
defer loadHostsIntoCache(currHostMap)
}
for k := range records {
var h models.Host
err = json.Unmarshal([]byte(records[k]), &h)
if err != nil {
return nil, err
}
currHostMap[h.ID.String()] = h
}
return currHostMap, nil
}
func DoesHostExistinTheNetworkAlready(h *models.Host, network models.NetworkID) bool {
func DoesHostExistinTheNetworkAlready(h *schema.Host, network schema.NetworkID) bool {
if len(h.Nodes) > 0 {
for _, nodeID := range h.Nodes {
node, err := GetNodeByID(nodeID)
@@ -195,54 +96,11 @@ func DoesHostExistinTheNetworkAlready(h *models.Host, network models.NetworkID)
return false
}
// GetHost - gets a host from db given id
func GetHost(hostid string) (*models.Host, error) {
if servercfg.CacheEnabled() {
if host, ok := getHostFromCache(hostid); ok {
return &host, nil
}
}
record, err := database.FetchRecord(database.HOSTS_TABLE_NAME, hostid)
if err != nil {
return nil, err
}
var h models.Host
if err = json.Unmarshal([]byte(record), &h); err != nil {
return nil, err
}
if servercfg.CacheEnabled() {
storeHostInCache(h)
}
return &h, nil
}
// GetHostByPubKey - gets a host from db given pubkey
func GetHostByPubKey(hostPubKey string) (*models.Host, error) {
hosts, err := GetAllHosts()
if err != nil {
return nil, err
}
for _, host := range hosts {
if host.PublicKey.String() == hostPubKey {
return &host, nil
}
}
return nil, errors.New("host not found")
}
// CreateHost - creates a host if not exist
func CreateHost(h *models.Host) error {
hosts, hErr := GetAllHosts()
clients, cErr := GetAllExtClients()
if (hErr != nil && !database.IsEmptyRecord(hErr)) ||
(cErr != nil && !database.IsEmptyRecord(cErr)) ||
len(hosts)+len(clients) >= MachinesLimit {
return errors.New("free tier limits exceeded on machines")
}
_, err := GetHost(h.ID.String())
if (err != nil && !database.IsEmptyRecord(err)) || (err == nil) {
func CreateHost(h *schema.Host) error {
_host := &schema.Host{ID: h.ID}
err := _host.Get(db.WithContext(context.TODO()))
if (err != nil && !errors.Is(err, gorm.ErrRecordNotFound)) || (err == nil) {
return ErrHostExists
}
@@ -269,7 +127,7 @@ func CreateHost(h *models.Host) error {
}
// UpdateHost - updates host data by field
func UpdateHost(newHost, currentHost *models.Host) {
func UpdateHost(newHost, currentHost *schema.Host) {
// unchangeable fields via API here
newHost.DaemonInstalled = currentHost.DaemonInstalled
newHost.OS = currentHost.OS
@@ -311,7 +169,7 @@ func UpdateHost(newHost, currentHost *models.Host) {
}
// UpdateHostFromClient - used for updating host on server with update recieved from client
func UpdateHostFromClient(newHost, currHost *models.Host) (sendPeerUpdate bool) {
func UpdateHostFromClient(newHost, currHost *schema.Host) (sendPeerUpdate bool) {
if newHost.PublicKey != currHost.PublicKey {
currHost.PublicKey = newHost.PublicKey
sendPeerUpdate = true
@@ -401,24 +259,12 @@ func UpdateHostFromClient(newHost, currHost *models.Host) (sendPeerUpdate bool)
}
// UpsertHost - upserts into DB a given host model, does not check for existence*
func UpsertHost(h *models.Host) error {
data, err := json.Marshal(h)
if err != nil {
return err
}
err = database.Insert(h.ID.String(), string(data), database.HOSTS_TABLE_NAME)
if err != nil {
return err
}
if servercfg.CacheEnabled() {
storeHostInCache(*h)
}
return nil
func UpsertHost(h *schema.Host) error {
return h.Upsert(db.WithContext(context.TODO()))
}
// UpdateHostNode - handles updates from client nodes
func UpdateHostNode(h *models.Host, newNode *models.Node) (publishDeletedNodeUpdate, publishPeerUpdate bool) {
func UpdateHostNode(h *schema.Host, newNode *models.Node) (publishDeletedNodeUpdate, publishPeerUpdate bool) {
currentNode, err := GetNodeByID(newNode.ID.String())
if err != nil {
return
@@ -442,7 +288,7 @@ func UpdateHostNode(h *models.Host, newNode *models.Node) (publishDeletedNodeUpd
}
// RemoveHost - removes a given host from server
func RemoveHost(h *models.Host, forceDelete bool) error {
func RemoveHost(h *schema.Host, forceDelete bool) error {
if !forceDelete && len(h.Nodes) > 0 {
return fmt.Errorf("host still has associated nodes")
}
@@ -453,13 +299,10 @@ func RemoveHost(h *models.Host, forceDelete bool) error {
}
}
err := database.DeleteRecord(database.HOSTS_TABLE_NAME, h.ID.String())
err := h.Delete(db.WithContext(context.TODO()))
if err != nil {
return err
}
if servercfg.CacheEnabled() {
deleteHostFromCache(h.ID.String())
}
go func() {
if servercfg.IsDNSMode() {
SetDNS()
@@ -469,21 +312,8 @@ func RemoveHost(h *models.Host, forceDelete bool) error {
return nil
}
// RemoveHostByID - removes a given host by id from server
func RemoveHostByID(hostID string) error {
err := database.DeleteRecord(database.HOSTS_TABLE_NAME, hostID)
if err != nil {
return err
}
if servercfg.CacheEnabled() {
deleteHostFromCache(hostID)
}
return nil
}
// UpdateHostNetwork - adds/deletes host from a network
func UpdateHostNetwork(h *models.Host, network string, add bool) (*models.Node, error) {
func UpdateHostNetwork(h *schema.Host, network string, add bool) (*models.Node, error) {
for _, nodeID := range h.Nodes {
node, err := GetNodeByID(nodeID)
if err != nil || node.PendingDelete {
@@ -513,7 +343,7 @@ func UpdateHostNetwork(h *models.Host, network string, add bool) (*models.Node,
// AssociateNodeToHost - associates and creates a node with a given host
// should be the only way nodes get created as of 0.18
func AssociateNodeToHost(n *models.Node, h *models.Host) error {
func AssociateNodeToHost(n *models.Node, h *schema.Host) error {
if len(h.ID.String()) == 0 || h.ID == uuid.Nil {
return ErrInvalidHostID
}
@@ -522,8 +352,8 @@ func AssociateNodeToHost(n *models.Node, h *models.Host) error {
if err != nil {
return err
}
currentHost, err := GetHost(h.ID.String())
if err != nil {
currentHost := &schema.Host{ID: h.ID}
if err = currentHost.Get(db.WithContext(context.TODO())); err != nil {
return err
}
h.HostPass = currentHost.HostPass
@@ -533,7 +363,7 @@ func AssociateNodeToHost(n *models.Node, h *models.Host) error {
// DissasociateNodeFromHost - deletes a node and removes from host nodes
// should be the only way nodes are deleted as of 0.18
func DissasociateNodeFromHost(n *models.Node, h *models.Host) error {
func DissasociateNodeFromHost(n *models.Node, h *schema.Host) error {
if len(h.ID.String()) == 0 || h.ID == uuid.Nil {
return ErrInvalidHostID
}
@@ -566,11 +396,16 @@ func DissasociateNodeFromHost(n *models.Node, h *models.Host) error {
}
// DisassociateAllNodesFromHost - deletes all nodes of the host
func DisassociateAllNodesFromHost(hostID string) error {
host, err := GetHost(hostID)
func DisassociateAllNodesFromHost(hostIDStr string) error {
hostID, err := uuid.Parse(hostIDStr)
if err != nil {
return err
}
host := &schema.Host{ID: hostID}
if err := host.Get(db.WithContext(context.TODO())); err != nil {
return err
}
for _, nodeID := range host.Nodes {
node, err := GetNodeByID(nodeID)
if err != nil {
@@ -588,9 +423,9 @@ func DisassociateAllNodesFromHost(hostID string) error {
}
// GetDefaultHosts - retrieve all hosts marked as default from DB
func GetDefaultHosts() []models.Host {
defaultHostList := []models.Host{}
hosts, err := GetAllHosts()
func GetDefaultHosts() []schema.Host {
defaultHostList := []schema.Host{}
hosts, err := (&schema.Host{}).ListAll(db.WithContext(context.TODO()))
if err != nil {
return defaultHostList
}
@@ -604,8 +439,8 @@ func GetDefaultHosts() []models.Host {
// GetHostNetworks - fetches all the networks
func GetHostNetworks(hostID string) []string {
currHost, err := GetHost(hostID)
if err != nil {
currHost := &schema.Host{ID: uuid.MustParse(hostID)}
if err := currHost.Get(db.WithContext(context.TODO())); err != nil {
return nil
}
nets := []string{}
@@ -620,14 +455,14 @@ func GetHostNetworks(hostID string) []string {
}
// GetRelatedHosts - fetches related hosts of a given host
func GetRelatedHosts(hostID string) []models.Host {
relatedHosts := []models.Host{}
func GetRelatedHosts(hostID string) []schema.Host {
relatedHosts := []schema.Host{}
networks := GetHostNetworks(hostID)
networkMap := make(map[string]struct{})
for _, network := range networks {
networkMap[network] = struct{}{}
}
hosts, err := GetAllHosts()
hosts, err := (&schema.Host{}).ListAll(db.WithContext(context.TODO()))
if err == nil {
for _, host := range hosts {
if host.ID.String() == hostID {
@@ -648,7 +483,7 @@ func GetRelatedHosts(hostID string) []models.Host {
// CheckHostPort checks host endpoints to ensures that hosts on the same server
// with the same endpoint have different listen ports
// in the case of 64535 hosts or more with same endpoint, ports will not be changed
func CheckHostPorts(h *models.Host) (changed bool) {
func CheckHostPorts(h *schema.Host) (changed bool) {
if h.IsStaticPort {
return false
}
@@ -658,7 +493,8 @@ func CheckHostPorts(h *models.Host) (changed bool) {
// Get the current host from database to check if it already has a valid port assigned
// This check happens before the mutex to avoid unnecessary locking
currentHost, err := GetHost(h.ID.String())
currentHost := &schema.Host{ID: h.ID}
err := currentHost.Get(db.WithContext(context.TODO()))
if err == nil && currentHost.ListenPort > 0 {
// If the host already has a port in the database, use that instead of the incoming port
// This prevents the host from being reassigned when the client sends the old port
@@ -679,7 +515,7 @@ func CheckHostPorts(h *models.Host) (changed bool) {
}
}()
hosts, err := GetAllHosts()
hosts, err := (&schema.Host{}).ListAll(db.WithContext(context.TODO()))
if err != nil {
return
}
@@ -735,7 +571,7 @@ func CheckHostPorts(h *models.Host) (changed bool) {
// Re-read hosts to get the latest state (in case another host just changed its port)
// This is important to avoid conflicts when multiple hosts are being processed
latestHosts, err := GetAllHosts()
latestHosts, err := (&schema.Host{}).ListAll(db.WithContext(context.TODO()))
if err == nil {
// Update portsInUse with latest state
for _, host := range latestHosts {
@@ -770,14 +606,15 @@ func CheckHostPorts(h *models.Host) (changed bool) {
}
// HostExists - checks if given host already exists
func HostExists(h *models.Host) bool {
_, err := GetHost(h.ID.String())
return (err != nil && !database.IsEmptyRecord(err)) || (err == nil)
func HostExists(h *schema.Host) bool {
_host := &schema.Host{ID: h.ID}
err := _host.Get(db.WithContext(context.TODO()))
return (err != nil && !errors.Is(err, gorm.ErrRecordNotFound)) || (err == nil)
}
// GetHostByNodeID - returns a host if found to have a node's ID, else nil
func GetHostByNodeID(id string) *models.Host {
hosts, err := GetAllHosts()
func GetHostByNodeID(id string) *schema.Host {
hosts, err := (&schema.Host{}).ListAll(db.WithContext(context.TODO()))
if err != nil {
return nil
}
+11 -11
View File
@@ -59,7 +59,7 @@ func CreateJWT(uuid string, macAddress string, network string) (response string,
}
// CreateUserJWT - creates a user jwt token
func CreateUserAccessJwtToken(username string, role models.UserRoleID, d time.Time, tokenID string) (response string, err error) {
func CreateUserAccessJwtToken(username string, role schema.UserRoleID, d time.Time, tokenID string) (response string, err error) {
claims := &models.UserClaims{
UserName: username,
Role: role,
@@ -83,7 +83,7 @@ func CreateUserAccessJwtToken(username string, role models.UserRoleID, d time.Ti
}
// CreateUserJWT - creates a user jwt token
func CreateUserJWT(username string, role models.UserRoleID, appName string) (response string, err error) {
func CreateUserJWT(username string, role schema.UserRoleID, appName string) (response string, err error) {
duration := GetJwtValidityDuration()
if appName == NetclientApp || appName == NetmakerDesktopApp {
duration = GetJwtValidityDurationForClients()
@@ -187,14 +187,14 @@ func GetUserNameFromToken(authtoken string) (username string, err error) {
}
if token != nil && token.Valid {
var user *models.User
// check that user exists
user, err = GetUser(claims.UserName)
user := &schema.User{Username: claims.UserName}
err = user.Get(db.WithContext(context.TODO()))
if err != nil {
return "", err
}
if user.UserName != "" {
return user.UserName, nil
if user.Username != "" {
return user.Username, nil
}
if user.PlatformRoleID != claims.Role {
return "", Unauthorized_Err
@@ -232,15 +232,15 @@ func VerifyUserToken(tokenString string) (username string, issuperadmin, isadmin
}
}
if token != nil && token.Valid {
var user *models.User
// check that user exists
user, err = GetUser(claims.UserName)
user := &schema.User{Username: claims.UserName}
err = user.Get(db.WithContext(context.TODO()))
if err != nil {
return "", false, false, err
}
if user.UserName != "" {
return user.UserName, user.PlatformRoleID == models.SuperAdminRole,
user.PlatformRoleID == models.AdminRole, nil
if user.Username != "" {
return user.Username, user.PlatformRoleID == schema.SuperAdminRole,
user.PlatformRoleID == schema.AdminRole, nil
}
err = errors.New("user does not exist")
}
+193 -278
View File
@@ -1,7 +1,7 @@
package logic
import (
"encoding/json"
"context"
"errors"
"fmt"
"net"
@@ -11,20 +11,20 @@ import (
"time"
"github.com/c-robinson/iplib"
validator "github.com/go-playground/validator/v10"
"github.com/google/uuid"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/db"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic/acls/nodeacls"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/schema"
"github.com/gravitl/netmaker/servercfg"
"github.com/gravitl/netmaker/validation"
"golang.org/x/exp/slog"
"gorm.io/gorm"
)
var (
networkCacheMutex = &sync.RWMutex{}
networkCacheMap = make(map[string]models.Network)
allocatedIpMap = make(map[string]map[string]net.IP)
)
@@ -38,14 +38,14 @@ func SetAllocatedIpMap() error {
allocatedIpMap = map[string]map[string]net.IP{}
}
currentNetworks, err := GetNetworks()
currentNetworks, err := (&schema.Network{}).ListAll(db.WithContext(context.TODO()))
if err != nil {
return err
}
for _, v := range currentNetworks {
pMap := map[string]net.IP{}
netName := v.NetID
netName := v.Name
//nodes
nodes, err := GetNetworkNodes(netName)
@@ -132,77 +132,16 @@ func RemoveNetworkFromAllocatedIpMap(networkName string) {
networkCacheMutex.Unlock()
}
func getNetworksFromCache() (networks []models.Network) {
networkCacheMutex.RLock()
for _, network := range networkCacheMap {
networks = append(networks, network)
}
networkCacheMutex.RUnlock()
return
}
func deleteNetworkFromCache(key string) {
networkCacheMutex.Lock()
delete(networkCacheMap, key)
networkCacheMutex.Unlock()
}
func getNetworkFromCache(key string) (network models.Network, ok bool) {
networkCacheMutex.RLock()
network, ok = networkCacheMap[key]
networkCacheMutex.RUnlock()
return
}
func storeNetworkInCache(key string, network models.Network) {
networkCacheMutex.Lock()
networkCacheMap[key] = network
networkCacheMutex.Unlock()
}
// GetNetworks - returns all networks from database
func GetNetworks() ([]models.Network, error) {
var networks []models.Network
if servercfg.CacheEnabled() {
networks := getNetworksFromCache()
if len(networks) != 0 {
return networks, nil
}
}
collection, err := database.FetchRecords(database.NETWORKS_TABLE_NAME)
if err != nil {
return networks, err
}
for _, value := range collection {
var network models.Network
if err := json.Unmarshal([]byte(value), &network); err != nil {
return networks, err
}
// add network our array
networks = append(networks, network)
if servercfg.CacheEnabled() {
storeNetworkInCache(network.NetID, network)
}
}
return networks, err
}
// DeleteNetwork - deletes a network
func DeleteNetwork(network string, force bool, done chan struct{}) error {
nodeCount, err := GetNetworkNonServerNodeCount(network)
if nodeCount == 0 || database.IsEmptyRecord(err) {
_network := &schema.Network{
Name: network,
}
// delete server nodes first then db records
err = database.DeleteRecord(database.NETWORKS_TABLE_NAME, network)
if err != nil {
return err
}
if servercfg.CacheEnabled() {
deleteNetworkFromCache(network)
}
return nil
return _network.Delete(db.WithContext(context.TODO()))
}
// Remove All Nodes
@@ -211,8 +150,8 @@ func DeleteNetwork(network string, force bool, done chan struct{}) error {
if err == nil {
for _, node := range nodes {
node := node
host, err := GetHost(node.HostID.String())
if err != nil {
host := &schema.Host{ID: node.HostID}
if err := host.Get(db.WithContext(context.TODO())); err != nil {
continue
}
if node.IsGw {
@@ -228,13 +167,13 @@ func DeleteNetwork(network string, force bool, done chan struct{}) error {
logger.Log(1, "failed to remove the node acls during network delete for network,", network)
}
// delete server nodes first then db records
err = database.DeleteRecord(database.NETWORKS_TABLE_NAME, network)
_network := &schema.Network{
Name: network,
}
err = _network.Delete(db.WithContext(context.TODO()))
if err != nil {
return
}
if servercfg.CacheEnabled() {
deleteNetworkFromCache(network)
}
done <- struct{}{}
close(done)
}()
@@ -254,54 +193,89 @@ func DeleteNetwork(network string, force bool, done chan struct{}) error {
return nil
}
// AssignVirtualNATDefaults determines safe defaults based on VPN CIDR
func AssignVirtualNATDefaults(network *schema.Network, vpnCIDR string) {
const (
cgnatCIDR = "100.64.0.0/10"
fallbackIPv4Pool = "198.18.0.0/15"
defaultIPv4SitePrefix = 24
)
// Parse CGNAT CIDR (should always succeed, but check for safety)
_, cgnatNet, err := net.ParseCIDR(cgnatCIDR)
if err != nil {
// Fallback to default pool if CGNAT parsing fails (shouldn't happen)
network.VirtualNATPoolIPv4 = fallbackIPv4Pool
network.VirtualNATSitePrefixLenIPv4 = defaultIPv4SitePrefix
return
}
var virtualIPv4Pool string
// Parse VPN CIDR - if it fails or is empty, use fallback
if vpnCIDR == "" {
virtualIPv4Pool = fallbackIPv4Pool
} else {
_, vpnNet, err := net.ParseCIDR(vpnCIDR)
if err != nil || vpnNet == nil {
// Invalid VPN CIDR, use fallback
virtualIPv4Pool = fallbackIPv4Pool
} else if !cidrOverlaps(vpnNet, cgnatNet) {
// Safe to reuse VPN CIDR for Virtual NAT
virtualIPv4Pool = vpnCIDR
} else {
// VPN is CGNAT — must not reuse
virtualIPv4Pool = fallbackIPv4Pool
}
}
network.VirtualNATPoolIPv4 = virtualIPv4Pool
network.VirtualNATSitePrefixLenIPv4 = defaultIPv4SitePrefix
}
// cidrOverlaps checks if two CIDR blocks overlap
func cidrOverlaps(a, b *net.IPNet) bool {
return a.Contains(b.IP) || b.Contains(a.IP)
}
// CreateNetwork - creates a network in database
func CreateNetwork(network models.Network) (models.Network, error) {
if network.AddressRange != "" {
normalizedRange, err := NormalizeCIDR(network.AddressRange)
func CreateNetwork(_network *schema.Network) error {
if _network.AddressRange != "" {
normalizedRange, err := NormalizeCIDR(_network.AddressRange)
if err != nil {
return models.Network{}, err
return err
}
network.AddressRange = normalizedRange
_network.AddressRange = normalizedRange
}
if network.AddressRange6 != "" {
normalizedRange, err := NormalizeCIDR(network.AddressRange6)
if _network.AddressRange6 != "" {
normalizedRange, err := NormalizeCIDR(_network.AddressRange6)
if err != nil {
return models.Network{}, err
return err
}
network.AddressRange6 = normalizedRange
_network.AddressRange6 = normalizedRange
}
if !IsNetworkCIDRUnique(network.GetNetworkNetworkCIDR4(), network.GetNetworkNetworkCIDR6()) {
return models.Network{}, errors.New("network cidr already in use")
if !IsNetworkCIDRUnique(GetNetworkNetworkCIDR4(_network), GetNetworkNetworkCIDR6(_network)) {
return errors.New("network cidr already in use")
}
network.SetDefaults()
network.SetNodesLastModified()
network.SetNetworkLastModified()
_network.NodesUpdatedAt = time.Now().UTC()
err := ValidateNetwork(&network, false)
err := ValidateNetwork(_network, false)
if err != nil {
//logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return models.Network{}, err
return err
}
data, err := json.Marshal(&network)
err = _network.Create(db.WithContext(context.TODO()))
if err != nil {
return models.Network{}, err
}
if err = database.Insert(network.NetID, string(data), database.NETWORKS_TABLE_NAME); err != nil {
return models.Network{}, err
}
if servercfg.CacheEnabled() {
storeNetworkInCache(network.NetID, network)
return err
}
_, _ = CreateEnrollmentKey(
0,
time.Time{},
[]string{network.NetID},
[]string{network.NetID},
[]string{_network.Name},
[]string{_network.Name},
[]models.TagID{},
true,
uuid.Nil,
@@ -310,7 +284,22 @@ func CreateNetwork(network models.Network) (models.Network, error) {
false,
)
return network, nil
return nil
}
func GetNetworkNetworkCIDR4(network *schema.Network) *net.IPNet {
if network.AddressRange == "" {
return nil
}
_, netCidr, _ := net.ParseCIDR(network.AddressRange)
return netCidr
}
func GetNetworkNetworkCIDR6(network *schema.Network) *net.IPNet {
if network.AddressRange6 == "" {
return nil
}
_, netCidr, _ := net.ParseCIDR(network.AddressRange6)
return netCidr
}
// GetNetworkNonServerNodeCount - get number of network non server nodes
@@ -320,13 +309,13 @@ func GetNetworkNonServerNodeCount(networkName string) (int, error) {
}
func IsNetworkCIDRUnique(cidr4 *net.IPNet, cidr6 *net.IPNet) bool {
networks, err := GetNetworks()
networks, err := (&schema.Network{}).ListAll(db.WithContext(context.TODO()))
if err != nil {
return database.IsEmptyRecord(err)
return errors.Is(err, gorm.ErrRecordNotFound)
}
for _, network := range networks {
if intersect(network.GetNetworkNetworkCIDR4(), cidr4) ||
intersect(network.GetNetworkNetworkCIDR6(), cidr6) {
if intersect(GetNetworkNetworkCIDR4(&network), cidr4) ||
intersect(GetNetworkNetworkCIDR6(&network), cidr6) {
return false
}
}
@@ -340,55 +329,17 @@ func intersect(n1, n2 *net.IPNet) bool {
return n2.Contains(n1.IP) || n1.Contains(n2.IP)
}
// GetParentNetwork - get parent network
func GetParentNetwork(networkname string) (models.Network, error) {
var network models.Network
if servercfg.CacheEnabled() {
if network, ok := getNetworkFromCache(networkname); ok {
return network, nil
}
}
networkData, err := database.FetchRecord(database.NETWORKS_TABLE_NAME, networkname)
if err != nil {
return network, err
}
if err = json.Unmarshal([]byte(networkData), &network); err != nil {
return models.Network{}, err
}
return network, nil
}
// GetNetworkSettings - get parent network
func GetNetworkSettings(networkname string) (models.Network, error) {
var network models.Network
if servercfg.CacheEnabled() {
if network, ok := getNetworkFromCache(networkname); ok {
return network, nil
}
}
networkData, err := database.FetchRecord(database.NETWORKS_TABLE_NAME, networkname)
if err != nil {
return network, err
}
if err = json.Unmarshal([]byte(networkData), &network); err != nil {
return models.Network{}, err
}
return network, nil
}
// UniqueAddress - get a unique ipv4 address
func UniqueAddressCache(networkName string, reverse bool) (net.IP, error) {
add := net.IP{}
var network models.Network
network, err := GetParentNetwork(networkName)
network := &schema.Network{Name: networkName}
err := network.Get(db.WithContext(context.TODO()))
if err != nil {
logger.Log(0, "UniqueAddressServer encountered an error")
return add, err
}
if network.IsIPv4 == "no" {
if network.AddressRange == "" {
return add, fmt.Errorf("IPv4 not active on network %s", networkName)
}
//ensure AddressRange is valid
@@ -424,14 +375,14 @@ func UniqueAddressCache(networkName string, reverse bool) (net.IP, error) {
// UniqueAddress - get a unique ipv4 address
func UniqueAddressDB(networkName string, reverse bool) (net.IP, error) {
add := net.IP{}
var network models.Network
network, err := GetParentNetwork(networkName)
network := &schema.Network{Name: networkName}
err := network.Get(db.WithContext(context.TODO()))
if err != nil {
logger.Log(0, "UniqueAddressServer encountered an error")
return add, err
}
if network.IsIPv4 == "no" {
if network.AddressRange == "" {
return add, fmt.Errorf("IPv4 not active on network %s", networkName)
}
//ensure AddressRange is valid
@@ -524,12 +475,12 @@ func UniqueAddress6(networkName string, reverse bool) (net.IP, error) {
// UniqueAddress6DB - see if ipv6 address is unique
func UniqueAddress6DB(networkName string, reverse bool) (net.IP, error) {
add := net.IP{}
var network models.Network
network, err := GetParentNetwork(networkName)
network := &schema.Network{Name: networkName}
err := network.Get(db.WithContext(context.TODO()))
if err != nil {
return add, err
}
if network.IsIPv6 == "no" {
if network.AddressRange6 == "" {
return add, fmt.Errorf("IPv6 not active on network %s", networkName)
}
@@ -568,12 +519,12 @@ func UniqueAddress6DB(networkName string, reverse bool) (net.IP, error) {
// UniqueAddress6Cache - see if ipv6 address is unique using cache
func UniqueAddress6Cache(networkName string, reverse bool) (net.IP, error) {
add := net.IP{}
var network models.Network
network, err := GetParentNetwork(networkName)
network := &schema.Network{Name: networkName}
err := network.Get(db.WithContext(context.TODO()))
if err != nil {
return add, err
}
if network.IsIPv6 == "no" {
if network.AddressRange6 == "" {
return add, fmt.Errorf("IPv6 not active on network %s", networkName)
}
@@ -610,60 +561,44 @@ func UniqueAddress6Cache(networkName string, reverse bool) (net.IP, error) {
}
// IsNetworkNameUnique - checks to see if any other networks have the same name (id)
func IsNetworkNameUnique(network *models.Network) (bool, error) {
func IsNetworkNameUnique(network *schema.Network) (bool, error) {
_network := &schema.Network{
Name: network.Name,
}
err := _network.Get(db.WithContext(context.TODO()))
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return true, nil
}
isunique := true
dbs, err := GetNetworks()
if err != nil && !database.IsEmptyRecord(err) {
return false, err
}
for i := 0; i < len(dbs); i++ {
if network.NetID == dbs[i].NetID {
isunique = false
}
}
return isunique, nil
return false, nil
}
func UpsertNetwork(network models.Network) error {
netData, err := json.Marshal(network)
if err != nil {
return err
}
err = database.Insert(network.NetID, string(netData), database.NETWORKS_TABLE_NAME)
if err != nil {
return err
}
if servercfg.CacheEnabled() {
storeNetworkInCache(network.NetID, network)
}
return nil
func UpsertNetwork(_network *schema.Network) error {
return _network.Update(db.WithContext(context.TODO()))
}
// UpdateNetwork - updates a network with another network's fields
func UpdateNetwork(currentNetwork *models.Network, newNetwork *models.Network) error {
func UpdateNetwork(currentNetwork, newNetwork *schema.Network) error {
if err := ValidateNetwork(newNetwork, true); err != nil {
return err
}
if newNetwork.NetID != currentNetwork.NetID {
return errors.New("failed to update network " + newNetwork.NetID + ", cannot change netid.")
if newNetwork.Name != currentNetwork.Name {
return errors.New("failed to update network " + newNetwork.Name + ", cannot change netid.")
}
featureFlags := GetFeatureFlags()
if featureFlags.EnableDeviceApproval {
currentNetwork.AutoJoin = newNetwork.AutoJoin
} else {
currentNetwork.AutoJoin = "true"
currentNetwork.AutoJoin = true
}
currentNetwork.AutoRemove = newNetwork.AutoRemove
currentNetwork.AutoRemoveThreshold = newNetwork.AutoRemoveThreshold
currentNetwork.AutoRemoveTags = newNetwork.AutoRemoveTags
currentNetwork.DefaultACL = newNetwork.DefaultACL
currentNetwork.NameServers = newNetwork.NameServers
// Validate and update Virtual NAT IPv4 settings
if newNetwork.VirtualNATPoolIPv4 != "" {
@@ -705,124 +640,104 @@ func UpdateNetwork(currentNetwork *models.Network, newNetwork *models.Network) e
currentNetwork.VirtualNATPoolIPv4 = newNetwork.VirtualNATPoolIPv4
currentNetwork.VirtualNATSitePrefixLenIPv4 = newNetwork.VirtualNATSitePrefixLenIPv4
}
data, err := json.Marshal(currentNetwork)
if err != nil {
return err
}
newNetwork.SetNetworkLastModified()
err = database.Insert(currentNetwork.NetID, string(data), database.NETWORKS_TABLE_NAME)
if err == nil {
if servercfg.CacheEnabled() {
storeNetworkInCache(newNetwork.NetID, *currentNetwork)
}
}
return err
return currentNetwork.Update(db.WithContext(context.TODO()))
}
// GetNetwork - gets a network from database
func GetNetwork(networkname string) (models.Network, error) {
// validateNetName - checks if a netid of a network uses valid characters
func validateNetName(network *schema.Network) error {
var validationErr error
var network models.Network
if servercfg.CacheEnabled() {
if network, ok := getNetworkFromCache(networkname); ok {
return network, nil
}
if len(network.Name) == 0 {
validationErr = errors.Join(validationErr, errors.New("network name cannot be empty"))
}
networkData, err := database.FetchRecord(database.NETWORKS_TABLE_NAME, networkname)
if err != nil {
return network, err
}
if err = json.Unmarshal([]byte(networkData), &network); err != nil {
return models.Network{}, err
}
return network, nil
}
// NetIDInNetworkCharSet - checks if a netid of a network uses valid characters
func NetIDInNetworkCharSet(network *models.Network) bool {
if len(network.Name) > 32 {
validationErr = errors.Join(validationErr, errors.New("network name cannot be longer than 32 characters"))
}
charset := "abcdefghijklmnopqrstuvwxyz1234567890-_"
for _, char := range network.NetID {
for _, char := range network.Name {
if !strings.Contains(charset, string(char)) {
return false
validationErr = errors.Join(validationErr, errors.New("invalid character(s) in network name"))
break
}
}
return true
return validationErr
}
// Validate - validates fields of an network struct
func ValidateNetwork(network *models.Network, isUpdate bool) error {
v := validator.New()
_ = v.RegisterValidation("netid_valid", func(fl validator.FieldLevel) bool {
inCharSet := NetIDInNetworkCharSet(network)
if isUpdate {
return inCharSet
}
isFieldUnique, _ := IsNetworkNameUnique(network)
return isFieldUnique && inCharSet
})
//
_ = v.RegisterValidation("checkyesorno", func(fl validator.FieldLevel) bool {
return validation.CheckYesOrNo(fl)
})
err := v.Struct(network)
func ValidateNetwork(network *schema.Network, isUpdate bool) error {
var validationErr error
err := validateNetName(network)
if err != nil {
for _, e := range err.(validator.ValidationErrors) {
fmt.Println(e)
validationErr = errors.Join(validationErr, err)
}
if !isUpdate {
nameUnique, _ := IsNetworkNameUnique(network)
if !nameUnique {
validationErr = errors.Join(validationErr, errors.New("invalid network name"))
}
}
return err
}
if network.AddressRange != "" {
_, _, err = net.ParseCIDR(network.AddressRange)
if err != nil {
validationErr = errors.Join(validationErr, err)
}
}
// ParseNetwork - parses a network into a model
func ParseNetwork(value string) (models.Network, error) {
var network models.Network
err := json.Unmarshal([]byte(value), &network)
return network, err
if network.AddressRange6 != "" {
_, _, err = net.ParseCIDR(network.AddressRange6)
if err != nil {
validationErr = errors.Join(validationErr, err)
}
}
if network.DefaultKeepAlive > 1000 {
validationErr = errors.Join(validationErr, errors.New("default keep alive must be less than 1000"))
}
return validationErr
}
// SaveNetwork - save network struct to database
func SaveNetwork(network *models.Network) error {
data, err := json.Marshal(network)
if err != nil {
return err
func SaveNetwork(_network *schema.Network) error {
_existingNetwork := schema.Network{Name: _network.Name}
// Check if network exists to preserve ID
err := _existingNetwork.Get(db.WithContext(context.TODO()))
if err == nil {
_network.ID = _existingNetwork.ID
return _network.Update(db.WithContext(context.TODO()))
}
if err := database.Insert(network.NetID, string(data), database.NETWORKS_TABLE_NAME); err != nil {
return err
}
if servercfg.CacheEnabled() {
storeNetworkInCache(network.NetID, *network)
}
return nil
return _network.Create(db.WithContext(context.TODO()))
}
// NetworkExists - check if network exists
func NetworkExists(name string) (bool, error) {
var network string
var err error
if servercfg.CacheEnabled() {
if _, ok := getNetworkFromCache(name); ok {
return ok, nil
err := (&schema.Network{Name: name}).Get(db.WithContext(context.TODO()))
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return false, nil
}
}
if network, err = database.FetchRecord(database.NETWORKS_TABLE_NAME, name); err != nil {
return false, err
}
return len(network) > 0, nil
return true, nil
}
// SortNetworks - Sorts slice of Networks by their NetID alphabetically with numbers first
func SortNetworks(unsortedNetworks []models.Network) {
func SortNetworks(unsortedNetworks []schema.Network) {
sort.Slice(unsortedNetworks, func(i, j int) bool {
return unsortedNetworks[i].NetID < unsortedNetworks[j].NetID
return unsortedNetworks[i].Name < unsortedNetworks[j].Name
})
}
var NetworkHook models.HookFunc = func(params ...interface{}) error {
networks, err := GetNetworks()
networks, err := (&schema.Network{}).ListAll(db.WithContext(context.TODO()))
if err != nil {
return err
}
@@ -831,10 +746,10 @@ var NetworkHook models.HookFunc = func(params ...interface{}) error {
return err
}
for _, network := range networks {
if network.AutoRemove == "false" || network.AutoRemoveThreshold == 0 {
if !network.AutoRemove || network.AutoRemoveThreshold == 0 {
continue
}
nodes := GetNetworkNodesMemory(allNodes, network.NetID)
nodes := GetNetworkNodesMemory(allNodes, network.Name)
for _, node := range nodes {
if !node.Connected {
continue
@@ -860,9 +775,9 @@ var NetworkHook models.HookFunc = func(params ...interface{}) error {
node.PendingDelete = true
node.Action = models.NODE_DELETE
DeleteNodesCh <- &node
host, err := GetHost(node.HostID.String())
if err == nil && len(host.Nodes) == 0 {
RemoveHostByID(host.ID.String())
host := &schema.Host{ID: node.HostID}
if err := host.Get(db.WithContext(context.TODO())); err == nil && len(host.Nodes) == 0 {
(&schema.Host{ID: host.ID}).Delete(db.WithContext(context.TODO()))
}
}
}
+30 -29
View File
@@ -131,7 +131,7 @@ func GetNetworkNodes(network string) ([]models.Node, error) {
}
// GetHostNodes - fetches all nodes part of the host
func GetHostNodes(host *models.Host) []models.Node {
func GetHostNodes(host *schema.Host) []models.Node {
nodes := []models.Node{}
for _, nodeID := range host.Nodes {
node, err := GetNodeByID(nodeID)
@@ -201,7 +201,8 @@ func UpsertNode(newNode *models.Node) error {
// UpdateNode - takes a node and updates another node with it's values
func UpdateNode(currentNode *models.Node, newNode *models.Node) error {
if newNode.Address.IP.String() != currentNode.Address.IP.String() {
if network, err := GetParentNetwork(newNode.Network); err == nil {
network := &schema.Network{Name: newNode.Network}
if err := network.Get(db.WithContext(context.TODO())); err == nil {
if !IsAddressInCIDR(newNode.Address.IP, network.AddressRange) {
return fmt.Errorf("invalid address provided; out of network range for node %s", newNode.ID)
}
@@ -319,7 +320,10 @@ func DeleteNode(node *models.Node, purge bool) error {
if alreadyDeleted {
logger.Log(1, "forcibly deleting node", node.ID.String())
}
host, err := GetHost(node.HostID.String())
host := &schema.Host{
ID: node.HostID,
}
err := host.Get(db.WithContext(context.TODO()))
if err != nil {
logger.Log(1, "no host found for node", node.ID.String(), "deleting..")
if delErr := DeleteNodeByID(node); delErr != nil {
@@ -428,7 +432,7 @@ func ValidateNode(node *models.Node, isUpdate bool) error {
return isFieldUnique
})
_ = v.RegisterValidation("network_exists", func(fl validator.FieldLevel) bool {
_, err := GetNetworkByNode(node)
err := (&schema.Network{Name: node.Network}).Get(db.WithContext(context.TODO()))
return err == nil
})
_ = v.RegisterValidation("checkyesornoorunset", func(f1 validator.FieldLevel) bool {
@@ -485,7 +489,7 @@ func AddStaticNodestoList(nodes []models.Node) []models.Node {
continue
}
if node.IsIngressGateway {
nodes = append(nodes, GetStaticNodesByNetwork(models.NetworkID(node.Network), false)...)
nodes = append(nodes, GetStaticNodesByNetwork(schema.NetworkID(node.Network), false)...)
netMap[node.Network] = struct{}{}
}
}
@@ -497,7 +501,7 @@ func AddStatusToNodes(nodes []models.Node, statusCall bool) (nodesWithStatus []m
for _, node := range nodes {
if _, ok := aclDefaultPolicyStatusMap[node.Network]; !ok {
// check default policy if all allowed return true
defaultPolicy, _ := GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy)
defaultPolicy, _ := GetDefaultPolicy(schema.NetworkID(node.Network), models.DevicePolicy)
aclDefaultPolicyStatusMap[node.Network] = defaultPolicy.Enabled
}
if statusCall {
@@ -511,24 +515,10 @@ func AddStatusToNodes(nodes []models.Node, statusCall bool) (nodesWithStatus []m
return
}
// GetNetworkByNode - gets the network model from a node
func GetNetworkByNode(node *models.Node) (models.Network, error) {
var network = models.Network{}
networkData, err := database.FetchRecord(database.NETWORKS_TABLE_NAME, node.Network)
if err != nil {
return network, err
}
if err = json.Unmarshal([]byte(networkData), &network); err != nil {
return models.Network{}, err
}
return network, nil
}
// SetNodeDefaults - sets the defaults of a node to avoid empty fields
func SetNodeDefaults(node *models.Node, resetConnected bool) {
parentNetwork, _ := GetNetworkByNode(node)
parentNetwork := &schema.Network{Name: node.Network}
_ = parentNetwork.Get(db.WithContext(context.TODO()))
_, cidr, err := net.ParseCIDR(parentNetwork.AddressRange)
if err == nil {
node.NetworkRange = *cidr
@@ -622,7 +612,10 @@ func GetAllNodesAPI(nodes []models.Node) []models.ApiNode {
for i := range nodes {
node := nodes[i]
if !node.IsStatic {
h, err := GetHost(node.HostID.String())
h := &schema.Host{
ID: node.HostID,
}
err := h.Get(db.WithContext(context.TODO()))
if err == nil {
node.Location = h.Location
node.CountryCode = h.CountryCode
@@ -643,7 +636,10 @@ func GetAllNodesAPIWithLocation(nodes []models.Node) []models.ApiNode {
if node.IsStatic {
newApiNode.Location = node.StaticNode.Location
} else {
host, _ := GetHost(node.HostID.String())
host := &schema.Host{
ID: node.HostID,
}
_ = host.Get(db.WithContext(context.TODO()))
newApiNode.Location = host.Location
}
@@ -694,7 +690,10 @@ func createNode(node *models.Node) error {
addressLock.Lock()
defer addressLock.Unlock()
host, err := GetHost(node.HostID.String())
host := &schema.Host{
ID: node.HostID,
}
err := host.Get(db.WithContext(context.TODO()))
if err != nil {
return err
}
@@ -702,7 +701,8 @@ func createNode(node *models.Node) error {
SetNodeDefaults(node, true)
defaultACLVal := acls.Allowed
parentNetwork, err := GetNetwork(node.Network)
parentNetwork := &schema.Network{Name: node.Network}
err = parentNetwork.Get(db.WithContext(context.TODO()))
if err == nil {
if parentNetwork.DefaultACL != "yes" {
defaultACLVal = acls.NotAllowed
@@ -714,7 +714,7 @@ func createNode(node *models.Node) error {
}
if node.Address.IP == nil {
if parentNetwork.IsIPv4 == "yes" {
if parentNetwork.AddressRange != "" {
if node.Address.IP, err = UniqueAddress(node.Network, false); err != nil {
return err
}
@@ -728,7 +728,7 @@ func createNode(node *models.Node) error {
return fmt.Errorf("invalid address: ipv4 %s is not unique", node.Address.String())
}
if node.Address6.IP == nil {
if parentNetwork.IsIPv6 == "yes" {
if parentNetwork.AddressRange6 != "" {
if node.Address6.IP, err = UniqueAddress6(node.Network, false); err != nil {
return err
}
@@ -836,7 +836,8 @@ func ValidateNodeIp(currentNode *models.Node, newNode *models.ApiNode) error {
}
func ValidateEgressRange(netID string, ranges []string) error {
network, err := GetNetworkSettings(netID)
network := &schema.Network{Name: netID}
err := network.Get(db.WithContext(context.TODO()))
if err != nil {
slog.Error("error getting network with netid", "error", netID, err.Error)
return errors.New("error getting network with netid: " + netID + " " + err.Error())
+44 -29
View File
@@ -64,9 +64,9 @@ var (
)
// GetHostPeerInfo - fetches required peer info per network
func GetHostPeerInfo(host *models.Host) (models.HostPeerInfo, error) {
func GetHostPeerInfo(host *schema.Host) (models.HostPeerInfo, error) {
peerInfo := models.HostPeerInfo{
NetworkPeerIDs: make(map[models.NetworkID]models.PeerMap),
NetworkPeerIDs: make(map[schema.NetworkID]models.PeerMap),
}
allNodes, err := GetAllNodes()
if err != nil {
@@ -84,7 +84,7 @@ func GetHostPeerInfo(host *models.Host) (models.HostPeerInfo, error) {
continue
}
networkPeersInfo := make(models.PeerMap)
defaultDevicePolicy, _ := GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy)
defaultDevicePolicy, _ := GetDefaultPolicy(schema.NetworkID(node.Network), models.DevicePolicy)
currentPeers := GetNetworkNodesMemory(allNodes, node.Network)
for _, peer := range currentPeers {
@@ -94,7 +94,10 @@ func GetHostPeerInfo(host *models.Host) (models.HostPeerInfo, error) {
continue
}
peerHost, err := GetHost(peer.HostID.String())
peerHost := &schema.Host{
ID: peer.HostID,
}
err := peerHost.Get(db.WithContext(context.TODO()))
if err != nil {
logger.Log(4, "no peer host", peer.HostID.String(), err.Error())
continue
@@ -134,13 +137,13 @@ func GetHostPeerInfo(host *models.Host) (models.HostPeerInfo, error) {
}
}
peerInfo.NetworkPeerIDs[models.NetworkID(node.Network)] = networkPeersInfo
peerInfo.NetworkPeerIDs[schema.NetworkID(node.Network)] = networkPeersInfo
}
return peerInfo, nil
}
// GetPeerUpdateForHost - gets the consolidated peer update for the host from all networks
func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.Node,
func GetPeerUpdateForHost(network string, host *schema.Host, allNodes []models.Node,
deletedNode *models.Node, deletedClients []models.ExtClient) (models.HostPeerUpdate, error) {
if host == nil {
return models.HostPeerUpdate{}, errors.New("host is nil")
@@ -165,8 +168,8 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N
HostNetworkInfo: models.HostInfoMap{},
ServerConfig: GetServerInfo(),
DnsNameservers: GetNameserversForHost(host),
AutoRelayNodes: make(map[models.NetworkID][]models.Node),
GwNodes: make(map[models.NetworkID][]models.Node),
AutoRelayNodes: make(map[schema.NetworkID][]models.Node),
GwNodes: make(map[schema.NetworkID][]models.Node),
AddressIdentityMap: make(map[string]models.PeerIdentity),
}
if host.DNS == "no" {
@@ -231,7 +234,7 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N
}
hostPeerUpdate.Nodes = append(hostPeerUpdate.Nodes, node)
acls, _ := ListAclsByNetwork(models.NetworkID(node.Network))
acls, _ := ListAclsByNetwork(schema.NetworkID(node.Network))
eli, _ := (&schema.Egress{Network: node.Network}).ListByNetwork(db.WithContext(context.TODO()))
GetNodeEgressInfo(&node, eli, acls)
egsWithDomain := ListAllByRoutingNodeWithDomain(eli, node.ID.String())
@@ -243,8 +246,8 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N
hostPeerUpdate.IsInternetGw = IsInternetGw(node)
}
hostPeerUpdate.DnsNameservers = append(hostPeerUpdate.DnsNameservers, GetEgressDomainNSForNode(&node)...)
defaultUserPolicy, _ := GetDefaultPolicy(models.NetworkID(node.Network), models.UserPolicy)
defaultDevicePolicy, _ := GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy)
defaultUserPolicy, _ := GetDefaultPolicy(schema.NetworkID(node.Network), models.UserPolicy)
defaultDevicePolicy, _ := GetDefaultPolicy(schema.NetworkID(node.Network), models.DevicePolicy)
if (defaultDevicePolicy.Enabled && defaultUserPolicy.Enabled) ||
(!CheckIfAnyPolicyisUniDirectional(node, acls) &&
!(node.EgressDetails.IsEgressGateway && len(node.EgressDetails.EgressGatewayRanges) > 0)) {
@@ -273,11 +276,6 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N
}
}
}
networkSettings, err := GetNetwork(node.Network)
if err != nil {
continue
}
hostPeerUpdate.NameServers = append(hostPeerUpdate.NameServers, networkSettings.NameServers...)
currentPeers := GetNetworkNodesMemory(allNodes, node.Network)
for _, peer := range currentPeers {
if peer.ID.String() == node.ID.String() {
@@ -285,13 +283,16 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N
continue
}
peerHost, err := GetHost(peer.HostID.String())
peerHost := &schema.Host{
ID: peer.HostID,
}
err := peerHost.Get(db.WithContext(context.TODO()))
if err != nil {
logger.Log(4, "no peer host", peer.HostID.String(), err.Error())
continue
}
peerConfig := wgtypes.PeerConfig{
PublicKey: peerHost.PublicKey,
PublicKey: peerHost.PublicKey.Key,
PersistentKeepaliveInterval: &peerHost.PersistentKeepalive,
ReplaceAllowedIPs: true,
}
@@ -314,7 +315,10 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N
// get relay host
failOverNode, err := GetNodeByID(peer.FailedOverBy.String())
if err == nil {
relayHost, err := GetHost(failOverNode.HostID.String())
relayHost := &schema.Host{
ID: failOverNode.HostID,
}
err := relayHost.Get(db.WithContext(context.TODO()))
if err == nil {
peerKey = relayHost.PublicKey.String()
}
@@ -324,7 +328,10 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N
// get relay host
autoRelayNode, err := GetNodeByID(peerAutoRelayID)
if err == nil {
relayHost, err := GetHost(autoRelayNode.HostID.String())
relayHost := &schema.Host{
ID: autoRelayNode.HostID,
}
err = relayHost.Get(db.WithContext(context.TODO()))
if err == nil {
peerKey = relayHost.PublicKey.String()
}
@@ -334,7 +341,10 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N
// get relay host
relayNode, err := GetNodeByID(peer.RelayedBy)
if err == nil {
relayHost, err := GetHost(relayNode.HostID.String())
relayHost := &schema.Host{
ID: relayNode.HostID,
}
err := relayHost.Get(db.WithContext(context.TODO()))
if err == nil {
peerKey = relayHost.PublicKey.String()
}
@@ -363,11 +373,11 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N
}
if allowedToComm {
if peer.IsAutoRelay {
hostPeerUpdate.AutoRelayNodes[models.NetworkID(peer.Network)] = append(hostPeerUpdate.AutoRelayNodes[models.NetworkID(peer.Network)],
hostPeerUpdate.AutoRelayNodes[schema.NetworkID(peer.Network)] = append(hostPeerUpdate.AutoRelayNodes[schema.NetworkID(peer.Network)],
peer)
}
if node.AutoAssignGateway && peer.IsGw {
hostPeerUpdate.GwNodes[models.NetworkID(peer.Network)] = append(hostPeerUpdate.GwNodes[models.NetworkID(peer.Network)],
hostPeerUpdate.GwNodes[schema.NetworkID(peer.Network)] = append(hostPeerUpdate.GwNodes[schema.NetworkID(peer.Network)],
peer)
}
}
@@ -587,7 +597,7 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N
Network: rangeI,
RouteMetric: 256,
Nat: true,
Mode: models.DirectNAT,
Mode: schema.DirectNAT,
})
}
inetEgressInfo := models.EgressInfo{
@@ -626,11 +636,14 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N
hostPeerUpdate.Peers[i] = peer
}
if deletedNode != nil && host.OS != models.OS_Types.IoT {
peerHost, err := GetHost(deletedNode.HostID.String())
peerHost := &schema.Host{
ID: deletedNode.HostID,
}
err := peerHost.Get(db.WithContext(context.TODO()))
if err == nil && host.ID != peerHost.ID {
if _, ok := peerIndexMap[peerHost.PublicKey.String()]; !ok {
hostPeerUpdate.Peers = append(hostPeerUpdate.Peers, wgtypes.PeerConfig{
PublicKey: peerHost.PublicKey,
PublicKey: peerHost.PublicKey.Key,
Remove: true,
})
}
@@ -662,7 +675,7 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N
}
// GetPeerListenPort - given a host, retrieve it's appropriate listening port
func GetPeerListenPort(host *models.Host) int {
func GetPeerListenPort(host *schema.Host) int {
peerPort := host.ListenPort
if !host.IsStaticPort && host.WgPublicListenPort != 0 {
peerPort = host.WgPublicListenPort
@@ -742,8 +755,10 @@ func GetAllowedIPs(node, peer *models.Node, metrics *models.Metrics) []net.IPNet
}
func GetEgressIPs(peer *models.Node) []net.IPNet {
peerHost, err := GetHost(peer.HostID.String())
peerHost := &schema.Host{
ID: peer.HostID,
}
err := peerHost.Get(db.WithContext(context.TODO()))
if err != nil {
logger.Log(0, "error retrieving host for peer", peer.ID.String(), "host id", peer.HostID.String(), err.Error())
}
+2 -2
View File
@@ -6,7 +6,7 @@ import (
"time"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/schema"
)
const (
@@ -17,7 +17,7 @@ const (
type CValue struct {
Network string `json:"network,omitempty"`
Value string `json:"value"`
Host models.Host `json:"host"`
Host schema.Host `json:"host"`
Pass string `json:"pass,omitempty"`
User string `json:"user,omitempty"`
ALL bool `json:"all,omitempty"`
+8 -5
View File
@@ -37,7 +37,10 @@ func CreateRelay(relay models.RelayRequest) ([]models.Node, models.Node, error)
if err != nil {
return returnnodes, models.Node{}, err
}
host, err := GetHost(node.HostID.String())
host := &schema.Host{
ID: node.HostID,
}
err = host.Get(db.WithContext(context.TODO()))
if err != nil {
return returnnodes, models.Node{}, err
}
@@ -120,7 +123,7 @@ func ValidateRelay(relay models.RelayRequest, update bool) error {
return errors.New("node is already acting as a relay")
}
eli, _ := (&schema.Egress{Network: node.Network}).ListByNetwork(db.WithContext(context.TODO()))
acls, _ := ListAclsByNetwork(models.NetworkID(node.Network))
acls, _ := ListAclsByNetwork(schema.NetworkID(node.Network))
for _, relayedNodeID := range relay.RelayedNodes {
relayedNode, err := GetNodeByID(relayedNodeID)
if err != nil {
@@ -203,7 +206,7 @@ func DeleteRelay(network, nodeid string) ([]models.Node, models.Node, error) {
func RelayedAllowedIPs(peer, node *models.Node) []net.IPNet {
var allowedIPs = []net.IPNet{}
eli, _ := (&schema.Egress{Network: node.Network}).ListByNetwork(db.WithContext(context.TODO()))
acls, _ := ListAclsByNetwork(models.NetworkID(node.Network))
acls, _ := ListAclsByNetwork(schema.NetworkID(node.Network))
for _, relayedNodeID := range peer.RelayedNodes {
if node.ID.String() == relayedNodeID {
continue
@@ -237,9 +240,9 @@ func GetAllowedIpsForRelayed(relayed, relay *models.Node) (allowedIPs []net.IPNe
return
}
serverSettings := GetServerSettings()
acls, _ := ListAclsByNetwork(models.NetworkID(relay.Network))
acls, _ := ListAclsByNetwork(schema.NetworkID(relay.Network))
eli, _ := (&schema.Egress{Network: relay.Network}).ListByNetwork(db.WithContext(context.TODO()))
defaultPolicy, _ := GetDefaultPolicy(models.NetworkID(relay.Network), models.DevicePolicy)
defaultPolicy, _ := GetDefaultPolicy(schema.NetworkID(relay.Network), models.DevicePolicy)
for _, peer := range peers {
if peer.ID == relayed.ID || peer.ID == relay.ID {
continue
+3 -1
View File
@@ -9,6 +9,7 @@ import (
"github.com/gorilla/mux"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/schema"
"github.com/gravitl/netmaker/servercfg"
)
@@ -35,7 +36,8 @@ func SecurityCheck(reqAdmin bool, next http.Handler) http.HandlerFunc {
return
}
if username != MasterUser {
user, err := GetUser(username)
user := &schema.User{Username: username}
err = user.Get(r.Context())
if err != nil {
ReturnErrorResponse(w, r, FormatError(err, "unauthorized"))
return
-22
View File
@@ -5,21 +5,9 @@ import (
"time"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/servercfg"
)
var (
// NetworksLimit - dummy var for community
NetworksLimit = 1000000000
// UsersLimit - dummy var for community
UsersLimit = 1000000000
// MachinesLimit - dummy var for community
MachinesLimit = 1000000000
// IngressesLimit - dummy var for community
IngressesLimit = 1000000000
// EgressesLimit - dummy var for community
EgressesLimit = 1000000000
// FreeTier - specifies if free tier
FreeTier = false
// DefaultTrialEndDate - is a placeholder date for not applicable trial end dates
DefaultTrialEndDate, _ = time.Parse("2006-Jan-02", "2021-Apr-01")
@@ -61,13 +49,3 @@ func StoreJWTSecret(privateKey string) error {
}
return database.Insert("nm-jwt-secret", string(data), database.SERVERCONF_TABLE_NAME)
}
// SetFreeTierLimits - sets limits for free tier
func SetFreeTierLimits() {
FreeTier = true
UsersLimit = servercfg.GetUserLimit()
NetworksLimit = servercfg.GetNetworkLimit()
MachinesLimit = servercfg.GetMachinesLimit()
IngressesLimit = servercfg.GetIngressLimit()
EgressesLimit = servercfg.GetEgressLimit()
}
+8 -4
View File
@@ -11,11 +11,15 @@ import (
"sync"
"time"
"context"
"github.com/gravitl/netmaker/config"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/db"
"github.com/gravitl/netmaker/logic/acls"
"github.com/gravitl/netmaker/logic/acls/nodeacls"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/schema"
"github.com/gravitl/netmaker/servercfg"
)
@@ -87,13 +91,13 @@ func UpsertServerSettings(s models.ServerSettings) error {
}
func setDefaultsforOldAclCfg() {
nets, _ := GetNetworks()
nets, _ := (&schema.Network{}).ListAll(db.WithContext(context.TODO()))
for _, netI := range nets {
if netI.DefaultACL != "yes" {
netI.DefaultACL = "yes"
UpsertNetwork(netI)
UpsertNetwork(&netI)
}
networkACL, err := nodeacls.FetchAllACLs(nodeacls.NetworkID(netI.NetID))
networkACL, err := nodeacls.FetchAllACLs(nodeacls.NetworkID(netI.Name))
if err != nil {
continue
}
@@ -105,7 +109,7 @@ func setDefaultsforOldAclCfg() {
}
networkACL.UpdateACL(id, aclNode)
}
networkACL.Save(acls.ContainerID(netI.NetID))
networkACL.Save(acls.ContainerID(netI.Name))
}
nodes, _ := GetAllNodes()
for _, node := range nodes {
+10 -4
View File
@@ -1,12 +1,15 @@
package logic
import (
"context"
"encoding/json"
"os"
"time"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/db"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/schema"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/servercfg"
@@ -92,9 +95,9 @@ func FetchTelemetryData() telemetryData {
data.IsPro = servercfg.IsPro
data.ExtClients = getDBLength(database.EXT_CLIENT_TABLE_NAME)
data.Users = getDBLength(database.USERS_TABLE_NAME)
data.Networks = getDBLength(database.NETWORKS_TABLE_NAME)
data.Hosts = getDBLength(database.HOSTS_TABLE_NAME)
data.Users, _ = (&schema.User{}).Count(db.WithContext(context.TODO()))
data.Networks, _ = (&schema.Network{}).Count(db.WithContext(context.TODO()))
data.Hosts, _ = (&schema.Host{}).Count(db.WithContext(context.TODO()))
data.Version = servercfg.GetVersion()
data.Servers = getServerCount()
nodes, _ := GetAllNodes()
@@ -143,7 +146,10 @@ func setTelemetryTimestamp(telRecord *models.Telemetry) error {
func getClientCount(nodes []models.Node) clientCount {
var count clientCount
for _, node := range nodes {
host, err := GetHost(node.HostID.String())
host := &schema.Host{
ID: node.HostID,
}
err := host.Get(db.WithContext(context.TODO()))
if err != nil {
continue
}
+4 -9
View File
@@ -18,14 +18,8 @@ func GetCurrentServerUsage() (limits models.Usage) {
if cErr == nil {
limits.Clients = len(clients)
}
users, err := GetUsers()
if err == nil {
limits.Users = len(users)
}
networks, err := GetNetworks()
if err == nil {
limits.Networks = len(networks)
}
limits.Users, _ = (&schema.User{}).Count(db.WithContext(context.TODO()))
limits.Networks, _ = (&schema.Network{}).Count(db.WithContext(context.TODO()))
limits.Egresses, _ = (&schema.Egress{}).Count(db.WithContext(context.TODO()))
nodes, _ := GetAllNodes()
@@ -35,8 +29,9 @@ func GetCurrentServerUsage() (limits models.Usage) {
}
limits.NetworkUsage = make(map[string]models.NetworkUsage)
networks, _ := (&schema.Network{}).ListAll(db.WithContext(context.TODO()))
for _, network := range networks {
limits.NetworkUsage[network.NetID] = models.NetworkUsage{}
limits.NetworkUsage[network.Name] = models.NetworkUsage{}
}
for _, node := range nodes {
+57 -115
View File
@@ -1,68 +1,51 @@
package logic
import (
"encoding/json"
"context"
"fmt"
"time"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/db"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/schema"
)
// Pre-Define Permission Templates for default Roles
var SuperAdminPermissionTemplate = models.UserRolePermissionTemplate{
ID: models.SuperAdminRole,
var SuperAdminPermissionTemplate = schema.UserRole{
ID: schema.SuperAdminRole,
Default: true,
FullAccess: true,
}
var AdminPermissionTemplate = models.UserRolePermissionTemplate{
ID: models.AdminRole,
var AdminPermissionTemplate = schema.UserRole{
ID: schema.AdminRole,
Default: true,
FullAccess: true,
}
var GetFilteredNodesByUserAccess = func(user models.User, nodes []models.Node) (filteredNodes []models.Node) {
var GetFilteredNodesByUserAccess = func(user *schema.User, nodes []models.Node) (filteredNodes []models.Node) {
return
}
var CreateRole = func(r models.UserRolePermissionTemplate) error {
var DeleteRole = func(r schema.UserRoleID, force bool) error {
return nil
}
var DeleteRole = func(r models.UserRoleID, force bool) error {
return nil
}
var FilterNetworksByRole = func(allnetworks []models.Network, user models.User) []models.Network {
var FilterNetworksByRole = func(allnetworks []schema.Network, user *schema.User) []schema.Network {
return allnetworks
}
var IsGroupsValid = func(groups map[models.UserGroupID]struct{}) error {
return nil
}
var IsGroupValid = func(groupID models.UserGroupID) error {
return nil
}
var IsNetworkRolesValid = func(networkRoles map[models.NetworkID]map[models.UserRoleID]struct{}) error {
var IsGroupsValid = func(groups map[schema.UserGroupID]struct{}) error {
return nil
}
var MigrateUserRoleAndGroups = func(u models.User) models.User {
return u
}
var MigrateToUUIDs = func() {}
var UpdateUserGwAccess = func(currentUser, changeUser models.User) {}
var UpdateRole = func(r models.UserRolePermissionTemplate) error { return nil }
var UpdateUserGwAccess = func(currentUser, changeUser *schema.User) {}
var InitialiseRoles = userRolesInit
var IntialiseGroups = func() {}
var DeleteNetworkRoles = func(netID string) {}
var CreateDefaultNetworkRolesAndGroups = func(netID models.NetworkID) {}
var CreateDefaultUserPolicies = func(netID models.NetworkID) {
var CreateDefaultNetworkRolesAndGroups = func(netID schema.NetworkID) {}
var CreateDefaultUserPolicies = func(netID schema.NetworkID) {
if netID.String() == "" {
return
}
@@ -95,96 +78,55 @@ var CreateDefaultUserPolicies = func(netID models.NetworkID) {
InsertAcl(defaultUserAcl)
}
}
var ListUserGroups = func() ([]models.UserGroup, error) { return nil, nil }
var GetUserGroupsInNetwork = func(netID models.NetworkID) (networkGrps map[models.UserGroupID]models.UserGroup) { return }
var GetUserGroup = func(groupId models.UserGroupID) (userGrps models.UserGroup, err error) { return }
var AddGlobalNetRolesToAdmins = func(u *models.User) {}
var GetUserGroup = func(groupId schema.UserGroupID) (userGrps schema.UserGroup, err error) { return }
var AddGlobalNetRolesToAdmins = func(u *schema.User) {}
var EmailInit = func() {}
// GetRole - fetches role template by id
func GetRole(roleID models.UserRoleID) (models.UserRolePermissionTemplate, error) {
// check if role already exists
data, err := database.FetchRecord(database.USER_PERMISSIONS_TABLE_NAME, roleID.String())
if err != nil {
return models.UserRolePermissionTemplate{}, err
}
ur := models.UserRolePermissionTemplate{}
err = json.Unmarshal([]byte(data), &ur)
if err != nil {
return ur, err
}
return ur, nil
}
// ListPlatformRoles - lists user platform roles permission templates
func ListPlatformRoles() ([]models.UserRolePermissionTemplate, error) {
data, err := database.FetchRecords(database.USER_PERMISSIONS_TABLE_NAME)
if err != nil && !database.IsEmptyRecord(err) {
return []models.UserRolePermissionTemplate{}, err
}
userRoles := []models.UserRolePermissionTemplate{}
for _, dataI := range data {
userRole := models.UserRolePermissionTemplate{}
err := json.Unmarshal([]byte(dataI), &userRole)
if err != nil {
continue
}
if userRole.NetworkID != "" {
continue
}
userRoles = append(userRoles, userRole)
}
return userRoles, nil
}
func GetAllRsrcIDForRsrc(rsrc models.RsrcType) models.RsrcID {
func GetAllRsrcIDForRsrc(rsrc schema.RsrcType) schema.RsrcID {
switch rsrc {
case models.HostRsrc:
return models.AllHostRsrcID
case models.RelayRsrc:
return models.AllRelayRsrcID
case models.RemoteAccessGwRsrc:
return models.AllRemoteAccessGwRsrcID
case models.ExtClientsRsrc:
return models.AllExtClientsRsrcID
case models.InetGwRsrc:
return models.AllInetGwRsrcID
case models.EgressGwRsrc:
return models.AllEgressGwRsrcID
case models.NetworkRsrc:
return models.AllNetworkRsrcID
case models.EnrollmentKeysRsrc:
return models.AllEnrollmentKeysRsrcID
case models.UserRsrc:
return models.AllUserRsrcID
case models.DnsRsrc:
return models.AllDnsRsrcID
case models.FailOverRsrc:
return models.AllFailOverRsrcID
case models.AclRsrc:
return models.AllAclsRsrcID
case models.TagRsrc:
return models.AllTagsRsrcID
case models.PostureCheckRsrc:
return models.AllPostureCheckRsrcID
case models.NameserverRsrc:
return models.AllNameserverRsrcID
case models.JitAdminRsrc:
return models.AllJitAdminRsrcID
case models.JitUserRsrc:
return models.AllJitUserRsrcID
case models.UserActivityRsrc:
return models.AllUserActivityRsrcID
case models.TrafficFlow:
return models.AllTrafficFlowRsrcID
case schema.HostRsrc:
return schema.AllHostRsrcID
case schema.RelayRsrc:
return schema.AllRelayRsrcID
case schema.RemoteAccessGwRsrc:
return schema.AllRemoteAccessGwRsrcID
case schema.ExtClientsRsrc:
return schema.AllExtClientsRsrcID
case schema.InetGwRsrc:
return schema.AllInetGwRsrcID
case schema.EgressGwRsrc:
return schema.AllEgressGwRsrcID
case schema.NetworkRsrc:
return schema.AllNetworkRsrcID
case schema.EnrollmentKeysRsrc:
return schema.AllEnrollmentKeysRsrcID
case schema.UserRsrc:
return schema.AllUserRsrcID
case schema.DnsRsrc:
return schema.AllDnsRsrcID
case schema.FailOverRsrc:
return schema.AllFailOverRsrcID
case schema.AclRsrc:
return schema.AllAclsRsrcID
case schema.TagRsrc:
return schema.AllTagsRsrcID
case schema.PostureCheckRsrc:
return schema.AllPostureCheckRsrcID
case schema.NameserverRsrc:
return schema.AllNameserverRsrcID
case schema.JitAdminRsrc:
return schema.AllJitAdminRsrcID
case schema.JitUserRsrc:
return schema.AllJitUserRsrcID
case schema.UserActivityRsrc:
return schema.AllUserActivityRsrcID
case schema.TrafficFlow:
return schema.AllTrafficFlowRsrcID
}
return ""
}
func userRolesInit() {
d, _ := json.Marshal(SuperAdminPermissionTemplate)
database.Insert(SuperAdminPermissionTemplate.ID.String(), string(d), database.USER_PERMISSIONS_TABLE_NAME)
d, _ = json.Marshal(AdminPermissionTemplate)
database.Insert(AdminPermissionTemplate.ID.String(), string(d), database.USER_PERMISSIONS_TABLE_NAME)
_ = SuperAdminPermissionTemplate.Upsert(db.WithContext(context.TODO()))
_ = AdminPermissionTemplate.Upsert(db.WithContext(context.TODO()))
}
+42 -106
View File
@@ -1,120 +1,59 @@
package logic
import (
"context"
"encoding/json"
"errors"
"sort"
"sync"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/db"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/servercfg"
"github.com/gravitl/netmaker/schema"
"gorm.io/datatypes"
)
var (
userCacheMutex = &sync.RWMutex{}
usersCacheMap = make(map[string]models.User)
)
func getUserFromCache(username string) (models.User, bool) {
userCacheMutex.RLock()
user, ok := usersCacheMap[username]
userCacheMutex.RUnlock()
return user, ok
}
func getUsersFromCache() []models.User {
userCacheMutex.RLock()
users := make([]models.User, 0, len(usersCacheMap))
for _, user := range usersCacheMap {
users = append(users, user)
}
userCacheMutex.RUnlock()
return users
}
func storeUserInCache(user models.User) {
userCacheMutex.Lock()
usersCacheMap[user.UserName] = user
userCacheMutex.Unlock()
}
func deleteUserFromCache(username string) {
userCacheMutex.Lock()
delete(usersCacheMap, username)
userCacheMutex.Unlock()
}
func loadUsersIntoCache(users map[string]models.User) {
userCacheMutex.Lock()
usersCacheMap = users
userCacheMutex.Unlock()
}
// GetUser - gets a user
func GetUser(username string) (*models.User, error) {
if servercfg.CacheEnabled() {
if user, ok := getUserFromCache(username); ok {
user.IsSuperAdmin = user.PlatformRoleID == models.SuperAdminRole
user.IsAdmin = user.PlatformRoleID == models.SuperAdminRole || user.PlatformRoleID == models.AdminRole
return &user, nil
}
}
var user models.User
record, err := database.FetchRecord(database.USERS_TABLE_NAME, username)
if err != nil {
return &user, err
}
if err = json.Unmarshal([]byte(record), &user); err != nil {
return &models.User{}, err
}
user.IsSuperAdmin = user.PlatformRoleID == models.SuperAdminRole
user.IsAdmin = user.PlatformRoleID == models.SuperAdminRole || user.PlatformRoleID == models.AdminRole
if servercfg.CacheEnabled() {
storeUserInCache(user)
}
return &user, err
}
// GetReturnUser - gets a user
func GetReturnUser(username string) (models.ReturnUser, error) {
u, err := GetUser(username)
_user := &schema.User{
Username: username,
}
err := _user.Get(db.WithContext(context.TODO()))
if err != nil {
return models.ReturnUser{}, err
}
return ToReturnUser(*u), nil
return ToReturnUser(_user), nil
}
// ToReturnUser - gets a user as a return user
func ToReturnUser(user models.User) models.ReturnUser {
func ToReturnUser(user *schema.User) models.ReturnUser {
return models.ReturnUser{
UserName: user.UserName,
UserName: user.Username,
ExternalIdentityProviderID: user.ExternalIdentityProviderID,
IsMFAEnabled: user.IsMFAEnabled,
DisplayName: user.DisplayName,
AccountDisabled: user.AccountDisabled,
IsAdmin: user.PlatformRoleID == schema.SuperAdminRole || user.PlatformRoleID == schema.AdminRole,
IsSuperAdmin: user.PlatformRoleID == schema.SuperAdminRole,
AuthType: user.AuthType,
RemoteGwIDs: user.RemoteGwIDs,
UserGroups: user.UserGroups,
PlatformRoleID: user.PlatformRoleID,
IsSuperAdmin: user.PlatformRoleID == models.SuperAdminRole,
IsAdmin: user.PlatformRoleID == models.SuperAdminRole || user.PlatformRoleID == models.AdminRole,
NetworkRoles: user.NetworkRoles,
LastLoginTime: user.LastLoginTime,
// no need to set. field not in use.
RemoteGwIDs: nil,
UserGroups: user.UserGroups.Data(),
PlatformRoleID: user.PlatformRoleID,
// no need to set. field not in use.
NetworkRoles: nil,
LastLoginTime: user.LastLoginAt,
CreatedBy: user.CreatedBy,
CreatedAt: user.CreatedAt,
UpdatedAt: user.UpdatedAt,
}
}
// SetUserDefaults - sets the defaults of a user to avoid empty fields
func SetUserDefaults(user *models.User) {
if user.RemoteGwIDs == nil {
user.RemoteGwIDs = make(map[string]struct{})
}
if len(user.NetworkRoles) == 0 {
user.NetworkRoles = make(map[models.NetworkID]map[models.UserRoleID]struct{})
}
if len(user.UserGroups) == 0 {
user.UserGroups = make(map[models.UserGroupID]struct{})
func SetUserDefaults(user *schema.User) {
if len(user.UserGroups.Data()) == 0 {
user.UserGroups = datatypes.NewJSONType(make(map[schema.UserGroupID]struct{}))
}
}
@@ -127,16 +66,13 @@ func SortUsers(unsortedUsers []models.ReturnUser) {
// GetSuperAdmin - fetches superadmin user
func GetSuperAdmin() (models.ReturnUser, error) {
users, err := GetUsers()
_user := &schema.User{}
err := _user.GetSuperAdmin(db.WithContext(context.TODO()))
if err != nil {
return models.ReturnUser{}, err
}
for _, user := range users {
if user.PlatformRoleID == models.SuperAdminRole {
return user, nil
}
}
return models.ReturnUser{}, errors.New("superadmin not found")
return ToReturnUser(_user), nil
}
func InsertPendingUser(u *models.User) error {
@@ -177,8 +113,8 @@ func ListPendingReturnUsers() ([]models.ReturnUser, error) {
user := models.ReturnUser{}
err = json.Unmarshal([]byte(record), &user)
if err == nil {
user.IsSuperAdmin = user.PlatformRoleID == models.SuperAdminRole
user.IsAdmin = user.PlatformRoleID == models.SuperAdminRole || user.PlatformRoleID == models.AdminRole
user.IsSuperAdmin = user.PlatformRoleID == schema.SuperAdminRole
user.IsAdmin = user.PlatformRoleID == schema.SuperAdminRole || user.PlatformRoleID == schema.AdminRole
pendingUsers = append(pendingUsers, user)
}
}
@@ -195,25 +131,25 @@ func ListPendingUsers() ([]models.User, error) {
var user models.User
err = json.Unmarshal([]byte(record), &user)
if err == nil {
user.IsSuperAdmin = user.PlatformRoleID == models.SuperAdminRole
user.IsAdmin = user.PlatformRoleID == models.SuperAdminRole || user.PlatformRoleID == models.AdminRole
user.IsSuperAdmin = user.PlatformRoleID == schema.SuperAdminRole
user.IsAdmin = user.PlatformRoleID == schema.SuperAdminRole || user.PlatformRoleID == schema.AdminRole
pendingUsers = append(pendingUsers, user)
}
}
return pendingUsers, nil
}
func GetUserMap() (map[string]models.User, error) {
users, err := GetUsersDB()
if err != nil && !database.IsEmptyRecord(err) {
func GetUserMap() (map[string]schema.User, error) {
users, err := (&schema.User{}).ListAll(db.WithContext(context.TODO()))
if err != nil {
return nil, err
}
userMap := make(map[string]models.User, len(users))
userMap := make(map[string]schema.User, len(users))
for _, user := range users {
user.IsSuperAdmin = user.PlatformRoleID == models.SuperAdminRole
user.IsAdmin = user.PlatformRoleID == models.SuperAdminRole || user.PlatformRoleID == models.AdminRole
userMap[user.UserName] = user
userMap[user.Username] = user
}
return userMap, nil
}
+9 -21
View File
@@ -2,10 +2,10 @@
package logic
import (
"context"
"crypto/rand"
"encoding/base32"
"encoding/base64"
"encoding/json"
"fmt"
"log/slog"
"net"
@@ -19,9 +19,9 @@ import (
"github.com/blang/semver"
"github.com/c-robinson/iplib"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/db"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/schema"
)
// IsBase64 - checks if a string is in base64 format
@@ -57,23 +57,11 @@ func IsAddressInCIDR(address net.IP, cidr string) bool {
// SetNetworkNodesLastModified - sets the network nodes last modified
func SetNetworkNodesLastModified(networkName string) error {
timestamp := time.Now().Unix()
network, err := GetParentNetwork(networkName)
if err != nil {
return err
_network := &schema.Network{
Name: networkName,
NodesUpdatedAt: time.Now(),
}
network.NodesLastModified = timestamp
data, err := json.Marshal(&network)
if err != nil {
return err
}
err = database.Insert(networkName, string(data), database.NETWORKS_TABLE_NAME)
if err != nil {
return err
}
return nil
return _network.UpdateNodesUpdatedAt(db.WithContext(context.TODO()))
}
// RandomString - returns a random string in a charset
@@ -270,7 +258,7 @@ func GetClientIP(r *http.Request) string {
}
// CompareIfaceSlices compares two slices of Iface for deep equality (order-sensitive)
func CompareIfaceSlices(a, b []models.Iface) bool {
func CompareIfaceSlices(a, b []schema.Iface) bool {
if len(a) != len(b) {
return false
}
@@ -281,7 +269,7 @@ func CompareIfaceSlices(a, b []models.Iface) bool {
}
return true
}
func compareIface(a, b models.Iface) bool {
func compareIface(a, b schema.Iface) bool {
return a.Name == b.Name &&
a.Address.IP.Equal(b.Address.IP) &&
a.Address.Mask.String() == b.Address.Mask.String() &&
+7 -6
View File
@@ -5,8 +5,10 @@ import (
"time"
"github.com/google/uuid"
"github.com/gravitl/netmaker/db"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/schema"
)
const (
@@ -46,8 +48,8 @@ func CheckZombies(newnode *models.Node) {
// checkForZombieHosts - checks if new host has the same macAddress as an existing host
// if true, existing host is added to host zombie collection
func checkForZombieHosts(h *models.Host) {
hosts, err := GetAllHosts()
func checkForZombieHosts(h *schema.Host) {
hosts, err := (&schema.Host{}).ListAll(db.WithContext(context.TODO()))
if err != nil {
logger.Log(3, "error retrieving all hosts", err.Error())
}
@@ -118,12 +120,11 @@ func ManageZombies(ctx context.Context) {
if len(hostZombies) > 0 {
logger.Log(3, "checking host zombies")
for i := len(hostZombies) - 1; i >= 0; i-- {
host, err := GetHost(hostZombies[i].String())
host := &schema.Host{ID: hostZombies[i]}
err := host.Get(db.WithContext(context.TODO()))
if err != nil {
logger.Log(1, "error retrieving zombie host", err.Error())
if host != nil {
logger.Log(1, "deleting ", host.ID.String(), " from zombie list")
}
logger.Log(1, "deleting ", host.ID.String(), " from zombie list")
hostZombies = append(hostZombies[:i], hostZombies[i+1:]...)
continue
}
+8 -7
View File
@@ -53,6 +53,7 @@ var version = "v1.5.0"
func main() {
absoluteConfigPath := flag.String("c", "", "absolute path to configuration file")
flag.Parse()
setVerbosity()
setupConfig(*absoluteConfigPath)
servercfg.SetVersion(version)
fmt.Println(models.RetrieveLogo()) // print the logo
@@ -60,10 +61,6 @@ func main() {
logic.SetAllocatedIpMap()
defer logic.ClearAllocatedIpMap()
setGarbageCollection()
setVerbosity()
if servercfg.DeployedByOperator() && !servercfg.IsPro {
logic.SetFreeTierLimits()
}
defer db.CloseDB()
defer database.CloseDB()
@@ -122,15 +119,19 @@ func initialize() { // Client Mode Prereq Check
logger.FatalLog("error initializing database: ", err.Error())
}
err = migrate.ToSQLSchema()
if err != nil {
// we shouldn't allow user to use the product until the migration is successfully done.
panic(err)
}
initializeUUID()
//initialize cache
_, _ = logic.GetNetworks()
_, _ = logic.GetAllNodes()
_, _ = logic.GetAllHosts()
_, _ = logic.GetAllExtClients()
_ = logic.ListAcls()
_, _ = logic.GetAllEnrollmentKeys()
_, _ = logic.GetUsersDB()
_ = logic.CleanExpiredSSOStates()
migrate.Run()
+62 -303
View File
@@ -33,39 +33,24 @@ func Run() {
updateEnrollmentKeys()
assignSuperAdmin()
createDefaultTagsAndPolicies()
removeOldUserGrps()
migrateToUUIDs()
syncUsers()
updateHosts()
updateNodes()
updateAcls()
updateNewAcls()
logic.MigrateToGws()
migrateToEgressV1()
updateNetworks()
migrateNameservers()
resync()
deleteOldExtclients()
checkAndDeprecateOldAcls()
migrateJITEnabled()
}
func migrateJITEnabled() {
nets, _ := logic.GetNetworks()
for _, netI := range nets {
if netI.JITEnabled == "" {
netI.JITEnabled = "no"
logic.UpsertNetwork(netI)
}
}
}
func checkAndDeprecateOldAcls() {
// check if everything is allowed on old acl and disable old acls
nets, _ := logic.GetNetworks()
nets, _ := (&schema.Network{}).ListAll(db.WithContext(context.TODO()))
disableOldAcls := true
for _, netI := range nets {
networkACL, err := nodeacls.FetchAllACLs(nodeacls.NetworkID(netI.NetID))
networkACL, err := nodeacls.FetchAllACLs(nodeacls.NetworkID(netI.Name))
if err != nil {
continue
}
@@ -79,7 +64,7 @@ func checkAndDeprecateOldAcls() {
}
if disableOldAcls {
netI.DefaultACL = "yes"
logic.UpsertNetwork(netI)
logic.UpsertNetwork(&netI)
}
}
if disableOldAcls {
@@ -91,17 +76,6 @@ func checkAndDeprecateOldAcls() {
}
func updateNetworks() {
nets, _ := logic.GetNetworks()
for _, netI := range nets {
if netI.AutoJoin == "" {
netI.AutoJoin = "true"
logic.UpsertNetwork(netI)
}
if netI.AutoRemove == "" {
netI.AutoRemove = "false"
logic.UpsertNetwork(netI)
}
}
initializeVirtualNATSettings()
}
@@ -115,7 +89,7 @@ func initializeVirtualNATSettings() {
logger.Log(1, "Initializing Virtual NAT settings for existing networks")
defer logger.Log(1, "Completed initializing Virtual NAT settings for existing networks")
networks, err := logic.GetNetworks()
networks, err := (&schema.Network{}).ListAll(db.WithContext(context.TODO()))
if err != nil {
logger.Log(0, "failed to get networks for Virtual NAT migration:", err.Error())
return
@@ -174,7 +148,7 @@ func initializeVirtualNATSettings() {
// If this network needs a unique pool, allocate one
if needsUniquePool {
uniquePool := allocateUniquePoolFromFallback(fallbackNet, poolPrefixLen, allocatedPools, network.NetID)
uniquePool := allocateUniquePoolFromFallback(fallbackNet, poolPrefixLen, allocatedPools, network.Name)
if uniquePool != "" {
vpnCIDR = uniquePool
allocatedPools[uniquePool] = struct{}{}
@@ -182,14 +156,14 @@ func initializeVirtualNATSettings() {
}
// Initialize virtual NAT defaults
network.AssignVirtualNATDefaults(vpnCIDR, network.NetID)
logic.AssignVirtualNATDefaults(&network, vpnCIDR)
// Save the updated network
if err := logic.UpsertNetwork(network); err != nil {
logger.Log(0, "failed to update network", network.NetID, "with Virtual NAT settings:", err.Error())
if err := logic.UpsertNetwork(&network); err != nil {
logger.Log(0, "failed to update network", network.Name, "with Virtual NAT settings:", err.Error())
continue
}
logger.Log(1, "initialized Virtual NAT settings for network", network.NetID, "pool:", network.VirtualNATPoolIPv4)
logger.Log(1, "initialized Virtual NAT settings for network", network.Name, "pool:", network.VirtualNATPoolIPv4)
}
}
@@ -282,116 +256,6 @@ func cidrOverlaps(a, b *net.IPNet) bool {
return a.Contains(b.IP) || b.Contains(a.IP)
}
func migrateNameservers() {
nets, _ := logic.GetNetworks()
user, err := logic.GetSuperAdmin()
if err != nil {
return
}
for _, netI := range nets {
_ = logic.CreateFallbackNameserver(netI.NetID)
_, cidr, err := net.ParseCIDR(netI.AddressRange)
if err != nil {
continue
}
ns := &schema.Nameserver{
NetworkID: netI.NetID,
}
nameservers, _ := ns.ListByNetwork(db.WithContext(context.TODO()))
for _, nsI := range nameservers {
if len(nsI.Domains) != 0 {
for _, matchDomain := range nsI.MatchDomains {
nsI.Domains = append(nsI.Domains, schema.NameserverDomain{
Domain: matchDomain,
})
}
nsI.MatchDomains = []string{}
_ = nsI.Update(db.WithContext(context.TODO()))
}
}
if len(netI.NameServers) > 0 {
ns := schema.Nameserver{
ID: uuid.NewString(),
Name: "upstream nameservers",
NetworkID: netI.NetID,
Servers: []string{},
MatchAll: true,
Domains: []schema.NameserverDomain{
{
Domain: ".",
},
},
Tags: datatypes.JSONMap{
"*": struct{}{},
},
Nodes: make(datatypes.JSONMap),
Status: true,
CreatedBy: user.UserName,
}
for _, nsIP := range netI.NameServers {
if net.ParseIP(nsIP) == nil {
continue
}
if !cidr.Contains(net.ParseIP(nsIP)) {
ns.Servers = append(ns.Servers, nsIP)
}
}
ns.Create(db.WithContext(context.TODO()))
netI.NameServers = []string{}
logic.SaveNetwork(&netI)
}
}
nodes, _ := logic.GetAllNodes()
for _, node := range nodes {
if !node.IsGw {
continue
}
if node.IngressDNS != "" {
if (node.Address.IP != nil && node.Address.IP.String() == node.IngressDNS) ||
(node.Address6.IP != nil && node.Address6.IP.String() == node.IngressDNS) {
continue
}
if node.IngressDNS == "8.8.8.8" || node.IngressDNS == "1.1.1.1" || node.IngressDNS == "9.9.9.9" {
continue
}
h, err := logic.GetHost(node.HostID.String())
if err != nil {
continue
}
ns := schema.Nameserver{
ID: uuid.NewString(),
Name: fmt.Sprintf("%s gw nameservers", h.Name),
NetworkID: node.Network,
Servers: []string{node.IngressDNS},
MatchAll: true,
Domains: []schema.NameserverDomain{
{
Domain: ".",
},
},
Nodes: datatypes.JSONMap{
node.ID.String(): struct{}{},
},
Tags: make(datatypes.JSONMap),
Status: true,
CreatedBy: user.UserName,
}
ns.Create(db.WithContext(context.TODO()))
node.IngressDNS = ""
logic.UpsertNode(&node)
}
}
}
// removes if any stale configurations from previous run.
func resync() {
@@ -447,17 +311,18 @@ func assignSuperAdmin() {
createdSuperAdmin := false
owner := servercfg.GetOwnerEmail()
if owner != "" {
user, err := logic.GetUser(owner)
user := &schema.User{Username: owner}
err = user.Get(db.WithContext(context.TODO()))
if err != nil {
log.Fatal("error getting user", "user", owner, "error", err.Error())
}
user.PlatformRoleID = models.SuperAdminRole
user.PlatformRoleID = schema.SuperAdminRole
err = logic.UpsertUser(*user)
if err != nil {
log.Fatal(
"error updating user to superadmin",
"user",
user.UserName,
user.Username,
"error",
err.Error(),
)
@@ -466,7 +331,7 @@ func assignSuperAdmin() {
}
for _, u := range users {
var isAdmin bool
if u.PlatformRoleID == models.AdminRole {
if u.PlatformRoleID == schema.AdminRole {
isAdmin = true
}
if u.PlatformRoleID == "" && u.IsAdmin {
@@ -474,19 +339,19 @@ func assignSuperAdmin() {
}
if isAdmin {
user, err := logic.GetUser(u.UserName)
user := &schema.User{Username: u.UserName}
err = user.Get(db.WithContext(context.TODO()))
if err != nil {
slog.Error("error getting user", "user", u.UserName, "error", err.Error())
continue
}
user.PlatformRoleID = models.SuperAdminRole
user.IsSuperAdmin = true
user.PlatformRoleID = schema.SuperAdminRole
err = logic.UpsertUser(*user)
if err != nil {
slog.Error(
"error updating user to superadmin",
"user",
user.UserName,
user.Username,
"error",
err.Error(),
)
@@ -549,16 +414,16 @@ func updateEnrollmentKeys() {
existingTags[t] = struct{}{}
}
}
networks, _ := logic.GetNetworks()
networks, _ := (&schema.Network{}).ListAll(db.WithContext(context.TODO()))
for _, network := range networks {
if _, ok := existingTags[network.NetID]; ok {
if _, ok := existingTags[network.Name]; ok {
continue
}
_, _ = logic.CreateEnrollmentKey(
0,
time.Time{},
[]string{network.NetID},
[]string{network.NetID},
[]string{network.Name},
[]string{network.Name},
[]models.TagID{},
true,
uuid.Nil,
@@ -566,57 +431,6 @@ func updateEnrollmentKeys() {
false,
false,
)
}
}
func removeOldUserGrps() {
rows, err := database.FetchRecords(database.USER_GROUPS_TABLE_NAME)
if err != nil {
return
}
for key, row := range rows {
userG := models.UserGroup{}
_ = json.Unmarshal([]byte(row), &userG)
if userG.ID == "" {
database.DeleteRecord(database.USER_GROUPS_TABLE_NAME, key)
}
}
}
func updateHosts() {
rows, err := database.FetchRecords(database.HOSTS_TABLE_NAME)
if err != nil {
logger.Log(0, "failed to fetch database records for hosts")
}
for _, row := range rows {
var host models.Host
if err := json.Unmarshal([]byte(row), &host); err != nil {
logger.Log(0, "failed to unmarshal database row to host", "row", row)
continue
}
if host.PersistentKeepalive == 0 {
host.PersistentKeepalive = models.DefaultPersistentKeepAlive
if err := logic.UpsertHost(&host); err != nil {
logger.Log(0, "failed to upsert host", host.ID.String())
continue
}
}
if host.DNS == "" || (host.DNS != "yes" && host.DNS != "no") {
if logic.GetServerSettings().ManageDNS {
host.DNS = "yes"
} else {
host.DNS = "no"
}
if host.IsDefault {
host.DNS = "yes"
}
logic.UpsertHost(&host)
}
if host.IsDefault && !host.AutoUpdate {
host.AutoUpdate = true
logic.UpsertHost(&host)
}
}
}
@@ -633,7 +447,10 @@ func updateNodes() {
logic.UpsertNode(&node)
}
if node.IsIngressGateway {
host, err := logic.GetHost(node.HostID.String())
host := &schema.Host{
ID: node.HostID,
}
err = host.Get(db.WithContext(context.TODO()))
if err == nil {
go logic.DeleteRole(models.GetRAGRoleID(node.Network, host.ID.String()), true)
}
@@ -682,8 +499,8 @@ func updateAcls() {
if !logic.GetServerSettings().OldAClsSupport {
return
}
networks, err := logic.GetNetworks()
if err != nil && !database.IsEmptyRecord(err) {
networks, err := (&schema.Network{}).ListAll(db.WithContext(context.TODO()))
if err != nil {
slog.Error("acls migration failed. error getting networks", "error", err)
return
}
@@ -691,19 +508,19 @@ func updateAcls() {
// get current acls per network
for _, network := range networks {
var networkAcl acls.ACLContainer
networkAcl, err := networkAcl.Get(acls.ContainerID(network.NetID))
networkAcl, err := networkAcl.Get(acls.ContainerID(network.Name))
if err != nil {
if database.IsEmptyRecord(err) {
continue
}
slog.Error(fmt.Sprintf("error during acls migration. error getting acls for network: %s", network.NetID), "error", err)
slog.Error(fmt.Sprintf("error during acls migration. error getting acls for network: %s", network.Name), "error", err)
continue
}
// convert old acls to new acls with clients
// TODO: optimise O(n^2) operation
clients, err := logic.GetNetworkExtClients(network.NetID)
clients, err := logic.GetNetworkExtClients(network.Name)
if err != nil {
slog.Error(fmt.Sprintf("error during acls migration. error getting clients for network: %s", network.NetID), "error", err)
slog.Error(fmt.Sprintf("error during acls migration. error getting clients for network: %s", network.Name), "error", err)
continue
}
clientsIdMap := make(map[string]struct{})
@@ -749,7 +566,7 @@ func updateAcls() {
continue
}
if nodeAcl == nil {
slog.Warn("acls migration bad data: nil node acl", "node", id, "network", network.NetID)
slog.Warn("acls migration bad data: nil node acl", "node", id, "network", network.Name)
continue
}
nodeAcl[acls.AclID(client.ClientID)] = acls.Allowed
@@ -793,19 +610,19 @@ func updateAcls() {
}
// save new acls
slog.Debug(fmt.Sprintf("(migration) saving new acls for network: %s", network.NetID), "networkAcl", networkAcl)
if _, err := networkAcl.Save(acls.ContainerID(network.NetID)); err != nil {
slog.Error(fmt.Sprintf("error during acls migration. error saving new acls for network: %s", network.NetID), "error", err)
slog.Debug(fmt.Sprintf("(migration) saving new acls for network: %s", network.Name), "networkAcl", networkAcl)
if _, err := networkAcl.Save(acls.ContainerID(network.Name)); err != nil {
slog.Error(fmt.Sprintf("error during acls migration. error saving new acls for network: %s", network.Name), "error", err)
continue
}
slog.Info(fmt.Sprintf("(migration) successfully saved new acls for network: %s", network.NetID))
slog.Info(fmt.Sprintf("(migration) successfully saved new acls for network: %s", network.Name))
}
}
func updateNewAcls() {
if servercfg.IsPro {
userGroups, _ := logic.ListUserGroups()
userGroupMap := make(map[models.UserGroupID]models.UserGroup)
userGroups, _ := (&schema.UserGroup{}).ListAll(db.WithContext(context.TODO()))
userGroupMap := make(map[schema.UserGroupID]schema.UserGroup)
for _, userGroup := range userGroups {
userGroupMap[userGroup.ID] = userGroup
}
@@ -815,14 +632,14 @@ func updateNewAcls() {
aclSrc := make([]models.AclPolicyTag, 0)
for _, src := range acl.Src {
if src.ID == models.UserGroupAclID {
userGroup, ok := userGroupMap[models.UserGroupID(src.Value)]
userGroup, ok := userGroupMap[schema.UserGroupID(src.Value)]
if !ok {
// if the group doesn't exist, don't add it to the acl's src.
continue
} else {
_, allNetworkAccess := userGroup.NetworkRoles[models.AllNetworks]
_, allNetworkAccess := userGroup.NetworkRoles.Data()[schema.AllNetworks]
if !allNetworkAccess {
_, ok := userGroup.NetworkRoles[acl.NetworkID]
_, ok := userGroup.NetworkRoles.Data()[acl.NetworkID]
if !ok {
// if the group doesn't have permissions for the acl's
// network, don't add it to the acl's src.
@@ -863,107 +680,46 @@ func MigrateEmqx() {
}
func migrateToUUIDs() {
logic.MigrateToUUIDs()
}
func syncUsers() {
logger.Log(1, "Migrating Users (SyncUsers)")
defer logger.Log(1, "Completed migrating Users (SyncUsers)")
// create default network user roles for existing networks
if servercfg.IsPro {
networks, _ := logic.GetNetworks()
networks, _ := (&schema.Network{}).ListAll(db.WithContext(context.TODO()))
for _, netI := range networks {
logic.CreateDefaultNetworkRolesAndGroups(models.NetworkID(netI.NetID))
logic.CreateDefaultNetworkRolesAndGroups(schema.NetworkID(netI.Name))
}
}
users, err := logic.GetUsersDB()
users, err := (&schema.User{}).ListAll(db.WithContext(context.TODO()))
if err == nil {
for _, user := range users {
user := user
needsUpdate := false
// Update admin flags based on platform role
if user.PlatformRoleID == models.AdminRole && !user.IsAdmin {
user.IsAdmin = true
user.IsSuperAdmin = false
needsUpdate = true
user.AuthType = schema.BasicAuth
if logic.IsOauthUser(&user) == nil {
user.AuthType = schema.OAuth
}
if user.PlatformRoleID == models.SuperAdminRole && !user.IsSuperAdmin {
user.IsSuperAdmin = true
user.IsAdmin = true
needsUpdate = true
}
if user.PlatformRoleID == models.PlatformUser || user.PlatformRoleID == models.ServiceUser {
if user.IsSuperAdmin || user.IsAdmin {
user.IsSuperAdmin = false
user.IsAdmin = false
needsUpdate = true
}
if len(user.UserGroups.Data()) == 0 {
user.UserGroups = datatypes.NewJSONType(make(map[schema.UserGroupID]struct{}))
}
if user.PlatformRoleID.String() != "" {
// Initialize maps if nil
if user.NetworkRoles == nil {
user.NetworkRoles = make(map[models.NetworkID]map[models.UserRoleID]struct{})
needsUpdate = true
}
if user.UserGroups == nil {
user.UserGroups = make(map[models.UserGroupID]struct{})
needsUpdate = true
}
// Migrate user roles and groups, then add global net roles
user = logic.MigrateUserRoleAndGroups(user)
logic.AddGlobalNetRolesToAdmins(&user)
needsUpdate = true
} else {
// Set auth type
user.AuthType = models.BasicAuth
if logic.IsOauthUser(&user) == nil {
user.AuthType = models.OAuth
}
if len(user.NetworkRoles) == 0 {
user.NetworkRoles = make(map[models.NetworkID]map[models.UserRoleID]struct{})
}
if len(user.UserGroups) == 0 {
user.UserGroups = make(map[models.UserGroupID]struct{})
}
// We reach here only if the platform role id has not been set.
//
// Thus, we use the boolean fields to assign the role.
if user.IsSuperAdmin {
user.PlatformRoleID = models.SuperAdminRole
} else if user.IsAdmin {
user.PlatformRoleID = models.AdminRole
} else {
user.PlatformRoleID = models.ServiceUser
}
logic.AddGlobalNetRolesToAdmins(&user)
user = logic.MigrateUserRoleAndGroups(user)
needsUpdate = true
}
// Only update user once after all changes are collected
if needsUpdate {
logic.UpsertUser(user)
}
logic.AddGlobalNetRolesToAdmins(&user)
logic.UpsertUser(user)
}
}
}
func createDefaultTagsAndPolicies() {
networks, err := logic.GetNetworks()
networks, err := (&schema.Network{}).ListAll(db.WithContext(context.TODO()))
if err != nil {
return
}
for _, network := range networks {
logic.CreateDefaultTags(models.NetworkID(network.NetID))
logic.CreateDefaultAclNetworkPolicies(models.NetworkID(network.NetID))
logic.CreateDefaultTags(schema.NetworkID(network.Name))
logic.CreateDefaultAclNetworkPolicies(schema.NetworkID(network.Name))
// delete old remote access gws policy
logic.DeleteAcl(models.Acl{ID: fmt.Sprintf("%s.%s", network.NetID, "all-remote-access-gws")})
logic.DeleteAcl(models.Acl{ID: fmt.Sprintf("%s.%s", network.Name, "all-remote-access-gws")})
}
logic.MigrateAclPolicies()
if !servercfg.IsPro {
@@ -986,7 +742,10 @@ func migrateToEgressV1() {
}
for _, node := range nodes {
if node.IsEgressGateway {
_, err := logic.GetHost(node.HostID.String())
host := &schema.Host{
ID: node.HostID,
}
err := host.Get(db.WithContext(context.TODO()))
if err != nil {
continue
}
@@ -1020,7 +779,7 @@ func migrateToEgressV1() {
MetaData: "",
Default: false,
ServiceType: models.Any,
NetworkID: models.NetworkID(node.Network),
NetworkID: schema.NetworkID(node.Network),
Proto: models.ALL,
RuleType: models.DevicePolicy,
Src: []models.AclPolicyTag{
@@ -1049,7 +808,7 @@ func migrateToEgressV1() {
MetaData: "",
Default: false,
ServiceType: models.Any,
NetworkID: models.NetworkID(node.Network),
NetworkID: schema.NetworkID(node.Network),
Proto: models.ALL,
RuleType: models.UserPolicy,
Src: []models.AclPolicyTag{
-183
View File
@@ -1,183 +0,0 @@
package migrate
import (
"context"
"errors"
"fmt"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/db"
"github.com/gravitl/netmaker/schema"
"github.com/gravitl/netmaker/servercfg"
"gorm.io/gorm"
"os"
"path/filepath"
)
// ToSQLSchema migrates the data from key-value
// db to sql db.
//
// This function archives the old data and does not
// delete it.
//
// Based on the db server, the archival is done in the
// following way:
//
// 1. Sqlite: Moves the old data to a
// netmaker_archive.db file.
//
// 2. Postgres: Moves the data to a netmaker_archive
// schema within the same database.
func ToSQLSchema() error {
// initialize sql schema db.
err := db.InitializeDB(schema.ListModels()...)
if err != nil {
return err
}
// migrate, if not done already.
err = migrate()
if err != nil {
return err
}
// archive key-value schema db, if not done already.
// ignore errors.
_ = archive()
return nil
}
func migrate() error {
// begin a new transaction.
dbctx := db.BeginTx(context.TODO())
commit := false
defer func() {
if commit {
db.FromContext(dbctx).Commit()
} else {
db.FromContext(dbctx).Rollback()
}
}()
// check if migrated already.
migrationJob := &schema.Job{
ID: "migration-v1.0.0",
}
err := migrationJob.Get(dbctx)
if err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
// initialize key-value schema db.
err := database.InitializeDatabase()
if err != nil {
return err
}
defer database.CloseDB()
// migrate.
// TODO: add migration code.
// mark migration job completed.
err = migrationJob.Create(dbctx)
if err != nil {
return err
}
commit = true
}
return nil
}
func archive() error {
dbServer := servercfg.GetDB()
if dbServer != "sqlite" && dbServer != "postgres" {
return nil
}
// begin a new transaction.
dbctx := db.BeginTx(context.TODO())
commit := false
defer func() {
if commit {
db.FromContext(dbctx).Commit()
} else {
db.FromContext(dbctx).Rollback()
}
}()
// check if key-value schema db archived already.
archivalJob := &schema.Job{
ID: "archival-v1.0.0",
}
err := archivalJob.Get(dbctx)
if err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
// archive.
switch dbServer {
case "sqlite":
err = sqliteArchiveOldData()
default:
err = pgArchiveOldData()
}
if err != nil {
return err
}
// mark archival job completed.
err = archivalJob.Create(dbctx)
if err != nil {
return err
}
commit = true
} else {
// remove the residual
if dbServer == "sqlite" {
_ = os.Remove(filepath.Join("data", "netmaker.db"))
}
}
return nil
}
func sqliteArchiveOldData() error {
oldDBFilePath := filepath.Join("data", "netmaker.db")
archiveDBFilePath := filepath.Join("data", "netmaker_archive.db")
// check if netmaker_archive.db exist.
_, err := os.Stat(archiveDBFilePath)
if err == nil {
return nil
} else if !os.IsNotExist(err) {
return err
}
// rename old db file to netmaker_archive.db.
return os.Rename(oldDBFilePath, archiveDBFilePath)
}
func pgArchiveOldData() error {
_, err := database.PGDB.Exec("CREATE SCHEMA IF NOT EXISTS netmaker_archive")
if err != nil {
return err
}
for _, table := range database.Tables {
_, err := database.PGDB.Exec(
fmt.Sprintf(
"ALTER TABLE public.%s SET SCHEMA netmaker_archive",
table,
),
)
if err != nil {
return err
}
}
return nil
}
+23 -173
View File
@@ -16,6 +16,7 @@
package migrate
import (
"context"
"testing"
"time"
@@ -23,10 +24,10 @@ import (
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/db"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/schema"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/datatypes"
)
// TestSyncUsersLargeScale tests syncUsers() with a large number of users
@@ -52,33 +53,25 @@ func TestSyncUsersLargeScale(t *testing.T) {
startCreate := time.Now()
for i := 0; i < numUsers; i++ {
user := models.User{
UserName: "testuser" + uuid.New().String()[:8],
user := schema.User{
Username: "testuser" + uuid.New().String()[:8],
Password: "testpassword123",
DisplayName: "Test User " + uuid.New().String()[:8],
IsAdmin: i%10 == 0, // 10% are admins
IsSuperAdmin: i%100 == 0, // 1% are super admins
PlatformRoleID: models.ServiceUser, // Most are service users
PlatformRoleID: schema.ServiceUser, // Most are service users
}
// Assign different platform roles
if i%100 == 0 {
user.PlatformRoleID = models.SuperAdminRole
user.PlatformRoleID = schema.SuperAdminRole
} else if i%10 == 0 {
user.PlatformRoleID = models.AdminRole
user.PlatformRoleID = schema.AdminRole
} else if i%5 == 0 {
user.PlatformRoleID = models.PlatformUser
}
// Some users have network roles
if i%3 == 0 {
user.NetworkRoles = make(map[models.NetworkID]map[models.UserRoleID]struct{})
user.NetworkRoles[models.NetworkID("test-network")] = make(map[models.UserRoleID]struct{})
user.PlatformRoleID = schema.PlatformUser
}
// Some users have user groups
if i%4 == 0 {
user.UserGroups = make(map[models.UserGroupID]struct{})
user.UserGroups = datatypes.NewJSONType(make(map[schema.UserGroupID]struct{}))
}
err := logic.UpsertUser(user)
@@ -89,7 +82,7 @@ func TestSyncUsersLargeScale(t *testing.T) {
t.Logf("Created %d users in %v (avg: %v per user)", numUsers, createDuration, createDuration/time.Duration(numUsers))
// Verify users were created
users, err := logic.GetUsersDB()
users, err := (&schema.User{}).ListAll(db.WithContext(context.TODO()))
require.NoError(t, err)
assert.GreaterOrEqual(t, len(users), numUsers, "Expected at least %d users", numUsers)
@@ -103,7 +96,7 @@ func TestSyncUsersLargeScale(t *testing.T) {
syncDuration, len(users), syncDuration/time.Duration(len(users)))
// Verify users were migrated correctly
usersAfter, err := logic.GetUsersDB()
usersAfter, err := (&schema.User{}).ListAll(db.WithContext(context.TODO()))
require.NoError(t, err)
assert.Equal(t, len(users), len(usersAfter), "User count should remain the same")
@@ -117,19 +110,7 @@ func TestSyncUsersLargeScale(t *testing.T) {
user := usersAfter[i]
// Verify platform role is set
assert.NotEmpty(t, user.PlatformRoleID.String(), "User %s should have PlatformRoleID", user.UserName)
// Verify admin flags match platform role
if user.PlatformRoleID == models.SuperAdminRole {
assert.True(t, user.IsSuperAdmin, "SuperAdmin user should have IsSuperAdmin=true")
assert.True(t, user.IsAdmin, "SuperAdmin user should have IsAdmin=true")
} else if user.PlatformRoleID == models.AdminRole {
assert.True(t, user.IsAdmin, "Admin user should have IsAdmin=true")
assert.False(t, user.IsSuperAdmin, "Admin user should not have IsSuperAdmin=true")
} else {
assert.False(t, user.IsSuperAdmin, "Non-admin user should not have IsSuperAdmin=true")
assert.False(t, user.IsAdmin, "Non-admin user should not have IsAdmin=true")
}
assert.NotEmpty(t, user.PlatformRoleID.String(), "User %s should have PlatformRoleID", user.Username)
// Verify user groups are initialized
assert.NotNil(t, user.UserGroups, "User should have UserGroups map")
@@ -166,20 +147,20 @@ func TestMigrateToUUIDsLargeScale(t *testing.T) {
startCreate := time.Now()
for i := 0; i < numUsers; i++ {
user := models.User{
UserName: "testuser" + uuid.New().String()[:8],
user := schema.User{
Username: "testuser" + uuid.New().String()[:8],
Password: "testpassword123",
DisplayName: "Test User " + uuid.New().String()[:8],
PlatformRoleID: models.ServiceUser,
UserGroups: make(map[models.UserGroupID]struct{}),
PlatformRoleID: schema.ServiceUser,
UserGroups: datatypes.NewJSONType(make(map[schema.UserGroupID]struct{})),
}
// Add some user groups with non-UUID IDs (to trigger migration)
if i%2 == 0 {
user.UserGroups[models.UserGroupID("old-group-1")] = struct{}{}
user.UserGroups.Data()[("old-group-1")] = struct{}{}
}
if i%3 == 0 {
user.UserGroups[models.UserGroupID("old-group-2")] = struct{}{}
user.UserGroups.Data()[("old-group-2")] = struct{}{}
}
err := logic.UpsertUser(user)
@@ -190,140 +171,9 @@ func TestMigrateToUUIDsLargeScale(t *testing.T) {
t.Logf("Created %d users in %v", numUsers, createDuration)
// Verify users were created
users, err := logic.GetUsersDB()
users, err := (&schema.User{}).ListAll(db.WithContext(context.TODO()))
require.NoError(t, err)
assert.GreaterOrEqual(t, len(users), numUsers, "Expected at least %d users", numUsers)
// Test MigrateToUUIDs() performance
t.Log("Running MigrateToUUIDs() migration...")
startMigrate := time.Now()
migrateToUUIDs()
migrateDuration := time.Since(startMigrate)
t.Logf("MigrateToUUIDs() completed in %v for %d users (avg: %v per user)",
migrateDuration, len(users), migrateDuration/time.Duration(len(users)))
// Verify users still exist after migration
usersAfter, err := logic.GetUsersDB()
require.NoError(t, err)
assert.Equal(t, len(users), len(usersAfter), "User count should remain the same")
// Performance assertion - MigrateToUUIDs should complete in reasonable time
maxDuration := 30 * time.Second
assert.Less(t, migrateDuration, maxDuration,
"MigrateToUUIDs() took too long: %v (expected < %v)", migrateDuration, maxDuration)
t.Logf("✓ MigrateToUUIDs() performance test passed: %v for %d users", migrateDuration, len(users))
}
// TestSyncUsersCorrectness tests that syncUsers() correctly migrates user data
func TestSyncUsersCorrectness(t *testing.T) {
// Initialize test database
err := db.InitializeDB(schema.ListModels()...)
require.NoError(t, err)
defer db.CloseDB()
err = database.InitializeDatabase()
require.NoError(t, err)
defer database.CloseDB()
// Create test users with various states
testCases := []struct {
name string
user models.User
expectedRole models.UserRoleID
expectedAdmin bool
expectedSuper bool
}{
{
name: "user with AdminRole but IsAdmin=false",
user: models.User{
UserName: "admin1",
Password: "password",
PlatformRoleID: models.AdminRole,
IsAdmin: false,
IsSuperAdmin: false,
},
expectedRole: models.AdminRole,
expectedAdmin: true,
expectedSuper: false,
},
{
name: "user with SuperAdminRole but IsSuperAdmin=false",
user: models.User{
UserName: "superadmin1",
Password: "password",
PlatformRoleID: models.SuperAdminRole,
IsAdmin: false,
IsSuperAdmin: false,
},
expectedRole: models.SuperAdminRole,
expectedAdmin: true,
expectedSuper: true,
},
{
name: "user with IsSuperAdmin=true but no PlatformRoleID",
user: models.User{
UserName: "superadmin2",
Password: "password",
PlatformRoleID: "",
IsAdmin: true,
IsSuperAdmin: true,
},
expectedRole: models.SuperAdminRole,
expectedAdmin: true,
expectedSuper: true,
},
{
name: "user with IsAdmin=true but no PlatformRoleID",
user: models.User{
UserName: "admin2",
Password: "password",
PlatformRoleID: "",
IsAdmin: true,
IsSuperAdmin: false,
},
expectedRole: models.AdminRole,
expectedAdmin: true,
expectedSuper: false,
},
{
name: "regular user with no role",
user: models.User{
UserName: "user1",
Password: "password",
PlatformRoleID: "",
IsAdmin: false,
IsSuperAdmin: false,
},
expectedRole: models.ServiceUser,
expectedAdmin: false,
expectedSuper: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Create user
err := logic.UpsertUser(tc.user)
require.NoError(t, err)
// Run migration
syncUsers()
// Verify migration
user, err := logic.GetUser(tc.user.UserName)
require.NoError(t, err)
assert.Equal(t, tc.expectedRole, user.PlatformRoleID, "PlatformRoleID should match")
assert.Equal(t, tc.expectedAdmin, user.IsAdmin, "IsAdmin should match")
assert.Equal(t, tc.expectedSuper, user.IsSuperAdmin, "IsSuperAdmin should match")
assert.NotNil(t, user.UserGroups, "UserGroups should be initialized")
assert.NotNil(t, user.NetworkRoles, "NetworkRoles should be initialized")
// Cleanup
_ = database.DeleteRecord(database.USERS_TABLE_NAME, tc.user.UserName)
})
}
}
// BenchmarkSyncUsers benchmarks syncUsers() performance
@@ -344,13 +194,13 @@ func BenchmarkSyncUsers(b *testing.B) {
// Create test users
numUsers := 1000
for i := 0; i < numUsers; i++ {
user := models.User{
UserName: "benchuser" + uuid.New().String()[:8],
user := schema.User{
Username: "benchuser" + uuid.New().String()[:8],
Password: "password",
PlatformRoleID: models.ServiceUser,
PlatformRoleID: schema.ServiceUser,
}
if i%10 == 0 {
user.PlatformRoleID = models.AdminRole
user.PlatformRoleID = schema.AdminRole
}
_ = logic.UpsertUser(user)
}
+354
View File
@@ -0,0 +1,354 @@
package migrate
import (
"context"
"encoding/json"
"errors"
"fmt"
"time"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/db"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/schema"
"gorm.io/datatypes"
"gorm.io/gorm"
)
// ToSQLSchema migrates the data from key-value
// db to sql db.
func ToSQLSchema() error {
// begin a new transaction.
dbctx := db.BeginTx(context.TODO())
commit := false
defer func() {
if commit {
db.FromContext(dbctx).Commit()
} else {
db.FromContext(dbctx).Rollback()
}
}()
// v1.5.1 migration includes migrating the users, groups, roles, networks and hosts tables.
// future table migrations should be made below this block,
// with a different version number and a similar check for whether the
// migration was already done.
migrationJob := &schema.Job{
ID: "migration-v1.5.1",
}
err := migrationJob.Get(dbctx)
if err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
logger.Log(1, fmt.Sprintf("running migration job %s", migrationJob.ID))
// migrate.
err = migrateV1_5_1(dbctx)
if err != nil {
return err
}
// mark migration job completed.
err = migrationJob.Create(dbctx)
if err != nil {
return err
}
logger.Log(1, fmt.Sprintf("migration job %s completed", migrationJob.ID))
commit = true
} else {
logger.Log(1, fmt.Sprintf("migration job %s already completed, skipping", migrationJob.ID))
}
return nil
}
func migrateV1_5_1(ctx context.Context) error {
err := migrateUsers(ctx)
if err != nil {
return err
}
err = migrateNetworks(ctx)
if err != nil {
return err
}
err = migrateUserRoles(ctx)
if err != nil {
return err
}
err = migrateUserGroups(ctx)
if err != nil {
return err
}
return migrateHosts(ctx)
}
func migrateUsers(ctx context.Context) error {
records, err := database.FetchRecords(database.USERS_TABLE_NAME)
if err != nil && !database.IsEmptyRecord(err) {
return err
}
for _, record := range records {
var user models.User
err = json.Unmarshal([]byte(record), &user)
if err != nil {
return err
}
platformRoleID := user.PlatformRoleID
if user.PlatformRoleID == "" {
if user.IsSuperAdmin {
platformRoleID = schema.SuperAdminRole
} else if user.IsAdmin {
platformRoleID = schema.AdminRole
} else {
platformRoleID = schema.ServiceUser
}
}
_user := schema.User{
ID: "",
Username: user.UserName,
DisplayName: user.DisplayName,
PlatformRoleID: platformRoleID,
ExternalIdentityProviderID: user.ExternalIdentityProviderID,
AccountDisabled: user.AccountDisabled,
AuthType: user.AuthType,
Password: user.Password,
IsMFAEnabled: user.IsMFAEnabled,
TOTPSecret: user.TOTPSecret,
LastLoginAt: user.LastLoginTime,
UserGroups: datatypes.NewJSONType(user.UserGroups),
CreatedBy: user.CreatedBy,
CreatedAt: user.CreatedAt,
UpdatedAt: user.UpdatedAt,
}
logger.Log(4, fmt.Sprintf("migrating user %s", _user.Username))
err = _user.Create(ctx)
if err != nil {
logger.Log(4, fmt.Sprintf("migrating user %s failed: %v", _user.Username, err))
return err
}
}
return nil
}
func migrateNetworks(ctx context.Context) error {
records, err := database.FetchRecords(database.NETWORKS_TABLE_NAME)
if err != nil && !database.IsEmptyRecord(err) {
return err
}
for _, record := range records {
var network models.Network
err = json.Unmarshal([]byte(record), &network)
if err != nil {
return err
}
var autoJoin, autoRemove, jitEnabled bool
if network.AutoJoin == "false" {
autoJoin = false
} else {
autoJoin = true
}
if network.AutoRemove == "true" {
autoRemove = true
} else {
autoRemove = false
}
if network.JITEnabled == "yes" {
jitEnabled = true
} else {
jitEnabled = false
}
_network := &schema.Network{
ID: "",
Name: network.NetID,
AddressRange: network.AddressRange,
AddressRange6: network.AddressRange6,
DefaultKeepAlive: int(network.DefaultKeepalive),
DefaultACL: network.DefaultACL,
DefaultMTU: network.DefaultMTU,
AutoJoin: autoJoin,
AutoRemove: autoRemove,
AutoRemoveTags: network.AutoRemoveTags,
AutoRemoveThreshold: network.AutoRemoveThreshold,
JITEnabled: jitEnabled,
VirtualNATPoolIPv4: network.VirtualNATPoolIPv4,
VirtualNATSitePrefixLenIPv4: network.VirtualNATSitePrefixLenIPv4,
NodesUpdatedAt: time.Unix(network.NodesLastModified, 0),
CreatedBy: network.CreatedBy,
CreatedAt: network.CreatedAt,
UpdatedAt: time.Unix(network.NetworkLastModified, 0),
}
logger.Log(4, fmt.Sprintf("migrating network %s", _network.Name))
err = _network.Create(ctx)
if err != nil {
logger.Log(4, fmt.Sprintf("migrating network %s failed: %v", _network.Name, err))
return err
}
}
return nil
}
func migrateUserRoles(ctx context.Context) error {
records, err := database.FetchRecords(database.USER_PERMISSIONS_TABLE_NAME)
if err != nil && !database.IsEmptyRecord(err) {
return err
}
for _, record := range records {
var _userRole schema.UserRole
err = json.Unmarshal([]byte(record), &_userRole)
if err != nil {
return err
}
logger.Log(4, fmt.Sprintf("migrating user role %s", _userRole.ID))
err = _userRole.Create(ctx)
if err != nil {
logger.Log(4, fmt.Sprintf("migrating user role %s failed: %v", _userRole.ID, err))
return err
}
}
return nil
}
func migrateUserGroups(ctx context.Context) error {
records, err := database.FetchRecords(database.USER_GROUPS_TABLE_NAME)
if err != nil && !database.IsEmptyRecord(err) {
return err
}
for _, record := range records {
var _userGroup schema.UserGroup
err = json.Unmarshal([]byte(record), &_userGroup)
if err != nil {
return err
}
logger.Log(4, fmt.Sprintf("migrating user group %s", _userGroup.ID))
err = _userGroup.Create(ctx)
if err != nil {
logger.Log(4, fmt.Sprintf("migrating user group %s failed: %v", _userGroup.ID, err))
return err
}
}
return nil
}
func migrateHosts(ctx context.Context) error {
records, err := database.FetchRecords(database.HOSTS_TABLE_NAME)
if err != nil && !database.IsEmptyRecord(err) {
return err
}
for _, record := range records {
var host models.Host
err = json.Unmarshal([]byte(record), &host)
if err != nil {
return err
}
_host := &schema.Host{
ID: host.ID,
Verbosity: host.Verbosity,
FirewallInUse: host.FirewallInUse,
Version: host.Version,
IPForwarding: host.IPForwarding,
DaemonInstalled: host.DaemonInstalled,
AutoUpdate: host.AutoUpdate,
HostPass: host.HostPass,
Name: host.Name,
OS: host.OS,
OSFamily: host.OSFamily,
OSVersion: host.OSVersion,
KernelVersion: host.KernelVersion,
Interface: host.Interface,
Debug: host.Debug,
ListenPort: host.ListenPort,
WgPublicListenPort: host.WgPublicListenPort,
MTU: host.MTU,
PublicKey: schema.WgKey{
Key: host.PublicKey,
},
MacAddress: host.MacAddress,
TrafficKeyPublic: host.TrafficKeyPublic,
Nodes: host.Nodes,
Interfaces: host.Interfaces,
DefaultInterface: host.DefaultInterface,
EndpointIP: host.EndpointIP,
EndpointIPv6: host.EndpointIPv6,
IsDocker: host.IsDocker,
IsK8S: host.IsK8S,
IsStaticPort: host.IsStaticPort,
IsStatic: host.IsStatic,
IsDefault: host.IsDefault,
DNS: host.DNS,
NatType: host.NatType,
TurnEndpoint: nil,
PersistentKeepalive: host.PersistentKeepalive,
Location: host.Location,
CountryCode: host.CountryCode,
EnableFlowLogs: host.EnableFlowLogs,
}
if host.TurnEndpoint != nil {
_host.TurnEndpoint = &schema.AddrPort{
AddrPort: *host.TurnEndpoint,
}
}
if _host.PersistentKeepalive == 0 {
_host.PersistentKeepalive = models.DefaultPersistentKeepAlive
}
if _host.DNS == "" || (_host.DNS != "yes" && _host.DNS != "no") {
if logic.GetServerSettings().ManageDNS {
_host.DNS = "yes"
} else {
_host.DNS = "no"
}
if _host.IsDefault {
_host.DNS = "yes"
}
}
if _host.IsDefault && !_host.AutoUpdate {
_host.AutoUpdate = true
}
logger.Log(4, fmt.Sprintf("migrating host %s", _host.ID))
err = _host.Create(ctx)
if err != nil {
logger.Log(4, fmt.Sprintf("migrating host %s failed: %v", _host.ID, err))
return err
}
}
return nil
}
+3 -1
View File
@@ -3,6 +3,8 @@ package models
import (
"net"
"time"
"github.com/gravitl/netmaker/schema"
)
// AllowedTrafficDirection - allowed direction of traffic
@@ -84,7 +86,7 @@ type Acl struct {
Default bool `json:"default"`
MetaData string `json:"meta_data"`
Name string `json:"name"`
NetworkID NetworkID `json:"network_id"`
NetworkID schema.NetworkID `json:"network_id"`
RuleType AclPolicyType `json:"policy_type"`
Src []AclPolicyTag `json:"src_type"`
Dst []AclPolicyTag `json:"dst_type"`
+6 -4
View File
@@ -4,6 +4,8 @@ import (
"net"
"strings"
"time"
"github.com/gravitl/netmaker/schema"
)
// ApiHost - the host struct for API usage
@@ -47,8 +49,8 @@ type ApiIface struct {
AddressString string `json:"addressString"`
}
// Host.ConvertNMHostToAPI - converts a Netmaker host to an API editable host
func (h *Host) ConvertNMHostToAPI() *ApiHost {
// NewApiHostFromSchemaHost - converts a Netmaker host to an API editable host
func NewApiHostFromSchemaHost(h *schema.Host) *ApiHost {
a := ApiHost{}
a.Debug = h.Debug
a.EndpointIP = h.EndpointIP.String()
@@ -96,8 +98,8 @@ func (h *Host) ConvertNMHostToAPI() *ApiHost {
// APIHost.ConvertAPIHostToNMHost - convert's a given apihost struct to
// a Host struct
func (a *ApiHost) ConvertAPIHostToNMHost(currentHost *Host) *Host {
h := Host{}
func (a *ApiHost) ConvertAPIHostToNMHost(currentHost *schema.Host) *schema.Host {
h := schema.Host{}
h.ID = currentHost.ID
h.HostPass = currentHost.HostPass
h.DaemonInstalled = currentHost.DaemonInstalled
+2 -1
View File
@@ -5,6 +5,7 @@ import (
"time"
"github.com/google/uuid"
"github.com/gravitl/netmaker/schema"
"golang.org/x/exp/slog"
)
@@ -69,7 +70,7 @@ type ApiNode struct {
Location string `json:"location"`
Country string `json:"country"`
PostureChecksViolations []Violation `json:"posture_check_violations"`
PostureCheckVolationSeverityLevel Severity `json:"posture_check_violation_severity_level"`
PostureCheckVolationSeverityLevel schema.Severity `json:"posture_check_violation_severity_level"`
LastEvaluatedAt time.Time `json:"last_evaluated_at"`
}
+13 -18
View File
@@ -1,23 +1,18 @@
package models
type EgressNATMode string
const (
VirtualNAT EgressNATMode = "virtual_nat"
DirectNAT EgressNATMode = "direct_nat"
)
import "github.com/gravitl/netmaker/schema"
type EgressReq struct {
ID string `json:"id"`
Name string `json:"name"`
Network string `json:"network"`
Description string `json:"description"`
Nodes map[string]int `json:"nodes"`
Tags map[string]int `json:"tags"`
Range string `json:"range"`
Domain string `json:"domain"`
Nat bool `json:"nat"`
Mode EgressNATMode `json:"mode"`
Status bool `json:"status"`
IsInetGw bool `json:"is_internet_gateway"`
ID string `json:"id"`
Name string `json:"name"`
Network string `json:"network"`
Description string `json:"description"`
Nodes map[string]int `json:"nodes"`
Tags map[string]int `json:"tags"`
Range string `json:"range"`
Domain string `json:"domain"`
Nat bool `json:"nat"`
Mode schema.EgressNATMode `json:"mode"`
Status bool `json:"status"`
IsInetGw bool `json:"is_internet_gateway"`
}
+2 -1
View File
@@ -6,6 +6,7 @@ import (
"time"
"github.com/google/uuid"
"github.com/gravitl/netmaker/schema"
)
const (
@@ -75,7 +76,7 @@ type APIEnrollmentKey struct {
// RegisterResponse - the response to a successful enrollment register
type RegisterResponse struct {
ServerConf ServerConfig `json:"server_config"`
RequestedHost Host `json:"requested_host"`
RequestedHost schema.Host `json:"requested_host"`
}
// EnrollmentKey.IsValid - checks if the key is still valid to use
+8 -80
View File
@@ -1,84 +1,12 @@
package models
type Action string
const (
Create Action = "CREATE"
Update Action = "UPDATE"
Delete Action = "DELETE"
DeleteAll Action = "DELETE_ALL"
Login Action = "LOGIN"
LogOut Action = "LOGOUT"
Connect Action = "CONNECT"
Sync Action = "SYNC"
RefreshKey Action = "REFRESH_KEY"
RefreshAllKeys Action = "REFRESH_ALL_KEYS"
SyncAll Action = "SYNC_ALL"
UpgradeAll Action = "UPGRADE_ALL"
Disconnect Action = "DISCONNECT"
JoinHostToNet Action = "JOIN_HOST_TO_NETWORK"
RemoveHostFromNet Action = "REMOVE_HOST_FROM_NETWORK"
EnableMFA Action = "ENABLE_MFA"
DisableMFA Action = "DISABLE_MFA"
EnforceMFA Action = "ENFORCE_MFA"
UnenforceMFA Action = "UNENFORCE_MFA"
EnableBasicAuth Action = "ENABLE_BASIC_AUTH"
DisableBasicAuth Action = "DISABLE_BASIC_AUTH"
EnableTelemetry Action = "ENABLE_TELEMETRY"
DisableTelemetry Action = "DISABLE_TELEMETRY"
UpdateClientSettings Action = "UPDATE_CLIENT_SETTINGS"
UpdateAuthenticationSecuritySettings Action = "UPDATE_AUTHENTICATION_SECURITY_SETTINGS"
UpdateMonitoringAndDebuggingSettings Action = "UPDATE_MONITORING_AND_DEBUGGING_SETTINGS"
UpdateSMTPSettings Action = "UPDATE_EMAIL_SETTINGS"
UpdateIDPSettings Action = "UPDATE_IDP_SETTINGS"
EnableFlowLogs Action = "ENABLE_FLOW_LOGS"
DisableFlowLogs Action = "DISABLE_FLOW_LOGS"
GatewayAssign Action = "GATEWAY_ASSIGN"
GatewayUnAssign Action = "GATEWAY_UNASSIGN"
)
type SubjectType string
const (
UserSub SubjectType = "USER"
UserAccessTokenSub SubjectType = "USER_ACCESS_TOKEN"
DeviceSub SubjectType = "DEVICE"
NodeSub SubjectType = "NODE"
GatewaySub SubjectType = "GATEWAY"
SettingSub SubjectType = "SETTING"
AclSub SubjectType = "ACL"
TagSub SubjectType = "TAG"
UserRoleSub SubjectType = "USER_ROLE"
UserGroupSub SubjectType = "USER_GROUP"
UserInviteSub SubjectType = "USER_INVITE"
PendingUserSub SubjectType = "PENDING_USER"
EgressSub SubjectType = "EGRESS"
NetworkSub SubjectType = "NETWORK"
DashboardSub SubjectType = "DASHBOARD"
EnrollmentKeySub SubjectType = "ENROLLMENT_KEY"
ClientAppSub SubjectType = "CLIENT-APP"
NameserverSub SubjectType = "NAMESERVER"
PostureCheckSub SubjectType = "POSTURE_CHECK"
)
func (sub SubjectType) String() string {
return string(sub)
}
type Origin string
const (
Dashboard Origin = "DASHBOARD"
Api Origin = "API"
NMCTL Origin = "NMCTL"
ClientApp Origin = "CLIENT-APP"
)
import "github.com/gravitl/netmaker/schema"
type Subject struct {
ID string `json:"id"`
Name string `json:"name"`
Type SubjectType `json:"subject_type"`
Info interface{} `json:"info"`
ID string `json:"id"`
Name string `json:"name"`
Type schema.SubjectType `json:"subject_type"`
Info interface{} `json:"info"`
}
type Diff struct {
@@ -87,11 +15,11 @@ type Diff struct {
}
type Event struct {
Action Action
Action schema.Action
Source Subject
Origin Origin
Origin schema.Origin
Target Subject
TriggeredBy string
NetworkID NetworkID
NetworkID schema.NetworkID
Diff Diff
}
+3 -1
View File
@@ -3,6 +3,8 @@ package models
import (
"sync"
"time"
"github.com/gravitl/netmaker/schema"
)
// ExtClient - struct for external clients
@@ -37,7 +39,7 @@ type ExtClient struct {
Country string `json:"country"`
Location string `json:"location"` //format: lat,long
PostureChecksViolations []Violation `json:"posture_check_violations"`
PostureCheckVolationSeverityLevel Severity `json:"posture_check_violation_severity_level"`
PostureCheckVolationSeverityLevel schema.Severity `json:"posture_check_violation_severity_level"`
LastEvaluatedAt time.Time `json:"last_evaluated_at"`
JITExpiresAt *time.Time `json:"jit_expires_at,omitempty" bson:"jit_expires_at,omitempty"` // JIT grant expiry time (nil if JIT not enabled or user is admin)
Mutex *sync.Mutex `json:"-"`
+9 -8
View File
@@ -6,6 +6,7 @@ import (
"time"
"github.com/google/uuid"
"github.com/gravitl/netmaker/schema"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
@@ -63,7 +64,7 @@ type Host struct {
MacAddress net.HardwareAddr `json:"macaddress" yaml:"macaddress"`
TrafficKeyPublic []byte `json:"traffickeypublic" yaml:"traffickeypublic"`
Nodes []string `json:"nodes" yaml:"nodes"`
Interfaces []Iface `json:"interfaces" yaml:"interfaces"`
Interfaces []schema.Iface `json:"interfaces" yaml:"interfaces"`
DefaultInterface string `json:"defaultinterface" yaml:"defaultinterface"`
EndpointIP net.IP `json:"endpointip" yaml:"endpointip"`
EndpointIPv6 net.IP `json:"endpointipv6" yaml:"endpointipv6"`
@@ -150,7 +151,7 @@ const (
// HostUpdate - struct for host update
type HostUpdate struct {
Action HostMqAction
Host Host
Host schema.Host
Node Node
Signal Signal
EgressDomain EgressDomain
@@ -182,10 +183,10 @@ type Signal struct {
// RegisterMsg - login message struct for hosts to join via SSO login
type RegisterMsg struct {
RegisterHost Host `json:"host"`
Network string `json:"network,omitempty"`
User string `json:"user,omitempty"`
Password string `json:"password,omitempty"`
JoinAll bool `json:"join_all,omitempty"`
Relay string `json:"relay,omitempty"`
RegisterHost schema.Host `json:"host"`
Network string `json:"network,omitempty"`
User string `json:"user,omitempty"`
Password string `json:"password,omitempty"`
JoinAll bool `json:"join_all,omitempty"`
Relay string `json:"relay,omitempty"`
}
+7 -5
View File
@@ -2,6 +2,8 @@ package models
import (
"time"
"github.com/gravitl/netmaker/schema"
)
// Metrics - metrics struct
@@ -48,11 +50,11 @@ type HostInfoMap map[string]HostNetworkInfo
// HostNetworkInfo - holds info related to host networking (used for client side peer calculations)
type HostNetworkInfo struct {
Interfaces []Iface `json:"interfaces" yaml:"interfaces"`
ListenPort int `json:"listen_port" yaml:"listen_port"`
IsStaticPort bool `json:"is_static_port"`
IsStatic bool `json:"is_static"`
Version string `json:"version"`
Interfaces []schema.Iface `json:"interfaces" yaml:"interfaces"`
ListenPort int `json:"listen_port" yaml:"listen_port"`
IsStaticPort bool `json:"is_static_port"`
IsStatic bool `json:"is_static"`
Version string `json:"version"`
}
// PeerMap - peer map for ids and addresses in metrics
+28 -27
View File
@@ -3,11 +3,12 @@ package models
import (
"net"
"github.com/gravitl/netmaker/schema"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type HostPeerInfo struct {
NetworkPeerIDs map[NetworkID]PeerMap `json:"network_peers"`
NetworkPeerIDs map[schema.NetworkID]PeerMap `json:"network_peers"`
}
type PeerType int
@@ -27,37 +28,37 @@ type PeerIdentity struct {
// HostPeerUpdate - struct for host peer updates
type HostPeerUpdate struct {
Host Host `json:"host"`
Nodes []Node `json:"nodes"`
ChangeDefaultGw bool `json:"change_default_gw"`
DefaultGwIp net.IP `json:"default_gw_ip"`
IsInternetGw bool `json:"is_inet_gw"`
NodeAddrs []net.IPNet `json:"nodes_addrs"`
Server string `json:"server"`
ServerVersion string `json:"serverversion"`
ServerAddrs []ServerAddr `json:"serveraddrs"`
NodePeers []wgtypes.PeerConfig `json:"node_peers"`
Peers []wgtypes.PeerConfig `json:"host_peers"`
PeerIDs PeerMap `json:"peerids"`
HostNetworkInfo HostInfoMap `json:"host_network_info,omitempty"`
EgressRoutes []EgressNetworkRoutes `json:"egress_network_routes"`
FwUpdate FwUpdate `json:"fw_update"`
ReplacePeers bool `json:"replace_peers"`
NameServers []string `json:"name_servers"`
DnsNameservers []Nameserver `json:"dns_nameservers"`
EgressWithDomains []EgressDomain `json:"egress_with_domains"`
AutoRelayNodes map[NetworkID][]Node `json:"auto_relay_nodes"`
GwNodes map[NetworkID][]Node `json:"gw_nodes"`
AddressIdentityMap map[string]PeerIdentity `json:"address_identity_map"`
Host schema.Host `json:"host"`
Nodes []Node `json:"nodes"`
ChangeDefaultGw bool `json:"change_default_gw"`
DefaultGwIp net.IP `json:"default_gw_ip"`
IsInternetGw bool `json:"is_inet_gw"`
NodeAddrs []net.IPNet `json:"nodes_addrs"`
Server string `json:"server"`
ServerVersion string `json:"serverversion"`
ServerAddrs []ServerAddr `json:"serveraddrs"`
NodePeers []wgtypes.PeerConfig `json:"node_peers"`
Peers []wgtypes.PeerConfig `json:"host_peers"`
PeerIDs PeerMap `json:"peerids"`
HostNetworkInfo HostInfoMap `json:"host_network_info,omitempty"`
EgressRoutes []EgressNetworkRoutes `json:"egress_network_routes"`
FwUpdate FwUpdate `json:"fw_update"`
ReplacePeers bool `json:"replace_peers"`
NameServers []string `json:"name_servers"`
DnsNameservers []Nameserver `json:"dns_nameservers"`
EgressWithDomains []EgressDomain `json:"egress_with_domains"`
AutoRelayNodes map[schema.NetworkID][]Node `json:"auto_relay_nodes"`
GwNodes map[schema.NetworkID][]Node `json:"gw_nodes"`
AddressIdentityMap map[string]PeerIdentity `json:"address_identity_map"`
ServerConfig
OldPeerUpdateFields
}
type EgressDomain struct {
ID string `json:"id"`
Node Node `json:"node"`
Host Host `json:"host"`
Domain string `json:"domain"`
ID string `json:"id"`
Node Node `json:"node"`
Host schema.Host `json:"host"`
Domain string `json:"domain"`
}
type Nameserver struct {
IPs []string `json:"ips"`
+3 -143
View File
@@ -1,7 +1,6 @@
package models
import (
"net"
"time"
)
@@ -13,15 +12,9 @@ type Network struct {
NetID string `json:"netid" bson:"netid" validate:"required,min=1,max=32,netid_valid"`
NodesLastModified int64 `json:"nodeslastmodified" bson:"nodeslastmodified" swaggertype:"primitive,integer" format:"int64"`
NetworkLastModified int64 `json:"networklastmodified" bson:"networklastmodified" swaggertype:"primitive,integer" format:"int64"`
DefaultInterface string `json:"defaultinterface" bson:"defaultinterface" validate:"min=1,max=35"`
DefaultListenPort int32 `json:"defaultlistenport,omitempty" bson:"defaultlistenport,omitempty" validate:"omitempty,min=1024,max=65535"`
NodeLimit int32 `json:"nodelimit" bson:"nodelimit"`
DefaultPostDown string `json:"defaultpostdown" bson:"defaultpostdown"`
DefaultKeepalive int32 `json:"defaultkeepalive" bson:"defaultkeepalive" validate:"omitempty,max=1000"`
AllowManualSignUp string `json:"allowmanualsignup" bson:"allowmanualsignup" validate:"checkyesorno"`
IsIPv4 string `json:"isipv4" bson:"isipv4" validate:"checkyesorno"`
IsIPv6 string `json:"isipv6" bson:"isipv6" validate:"checkyesorno"`
DefaultUDPHolePunch string `json:"defaultudpholepunch" bson:"defaultudpholepunch" validate:"checkyesorno"`
DefaultMTU int32 `json:"defaultmtu" bson:"defaultmtu"`
DefaultACL string `json:"defaultacl" bson:"defaultacl" yaml:"defaultacl" validate:"checkyesorno"`
NameServers []string `json:"dns_nameservers"`
@@ -33,145 +26,12 @@ type Network struct {
// VirtualNATPoolIPv4 is the IPv4 CIDR pool from which virtual NAT ranges are allocated for egress gateways
VirtualNATPoolIPv4 string `json:"virtual_nat_pool_ipv4"`
// VirtualNATSitePrefixLenIPv4 is the prefix length (e.g., 24) for individual site allocations from the IPv4 virtual NAT pool
VirtualNATSitePrefixLenIPv4 int `json:"virtual_nat_site_prefixlen_ipv4"`
VirtualNATSitePrefixLenIPv4 int `json:"virtual_nat_site_prefixlen_ipv4"`
CreatedBy string `json:"created_by"`
CreatedAt time.Time
}
// SaveData - sensitive fields of a network that should be kept the same
type SaveData struct { // put sensitive fields here
NetID string `json:"netid" bson:"netid" validate:"required,min=1,max=32,netid_valid"`
}
// Network.SetNodesLastModified - sets nodes last modified on network, depricated
func (network *Network) SetNodesLastModified() {
network.NodesLastModified = time.Now().Unix()
}
// Network.SetNetworkLastModified - sets network last modified time
func (network *Network) SetNetworkLastModified() {
network.NetworkLastModified = time.Now().Unix()
}
// Network.SetDefaults - sets default values for a network struct
func (network *Network) SetDefaults() (upsert bool) {
if network.DefaultUDPHolePunch == "" {
network.DefaultUDPHolePunch = "no"
upsert = true
}
if network.DefaultInterface == "" {
if len(network.NetID) < 33 {
network.DefaultInterface = "nm-" + network.NetID
} else {
network.DefaultInterface = network.NetID
}
upsert = true
}
if network.DefaultListenPort == 0 {
network.DefaultListenPort = 51821
upsert = true
}
if network.NodeLimit == 0 {
network.NodeLimit = 999999999
upsert = true
}
if network.DefaultKeepalive == 0 {
network.DefaultKeepalive = 20
upsert = true
}
if network.AllowManualSignUp == "" {
network.AllowManualSignUp = "no"
upsert = true
}
if network.IsIPv4 == "" {
network.IsIPv4 = "yes"
upsert = true
}
if network.IsIPv6 == "" {
network.IsIPv6 = "no"
upsert = true
}
if network.DefaultMTU == 0 {
network.DefaultMTU = 1280
upsert = true
}
if network.DefaultACL == "" {
network.DefaultACL = "yes"
upsert = true
}
if network.JITEnabled == "" {
network.JITEnabled = "no"
upsert = true
}
return
}
// AssignVirtualNATDefaults determines safe defaults based on VPN CIDR
func (network *Network) AssignVirtualNATDefaults(vpnCIDR string, networkID string) {
const (
cgnatCIDR = "100.64.0.0/10"
fallbackIPv4Pool = "198.18.0.0/15"
defaultIPv4SitePrefix = 24
)
// Parse CGNAT CIDR (should always succeed, but check for safety)
_, cgnatNet, err := net.ParseCIDR(cgnatCIDR)
if err != nil {
// Fallback to default pool if CGNAT parsing fails (shouldn't happen)
network.VirtualNATPoolIPv4 = fallbackIPv4Pool
network.VirtualNATSitePrefixLenIPv4 = defaultIPv4SitePrefix
return
}
var virtualIPv4Pool string
// Parse VPN CIDR - if it fails or is empty, use fallback
if vpnCIDR == "" {
virtualIPv4Pool = fallbackIPv4Pool
} else {
_, vpnNet, err := net.ParseCIDR(vpnCIDR)
if err != nil || vpnNet == nil {
// Invalid VPN CIDR, use fallback
virtualIPv4Pool = fallbackIPv4Pool
} else if !cidrOverlaps(vpnNet, cgnatNet) {
// Safe to reuse VPN CIDR for Virtual NAT
virtualIPv4Pool = vpnCIDR
} else {
// VPN is CGNAT — must not reuse
virtualIPv4Pool = fallbackIPv4Pool
}
}
network.VirtualNATPoolIPv4 = virtualIPv4Pool
network.VirtualNATSitePrefixLenIPv4 = defaultIPv4SitePrefix
}
func cidrOverlaps(a, b *net.IPNet) bool {
if a == nil || b == nil {
return false
}
return a.Contains(b.IP) || b.Contains(a.IP)
}
func (network *Network) GetNetworkNetworkCIDR4() *net.IPNet {
if network.AddressRange == "" {
return nil
}
_, netCidr, _ := net.ParseCIDR(network.AddressRange)
return netCidr
}
func (network *Network) GetNetworkNetworkCIDR6() *net.IPNet {
if network.AddressRange6 == "" {
return nil
}
_, netCidr, _ := net.ParseCIDR(network.AddressRange6)
return netCidr
}
type NetworkStatResp struct {
Network
Hosts int `json:"hosts"`
}
+8 -13
View File
@@ -9,6 +9,7 @@ import (
"time"
"github.com/google/uuid"
"github.com/gravitl/netmaker/schema"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
@@ -55,14 +56,7 @@ var seededRand *rand.Rand = rand.New(
type NodeCheckin struct {
Version string
Connected bool
Ifaces []Iface
}
// Iface struct for local interfaces of a node
type Iface struct {
Name string `json:"name"`
Address net.IPNet `json:"address"`
AddressString string `json:"addressString"`
Ifaces []schema.Iface
}
// CommonNode - represents a commonn node data elements shared by netmaker and netclient
@@ -127,7 +121,7 @@ type Node struct {
Mutex *sync.Mutex `json:"-"`
EgressDetails EgressDetails `json:"-"`
PostureChecksViolations []Violation `json:"posture_check_violations"`
PostureCheckVolationSeverityLevel Severity `json:"posture_check_violation_severity_level"`
PostureCheckVolationSeverityLevel schema.Severity `json:"posture_check_violation_severity_level"`
LastEvaluatedAt time.Time `json:"last_evaluated_at"`
Location string `json:"location"` // Format: "lat,lon"
CountryCode string `json:"country_code"`
@@ -148,7 +142,7 @@ type LegacyNode struct {
Address string `json:"address" bson:"address" yaml:"address" validate:"omitempty,ipv4"`
Address6 string `json:"address6" bson:"address6" yaml:"address6" validate:"omitempty,ipv6"`
LocalAddress string `json:"localaddress" bson:"localaddress" yaml:"localaddress" validate:"omitempty"`
Interfaces []Iface `json:"interfaces" yaml:"interfaces"`
Interfaces []schema.Iface `json:"interfaces" yaml:"interfaces"`
Name string `json:"name" bson:"name" yaml:"name" validate:"omitempty,max=62,in_charset"`
NetworkSettings Network `json:"networksettings" bson:"networksettings" yaml:"networksettings" validate:"-"`
ListenPort int32 `json:"listenport" bson:"listenport" yaml:"listenport" validate:"omitempty,numeric,min=1024,max=65535"`
@@ -541,10 +535,10 @@ func (node *Node) DoesACLDeny() bool {
return node.DefaultACL == "no"
}
func (ln *LegacyNode) ConvertToNewNode() (*Host, *Node) {
func (ln *LegacyNode) ConvertToNewNode() (*schema.Host, *Node) {
var node Node
//host:= logic.GetHost(node.HostID)
var host Host
var host schema.Host
if host.ID.String() == "" {
host.ID = uuid.New()
host.FirewallInUse = ln.FirewallInUse
@@ -554,7 +548,8 @@ func (ln *LegacyNode) ConvertToNewNode() (*Host, *Node) {
host.Name = ln.Name
host.ListenPort = int(ln.ListenPort)
host.MTU = int(ln.MTU)
host.PublicKey, _ = wgtypes.ParseKey(ln.PublicKey)
pubkey, _ := wgtypes.ParseKey(ln.PublicKey)
host.PublicKey = schema.WgKey{Key: pubkey}
host.MacAddress, _ = net.ParseMAC(ln.MacAddress)
host.TrafficKeyPublic = ln.TrafficKeys.Mine
id, _ := uuid.Parse(ln.ID)
+45 -58
View File
@@ -2,10 +2,10 @@ package models
import (
"net"
"strings"
"time"
jwt "github.com/golang-jwt/jwt/v4"
"github.com/gravitl/netmaker/schema"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
@@ -141,6 +141,14 @@ type SuccessResponse struct {
Response interface{}
}
type PaginatedResponse struct {
Data interface{} `json:"data"`
Page int `json:"page"`
PerPage int `json:"per_page"`
Total int `json:"total"`
TotalPages int `json:"total_pages"`
}
// DisplayKey - what is displayed for key
type DisplayKey struct {
Name string `json:"name" bson:"name"`
@@ -194,12 +202,12 @@ type EgressRangeMetric struct {
// from. Might not be always set.
EgressID string `json:"-"`
// EgressName is the name of the egress gateway identified by EgressID. Might not be always set.
EgressName string `json:"-"`
Network string `json:"network"`
VirtualNetwork string `json:"virtual_network"`
RouteMetric uint32 `json:"route_metric"` // preffered range 1-999
Nat bool `json:"nat"`
Mode EgressNATMode `json:"nat_mode"`
EgressName string `json:"-"`
Network string `json:"network"`
VirtualNetwork string `json:"virtual_network"`
RouteMetric uint32 `json:"route_metric"` // preffered range 1-999
Nat bool `json:"nat"`
Mode schema.EgressNATMode `json:"nat_mode"`
}
// EgressGatewayRequest - egress gateway request
@@ -268,31 +276,31 @@ type TrafficKeys struct {
// HostPull - response of a host's pull
type HostPull struct {
Host Host `json:"host" yaml:"host"`
Nodes []Node `json:"nodes" yaml:"nodes"`
Peers []wgtypes.PeerConfig `json:"peers" yaml:"peers"`
ServerConfig ServerConfig `json:"server_config" yaml:"server_config"`
PeerIDs PeerMap `json:"peer_ids,omitempty" yaml:"peer_ids,omitempty"`
HostNetworkInfo HostInfoMap `json:"host_network_info,omitempty" yaml:"host_network_info,omitempty"`
EgressRoutes []EgressNetworkRoutes `json:"egress_network_routes"`
FwUpdate FwUpdate `json:"fw_update"`
ChangeDefaultGw bool `json:"change_default_gw"`
DefaultGwIp net.IP `json:"default_gw_ip"`
IsInternetGw bool `json:"is_inet_gw"`
EndpointDetection bool `json:"endpoint_detection"`
NameServers []string `json:"name_servers"`
EgressWithDomains []EgressDomain `json:"egress_with_domains"`
DnsNameservers []Nameserver `json:"dns_nameservers"`
AutoRelayNodes map[NetworkID][]Node `json:"auto_relay_nodes"`
GwNodes map[NetworkID][]Node `json:"gw_nodes"`
ReplacePeers bool `json:"replace_peers"`
AddressIdentityMap map[string]PeerIdentity `json:"address_identity_map"`
Host schema.Host `json:"host" yaml:"host"`
Nodes []Node `json:"nodes" yaml:"nodes"`
Peers []wgtypes.PeerConfig `json:"peers" yaml:"peers"`
ServerConfig ServerConfig `json:"server_config" yaml:"server_config"`
PeerIDs PeerMap `json:"peer_ids,omitempty" yaml:"peer_ids,omitempty"`
HostNetworkInfo HostInfoMap `json:"host_network_info,omitempty" yaml:"host_network_info,omitempty"`
EgressRoutes []EgressNetworkRoutes `json:"egress_network_routes"`
FwUpdate FwUpdate `json:"fw_update"`
ChangeDefaultGw bool `json:"change_default_gw"`
DefaultGwIp net.IP `json:"default_gw_ip"`
IsInternetGw bool `json:"is_inet_gw"`
EndpointDetection bool `json:"endpoint_detection"`
NameServers []string `json:"name_servers"`
EgressWithDomains []EgressDomain `json:"egress_with_domains"`
DnsNameservers []Nameserver `json:"dns_nameservers"`
AutoRelayNodes map[schema.NetworkID][]Node `json:"auto_relay_nodes"`
GwNodes map[schema.NetworkID][]Node `json:"gw_nodes"`
ReplacePeers bool `json:"replace_peers"`
AddressIdentityMap map[string]PeerIdentity `json:"address_identity_map"`
}
// NodeGet - struct for a single node get response
type NodeGet struct {
Node Node `json:"node" bson:"node" yaml:"node"`
Host Host `json:"host" yaml:"host"`
Host schema.Host `json:"host" yaml:"host"`
Peers []wgtypes.PeerConfig `json:"peers" bson:"peers" yaml:"peers"`
HostPeers []wgtypes.PeerConfig `json:"host_peers" bson:"host_peers" yaml:"host_peers"`
ServerConfig ServerConfig `json:"serverconfig" bson:"serverconfig" yaml:"serverconfig"`
@@ -302,7 +310,7 @@ type NodeGet struct {
// NodeJoinResponse data returned to node in response to join
type NodeJoinResponse struct {
Node Node `json:"node" bson:"node" yaml:"node"`
Host Host `json:"host" yaml:"host"`
Host schema.Host `json:"host" yaml:"host"`
ServerConfig ServerConfig `json:"serverconfig" bson:"serverconfig" yaml:"serverconfig"`
Peers []wgtypes.PeerConfig `json:"peers" bson:"peers" yaml:"peers"`
}
@@ -336,17 +344,6 @@ type ServerConfig struct {
OldAClsSupport bool `json:"-"`
}
// User.NameInCharset - returns if name is in charset below or not
func (user *User) NameInCharSet() bool {
charset := "abcdefghijklmnopqrstuvwxyz1234567890-."
for _, char := range user.UserName {
if !strings.Contains(charset, strings.ToLower(string(char))) {
return false
}
}
return true
}
// ServerIDs - struct to hold server ids.
type ServerIDs struct {
ServerIDs []string `json:"server_ids"`
@@ -354,9 +351,9 @@ type ServerIDs struct {
// JoinData - struct to hold data required for node to join a network on server
type JoinData struct {
Host Host `json:"host" yaml:"host"`
Node Node `json:"node" yaml:"node"`
Key string `json:"key" yaml:"key"`
Host schema.Host `json:"host" yaml:"host"`
Node Node `json:"node" yaml:"node"`
Key string `json:"key" yaml:"key"`
}
// HookFunc - function type for hooks that can accept optional parameters
@@ -484,23 +481,13 @@ type PostureCheckDeviceInfo struct {
AutoUpdate bool
Tags map[TagID]struct{}
IsUser bool
UserGroups map[UserGroupID]struct{}
UserGroups map[schema.UserGroupID]struct{}
}
type Violation struct {
CheckID string `json:"check_id"`
Name string `json:"name"`
Attribute string `json:"attribute"`
Message string `json:"message"`
Severity Severity `json:"severity"`
CheckID string `json:"check_id"`
Name string `json:"name"`
Attribute string `json:"attribute"`
Message string `json:"message"`
Severity schema.Severity `json:"severity"`
}
type Severity int
const (
SeverityUnknown Severity = iota
SeverityLow
SeverityMedium
SeverityHigh
SeverityCritical
)
+12 -10
View File
@@ -3,6 +3,8 @@ package models
import (
"fmt"
"time"
"github.com/gravitl/netmaker/schema"
)
type TagID string
@@ -21,19 +23,19 @@ func (t Tag) GetIDFromName() string {
}
type Tag struct {
ID TagID `json:"id"`
TagName string `json:"tag_name"`
Network NetworkID `json:"network"`
ColorCode string `json:"color_code"`
CreatedBy string `json:"created_by"`
CreatedAt time.Time `json:"created_at"`
ID TagID `json:"id"`
TagName string `json:"tag_name"`
Network schema.NetworkID `json:"network"`
ColorCode string `json:"color_code"`
CreatedBy string `json:"created_by"`
CreatedAt time.Time `json:"created_at"`
}
type CreateTagReq struct {
TagName string `json:"tag_name"`
Network NetworkID `json:"network"`
ColorCode string `json:"color_code"`
TaggedNodes []ApiNode `json:"tagged_nodes"`
TagName string `json:"tag_name"`
Network schema.NetworkID `json:"network"`
ColorCode string `json:"color_code"`
TaggedNodes []ApiNode `json:"tagged_nodes"`
}
type TagListResp struct {
+49 -187
View File
@@ -5,35 +5,17 @@ import (
"time"
jwt "github.com/golang-jwt/jwt/v4"
"github.com/gravitl/netmaker/schema"
)
type NetworkID string
type RsrcType string
type RsrcID string
type UserRoleID string
type UserGroupID string
type AuthType string
type TokenType string
var (
BasicAuth AuthType = "basic_auth"
OAuth AuthType = "oauth"
)
func (r RsrcType) String() string {
return string(r)
}
func (rid RsrcID) String() string {
return string(rid)
}
func GetRAGRoleName(netID, hostName string) string {
return fmt.Sprintf("netID-%s-rag-%s", netID, hostName)
}
func GetRAGRoleID(netID, hostID string) UserRoleID {
return UserRoleID(fmt.Sprintf("netID-%s-rag-%s", netID, hostID))
func GetRAGRoleID(netID, hostID string) schema.UserRoleID {
return schema.UserRoleID(fmt.Sprintf("netID-%s-rag-%s", netID, hostID))
}
func (t TokenType) String() string {
@@ -45,169 +27,49 @@ var (
AccessTokenType TokenType = "access_token"
)
var RsrcTypeMap = map[RsrcType]struct{}{
HostRsrc: {},
RelayRsrc: {},
RemoteAccessGwRsrc: {},
ExtClientsRsrc: {},
InetGwRsrc: {},
EgressGwRsrc: {},
NetworkRsrc: {},
EnrollmentKeysRsrc: {},
UserRsrc: {},
AclRsrc: {},
DnsRsrc: {},
FailOverRsrc: {},
}
const AllNetworks NetworkID = "all_networks"
const (
HostRsrc RsrcType = "host"
RelayRsrc RsrcType = "relay"
RemoteAccessGwRsrc RsrcType = "remote_access_gw"
GatewayRsrc RsrcType = "gateway"
ExtClientsRsrc RsrcType = "extclient"
InetGwRsrc RsrcType = "inet_gw"
EgressGwRsrc RsrcType = "egress"
NetworkRsrc RsrcType = "network"
EnrollmentKeysRsrc RsrcType = "enrollment_key"
UserRsrc RsrcType = "user"
AclRsrc RsrcType = "acl"
TagRsrc RsrcType = "tag"
DnsRsrc RsrcType = "dns"
NameserverRsrc RsrcType = "nameserver"
FailOverRsrc RsrcType = "fail_over"
MetricRsrc RsrcType = "metric"
PostureCheckRsrc RsrcType = "posturecheck"
JitAdminRsrc RsrcType = "jit_admin"
JitUserRsrc RsrcType = "jit_user"
UserActivityRsrc RsrcType = "user_activity"
TrafficFlow RsrcType = "traffic_flow"
)
const (
AllHostRsrcID RsrcID = "all_host"
AllRelayRsrcID RsrcID = "all_relay"
AllRemoteAccessGwRsrcID RsrcID = "all_remote_access_gw"
AllExtClientsRsrcID RsrcID = "all_extclients"
AllInetGwRsrcID RsrcID = "all_inet_gw"
AllEgressGwRsrcID RsrcID = "all_egress"
AllNetworkRsrcID RsrcID = "all_network"
AllEnrollmentKeysRsrcID RsrcID = "all_enrollment_key"
AllUserRsrcID RsrcID = "all_user"
AllDnsRsrcID RsrcID = "all_dns"
AllFailOverRsrcID RsrcID = "all_fail_over"
AllAclsRsrcID RsrcID = "all_acl"
AllTagsRsrcID RsrcID = "all_tag"
AllPostureCheckRsrcID RsrcID = "all_posturecheck"
AllNameserverRsrcID RsrcID = "all_nameserver"
AllJitAdminRsrcID RsrcID = "all_jit_admin"
AllJitUserRsrcID RsrcID = "all_jit_user"
AllUserActivityRsrcID RsrcID = "all_user_activity"
AllTrafficFlowRsrcID RsrcID = "all_traffic_flow"
)
// Pre-Defined User Roles
const (
SuperAdminRole UserRoleID = "super-admin"
AdminRole UserRoleID = "admin"
ServiceUser UserRoleID = "service-user"
PlatformUser UserRoleID = "platform-user"
Auditor UserRoleID = "auditor"
NetworkAdmin UserRoleID = "network-admin"
NetworkUser UserRoleID = "network-user"
)
func (r UserRoleID) String() string {
return string(r)
}
func (g UserGroupID) String() string {
return string(g)
}
func (n NetworkID) String() string {
return string(n)
}
type RsrcPermissionScope struct {
Create bool `json:"create"`
Read bool `json:"read"`
Update bool `json:"update"`
Delete bool `json:"delete"`
VPNaccess bool `json:"vpn_access"`
SelfOnly bool `json:"self_only"`
}
type UserRolePermissionTemplate struct {
ID UserRoleID `json:"id"`
Name string `json:"name"`
Default bool `json:"default"`
MetaData string `json:"meta_data"`
DenyDashboardAccess bool `json:"deny_dashboard_access"`
FullAccess bool `json:"full_access"`
NetworkID NetworkID `json:"network_id"`
NetworkLevelAccess map[RsrcType]map[RsrcID]RsrcPermissionScope `json:"network_level_access"`
GlobalLevelAccess map[RsrcType]map[RsrcID]RsrcPermissionScope `json:"global_level_access"`
}
type CreateGroupReq struct {
Group UserGroup `json:"user_group"`
Members []string `json:"members"`
}
type UserGroup struct {
ID UserGroupID `json:"id"`
ExternalIdentityProviderID string `json:"external_identity_provider_id"`
Default bool `json:"default"`
Name string `json:"name"`
NetworkRoles map[NetworkID]map[UserRoleID]struct{} `json:"network_roles"`
ColorCode string `json:"color_code"`
MetaData string `json:"meta_data"`
}
// User struct - struct for Users
type User struct {
UserName string `json:"username" bson:"username" validate:"min=3,in_charset|email"`
ExternalIdentityProviderID string `json:"external_identity_provider_id"`
IsMFAEnabled bool `json:"is_mfa_enabled"`
TOTPSecret string `json:"totp_secret"`
DisplayName string `json:"display_name"`
AccountDisabled bool `json:"account_disabled"`
Password string `json:"password" bson:"password" validate:"required,min=5"`
IsAdmin bool `json:"isadmin" bson:"isadmin"` // deprecated
IsSuperAdmin bool `json:"issuperadmin"` // deprecated
RemoteGwIDs map[string]struct{} `json:"remote_gw_ids"` // deprecated
AuthType AuthType `json:"auth_type"`
UserGroups map[UserGroupID]struct{} `json:"user_group_ids"`
PlatformRoleID UserRoleID `json:"platform_role_id"`
NetworkRoles map[NetworkID]map[UserRoleID]struct{} `json:"network_roles"`
LastLoginTime time.Time `json:"last_login_time"`
}
type ReturnUserWithRolesAndGroups struct {
ReturnUser
PlatformRole UserRolePermissionTemplate `json:"platform_role"`
UserGroups map[UserGroupID]UserGroup `json:"user_group_ids"`
UserName string `json:"username" bson:"username" validate:"min=3,in_charset|email"`
ExternalIdentityProviderID string `json:"external_identity_provider_id"`
IsMFAEnabled bool `json:"is_mfa_enabled"`
TOTPSecret string `json:"totp_secret"`
DisplayName string `json:"display_name"`
AccountDisabled bool `json:"account_disabled"`
Password string `json:"password" bson:"password" validate:"required,min=5"`
IsAdmin bool `json:"isadmin" bson:"isadmin"` // deprecated
IsSuperAdmin bool `json:"issuperadmin"` // deprecated
RemoteGwIDs map[string]struct{} `json:"remote_gw_ids"` // deprecated
AuthType schema.AuthType `json:"auth_type"`
UserGroups map[schema.UserGroupID]struct{} `json:"user_group_ids"`
PlatformRoleID schema.UserRoleID `json:"platform_role_id"`
NetworkRoles map[schema.NetworkID]map[schema.UserRoleID]struct{} `json:"network_roles"`
LastLoginTime time.Time `json:"last_login_time"`
CreatedBy string `json:"created_by"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// ReturnUser - return user struct
type ReturnUser struct {
UserName string `json:"username"`
ExternalIdentityProviderID string `json:"external_identity_provider_id"`
IsMFAEnabled bool `json:"is_mfa_enabled"`
DisplayName string `json:"display_name"`
AccountDisabled bool `json:"account_disabled"`
IsAdmin bool `json:"isadmin"`
IsSuperAdmin bool `json:"issuperadmin"`
AuthType AuthType `json:"auth_type"`
RemoteGwIDs map[string]struct{} `json:"remote_gw_ids"` // deprecated
UserGroups map[UserGroupID]struct{} `json:"user_group_ids"`
PlatformRoleID UserRoleID `json:"platform_role_id"`
NetworkRoles map[NetworkID]map[UserRoleID]struct{} `json:"network_roles"`
LastLoginTime time.Time `json:"last_login_time"`
NumAccessTokens int `json:"num_access_tokens"`
UserName string `json:"username"`
ExternalIdentityProviderID string `json:"external_identity_provider_id"`
IsMFAEnabled bool `json:"is_mfa_enabled"`
DisplayName string `json:"display_name"`
AccountDisabled bool `json:"account_disabled"`
IsAdmin bool `json:"isadmin"`
IsSuperAdmin bool `json:"issuperadmin"`
AuthType schema.AuthType `json:"auth_type"`
RemoteGwIDs map[string]struct{} `json:"remote_gw_ids"` // deprecated
UserGroups map[schema.UserGroupID]struct{} `json:"user_group_ids"`
PlatformRoleID schema.UserRoleID `json:"platform_role_id"`
NetworkRoles map[schema.NetworkID]map[schema.UserRoleID]struct{} `json:"network_roles"`
LastLoginTime time.Time `json:"last_login_time"`
NumAccessTokens int `json:"num_access_tokens"`
CreatedBy string `json:"created_by"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// UserAuthParams - user auth params struct
@@ -234,7 +96,7 @@ type UserTOTPVerificationParams struct {
// UserClaims - user claims struct
type UserClaims struct {
Role UserRoleID
Role schema.UserRoleID
UserName string
Api string
TokenType TokenType
@@ -243,20 +105,20 @@ type UserClaims struct {
}
type InviteUsersReq struct {
UserEmails []string `json:"user_emails"`
PlatformRoleID string `json:"platform_role_id"`
UserGroups map[UserGroupID]struct{} `json:"user_group_ids"`
NetworkRoles map[NetworkID]map[UserRoleID]struct{} `json:"network_roles"`
UserEmails []string `json:"user_emails"`
PlatformRoleID string `json:"platform_role_id"`
UserGroups map[schema.UserGroupID]struct{} `json:"user_group_ids"`
NetworkRoles map[schema.NetworkID]map[schema.UserRoleID]struct{} `json:"network_roles"`
}
// UserInvite - model for user invite
type UserInvite struct {
Email string `json:"email"`
PlatformRoleID string `json:"platform_role_id"`
UserGroups map[UserGroupID]struct{} `json:"user_group_ids"`
NetworkRoles map[NetworkID]map[UserRoleID]struct{} `json:"network_roles"`
InviteCode string `json:"invite_code"`
InviteURL string `json:"invite_url"`
Email string `json:"email"`
PlatformRoleID string `json:"platform_role_id"`
UserGroups map[schema.UserGroupID]struct{} `json:"user_group_ids"`
NetworkRoles map[schema.NetworkID]map[schema.UserRoleID]struct{} `json:"network_roles"`
InviteCode string `json:"invite_code"`
InviteURL string `json:"invite_url"`
}
// UserMapping - user ip map with groups
+12 -9
View File
@@ -1,17 +1,20 @@
package mq
import (
"context"
"encoding/json"
"net"
mqtt "github.com/eclipse/paho.mqtt.golang"
"github.com/google/uuid"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/db"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/logic/hostactions"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/netclient/ncutils"
"github.com/gravitl/netmaker/schema"
"github.com/gravitl/netmaker/servercfg"
"github.com/gravitl/netmaker/utils"
"golang.org/x/exp/slog"
@@ -60,8 +63,8 @@ func UpdateNode(client mqtt.Client, msg mqtt.Message) {
if ifaceDelta { // reduce number of unneeded updates, by only sending on iface changes
if !newNode.Connected {
err = PublishDeletedNodePeerUpdate(&newNode)
host, err := logic.GetHost(newNode.HostID.String())
if err != nil {
host := &schema.Host{ID: newNode.HostID}
if err := host.Get(db.WithContext(context.TODO())); err != nil {
slog.Error("failed to get host for the node", "nodeid", newNode.ID.String(), "error", err)
return
}
@@ -87,8 +90,8 @@ func UpdateHost(client mqtt.Client, msg mqtt.Message) {
slog.Error("error getting host.ID sent on ", "topic", msg.Topic(), "error", err)
return
}
currentHost, err := logic.GetHost(id)
if err != nil {
currentHost := &schema.Host{ID: uuid.MustParse(id)}
if err := currentHost.Get(db.WithContext(context.TODO())); err != nil {
slog.Error("error getting host", "id", id, "error", err)
return
}
@@ -159,7 +162,7 @@ func UpdateHost(client mqtt.Client, msg mqtt.Message) {
}
}
func DeleteAndCleanupHost(h *models.Host) {
func DeleteAndCleanupHost(h *schema.Host) {
if servercfg.GetBrokerType() == servercfg.EmqxBrokerType {
// delete EMQX credentials for host
if err := emqx.DeleteEmqxUser(h.ID.String()); err != nil {
@@ -180,7 +183,7 @@ func DeleteAndCleanupHost(h *models.Host) {
slog.Error("failed to delete all nodes of host", "id", h.ID, "error", err)
return
}
if err := logic.RemoveHostByID(h.ID.String()); err != nil {
if err := (&schema.Host{ID: h.ID}).Delete(db.WithContext(context.TODO())); err != nil {
slog.Error("failed to delete host", "id", h.ID, "error", err)
return
}
@@ -209,8 +212,8 @@ func SignalPeer(signal models.Signal) {
}
signal.NetworkID = node.Network
signal.IsPro = servercfg.IsPro
peerHost, err := logic.GetHost(signal.ToHostID)
if err != nil {
peerHost := &schema.Host{ID: uuid.MustParse(signal.ToHostID)}
if err := peerHost.Get(db.WithContext(context.TODO())); err != nil {
slog.Error("failed to signal, peer host not found", "error", err)
return
}
@@ -255,7 +258,7 @@ func ClientPeerUpdate(client mqtt.Client, msg mqtt.Message) {
slog.Info("sent peer updates after signal received from", "id", id)
}
func HandleHostCheckin(h, currentHost *models.Host) bool {
func HandleHostCheckin(h, currentHost *schema.Host) bool {
if h == nil {
return false
}
+5 -2
View File
@@ -2,6 +2,7 @@ package mq
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
@@ -11,9 +12,11 @@ import (
"time"
mqtt "github.com/eclipse/paho.mqtt.golang"
"github.com/gravitl/netmaker/db"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/schema"
"github.com/gravitl/netmaker/servercfg"
"golang.org/x/exp/slog"
)
@@ -77,7 +80,7 @@ func SendPullSYN() error {
if err != nil {
return err
}
hosts, err := logic.GetAllHosts()
hosts, err := (&schema.Host{}).ListAll(db.WithContext(context.TODO()))
if err != nil {
return err
}
@@ -125,7 +128,7 @@ func KickOutClients() error {
if err != nil {
return err
}
hosts, err := logic.GetAllHosts()
hosts, err := (&schema.Host{}).ListAll(db.WithContext(context.TODO()))
if err != nil {
slog.Error("failed to migrate emqx: ", "error", err)
return err
+14 -12
View File
@@ -1,6 +1,7 @@
package mq
import (
"context"
"encoding/json"
"errors"
"fmt"
@@ -8,9 +9,11 @@ import (
"time"
"github.com/google/uuid"
"github.com/gravitl/netmaker/db"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/schema"
"github.com/gravitl/netmaker/servercfg"
"golang.org/x/exp/slog"
)
@@ -25,7 +28,7 @@ func PublishPeerUpdate(replacePeers bool) error {
sendDNSSync()
}
hosts, err := logic.GetAllHosts()
hosts, err := (&schema.Host{}).ListAll(db.WithContext(context.TODO()))
if err != nil {
logger.Log(1, "err getting all hosts", err.Error())
return err
@@ -38,7 +41,7 @@ func PublishPeerUpdate(replacePeers bool) error {
for _, host := range hosts {
host := host
time.Sleep(5 * time.Millisecond)
go func(host models.Host) {
go func(host schema.Host) {
if err = PublishSingleHostPeerUpdate(&host, allNodes, nil, nil, replacePeers, nil); err != nil {
id := host.Name
if host.ID != uuid.Nil {
@@ -59,7 +62,7 @@ func PublishDeletedNodePeerUpdate(delNode *models.Node) error {
return nil
}
hosts, err := logic.GetAllHosts()
hosts, err := (&schema.Host{}).ListAll(db.WithContext(context.TODO()))
if err != nil {
logger.Log(1, "err getting all hosts", err.Error())
return err
@@ -84,7 +87,7 @@ func PublishDeletedClientPeerUpdate(delClient *models.ExtClient) error {
return nil
}
hosts, err := logic.GetAllHosts()
hosts, err := (&schema.Host{}).ListAll(db.WithContext(context.TODO()))
if err != nil {
logger.Log(1, "err getting all hosts", err.Error())
return err
@@ -105,7 +108,7 @@ func PublishDeletedClientPeerUpdate(delClient *models.ExtClient) error {
}
// PublishSingleHostPeerUpdate --- determines and publishes a peer update to one host
func PublishSingleHostPeerUpdate(host *models.Host, allNodes []models.Node, deletedNode *models.Node, deletedClients []models.ExtClient, replacePeers bool, wg *sync.WaitGroup) error {
func PublishSingleHostPeerUpdate(host *schema.Host, allNodes []models.Node, deletedNode *models.Node, deletedClients []models.ExtClient, replacePeers bool, wg *sync.WaitGroup) error {
if wg != nil {
defer wg.Done()
}
@@ -136,8 +139,8 @@ func PublishSingleHostPeerUpdate(host *models.Host, allNodes []models.Node, dele
// NodeUpdate -- publishes a node update
func NodeUpdate(node *models.Node) error {
host, err := logic.GetHost(node.HostID.String())
if err != nil {
host := &schema.Host{ID: node.HostID}
if err := host.Get(db.WithContext(context.TODO())); err != nil {
return nil
}
if !servercfg.IsMessageQueueBackend() {
@@ -265,16 +268,15 @@ func SendDNSSyncByNetwork(network string) error {
}
func sendDNSSync() error {
networks, err := logic.GetNetworks()
networks, err := (&schema.Network{}).ListAll(db.WithContext(context.TODO()))
if err == nil && len(networks) > 0 {
for _, v := range networks {
k, err := logic.GetDNS(v.NetID)
k = append(k, logic.EgressDNs(v.NetID)...)
k, err := logic.GetDNS(v.Name)
k = append(k, logic.EgressDNs(v.Name)...)
if err == nil && len(k) > 0 {
err = PushSyncDNS(k)
if err != nil {
slog.Warn("error publishing dns entry data for network ", v.NetID, err.Error())
slog.Warn("error publishing dns entry data for network ", v.Name, err.Error())
}
}
}
+8 -5
View File
@@ -3,6 +3,7 @@ package mq
import (
"bytes"
"compress/gzip"
"context"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
@@ -13,13 +14,15 @@ import (
"strings"
"time"
"github.com/gravitl/netmaker/db"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/netclient/ncutils"
"github.com/gravitl/netmaker/schema"
"golang.org/x/exp/slog"
)
func decryptMsgWithHost(host *models.Host, msg []byte) ([]byte, error) {
func decryptMsgWithHost(host *schema.Host, msg []byte) ([]byte, error) {
if host.OS == models.OS_Types.IoT { // just pass along IoT messages
return msg, nil
}
@@ -44,8 +47,8 @@ func DecryptMsg(node *models.Node, msg []byte) ([]byte, error) {
if len(msg) <= 24 { // make sure message is of appropriate length
return nil, fmt.Errorf("received invalid message from broker %v", msg)
}
host, err := logic.GetHost(node.HostID.String())
if err != nil {
host := &schema.Host{ID: node.HostID}
if err := host.Get(db.WithContext(context.TODO())); err != nil {
return nil, err
}
@@ -105,7 +108,7 @@ func encryptAESGCM(key, plaintext []byte) ([]byte, error) {
return ciphertext, nil
}
func encryptMsg(host *models.Host, msg []byte) ([]byte, error) {
func encryptMsg(host *schema.Host, msg []byte) ([]byte, error) {
if host.OS == models.OS_Types.IoT {
return msg, nil
}
@@ -133,7 +136,7 @@ func encryptMsg(host *models.Host, msg []byte) ([]byte, error) {
return ncutils.Chunk(msg, nodePubKey, serverPrivKey)
}
func publish(host *models.Host, dest string, msg []byte) error {
func publish(host *schema.Host, dest string, msg []byte) error {
var encrypted []byte
var encryptErr error
+25 -21
View File
@@ -9,14 +9,15 @@ import (
"net/http"
"strings"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/models"
proLogic "github.com/gravitl/netmaker/pro/logic"
"github.com/gravitl/netmaker/schema"
"github.com/gravitl/netmaker/servercfg"
"golang.org/x/oauth2"
"golang.org/x/oauth2/microsoft"
"gorm.io/gorm"
)
var azure_ad_functions = map[string]interface{}{
@@ -90,9 +91,10 @@ func handleAzureCallback(w http.ResponseWriter, r *http.Request) {
return
}
user, err := logic.GetUser(content.UserPrincipalName)
user := &schema.User{Username: content.UserPrincipalName}
err = user.Get(r.Context())
if err != nil {
if database.IsEmptyRecord(err) { // user must not exist, so try to make one
if errors.Is(err, gorm.ErrRecordNotFound) { // user must not exist, so try to make one
if inviteExists {
// create user
user, err := proLogic.PrepareOauthUserFromInvite(in)
@@ -100,7 +102,7 @@ func handleAzureCallback(w http.ResponseWriter, r *http.Request) {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
}
user.UserName = content.UserPrincipalName
user.Username = content.UserPrincipalName
user.ExternalIdentityProviderID = string(content.ID)
if err = logic.CreateUser(&user); err != nil {
handleSomethingWentWrong(w)
@@ -114,9 +116,9 @@ func handleAzureCallback(w http.ResponseWriter, r *http.Request) {
return
}
err = logic.InsertPendingUser(&models.User{
UserName: content.Email,
UserName: content.UserPrincipalName,
ExternalIdentityProviderID: string(content.ID),
AuthType: models.OAuth,
AuthType: schema.OAuth,
})
if err != nil {
handleSomethingWentWrong(w)
@@ -132,14 +134,15 @@ func handleAzureCallback(w http.ResponseWriter, r *http.Request) {
} else {
// if user exists, then ensure user's auth type is
// oauth before proceeding.
if user.AuthType == models.BasicAuth {
if user.AuthType == schema.BasicAuth {
logger.Log(0, "invalid auth type: basic_auth")
handleAuthTypeMismatch(w)
return
}
}
user, err = logic.GetUser(content.UserPrincipalName)
user = &schema.User{Username: content.UserPrincipalName}
err = user.Get(r.Context())
if err != nil {
handleOauthUserNotFound(w)
return
@@ -150,7 +153,8 @@ func handleAzureCallback(w http.ResponseWriter, r *http.Request) {
return
}
userRole, err := logic.GetRole(user.PlatformRoleID)
userRole := &schema.UserRole{ID: user.PlatformRoleID}
err = userRole.Get(r.Context())
if err != nil {
handleSomethingWentWrong(w)
return
@@ -175,23 +179,23 @@ func handleAzureCallback(w http.ResponseWriter, r *http.Request) {
return
}
logic.LogEvent(&models.Event{
Action: models.Login,
Action: schema.Login,
Source: models.Subject{
ID: user.UserName,
Name: user.UserName,
Type: models.UserSub,
ID: user.Username,
Name: user.Username,
Type: schema.UserSub,
},
TriggeredBy: user.UserName,
TriggeredBy: user.Username,
Target: models.Subject{
ID: models.DashboardSub.String(),
Name: models.DashboardSub.String(),
Type: models.DashboardSub,
Info: user,
ID: schema.DashboardSub.String(),
Name: schema.DashboardSub.String(),
Type: schema.DashboardSub,
Info: logic.ToReturnUser(user),
},
Origin: models.Dashboard,
Origin: schema.Dashboard,
})
logger.Log(1, "completed azure OAuth sigin in for", user.UserName)
http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?login="+jwt+"&user="+user.UserName, http.StatusPermanentRedirect)
logger.Log(1, "completed azure OAuth sigin in for", user.Username)
http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?login="+jwt+"&user="+user.Username, http.StatusPermanentRedirect)
}
func getAzureUserInfo(state string, code string) (*OAuthUser, error) {
+28 -23
View File
@@ -9,14 +9,16 @@ import (
"net/http"
"strings"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/db"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/models"
proLogic "github.com/gravitl/netmaker/pro/logic"
"github.com/gravitl/netmaker/schema"
"github.com/gravitl/netmaker/servercfg"
"golang.org/x/oauth2"
"golang.org/x/oauth2/github"
"gorm.io/gorm"
)
var github_functions = map[string]interface{}{
@@ -91,29 +93,30 @@ func handleGithubCallback(w http.ResponseWriter, r *http.Request) {
return
}
// if user exists with provider ID, convert them into email ID
user, err := logic.GetUser(content.Login)
user := &schema.User{Username: content.Login}
err = user.Get(r.Context())
if err == nil {
// if user exists, then ensure user's auth type is
// oauth before proceeding.
if user.AuthType == models.BasicAuth {
if user.AuthType == schema.BasicAuth {
logger.Log(0, "invalid auth type: basic_auth")
handleAuthTypeMismatch(w)
return
}
// checks if user exists with email
_, err := logic.GetUser(content.Email)
emailCheck := &schema.User{Username: content.Email}
err = emailCheck.Get(r.Context())
if err != nil {
user.UserName = content.Email
user.Username = content.Email
user.ExternalIdentityProviderID = content.Login
_ = logic.DeleteUser(content.Login)
_ = logic.UpsertUser(*user)
_ = user.Update(db.WithContext(context.TODO()))
}
}
_, err = logic.GetUser(content.Email)
emailCheck := &schema.User{Username: content.Email}
err = emailCheck.Get(r.Context())
if err != nil {
if database.IsEmptyRecord(err) { // user must not exist, so try to make one
if errors.Is(err, gorm.ErrRecordNotFound) { // user must not exist, so try to make one
if inviteExists {
// create user
user, err := proLogic.PrepareOauthUserFromInvite(in)
@@ -136,7 +139,7 @@ func handleGithubCallback(w http.ResponseWriter, r *http.Request) {
err = logic.InsertPendingUser(&models.User{
UserName: content.Email,
ExternalIdentityProviderID: string(content.ID),
AuthType: models.OAuth,
AuthType: schema.OAuth,
})
if err != nil {
handleSomethingWentWrong(w)
@@ -150,7 +153,8 @@ func handleGithubCallback(w http.ResponseWriter, r *http.Request) {
return
}
}
user, err = logic.GetUser(content.Email)
user = &schema.User{Username: content.Email}
err = user.Get(r.Context())
if err != nil {
handleOauthUserNotFound(w)
return
@@ -161,7 +165,8 @@ func handleGithubCallback(w http.ResponseWriter, r *http.Request) {
return
}
userRole, err := logic.GetRole(user.PlatformRoleID)
userRole := &schema.UserRole{ID: user.PlatformRoleID}
err = userRole.Get(r.Context())
if err != nil {
handleSomethingWentWrong(w)
return
@@ -186,20 +191,20 @@ func handleGithubCallback(w http.ResponseWriter, r *http.Request) {
return
}
logic.LogEvent(&models.Event{
Action: models.Login,
Action: schema.Login,
Source: models.Subject{
ID: user.UserName,
Name: user.UserName,
Type: models.UserSub,
ID: user.Username,
Name: user.Username,
Type: schema.UserSub,
},
TriggeredBy: user.UserName,
TriggeredBy: user.Username,
Target: models.Subject{
ID: models.DashboardSub.String(),
Name: models.DashboardSub.String(),
Type: models.DashboardSub,
Info: user,
ID: schema.DashboardSub.String(),
Name: schema.DashboardSub.String(),
Type: schema.DashboardSub,
Info: logic.ToReturnUser(user),
},
Origin: models.Dashboard,
Origin: schema.Dashboard,
})
logger.Log(1, "completed github OAuth sigin in for", content.Email)
http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?login="+jwt+"&user="+content.Email, http.StatusPermanentRedirect)
+23 -18
View File
@@ -3,20 +3,22 @@ package auth
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/models"
proLogic "github.com/gravitl/netmaker/pro/logic"
"github.com/gravitl/netmaker/schema"
"github.com/gravitl/netmaker/servercfg"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
"gorm.io/gorm"
)
var google_functions = map[string]interface{}{
@@ -93,9 +95,10 @@ func handleGoogleCallback(w http.ResponseWriter, r *http.Request) {
return
}
user, err := logic.GetUser(content.Email)
user := &schema.User{Username: content.Email}
err = user.Get(r.Context())
if err != nil {
if database.IsEmptyRecord(err) { // user must not exist, so try to make one
if errors.Is(err, gorm.ErrRecordNotFound) { // user must not exist, so try to make one
if inviteExists {
// create user
user, err := proLogic.PrepareOauthUserFromInvite(in)
@@ -108,7 +111,7 @@ func handleGoogleCallback(w http.ResponseWriter, r *http.Request) {
handleSomethingWentWrong(w)
return
}
logic.DeleteUserInvite(user.UserName)
logic.DeleteUserInvite(user.Username)
logic.DeletePendingUser(content.Email)
} else {
if !isEmailAllowed(content.Email) {
@@ -118,7 +121,7 @@ func handleGoogleCallback(w http.ResponseWriter, r *http.Request) {
err = logic.InsertPendingUser(&models.User{
UserName: content.Email,
ExternalIdentityProviderID: string(content.ID),
AuthType: models.OAuth,
AuthType: schema.OAuth,
})
if err != nil {
handleSomethingWentWrong(w)
@@ -135,14 +138,15 @@ func handleGoogleCallback(w http.ResponseWriter, r *http.Request) {
} else {
// if user exists, then ensure user's auth type is
// oauth before proceeding.
if user.AuthType == models.BasicAuth {
if user.AuthType == schema.BasicAuth {
logger.Log(0, "invalid auth type: basic_auth")
handleAuthTypeMismatch(w)
return
}
}
user, err = logic.GetUser(content.Email)
user = &schema.User{Username: content.Email}
err = user.Get(r.Context())
if err != nil {
logger.Log(0, "error fetching user: ", err.Error())
handleOauthUserNotFound(w)
@@ -154,7 +158,8 @@ func handleGoogleCallback(w http.ResponseWriter, r *http.Request) {
return
}
userRole, err := logic.GetRole(user.PlatformRoleID)
userRole := &schema.UserRole{ID: user.PlatformRoleID}
err = userRole.Get(r.Context())
if err != nil {
handleSomethingWentWrong(w)
return
@@ -180,20 +185,20 @@ func handleGoogleCallback(w http.ResponseWriter, r *http.Request) {
}
logic.LogEvent(&models.Event{
Action: models.Login,
Action: schema.Login,
Source: models.Subject{
ID: user.UserName,
Name: user.UserName,
Type: models.UserSub,
ID: user.Username,
Name: user.Username,
Type: schema.UserSub,
},
TriggeredBy: user.UserName,
TriggeredBy: user.Username,
Target: models.Subject{
ID: models.DashboardSub.String(),
Name: models.DashboardSub.String(),
Type: models.DashboardSub,
Info: user,
ID: schema.DashboardSub.String(),
Name: schema.DashboardSub.String(),
Type: schema.DashboardSub,
Info: logic.ToReturnUser(user),
},
Origin: models.Dashboard,
Origin: schema.Dashboard,
})
logger.Log(1, "completed google OAuth sigin in for", content.Email)
+8 -5
View File
@@ -2,14 +2,16 @@ package auth
import (
"bytes"
"errors"
"fmt"
"net/http"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/logic/pro/netcache"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/schema"
"gorm.io/gorm"
)
// HandleHeadlessSSOCallback - handle OAuth callback for headless logins such as Netmaker CLI
@@ -60,13 +62,14 @@ func HandleHeadlessSSOCallback(w http.ResponseWriter, r *http.Request) {
handleOauthUserSignUpApprovalPending(w)
return
}
user, err := logic.GetUser(userClaims.getUserName())
user := &schema.User{Username: userClaims.getUserName()}
err = user.Get(r.Context())
if err != nil {
if database.IsEmptyRecord(err) { // user must not exist, so try to make one
if errors.Is(err, gorm.ErrRecordNotFound) { // user must not exist, so try to make one
err = logic.InsertPendingUser(&models.User{
UserName: userClaims.getUserName(),
ExternalIdentityProviderID: string(userClaims.ID),
AuthType: models.OAuth,
AuthType: schema.OAuth,
})
if err != nil {
handleSomethingWentWrong(w)
@@ -88,7 +91,7 @@ func HandleHeadlessSSOCallback(w http.ResponseWriter, r *http.Request) {
return
}
jwt, jwtErr := logic.VerifyAuthRequest(models.UserAuthParams{
UserName: user.UserName,
UserName: user.Username,
Password: newPass,
}, logic.NetclientApp)
if jwtErr != nil {
+23 -18
View File
@@ -2,19 +2,21 @@ package auth
import (
"context"
"errors"
"fmt"
"net/http"
"strings"
"time"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/models"
proLogic "github.com/gravitl/netmaker/pro/logic"
"github.com/gravitl/netmaker/schema"
"github.com/gravitl/netmaker/servercfg"
"golang.org/x/oauth2"
"gorm.io/gorm"
)
const OIDC_TIMEOUT = 10 * time.Second
@@ -102,9 +104,10 @@ func handleOIDCCallback(w http.ResponseWriter, r *http.Request) {
return
}
user, err := logic.GetUser(content.Email)
user := &schema.User{Username: content.Email}
err = user.Get(r.Context())
if err != nil {
if database.IsEmptyRecord(err) { // user must not exist, so try to make one
if errors.Is(err, gorm.ErrRecordNotFound) { // user must not exist, so try to make one
if inviteExists {
// create user
user, err := proLogic.PrepareOauthUserFromInvite(in)
@@ -117,7 +120,7 @@ func handleOIDCCallback(w http.ResponseWriter, r *http.Request) {
handleSomethingWentWrong(w)
return
}
logic.DeleteUserInvite(user.UserName)
logic.DeleteUserInvite(user.Username)
logic.DeletePendingUser(content.Email)
} else {
if !isEmailAllowed(content.Email) {
@@ -127,7 +130,7 @@ func handleOIDCCallback(w http.ResponseWriter, r *http.Request) {
err = logic.InsertPendingUser(&models.User{
UserName: content.Email,
ExternalIdentityProviderID: string(content.ID),
AuthType: models.OAuth,
AuthType: schema.OAuth,
})
if err != nil {
handleSomethingWentWrong(w)
@@ -143,14 +146,15 @@ func handleOIDCCallback(w http.ResponseWriter, r *http.Request) {
} else {
// if user exists, then ensure user's auth type is
// oauth before proceeding.
if user.AuthType == models.BasicAuth {
if user.AuthType == schema.BasicAuth {
logger.Log(0, "invalid auth type: basic_auth")
handleAuthTypeMismatch(w)
return
}
}
user, err = logic.GetUser(content.Email)
user = &schema.User{Username: content.Email}
err = user.Get(r.Context())
if err != nil {
handleOauthUserNotFound(w)
return
@@ -161,7 +165,8 @@ func handleOIDCCallback(w http.ResponseWriter, r *http.Request) {
return
}
userRole, err := logic.GetRole(user.PlatformRoleID)
userRole := &schema.UserRole{ID: user.PlatformRoleID}
err = userRole.Get(r.Context())
if err != nil {
handleSomethingWentWrong(w)
return
@@ -186,20 +191,20 @@ func handleOIDCCallback(w http.ResponseWriter, r *http.Request) {
return
}
logic.LogEvent(&models.Event{
Action: models.Login,
Action: schema.Login,
Source: models.Subject{
ID: user.UserName,
Name: user.UserName,
Type: models.UserSub,
ID: user.Username,
Name: user.Username,
Type: schema.UserSub,
},
TriggeredBy: user.UserName,
TriggeredBy: user.Username,
Target: models.Subject{
ID: models.DashboardSub.String(),
Name: models.DashboardSub.String(),
Type: models.DashboardSub,
Info: user,
ID: schema.DashboardSub.String(),
Name: schema.DashboardSub.String(),
Type: schema.DashboardSub,
Info: logic.ToReturnUser(user),
},
Origin: models.Dashboard,
Origin: schema.Dashboard,
})
logger.Log(1, "completed OIDC OAuth signin in for", content.Email)
http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?login="+jwt+"&user="+content.Email, http.StatusPermanentRedirect)
+4 -4
View File
@@ -8,9 +8,8 @@ import (
"github.com/gorilla/mux"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/logic/pro/netcache"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/schema"
)
var (
@@ -69,12 +68,13 @@ func HandleHostSSOCallback(w http.ResponseWriter, r *http.Request) {
return
}
// check if user exists
user, err := logic.GetUser(userClaims.getUserName())
user := &schema.User{Username: userClaims.getUserName()}
err = user.Get(r.Context())
if err != nil {
handleOauthUserNotFound(w)
return
}
if user.PlatformRoleID != models.AdminRole && user.PlatformRoleID != models.SuperAdminRole {
if user.PlatformRoleID != schema.AdminRole && user.PlatformRoleID != schema.SuperAdminRole {
response := returnErrTemplate(userClaims.getUserName(), "only admin users can register using SSO", state, reqKeyIf)
w.WriteHeader(http.StatusForbidden)
w.Write(response)
+42 -35
View File
@@ -7,7 +7,7 @@ import (
"sync"
"time"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/db"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/logic"
"github.com/gravitl/netmaker/models"
@@ -17,7 +17,9 @@ import (
"github.com/gravitl/netmaker/pro/idp/google"
"github.com/gravitl/netmaker/pro/idp/okta"
proLogic "github.com/gravitl/netmaker/pro/logic"
"github.com/gravitl/netmaker/schema"
"github.com/gravitl/netmaker/servercfg"
"gorm.io/datatypes"
)
var (
@@ -126,8 +128,8 @@ func SyncFromIDP() error {
}
func syncUsers(idpUsers []idp.User) error {
dbUsers, err := logic.GetUsersDB()
if err != nil && !database.IsEmptyRecord(err) {
dbUsers, err := (&schema.User{}).ListAll(db.WithContext(context.TODO()))
if err != nil {
return err
}
@@ -141,9 +143,9 @@ func syncUsers(idpUsers []idp.User) error {
idpUsersMap[user.Username] = struct{}{}
}
dbUsersMap := make(map[string]models.User)
dbUsersMap := make(map[string]*schema.User)
for _, user := range dbUsers {
dbUsersMap[user.UserName] = user
dbUsersMap[user.Username] = &user
}
filters := logic.GetServerSettings().UserFilters
@@ -151,8 +153,10 @@ func syncUsers(idpUsers []idp.User) error {
for _, user := range idpUsers {
if user.AccountArchived {
// delete the user if it has been archived.
user := dbUsersMap[user.Username]
_ = deleteAndCleanUpUser(&user)
user, ok := dbUsersMap[user.Username]
if ok {
_ = deleteAndCleanUpUser(user)
}
continue
}
@@ -172,14 +176,14 @@ func syncUsers(idpUsers []idp.User) error {
dbUser, ok := dbUsersMap[user.Username]
if !ok {
// create the user only if it doesn't exist.
err = logic.CreateUser(&models.User{
UserName: user.Username,
err = logic.CreateUser(&schema.User{
Username: user.Username,
ExternalIdentityProviderID: user.ID,
DisplayName: user.DisplayName,
AccountDisabled: user.AccountDisabled,
Password: password,
AuthType: models.OAuth,
PlatformRoleID: models.ServiceUser,
AuthType: schema.OAuth,
PlatformRoleID: schema.ServiceUser,
})
if err != nil {
return err
@@ -191,7 +195,7 @@ func syncUsers(idpUsers []idp.User) error {
// created. Now, since the user is created, the pending user
// can be deleted.
_ = logic.DeletePendingUser(user.Username)
} else if dbUser.AuthType == models.OAuth {
} else if dbUser.AuthType == schema.OAuth {
if dbUser.AccountDisabled != user.AccountDisabled ||
dbUser.DisplayName != user.DisplayName ||
dbUser.ExternalIdentityProviderID != user.ID {
@@ -200,7 +204,7 @@ func syncUsers(idpUsers []idp.User) error {
dbUser.DisplayName = user.DisplayName
dbUser.ExternalIdentityProviderID = user.ID
err = logic.UpsertUser(dbUser)
err = logic.UpsertUser(*dbUser)
if err != nil {
return err
}
@@ -213,10 +217,10 @@ func syncUsers(idpUsers []idp.User) error {
for _, user := range dbUsersMap {
if user.ExternalIdentityProviderID != "" {
if _, ok := idpUsersMap[user.UserName]; !ok {
if _, ok := idpUsersMap[user.Username]; !ok {
// delete the user if it has been deleted on idp
// or is filtered out.
err = deleteAndCleanUpUser(&user)
err = deleteAndCleanUpUser(user)
if err != nil {
return err
}
@@ -228,13 +232,13 @@ func syncUsers(idpUsers []idp.User) error {
}
func syncGroups(idpGroups []idp.Group) error {
dbGroups, err := proLogic.ListUserGroups()
if err != nil && !database.IsEmptyRecord(err) {
dbGroups, err := (&schema.UserGroup{}).ListAll(db.WithContext(context.TODO()))
if err != nil {
return err
}
dbUsers, err := logic.GetUsersDB()
if err != nil && !database.IsEmptyRecord(err) {
dbUsers, err := (&schema.User{}).ListAll(db.WithContext(context.TODO()))
if err != nil {
return err
}
@@ -243,17 +247,17 @@ func syncGroups(idpGroups []idp.Group) error {
idpGroupsMap[group.ID] = struct{}{}
}
dbGroupsMap := make(map[string]models.UserGroup)
dbGroupsMap := make(map[string]schema.UserGroup)
for _, group := range dbGroups {
if group.ExternalIdentityProviderID != "" {
dbGroupsMap[group.ExternalIdentityProviderID] = group
}
}
dbUsersMap := make(map[string]models.User)
dbUsersMap := make(map[string]*schema.User)
for _, user := range dbUsers {
if user.ExternalIdentityProviderID != "" {
dbUsersMap[user.ExternalIdentityProviderID] = user
dbUsersMap[user.ExternalIdentityProviderID] = &user
}
}
@@ -261,7 +265,7 @@ func syncGroups(idpGroups []idp.Group) error {
filters := logic.GetServerSettings().GroupFilters
networks, err := logic.GetNetworks()
networks, err := (&schema.Network{}).ListAll(db.WithContext(context.TODO()))
if err != nil {
return err
}
@@ -269,7 +273,7 @@ func syncGroups(idpGroups []idp.Group) error {
var aclsUpdated bool
var acls []models.Acl
for _, network := range networks {
aclID := fmt.Sprintf("%s.%s-grp", network.NetID, models.NetworkUser)
aclID := fmt.Sprintf("%s.%s-grp", network.Name, schema.NetworkUser)
acl, err := logic.GetAcl(aclID)
if err == nil {
acls = append(acls, acl)
@@ -295,7 +299,7 @@ func syncGroups(idpGroups []idp.Group) error {
dbGroup.ExternalIdentityProviderID = group.ID
dbGroup.Name = group.Name
dbGroup.Default = false
dbGroup.NetworkRoles = map[models.NetworkID]map[models.UserRoleID]struct{}{}
dbGroup.NetworkRoles = datatypes.NewJSONType(schema.NetworkRoles{})
err := proLogic.CreateUserGroup(&dbGroup)
if err != nil {
return err
@@ -323,27 +327,30 @@ func syncGroups(idpGroups []idp.Group) error {
for _, user := range dbUsers {
// use dbGroup.Name because the group name may have been changed on idp.
_, inNetmakerGroup := user.UserGroups[dbGroup.ID]
_, inNetmakerGroup := user.UserGroups.Data()[dbGroup.ID]
_, inIDPGroup := groupMembersMap[user.ExternalIdentityProviderID]
if inNetmakerGroup && !inIDPGroup {
// use dbGroup.Name because the group name may have been changed on idp.
delete(dbUsersMap[user.ExternalIdentityProviderID].UserGroups, dbGroup.ID)
delete(dbUsersMap[user.ExternalIdentityProviderID].UserGroups.Data(), dbGroup.ID)
modifiedUsers[user.ExternalIdentityProviderID] = struct{}{}
}
if !inNetmakerGroup && inIDPGroup {
// use dbGroup.Name because the group name may have been changed on idp.
dbUsersMap[user.ExternalIdentityProviderID].UserGroups[dbGroup.ID] = struct{}{}
dbUsersMap[user.ExternalIdentityProviderID].UserGroups.Data()[dbGroup.ID] = struct{}{}
modifiedUsers[user.ExternalIdentityProviderID] = struct{}{}
}
}
}
for userID := range modifiedUsers {
err = logic.UpsertUser(dbUsersMap[userID])
if err != nil {
return err
user, ok := dbUsersMap[userID]
if ok {
err = logic.UpsertUser(*user)
if err != nil {
return err
}
}
}
@@ -456,8 +463,8 @@ func filterGroupsByMembers(idpGroups []idp.Group, idpUsers []idp.User) []idp.Gro
// TODO: deduplicate
// The cyclic import between the package logic and mq requires this
// function to be duplicated in multiple places.
func deleteAndCleanUpUser(user *models.User) error {
err := logic.DeleteUser(user.UserName)
func deleteAndCleanUpUser(user *schema.User) error {
err := logic.DeleteUser(user.Username)
if err != nil {
return err
}
@@ -469,7 +476,7 @@ func deleteAndCleanUpUser(user *models.User) error {
return
}
for _, extclient := range extclients {
if extclient.OwnerID == user.UserName {
if extclient.OwnerID == user.Username {
err = logic.DeleteExtClientAndCleanup(extclient)
if err == nil {
_ = mq.PublishDeletedClientPeerUpdate(&extclient)
@@ -477,7 +484,7 @@ func deleteAndCleanUpUser(user *models.User) error {
}
}
go logic.DeleteUserInvite(user.UserName)
go logic.DeleteUserInvite(user.Username)
go mq.PublishPeerUpdate(false)
if servercfg.IsDNSMode() {
go logic.SetDNS()
+15 -7
View File
@@ -65,7 +65,7 @@ func getAutoRelayGws(w http.ResponseWriter, r *http.Request) {
)
return
}
defaultPolicy, err := logic.GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy)
defaultPolicy, err := logic.GetDefaultPolicy(schema.NetworkID(node.Network), models.DevicePolicy)
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
@@ -205,7 +205,10 @@ func autoRelayME(w http.ResponseWriter, r *http.Request) {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
}
host, err := logic.GetHost(node.HostID.String())
host := &schema.Host{
ID: node.HostID,
}
err = host.Get(r.Context())
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
@@ -243,7 +246,7 @@ func autoRelayME(w http.ResponseWriter, r *http.Request) {
return
}
eli, _ := (&schema.Egress{Network: node.Network}).ListByNetwork(db.WithContext(context.TODO()))
acls, _ := logic.ListAclsByNetwork(models.NetworkID(node.Network))
acls, _ := logic.ListAclsByNetwork(schema.NetworkID(node.Network))
logic.GetNodeEgressInfo(&node, eli, acls)
logic.GetNodeEgressInfo(&peerNode, eli, acls)
logic.GetNodeEgressInfo(&autoRelayNode, eli, acls)
@@ -370,8 +373,10 @@ func autoRelayMEUpdate(w http.ResponseWriter, r *http.Request) {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
}
host, err := logic.GetHost(node.HostID.String())
host := &schema.Host{
ID: node.HostID,
}
err = host.Get(r.Context())
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
@@ -532,7 +537,10 @@ func checkautoRelayCtx(w http.ResponseWriter, r *http.Request) {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
}
host, err := logic.GetHost(node.HostID.String())
host := &schema.Host{
ID: node.HostID,
}
err = host.Get(r.Context())
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
@@ -568,7 +576,7 @@ func checkautoRelayCtx(w http.ResponseWriter, r *http.Request) {
return
}
eli, _ := (&schema.Egress{Network: node.Network}).ListByNetwork(db.WithContext(context.TODO()))
acls, _ := logic.ListAclsByNetwork(models.NetworkID(node.Network))
acls, _ := logic.ListAclsByNetwork(schema.NetworkID(node.Network))
logic.GetNodeEgressInfo(&node, eli, acls)
logic.GetNodeEgressInfo(&peerNode, eli, acls)
logic.GetNodeEgressInfo(&autoRelayNode, eli, acls)
+5 -4
View File
@@ -66,7 +66,7 @@ func listNetworkActivity(w http.ResponseWriter, r *http.Request) {
page, _ := strconv.Atoi(r.URL.Query().Get("page"))
pageSize, _ := strconv.Atoi(r.URL.Query().Get("per_page"))
ctx := db.WithContext(r.Context())
netActivity, err := (&schema.Event{NetworkID: models.NetworkID(netID)}).ListByNetwork(db.SetPagination(ctx, page, pageSize), fromDate, toDate)
netActivity, err := (&schema.Event{NetworkID: schema.NetworkID(netID)}).ListByNetwork(db.SetPagination(ctx, page, pageSize), fromDate, toDate)
if err != nil {
logic.ReturnErrorResponse(w, r, models.ErrorResponse{
Code: http.StatusInternalServerError,
@@ -100,12 +100,13 @@ func listUserActivity(w http.ResponseWriter, r *http.Request) {
})
return
}
caller, err := logic.GetUser(r.Header.Get("user"))
caller := &schema.User{Username: r.Header.Get("user")}
err := caller.Get(r.Context())
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
}
if caller.UserName != username && caller.PlatformRoleID != models.SuperAdminRole && caller.PlatformRoleID != models.AdminRole {
if caller.Username != username && caller.PlatformRoleID != schema.SuperAdminRole && caller.PlatformRoleID != schema.AdminRole {
logic.ReturnErrorResponse(w, r, models.ErrorResponse{
Code: http.StatusForbidden,
Message: "you are not authorized to view this user's activity",
@@ -190,7 +191,7 @@ func listActivity(w http.ResponseWriter, r *http.Request) {
}
}
var events []schema.Event
e := &schema.Event{TriggeredBy: username, NetworkID: models.NetworkID(network)}
e := &schema.Event{TriggeredBy: username, NetworkID: schema.NetworkID(network)}
if username != "" && network != "" {
events, err = e.ListByUserAndNetwork(db.SetPagination(ctx, page, pageSize), fromDate, toDate)
} else if username != "" && network == "" {
+10 -4
View File
@@ -142,7 +142,10 @@ func failOverME(w http.ResponseWriter, r *http.Request) {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
}
host, err := logic.GetHost(node.HostID.String())
host := &schema.Host{
ID: node.HostID,
}
err = host.Get(r.Context())
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
@@ -179,7 +182,7 @@ func failOverME(w http.ResponseWriter, r *http.Request) {
return
}
eli, _ := (&schema.Egress{Network: node.Network}).ListByNetwork(db.WithContext(context.TODO()))
acls, _ := logic.ListAclsByNetwork(models.NetworkID(node.Network))
acls, _ := logic.ListAclsByNetwork(schema.NetworkID(node.Network))
logic.GetNodeEgressInfo(&node, eli, acls)
logic.GetNodeEgressInfo(&peerNode, eli, acls)
logic.GetNodeEgressInfo(&failOverNode, eli, acls)
@@ -295,7 +298,10 @@ func checkfailOverCtx(w http.ResponseWriter, r *http.Request) {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
}
host, err := logic.GetHost(node.HostID.String())
host := &schema.Host{
ID: node.HostID,
}
err = host.Get(r.Context())
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
@@ -331,7 +337,7 @@ func checkfailOverCtx(w http.ResponseWriter, r *http.Request) {
return
}
eli, _ := (&schema.Egress{Network: node.Network}).ListByNetwork(db.WithContext(context.TODO()))
acls, _ := logic.ListAclsByNetwork(models.NetworkID(node.Network))
acls, _ := logic.ListAclsByNetwork(schema.NetworkID(node.Network))
logic.GetNodeEgressInfo(&node, eli, acls)
logic.GetNodeEgressInfo(&peerNode, eli, acls)
logic.GetNodeEgressInfo(&failOverNode, eli, acls)
+87 -78
View File
@@ -1,6 +1,7 @@
package controllers
import (
"context"
"encoding/json"
"errors"
"net/http"
@@ -76,7 +77,8 @@ func handleJIT(w http.ResponseWriter, r *http.Request) {
return
}
user, err := logic.GetUser(username)
user := &schema.User{Username: username}
err := user.Get(r.Context())
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "unauthorized"))
return
@@ -93,7 +95,7 @@ func handleJIT(w http.ResponseWriter, r *http.Request) {
}
// handleJITGet - handles GET requests for JIT status/requests
func handleJITGet(w http.ResponseWriter, r *http.Request, networkID string, user *models.User) {
func handleJITGet(w http.ResponseWriter, r *http.Request, networkID string, user *schema.User) {
statusFilter := r.URL.Query().Get("status") // "pending", "approved", "denied", "expired", or empty for all
// Parse pagination parameters (default to 0, db.SetPagination will apply defaults)
@@ -121,19 +123,19 @@ func handleJITGet(w http.ResponseWriter, r *http.Request, networkID string, user
totalPages = 1
}
response := map[string]interface{}{
"data": requests,
"page": page,
"per_page": pageSize,
"total": total,
"total_pages": totalPages,
response := models.PaginatedResponse{
Data: requests,
Page: page,
PerPage: pageSize,
Total: int(total),
TotalPages: totalPages,
}
logic.ReturnSuccessResponseWithJson(w, r, response, "fetched JIT requests")
}
// handleJITPost - handles POST requests for JIT operations
func handleJITPost(w http.ResponseWriter, r *http.Request, networkID string, user *models.User) {
func handleJITPost(w http.ResponseWriter, r *http.Request, networkID string, user *schema.User) {
var req models.JITOperationRequest
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
@@ -157,7 +159,7 @@ func handleJITPost(w http.ResponseWriter, r *http.Request, networkID string, use
}
// handleEnableJIT - enables JIT on a network
func handleEnableJIT(w http.ResponseWriter, r *http.Request, networkID string, user *models.User) {
func handleEnableJIT(w http.ResponseWriter, r *http.Request, networkID string, user *schema.User) {
// Check if user is admin
if !proLogic.IsNetworkAdmin(user, networkID) {
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("only network admins can enable JIT"), "forbidden"))
@@ -170,27 +172,27 @@ func handleEnableJIT(w http.ResponseWriter, r *http.Request, networkID string, u
}
logic.LogEvent(&models.Event{
Action: models.Update,
Action: schema.Update,
Source: models.Subject{
ID: user.UserName,
Name: user.UserName,
Type: models.UserSub,
ID: user.Username,
Name: user.Username,
Type: schema.UserSub,
},
TriggeredBy: user.UserName,
TriggeredBy: user.Username,
Target: models.Subject{
ID: networkID,
Name: networkID,
Type: models.NetworkSub,
Type: schema.NetworkSub,
},
NetworkID: models.NetworkID(networkID),
Origin: models.Dashboard,
NetworkID: schema.NetworkID(networkID),
Origin: schema.Dashboard,
})
logic.ReturnSuccessResponse(w, r, "JIT enabled on network")
}
// handleDisableJIT - disables JIT on a network
func handleDisableJIT(w http.ResponseWriter, r *http.Request, networkID string, user *models.User) {
func handleDisableJIT(w http.ResponseWriter, r *http.Request, networkID string, user *schema.User) {
// Check if user is admin
if !proLogic.IsNetworkAdmin(user, networkID) {
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("only network admins can disable JIT"), "forbidden"))
@@ -203,27 +205,27 @@ func handleDisableJIT(w http.ResponseWriter, r *http.Request, networkID string,
}
logic.LogEvent(&models.Event{
Action: models.Update,
Action: schema.Update,
Source: models.Subject{
ID: user.UserName,
Name: user.UserName,
Type: models.UserSub,
ID: user.Username,
Name: user.Username,
Type: schema.UserSub,
},
TriggeredBy: user.UserName,
TriggeredBy: user.Username,
Target: models.Subject{
ID: networkID,
Name: networkID,
Type: models.NetworkSub,
Type: schema.NetworkSub,
},
NetworkID: models.NetworkID(networkID),
Origin: models.Dashboard,
NetworkID: schema.NetworkID(networkID),
Origin: schema.Dashboard,
})
logic.ReturnSuccessResponse(w, r, "JIT disabled on network")
}
// handleApproveRequest - approves a JIT request
func handleApproveRequest(w http.ResponseWriter, r *http.Request, networkID string, user *models.User, requestID string, expiresAtEpoch int64) {
func handleApproveRequest(w http.ResponseWriter, r *http.Request, networkID string, user *schema.User, requestID string, expiresAtEpoch int64) {
// Check if user is admin
if !proLogic.IsNetworkAdmin(user, networkID) {
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("only network admins can approve requests"), "forbidden"))
@@ -250,40 +252,41 @@ func handleApproveRequest(w http.ResponseWriter, r *http.Request, networkID stri
return
}
grant, req, err := proLogic.ApproveJITRequest(requestID, expiresAt, user.UserName)
grant, req, err := proLogic.ApproveJITRequest(requestID, expiresAt, user.Username)
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
}
// Send approval email to user
go func() {
network, _ := logic.GetNetwork(networkID)
network := &schema.Network{Name: networkID}
_ = network.Get(r.Context())
if err := email.SendJITApprovalEmail(grant, req, network); err != nil {
slog.Error("failed to send approval notification", "error", err)
}
}()
logic.LogEvent(&models.Event{
Action: models.Update,
Action: schema.Update,
Source: models.Subject{
ID: user.UserName,
Name: user.UserName,
Type: models.UserSub,
ID: user.Username,
Name: user.Username,
Type: schema.UserSub,
},
TriggeredBy: user.UserName,
TriggeredBy: user.Username,
Target: models.Subject{
ID: requestID,
Name: networkID,
Type: models.NetworkSub,
Type: schema.NetworkSub,
},
NetworkID: models.NetworkID(networkID),
Origin: models.Dashboard,
NetworkID: schema.NetworkID(networkID),
Origin: schema.Dashboard,
})
logic.ReturnSuccessResponseWithJson(w, r, grant, "JIT request approved")
}
// handleDenyRequest - denies a JIT request
func handleDenyRequest(w http.ResponseWriter, r *http.Request, networkID string, user *models.User, requestID string) {
func handleDenyRequest(w http.ResponseWriter, r *http.Request, networkID string, user *schema.User, requestID string) {
// Check if user is admin
if !proLogic.IsNetworkAdmin(user, networkID) {
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("only network admins can deny requests"), "forbidden"))
@@ -295,7 +298,7 @@ func handleDenyRequest(w http.ResponseWriter, r *http.Request, networkID string,
return
}
request, err := proLogic.DenyJITRequest(requestID, user.UserName)
request, err := proLogic.DenyJITRequest(requestID, user.Username)
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
@@ -303,27 +306,28 @@ func handleDenyRequest(w http.ResponseWriter, r *http.Request, networkID string,
// Send denial email to requester
go func() {
network, _ := logic.GetNetwork(networkID)
network := &schema.Network{Name: networkID}
_ = network.Get(db.WithContext(context.TODO()))
if err := email.SendJITDeniedEmail(request, network); err != nil {
slog.Error("failed to send JIT denied notification", "error", err)
}
}()
logic.LogEvent(&models.Event{
Action: models.Update,
Action: schema.Update,
Source: models.Subject{
ID: user.UserName,
Name: user.UserName,
Type: models.UserSub,
ID: user.Username,
Name: user.Username,
Type: schema.UserSub,
},
TriggeredBy: user.UserName,
TriggeredBy: user.Username,
Target: models.Subject{
ID: requestID,
Name: networkID,
Type: models.NetworkSub,
Type: schema.NetworkSub,
},
NetworkID: models.NetworkID(networkID),
Origin: models.Dashboard,
NetworkID: schema.NetworkID(networkID),
Origin: schema.Dashboard,
})
logic.ReturnSuccessResponse(w, r, "JIT request denied")
@@ -361,7 +365,8 @@ func deleteJITGrant(w http.ResponseWriter, r *http.Request) {
return
}
user, err := logic.GetUser(username)
user := &schema.User{Username: username}
err := user.Get(r.Context())
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "unauthorized"))
return
@@ -418,9 +423,10 @@ func deleteJITGrant(w http.ResponseWriter, r *http.Request) {
// Send email notification to user
if revokedRequest != nil {
network, err := logic.GetNetwork(networkID)
network := &schema.Network{Name: networkID}
err := network.Get(r.Context())
if err == nil {
if err := email.SendJITExpirationEmail(&grant, revokedRequest, network, true, user.UserName); err != nil {
if err := email.SendJITExpirationEmail(&grant, revokedRequest, network, true, user.Username); err != nil {
slog.Warn("failed to send revocation email", "grant_id", grantID, "user", revokedRequest.UserName, "error", err)
}
}
@@ -432,20 +438,20 @@ func deleteJITGrant(w http.ResponseWriter, r *http.Request) {
}
logic.LogEvent(&models.Event{
Action: models.Delete,
Action: schema.Delete,
Source: models.Subject{
ID: user.UserName,
Name: user.UserName,
Type: models.UserSub,
ID: user.Username,
Name: user.Username,
Type: schema.UserSub,
},
TriggeredBy: user.UserName,
TriggeredBy: user.Username,
Target: models.Subject{
ID: grantID,
Name: networkID,
Type: models.NetworkSub,
Type: schema.NetworkSub,
},
NetworkID: models.NetworkID(networkID),
Origin: models.Dashboard,
NetworkID: schema.NetworkID(networkID),
Origin: schema.Dashboard,
})
logic.ReturnSuccessResponse(w, r, "JIT grant revoked")
@@ -473,24 +479,25 @@ func getUserJITNetworks(w http.ResponseWriter, r *http.Request) {
return
}
user, err := logic.GetUser(username)
user := &schema.User{Username: username}
err := user.Get(r.Context())
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "unauthorized"))
return
}
// Get all networks user has access to
allNetworks, err := logic.GetNetworks()
allNetworks, err := (&schema.Network{}).ListAll(r.Context())
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
}
// Filter networks by user role
userNetworks := logic.FilterNetworksByRole(allNetworks, *user)
userNetworks := logic.FilterNetworksByRole(allNetworks, user)
// Build response with JIT status for each network
networksWithJITStatus, err := proLogic.GetUserJITNetworksStatus(userNetworks, user.UserName)
networksWithJITStatus, err := proLogic.GetUserJITNetworksStatus(userNetworks, user.Username)
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
@@ -525,7 +532,8 @@ func requestJITAccess(w http.ResponseWriter, r *http.Request) {
}
network := r.URL.Query().Get("network")
user, err := logic.GetUser(username)
user := &schema.User{Username: username}
err := user.Get(r.Context())
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "unauthorized"))
return
@@ -550,17 +558,17 @@ func requestJITAccess(w http.ResponseWriter, r *http.Request) {
return
}
// Check if user has access to the network by role
allNetworks, err := logic.GetNetworks()
allNetworks, err := (&schema.Network{}).ListAll(r.Context())
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
return
}
// Filter networks by user role
userNetworks := logic.FilterNetworksByRole(allNetworks, *user)
userNetworks := logic.FilterNetworksByRole(allNetworks, user)
hasAccess := false
for _, network := range userNetworks {
if network.NetID == req.NetworkID {
if network.Name == req.NetworkID {
hasAccess = true
break
}
@@ -572,7 +580,7 @@ func requestJITAccess(w http.ResponseWriter, r *http.Request) {
}
// Create the JIT request
request, err := proLogic.CreateJITRequest(req.NetworkID, user.UserName, req.Reason)
request, err := proLogic.CreateJITRequest(req.NetworkID, user.Username, req.Reason)
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
@@ -580,27 +588,28 @@ func requestJITAccess(w http.ResponseWriter, r *http.Request) {
// Send email notifications to network admins
go func() {
network, _ := logic.GetNetwork(req.NetworkID)
network := &schema.Network{Name: req.NetworkID}
_ = network.Get(r.Context())
if err := email.SendJITRequestEmails(request, network); err != nil {
slog.Error("failed to send JIT request notifications", "error", err)
}
}()
logic.LogEvent(&models.Event{
Action: models.Create,
Action: schema.Create,
Source: models.Subject{
ID: user.UserName,
Name: user.UserName,
Type: models.UserSub,
ID: user.Username,
Name: user.Username,
Type: schema.UserSub,
},
TriggeredBy: user.UserName,
TriggeredBy: user.Username,
Target: models.Subject{
ID: request.ID,
Name: req.NetworkID,
Type: models.NetworkSub,
Type: schema.NetworkSub,
},
NetworkID: models.NetworkID(req.NetworkID),
Origin: models.ClientApp,
NetworkID: schema.NetworkID(req.NetworkID),
Origin: schema.ClientApp,
})
logic.ReturnSuccessResponseWithJson(w, r, request, "JIT access request created")
+17 -17
View File
@@ -92,20 +92,20 @@ func createPostureCheck(w http.ResponseWriter, r *http.Request) {
return
}
logic.LogEvent(&models.Event{
Action: models.Create,
Action: schema.Create,
Source: models.Subject{
ID: r.Header.Get("user"),
Name: r.Header.Get("user"),
Type: models.UserSub,
Type: schema.UserSub,
},
TriggeredBy: r.Header.Get("user"),
Target: models.Subject{
ID: pc.ID,
Name: pc.Name,
Type: models.PostureCheckSub,
Type: schema.PostureCheckSub,
},
NetworkID: models.NetworkID(pc.NetworkID),
Origin: models.Dashboard,
NetworkID: schema.NetworkID(pc.NetworkID),
Origin: schema.Dashboard,
})
go mq.PublishPeerUpdate(false)
@@ -131,7 +131,7 @@ func listPostureChecks(w http.ResponseWriter, r *http.Request) {
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("network is required"), logic.BadReq))
return
}
_, err := logic.GetNetwork(network)
err := (&schema.Network{Name: network}).Get(r.Context())
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("network not found"), logic.BadReq))
return
@@ -151,7 +151,7 @@ func listPostureChecks(w http.ResponseWriter, r *http.Request) {
logic.ReturnSuccessResponseWithJson(w, r, pc, "fetched posture check")
return
}
pc := schema.PostureCheck{NetworkID: models.NetworkID(network)}
pc := schema.PostureCheck{NetworkID: schema.NetworkID(network)}
list, err := pc.ListByNetwork(db.WithContext(r.Context()))
if err != nil {
logic.ReturnErrorResponse(
@@ -202,24 +202,24 @@ func updatePostureCheck(w http.ResponseWriter, r *http.Request) {
updateStatus = true
}
event := &models.Event{
Action: models.Update,
Action: schema.Update,
Source: models.Subject{
ID: r.Header.Get("user"),
Name: r.Header.Get("user"),
Type: models.UserSub,
Type: schema.UserSub,
},
TriggeredBy: r.Header.Get("user"),
Target: models.Subject{
ID: pc.ID,
Name: updatePc.Name,
Type: models.PostureCheckSub,
Type: schema.PostureCheckSub,
},
Diff: models.Diff{
Old: pc,
New: updatePc,
},
NetworkID: models.NetworkID(pc.NetworkID),
Origin: models.Dashboard,
NetworkID: schema.NetworkID(pc.NetworkID),
Origin: schema.Dashboard,
}
pc.Tags = updatePc.Tags
pc.UserGroups = updatePc.UserGroups
@@ -278,20 +278,20 @@ func deletePostureCheck(w http.ResponseWriter, r *http.Request) {
return
}
logic.LogEvent(&models.Event{
Action: models.Delete,
Action: schema.Delete,
Source: models.Subject{
ID: r.Header.Get("user"),
Name: r.Header.Get("user"),
Type: models.UserSub,
Type: schema.UserSub,
},
TriggeredBy: r.Header.Get("user"),
Target: models.Subject{
ID: pc.ID,
Name: pc.Name,
Type: models.PostureCheckSub,
Type: schema.PostureCheckSub,
},
NetworkID: models.NetworkID(pc.NetworkID),
Origin: models.Dashboard,
NetworkID: schema.NetworkID(pc.NetworkID),
Origin: schema.Dashboard,
Diff: models.Diff{
Old: pc,
New: nil,
+19 -17
View File
@@ -15,6 +15,7 @@ import (
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/mq"
proLogic "github.com/gravitl/netmaker/pro/logic"
"github.com/gravitl/netmaker/schema"
)
func TagHandlers(r *mux.Router) {
@@ -44,12 +45,12 @@ func getTags(w http.ResponseWriter, r *http.Request) {
return
}
// check if network exists
_, err := logic.GetNetwork(netID)
err := (&schema.Network{Name: netID}).Get(r.Context())
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
}
tags, err := proLogic.ListTagsWithNodes(models.NetworkID(netID))
tags, err := proLogic.ListTagsWithNodes(schema.NetworkID(netID))
if err != nil {
logger.Log(0, r.Header.Get("user"), "failed to get all network tag entries: ", err.Error())
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
@@ -77,13 +78,14 @@ func createTag(w http.ResponseWriter, r *http.Request) {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
}
user, err := logic.GetUser(r.Header.Get("user"))
user := &schema.User{Username: r.Header.Get("user")}
err = user.Get(r.Context())
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
return
}
// check if tag network exists
_, err = logic.GetNetwork(req.Network.String())
err = (&schema.Network{Name: req.Network.String()}).Get(r.Context())
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("failed to get network details for "+req.Network.String()), "badrequest"))
return
@@ -93,7 +95,7 @@ func createTag(w http.ResponseWriter, r *http.Request) {
ID: models.TagID(fmt.Sprintf("%s.%s", req.Network, req.TagName)),
TagName: req.TagName,
Network: req.Network,
CreatedBy: user.UserName,
CreatedBy: user.Username,
ColorCode: req.ColorCode,
CreatedAt: time.Now().UTC(),
}
@@ -138,20 +140,20 @@ func createTag(w http.ResponseWriter, r *http.Request) {
}
}()
logic.LogEvent(&models.Event{
Action: models.Create,
Action: schema.Create,
Source: models.Subject{
ID: r.Header.Get("user"),
Name: r.Header.Get("user"),
Type: models.UserSub,
Type: schema.UserSub,
},
TriggeredBy: r.Header.Get("user"),
Target: models.Subject{
ID: tag.ID.String(),
Name: tag.TagName,
Type: models.TagSub,
Type: schema.TagSub,
},
NetworkID: tag.Network,
Origin: models.Dashboard,
Origin: schema.Dashboard,
})
go mq.PublishPeerUpdate(false)
@@ -189,23 +191,23 @@ func updateTag(w http.ResponseWriter, r *http.Request) {
return
}
e := &models.Event{
Action: models.Update,
Action: schema.Update,
Source: models.Subject{
ID: r.Header.Get("user"),
Name: r.Header.Get("user"),
Type: models.UserSub,
Type: schema.UserSub,
},
TriggeredBy: r.Header.Get("user"),
Target: models.Subject{
ID: tag.ID.String(),
Name: tag.TagName,
Type: models.TagSub,
Type: schema.TagSub,
},
Diff: models.Diff{
Old: tag,
},
NetworkID: tag.Network,
Origin: models.Dashboard,
Origin: schema.Dashboard,
}
updateTag.NewName = strings.TrimSpace(updateTag.NewName)
var newID models.TagID
@@ -290,20 +292,20 @@ func deleteTag(w http.ResponseWriter, r *http.Request) {
mq.PublishPeerUpdate(false)
}()
logic.LogEvent(&models.Event{
Action: models.Delete,
Action: schema.Delete,
Source: models.Subject{
ID: r.Header.Get("user"),
Name: r.Header.Get("user"),
Type: models.UserSub,
Type: schema.UserSub,
},
TriggeredBy: r.Header.Get("user"),
Target: models.Subject{
ID: tag.ID.String(),
Name: tag.TagName,
Type: models.TagSub,
Type: schema.TagSub,
},
NetworkID: tag.Network,
Origin: models.Dashboard,
Origin: schema.Dashboard,
Diff: models.Diff{
Old: tag,
New: nil,
+299 -200
View File
File diff suppressed because it is too large Load Diff
+2 -2
View File
@@ -3,8 +3,8 @@ package email
import (
"fmt"
"github.com/gravitl/netmaker/models"
proLogic "github.com/gravitl/netmaker/pro/logic"
"github.com/gravitl/netmaker/schema"
"github.com/gravitl/netmaker/servercfg"
)
@@ -38,7 +38,7 @@ func (invite UserInvitedMail) GetBody(info Notification) string {
WithHtml("<br>").
WithHtml(fmt.Sprintf("<li><a href=\"%s\">Download the Netmaker Desktop App</a>.</li>", downloadLink))
if invite.PlatformRoleID == models.AdminRole.String() || invite.PlatformRoleID == models.PlatformUser.String() {
if invite.PlatformRoleID == schema.AdminRole.String() || invite.PlatformRoleID == schema.PlatformUser.String() {
content = content.
WithHtml("<br>").
WithHtml(fmt.Sprintf("<li>Access the <a href=\"%s\">Netmaker Dashboard</a> - use it to manage your network settings and view network status.</li>", dashboardURL))
+5 -6
View File
@@ -5,7 +5,6 @@ import (
"fmt"
"log/slog"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/schema"
)
@@ -14,11 +13,11 @@ type JITApprovedMail struct {
BodyBuilder EmailBodyBuilder
Grant *schema.JITGrant
Request *schema.JITRequest
Network models.Network
Network *schema.Network
}
// SendJITApprovalEmail - sends email notification to user when JIT request is approved
func SendJITApprovalEmail(grant *schema.JITGrant, request *schema.JITRequest, network models.Network) error {
func SendJITApprovalEmail(grant *schema.JITGrant, request *schema.JITRequest, network *schema.Network) error {
mail := JITApprovedMail{
BodyBuilder: &EmailBodyBuilderWithH1HeadlineAndImage{},
Grant: grant,
@@ -40,17 +39,17 @@ func SendJITApprovalEmail(grant *schema.JITGrant, request *schema.JITRequest, ne
// GetSubject - gets the subject of the email
func (mail JITApprovedMail) GetSubject(info Notification) string {
return fmt.Sprintf("JIT Access Approved: %s", mail.Network.NetID)
return fmt.Sprintf("JIT Access Approved: %s", mail.Network.Name)
}
// GetBody - gets the body of the email
func (mail JITApprovedMail) GetBody(info Notification) string {
content := mail.BodyBuilder.
WithHeadline("JIT Access Approved").
WithParagraph(fmt.Sprintf("Your request for Just-In-Time access to network <strong>%s</strong> has been approved.", mail.Network.NetID)).
WithParagraph(fmt.Sprintf("Your request for Just-In-Time access to network <strong>%s</strong> has been approved.", mail.Network.Name)).
WithParagraph("Access Details:").
WithHtml("<ul>").
WithHtml(fmt.Sprintf("<li><strong>Network:</strong> %s</li>", mail.Network.NetID)).
WithHtml(fmt.Sprintf("<li><strong>Network:</strong> %s</li>", mail.Network.Name)).
WithHtml(fmt.Sprintf("<li><strong>Granted At:</strong> %s</li>", formatUTCTime(mail.Grant.GrantedAt))).
WithHtml(fmt.Sprintf("<li><strong>Expires At:</strong> %s</li>", formatUTCTime(mail.Grant.ExpiresAt))).
WithHtml(fmt.Sprintf("<li><strong>Approved By:</strong> %s</li>", mail.Request.ApprovedBy)).
+5 -6
View File
@@ -5,7 +5,6 @@ import (
"fmt"
"log/slog"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/schema"
)
@@ -13,11 +12,11 @@ import (
type JITDeniedMail struct {
BodyBuilder EmailBodyBuilder
Request *schema.JITRequest
Network models.Network
Network *schema.Network
}
// SendJITDeniedEmail - sends email notification to user when JIT request is denied
func SendJITDeniedEmail(request *schema.JITRequest, network models.Network) error {
func SendJITDeniedEmail(request *schema.JITRequest, network *schema.Network) error {
mail := JITDeniedMail{
BodyBuilder: &EmailBodyBuilderWithH1HeadlineAndImage{},
Request: request,
@@ -38,17 +37,17 @@ func SendJITDeniedEmail(request *schema.JITRequest, network models.Network) erro
// GetSubject - gets the subject of the email
func (mail JITDeniedMail) GetSubject(info Notification) string {
return fmt.Sprintf("JIT Access Request Denied: %s", mail.Network.NetID)
return fmt.Sprintf("JIT Access Request Denied: %s", mail.Network.Name)
}
// GetBody - gets the body of the email
func (mail JITDeniedMail) GetBody(info Notification) string {
content := mail.BodyBuilder.
WithHeadline("JIT Access Request Denied").
WithParagraph(fmt.Sprintf("Your request for Just-In-Time access to network <strong>%s</strong> has been denied.", mail.Network.NetID)).
WithParagraph(fmt.Sprintf("Your request for Just-In-Time access to network <strong>%s</strong> has been denied.", mail.Network.Name)).
WithParagraph("Request Details:").
WithHtml("<ul>").
WithHtml(fmt.Sprintf("<li><strong>Network:</strong> %s</li>", mail.Network.NetID)).
WithHtml(fmt.Sprintf("<li><strong>Network:</strong> %s</li>", mail.Network.Name)).
WithHtml(fmt.Sprintf("<li><strong>Requested At:</strong> %s</li>", formatUTCTime(mail.Request.RequestedAt))).
WithHtml(fmt.Sprintf("<li><strong>Denied At:</strong> %s</li>", formatUTCTime(mail.Request.ApprovedAt))).
WithHtml(fmt.Sprintf("<li><strong>Denied By:</strong> %s</li>", mail.Request.ApprovedBy)).
+7 -8
View File
@@ -6,7 +6,6 @@ import (
"log/slog"
"time"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/schema"
)
@@ -15,14 +14,14 @@ type JITExpiredMail struct {
BodyBuilder EmailBodyBuilder
Grant *schema.JITGrant
Request *schema.JITRequest
Network models.Network
Network *schema.Network
IsRevoked bool
RevokedBy string // set when IsRevoked is true
}
// SendJITExpirationEmail - sends email notification to user when JIT grant expires or is revoked
// revokedBy is the username of the admin who revoked the grant; empty when grant expired naturally
func SendJITExpirationEmail(grant *schema.JITGrant, request *schema.JITRequest, network models.Network, isRevoked bool, revokedBy string) error {
func SendJITExpirationEmail(grant *schema.JITGrant, request *schema.JITRequest, network *schema.Network, isRevoked bool, revokedBy string) error {
mail := JITExpiredMail{
BodyBuilder: &EmailBodyBuilderWithH1HeadlineAndImage{},
Grant: grant,
@@ -47,9 +46,9 @@ func SendJITExpirationEmail(grant *schema.JITGrant, request *schema.JITRequest,
// GetSubject - gets the subject of the email
func (mail JITExpiredMail) GetSubject(info Notification) string {
if mail.IsRevoked {
return fmt.Sprintf("JIT Access Revoked: %s", mail.Network.NetID)
return fmt.Sprintf("JIT Access Revoked: %s", mail.Network.Name)
}
return fmt.Sprintf("JIT Access Expired: %s", mail.Network.NetID)
return fmt.Sprintf("JIT Access Expired: %s", mail.Network.Name)
}
// GetBody - gets the body of the email
@@ -57,10 +56,10 @@ func (mail JITExpiredMail) GetBody(info Notification) string {
var headline, message string
if mail.IsRevoked {
headline = "JIT Access Revoked"
message = fmt.Sprintf("Your Just-In-Time access to network <strong>%s</strong> has been revoked by an administrator.", mail.Network.NetID)
message = fmt.Sprintf("Your Just-In-Time access to network <strong>%s</strong> has been revoked by an administrator.", mail.Network.Name)
} else {
headline = "JIT Access Expired"
message = fmt.Sprintf("Your Just-In-Time access to network <strong>%s</strong> has expired.", mail.Network.NetID)
message = fmt.Sprintf("Your Just-In-Time access to network <strong>%s</strong> has expired.", mail.Network.Name)
}
builder := mail.BodyBuilder.
@@ -68,7 +67,7 @@ func (mail JITExpiredMail) GetBody(info Notification) string {
WithParagraph(message).
WithParagraph("Access Details:").
WithHtml("<ul>").
WithHtml(fmt.Sprintf("<li><strong>Network:</strong> %s</li>", mail.Network.NetID)).
WithHtml(fmt.Sprintf("<li><strong>Network:</strong> %s</li>", mail.Network.Name)).
WithHtml(fmt.Sprintf("<li><strong>Granted At:</strong> %s</li>", formatUTCTime(mail.Grant.GrantedAt))).
WithHtml(fmt.Sprintf("<li><strong>Expired At:</strong> %s</li>", formatUTCTime(mail.Grant.ExpiresAt)))
if mail.IsRevoked && mail.RevokedBy != "" {

Some files were not shown because too many files have changed in this diff Show More