1
0
Fork 0

Use multi-threaded quicksort for single-pass sorts

This commit is contained in:
Joris van Rantwijk 2022-06-26 20:51:35 +02:00
parent ae33feaca4
commit 2977f50539
2 changed files with 460 additions and 95 deletions

View File

@ -3,7 +3,7 @@
# #
CXX = g++ CXX = g++
CXXFLAGS = -Wall -O2 CXXFLAGS = -Wall -O2 -pthread
# -fsanitize=address -fsanitize=undefined # -fsanitize=address -fsanitize=undefined
SRCDIR = src SRCDIR = src

View File

@ -25,17 +25,24 @@
#include <fcntl.h> #include <fcntl.h>
#include <getopt.h> #include <getopt.h>
#include <inttypes.h> #include <inttypes.h>
#include <limits.h>
#include <string.h> #include <string.h>
#include <time.h> #include <time.h>
#include <unistd.h> #include <unistd.h>
#include <algorithm> #include <algorithm>
#include <condition_variable>
#include <deque>
#include <functional>
#include <future>
#include <iterator> #include <iterator>
#include <memory> #include <memory>
#include <mutex>
#include <numeric> #include <numeric>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include <system_error> #include <system_error>
#include <thread>
#include <tuple> #include <tuple>
#include <vector> #include <vector>
@ -48,6 +55,7 @@
/* Default number of sorting threads. */ /* Default number of sorting threads. */
#define DEFAULT_THREADS 1 #define DEFAULT_THREADS 1
#define MAX_THREADS 128
/* Align buffer sizes and I/O on this number of records. /* Align buffer sizes and I/O on this number of records.
For efficiency, I/O should be done in multiples of 4096 bytes. */ For efficiency, I/O should be done in multiples of 4096 bytes. */
@ -72,6 +80,9 @@ struct SortContext
/** Maximum number of arrays to merge in one step. */ /** Maximum number of arrays to merge in one step. */
unsigned int branch_factor; unsigned int branch_factor;
/** Number of threads for parallel sorting. */
unsigned int num_threads;
/** True to eliminate duplicate records. */ /** True to eliminate duplicate records. */
bool flag_unique; bool flag_unique;
@ -174,6 +185,152 @@ private:
}; };
/** Thread pool for parallel sorting and background I/O. */
class ThreadPool
{
public:
typedef std::future<void> FutureType;
/** Initialize thread pool with the specified number of threads. */
explicit ThreadPool(unsigned int num_threads);
// Prevent copying and assignment.
ThreadPool(const ThreadPool&) = delete;
ThreadPool& operator=(const ThreadPool&) = delete;
/**
* Destructor: Wait until pending tasks are finished, then join threads.
*/
~ThreadPool();
/**
* Submit a function to be executed in the thread pool.
* @param f Function which takes no arguments and returns void.
* @return A future that becomes ready when the function completes.
*/
template <class FunctionType>
FutureType submit(FunctionType f)
{
std::packaged_task<void()> task(f);
FutureType fut(task.get_future());
std::lock_guard<std::mutex> lock(m_mutex);
m_queue.push_back(std::move(task));
m_cond.notify_all();
return fut;
}
private:
/** Worker function that runs in each thread. */
void worker_function();
std::mutex m_mutex;
std::condition_variable m_cond;
std::vector<std::thread> m_threads;
bool m_stop_flag;
std::deque<std::packaged_task<void()>> m_queue;
};
// Constructor.
ThreadPool::ThreadPool(unsigned int num_threads)
: m_stop_flag(false)
{
// Create the worker threads.
for (unsigned int i = 0; i < num_threads; i++) {
m_threads.emplace_back(&ThreadPool::worker_function, this);
}
}
// Destructor.
ThreadPool::~ThreadPool()
{
// Set stop flag.
{
std::unique_lock<std::mutex> lock(m_mutex);
m_stop_flag = true;
m_cond.notify_all();
}
// Join the worker threads.
for (auto& thread : m_threads) {
thread.join();
}
}
// Worker function that runs in each thread.
void ThreadPool::worker_function()
{
// Main loop.
while (true) {
std::packaged_task<void()> task;
// Get a pending task from the queue.
{
std::unique_lock<std::mutex> lock(m_mutex);
// Wait until the queue is non-empty or the stop flag is raised.
while (m_queue.empty() && !m_stop_flag) {
m_cond.wait(lock);
}
// Exit if the stop flag is raised (thread pool is shutting down).
if (m_stop_flag) {
break;
}
// Get oldest pending task.
task = std::move(m_queue.front());
m_queue.pop_front();
}
// Execute the pending task.
task();
}
}
/** Helper class to wait until multiple "std::future"s are ready. */
class CompletionTracker
{
public:
typedef std::future<void> FutureType;
/** Add a future to wait for. */
void add(FutureType&& fut)
{
std::unique_lock<std::mutex> lock(m_mutex);
m_futures.emplace_back(std::move(fut));
}
/** Wait until all futures have completed. */
void wait()
{
while (true) {
FutureType fut;
// Get the next future from the queue.
{
std::unique_lock<std::mutex> lock(m_mutex);
if (m_futures.empty()) {
// No more futures. We are done waiting.
break;
}
fut = std::move(m_futures.front());
m_futures.pop_front();
}
// Wait on this future.
fut.wait();
}
}
private:
std::mutex m_mutex;
std::deque<FutureType> m_futures;
};
/** /**
* Binary file access. * Binary file access.
* *
@ -868,6 +1025,115 @@ void insertion_sort_records(
} }
/**
* Partition the array into two parts.
*
* This is a helper function for quicksort_records().
*
* After partitioning, all of the "num_left" first records will be
* less-than-or-equal to all of the "num_right" last records.
*
* After partitioning, either
* num_left + num_right == num_records
* or
* num_left + num_right == num_records - 1
*
* In the second case, the record in the middle is already in its final
* position in the array.
*
* @param range_start Pointer to first element of the array.
* @param record_size Record size in bytes.
* @param num_records Number of records in the array.
* @param[out] num_left Number of records in the left half.
* @param[out] num_right Number of records in the right half.
*/
inline void quicksort_partition_records(
unsigned char * range_begin,
size_t record_size,
size_t num_records,
size_t& num_left,
size_t& num_right)
{
// Initialize pointers to start, end and middle of range.
unsigned char * left_ptr = range_begin;
unsigned char * right_ptr = range_begin + (num_records - 1) * record_size;
unsigned char * pivot_ptr = range_begin + (num_records / 2) * record_size;
// Sort the first, middle and last records such that they are
// in proper order with respect to each other.
if (record_compare(pivot_ptr, left_ptr, record_size) < 0) {
record_swap(left_ptr, pivot_ptr, record_size);
}
if (record_compare(right_ptr, pivot_ptr, record_size) < 0) {
record_swap(pivot_ptr, right_ptr, record_size);
if (record_compare(pivot_ptr, left_ptr, record_size) < 0) {
record_swap(left_ptr, pivot_ptr, record_size);
}
}
// The median of the three records we examined is now in the
// middle of the range, pointed to by pivot_ptr.
// This is not necessarily the final location of that element.
// The first and last record of the range are now on the proper
// side of the partition. No need to examine them again.
left_ptr += record_size;
right_ptr -= record_size;
// Partition the rest of the array based on comparing to the pivot.
while (true) {
// Skip left-side records that are less than the pivot.
while (record_compare(left_ptr, pivot_ptr, record_size) < 0) {
left_ptr += record_size;
}
// Skip right-side records that are greater than the pivot.
while (record_compare(pivot_ptr, right_ptr, record_size) < 0) {
right_ptr -= record_size;
}
// Stop when the pointers meet.
if (left_ptr >= right_ptr) {
break;
}
// Swap the records that are on the wrong sides.
record_swap(left_ptr, right_ptr, record_size);
// If we moved the pivot, update its pointer so it keeps
// pointing to the pivot value.
if (pivot_ptr == left_ptr) {
pivot_ptr = right_ptr;
} else if (pivot_ptr == right_ptr) {
pivot_ptr = left_ptr;
}
// Do not compare the swapped elements again.
left_ptr += record_size;
right_ptr -= record_size;
// Stop when pointers cross.
// (Pointers equal is not good enough at this point, because
// we won't know on which side the pointed record belongs.)
if (left_ptr > right_ptr) {
break;
}
}
// If pointers are equal, they must both be pointing to a pivot.
// Bump both pointers so they correctly delineate the new
// subranges. The record where the pointers meet is already in
// its final position.
if (left_ptr == right_ptr) {
left_ptr += record_size;
right_ptr -= record_size;
}
// Determine the number of elements in the left and right subranges.
num_left = (right_ptr + record_size - range_begin) / record_size;
num_right = num_records - (left_ptr - range_begin) / record_size;
}
/** /**
@ -925,96 +1191,25 @@ void quicksort_records(
continue; continue;
} }
// Initialize pointers to start, end and middle of range. // Partition the array into two parts.
unsigned char * left_ptr = range_begin; size_t num_left, num_right;
unsigned char * right_ptr = quicksort_partition_records(
range_begin + (range_num_records - 1) * record_size; range_begin,
unsigned char * pivot_ptr = record_size,
range_begin + (range_num_records / 2) * record_size; range_num_records,
num_left,
// Sort the first, middle and last records such that they are num_right);
// in proper order with respect to each other.
if (record_compare(pivot_ptr, left_ptr, record_size) < 0) {
record_swap(left_ptr, pivot_ptr, record_size);
}
if (record_compare(right_ptr, pivot_ptr, record_size) < 0) {
record_swap(pivot_ptr, right_ptr, record_size);
if (record_compare(pivot_ptr, left_ptr, record_size) < 0) {
record_swap(left_ptr, pivot_ptr, record_size);
}
}
// The median of the three records we examined is now in the
// middle of the range, pointed to by pivot_ptr.
// This is not necessarily the final location of that element.
// The first and last record of the range are now on the proper
// side of the partition. No need to examine them again.
left_ptr += record_size;
right_ptr -= record_size;
// Partition the rest of the array based on comparing to the pivot.
while (true) {
// Skip left-side records that are less than the pivot.
while (record_compare(left_ptr, pivot_ptr, record_size) < 0) {
left_ptr += record_size;
}
// Skip right-side records that are greater than the pivot.
while (record_compare(pivot_ptr, right_ptr, record_size) < 0) {
right_ptr -= record_size;
}
// Stop when the pointers meet.
if (left_ptr >= right_ptr) {
break;
}
// Swap the records that are on the wrong sides.
record_swap(left_ptr, right_ptr, record_size);
// If we moved the pivot, update its pointer so it keeps
// pointing to the pivot value.
if (pivot_ptr == left_ptr) {
pivot_ptr = right_ptr;
} else if (pivot_ptr == right_ptr) {
pivot_ptr = left_ptr;
}
// Do not compare the swapped elements again.
left_ptr += record_size;
right_ptr -= record_size;
// Stop when pointers cross.
// (Pointers equal is not good enough at this point, because
// we won't know on which side the pointed record belongs.)
if (left_ptr > right_ptr) {
break;
}
}
// If pointers are equal, they must both be pointing to a pivot.
// Bump both pointers so they correctly delineate the new
// subranges. The record where the pointers meet is already in
// its final position.
if (left_ptr == right_ptr) {
left_ptr += record_size;
right_ptr -= record_size;
}
// Push left subrange on the stack, if it meets the size threshold. // Push left subrange on the stack, if it meets the size threshold.
size_t num_left =
(right_ptr + record_size - range_begin) / record_size;
if (num_left > insertion_sort_threshold) { if (num_left > insertion_sort_threshold) {
stack.emplace_back(range_begin, num_left, depth_limit - 1); stack.emplace_back(range_begin, num_left, depth_limit - 1);
} }
// Push right subrange on the stack, if it meets the size threshold. // Push right subrange on the stack, if it meets the size threshold.
size_t num_right =
range_num_records - (left_ptr - range_begin) / record_size;
if (num_right > insertion_sort_threshold) { if (num_right > insertion_sort_threshold) {
stack.emplace_back(left_ptr, num_right, depth_limit - 1); unsigned char * right_half =
range_begin + (range_num_records - num_right) * record_size;
stack.emplace_back(right_half, num_right, depth_limit - 1);
} }
} }
@ -1027,6 +1222,126 @@ void quicksort_records(
} }
/** Helper function for quicksort_records_parallel(). */
void quicksort_records_parallel_step(
unsigned char * range_begin,
size_t record_size,
size_t num_records,
size_t parallel_size_threshold,
unsigned int parallel_depth_limit,
ThreadPool * thread_pool,
CompletionTracker * completion_tracker)
{
while (true) {
// If the range is below threshold, or recursion is too deep,
// handle this part of the array entirely in this thread.
if (num_records <= parallel_size_threshold
|| parallel_depth_limit == 0) {
quicksort_records(range_begin, record_size, num_records);
break;
}
parallel_depth_limit--;
// Partition the array into two parts.
size_t num_left, num_right;
quicksort_partition_records(
range_begin,
record_size,
num_records,
num_left,
num_right);
// Submit the largest of the two subranges to the thread pool.
// We will handle the other subrange within this thread.
unsigned char * right_half =
range_begin + (num_records - num_right) * record_size;
if (num_left >= num_right) {
// Submit left subrange.
completion_tracker->add(
thread_pool->submit(
std::bind(
quicksort_records_parallel_step,
range_begin,
record_size,
num_left,
parallel_size_threshold,
parallel_depth_limit,
thread_pool,
completion_tracker)));
// Continue with right subrange in this thread.
range_begin = right_half;
num_records = num_right;
} else {
// Submit right subrange.
completion_tracker->add(
thread_pool->submit(
std::bind(
quicksort_records_parallel_step,
right_half,
record_size,
num_right,
parallel_size_threshold,
parallel_depth_limit,
thread_pool,
completion_tracker)));
// Continue with left subrange in this thread.
num_records = num_left;
}
}
}
/**
* Sort an array of records using in-place quicksort.
*
* Use multiple threads to parallelize the sort process.
*/
void quicksort_records_parallel(
unsigned char * buffer,
size_t record_size,
size_t num_records,
unsigned int num_threads,
ThreadPool * thread_pool)
{
// Small fragments should not be further distributed between threads.
size_t parallel_size_threshold =
std::max(size_t(1024), num_records / num_threads / 4);
// Stop parallel processing past a certain nesting depth.
// This is necessary to avoid quadratic run time.
unsigned int parallel_depth_limit = 2;
for (unsigned int nn = num_threads; nn > 1; nn >>= 1) {
parallel_depth_limit += 2;
}
// Tracker to determine when all sort tasks have finished.
CompletionTracker completion_tracker;
// Submit the full array to the thread pool.
completion_tracker.add(
thread_pool->submit(
std::bind(
quicksort_records_parallel_step,
buffer,
record_size,
num_records,
parallel_size_threshold,
parallel_depth_limit,
thread_pool,
&completion_tracker)));
// Wait until all sort tasks have finished.
completion_tracker.wait();
}
/** Sort the specified block of records (in-place). */ /** Sort the specified block of records (in-place). */
void sort_records( void sort_records(
unsigned char * buffer, unsigned char * buffer,
@ -1115,12 +1430,26 @@ void single_pass(
timer.stop(); timer.stop();
log(ctx, " t = %.3f seconds\n", timer.value()); log(ctx, " t = %.3f seconds\n", timer.value());
// TODO : multi-threaded sorting with thread pool // Set up thread pool.
std::unique_ptr<ThreadPool> thread_pool;
if (ctx.num_threads > 1) {
log(ctx, "creating thread pool with %u threads\n", ctx.num_threads);
thread_pool.reset(new ThreadPool(ctx.num_threads));
}
// Sort records in memory buffer. // Sort records in memory buffer.
log(ctx, "sorting records\n"); log(ctx, "sorting records using %u threads\n", ctx.num_threads);
timer.start(); timer.start();
sort_records(buffer.data(), ctx.record_size, num_records); if (ctx.num_threads > 1) {
quicksort_records_parallel(
buffer.data(),
ctx.record_size,
num_records,
ctx.num_threads,
thread_pool.get());
} else {
quicksort_records(buffer.data(), ctx.record_size, num_records);
}
timer.stop(); timer.stop();
log(ctx, " t = %.3f seconds\n", timer.value()); log(ctx, " t = %.3f seconds\n", timer.value());
@ -1785,6 +2114,31 @@ std::string get_default_tmpdir(void)
} }
/** Parse an unsigned integer. */
bool parse_uint(const char *argstr, unsigned int& value)
{
char *endptr;
errno = 0;
long t = strtol(argstr, &endptr, 10);
if (endptr == argstr || endptr[0] != '\0') {
return false;
}
if (errno != 0) {
return false;
}
if (t < 0 || (unsigned long)t > UINT_MAX) {
return false;
}
value = t;
return true;
}
/** /**
* Parse a memory size specification. * Parse a memory size specification.
* *
@ -1841,7 +2195,8 @@ void usage()
" -u, --unique eliminate duplicates after sorting\n" " -u, --unique eliminate duplicates after sorting\n"
" --memory=<n>M use at most <n> MiByte RAM (default: %d)\n" " --memory=<n>M use at most <n> MiByte RAM (default: %d)\n"
" --memory=<n>G use at most <n> GiByte RAM\n" " --memory=<n>G use at most <n> GiByte RAM\n"
" --branch=B merge at most B arrays in one step (default: %d)\n" " --branch=N merge N subarrays in one step (default: %d)\n"
" --threads=N use N threads for parallel sorting (default: %d)\n"
" --temporary-directory=DIR write temporary file to the specified\n" " --temporary-directory=DIR write temporary file to the specified\n"
" directory (default: $TMPDIR)\n" " directory (default: $TMPDIR)\n"
"\n" "\n"
@ -1850,7 +2205,8 @@ void usage()
"created with the same size as the input/output files.\n" "created with the same size as the input/output files.\n"
"\n", "\n",
DEFAULT_MEMORY_SIZE_MBYTE, DEFAULT_MEMORY_SIZE_MBYTE,
DEFAULT_BRANCH_FACTOR); DEFAULT_BRANCH_FACTOR,
DEFAULT_THREADS);
} }
@ -1864,6 +2220,7 @@ int main(int argc, char **argv)
{ "unique", 0, NULL, 'u' }, { "unique", 0, NULL, 'u' },
{ "memory", 1, NULL, 'M' }, { "memory", 1, NULL, 'M' },
{ "branch", 1, NULL, 'B' }, { "branch", 1, NULL, 'B' },
{ "threads", 1, NULL, 'J' },
{ "temporary-directory", 1, NULL, 'T' }, { "temporary-directory", 1, NULL, 'T' },
{ "verbose", 0, NULL, 'v' }, { "verbose", 0, NULL, 'v' },
{ "help", 0, NULL, 'h' }, { "help", 0, NULL, 'h' },
@ -1871,17 +2228,17 @@ int main(int argc, char **argv)
}; };
bool flag_unique = false; bool flag_unique = false;
bool flag_verbose = false; bool flag_verbose = false;
int record_size = 0; unsigned int record_size = 0;
unsigned int branch_factor = DEFAULT_BRANCH_FACTOR;
unsigned int num_threads = DEFAULT_THREADS;
uint64_t memory_size = uint64_t(DEFAULT_MEMORY_SIZE_MBYTE) * 1024 * 1024; uint64_t memory_size = uint64_t(DEFAULT_MEMORY_SIZE_MBYTE) * 1024 * 1024;
int branch_factor = DEFAULT_BRANCH_FACTOR;
std::string tempdir = get_default_tmpdir(); std::string tempdir = get_default_tmpdir();
int opt; int opt;
while ((opt = getopt_long(argc, argv, "s:T:uv", longopts, NULL)) != -1) { while ((opt = getopt_long(argc, argv, "s:T:uvh", longopts, NULL)) != -1) {
switch (opt) { switch (opt) {
case 's': case 's':
record_size = atoi(optarg); if (!parse_uint(optarg, record_size) || record_size < 1) {
if (record_size < 1) {
fprintf(stderr, fprintf(stderr,
"ERROR: Invalid record size (must be at least 1)\n"); "ERROR: Invalid record size (must be at least 1)\n");
return EXIT_FAILURE; return EXIT_FAILURE;
@ -1897,13 +2254,20 @@ int main(int argc, char **argv)
} }
break; break;
case 'B': case 'B':
branch_factor = atoi(optarg); if (!parse_uint(optarg, branch_factor) || branch_factor < 2) {
if (branch_factor < 2) {
fprintf(stderr, fprintf(stderr,
"ERROR: Invalid radix value, must be at least 2\n"); "ERROR: Invalid radix value, must be at least 2\n");
return EXIT_FAILURE; return EXIT_FAILURE;
} }
break; break;
case 'J':
if (!parse_uint(optarg, num_threads)
|| num_threads < 1
|| num_threads > MAX_THREADS) {
fprintf(stderr, "ERROR: Invalid number of threads\n");
return EXIT_FAILURE;
}
break;
case 'T': case 'T':
tempdir = optarg; tempdir = optarg;
break; break;
@ -1953,6 +2317,7 @@ int main(int argc, char **argv)
ctx.record_size = record_size; ctx.record_size = record_size;
ctx.memory_size = memory_size; ctx.memory_size = memory_size;
ctx.branch_factor = branch_factor; ctx.branch_factor = branch_factor;
ctx.num_threads = num_threads;
ctx.flag_unique = flag_unique; ctx.flag_unique = flag_unique;
ctx.flag_verbose = flag_verbose; ctx.flag_verbose = flag_verbose;
ctx.temporary_directory = tempdir; ctx.temporary_directory = tempdir;