| 1 | // Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | // Licensed under the MIT license. |
| 3 | |
| 4 | #pragma once |
| 5 | |
| 6 | #include <atomic> |
| 7 | #include <cassert> |
| 8 | #include <cinttypes> |
| 9 | #include <cstdint> |
| 10 | #include <deque> |
| 11 | #include <fstream> |
| 12 | #include <iostream> |
| 13 | #include <string> |
| 14 | #include <thread> |
| 15 | |
| 16 | #include "core/auto_ptr.h" |
| 17 | #include "core/faster.h" |
| 18 | #include "core/thread.h" |
| 19 | #include "sum_store.h" |
| 20 | |
| 21 | namespace sum_store { |
| 22 | |
| 23 | class ConcurrentRecoveryTest { |
| 24 | public: |
| 25 | static constexpr uint64_t kNumUniqueKeys = (1L << 22); |
| 26 | static constexpr uint64_t kKeySpace = (1L << 14); |
| 27 | static constexpr uint64_t kNumOps = (1L << 25); |
| 28 | static constexpr uint64_t kRefreshInterval = (1L << 8); |
| 29 | static constexpr uint64_t kCompletePendingInterval = (1L << 12); |
| 30 | static constexpr uint64_t kCheckpointInterval = (1L << 22); |
| 31 | |
| 32 | ConcurrentRecoveryTest(store_t& store_, size_t num_threads_) |
| 33 | : store{ store_ } |
| 34 | , num_threads{ num_threads_ } |
| 35 | , num_active_threads{ 0 } |
| 36 | , num_checkpoints{ 0 } { |
| 37 | } |
| 38 | |
| 39 | private: |
| 40 | static void PopulateWorker(store_t* store, size_t thread_idx, |
| 41 | std::atomic<size_t>* num_active_threads, size_t num_threads, |
| 42 | std::atomic<uint32_t>* num_checkpoints) { |
| 43 | auto callback = [](IAsyncContext* ctxt, Status result) { |
| 44 | CallbackContext<RmwContext> context{ ctxt }; |
| 45 | assert(result == Status::Ok); |
| 46 | }; |
| 47 | |
| 48 | auto hybrid_log_persistence_callback = [](Status result, uint64_t persistent_serial_num) { |
| 49 | if(result != Status::Ok) { |
| 50 | printf("Thread %" PRIu32 " reports checkpoint failed.\n" , |
| 51 | Thread::id()); |
| 52 | } else { |
| 53 | printf("Thread %" PRIu32 " reports persistence until %" PRIu64 "\n" , |
| 54 | Thread::id(), persistent_serial_num); |
| 55 | } |
| 56 | }; |
| 57 | |
| 58 | // Register thread with the store |
| 59 | store->StartSession(); |
| 60 | |
| 61 | ++(*num_active_threads); |
| 62 | |
| 63 | // Process the batch of input data |
| 64 | for(size_t idx = 0; idx < kNumOps; ++idx) { |
| 65 | RmwContext context{ idx % kNumUniqueKeys, 1 }; |
| 66 | store->Rmw(context, callback, idx); |
| 67 | if(idx % kCheckpointInterval == 0 && *num_active_threads == num_threads) { |
| 68 | Guid token; |
| 69 | if(store->Checkpoint(nullptr, hybrid_log_persistence_callback, token)) { |
| 70 | printf("Thread %" PRIu32 " calling Checkpoint(), version = %" PRIu32 ", token = %s\n" , |
| 71 | Thread::id(), ++(*num_checkpoints), token.ToString().c_str()); |
| 72 | } |
| 73 | } |
| 74 | if(idx % kCompletePendingInterval == 0) { |
| 75 | store->CompletePending(false); |
| 76 | } else if(idx % kRefreshInterval == 0) { |
| 77 | store->Refresh(); |
| 78 | } |
| 79 | } |
| 80 | |
| 81 | // Make sure operations are completed |
| 82 | store->CompletePending(true); |
| 83 | |
| 84 | // Deregister thread from FASTER |
| 85 | store->StopSession(); |
| 86 | |
| 87 | printf("Populate successful on thread %" PRIu32 ".\n" , Thread::id()); |
| 88 | } |
| 89 | |
| 90 | public: |
| 91 | void Populate() { |
| 92 | std::deque<std::thread> threads; |
| 93 | for(size_t idx = 0; idx < num_threads; ++idx) { |
| 94 | threads.emplace_back(&PopulateWorker, &store, idx, &num_active_threads, num_threads, |
| 95 | &num_checkpoints); |
| 96 | } |
| 97 | for(auto& thread : threads) { |
| 98 | thread.join(); |
| 99 | } |
| 100 | // Verify the records. |
| 101 | auto callback = [](IAsyncContext* ctxt, Status result) { |
| 102 | CallbackContext<ReadContext> context{ ctxt }; |
| 103 | assert(result == Status::Ok); |
| 104 | }; |
| 105 | // Create array for reading |
| 106 | auto read_results = alloc_aligned<uint64_t>(64, sizeof(uint64_t) * kNumUniqueKeys); |
| 107 | std::memset(read_results.get(), 0, sizeof(uint64_t) * kNumUniqueKeys); |
| 108 | |
| 109 | // Register with thread |
| 110 | store.StartSession(); |
| 111 | |
| 112 | // Issue read requests |
| 113 | for(uint64_t idx = 0; idx < kNumUniqueKeys; ++idx) { |
| 114 | ReadContext context{ AdId{ idx }, read_results.get() + idx }; |
| 115 | store.Read(context, callback, idx); |
| 116 | } |
| 117 | |
| 118 | // Complete all pending requests |
| 119 | store.CompletePending(true); |
| 120 | |
| 121 | // Release |
| 122 | store.StopSession(); |
| 123 | for(uint64_t idx = 0; idx < kNumUniqueKeys; ++idx) { |
| 124 | uint64_t expected_result = (num_threads * kNumOps) / kNumUniqueKeys; |
| 125 | if(read_results.get()[idx] != expected_result) { |
| 126 | printf("Debug error for AdId %" PRIu64 ": Expected (%" PRIu64 "), Found(%" PRIu64 ")\n" , |
| 127 | idx, |
| 128 | expected_result, |
| 129 | read_results.get()[idx]); |
| 130 | } |
| 131 | } |
| 132 | } |
| 133 | |
| 134 | void RecoverAndTest(const Guid& index_token, const Guid& hybrid_log_token) { |
| 135 | auto callback = [](IAsyncContext* ctxt, Status result) { |
| 136 | CallbackContext<ReadContext> context{ ctxt }; |
| 137 | assert(result == Status::Ok); |
| 138 | }; |
| 139 | |
| 140 | // Recover |
| 141 | uint32_t version; |
| 142 | std::vector<Guid> session_ids; |
| 143 | FASTER::core::Status result = store.Recover(index_token, hybrid_log_token, version, |
| 144 | session_ids); |
| 145 | if(result != FASTER::core::Status::Ok) { |
| 146 | printf("Recovery failed with error %u\n" , static_cast<uint8_t>(result)); |
| 147 | exit(1); |
| 148 | } |
| 149 | |
| 150 | std::vector<uint64_t> serial_nums; |
| 151 | for(const auto& session_id : session_ids) { |
| 152 | serial_nums.push_back(store.ContinueSession(session_id)); |
| 153 | store.StopSession(); |
| 154 | } |
| 155 | |
| 156 | // Create array for reading |
| 157 | auto read_results = alloc_aligned<uint64_t>(64, sizeof(uint64_t) * kNumUniqueKeys); |
| 158 | std::memset(read_results.get(), 0, sizeof(uint64_t) * kNumUniqueKeys); |
| 159 | |
| 160 | // Register with thread |
| 161 | store.StartSession(); |
| 162 | |
| 163 | // Issue read requests |
| 164 | for(uint64_t idx = 0; idx < kNumUniqueKeys; ++idx) { |
| 165 | ReadContext context{ AdId{ idx}, read_results.get() + idx }; |
| 166 | store.Read(context, callback, idx); |
| 167 | } |
| 168 | |
| 169 | // Complete all pending requests |
| 170 | store.CompletePending(true); |
| 171 | |
| 172 | // Release |
| 173 | store.StopSession(); |
| 174 | |
| 175 | // Test outputs |
| 176 | // Compute expected array |
| 177 | auto expected_results = alloc_aligned<uint64_t>(64, |
| 178 | sizeof(uint64_t) * kNumUniqueKeys); |
| 179 | std::memset(expected_results.get(), 0, sizeof(uint64_t) * kNumUniqueKeys); |
| 180 | |
| 181 | // Sessions that were active during checkpoint: |
| 182 | for(uint64_t serial_num : serial_nums) { |
| 183 | for(uint64_t idx = 0; idx <= serial_num; ++idx) { |
| 184 | ++expected_results.get()[idx % kNumUniqueKeys]; |
| 185 | } |
| 186 | } |
| 187 | // Sessions that were finished at time of checkpoint. |
| 188 | size_t num_completed = num_threads - serial_nums.size(); |
| 189 | for(size_t thread_idx = 0; thread_idx < num_completed; ++thread_idx) { |
| 190 | uint64_t serial_num = kNumOps; |
| 191 | for(uint64_t idx = 0; idx < serial_num; ++idx) { |
| 192 | ++expected_results.get()[idx % kNumUniqueKeys]; |
| 193 | } |
| 194 | } |
| 195 | |
| 196 | // Assert if expected is same as found |
| 197 | for(uint64_t idx = 0; idx < kNumUniqueKeys; ++idx) { |
| 198 | if(expected_results.get()[idx] != read_results.get()[idx]) { |
| 199 | printf("Debug error for AdId %" PRIu64 ": Expected (%" PRIu64 "), Found(%" PRIu64 ")\n" , |
| 200 | idx, |
| 201 | expected_results.get()[idx], |
| 202 | read_results.get()[idx]); |
| 203 | } |
| 204 | } |
| 205 | printf("Test successful\n" ); |
| 206 | } |
| 207 | |
| 208 | static void ContinueWorker(store_t* store, size_t thread_idx, |
| 209 | std::atomic<size_t>* num_active_threads, size_t num_threads, |
| 210 | std::atomic<uint32_t>* num_checkpoints, Guid guid) { |
| 211 | auto callback = [](IAsyncContext* ctxt, Status result) { |
| 212 | CallbackContext<RmwContext> context{ ctxt }; |
| 213 | assert(result == Status::Ok); |
| 214 | }; |
| 215 | |
| 216 | |
| 217 | auto hybrid_log_persistence_callback = [](Status result, uint64_t persistent_serial_num) { |
| 218 | if(result != Status::Ok) { |
| 219 | printf("Thread %" PRIu32 " reports checkpoint failed.\n" , |
| 220 | Thread::id()); |
| 221 | } else { |
| 222 | printf("Thread %" PRIu32 " reports persistence until %" PRIu64 "\n" , |
| 223 | Thread::id(), persistent_serial_num); |
| 224 | } |
| 225 | }; |
| 226 | |
| 227 | // Register thread with the store |
| 228 | uint64_t start_num = store->ContinueSession(guid); |
| 229 | |
| 230 | ++(*num_active_threads); |
| 231 | |
| 232 | // Process the batch of input data |
| 233 | for(size_t idx = start_num + 1; idx < kNumOps; ++idx) { |
| 234 | RmwContext context{ idx % kNumUniqueKeys, 1 }; |
| 235 | store->Rmw(context, callback, idx); |
| 236 | if(idx % kCheckpointInterval == 0 && *num_active_threads == num_threads) { |
| 237 | Guid token; |
| 238 | if(store->Checkpoint(nullptr, hybrid_log_persistence_callback, token)) { |
| 239 | printf("Thread %" PRIu32 " calling Checkpoint(), version = %" PRIu32 ", token = %s\n" , |
| 240 | Thread::id(), ++(*num_checkpoints), token.ToString().c_str()); |
| 241 | } |
| 242 | } |
| 243 | if(idx % kCompletePendingInterval == 0) { |
| 244 | store->CompletePending(false); |
| 245 | } else if(idx % kRefreshInterval == 0) { |
| 246 | store->Refresh(); |
| 247 | } |
| 248 | } |
| 249 | |
| 250 | // Make sure operations are completed |
| 251 | store->CompletePending(true); |
| 252 | |
| 253 | // Deregister thread from FASTER |
| 254 | store->StopSession(); |
| 255 | |
| 256 | printf("Populate successful on thread %" PRIu32 ".\n" , Thread::id()); |
| 257 | } |
| 258 | |
| 259 | void Continue(const Guid& index_token, const Guid& hybrid_log_token) { |
| 260 | // Recover |
| 261 | printf("Recovering version (index_token = %s, hybrid_log_token = %s)\n" , |
| 262 | index_token.ToString().c_str(), hybrid_log_token.ToString().c_str()); |
| 263 | uint32_t version; |
| 264 | std::vector<Guid> session_ids; |
| 265 | FASTER::core::Status result = store.Recover(index_token, hybrid_log_token, version, |
| 266 | session_ids); |
| 267 | if(result != FASTER::core::Status::Ok) { |
| 268 | printf("Recovery failed with error %u\n" , static_cast<uint8_t>(result)); |
| 269 | exit(1); |
| 270 | } else { |
| 271 | printf("Recovery Done!\n" ); |
| 272 | } |
| 273 | |
| 274 | num_checkpoints.store(version); |
| 275 | // Some threads may have already completed. |
| 276 | num_threads = session_ids.size(); |
| 277 | |
| 278 | std::deque<std::thread> threads; |
| 279 | for(size_t idx = 0; idx < num_threads; ++idx) { |
| 280 | threads.emplace_back(&ContinueWorker, &store, idx, &num_active_threads, num_threads, |
| 281 | &num_checkpoints, session_ids[idx]); |
| 282 | } |
| 283 | for(auto& thread : threads) { |
| 284 | thread.join(); |
| 285 | } |
| 286 | } |
| 287 | |
| 288 | store_t& store; |
| 289 | size_t num_threads; |
| 290 | std::atomic<size_t> num_active_threads; |
| 291 | std::atomic<uint32_t> num_checkpoints; |
| 292 | }; |
| 293 | |
| 294 | } // namespace sum_store |
| 295 | |