00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022 #ifndef AMINO_LOCKFREE_SET_H
00023 #define AMINO_LOCKFREE_SET_H
00024
00025 #include "ordered_list.h"
00026 using namespace std;
00027
00028 namespace amino {
00037 template<typename KeyType> class SetNode {
00038 public:
00039
00040 KeyType element;
00041
00042 unsigned int key;
00043
00047 SetNode(const KeyType& elem, unsigned int k) :
00048 element(elem), key(k) {
00049 }
00050
00055 explicit SetNode(unsigned int k) :
00056 key(k) {
00057 element = KeyType();
00058 }
00059
00060 SetNode() {
00061 }
00062
00063 const SetNode& operator=(const SetNode& node) {
00064 key = node.key;
00065 element = node.element;
00066 return node;
00067 }
00068
00069
00070 bool operator==(const SetNode<KeyType>& rightHand) {
00071 return this->key == rightHand.key && this->element == rightHand.element;
00072 }
00073
00074
00075 bool operator>=(const SetNode<KeyType>& rightHand) {
00076 if (this->key == rightHand.key)
00077 return this->element >= rightHand.element;
00078 return this->key >= rightHand.key;
00079 }
00080 };
00081
00097 template<typename KeyType> class Set {
00098 private:
00099
00100 static const int DEFAULT_ARRAY_SIZE = 512;
00101
00102
00103 static const int DEFAULT_SEGMENT_SIZE = 64;
00104
00105
00106 static const int MINIMAL_SEGMENT_SIZE = 8;
00107
00108
00109 static const float MAX_LOAD = 0.75;
00110
00111
00112 OrderedList<SetNode<KeyType> > *oList;
00113
00114
00115 atomic<atomic<NodeType<SetNode<KeyType> > *> *> *mainArray;
00116
00117 typedef NodeType<SetNode<KeyType> >* node_ptr;
00118 typedef atomic<atomic<node_ptr> *> array_t;
00119
00120
00121 atomic<int> _count;
00122
00123
00124 int segmentSize;
00125
00126
00127 float loadFactor;
00128
00129
00130 atomic<int> _size;
00131
00132
00133 Set(const Set&) {
00134 }
00135 Set& operator=(const Set&) {
00136 }
00137 public:
00148 explicit Set(int expectedSetSize = DEFAULT_SEGMENT_SIZE
00149 * DEFAULT_ARRAY_SIZE, float expectedLoadFactor = MAX_LOAD) {
00150
00151 segmentSize = (getLargestValue(expectedSetSize / DEFAULT_ARRAY_SIZE))
00152 << 1;
00153 if (segmentSize < MINIMAL_SEGMENT_SIZE)
00154 segmentSize = MINIMAL_SEGMENT_SIZE;
00155
00156 loadFactor = expectedLoadFactor;
00157
00158 _size = 2;
00159 _count = 0;
00160
00161 oList = new OrderedList<SetNode<KeyType> > ;
00162
00163
00164 mainArray = new array_t[DEFAULT_ARRAY_SIZE];
00165 for (int i = 0; i < DEFAULT_ARRAY_SIZE; ++i) {
00166 mainArray[i] = NULL;
00167 }
00168
00169
00170 SetNode<KeyType> dummy(0);
00171 node_ptr address = NULL;
00172 if ((address = oList->add_returnAddress(dummy, &oList->head)) != NULL) {
00173 set_bucket(0, address);
00174 }
00175 assert(address != NULL);
00176 assert(get_bucket(0).load(memory_order_relaxed) != NULL);
00177 }
00178
00182 ~Set() {
00183 for (int i = 0; i < DEFAULT_ARRAY_SIZE; ++i)
00184 delete[] mainArray[i].load(memory_order_relaxed);
00185 delete[] mainArray;
00186 delete oList;
00187 }
00188
00194 bool empty() {
00195 return _count.load(memory_order_relaxed) == 0;
00196 }
00197
00204 unsigned int size() {
00205 return _count.load(memory_order_relaxed);
00206 }
00207
00216 bool insert(KeyType element) {
00217
00218 int key = hash_function(element);
00219
00220 unsigned int bucket = key % _size.load(memory_order_relaxed);
00221 assert(binary_reverse(binary_reverse(bucket)) == bucket);
00222
00223 if (get_bucket(bucket).load(memory_order_relaxed) == NULL) {
00224 initialize_bucket(bucket);
00225 }
00226
00227 assert(get_bucket(bucket).load(memory_order_relaxed) != NULL);
00228
00229 SetNode<KeyType> node = SetNode<KeyType> (element, regularKey(key));
00230
00231
00232 atomic<node_ptr> start = get_bucket(bucket);
00233
00234
00235 if (!oList->add(node, &start)) {
00236 return false;
00237 }
00238
00239 int old_size = _size.load(memory_order_relaxed);
00240
00241 ++_count;
00242
00243 if (_count.load(memory_order_relaxed) / old_size > loadFactor && old_size
00244 < DEFAULT_ARRAY_SIZE * segmentSize)
00245 _size.compare_swap(old_size, 2 * old_size);
00246 return true;
00247 }
00248
00257 bool remove(KeyType element) {
00258
00259 int key = hash_function(element);
00260
00261 unsigned int bucket = key % _size.load(memory_order_relaxed);
00262
00263
00264 if (get_bucket(bucket).load(memory_order_relaxed) == NULL)
00265 initialize_bucket(bucket);
00266 assert(get_bucket(bucket).load(memory_order_relaxed) != NULL);
00267
00268
00269 SetNode<KeyType> node = SetNode<KeyType> (element, regularKey(key));
00270
00271 atomic<node_ptr> start = get_bucket(bucket);
00272
00273
00274 if (!oList->remove(node, &start))
00275 return false;
00276
00277 --_count;
00278 return true;
00279 }
00280
00289 bool search(const KeyType& element) {
00290
00291 unsigned int key = hash_function(element);
00292
00293 unsigned int bucket = key % _size.load(memory_order_relaxed);
00294
00295 if (get_bucket(bucket).load(memory_order_relaxed) == NULL)
00296 initialize_bucket(bucket);
00297 assert(get_bucket(bucket).load(memory_order_relaxed) != NULL);
00298
00299
00300 SetNode<KeyType> node(element, regularKey(key));
00301
00302 atomic<node_ptr> start = get_bucket(bucket);
00303
00304
00305 return oList->search(node, &start);
00306 }
00307
00308 private:
00317 unsigned int getLargestValue(unsigned int n) {
00318 n |= (n >> 1);
00319 n |= (n >> 2);
00320 n |= (n >> 4);
00321 n |= (n >> 8);
00322 n |= (n >> 16);
00323 return n - (n >> 1);
00324 }
00325
00334 unsigned int binary_reverse(unsigned int n) {
00335 n = ((n & 0x55555555) << 1) | ((n >> 1) & 0x55555555);
00336 n = ((n & 0x33333333) << 2) | ((n >> 2) & 0x33333333);
00337 n = ((n & 0x0f0f0f0f) << 4) | ((n >> 4) & 0x0f0f0f0f);
00338 n = (n << 24) | ((n & 0xff00) << 8) | ((n >> 8) & 0xff00) | (n >> 24);
00339 return n;
00340 }
00341
00350 unsigned int dummyKey(unsigned int key) {
00351 return binary_reverse(key);
00352 }
00353
00362 unsigned int regularKey(unsigned int key) {
00363 return binary_reverse(key | 0x80000000);
00364 }
00365
00374 unsigned int get_parent(unsigned int bucket) {
00375 return bucket - getLargestValue(bucket);
00376 }
00377
00385 void initialize_bucket(unsigned int bucket) {
00386 unsigned int parent = get_parent(bucket);
00387
00388 if (get_bucket(parent).load(memory_order_relaxed) == NULL) {
00389 initialize_bucket(parent);
00390 }
00391 assert(get_bucket(parent).load(memory_order_relaxed) != NULL);
00392
00393 SetNode<KeyType> dummy = SetNode<KeyType> (dummyKey(bucket));
00394
00395 assert(binary_reverse(dummyKey(bucket)) == bucket);
00396
00397 atomic<node_ptr> start = get_bucket(parent);
00398
00399 node_ptr address = oList->add_returnAddress(dummy, &start);
00400
00401
00402 assert(oList->search(dummy, &start));
00403
00404 set_bucket(bucket, address);
00405 assert(get_bucket(bucket).load(memory_order_relaxed) == address);
00406 }
00407
00416 atomic<node_ptr> get_bucket(unsigned int bucket) {
00417 int local_size = _size.load(memory_order_relaxed);
00418 bucket %= local_size;
00419
00420 int segment = bucket / segmentSize;
00421
00422
00423 if (mainArray[segment].load(memory_order_relaxed) == NULL) {
00424 set_bucket(bucket, NULL);
00425 }
00426 return mainArray[segment].load(memory_order_relaxed)[bucket % segmentSize];
00427 }
00428
00437 void set_bucket(int bucket, node_ptr head) {
00438 int segment = bucket / segmentSize;
00439
00440
00441 if (mainArray[segment].load(memory_order_relaxed) == NULL) {
00442 atomic<node_ptr>* new_segment = new atomic<node_ptr> [segmentSize];
00443 for (int i = 0; i < segmentSize; ++i) {
00444 new_segment[i] = NULL;
00445 }
00446 atomic<node_ptr> *UNINITIALIZE = NULL;
00447
00448 if (!mainArray[segment].compare_swap(UNINITIALIZE, new_segment))
00449 delete[] new_segment;
00450 }
00451
00452 if (head != NULL) {
00453 node_ptr BUCKET_UNINITIALIZE = NULL;
00454 mainArray[segment].load(memory_order_relaxed)[bucket % segmentSize].compare_swap(
00455 BUCKET_UNINITIALIZE, head);
00456 assert(get_bucket(bucket).load(memory_order_relaxed) == head);
00457 }
00458 }
00459 };
00460 }
00461
00475 template<typename KeyType> int hash_function(KeyType element) {
00476 throw std::runtime_error("Need hash function");
00477 }
00478
00482 template<> int hash_function<int> (int element) {
00483 return element;
00484 }
00485
00486 template<> int hash_function<unsigned int> (unsigned int element) {
00487 return static_cast<int> (element);
00488 }
00489
00490 template<> int hash_function<long> (long element) {
00491 return static_cast<int> (element);
00492 }
00493
00494 template<> int hash_function<unsigned long> (unsigned long element) {
00495 return static_cast<int> (element);
00496 }
00497
00498 template<> int hash_function<char> (char element) {
00499 return static_cast<int> (element);
00500 }
00501
00502 template<> int hash_function<unsigned char> (unsigned char element) {
00503 return static_cast<int> (element);
00504 }
00505
00506 template<> int hash_function<signed char> (signed char element) {
00507 return static_cast<int> (element);
00508 }
00509
00510 template<> int hash_function<short> (short element) {
00511 return static_cast<int> (element);
00512 }
00513
00514 template<> int hash_function<std::string> (std::string element) {
00515 unsigned long result = 0;
00516 for (string::iterator iter = element.begin(); iter != element.end(); ++iter)
00517 result = 5 * result + *iter;
00518
00519 return static_cast<int> (result);
00520 }
00521
00522 #endif