about summary refs log tree commit diff stats
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/server/hashtable.h30
1 files changed, 21 insertions, 9 deletions
diff --git a/src/server/hashtable.h b/src/server/hashtable.h
index 0989a79..786d8f5 100644
--- a/src/server/hashtable.h
+++ b/src/server/hashtable.h
@@ -3,7 +3,9 @@
 #include <algorithm>
 #include <iostream>
 #include <list>
+#include <mutex>
 #include <optional>
+#include <shared_mutex>
 #include <vector>
 
 template <typename K, typename V>
@@ -12,12 +14,16 @@ public:
     HashTable(size_t size)
         : size { size }
         , table(size)
+        , bucket_mutexes(size)
     {
     }
 
     bool insert(K key, V value)
     {
-        std::list<std::pair<K, V>>& list = get_bucket(key);
+        size_t index = get_bucket_index(key);
+        std::unique_lock<std::shared_mutex> lock(bucket_mutexes.at(index));
+
+        std::list<std::pair<K, V>>& list = table.at(index);
 
         if (bucket_contains_key(list, key)) {
             return false;
@@ -30,7 +36,10 @@ public:
 
     std::optional<V> get(K key)
     {
-        std::list<std::pair<K, V>>& list = get_bucket(key);
+        size_t index = get_bucket_index(key);
+        std::shared_lock<std::shared_mutex> lock(bucket_mutexes.at(index));
+
+        std::list<std::pair<K, V>>& list = table.at(index);
 
         auto iter = bucket_find_key(list, key);
         if (iter != list.end()) {
@@ -42,7 +51,10 @@ public:
 
     bool remove(K key)
     {
-        std::list<std::pair<K, V>>& list = get_bucket(key);
+        size_t index = get_bucket_index(key);
+        std::unique_lock<std::shared_mutex> lock(bucket_mutexes.at(index));
+
+        std::list<std::pair<K, V>>& list = table.at(index);
 
         auto iter = bucket_find_key(list, key);
         if (iter != list.end()) {
@@ -57,11 +69,13 @@ public:
     {
         size_t index { 0 };
         for (auto bucket : table) {
-            std::cout << "Bucket " << index++ << ": [";
+            std::cout << "Bucket " << index << ": [";
+            std::shared_lock<std::shared_mutex> lock(bucket_mutexes.at(index));
             for (auto pair : bucket) {
                 std::cout << "(" << pair.first << ", " << pair.second << ")";
             }
             std::cout << "]" << "\n";
+            ++index;
         }
     }
 
@@ -70,6 +84,8 @@ private:
 
     std::vector<std::list<std::pair<K, V>>> table;
 
+    std::vector<std::shared_mutex> bucket_mutexes;
+
     std::hash<K> hash_function;
 
     auto bucket_find_key(std::list<std::pair<K, V>>& list, K key)
@@ -84,9 +100,5 @@ private:
         return list.begin() != list.end() && bucket_find_key(list, key) != list.end();
     }
 
-    std::list<std::pair<K, V>>& get_bucket(K key)
-    {
-        size_t index = hash_function(key) % size;
-        return table.at(index);
-    }
+    size_t get_bucket_index(K key) { return hash_function(key) % size; }
 };