00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023 #ifndef LOCKFREE_DICTIONARY_H_
00024 #define LOCKFREE_DICTIONARY_H_
00025
00026 #include <amino/smr.h>
00027
00028 namespace amino {
00029
00030 using namespace internal;
00031
00032 #define COPY_NODE(node) (node)
00033 #define READ_NODE(node) (IS_MARKED(node) ? NULL : (node))
00034 #define RELEASE_NODE(node) (node)
00035 #define GET_UNMARKED(p) ((DictNode<K,V>*)(((long)(p))&(~3)))
00036 #define GET_UNMARKED_VALUE(p) ((Value<V>*)(((long)(p))&(~3)))
00037
00038 #define GET_MARKED(p) ((DictNode<K,V>*)(((long)(p))|(1)))
00039
00040 #define GET_MARKED_VALUE(p) ((Value<V>*)(((long)(p))|(1)))
00041
00042 #define IS_MARKED(p) (((long)(p))&(1))
00043
00044 #define MAXLEVEL 10
00045 #define SLCONST 0.5
00046
00047 template<typename E> struct Value {
00048 E v;
00049
00050 Value() {
00051 }
00052
00053 Value(const E& value) :
00054 v(value) {
00055 }
00056 };
00057
00058 template<typename K, typename V> struct DictNode {
00059 int level;
00060
00061
00062
00063
00064
00065
00066
00067
00068
00069
00070
00071
00072
00073 int validLevel;
00074 int version;
00075 atomic<Value<V>*> value;
00076 DictNode<K, V>* prev;
00077 atomic<DictNode<K, V>*> next[MAXLEVEL];
00078 K key;
00079
00080 DictNode() :
00081 level(-1), validLevel(0), version(0), prev(NULL),key() {
00082 value.store(NULL);
00083 for (int i = 0; i < MAXLEVEL; ++i) {
00084 next[i].store(NULL, memory_order_relaxed);
00085 }
00086 }
00087
00088 DictNode(int l, K k, V v) :
00089 level(-1), validLevel(0), version(0), prev(NULL), key(k) {
00090 value.store(new Value<V> (v), memory_order_relaxed);
00091 for (int i = 0; i < MAXLEVEL; ++i) {
00092 next[i].store(NULL, memory_order_relaxed);
00093 }
00094 }
00095
00096 ~DictNode() {
00097 delete GET_UNMARKED_VALUE(value.load(memory_order_relaxed));
00098 }
00099
00100 };
00101
00102 template<typename K, typename V> class LockFreeDictionary {
00103 private:
00104 DictNode<K, V>* head;
00105 DictNode<K, V>* tail;
00106 DictNode<K, V>* INVALID;
00107
00108 SMR<DictNode<K, V>, MAXLEVEL>* mm;
00109 typedef typename SMR<DictNode<K, V>, MAXLEVEL>::HP_Rec HazardP;
00110
00111 public:
00112 LockFreeDictionary() {
00113 mm = getSMR<DictNode<K, V>, MAXLEVEL> ();
00114 head = new DictNode<K, V> ();
00115 head->validLevel = MAXLEVEL - 1;
00116 tail = new DictNode<K, V> ();
00117 tail->validLevel = MAXLEVEL - 1;
00118 INVALID = new DictNode<K, V> ();
00119
00120 for (int i = 0; i < MAXLEVEL; ++i) {
00121 head->next[i].store(tail, memory_order_relaxed);
00122 }
00123
00124 srand( time(NULL));
00125 }
00126
00127 ~LockFreeDictionary() {
00128 cout << "size = " << size() << endl;
00129
00130 DictNode<K,V> *curr = head->next[0].load(memory_order_relaxed);
00131 DictNode<K,V> *tmp = NULL;
00132 int count = 0;
00133 while (tail != curr) {
00134 ++count ;
00135
00136 tmp = curr;
00137 curr = curr->next[0].load(memory_order_relaxed);
00138 delete tmp;
00139 }
00140
00141 cout << "count " << count << endl;
00142
00143
00144
00145 delete head;
00146 delete tail;
00147 delete INVALID;
00148 }
00149
00150 bool empty() {
00151 return head->next[0].load(memory_order_relaxed) == tail;
00152 }
00153
00154 int size() {
00155 int size = 0;
00156 for (DictNode<K,V>* itr = head->next[0].load(memory_order_relaxed); itr
00157 != tail; itr = itr->next[0].load(memory_order_relaxed)) {
00158 ++size;
00159 }
00160 return size;
00161 }
00162
00168 bool insert(const K& key, const V& value) {
00169 DictNode<K, V>* node1;
00170 DictNode<K, V>* node2;
00171 Value<V>* value2;
00172 int curLevel = randomLevel();
00173
00174
00175
00176
00177
00178 HazardP * hp = mm->getHPRec();
00179 #ifdef FREELIST
00180 DictNode<K, V>* newNode = mm->newNode(hp);
00181 newNode->level = curLevel;
00182 newNode->key = key;
00183 newNode->value.store(new Value<V> (value), memory_order_relaxed);
00184 #else
00185 DictNode<K, V>* newNode = new DictNode<K, V> (curLevel, key, value);
00186 #endif
00187
00188
00189
00190
00191
00192
00193
00194
00195
00196
00197 DictNode<K, V>* savedNodes[MAXLEVEL + 1];
00198 savedNodes[MAXLEVEL] = head;
00199 for (int i = MAXLEVEL - 1; i >= 0; --i) {
00200 savedNodes[i] = searchLevel(savedNodes[i + 1], i, key);
00201 if ((i < MAXLEVEL - 1) && (i >= curLevel - 1)) {
00202
00203 mm->retire(hp,savedNodes[i + 1]);
00204 }
00205 }
00206 node1 = savedNodes[0];
00207
00208
00209
00210 while (true) {
00211
00212 node2 = scanKey(node1, 0, key);
00213 value2 = node2->value.load(memory_order_relaxed);
00214
00215
00216
00217
00218
00219
00220 if (!IS_MARKED(value2) && (node2->value.load(memory_order_relaxed) != NULL
00221 && node2->key == key)) {
00222 if (node2->value.compare_swap(value2, new Value<V> (value))) {
00223
00224
00225 mm->retire(hp,node1);
00226 mm->retire(hp,node2);
00227 for (int i = 1; i < curLevel; ++i) {
00228
00229 mm->retire(hp,savedNodes[i]);
00230 }
00231
00232
00233 delete value2;
00234 mm->delNode(hp,newNode);
00235
00236 return true;
00237 } else {
00238
00239 mm->retire(hp,node2);
00240 continue;
00241 }
00242 }
00243
00244
00245
00246 newNode->next[0] = node2;
00247
00248 mm->retire(hp,node2);
00249 if (node1->next[0].compare_swap(node2, newNode)) {
00250
00251 mm->retire(hp,node1);
00252 break;
00253 }
00254
00255 }
00256 ++newNode->version;
00257 newNode->validLevel = 1;
00258
00259 for (int i = 1; i < curLevel; ++i) {
00260 node1 = savedNodes[i];
00261 while (true) {
00262 node2 = scanKey(node1, i, key);
00263 newNode->next[i] = node2;
00264
00265 mm->retire(hp,node2);
00266
00267
00268
00269 if (IS_MARKED(newNode->value.load(memory_order_relaxed))) {
00270
00271 mm->retire(hp,node1);
00272 break;
00273 }
00274 if (node1->next[i].compare_swap(node2, newNode)) {
00275 newNode->validLevel = i + 1;
00276
00277 mm->retire(hp,node1);
00278 break;
00279 }
00280
00281 }
00282 }
00283 if (IS_MARKED(newNode->value.load(memory_order_relaxed))) {
00284 newNode = helpDelete(newNode, 0);
00285 }
00286
00287 mm->retire(hp,newNode);
00288
00289 return true;
00290 }
00291
00292 bool findKey(const K& key, V& value) {
00293 HazardP * hp = mm->getHPRec();
00294
00295
00296 try_again:
00297 DictNode<K, V>* last = COPY_NODE(head);
00298
00299 mm->employ(hp,0,head);
00300 if(last != head) {
00301 goto try_again;
00302 }
00303
00304 DictNode<K, V>* node1;
00305 DictNode<K, V>* node2;
00306
00307 for (int i = MAXLEVEL - 1; i >= 0; --i) {
00308 node1 = searchLevel(last, i, key);
00309
00310 mm->retire(hp,last);
00311 last = node1;
00312 }
00313 node2 = scanKey(last, 0, key);
00314
00315 mm->retire(hp,last);
00316 Value<V>* result = node2->value.load(memory_order_relaxed);
00317 if ((node2->key != key) || (IS_MARKED(result))) {
00318
00319 mm->retire(hp,node2);
00320 return false;
00321 }
00322 value = result->v;
00323
00324 mm->retire(hp,node2);
00325 return true;
00326 }
00327
00328 bool deleteKey(const K& key, V& value) {
00329 return _delete(key, false, value);
00330 }
00331
00332 bool findValue(const V& value, K& key) {
00333 return fDValue(value, false, key);
00334 }
00335
00336 bool deleteValue(const V& value, K& key) {
00337 return fDValue(value, true, key);
00338 }
00339
00340 void dumpQueue() {
00341 cout << "dumping queue...................." << endl;
00342 cout << "head:" << head << "->";
00343 for (DictNode<K, V>* itr = head->next[0].load(memory_order_relaxed); itr
00344 != tail; itr = itr->next[0].load(memory_order_relaxed)) {
00345 cout << GET_UNMARKED(itr)->value.load(memory_order_relaxed)->v << "->";
00346 }
00347 cout << tail << ":tail" << endl;
00348 }
00349
00350 private:
00351
00352
00353
00354
00355 DictNode<K, V>* searchLevel(DictNode<K, V>*& last, int curLevel,
00356 const K& expectedKey) {
00357 HazardP * hp = mm->getHPRec();
00358
00359 try_again:
00360 DictNode<K, V>* curNode = last;
00361 DictNode<K, V>* stop = NULL;
00362 DictNode<K, V>* nextNode = NULL;
00363 int count1 = 0;
00364 int count2 = 0;
00365 int count3 = 0;
00366 int count4 = 0;
00367 int count5 = 0;
00368
00369 assert(last != tail);
00370
00371 while (true) {
00372
00373 nextNode = GET_UNMARKED(curNode->next[curLevel].load(memory_order_relaxed));
00374 ++count1;
00375 if (count1 > 10000) {
00376 cout << "searchLevel loop1: count1 = " << count1
00377 << ", count2 = " << count2 << ", count3 = " << count3
00378 << ", count4 = " << count4 << ", count5 = " << count5
00379 << endl;
00380 exit(1);
00381
00382
00383
00384
00385 }
00386
00387
00388
00389 if (NULL == nextNode) {
00390
00391 ++count2;
00392 if (curNode == last) {
00393 last = helpDelete(last, curLevel);
00394 }
00395 curNode = last;
00396 }
00397
00398
00399
00400 else if ((nextNode != head) && (nextNode == tail || nextNode->key
00401 >= expectedKey)) {
00402 assert(nextNode != head);
00403
00404
00405 ++count3;
00406
00407
00408
00409
00410
00411
00412
00413
00414
00415
00416
00417
00418
00419
00420
00421
00422
00423
00424
00425
00426
00427
00428
00429
00430
00431
00432 if ((curNode->validLevel > curLevel || curNode == last
00433 || curNode == stop) && ((curNode != tail) && (curNode
00434 == head || curNode->key < expectedKey))
00435 && (last == head || curNode == tail || curNode->key
00436 >= last->key)) {
00437 if (curNode->validLevel <= curLevel) {
00438
00439 mm->retire(hp,curNode);
00440
00441 curNode = COPY_NODE(last);
00442 mm->employ(hp,0,last);
00443 if(curNode != last) {
00444 goto try_again;
00445 }
00446
00447 nextNode = scanKey(curNode, curLevel, expectedKey);
00448
00449 mm->retire(hp,nextNode);
00450 }
00451 return curNode;
00452 }
00453
00454 mm->retire(hp,curNode);
00455
00456 stop = curNode;
00457 if (IS_MARKED(last->value.load(memory_order_relaxed))) {
00458 last = helpDelete(last, curLevel);
00459 }
00460 curNode = last;
00461 }
00462
00463
00464
00465 else if (last != tail && nextNode != head && (last == head
00466 || nextNode == tail || nextNode->key >= last->key)) {
00467
00468 assert(last != tail);
00469 ++count4;
00470 curNode = nextNode;
00471 }
00472
00473
00474
00475
00476 else {
00477
00478 ++count5;
00479 if (IS_MARKED(last->value.load(memory_order_relaxed))) {
00480 last = helpDelete(last, curLevel);
00481 }
00482 curNode = last;
00483 }
00484 }
00485 }
00486
00487 bool _delete(const K& key, bool delval, V& value) {
00488
00489
00490 HazardP * hp = mm->getHPRec();
00491
00492 DictNode<K, V>* node1;
00493 DictNode<K, V>* node2;
00494 DictNode<K, V>* prev;
00495 DictNode<K, V>* last;
00496 DictNode<K, V>* savedNodes[MAXLEVEL + 1];
00497 Value<V>* tmpValue = NULL;
00498
00499
00500
00501
00502
00503
00504
00505
00506
00507
00508 try_again:
00509 savedNodes[MAXLEVEL] = head;
00510 for (int i = MAXLEVEL - 1; i >= 0; --i) {
00511 savedNodes[i] = searchLevel(savedNodes[i + 1], i, key);
00512 }
00513 node1 = scanKey(savedNodes[0], 0, key);
00514
00515
00516 if (node1 == tail) {
00517 cout << "didn't find" << endl;
00518 return false;
00519 }
00520
00521 int count1 = 0;
00522 while (true) {
00523 ++count1;
00524 if (count1 > 10000) {
00525 cout << "_delete loop1 : " << count1 << endl;
00526 }
00527
00528 if (!delval) {
00529 tmpValue = node1->value.load(memory_order_relaxed);
00530 value = tmpValue->v;
00531 }
00532
00533
00534
00535
00536 if ((node1->value.load(memory_order_relaxed) != NULL && node1->key
00537 == key) && (!delval || node1->value.load(
00538 memory_order_relaxed)->v == value) && (!IS_MARKED(tmpValue))) {
00539
00540
00541
00542
00543
00544
00545
00546
00547 if (node1->value.compare_swap(tmpValue, GET_MARKED_VALUE(tmpValue))) {
00548 int index = (node1->level - 1)/2;
00549 node1->prev = COPY_NODE(savedNodes[index]);
00550
00551 mm->employ(hp,0,savedNodes[index]);
00552 if(node1->prev != savedNodes[index]) {
00553 goto try_again;
00554 }
00555
00556 break;
00557 } else {
00558 continue;
00559 }
00560 }
00561
00562 mm->retire(hp,node1);
00563 for (int i = 0; i < MAXLEVEL; ++i) {
00564
00565 mm->retire(hp,savedNodes[i]);
00566 }
00567 return false;
00568 }
00569
00570
00571
00572
00573
00574 for (int i = 0; i < node1->level; ++i) {
00575 do {
00576 node2 = node1->next[i].load(memory_order_relaxed);
00577 } while (!IS_MARKED(node2) && !node1->next[i].compare_swap(node2, GET_MARKED(node2)));
00578 }
00579
00580
00581
00582
00583
00584
00585
00586
00587
00588
00589
00590 int count2 = 0;
00591 for (int i = node1->level - 1; i >= 0; --i) {
00592 prev = savedNodes[i];
00593 while (true) {
00594 ++count2;
00595 if (count2 > 10000) {
00596 cout << "_delete loop2 : " << count2 << endl;
00597 }
00598
00599 if (node1->next[i].load(memory_order_relaxed) == INVALID) {
00600 break;
00601 }
00602 last = scanKey(prev, i, node1->key);
00603
00604 mm->retire(hp,last);
00605 if ((last != node1) || (node1->next[i].load(
00606 memory_order_relaxed) == INVALID)) {
00607 break;
00608 }
00609 if (prev->next[i].compare_swap(node1, GET_UNMARKED(node1->next[i].load(memory_order_relaxed)))) {
00610 node1->next[i].store(INVALID, memory_order_relaxed);
00611 break;
00612 }
00613 if (node1->next[i].load(memory_order_relaxed) == INVALID) {
00614 break;
00615 }
00616
00617 }
00618
00619 mm->retire(hp,prev);
00620 }
00621 for (int i = node1->level; i < MAXLEVEL; ++i) {
00622
00623 mm->retire(hp,savedNodes[i]);
00624 }
00625
00626
00627
00628 mm->delNode(node1);
00629 return true;
00630 }
00631
00632
00633
00634
00635
00636
00637
00638
00639 bool fDValue(const V& value, bool del, K& key) {
00640 HazardP * hp = mm->getHPRec();
00641 DictNode<K, V>* node1;
00642 DictNode<K, V>* node2;
00643 bool ok;
00644 int version;
00645 int version2;
00646 K key2;
00647
00648 int jump = 16;
00649 try_again:
00650 DictNode<K, V>* last = COPY_NODE(head);
00651
00652 mm->employ(hp,0,head);
00653 if(last != head) {
00654 goto try_again;
00655 }
00656
00657 next_jump: node1 = last;
00658 K key1 = node1->key;
00659 int step = 0;
00660
00661 while (true) {
00662 ok = false;
00663
00664
00665
00666
00667
00668
00669
00670 version = node1->version;
00671 node2 = node1->next[0].load(memory_order_relaxed);
00672 if (!IS_MARKED(node2) && (node2 != NULL)) {
00673 version2 = node2->version;
00674 key2 = node2->key;
00675
00676 assert(node1 != tail);
00677 assert(node2 != head);
00678
00679
00680 if ((node1->key == key1) && (node1->validLevel > 0)
00681 && (node1->next[0].load(memory_order_relaxed) == node2
00682 && (node1->version == version) && (node2->key
00683 == key2) && (node2->validLevel > 0)
00684 && (node2->version == version2))) {
00685 ok = true;
00686 }
00687 }
00688
00689
00690
00691
00692 if (!ok) {
00693 node1 = node2 = readNext(last, 0);
00694 key1 = key2 = node2->key;
00695 version2 = node2->version;
00696
00697 mm->retire(hp,last);
00698 last = node2;
00699 step = 0;
00700 }
00701
00702
00703
00704
00705
00706 if (node2 == tail) {
00707
00708
00709 mm->retire(hp,last);
00710 return false;
00711 }
00712
00713
00714
00715
00716 if (node2->value.load(memory_order_relaxed)->v == value) {
00717 if (node2->version == version2) {
00718
00719
00720
00721
00722
00723
00724 if (del) {
00725
00726 V tmpValue = value;
00727
00728 bool result = _delete(key2, true, tmpValue);
00729
00730
00731 if (result && tmpValue == value) {
00732
00733 mm->retire(hp,last);
00734 key = key2;
00735
00736 return true;
00737 }
00738 } else {
00739
00740
00741 mm->retire(hp,last);
00742 key = key2;
00743
00744 return true;
00745 }
00746 }
00747 }
00748
00749
00750
00751
00752 else if (++step >= jump) {
00753
00754
00755
00756
00757
00758
00759 if ((node2->validLevel == 0) || (node2->key != key2)) {
00760
00761 mm->retire(hp,node2);
00762 node2 = readNext(last, 0);
00763 if (jump >= 4) {
00764 jump /= 2;
00765 }
00766 } else {
00767 jump += jump / 2;
00768 }
00769
00770 mm->retire(hp,last);
00771 last = node2;
00772 goto next_jump;
00773 } else {
00774 key1 = key2;
00775 node1 = node2;
00776 }
00777 }
00778 }
00790 DictNode<K, V>* readNext(DictNode<K, V>*& node, int curLevel) {
00791
00792 DictNode<K, V>* nextNode;
00793 DictNode<K, V>* tmpNode;
00794
00795 assert(node != tail);
00796 if (IS_MARKED(node->value.load(memory_order_relaxed))) {
00797 node = helpDelete(node, curLevel);
00798 }
00799 tmpNode = node->next[curLevel].load(memory_order_relaxed);
00800 nextNode = READ_NODE(tmpNode);
00801 while (NULL == nextNode) {
00802 node = helpDelete(node, curLevel);
00803 tmpNode = node->next[curLevel].load(memory_order_relaxed);
00804 nextNode = READ_NODE(tmpNode);
00805 }
00806 assert(!IS_MARKED(nextNode));
00807 return nextNode;
00808 }
00809
00816 DictNode<K, V>* scanKey(DictNode<K, V>*& node, int curLevel,
00817 const K& expectedKey) {
00818 HazardP * hp = mm->getHPRec();
00819
00820 DictNode<K, V>* nextNode = readNext(node, curLevel);
00821 while ((nextNode != tail) && ((nextNode == head) || nextNode->key
00822 < expectedKey)) {
00823
00824 mm->retire(hp,node);
00825 node = nextNode;
00826 nextNode = readNext(node, curLevel);
00827 }
00828 return nextNode;
00829 }
00830
00858 DictNode<K, V>* helpDelete(DictNode<K, V>* node, int curLevel) {
00859 assert(node != tail && node != head);
00860 HazardP * hp = mm->getHPRec();
00861
00862 DictNode<K, V>* prevNode;
00863 DictNode<K, V>* node1;
00864 DictNode<K, V>* node2;
00865 DictNode<K, V>* last;
00866
00867 try_again:
00868
00869
00870
00871
00872
00873 for (int i = curLevel; i < node->level; ++i) {
00874
00875
00876
00877
00878 do {
00879 node2 = node->next[i].load(memory_order_relaxed);
00880 } while (!IS_MARKED(node2) && (!node->next[i].compare_swap(node2, GET_MARKED(node2))));
00881 }
00882
00883
00884
00885
00886
00887
00888 prevNode = node->prev;
00889 if ((prevNode == NULL) || (curLevel >= prevNode->validLevel)) {
00890 prevNode = COPY_NODE(head);
00891
00892 mm->employ(hp,0,head);
00893 if(prevNode != head) {
00894 goto try_again;
00895 }
00896 } else {
00897
00898
00899 }
00900
00901
00902
00903
00904
00905 int count1 = 0;
00906 while (true) {
00907 ++count1;
00908 if (count1 > 10000) {
00909 cout << "helpDelete loop1 " << count1 << endl;
00910 exit(1);
00911 }
00912 if (INVALID == node->next[curLevel].load(memory_order_relaxed)) {
00913 break;
00914 }
00915
00916 for (int i = prevNode->validLevel - 1; i >= curLevel; --i) {
00917 node1 = searchLevel(prevNode, i, node->key);
00918
00919 mm->retire(hp,prevNode);
00920 prevNode = node1;
00921 }
00922 last = scanKey(prevNode, curLevel, node->key);
00923
00924 mm->retire(hp,last);
00925 if (last != node || INVALID == node->next[curLevel].load(
00926 memory_order_relaxed)) {
00927 break;
00928 }
00929
00930 if (prevNode->next[curLevel].compare_swap(node, GET_UNMARKED(
00931 node->next[curLevel].load(memory_order_relaxed)))) {
00932 node->next[curLevel].store(INVALID, memory_order_relaxed);
00933 break;
00934 }
00935 if (INVALID == node->next[curLevel].load(memory_order_relaxed)) {
00936 break;
00937 }
00938
00939
00940 }
00941
00942
00943 mm->retire(hp,node);
00944 return prevNode;
00945 }
00946
00947 int randomLevel() {
00948 int v = 1;
00949 while (((static_cast<double> (rand()) / RAND_MAX) < SLCONST) && (v
00950 < MAXLEVEL - 1)) {
00951 ++v;
00952 }
00953
00954 return v;
00955 }
00956
00957 };
00958 }
00959
00960 #endif