diff --git a/echovault/api_string_test.go b/echovault/api_string_test.go index 41e810a..2c21972 100644 --- a/echovault/api_string_test.go +++ b/echovault/api_string_test.go @@ -323,13 +323,27 @@ func TestEchoVault_APPEND(t *testing.T) { wantErr bool }{ { - name: "Return the correct string length for appended value", + name: "Test APPEND with no preset value", + key: "key1", + value: "Hello ", + want: 6, + wantErr: false, + }, + { + name: "Test APPEND with preset value", presetValue: "Hello ", - key: "key1", + key: "key2", value: "World", want: 11, wantErr: false, }, + { + name: "Test APPEND with integer preset value", + key: "key3", + presetValue: 10, + value: "Hello ", + wantErr: true, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/internal/modules/string/commands.go b/internal/modules/string/commands.go index 1ab2550..a353f94 100644 --- a/internal/modules/string/commands.go +++ b/internal/modules/string/commands.go @@ -182,16 +182,19 @@ func handleAppend(params internal.HandlerFuncParams) ([]byte, error) { }); err != nil { return nil, err } - return []byte(fmt.Sprintf("$%d\r\n%s\r\n", len(value), value)), nil + return []byte(fmt.Sprintf(":%d\r\n", len(value))), nil + } + currentValue, ok := params.GetValues(params.Context, []string{key})[key].(string) + if !ok { + return nil, fmt.Errorf("Value at key %s is not a string", key) } - currentValue := params.GetValues(params.Context, []string{key})[key] newValue := fmt.Sprintf("%v%s", currentValue, value) if err = params.SetValues(params.Context, map[string]interface{}{ key: internal.AdaptType(newValue), }); err != nil { return nil, err } - return []byte(fmt.Sprintf("$%d\r\n%s\r\n", len(newValue), newValue)), nil + return []byte(fmt.Sprintf(":%d\r\n", len(newValue))), nil } func Commands() []internal.Command { diff --git a/internal/modules/string/commands_test.go b/internal/modules/string/commands_test.go index c92e419..8b91ebf 100644 --- a/internal/modules/string/commands_test.go +++ b/internal/modules/string/commands_test.go @@ -16,14 +16,15 @@ package str_test import ( "errors" + "strconv" + "strings" + "testing" + "github.com/echovault/echovault/echovault" "github.com/echovault/echovault/internal" "github.com/echovault/echovault/internal/config" "github.com/echovault/echovault/internal/constants" "github.com/tidwall/resp" - "strconv" - "strings" - "testing" ) func Test_String(t *testing.T) { @@ -450,4 +451,106 @@ func Test_String(t *testing.T) { }) } }) + + t.Run("Test_HandleAppend", func(t *testing.T) { + t.Parallel() + conn, err := internal.GetConnection("localhost", port) + if err != nil { + t.Error(err) + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) + + tests := []struct { + name string + key string + presetValue interface{} + command []string + expectedResponse int + expectedError error + }{ + { + name: "Test APPEND with no preset value", + key: "AppendKey1", + command: []string{"APPEND", "AppendKey1", "Hello"}, + expectedResponse: 5, + expectedError: nil, + }, + { + name: "Test APPEND with preset value", + key: "AppendKey2", + presetValue: "Hello ", + command: []string{"APPEND", "AppendKey2", "World"}, + expectedResponse: 11, + expectedError: nil, + }, + { + name: "Test APPEND with integer preset value", + key: "AppendKey4", + presetValue: 10, + command: []string{"APPEND", "AppendKey4", "World"}, + expectedResponse: 0, + expectedError: errors.New("Value at key AppendKey4 is not a string"), + }, + { + name: "Command too short", + command: []string{"APPEND", "AppendKey5"}, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "Command too long", + command: []string{"APPEND", "AppendKey5", "new value", "extra value"}, + expectedError: errors.New(constants.WrongArgsResponse), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValue != "" { + if err = client.WriteArray([]resp.Value{ + resp.StringValue("SET"), + resp.StringValue(test.key), + resp.AnyValue(test.presetValue), + }); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if !strings.EqualFold(res.String(), "ok") { + t.Errorf("expected preset response to be OK, got %s", res.String()) + } + } + + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + } + return + } + + if res.Integer() != test.expectedResponse { + t.Errorf("expected response \"%d\", got \"%d\"", test.expectedResponse, res.Integer()) + } + }) + } + }) }