From 2977f505391a3ac82f8bcf12f8979c9dfee54f9c Mon Sep 17 00:00:00 2001 From: Joris van Rantwijk Date: Sun, 26 Jun 2022 20:51:35 +0200 Subject: [PATCH] Use multi-threaded quicksort for single-pass sorts --- Makefile | 2 +- src/sortbin.cpp | 553 ++++++++++++++++++++++++++++++++++++++++-------- 2 files changed, 460 insertions(+), 95 deletions(-) diff --git a/Makefile b/Makefile index b748433..28784ab 100644 --- a/Makefile +++ b/Makefile @@ -3,7 +3,7 @@ # CXX = g++ -CXXFLAGS = -Wall -O2 +CXXFLAGS = -Wall -O2 -pthread # -fsanitize=address -fsanitize=undefined SRCDIR = src diff --git a/src/sortbin.cpp b/src/sortbin.cpp index b669640..23935e5 100644 --- a/src/sortbin.cpp +++ b/src/sortbin.cpp @@ -25,17 +25,24 @@ #include #include #include +#include #include #include #include #include +#include +#include +#include +#include #include #include +#include #include #include #include #include +#include #include #include @@ -48,6 +55,7 @@ /* Default number of sorting threads. */ #define DEFAULT_THREADS 1 +#define MAX_THREADS 128 /* Align buffer sizes and I/O on this number of records. For efficiency, I/O should be done in multiples of 4096 bytes. */ @@ -72,6 +80,9 @@ struct SortContext /** Maximum number of arrays to merge in one step. */ unsigned int branch_factor; + /** Number of threads for parallel sorting. */ + unsigned int num_threads; + /** True to eliminate duplicate records. */ bool flag_unique; @@ -174,6 +185,152 @@ private: }; +/** Thread pool for parallel sorting and background I/O. */ +class ThreadPool +{ +public: + typedef std::future FutureType; + + /** Initialize thread pool with the specified number of threads. */ + explicit ThreadPool(unsigned int num_threads); + + // Prevent copying and assignment. + ThreadPool(const ThreadPool&) = delete; + ThreadPool& operator=(const ThreadPool&) = delete; + + /** + * Destructor: Wait until pending tasks are finished, then join threads. + */ + ~ThreadPool(); + + /** + * Submit a function to be executed in the thread pool. + * @param f Function which takes no arguments and returns void. + * @return A future that becomes ready when the function completes. + */ + template + FutureType submit(FunctionType f) + { + std::packaged_task task(f); + FutureType fut(task.get_future()); + std::lock_guard lock(m_mutex); + m_queue.push_back(std::move(task)); + m_cond.notify_all(); + return fut; + } + +private: + /** Worker function that runs in each thread. */ + void worker_function(); + + std::mutex m_mutex; + std::condition_variable m_cond; + std::vector m_threads; + bool m_stop_flag; + std::deque> m_queue; +}; + +// Constructor. +ThreadPool::ThreadPool(unsigned int num_threads) + : m_stop_flag(false) +{ + // Create the worker threads. + for (unsigned int i = 0; i < num_threads; i++) { + m_threads.emplace_back(&ThreadPool::worker_function, this); + } +} + +// Destructor. +ThreadPool::~ThreadPool() +{ + // Set stop flag. + { + std::unique_lock lock(m_mutex); + m_stop_flag = true; + m_cond.notify_all(); + } + + // Join the worker threads. + for (auto& thread : m_threads) { + thread.join(); + } +} + +// Worker function that runs in each thread. +void ThreadPool::worker_function() +{ + // Main loop. + while (true) { + + std::packaged_task task; + + // Get a pending task from the queue. + { + std::unique_lock lock(m_mutex); + + // Wait until the queue is non-empty or the stop flag is raised. + while (m_queue.empty() && !m_stop_flag) { + m_cond.wait(lock); + } + + // Exit if the stop flag is raised (thread pool is shutting down). + if (m_stop_flag) { + break; + } + + // Get oldest pending task. + task = std::move(m_queue.front()); + m_queue.pop_front(); + } + + // Execute the pending task. + task(); + } +} + + +/** Helper class to wait until multiple "std::future"s are ready. */ +class CompletionTracker +{ +public: + typedef std::future FutureType; + + /** Add a future to wait for. */ + void add(FutureType&& fut) + { + std::unique_lock lock(m_mutex); + m_futures.emplace_back(std::move(fut)); + } + + /** Wait until all futures have completed. */ + void wait() + { + while (true) { + FutureType fut; + + // Get the next future from the queue. + { + std::unique_lock lock(m_mutex); + if (m_futures.empty()) { + // No more futures. We are done waiting. + break; + } + + fut = std::move(m_futures.front()); + m_futures.pop_front(); + } + + // Wait on this future. + fut.wait(); + } + } + +private: + std::mutex m_mutex; + std::deque m_futures; +}; + + /** * Binary file access. * @@ -868,6 +1025,115 @@ void insertion_sort_records( } +/** + * Partition the array into two parts. + * + * This is a helper function for quicksort_records(). + * + * After partitioning, all of the "num_left" first records will be + * less-than-or-equal to all of the "num_right" last records. + * + * After partitioning, either + * num_left + num_right == num_records + * or + * num_left + num_right == num_records - 1 + * + * In the second case, the record in the middle is already in its final + * position in the array. + * + * @param range_start Pointer to first element of the array. + * @param record_size Record size in bytes. + * @param num_records Number of records in the array. + * @param[out] num_left Number of records in the left half. + * @param[out] num_right Number of records in the right half. + */ +inline void quicksort_partition_records( + unsigned char * range_begin, + size_t record_size, + size_t num_records, + size_t& num_left, + size_t& num_right) +{ + // Initialize pointers to start, end and middle of range. + unsigned char * left_ptr = range_begin; + unsigned char * right_ptr = range_begin + (num_records - 1) * record_size; + unsigned char * pivot_ptr = range_begin + (num_records / 2) * record_size; + + // Sort the first, middle and last records such that they are + // in proper order with respect to each other. + if (record_compare(pivot_ptr, left_ptr, record_size) < 0) { + record_swap(left_ptr, pivot_ptr, record_size); + } + if (record_compare(right_ptr, pivot_ptr, record_size) < 0) { + record_swap(pivot_ptr, right_ptr, record_size); + if (record_compare(pivot_ptr, left_ptr, record_size) < 0) { + record_swap(left_ptr, pivot_ptr, record_size); + } + } + + // The median of the three records we examined is now in the + // middle of the range, pointed to by pivot_ptr. + // This is not necessarily the final location of that element. + + // The first and last record of the range are now on the proper + // side of the partition. No need to examine them again. + left_ptr += record_size; + right_ptr -= record_size; + + // Partition the rest of the array based on comparing to the pivot. + while (true) { + + // Skip left-side records that are less than the pivot. + while (record_compare(left_ptr, pivot_ptr, record_size) < 0) { + left_ptr += record_size; + } + + // Skip right-side records that are greater than the pivot. + while (record_compare(pivot_ptr, right_ptr, record_size) < 0) { + right_ptr -= record_size; + } + + // Stop when the pointers meet. + if (left_ptr >= right_ptr) { + break; + } + + // Swap the records that are on the wrong sides. + record_swap(left_ptr, right_ptr, record_size); + + // If we moved the pivot, update its pointer so it keeps + // pointing to the pivot value. + if (pivot_ptr == left_ptr) { + pivot_ptr = right_ptr; + } else if (pivot_ptr == right_ptr) { + pivot_ptr = left_ptr; + } + + // Do not compare the swapped elements again. + left_ptr += record_size; + right_ptr -= record_size; + + // Stop when pointers cross. + // (Pointers equal is not good enough at this point, because + // we won't know on which side the pointed record belongs.) + if (left_ptr > right_ptr) { + break; + } + } + + // If pointers are equal, they must both be pointing to a pivot. + // Bump both pointers so they correctly delineate the new + // subranges. The record where the pointers meet is already in + // its final position. + if (left_ptr == right_ptr) { + left_ptr += record_size; + right_ptr -= record_size; + } + + // Determine the number of elements in the left and right subranges. + num_left = (right_ptr + record_size - range_begin) / record_size; + num_right = num_records - (left_ptr - range_begin) / record_size; +} /** @@ -925,96 +1191,25 @@ void quicksort_records( continue; } - // Initialize pointers to start, end and middle of range. - unsigned char * left_ptr = range_begin; - unsigned char * right_ptr = - range_begin + (range_num_records - 1) * record_size; - unsigned char * pivot_ptr = - range_begin + (range_num_records / 2) * record_size; - - // Sort the first, middle and last records such that they are - // in proper order with respect to each other. - if (record_compare(pivot_ptr, left_ptr, record_size) < 0) { - record_swap(left_ptr, pivot_ptr, record_size); - } - if (record_compare(right_ptr, pivot_ptr, record_size) < 0) { - record_swap(pivot_ptr, right_ptr, record_size); - if (record_compare(pivot_ptr, left_ptr, record_size) < 0) { - record_swap(left_ptr, pivot_ptr, record_size); - } - } - - // The median of the three records we examined is now in the - // middle of the range, pointed to by pivot_ptr. - // This is not necessarily the final location of that element. - - // The first and last record of the range are now on the proper - // side of the partition. No need to examine them again. - left_ptr += record_size; - right_ptr -= record_size; - - // Partition the rest of the array based on comparing to the pivot. - while (true) { - - // Skip left-side records that are less than the pivot. - while (record_compare(left_ptr, pivot_ptr, record_size) < 0) { - left_ptr += record_size; - } - - // Skip right-side records that are greater than the pivot. - while (record_compare(pivot_ptr, right_ptr, record_size) < 0) { - right_ptr -= record_size; - } - - // Stop when the pointers meet. - if (left_ptr >= right_ptr) { - break; - } - - // Swap the records that are on the wrong sides. - record_swap(left_ptr, right_ptr, record_size); - - // If we moved the pivot, update its pointer so it keeps - // pointing to the pivot value. - if (pivot_ptr == left_ptr) { - pivot_ptr = right_ptr; - } else if (pivot_ptr == right_ptr) { - pivot_ptr = left_ptr; - } - - // Do not compare the swapped elements again. - left_ptr += record_size; - right_ptr -= record_size; - - // Stop when pointers cross. - // (Pointers equal is not good enough at this point, because - // we won't know on which side the pointed record belongs.) - if (left_ptr > right_ptr) { - break; - } - } - - // If pointers are equal, they must both be pointing to a pivot. - // Bump both pointers so they correctly delineate the new - // subranges. The record where the pointers meet is already in - // its final position. - if (left_ptr == right_ptr) { - left_ptr += record_size; - right_ptr -= record_size; - } + // Partition the array into two parts. + size_t num_left, num_right; + quicksort_partition_records( + range_begin, + record_size, + range_num_records, + num_left, + num_right); // Push left subrange on the stack, if it meets the size threshold. - size_t num_left = - (right_ptr + record_size - range_begin) / record_size; if (num_left > insertion_sort_threshold) { stack.emplace_back(range_begin, num_left, depth_limit - 1); } // Push right subrange on the stack, if it meets the size threshold. - size_t num_right = - range_num_records - (left_ptr - range_begin) / record_size; if (num_right > insertion_sort_threshold) { - stack.emplace_back(left_ptr, num_right, depth_limit - 1); + unsigned char * right_half = + range_begin + (range_num_records - num_right) * record_size; + stack.emplace_back(right_half, num_right, depth_limit - 1); } } @@ -1027,6 +1222,126 @@ void quicksort_records( } +/** Helper function for quicksort_records_parallel(). */ +void quicksort_records_parallel_step( + unsigned char * range_begin, + size_t record_size, + size_t num_records, + size_t parallel_size_threshold, + unsigned int parallel_depth_limit, + ThreadPool * thread_pool, + CompletionTracker * completion_tracker) +{ + while (true) { + + // If the range is below threshold, or recursion is too deep, + // handle this part of the array entirely in this thread. + if (num_records <= parallel_size_threshold + || parallel_depth_limit == 0) { + quicksort_records(range_begin, record_size, num_records); + break; + } + + parallel_depth_limit--; + + // Partition the array into two parts. + size_t num_left, num_right; + quicksort_partition_records( + range_begin, + record_size, + num_records, + num_left, + num_right); + + // Submit the largest of the two subranges to the thread pool. + // We will handle the other subrange within this thread. + unsigned char * right_half = + range_begin + (num_records - num_right) * record_size; + if (num_left >= num_right) { + + // Submit left subrange. + completion_tracker->add( + thread_pool->submit( + std::bind( + quicksort_records_parallel_step, + range_begin, + record_size, + num_left, + parallel_size_threshold, + parallel_depth_limit, + thread_pool, + completion_tracker))); + + // Continue with right subrange in this thread. + range_begin = right_half; + num_records = num_right; + + } else { + + // Submit right subrange. + completion_tracker->add( + thread_pool->submit( + std::bind( + quicksort_records_parallel_step, + right_half, + record_size, + num_right, + parallel_size_threshold, + parallel_depth_limit, + thread_pool, + completion_tracker))); + + // Continue with left subrange in this thread. + num_records = num_left; + } + } +} + + +/** + * Sort an array of records using in-place quicksort. + * + * Use multiple threads to parallelize the sort process. + */ +void quicksort_records_parallel( + unsigned char * buffer, + size_t record_size, + size_t num_records, + unsigned int num_threads, + ThreadPool * thread_pool) +{ + // Small fragments should not be further distributed between threads. + size_t parallel_size_threshold = + std::max(size_t(1024), num_records / num_threads / 4); + + // Stop parallel processing past a certain nesting depth. + // This is necessary to avoid quadratic run time. + unsigned int parallel_depth_limit = 2; + for (unsigned int nn = num_threads; nn > 1; nn >>= 1) { + parallel_depth_limit += 2; + } + + // Tracker to determine when all sort tasks have finished. + CompletionTracker completion_tracker; + + // Submit the full array to the thread pool. + completion_tracker.add( + thread_pool->submit( + std::bind( + quicksort_records_parallel_step, + buffer, + record_size, + num_records, + parallel_size_threshold, + parallel_depth_limit, + thread_pool, + &completion_tracker))); + + // Wait until all sort tasks have finished. + completion_tracker.wait(); +} + + /** Sort the specified block of records (in-place). */ void sort_records( unsigned char * buffer, @@ -1115,12 +1430,26 @@ void single_pass( timer.stop(); log(ctx, " t = %.3f seconds\n", timer.value()); -// TODO : multi-threaded sorting with thread pool + // Set up thread pool. + std::unique_ptr thread_pool; + if (ctx.num_threads > 1) { + log(ctx, "creating thread pool with %u threads\n", ctx.num_threads); + thread_pool.reset(new ThreadPool(ctx.num_threads)); + } // Sort records in memory buffer. - log(ctx, "sorting records\n"); + log(ctx, "sorting records using %u threads\n", ctx.num_threads); timer.start(); - sort_records(buffer.data(), ctx.record_size, num_records); + if (ctx.num_threads > 1) { + quicksort_records_parallel( + buffer.data(), + ctx.record_size, + num_records, + ctx.num_threads, + thread_pool.get()); + } else { + quicksort_records(buffer.data(), ctx.record_size, num_records); + } timer.stop(); log(ctx, " t = %.3f seconds\n", timer.value()); @@ -1785,6 +2114,31 @@ std::string get_default_tmpdir(void) } +/** Parse an unsigned integer. */ +bool parse_uint(const char *argstr, unsigned int& value) +{ + char *endptr; + + errno = 0; + long t = strtol(argstr, &endptr, 10); + + if (endptr == argstr || endptr[0] != '\0') { + return false; + } + + if (errno != 0) { + return false; + } + + if (t < 0 || (unsigned long)t > UINT_MAX) { + return false; + } + + value = t; + return true; +} + + /** * Parse a memory size specification. * @@ -1841,7 +2195,8 @@ void usage() " -u, --unique eliminate duplicates after sorting\n" " --memory=M use at most MiByte RAM (default: %d)\n" " --memory=G use at most GiByte RAM\n" - " --branch=B merge at most B arrays in one step (default: %d)\n" + " --branch=N merge N subarrays in one step (default: %d)\n" + " --threads=N use N threads for parallel sorting (default: %d)\n" " --temporary-directory=DIR write temporary file to the specified\n" " directory (default: $TMPDIR)\n" "\n" @@ -1850,7 +2205,8 @@ void usage() "created with the same size as the input/output files.\n" "\n", DEFAULT_MEMORY_SIZE_MBYTE, - DEFAULT_BRANCH_FACTOR); + DEFAULT_BRANCH_FACTOR, + DEFAULT_THREADS); } @@ -1864,6 +2220,7 @@ int main(int argc, char **argv) { "unique", 0, NULL, 'u' }, { "memory", 1, NULL, 'M' }, { "branch", 1, NULL, 'B' }, + { "threads", 1, NULL, 'J' }, { "temporary-directory", 1, NULL, 'T' }, { "verbose", 0, NULL, 'v' }, { "help", 0, NULL, 'h' }, @@ -1871,17 +2228,17 @@ int main(int argc, char **argv) }; bool flag_unique = false; bool flag_verbose = false; - int record_size = 0; + unsigned int record_size = 0; + unsigned int branch_factor = DEFAULT_BRANCH_FACTOR; + unsigned int num_threads = DEFAULT_THREADS; uint64_t memory_size = uint64_t(DEFAULT_MEMORY_SIZE_MBYTE) * 1024 * 1024; - int branch_factor = DEFAULT_BRANCH_FACTOR; std::string tempdir = get_default_tmpdir(); int opt; - while ((opt = getopt_long(argc, argv, "s:T:uv", longopts, NULL)) != -1) { + while ((opt = getopt_long(argc, argv, "s:T:uvh", longopts, NULL)) != -1) { switch (opt) { case 's': - record_size = atoi(optarg); - if (record_size < 1) { + if (!parse_uint(optarg, record_size) || record_size < 1) { fprintf(stderr, "ERROR: Invalid record size (must be at least 1)\n"); return EXIT_FAILURE; @@ -1897,13 +2254,20 @@ int main(int argc, char **argv) } break; case 'B': - branch_factor = atoi(optarg); - if (branch_factor < 2) { + if (!parse_uint(optarg, branch_factor) || branch_factor < 2) { fprintf(stderr, "ERROR: Invalid radix value, must be at least 2\n"); return EXIT_FAILURE; } break; + case 'J': + if (!parse_uint(optarg, num_threads) + || num_threads < 1 + || num_threads > MAX_THREADS) { + fprintf(stderr, "ERROR: Invalid number of threads\n"); + return EXIT_FAILURE; + } + break; case 'T': tempdir = optarg; break; @@ -1953,6 +2317,7 @@ int main(int argc, char **argv) ctx.record_size = record_size; ctx.memory_size = memory_size; ctx.branch_factor = branch_factor; + ctx.num_threads = num_threads; ctx.flag_unique = flag_unique; ctx.flag_verbose = flag_verbose; ctx.temporary_directory = tempdir;