00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022 #if !defined(INCLUDED_CONTAINER_HASHTABLE_H)
00023 #define INCLUDED_CONTAINER_HASHTABLE_H
00024
00025 #include <cstddef>
00026 #include <algorithm>
00027 #include <functional>
00028 #include "debugging/debugging.h"
00029
00030
00031 namespace HashTableDetail {
00032 inline std::size_t next_power_of_two(std::size_t size) {
00033 std::size_t result = 1;
00034 while (result < size) {
00035 result <<= 1;
00036 }
00037 return result;
00038 }
00039
00040 struct BucketNodeBase {
00041 BucketNodeBase* next;
00042 BucketNodeBase* prev;
00043 };
00044
00045 inline void list_initialise(BucketNodeBase& self) {
00046 self.next = self.prev = &self;
00047 }
00048
00049 inline void list_swap(BucketNodeBase& self, BucketNodeBase& other) {
00050 BucketNodeBase tmp(self);
00051 if (other.next == &other) {
00052 list_initialise(self);
00053 } else {
00054 self = other;
00055 self.next->prev = self.prev->next = &self;
00056 }
00057 if (tmp.next == &self) {
00058 list_initialise(other);
00059 } else {
00060 other = tmp;
00061 other.next->prev = other.prev->next = &other;
00062 }
00063 }
00064
00065 inline void node_link(BucketNodeBase* node, BucketNodeBase* next) {
00066 node->next = next;
00067 node->prev = next->prev;
00068 next->prev = node;
00069 node->prev->next = node;
00070 }
00071 inline void node_unlink(BucketNodeBase* node) {
00072 node->prev->next = node->next;
00073 node->next->prev = node->prev;
00074 }
00075
00076 template<typename Key, typename Value>
00077 struct KeyValue {
00078 const Key key;
00079 Value value;
00080
00081 KeyValue(const Key& key_, const Value& value_)
00082 : key(key_), value(value_) {
00083 }
00084 };
00085
00086 template<typename Key, typename Value, typename Hash>
00087 struct BucketNode : public BucketNodeBase {
00088 Hash m_hash;
00089 KeyValue<Key, Value> m_value;
00090
00091 BucketNode(Hash hash, const Key& key, const Value& value)
00092 : m_hash(hash), m_value(key, value) {
00093 }
00094 BucketNode* getNext() const {
00095 return static_cast<BucketNode*>(next);
00096 }
00097 BucketNode* getPrev() const {
00098 return static_cast<BucketNode*>(prev);
00099 }
00100 };
00101
00102 template<typename Key, typename Value, typename Hash>
00103 class BucketIterator {
00104 typedef BucketNode<Key, Value, Hash> Node;
00105 Node* m_node;
00106
00107 void increment() {
00108 m_node = m_node->getNext();
00109 }
00110
00111 public:
00112 typedef std::forward_iterator_tag iterator_category;
00113 typedef std::ptrdiff_t difference_type;
00114 typedef difference_type distance_type;
00115 typedef KeyValue<Key, Value> value_type;
00116 typedef value_type* pointer;
00117 typedef value_type& reference;
00118
00119 BucketIterator(Node* node) : m_node(node) {
00120 }
00121
00122 Node* node() {
00123 return m_node;
00124 }
00125
00126 bool operator==(const BucketIterator& other) const {
00127 return m_node == other.m_node;
00128 }
00129 bool operator!=(const BucketIterator& other) const {
00130 return !operator==(other);
00131 }
00132 BucketIterator& operator++() {
00133 increment();
00134 return *this;
00135 }
00136 BucketIterator operator++(int) {
00137 BucketIterator tmp = *this;
00138 increment();
00139 return tmp;
00140 }
00141 value_type& operator*() const {
00142 return m_node->m_value;
00143 }
00144 value_type* operator->() const {
00145 return &(operator*());
00146 }
00147 };
00148 }
00149
00150
00164 template < typename Key, typename Value, typename Hasher, typename KeyEqual = std::equal_to<Key> >
00165 class HashTable : private KeyEqual, private Hasher {
00166 typedef typename Hasher::hash_type hash_type;
00167 typedef HashTableDetail::KeyValue<Key, Value> KeyValue;
00168 typedef HashTableDetail::BucketNode<Key, Value, hash_type> BucketNode;
00169
00170 inline BucketNode* node_create(hash_type hash, const Key& key, const Value& value) {
00171 return new BucketNode(hash, key, value);
00172 }
00173 inline void node_destroy(BucketNode* node) {
00174 delete node;
00175 }
00176
00177 typedef BucketNode* Bucket;
00178
00179 static Bucket* buckets_new(std::size_t count) {
00180 Bucket* buckets = new Bucket[count];
00181 std::uninitialized_fill(buckets, buckets + count, Bucket(0));
00182 return buckets;
00183 }
00184 static void buckets_delete(Bucket* buckets) {
00185 delete[] buckets;
00186 }
00187
00188 std::size_t m_bucketCount;
00189 Bucket* m_buckets;
00190 std::size_t m_size;
00191 HashTableDetail::BucketNodeBase m_list;
00192
00193 BucketNode* getFirst() {
00194 return static_cast<BucketNode*>(m_list.next);
00195 }
00196 BucketNode* getLast() {
00197 return static_cast<BucketNode*>(&m_list);
00198 }
00199
00200 public:
00201
00202 typedef KeyValue value_type;
00203 typedef HashTableDetail::BucketIterator<Key, Value, hash_type> iterator;
00204
00205 private:
00206
00207 void initialise() {
00208 list_initialise(m_list);
00209 }
00210 hash_type hashKey(const Key& key) {
00211 return Hasher::operator()(key);
00212 }
00213
00214 std::size_t getBucketId(hash_type hash) const {
00215 return hash & (m_bucketCount - 1);
00216 }
00217 Bucket& getBucket(hash_type hash) {
00218 return m_buckets[getBucketId(hash)];
00219 }
00220 BucketNode* bucket_find(Bucket bucket, hash_type hash, const Key& key) {
00221 std::size_t bucketId = getBucketId(hash);
00222 for (iterator i(bucket); i != end(); ++i) {
00223 hash_type nodeHash = i.node()->m_hash;
00224
00225 if (getBucketId(nodeHash) != bucketId) {
00226 return 0;
00227 }
00228
00229 if (nodeHash == hash && KeyEqual::operator()((*i).key, key)) {
00230 return i.node();
00231 }
00232 }
00233 return 0;
00234 }
00235 BucketNode* bucket_insert(Bucket& bucket, BucketNode* node) {
00236
00237 node_link(node, bucket_next(bucket));
00238 bucket = node;
00239 return node;
00240 }
00241 BucketNode* bucket_next(Bucket& bucket) {
00242 Bucket* end = m_buckets + m_bucketCount;
00243 for (Bucket* i = &bucket; i != end; ++i) {
00244 if (*i != 0) {
00245 return *i;
00246 }
00247 }
00248 return getLast();
00249 }
00250
00251 void buckets_resize(std::size_t count) {
00252 BucketNode* first = getFirst();
00253 BucketNode* last = getLast();
00254
00255 buckets_delete(m_buckets);
00256
00257 m_bucketCount = count;
00258
00259 m_buckets = buckets_new(m_bucketCount);
00260 initialise();
00261
00262 for (BucketNode* i = first; i != last;) {
00263 BucketNode* node = i;
00264 i = i->getNext();
00265 bucket_insert(getBucket((*node).m_hash), node);
00266 }
00267 }
00268 void size_increment() {
00269 if (m_size == m_bucketCount) {
00270 buckets_resize(m_bucketCount == 0 ? 8 : m_bucketCount << 1);
00271 }
00272 ++m_size;
00273 }
00274 void size_decrement() {
00275 --m_size;
00276 }
00277
00278 HashTable(const HashTable& other);
00279 HashTable& operator=(const HashTable& other);
00280 public:
00281 HashTable() : m_bucketCount(0), m_buckets(0), m_size(0) {
00282 initialise();
00283 }
00284 HashTable(std::size_t bucketCount) : m_bucketCount(HashTableDetail::next_power_of_two(bucketCount)), m_buckets(buckets_new(m_bucketCount)), m_size(0) {
00285 initialise();
00286 }
00287 ~HashTable() {
00288 for (BucketNode* i = getFirst(); i != getLast();) {
00289 BucketNode* node = i;
00290 i = i->getNext();
00291 node_destroy(node);
00292 }
00293 buckets_delete(m_buckets);
00294 }
00295
00296 iterator begin() {
00297 return iterator(getFirst());
00298 }
00299 iterator end() {
00300 return iterator(getLast());
00301 }
00302
00303 bool empty() const {
00304 return m_size == 0;
00305 }
00306 std::size_t size() const {
00307 return m_size;
00308 }
00309
00311 iterator find(const Key& key) {
00312 hash_type hash = hashKey(key);
00313 if (m_bucketCount != 0) {
00314 Bucket bucket = getBucket(hash);
00315 if (bucket != 0) {
00316 BucketNode* node = bucket_find(bucket, hash, key);
00317 if (node != 0) {
00318 return iterator(node);
00319 }
00320 }
00321 }
00322
00323 return end();
00324 }
00326 iterator insert(const Key& key, const Value& value) {
00327 hash_type hash = hashKey(key);
00328 if (m_bucketCount != 0) {
00329 Bucket& bucket = getBucket(hash);
00330 if (bucket != 0) {
00331 BucketNode* node = bucket_find(bucket, hash, key);
00332 if (node != 0) {
00333 return iterator(node);
00334 }
00335 }
00336 }
00337
00338 size_increment();
00339 return iterator(bucket_insert(getBucket(hash), node_create(hash, key, value)));
00340 }
00341
00345 void erase(iterator i) {
00346 Bucket& bucket = getBucket(i.node()->m_hash);
00347 BucketNode* node = i.node();
00348
00349
00350 if (bucket == node) {
00351 bucket = (node->getNext() == getLast() || &getBucket(node->getNext()->m_hash) != &bucket) ? 0 : node->getNext();
00352 }
00353
00354 node_unlink(node);
00355 ASSERT_MESSAGE(node != 0, "tried to erase a non-existent key/value");
00356 node_destroy(node);
00357
00358 size_decrement();
00359 }
00360
00362 Value& operator[](const Key& key) {
00363 hash_type hash = hashKey(key);
00364 if (m_bucketCount != 0) {
00365 Bucket& bucket = getBucket(hash);
00366 if (bucket != 0) {
00367 BucketNode* node = bucket_find(bucket, hash, key);
00368 if (node != 0) {
00369 return node->m_value.value;
00370 }
00371 }
00372 }
00373 size_increment();
00374 return bucket_insert(getBucket(hash), node_create(hash, key, Value()))->m_value.value;
00375 }
00377 void erase(const Key& key) {
00378 erase(find(key));
00379 }
00381 void swap(HashTable& other) {
00382 std::swap(m_buckets, other.m_buckets);
00383 std::swap(m_bucketCount, other.m_bucketCount);
00384 std::swap(m_size, other.m_size);
00385 HashTableDetail::list_swap(m_list, other.m_list);
00386 }
00388 void clear() {
00389 HashTable tmp;
00390 tmp.swap(*this);
00391 }
00392 };
00393
00394 #endif