From 6749fb45167aacd5adc6bab2e3b2f75cd47baeba Mon Sep 17 00:00:00 2001 From: abhishek9686 Date: Fri, 19 Jan 2024 14:51:51 +0530 Subject: [PATCH] add trial license logic --- database/database.go | 44 ++++++------- logic/telemetry.go | 12 ++-- logic/timer.go | 7 ++- logic/traffic.go | 4 +- pro/initialize.go | 37 ++++++++--- pro/trial.go | 146 +++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 209 insertions(+), 41 deletions(-) create mode 100644 pro/trial.go diff --git a/database/database.go b/database/database.go index c51340ab..dc6385b3 100644 --- a/database/database.go +++ b/database/database.go @@ -124,29 +124,29 @@ func InitializeDatabase() error { } func createTables() { - createTable(NETWORKS_TABLE_NAME) - createTable(NODES_TABLE_NAME) - createTable(CERTS_TABLE_NAME) - createTable(DELETED_NODES_TABLE_NAME) - createTable(USERS_TABLE_NAME) - createTable(DNS_TABLE_NAME) - createTable(EXT_CLIENT_TABLE_NAME) - createTable(PEERS_TABLE_NAME) - createTable(SERVERCONF_TABLE_NAME) - createTable(SERVER_UUID_TABLE_NAME) - createTable(GENERATED_TABLE_NAME) - createTable(NODE_ACLS_TABLE_NAME) - createTable(SSO_STATE_CACHE) - createTable(METRICS_TABLE_NAME) - createTable(NETWORK_USER_TABLE_NAME) - createTable(USER_GROUPS_TABLE_NAME) - createTable(CACHE_TABLE_NAME) - createTable(HOSTS_TABLE_NAME) - createTable(ENROLLMENT_KEYS_TABLE_NAME) - createTable(HOST_ACTIONS_TABLE_NAME) + CreateTable(NETWORKS_TABLE_NAME) + CreateTable(NODES_TABLE_NAME) + CreateTable(CERTS_TABLE_NAME) + CreateTable(DELETED_NODES_TABLE_NAME) + CreateTable(USERS_TABLE_NAME) + CreateTable(DNS_TABLE_NAME) + CreateTable(EXT_CLIENT_TABLE_NAME) + CreateTable(PEERS_TABLE_NAME) + CreateTable(SERVERCONF_TABLE_NAME) + CreateTable(SERVER_UUID_TABLE_NAME) + CreateTable(GENERATED_TABLE_NAME) + CreateTable(NODE_ACLS_TABLE_NAME) + CreateTable(SSO_STATE_CACHE) + CreateTable(METRICS_TABLE_NAME) + CreateTable(NETWORK_USER_TABLE_NAME) + CreateTable(USER_GROUPS_TABLE_NAME) + CreateTable(CACHE_TABLE_NAME) + CreateTable(HOSTS_TABLE_NAME) + CreateTable(ENROLLMENT_KEYS_TABLE_NAME) + CreateTable(HOST_ACTIONS_TABLE_NAME) } -func createTable(tableName string) error { +func CreateTable(tableName string) error { return getCurrentDB()[CREATE_TABLE].(func(string) error)(tableName) } @@ -194,7 +194,7 @@ func DeleteAllRecords(tableName string) error { if err != nil { return err } - err = createTable(tableName) + err = CreateTable(tableName) if err != nil { return err } diff --git a/logic/telemetry.go b/logic/telemetry.go index e4d48030..12b7035c 100644 --- a/logic/telemetry.go +++ b/logic/telemetry.go @@ -32,12 +32,12 @@ func sendTelemetry() error { return nil } - var telRecord, err = fetchTelemetryRecord() + var telRecord, err = FetchTelemetryRecord() if err != nil { return err } // get telemetry data - d, err := fetchTelemetryData() + d, err := FetchTelemetryData() if err != nil { return err } @@ -71,8 +71,8 @@ func sendTelemetry() error { }) } -// fetchTelemetry - fetches telemetry data: count of various object types in DB -func fetchTelemetryData() (telemetryData, error) { +// FetchTelemetryData - fetches telemetry data: count of various object types in DB +func FetchTelemetryData() (telemetryData, error) { var data telemetryData data.IsPro = servercfg.IsPro @@ -138,8 +138,8 @@ func getClientCount(nodes []models.Node) clientCount { return count } -// fetchTelemetryRecord - get the existing UUID and Timestamp from the DB -func fetchTelemetryRecord() (models.Telemetry, error) { +// FetchTelemetryRecord - get the existing UUID and Timestamp from the DB +func FetchTelemetryRecord() (models.Telemetry, error) { var rawData string var telObj models.Telemetry var err error diff --git a/logic/timer.go b/logic/timer.go index 2d0fbb6e..db36f579 100644 --- a/logic/timer.go +++ b/logic/timer.go @@ -3,11 +3,12 @@ package logic import ( "context" "fmt" - "github.com/gravitl/netmaker/logger" - "golang.org/x/exp/slog" "sync" "time" + "github.com/gravitl/netmaker/logger" + "golang.org/x/exp/slog" + "github.com/gravitl/netmaker/models" ) @@ -24,7 +25,7 @@ var HookManagerCh = make(chan models.HookDetails, 3) // TimerCheckpoint - Checks if 24 hours has passed since telemetry was last sent. If so, sends telemetry data to posthog func TimerCheckpoint() error { // get the telemetry record in the DB, which contains a timestamp - telRecord, err := fetchTelemetryRecord() + telRecord, err := FetchTelemetryRecord() if err != nil { return err } diff --git a/logic/traffic.go b/logic/traffic.go index 596bd737..3c065c29 100644 --- a/logic/traffic.go +++ b/logic/traffic.go @@ -2,7 +2,7 @@ package logic // RetrievePrivateTrafficKey - retrieves private key of server func RetrievePrivateTrafficKey() ([]byte, error) { - var telRecord, err = fetchTelemetryRecord() + var telRecord, err = FetchTelemetryRecord() if err != nil { return nil, err } @@ -12,7 +12,7 @@ func RetrievePrivateTrafficKey() ([]byte, error) { // RetrievePublicTrafficKey - retrieves public key of server func RetrievePublicTrafficKey() ([]byte, error) { - var telRecord, err = fetchTelemetryRecord() + var telRecord, err = FetchTelemetryRecord() if err != nil { return nil, err } diff --git a/pro/initialize.go b/pro/initialize.go index 32c89857..ac338110 100644 --- a/pro/initialize.go +++ b/pro/initialize.go @@ -4,6 +4,8 @@ package pro import ( + "time" + controller "github.com/gravitl/netmaker/controllers" "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/models" @@ -17,6 +19,7 @@ import ( // InitPro - Initialize Pro Logic func InitPro() { servercfg.IsPro = true + proLogic.InitTrial() models.SetLogo(retrieveProLogo()) controller.HttpMiddlewares = append( controller.HttpMiddlewares, @@ -31,18 +34,36 @@ func InitPro() { ) logic.EnterpriseCheckFuncs = append(logic.EnterpriseCheckFuncs, func() { // == License Handling == - ClearLicenseCache() - if err := ValidateLicense(); err != nil { - slog.Error(err.Error()) - return + enableLicenseHook := false + trialEndDate, err := getTrialEndDate() + if err != nil { + slog.Error("failed to get trial end date", "error", err) + enableLicenseHook = true } - slog.Info("proceeding with Paid Tier license") - logic.SetFreeTierForTelemetry(false) - // == End License Handling == - AddLicenseHooks() + // check if trial ended + if time.Now().After(trialEndDate) { + // trial ended already + enableLicenseHook = true + } + if enableLicenseHook { + slog.Info("starting license checker") + ClearLicenseCache() + if err := ValidateLicense(); err != nil { + slog.Error(err.Error()) + return + } + slog.Info("proceeding with Paid Tier license") + logic.SetFreeTierForTelemetry(false) + // == End License Handling == + AddLicenseHooks() + } else { + addTrialLicenseHook() + } + if servercfg.GetServerConfig().RacAutoDisable { AddRacHooks() } + }) logic.ResetFailOver = proLogic.ResetFailOver logic.ResetFailedOverPeer = proLogic.ResetFailedOverPeer diff --git a/pro/trial.go b/pro/trial.go new file mode 100644 index 00000000..63e4c83e --- /dev/null +++ b/pro/trial.go @@ -0,0 +1,146 @@ +//go:build ee +// +build ee + +package pro + +import ( + "crypto/rand" + "encoding/json" + "errors" + "time" + + "github.com/gravitl/netmaker/database" + "github.com/gravitl/netmaker/logger" + "github.com/gravitl/netmaker/logic" + "github.com/gravitl/netmaker/models" + "github.com/gravitl/netmaker/netclient/ncutils" + "golang.org/x/crypto/nacl/box" + "golang.org/x/exp/slog" +) + +type TrialInfo struct { + PrivKey []byte `json:"priv_key"` + PubKey []byte `json:"pub_key"` + Secret string `json:"secret"` +} + +func addTrialLicenseHook() { + logic.HookManagerCh <- models.HookDetails{ + Hook: TrialLicenseHook, + Interval: time.Hour, + } +} + +type TrialDates struct { + TrialStartedAt time.Time `json:"trial_started_at"` + TrialEndsAt time.Time `json:"trial_ends_at"` +} + +const trial_table_name = "trial" + +const trial_data_key = "trialdata" + +// store trial date +func InitTrial() error { + telData, err := logic.FetchTelemetryData() + if err != nil { + return err + } + if telData.Hosts > 0 || telData.Networks > 0 || telData.Users > 0 { + return nil + } + err = database.CreateTable(trial_table_name) + if err != nil { + slog.Error("failed to create table", "table name", trial_table_name, "err", err.Error()) + return err + } + // setup encryption keys + trafficPubKey, trafficPrivKey, err := box.GenerateKey(rand.Reader) // generate traffic keys + if err != nil { + return err + } + tPriv, err := ncutils.ConvertKeyToBytes(trafficPrivKey) + if err != nil { + return err + } + + tPub, err := ncutils.ConvertKeyToBytes(trafficPubKey) + if err != nil { + return err + } + trialDates := TrialDates{ + TrialStartedAt: time.Now(), + TrialEndsAt: time.Now().Add(time.Hour * 24 * 30), + } + t := TrialInfo{ + PrivKey: tPriv, + PubKey: tPub, + } + tel, err := logic.FetchTelemetryRecord() + if err != nil { + return err + } + + trialDatesData, err := json.Marshal(trialDates) + if err != nil { + return err + } + trialDatesSecret, err := ncutils.BoxEncrypt(trialDatesData, (*[32]byte)(tel.TrafficKeyPub), (*[32]byte)(t.PrivKey)) + if err != nil { + return err + } + t.Secret = string(trialDatesSecret) + trialData, err := json.Marshal(t) + if err != nil { + return err + } + err = database.Insert(trial_data_key, string(trialData), trial_table_name) + if err != nil { + return err + } + return nil +} + +func TrialLicenseHook() error { + endDate, err := getTrialEndDate() + if err != nil { + logger.FatalLog0("failed to trial end date", err.Error()) + } + if time.Now().After(endDate) { + logger.FatalLog0("***IMPORTANT: Your Trial Has Ended, to continue using pro version, please visit https://app.netmaker.io/ and create on-prem tenant to obtain a license***\nIf you wish to downgrade to community version, please run this command `/root/nm-quick.sh -d`") + + } + return nil +} + +// get trial date +func getTrialEndDate() (time.Time, error) { + record, err := database.FetchRecord(trial_table_name, trial_data_key) + if err != nil { + return time.Time{}, err + } + var trialInfo TrialInfo + err = json.Unmarshal([]byte(record), &trialInfo) + if err != nil { + return time.Time{}, err + } + tel, err := logic.FetchTelemetryRecord() + if err != nil { + return time.Time{}, err + } + // decrypt secret + secretDecrypt, err := ncutils.BoxDecrypt([]byte(trialInfo.Secret), (*[32]byte)(trialInfo.PubKey), (*[32]byte)(tel.TrafficKeyPriv)) + if err != nil { + return time.Time{}, err + } + trialDates := TrialDates{} + err = json.Unmarshal(secretDecrypt, &trialDates) + if err != nil { + return time.Time{}, err + } + if trialDates.TrialEndsAt.IsZero() { + return time.Time{}, errors.New("invalid date") + } + return trialDates.TrialEndsAt, nil + +}