From d12a41374b394bb19b36f294dc682288cc60c0f2 Mon Sep 17 00:00:00 2001 From: hdt3213 Date: Mon, 17 Aug 2020 03:00:29 +0800 Subject: [PATCH] fix a dead lock bug --- build-linux.sh | 3 + build.sh | 4 + src/datastruct/lock/lock_map.go | 60 ++++++---- src/db/hash_test.go | 200 ++++++++++++++++++++++++++++++++ src/db/string_test.go | 2 +- 5 files changed, 244 insertions(+), 25 deletions(-) create mode 100755 build-linux.sh create mode 100755 build.sh create mode 100644 src/db/hash_test.go diff --git a/build-linux.sh b/build-linux.sh new file mode 100755 index 0000000..6d3b6b2 --- /dev/null +++ b/build-linux.sh @@ -0,0 +1,3 @@ +#!/usr/bin/env bash + +GOOS=linux GOARCH=amd64 go build -o target/godis-linux ./src/cmd \ No newline at end of file diff --git a/build.sh b/build.sh new file mode 100755 index 0000000..ab3e900 --- /dev/null +++ b/build.sh @@ -0,0 +1,4 @@ +#!/usr/bin/env bash + + +go build -i -o target/godis-darwin ./src/cmd diff --git a/src/datastruct/lock/lock_map.go b/src/datastruct/lock/lock_map.go index f927dda..36a4e90 100644 --- a/src/datastruct/lock/lock_map.go +++ b/src/datastruct/lock/lock_map.go @@ -70,44 +70,56 @@ func (locks *Locks)RUnLock(key string) { mu.RUnlock() } +func (locks *Locks) toLockIndices(keys []string, reverse bool) []uint32 { + indexMap := make(map[uint32]bool) + for _, key := range keys { + index := locks.spread(fnv32(key)) + indexMap[index] = true + } + indices := make([]uint32, 0, len(indexMap)) + for index := range indexMap { + indices = append(indices, index) + } + sort.Slice(indices, func(i, j int) bool { + if !reverse { + return indices[i] < indices[j] + } else { + return indices[i] > indices[j] + } + }) + return indices +} + func (locks *Locks)Locks(keys ...string) { - keySlice := make(sort.StringSlice, len(keys)) - copy(keySlice, keys) - sort.Sort(keySlice) - for _, key := range keySlice { - locks.Lock(key) + indices := locks.toLockIndices(keys, false) + for _, index := range indices { + mu := locks.table[index] + mu.Lock() } } func (locks *Locks)RLocks(keys ...string) { - keySlice := make(sort.StringSlice, len(keys)) - copy(keySlice, keys) - sort.Sort(keySlice) - for _, key := range keySlice { - locks.RLock(key) + indices := locks.toLockIndices(keys, false) + for _, index := range indices { + mu := locks.table[index] + mu.RLock() } } func (locks *Locks)UnLocks(keys ...string) { - size := len(keys) - keySlice := make(sort.StringSlice, size) - copy(keySlice, keys) - sort.Sort(keySlice) - for i := size - 1; i >= 0; i-- { - key := keySlice[i] - locks.UnLock(key) + indices := locks.toLockIndices(keys, true) + for _, index := range indices { + mu := locks.table[index] + mu.Unlock() } } func (locks *Locks)RUnLocks(keys ...string) { - size := len(keys) - keySlice := make(sort.StringSlice, size) - copy(keySlice, keys) - sort.Sort(keySlice) - for i := size - 1; i >= 0; i-- { - key := keySlice[i] - locks.RUnLock(key) + indices := locks.toLockIndices(keys, true) + for _, index := range indices { + mu := locks.table[index] + mu.RUnlock() } } diff --git a/src/db/hash_test.go b/src/db/hash_test.go new file mode 100644 index 0000000..64fc71c --- /dev/null +++ b/src/db/hash_test.go @@ -0,0 +1,200 @@ +package db + +import ( + "fmt" + "github.com/HDT3213/godis/src/datastruct/utils" + "github.com/HDT3213/godis/src/redis/reply" + "math/rand" + "strconv" + "testing" +) + +func TestHSet(t *testing.T) { + FlushAll(testDB, [][]byte{}) + size := 100 + + // test hset + key := strconv.FormatInt(int64(rand.Int()), 10) + values := make(map[string][]byte, size) + for i := 0; i < size; i++ { + value := strconv.FormatInt(int64(rand.Int()), 10) + field := strconv.Itoa(i) + values[field] = []byte(value) + result := HSet(testDB, toArgs(key, field, value)) + if intResult, _ := result.(*reply.IntReply); intResult.Code != int64(1) { + t.Error(fmt.Sprintf("expected %d, actually %d", 1, intResult.Code)) + } + } + + // test hget and hexists + for field, v := range values { + actual := HGet(testDB, toArgs(key, field)) + expected := reply.MakeBulkReply(v) + if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { + t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(actual.ToBytes()))) + } + actual = HExists(testDB, toArgs(key, field)) + if intResult, _ := actual.(*reply.IntReply); intResult.Code != int64(1) { + t.Error(fmt.Sprintf("expected %d, actually %d", 1, intResult.Code)) + } + } + + // test hlen + actual := HLen(testDB, toArgs(key)) + if intResult, _ := actual.(*reply.IntReply); intResult.Code != int64(len(values)) { + t.Error(fmt.Sprintf("expected %d, actually %d", len(values), intResult.Code)) + } +} + +func TestHDel(t *testing.T) { + FlushAll(testDB, [][]byte{}) + size := 100 + + // set values + key := strconv.FormatInt(int64(rand.Int()), 10) + fields := make([]string, size) + for i := 0; i < size; i++ { + value := strconv.FormatInt(int64(rand.Int()), 10) + field := strconv.Itoa(i) + fields[i] = field + HSet(testDB, toArgs(key, field, value)) + } + + // test HDel + args := []string{key} + args = append(args, fields...) + actual := HDel(testDB, toArgs(args...)) + if intResult, _ := actual.(*reply.IntReply); intResult.Code != int64(len(fields)) { + t.Error(fmt.Sprintf("expected %d, actually %d", len(fields), intResult.Code)) + } + + actual = HLen(testDB, toArgs(key)) + if intResult, _ := actual.(*reply.IntReply); intResult.Code != int64(0) { + t.Error(fmt.Sprintf("expected %d, actually %d", 0, intResult.Code)) + } +} + +func TestHMSet(t *testing.T) { + FlushAll(testDB, [][]byte{}) + size := 100 + + // test hset + key := strconv.FormatInt(int64(rand.Int()), 10) + fields := make([]string, size) + values := make([]string, size) + setArgs := []string{key} + for i := 0; i < size; i++ { + fields[i] = strconv.FormatInt(int64(rand.Int()), 10) + values[i] = strconv.FormatInt(int64(rand.Int()), 10) + setArgs = append(setArgs, fields[i], values[i]) + } + result := HMSet(testDB, toArgs(setArgs...)) + if _, ok := result.(*reply.OkReply); !ok { + t.Error(fmt.Sprintf("expected ok, actually %s", string(result.ToBytes()))) + } + + // test HMGet + getArgs := []string{key} + getArgs = append(getArgs, fields...) + actual := HMGet(testDB, toArgs(getArgs...)) + expected := reply.MakeMultiBulkReply(toArgs(values...)) + if !utils.BytesEquals(actual.ToBytes(), expected.ToBytes()) { + t.Error(fmt.Sprintf("expected %s, actually %s", string(expected.ToBytes()), string(actual.ToBytes()))) + } +} + +func TestHGetAll(t *testing.T) { + FlushAll(testDB, [][]byte{}) + size := 100 + key := strconv.FormatInt(int64(rand.Int()), 10) + fields := make([]string, size) + valueSet := make(map[string]bool, size) + valueMap := make(map[string]string) + all := make([]string, 0) + for i := 0; i < size; i++ { + fields[i] = strconv.FormatInt(int64(rand.Int()), 10) + value := strconv.FormatInt(int64(rand.Int()), 10) + all = append(all, fields[i], value) + valueMap[fields[i]] = value + valueSet[value] = true + HSet(testDB, toArgs(key, fields[i], value)) + } + + // test HGetAll + result := HGetAll(testDB, toArgs(key)) + multiBulk, ok := result.(*reply.MultiBulkReply) + if !ok { + t.Error(fmt.Sprintf("expected MultiBulkReply, actually %s", string(result.ToBytes()))) + } + if 2*len(fields) != len(multiBulk.Args) { + t.Error(fmt.Sprintf("expected %d items , actually %d ", 2*len(fields), len(multiBulk.Args))) + } + for i := range fields { + field := string(multiBulk.Args[2*i]) + actual := string(multiBulk.Args[2*i+1]) + expected, ok := valueMap[field] + if !ok { + t.Error(fmt.Sprintf("unexpected field %s", field)) + continue + } + if actual != expected { + t.Error(fmt.Sprintf("expected %s, actually %s", expected, actual)) + } + } + + // test HKeys + result = HKeys(testDB, toArgs(key)) + multiBulk, ok = result.(*reply.MultiBulkReply) + if !ok { + t.Error(fmt.Sprintf("expected MultiBulkReply, actually %s", string(result.ToBytes()))) + } + if len(fields) != len(multiBulk.Args) { + t.Error(fmt.Sprintf("expected %d items , actually %d ", len(fields), len(multiBulk.Args))) + } + for _, v := range multiBulk.Args { + field := string(v) + if _, ok := valueMap[field]; !ok { + t.Error(fmt.Sprintf("unexpected field %s", field)) + } + } + + // test HVals + result = HVals(testDB, toArgs(key)) + multiBulk, ok = result.(*reply.MultiBulkReply) + if !ok { + t.Error(fmt.Sprintf("expected MultiBulkReply, actually %s", string(result.ToBytes()))) + } + if len(fields) != len(multiBulk.Args) { + t.Error(fmt.Sprintf("expected %d items , actually %d ", len(fields), len(multiBulk.Args))) + } + for _, v := range multiBulk.Args { + value := string(v) + _, ok := valueSet[value] + if !ok { + t.Error(fmt.Sprintf("unexpected value %s", value)) + } + } +} + +func TestHIncrBy(t *testing.T) { + FlushAll(testDB, [][]byte{}) + + key := strconv.FormatInt(int64(rand.Int()), 10) + result := HIncrBy(testDB, toArgs(key, "a", "1")) + if bulkResult, _ := result.(*reply.BulkReply); string(bulkResult.Arg) != "1" { + t.Error(fmt.Sprintf("expected %s, actually %s", "1", string(bulkResult.Arg))) + } + result = HIncrBy(testDB, toArgs(key, "a", "1")) + if bulkResult, _ := result.(*reply.BulkReply); string(bulkResult.Arg) != "2" { + t.Error(fmt.Sprintf("expected %s, actually %s", "2", string(bulkResult.Arg))) + } + + result = HIncrByFloat(testDB, toArgs(key, "b", "1.2")) + if bulkResult, _ := result.(*reply.BulkReply); string(bulkResult.Arg) != "1.2" { + t.Error(fmt.Sprintf("expected %s, actually %s", "1.2", string(bulkResult.Arg))) + } + result = HIncrByFloat(testDB, toArgs(key, "b", "1.2")) + if bulkResult, _ := result.(*reply.BulkReply); string(bulkResult.Arg) != "2.4" { + t.Error(fmt.Sprintf("expected %s, actually %s", "2.4", string(bulkResult.Arg))) + } +} diff --git a/src/db/string_test.go b/src/db/string_test.go index 9761c7d..792fb9c 100644 --- a/src/db/string_test.go +++ b/src/db/string_test.go @@ -95,7 +95,7 @@ func TestMSet(t *testing.T) { size := 10 keys := make([]string, size) values := make([][]byte, size) - args := make([]string, size*2)[0:0] + args := make([]string, 0, size*2) for i := 0; i < size; i++ { keys[i] = strconv.FormatInt(int64(rand.Int()), 10) value := strconv.FormatInt(int64(rand.Int()), 10)