Skip to content

Commit 1deac0a

Browse files
committed
parallelMergeInPlace, parallelEqual
1 parent 3718f5e commit 1deac0a

File tree

1 file changed

+256
-2
lines changed

1 file changed

+256
-2
lines changed

parallel_algorithm.d

Lines changed: 256 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import std.typetuple, std.parallelism, std.range, std.functional,
22
std.algorithm, std.stdio, std.array, std.traits, std.conv,
3-
core.stdc.string;
3+
core.stdc.string, core.atomic;
44

55
version(unittest) {
66
import std.random, std.typecons, std.math;
@@ -361,6 +361,82 @@ unittest {
361361
assert(buf == [1, 2, 2, 4, 4, 6, 8, 8, 10, 12, 16, 32]);
362362
}
363363

364+
void parallelMergeInPlace(alias pred = "a < b", R)(
365+
R range,
366+
size_t middle,
367+
size_t minParallel = 1024,
368+
TaskPool pool = null
369+
) {
370+
if(pool is null) pool = taskPool;
371+
immutable rlen = range.length;
372+
alias binaryFun!(pred) comp;
373+
374+
static size_t largestLess(T)(T[] data, T value) {
375+
return assumeSorted!(comp)(data).lowerBound(value).length;
376+
}
377+
378+
static size_t smallestGr(T)(T[] data, T value) {
379+
return data.length -
380+
assumeSorted!(comp)(data).upperBound(value).length;
381+
}
382+
383+
if (range.length < 2 || middle == 0 || middle == range.length) {
384+
return;
385+
}
386+
387+
if (range.length == 2) {
388+
if(comp(range[1], range[0])) {
389+
swap(range[0], range[1]);
390+
}
391+
392+
return;
393+
}
394+
395+
size_t half1, half2;
396+
397+
if (middle > range.length - middle) {
398+
half1 = middle / 2;
399+
auto pivot = range[half1];
400+
half2 = largestLess(range[middle..rlen], pivot);
401+
} else {
402+
half2 = (range.length - middle) / 2;
403+
auto pivot = range[half2 + middle];
404+
half1 = smallestGr(range[0..middle], pivot);
405+
}
406+
407+
bringToFront(range[half1..middle], range[middle..middle + half2]);
408+
size_t newMiddle = half1 + half2;
409+
410+
auto left = range[0..newMiddle];
411+
auto right = range[newMiddle..range.length];
412+
413+
if(left.length >= minParallel) {
414+
auto leftTask = scopedTask!(parallelMergeInPlace!(pred, R))
415+
(left, half1, minParallel, pool);
416+
taskPool.put(leftTask);
417+
parallelMergeInPlace!(pred, R)
418+
(right, half2 + middle - newMiddle, minParallel, pool);
419+
leftTask.yieldForce();
420+
} else {
421+
parallelMergeInPlace!(pred, R)(left, half1, minParallel, pool);
422+
parallelMergeInPlace!(pred, R)
423+
(right, half2 + middle - newMiddle, minParallel, pool);
424+
}
425+
}
426+
427+
unittest {
428+
auto arr = new int[10_000];
429+
430+
// Make sure serial and parallel both work by bypassing parallelism
431+
// by making minParallel huge.
432+
foreach(minParallel; [64, 20_000]) {
433+
copy(iota(0, 10_000, 2), arr[0..5_000]);
434+
copy(iota(1, 10_000, 2), arr[5_000..$]);
435+
parallelMergeInPlace(arr, 5_000, minParallel);
436+
assert(equal(arr, iota(10_000)), to!string(minParallel));
437+
}
438+
}
439+
364440
// In a few implementations we need to create custom ranges to be reduced.
365441
// std.parallelism.reduce checks for a random access range to conform to
366442
// Phobos conventions but only actually uses opIndex and length.
@@ -498,6 +574,104 @@ unittest {
498574
assert(parallelCount!"a == 2"([1, 2, 1, 2, 3]) == 2);
499575
}
500576

577+
void parallelAdjacentDifference(alias pred = "a - b", R1, R2)(
578+
R1 input,
579+
R2 output,
580+
TaskPool pool = null,
581+
size_t workUnitSize = size_t.max
582+
) if(allSatisfy!(isRandomAccessRange, TypeTuple!(R1, R2)) &&
583+
hasAssignableElements!(R2) &&
584+
is(typeof(binaryFun!pred(R1.init[0], R1.init[1])) : ElementType!R2)
585+
) {
586+
static size_t getLength(R)(ref R range) {
587+
static if(is(typeof(range.length) : size_t)) {
588+
return range.length;
589+
} else {
590+
return size_t.max;
591+
}
592+
}
593+
594+
// getLength(output) + 1 because we need one less element in output
595+
// than we had in input.
596+
immutable minLen = min(getLength(input), getLength(output) + 1);
597+
if(pool is null) pool = taskPool;
598+
599+
if(workUnitSize == size_t.max) {
600+
workUnitSize = pool.defaultWorkUnitSize(minLen - 1);
601+
}
602+
603+
// Using parallel foreach to iterate over individual elements is too
604+
// slow b/c of delegate overhead for such fine grained parallelism.
605+
// Use parallel foreach to iterate over slices and handle them serially.
606+
auto sliceStarts = iota(0, minLen - 1, workUnitSize);
607+
608+
foreach(startIndex; pool.parallel(sliceStarts, 1)) {
609+
immutable endIndex = min(startIndex + workUnitSize, minLen - 1);
610+
611+
// This avoids some indirection and seems to be faster.
612+
auto ip = input;
613+
auto op = output;
614+
615+
foreach(i; startIndex..endIndex) {
616+
op[i] = binaryFun!pred(ip[i + 1], ip[i]);
617+
}
618+
}
619+
}
620+
621+
unittest {
622+
auto input = [1, 2, 4, 8, 16, 32];
623+
auto output = new int[5];
624+
parallelAdjacentDifference(input, output);
625+
assert(output == [1, 2, 4, 8, 16]);
626+
}
627+
628+
bool parallelEqual(alias pred = "a == b", R1, R2)(
629+
R1 range1,
630+
R2 range2,
631+
size_t workUnitSize = size_t.max,
632+
TaskPool pool = null
633+
)
634+
if(isRandomAccessRange!R1 && isRandomAccessRange!R2 &&
635+
hasLength!R1 && hasLength!R2) {
636+
637+
if(range1.length != range2.length) return false;
638+
immutable len = range1.length;
639+
640+
if(pool is null) pool = taskPool;
641+
if(workUnitSize == size_t.max) {
642+
workUnitSize = pool.defaultWorkUnitSize(len);
643+
}
644+
645+
auto chunks1 = std.range.chunks(range1, workUnitSize);
646+
auto chunks2 = std.range.chunks(range2, workUnitSize);
647+
assert(chunks1.length == chunks2.length);
648+
immutable nChunks = chunks1.length;
649+
650+
bool ret = true;
651+
652+
try {
653+
foreach(chunkIndex; pool.parallel(iota(nChunks), 1)) {
654+
auto c1 = chunks1[chunkIndex];
655+
auto c2 = chunks2[chunkIndex];
656+
if(!std.algorithm.equal!pred(c1, c2)) {
657+
atomicStore(ret, false);
658+
break;
659+
}
660+
}
661+
} catch(ParallelForeachError) {
662+
// Ignore it. It's because we tried to break out of a parallel
663+
// foreach loop.
664+
}
665+
666+
return ret;
667+
}
668+
669+
unittest {
670+
assert(parallelEqual([1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 6]));
671+
assert(!parallelEqual([1, 2, 3, 4, 5, 6], [1, 3, 3, 4, 5, 6]));
672+
assert(!parallelEqual([1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 7]));
673+
}
674+
501675
//////////////////////////////////////////////////////////////////////////////
502676
// Benchmarks
503677
//////////////////////////////////////////////////////////////////////////////
@@ -525,6 +699,42 @@ void mergeBenchmark() {
525699
assert(equal(buf, sort(a ~ b)));
526700
}
527701

702+
void mergeInPlaceBenchmark() {
703+
enum N = 8192;
704+
enum nIter = 100;
705+
auto ab = new float[2 * N];
706+
auto a = ab[0..$ / 2];
707+
auto b = ab[$ / 2..$];
708+
709+
foreach(ref elem; ab) elem = uniform(0f, 1f);
710+
711+
auto sw = StopWatch(AutoStart.no);
712+
foreach(i; 0..nIter) {
713+
randomShuffle(ab);
714+
sort(a);
715+
sort(b);
716+
sw.start();
717+
718+
// Disable parallelism by setting minParallel to be bigger than N.
719+
parallelMergeInPlace(ab, N / 2, N + 1);
720+
sw.stop();
721+
}
722+
writeln("Serial In-place Merge: ", sw.peek.msecs);
723+
724+
sw.reset();
725+
foreach(i; 0..nIter) {
726+
randomShuffle(ab);
727+
sort(a);
728+
sort(b);
729+
sw.start();
730+
731+
// Use default minParallel.
732+
parallelMergeInPlace(ab, N / 2);
733+
sw.stop();
734+
}
735+
writeln("Parallel In-place Merge: ", sw.peek.msecs);
736+
}
737+
528738
void sortBenchmark() {
529739
enum N = 32768;
530740
enum nIter = 100;
@@ -563,7 +773,7 @@ void dotProdBenchmark() {
563773
}
564774

565775
void countBenchmark() {
566-
enum n = 50_000;
776+
enum n = 3_000;
567777
enum nIter = 1_000;
568778
auto nums = new int[n];
569779
foreach(ref elem; nums) elem = uniform(0, 3);
@@ -585,9 +795,53 @@ void countBenchmark() {
585795
writeln("Parallel count by pred: ", sw.peek.msecs);
586796
}
587797

798+
void adjacentDifferenceBenchmark() {
799+
enum n = 500_000;
800+
enum nIter = 100;
801+
802+
auto input = new int[n];
803+
auto output = new int[n - 1];
804+
foreach(ref elem; input) {
805+
elem = uniform(0, 10_000);
806+
}
807+
808+
auto sw = StopWatch(AutoStart.yes);
809+
foreach(iter; 0..nIter) {
810+
// Quick n' dirty inline serial impl.
811+
foreach(i; 0..n - 1) {
812+
output[i] = input[i + 1] - input[i];
813+
}
814+
}
815+
816+
writeln("Serial adjacent difference: ", sw.peek.msecs);
817+
818+
sw.reset();
819+
foreach(iter; 0..nIter) parallelAdjacentDifference(input, output);
820+
writeln("Parallel adjacent difference: ", sw.peek.msecs);
821+
}
822+
823+
void equalBenchmark() {
824+
enum n = 20_000;
825+
enum nIter = 1_000;
826+
827+
auto a = new int[n];
828+
auto b = new int[n];
829+
830+
auto sw = StopWatch(AutoStart.yes);
831+
foreach(i; 0..nIter) equal(a, b);
832+
writeln("Serial equal: ", sw.peek.msecs);
833+
834+
sw.reset();
835+
foreach(i; 0..nIter) parallelEqual(a, b);
836+
writeln("Parallel equal: ", sw.peek.msecs);
837+
}
838+
588839
void main() {
589840
mergeBenchmark();
841+
mergeInPlaceBenchmark();
590842
sortBenchmark();
591843
dotProdBenchmark();
592844
countBenchmark();
845+
adjacentDifferenceBenchmark();
846+
equalBenchmark();
593847
}

0 commit comments

Comments
 (0)