I was reading a few pages of Knuths The Art of Computer Programming, Volume 4A about “branchless computation” (p. 180) in which he demonstrates how to get rid of branches by using conditional instructions. As an instructive example he consideres the inner part of merge sort, in which we are to merge two sorted lists of numbers into one bigger list of the numbers. The description as given by Knuth is as follows:
If $x_i < y_j$ set $z_k \gets x_i$, $i \gets i+1$, and go to x_done if $i = i_{max}$.
Otherwise set $z_k \gets y_i$, $j \gets j+1$, and go to y_done if $j = j_{max}$.
Then set $k \gets k+1$ and go to z_done if $k = k_{max}$.
$x$ and $y$ are the input lists, $z$ is the output merged list. $i$, $j$, and $k$ are loop indices for the three respective lists and the $_{max}$ variants are the lists length.
I got curious and decided to see how a standard optimizing compilier would
handle this case, and whether writing the assmebly yourself would provide any
gain in performance. After all, this is just slightly more complicated than the
trivial examples used to show off good codegen, so it would not be unreasonable
for the compiler to manage to fix a bad implementation of this. In addition, it
would serve as a great excuse to finally learn how to write x86
.
Here’s the inner loop in C code:
void branching(uint64_t *xs, size_t xmax, uint64_t *ys, size_t ymax,
uint64_t *zs, size_t zmax) {
size_t i = 0, j = 0, k = 0;
while (k < zmax) {
if (xs[i] < ys[j]) {
zs[k++] = xs[i++];
if (i == xmax) { // x_done
memcpy(zs + k, ys + j, 8 * (zmax - k));
return;
}
} else {
zs[k++] = ys[j++];
if (j == ymax) { // y_done
memcpy(zs + k, xs + i, 8 * (zmax - k));
return;
}
}
} // z_done
}
This seems to be a more or less straight forward textbook implementation of the
procedure, so it will do fine as a benchmark. As a quick check before going any
deeper into this we can use godbolt.org to see whether
this experiment is even worth doing. Godbolts x86-64 gcc 8.3
with -O3
spits
out this (annotations are by me):
branching(unsigned long*, unsigned long, unsigned long*, unsigned long,
unsigned long*, unsigned long):
test r9, r9 ; if (r9 == 0)
je .L15 ; goto .L15
push r13 ;
xor eax, eax ;
xor r11d, r11d ; j = 0
xor r10d, r10d ; i = 0
push r12 ;
push rbp ;
push rbx ;
jmp .L2 ;
.L17:
add r10, 1 ; i++
mov QWORD PTR [r8-8+rax*8], rbp ; zs[k-1] = xi
cmp r10, rsi ; if (i == xmax)
je .L16 ; goto .L16
.L6:
cmp r9, rax ; if (k == zmax)
je .L1 ; goto .L1
.L2:
lea r12, [rdi+r10*8] ; calculate xs + i
lea r13, [rdx+r11*8] ; calculate ys + j
add rax, 1 ; k++
mov rbp, QWORD PTR [r12] ; xi = xs[i]
mov rbx, QWORD PTR [r13+0] ; yj = ys[j]
cmp rbp, rbx ; if (xi < yj)
jb .L17 ; goto .L17
add r11, 1 ; j++
mov QWORD PTR [r8-8+rax*8], rbx ; zs[k-1] = yj
cmp r11, rcx ; if (j != ymax)
jne .L6 ; goto .L6
sub r9, rax ; y_done
pop rbx ;
mov rsi, r12 ;
pop rbp ;
lea rdi, [r8+rax*8] ;
pop r12 ;
lea rdx, [0+r9*8] ;
pop r13 ;
jmp memcpy ;
.L1:
pop rbx ; z_done
pop rbp ;
pop r12 ;
pop r13 ;
ret ;
.L16:
sub r9, rax ; x_done
pop rbx ;
mov rsi, r13 ;
pop rbp ;
lea rdi, [r8+rax*8] ;
pop r12 ;
lea rdx, [0+r9*8] ;
pop r13 ;
jmp memcpy ;
.L15:
ret
Plenty of branches!^{1}
Now, maybe it turns out that it doesn’t matter if we’re branching or not and that the compiler knows best. We could guess that the reason we’re still getting branches is because that’s really the best way to go here. After all “you can’t beat the compiler” seems to be the consensus in many programming circles. Let’s try to write a version in C without exessive use of branching. Then perhaps the compiler will generate different code, and we can see what that difference amounts to in terms of running time. We can adopt Knuth’s branchless version:
void nonbranching_but_branching(uint64_t *xs, size_t xmax, uint64_t *ys, size_t ymax,
uint64_t *zs, size_t zmax) {
size_t i = 0, j = 0, k = 0;
uint64_t xi = xs[i], yj = ys[j];
while ((i < xmax) && (j < ymax) && (k < zmax)) {
int64_t t = one_if_lt(xi - yj);
yj = min(xi, yj);
zs[k] = yj;
i += t;
xi = xs[i];
t ^= 1;
j += t;
yj = ys[j];
k += 1;
}
if (i == xmax)
memcpy(zs + k, ys + j, 8 * (zmax - k));
if (j == ymax)
memcpy(zs + k, xs + i, 8 * (zmax - k));
}
What is going on, you might ask? The general idea is to first get min(xi, yj)
, and then have a number t
that’s 1
if xi < yj
and 0
otherwise: we
can add t
to i
, since t=1
if we just wrote xi
to zs[k]
. Then we can
xor
it with 1
, effectively flipping 1
to 0
and 0
to 1
, and then add
t^1
to j
; this causes either i
or j
to be incremented but not both. We
used two convenience functions here, one_if_lt
and min
, both implemented
straight forward with branching, hoping that the compiler will figure this
out for us, now that the branches are much smaller.
Next, if we cheat a litte and assume that the highest bit in the numbers are never set we can get rid of those branches^{2}:
void nonbranching(uint64_t *xs, size_t xmax, uint64_t *ys, size_t ymax,
uint64_t *zs, size_t zmax) {
size_t i = 0, j = 0, k = 0;
uint64_t xi = xs[i], yj = ys[j];
while ((i < xmax) && (j < ymax) && (k < zmax)) {
uint64_t neg = (xi - yj) >> 63;
yj = neg * xi + (1 - neg) * yj;
zs[k] = yj;
i += neg;
xi = xs[i];
neg ^= 1;
j += neg;
yj = ys[j];
k += 1;
}
if (i == xmax)
memcpy(zs + k, ys + j, 8 * (zmax - k));
if (j == ymax)
memcpy(zs + k, xs + i, 8 * (zmax - k));
}
What is up with (xi - yj) >> 63
you may ask? This result is negative if xi < yj
, and so it will overflow and its most significant bit will be set.
Then we shift down logically (since we’re using unsigned integers^{3}) so
the bits that are filled in are all zeroes. Since the width is 64, we effectively
move the upper bit to the lowest position while setting all other bits to zero.
Knuth has another quirk, namely that his arrays usually points to the end of
the array, and his indices are negative, going from -xmax
up to 0
instead
of the more standard going from 0
up to xmax
. One consequence of this is
that the termination check can be done with one comparison instead of three, by
and
ing together the three indices: since they are negative they have their
most significant bit set, unless zero. Here’s both of the previous versions
with this reversal trick:
void nonbranching_but_branching_reverse(uint64_t *xs, size_t xmax,
uint64_t *ys, size_t ymax,
uint64_t *zs, size_t zmax) {
uint64_t *xse = xs + xmax;
uint64_t *yse = ys + ymax;
uint64_t *zse = zs + zmax;
ssize_t i = -((ssize_t) xmax);
ssize_t j = -((ssize_t) ymax);
ssize_t k = -((ssize_t) zmax);
uint64_t xi = xse[i], yj = yse[j];
while (i & j & k) {
uint64_t t = one_if_lt(xi - yj);
yj = min(xi, yj);
zse[k] = yj;
i += t;
xi = xse[i];
t ^= 1;
j += t;
yj = yse[j];
k += 1;
}
if (i == 0)
memcpy(zse + k, yse + j, -8 * k);
if (j == 0)
memcpy(zse + k, xse + i, -8 * k);
}
void nonbranching_reverse(uint64_t *xs, size_t xmax, uint64_t *ys, size_t ymax,
uint64_t *zs, size_t zmax) {
uint64_t *xse = xs + xmax;
uint64_t *yse = ys + ymax;
uint64_t *zse = zs + zmax;
ssize_t i = -((ssize_t) xmax);
ssize_t j = -((ssize_t) ymax);
ssize_t k = -((ssize_t) zmax);
uint64_t xi = xse[i], yj = yse[j];
while (i & j & k) {
uint64_t neg = (xi - yj) >> 63;
yj = neg * xi + (1 - neg) * yj;
zse[k] = yj;
i += neg;
xi = xse[i];
neg ^= 1;
j += neg;
yj = yse[j];
k += 1;
}
if (i == 0)
memcpy(zse + k, yse + j, -8 * k);
if (j == 0)
memcpy(zse + k, xse + i, -8 * k);
}
Technically, I suppose we do assume that the length of the
arrays are not >2**63
, so that they fit in an ssize_t
, but considering
that the address space of x86-64
is not 64 bits, but merely 48 bits^{4},
this is not a problem, even in theory.
Lastly, we can try to write the assembly ourselves. When translating the
branch-free routine by Knuth into x86
there are a number of things to do.
First we need to figure out how to get -1/0/+1
by comparing two variables, as
MMIX
s CMP
instruction does. However, instead of trying to translate this
line by line, which would end up with us having more instructions than needed,
we should rather look more closely at what we’re doing, so that we really
understand the minimal amount of work that we have to do.
We only need to do two things: compare $x_i$ and $y_i$ and load the smaller
into a register, and increment either i
or j
. The former can be done using
cmovl
, and the latter can be done in a similar fasion as Knuth does it,
which is basically what we’ve been doing up to this point in C.
This is the version I ended up with (here in inline-GCC asm format):
1: mov %[minxy], %[yj] ;
cmp %[xi], %[yj] ; minxy = min(xi, yj)
cmovl %[minxy], %[xi] ;
mov QWORD PTR [%[zse]+8*%[k]], %[minxy] ; zs[k] = minxy
mov %[t], 0 ; t = 0
cmovl %[t], %[one] ; if xi < yj: t = 1
add %[i], %[t] ; i += t
mov %[xi], QWORD PTR [%[xse]+8*%[i]] ; xi = xs[i]
xor %[t], 1 ; t ^= 1
add %[j], %[t] ; j += t
mov %[yj], QWORD PTR [%[yse]+8*%[j]] ; yj = ys[j]
add %[k], 1 ; k += 1
mov %[u], %[i] ;
and %[u], %[j] ;
test %[u], %[k] ; if ((i & j & k) != 0)
jnz 1b ; goto 1
There’s a few quirks here, like having a couple of mov
instructions in
between the second conditional load and the instruction it conditions on, and
the fact that cmovl
couldn’t take an immediate value, so I had to setup a
register with only the value 1
in it. A sneaky detail to keep in mind is that
when we set t = 0
we cannot use the trick of xor
ing t
with itself,
since this will change the flags, causing the subsequent cmovl
to be wrong.
Now we can take a look at the assembly generated from some of the other
fuctions by using objdump -d
.
Our own programs are compiled with -O3 -march=native
.
Here is the inner loop in nonbranching_reverse
:
<nonbranching_reverse>:
1ef0: mov rax,rdi
1ef3: sub rax,rsi
1ef6: shr rax,0x3f
1efa: mov rdx,r8
1efd: sub rdx,rax
1f00: imul rdx,rsi
1f04: imul rdi,rax
1f08: add rbp,rax
1f0b: xor rax,0x1
1f0f: add rdi,rdx
1f12: mov QWORD PTR [r13+r12*8+0x0],rdi
1f17: add rcx,rax
1f1a: inc r12
1f1d: mov rax,rbp
1f20: and rax,r12
1f23: mov rdi,QWORD PTR [rbx+rbp*8]
1f27: mov rsi,QWORD PTR [r10+rcx*8]
1f2b: test rax,rcx
1f2e: jne 1ef0 <nonbranching_reverse+0x40>
Sure looks a lot better than branching
!
This seems more or less reasonable, but we can see that the multiplication
trickery that we used to avoid the min
branch takes up some space here;
presumably it also takes some time. Maybe one little branch isn’t too bad
though, and perhaps the compiler is more willingly to use conditional
instructions if we use the ternary operator, like this:
void nonbranching_reverse_ternary(uint64_t *xs, size_t xmax, uint64_t *ys, size_t ymax,
uint64_t *zs, size_t zmax) {
uint64_t *xse = xs + xmax;
uint64_t *yse = ys + ymax;
uint64_t *zse = zs + zmax;
ssize_t i = -((ssize_t) xmax);
ssize_t j = -((ssize_t) ymax);
ssize_t k = -((ssize_t) zmax);
uint64_t xi = xse[i], yj = yse[j];
while (i & j & k) {
uint64_t ybig = (xi - yj) >> 63;
yj = ybig ? xi : yj;
zse[k] = yj;
i += ybig;
xi = xse[i];
ybig ^= 1;
j += ybig;
yj = yse[j];
k += 1;
}
if (i == 0)
memcpy(zse + k, yse + j, -8 * k);
if (j == 0)
memcpy(zse + k, xse + i, -8 * k);
}
This time, if we look at the assembly, we can see that the compiler is finally getting it: cmove
!
2080: mov rax,yj ;
2083: sub rax,xi ;
2086: shr rax,0x3f ; t = (yj - xi) >> 63
208a: cmove yj,xi ; yj = t == 0 ? xi : yj
208e: add j,rax ; j += t
2091: mov QWORD PTR [zs+k*8],yj ; z[k] = yj
2096: xor rax,0x1 ; t ^= 1
209a: inc k ; k++
209d: add i,rax ; i += t
20a0: mov rax,k ;
20a3: and rax,j ; t = k & j
20a6: mov yj,QWORD PTR [ys+j*8] ; yj = ys[j]
20aa: mov xi,QWORD PTR [xs+i*8] ; xi = xs[i]
20ae: test rax,i ; if ((i & j & k) != 0)
20b1: jne 2080 ; goto .2080
So we see it’s really the same! Curiously, the compiler turned our code around
to have t
be 1
if xi
was the bigger, whereas our ybig
was 1
if yj
was the bigger.
And now for the results! We fill two arrays with random elements and run
branching
on it, such that we get the merged array back. This is used as the
ground truth which all other variations are checked agaist, in case we have
messed up. Then we use clock_gettime
to measure the wall clock time that we
spend, per method. The following is running time in milliseconds where both
lists are 2**25
elements long, averaged over 100 runs; 10 iterations per seed
and 10 different seeds (srand(i)
for each iteration).
These are the numbers I got on a Intel i7-7500U@2.7GHz (avg +/- var
):
branching: 30.998 +/- 0.001
nonbranching_but_branching: 27.330 +/- 0.002
nonbranching: 24.770 +/- 0.000
nonbranching_but_branching_reverse: 19.387 +/- 0.000
nonbranching_reverse: 20.015 +/- 0.000
nonbranching_reverse_ternary: 19.038 +/- 0.000
asm_nb_rev: 18.987 +/- 0.001
I also ran the suite on another machine with a Intel i5-8250U@1.60GHz, in order to see if there would be any significant difference:
branching: 31.405 +/- 0.034
nonbranching_but_branching: 27.646 +/- 0.097
nonbranching: 27.894 +/- 0.021
nonbranching_but_branching_reverse: 22.760 +/- 0.040
nonbranching_reverse: 21.284 +/- 0.050
nonbranching_reverse_ternary: 19.299 +/- 0.002
asm_nb_rev: 19.793 +/- 0.009
Interestingly, on this CPU our assembly is slightly slower than the ternary
version; I guess this is due to us using a cmovl
where the compiler generated
version used the shifting trick.
We can’t possibly have done all this merging without making a proper
mergesort
in the end! Luckily for us, the merge
part is really the
only difficult part of the routine:
void merge_sort(uint64_t *xs, size_t n, uint64_t *buf) {
if (n < 2) return;
size_t h = n / 2;
merge_sort(xs, h, buf);
merge_sort(xs + h, n - h, buf + h);
merge(xs, h, xs + h, n - h, buf, n);
memcpy(xs, buf, 8 * n);
}
Unfortunately we have to merge to a buffer and then memcpy
it back. Perhaps
this is fixable: we can make the sorting routine either put the result in xs
or in buf
, and by having the recursive calls say which we can merge into the
other, assuming both recursive calls agree(!!^{5}). That is, if the
recursive calls say that the sorted subarrays are in xs
, we merge into buf
and tell our caller that our
result is in buf
. At the end, we just need to
make sure that the final sorted numbers are in xs
.
void _sort_asm(uint64_t *xs, size_t n, uint64_t *buf, int *into_buf) {
if (n < 2) {
*into_buf = 0;
return;
}
size_t h = n / 2;
int res_in_buf;
_sort_asm(xs, h, buf, &res_in_buf); // WARNING: `res_in_buf` for the two calls is needs
_sort_asm(xs + h, n - h, buf + h, &res_in_buf); // not be the same in the real world!
*into_buf = res_in_buf ^ 1;
if (res_in_buf)
asm_nb_rev(buf, h, buf + h, n - h, xs, n);
else
asm_nb_rev(xs, h, xs + h, n - h, buf, n);
}
void sort_asm(uint64_t *xs, size_t n, uint64_t *buf) {
int res_in_buf;
_sort_asm(xs, n, buf, &res_in_buf);
if (res_in_buf) {
memcpy(xs, buf, 8 * n);
}
}
and similar, for the other variants.
You might see the branch and wonder if we can remove it — I tried, by making
an array {xs, buf}
and index it with res_in_buf
, but it caused a minor
slowdown: maybe some branching is fine after all.
Here are the running times:
i7-7500U i5-8250U
sort_branching: 369.479 +/- 0.047 393.762 +/- 0.082
sort_nonbranching_but_branching: 324.337 +/- 0.014 337.120 +/- 0.099
sort_nonbranching: 325.658 +/- 0.028 352.802 +/- 0.120
sort_nonbranching_but_branching_reverse: 279.237 +/- 0.164 287.799 +/- 0.154
sort_nonbranching_reverse: 283.927 +/- 0.033 299.277 +/- 0.929
sort_nonbranching_reverse_ternary: 270.668 +/- 0.009 278.644 +/- 1.677
sort_asm_nb_rev: 270.228 +/- 0.009 281.657 +/- 0.360
If you would like to run the suite yourself, the git repo is avaiable here.
Thanks for reading.
Originally I had omitted the _done
parts, and the code was much cleaner, and I’m not sure why having it in complicates this that much. Also, why is k
incremented before storing zs[k]
so that we have to store zs[k-1]
instead? ↩︎
Curiously, if we change from uint64_t
to int64_t
and use ((a-b)>>63)&1
for the test we do not depend on the magnitudes of the numbers (as the compiler can assume signed overflow will not happen); also the and
never makes it to the assembly, and we still use logical instead of arithmetic shift. ↩︎
The alternative is arithmetic shift in which the sign bit is propagated down. In this case we would end up with either all zeroes or all ones. ↩︎
https://en.wikipedia.org/wiki/X86-64#Virtual_address_space_details ↩︎
This is really only the case if n
is a power of two: otherwise you’ll have two siblings in the call tree with different n
s, and this difference will cause two leaf nodes to be at different depths, which in turn will make them “out of sync”. ↩︎
This work is licensed under a Creative Commons Attribution-ShareAlike 4.0 International License