diff --git a/internal/modules/hash/commands.go b/internal/modules/hash/commands.go index b05b871..24e6790 100644 --- a/internal/modules/hash/commands.go +++ b/internal/modules/hash/commands.go @@ -780,6 +780,27 @@ func handleHEXPIRE(params internal.HandlerFuncParams) ([]byte, error) { return []byte(resp), nil } +func handleHEXPIREAT(params internal.HandlerFuncParams) ([]byte, error) { + keys, err := hexpireatKeyFunc(params.Command) + if err != nil { + return nil, err + } + cmdargs := keys.WriteKeys[1:] + epoch, err := strconv.ParseInt(cmdargs[0], 10, 64) + if err != nil { + return nil, errors.New(fmt.Sprintf("seconds must be integer, was provided %q", cmdargs[0])) + } + + if params.GetClock().Now().Unix() > epoch { + params.Command[2] = "0" + return handleHEXPIRE(params) + } + + expireAt := epoch - params.GetClock().Now().Unix() + params.Command[2] = strconv.FormatInt(expireAt, 10) + return handleHEXPIRE(params) +} + func handleHTTL(params internal.HandlerFuncParams) ([]byte, error) { keys, err := httlKeyFunc(params.Command) if err != nil { @@ -1083,6 +1104,15 @@ Return the string length of the values stored at the specified fields. 0 if the KeyExtractionFunc: hexpireKeyFunc, HandlerFunc: handleHEXPIRE, }, + { + Command: "hexpireat", + Module: constants.HashModule, + Categories: []string{constants.HashCategory, constants.WriteCategory, constants.FastCategory}, + Description: `(HEXPIREAT key unix-time-seconds [NX | XX | GT | LT] FIELDS numfields field [field ...]) Sets the exact expiration time from now of a field in a hash.`, + Sync: true, + KeyExtractionFunc: hexpireatKeyFunc, + HandlerFunc: handleHEXPIREAT, + }, { Command: "httl", Module: constants.HashModule, diff --git a/internal/modules/hash/commands_test.go b/internal/modules/hash/commands_test.go index f44250c..f67b5b7 100644 --- a/internal/modules/hash/commands_test.go +++ b/internal/modules/hash/commands_test.go @@ -2243,6 +2243,314 @@ func Test_Hash(t *testing.T) { } + }) + t.Run("Test_HandleHEXPIREAT", func(t *testing.T) { + t.Parallel() + conn, err := internal.GetConnection("localhost", port) + if err != nil { + t.Error(err) + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) + + tests := []struct { + name string + key string + presetValue hash.Hash + command []string + expectedValue string + expectedError error + }{ + + { + name: "1. Set expiration for all keys in hash, no options.", + key: "HexpireAtKey1", + presetValue: hash.Hash{ + "HexpireK1Field1": hash.HashValue{ + Value: "default1", + }, + "HexpireK1Field2": hash.HashValue{ + Value: "default2", + }, + "HexpireK1Field3": hash.HashValue{ + Value: "default3", + }, + }, + command: []string{"HEXPIREAT", "HexpireAtKey1", strconv.FormatInt(mockClock.Now().Unix()+5, 10), "FIELDS", "3", "HexpireK1Field1", "HexpireK1Field2", "HexpireK1Field3"}, + expectedValue: "[1 1 1]", + expectedError: nil, + }, + { + name: "2. Set expiration for one key in hash, no options.", + key: "HexpireAtKey2", + presetValue: hash.Hash{ + "HexpireK2Field1": hash.HashValue{ + Value: "default1", + }, + }, + command: []string{"HEXPIREAT", "HexpireAtKey2", strconv.FormatInt(mockClock.Now().Unix()+5, 10), "FIELDS", "1", "HexpireK2Field1"}, + expectedValue: "[1]", + expectedError: nil, + }, + { + name: "3. Set expiration, expireTime already populated, no options.", + key: "HexpireAtKey3", + presetValue: hash.Hash{ + "HexpireK3Field1": hash.HashValue{ + Value: "default1", + ExpireAt: mockClock.Now().Add(500 * time.Second), + }, + }, + command: []string{"HEXPIREAT", "HexpireAtKey3", strconv.FormatInt(mockClock.Now().Unix()+100, 10), "FIELDS", "1", "HexpireK3Field1"}, + expectedValue: "[1]", + expectedError: nil, + }, + { + name: "4. Set expiration, option NX with no expire time currently set.", + key: "HexpireAtKey4", + presetValue: hash.Hash{ + "HexpireK4Field1": hash.HashValue{ + Value: "default1", + }, + }, + command: []string{"HEXPIREAT", "HexpireAtKey4", strconv.FormatInt(mockClock.Now().Unix()+5, 10), "NX", "FIELDS", "1", "HexpireK4Field1"}, + expectedValue: "[1]", + expectedError: nil, + }, + { + name: "5. Set expiration, option NX with an expire time already set.", + key: "HexpireAtKey5", + presetValue: hash.Hash{ + "HexpireK5Field1": hash.HashValue{ + Value: "default1", + ExpireAt: mockClock.Now().Add(500 * time.Second), + }, + }, + command: []string{"HEXPIREAT", "HexpireAtKey5", strconv.FormatInt(mockClock.Now().Unix()+100, 10), "NX", "FIELDS", "1", "HexpireK5Field1"}, + expectedValue: "[0]", + expectedError: nil, + }, + { + name: "6. Set expiration, option XX with no expire time currently set.", + key: "HexpireAtKey6", + presetValue: hash.Hash{ + "HexpireK6Field1": hash.HashValue{ + Value: "default1", + }, + }, + command: []string{"HEXPIREAT", "HexpireAtKey6", strconv.FormatInt(mockClock.Now().Unix()+5, 10), "XX", "FIELDS", "1", "HexpireK6Field1"}, + expectedValue: "[0]", + expectedError: nil, + }, + { + name: "7. Set expiration, option XX with expire time already set.", + key: "HexpireAtKey7", + presetValue: hash.Hash{ + "HexpireK7Field1": hash.HashValue{ + Value: "default1", + ExpireAt: mockClock.Now().Add(500 * time.Second), + }, + }, + command: []string{"HEXPIREAT", "HexpireAtKey7", strconv.FormatInt(mockClock.Now().Unix()+100, 10), "XX", "FIELDS", "1", "HexpireK7Field1"}, + expectedValue: "[1]", + expectedError: nil, + }, + { + name: "8. Set expiration, option GT with expire time less than one provided.", + key: "HexpireAtKey8", + presetValue: hash.Hash{ + "HexpireK8Field1": hash.HashValue{ + Value: "default1", + ExpireAt: mockClock.Now().Add(500 * time.Second), + }, + }, + command: []string{"HEXPIREAT", "HexpireAtKey8", strconv.FormatInt(mockClock.Now().Unix()+1000, 10), "GT", "FIELDS", "1", "HexpireK8Field1"}, + expectedValue: "[1]", + expectedError: nil, + }, + { + name: "9. Set expiration, option GT with expire time greater than one provided.", + key: "HexpireAtKey9", + presetValue: hash.Hash{ + "HexpireK9Field1": hash.HashValue{ + Value: "default1", + ExpireAt: mockClock.Now().Add(500 * time.Second), + }, + }, + command: []string{"HEXPIREAT", "HexpireAtKey9", strconv.FormatInt(mockClock.Now().Unix()+100, 10), "GT", "FIELDS", "1", "HexpireK9Field1"}, + expectedValue: "[0]", + expectedError: nil, + }, + { + name: "10. Set expiration, option LT with expire time less than one provided.", + key: "HexpireAtKey10", + presetValue: hash.Hash{ + "HexpireK10Field1": hash.HashValue{ + Value: "default1", + ExpireAt: mockClock.Now().Add(500 * time.Second), + }, + }, + command: []string{"HEXPIREAT", "HexpireAtKey10", strconv.FormatInt(mockClock.Now().Unix()+1000, 10), "LT", "FIELDS", "1", "HexpireK10Field1"}, + expectedValue: "[0]", + expectedError: nil, + }, + { + name: "11. Set expiration, option LT with expire time greater than one provided.", + key: "HexpireAtKey11", + presetValue: hash.Hash{ + "HexpireK11Field1": hash.HashValue{ + Value: "default1", + ExpireAt: mockClock.Now().Add(500 * time.Second), + }, + }, + command: []string{"HEXPIREAT", "HexpireAtKey11", strconv.FormatInt(mockClock.Now().Unix()+100, 10), "LT", "FIELDS", "1", "HexpireK11Field1"}, + expectedValue: "[1]", + expectedError: nil, + }, + { + name: "12. Set expiration, provide 0 seconds.", + key: "HexpireAtKey12", + presetValue: hash.Hash{ + "HexpireK12Field1": hash.HashValue{ + Value: "default1", + }, + }, + command: []string{"HEXPIREAT", "HexpireAtKey12", strconv.FormatInt(mockClock.Now().Unix()-10, 10), "FIELDS", "1", "HexpireK12Field1"}, + expectedValue: "[2]", + expectedError: nil, + }, + { + name: "13. Attempt to set expiration for non existent key.", + key: "HexpireAtKeyNOTEXIST", + presetValue: nil, + command: []string{"HEXPIREAT", "HexpireAtKeyNOTEXIST", strconv.FormatInt(mockClock.Now().Unix()+10, 10), "FIELDS", "1", "HexpireKNEField1"}, + expectedValue: "[-2]", + expectedError: nil, + }, + { + name: "14. Attempt to set expiration for field that doesn't exist.", + key: "HexpireAtKey14", + presetValue: hash.Hash{ + "HexpireK14Field1": hash.HashValue{ + Value: "default1", + }, + }, + command: []string{"HEXPIREAT", "HexpireAtKey14", strconv.FormatInt(mockClock.Now().Unix()+10, 10), "FIELDS", "2", "HexpireK14BadField1", "HexpireK14Field1"}, + expectedValue: "[-2 1]", + expectedError: nil, + }, + { + name: "15. Set expiration, command wrong length.", + key: "HexpireAtKey15", + presetValue: hash.Hash{ + "HexpireK15Field1": hash.HashValue{ + Value: "default1", + }, + }, + command: []string{"HEXPIREAT", "HexpireAtKey15", strconv.FormatInt(mockClock.Now().Unix()+10, 10), "1", "HexpireK15Field1"}, + expectedError: errors.New("Error wrong number of arguments"), + }, + { + name: "16. Set expiration, command filed numfields is not a number.", + key: "HexpireAtKey16", + presetValue: hash.Hash{ + "HexpireK16Field1": hash.HashValue{ + Value: "default1", + }, + }, + command: []string{"HEXPIREAT", "HexpireAtKey16", strconv.FormatInt(mockClock.Now().Unix()+10, 10), "FIELDS", "one", "HexpireK16Field1"}, + expectedError: errors.New("Error numberfields must be integer, was provided \"one\""), + }, + } + + for _, test := range tests { + + t.Run(test.name, func(t *testing.T) { + // set key with preset value + if test.presetValue != nil { + var command []resp.Value + var expected string + + command = []resp.Value{resp.StringValue("HSET"), resp.StringValue(test.key)} + for key, value := range test.presetValue { + command = append(command, []resp.Value{ + resp.StringValue(key), + resp.StringValue(value.Value.(string))}..., + ) + } + expected = strconv.Itoa(len(test.presetValue)) + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } + + } + + // preset Expire Time + for field, value := range test.presetValue { + if value.ExpireAt != (time.Time{}) { + cmd := []resp.Value{ + resp.StringValue("HEXPIRE"), + resp.StringValue(test.key), + resp.StringValue("500"), + resp.StringValue("FIELDS"), + resp.StringValue("1"), + resp.StringValue(field), + } + + if err = client.WriteArray(cmd); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + if res.String() != "[1]" { + t.Errorf("Error presetting expire time - Key: %s, Field: %s, response: %s", test.key, field, res.String()) + } + } + } + + // run HEXPIREAT command + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), res.Error()) + } + return + } + + if res.String() != test.expectedValue { + t.Errorf("expected response %q, got %q", test.expectedValue, res.String()) + } + + }) + + } + }) t.Run("Test_HandleHTTL", func(t *testing.T) { @@ -2504,9 +2812,9 @@ func Test_Hash(t *testing.T) { _ = conn.Close() }() client := resp.NewConn(conn) - + const fixedTimestamp = 1136189545000 - + tests := []struct { name string key string @@ -2589,12 +2897,12 @@ func Test_Hash(t *testing.T) { expectedError: errors.New("expire time must be integer, was provided \"notanumber\""), }, } - + for _, test := range tests { t.Run(test.name, func(t *testing.T) { if test.presetValue != nil { var command []resp.Value - + switch v := test.presetValue.(type) { case string: command = []resp.Value{ @@ -2608,14 +2916,14 @@ func Test_Hash(t *testing.T) { command = append(command, resp.StringValue(key), resp.StringValue(value.Value.(string))) } } - + if err = client.WriteArray(command); err != nil { t.Error(err) } if _, _, err = client.ReadValue(); err != nil { t.Error(err) } - + if test.setExpire { if hash, ok := test.presetValue.(hash.Hash); ok { for field := range hash { @@ -2637,7 +2945,7 @@ func Test_Hash(t *testing.T) { } } } - + // Execute HPEXPIRETIME command command := make([]resp.Value, len(test.command)) for i, v := range test.command { @@ -2646,19 +2954,19 @@ func Test_Hash(t *testing.T) { if err = client.WriteArray(command); err != nil { t.Error(err) } - + resp, _, err := client.ReadValue() if err != nil { t.Error(err) } - + if test.expectedError != nil { if !strings.Contains(resp.Error().Error(), test.expectedError.Error()) { t.Errorf("expected error %q, got %q", test.expectedError.Error(), resp.Error()) } return } - + if resp.String() != test.expectedValue { t.Errorf("Expected value %q but got %q", test.expectedValue, resp.String()) } @@ -2677,9 +2985,9 @@ func Test_Hash(t *testing.T) { _ = conn.Close() }() client := resp.NewConn(conn) - + const fixedTimestamp = 1136189545 - + tests := []struct { name string key string @@ -2762,12 +3070,12 @@ func Test_Hash(t *testing.T) { expectedError: errors.New("expire time must be integer, was provided \"notanumber\""), }, } - + for _, test := range tests { t.Run(test.name, func(t *testing.T) { if test.presetValue != nil { var command []resp.Value - + switch v := test.presetValue.(type) { case string: command = []resp.Value{ @@ -2781,14 +3089,14 @@ func Test_Hash(t *testing.T) { command = append(command, resp.StringValue(key), resp.StringValue(value.Value.(string))) } } - + if err = client.WriteArray(command); err != nil { t.Error(err) } if _, _, err = client.ReadValue(); err != nil { t.Error(err) } - + if test.setExpire { if hash, ok := test.presetValue.(hash.Hash); ok { for field := range hash { @@ -2810,7 +3118,7 @@ func Test_Hash(t *testing.T) { } } } - + command := make([]resp.Value, len(test.command)) for i, v := range test.command { command[i] = resp.StringValue(v) @@ -2818,19 +3126,19 @@ func Test_Hash(t *testing.T) { if err = client.WriteArray(command); err != nil { t.Error(err) } - + resp, _, err := client.ReadValue() if err != nil { t.Error(err) } - + if test.expectedError != nil { if !strings.Contains(resp.Error().Error(), test.expectedError.Error()) { t.Errorf("expected error %q, got %q", test.expectedError.Error(), resp.Error()) } return } - + if resp.String() != test.expectedValue { t.Errorf("Expected value %q but got %q", test.expectedValue, resp.String()) } diff --git a/internal/modules/hash/key_funcs.go b/internal/modules/hash/key_funcs.go index d313f01..e362c85 100644 --- a/internal/modules/hash/key_funcs.go +++ b/internal/modules/hash/key_funcs.go @@ -183,6 +183,18 @@ func hexpireKeyFunc(cmd []string) (internal.KeyExtractionFuncResult, error) { }, nil } +func hexpireatKeyFunc(cmd []string) (internal.KeyExtractionFuncResult, error) { + if len(cmd) < 6 { + return internal.KeyExtractionFuncResult{}, errors.New(constants.WrongArgsResponse) + } + + return internal.KeyExtractionFuncResult{ + Channels: make([]string, 0), + ReadKeys: make([]string, 0), + WriteKeys: cmd[1:], + }, nil +} + func httlKeyFunc(cmd []string) (internal.KeyExtractionFuncResult, error) { if len(cmd) < 5 { return internal.KeyExtractionFuncResult{}, errors.New(constants.WrongArgsResponse) diff --git a/sugardb/api_hash.go b/sugardb/api_hash.go index 536b063..5449fd5 100644 --- a/sugardb/api_hash.go +++ b/sugardb/api_hash.go @@ -397,6 +397,45 @@ func (server *SugarDB) HExpire(key string, seconds int, ExOpt ExpireOptions, fie return internal.ParseIntegerArrayResponse(b) } +// HExpireAt sets the expiration for the provided field(s) in a hash map to a specific epoch time. +// +// Parameters: +// +// `key` - string - the key to the hash map. +// +// `epoch` - int - Unix timestamp in seconds when the key should expire. +// +// `ExOpt` - ExpireOptions - One of NX, XX, GT, LT. +// +// `fields` - ...string - a list of fields to set expiration of. +// +// Returns: an integer array representing the outcome of the commmand for each field. +// - Integer reply: -2 if no such field exists in the provided hash key, or the provided key does not exist. +// - Integer reply: 0 if the specified NX | XX | GT | LT condition has not been met. +// - Integer reply: 1 if the expiration time was set/updated. +// - Integer reply: 2 when HEXPIRE/HPEXPIRE is called with 0 seconds +// +// Errors: +// +// "value of key is not a hash" - when the provided key is not a hash. +func (server *SugarDB) HExpireAt(key string, epoch int, ExOpt ExpireOptions, fields ...string) ([]int, error) { + cmd := []string{"HEXPIREAT", key, fmt.Sprintf("%v", epoch)} + if ExOpt != nil { + ExpireOption := fmt.Sprintf("%v", ExOpt) + cmd = append(cmd, ExpireOption) + } + + numFields := fmt.Sprintf("%v", len(fields)) + fieldsArray := append([]string{"FIELDS", numFields}, fields...) + + cmd = append(cmd, fieldsArray...) + b, err := server.handleCommand(server.context, internal.EncodeCommand(cmd), nil, false, true) + if err != nil { + return nil, err + } + return internal.ParseIntegerArrayResponse(b) +} + // HTTL gets the expiration for the provided field(s) in a hash map. // // Parameters: diff --git a/sugardb/api_hash_test.go b/sugardb/api_hash_test.go index 359a65f..ccd0864 100644 --- a/sugardb/api_hash_test.go +++ b/sugardb/api_hash_test.go @@ -16,11 +16,12 @@ package sugardb import ( "context" - "github.com/echovault/sugardb/internal/modules/hash" "reflect" "slices" "testing" "time" + + "github.com/echovault/sugardb/internal/modules/hash" ) func TestSugarDB_Hash(t *testing.T) { @@ -1070,6 +1071,121 @@ func TestSugarDB_Hash(t *testing.T) { } }) + t.Run("TestSugarDB_HExpireAt", func(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + presetValue interface{} + key string + fields []string + expireOption ExpireOptions + want []int + wantErr bool + }{ + { + name: "1. Set Expiration from existing hash.", + key: "HExpireAtKey1", + presetValue: hash.Hash{ + "field1": {Value: "value1"}, + "field2": {Value: 365}, + "field3": {Value: 3.142}, + }, + fields: []string{"field1", "field2", "field3"}, + want: []int{1, 1, 1}, + wantErr: false, + }, + { + name: "2. Return -2 when attempting to get from non-existed key", + presetValue: nil, + key: "HExpireAtKey2", + fields: []string{"field1"}, + want: []int{-2}, + wantErr: false, + }, + { + name: "3. Error when trying to get from a value that is not a hash map", + presetValue: "Default Value", + key: "HExpireAtKey3", + fields: []string{"field1"}, + want: nil, + wantErr: true, + }, + { + name: "4. Set Expiration with option NX.", + key: "HExpireAtKey4", + presetValue: hash.Hash{ + "field1": {Value: "value1"}, + "field2": {Value: 365}, + "field3": {Value: 3.142}, + }, + fields: []string{"field1", "field2", "field3"}, + expireOption: NX, + want: []int{1, 1, 1}, + wantErr: false, + }, + { + name: "5. Set Expiration with option XX.", + key: "HExpireAtKey5", + presetValue: hash.Hash{ + "field1": {Value: "value1"}, + "field2": {Value: 365}, + "field3": {Value: 3.142}, + }, + fields: []string{"field1", "field2", "field3"}, + expireOption: XX, + want: []int{0, 0, 0}, + wantErr: false, + }, + { + name: "6. Set Expiration with option GT.", + key: "HExpireAtKey6", + presetValue: hash.Hash{ + "field1": {Value: "value1"}, + "field2": {Value: 365}, + "field3": {Value: 3.142}, + }, + fields: []string{"field1", "field2", "field3"}, + expireOption: GT, + want: []int{0, 0, 0}, + wantErr: false, + }, + { + name: "7. Set Expiration with option LT.", + key: "HExpireAtKey7", + presetValue: hash.Hash{ + "field1": {Value: "value1"}, + "field2": {Value: 365}, + "field3": {Value: 3.142}, + }, + fields: []string{"field1", "field2", "field3"}, + expireOption: LT, + want: []int{1, 1, 1}, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if tt.presetValue != nil { + err := presetValue(server, context.Background(), tt.key, tt.presetValue) + if err != nil { + t.Error(err) + return + } + } + got, err := server.HExpireAt(tt.key, int(time.Now().Unix())+5, tt.expireOption, tt.fields...) + if (err != nil) != tt.wantErr { + t.Errorf("HExpireAt() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("HExpireAt() got = %v, want %v", got, tt.want) + } + }) + } + }) + t.Run("TestSugarDB_HTTL", func(t *testing.T) { t.Parallel()