April 24, 2019

I was reading a few pages of Knuths *The Art of Computer Programming*, Volume
4A about “branchless computation” (p. 180) in which he demonstrates how to get
rid of branches by using conditional instructions. As an instructive
example he consideres the inner part of *merge sort*, in which we are to merge
two sorted lists of numbers into one bigger list of the numbers. The description
as given by Knuth is as follows:

If $x_i < y_j$ set $z_k \gets x_i$, $i \gets i+1$, and go to *x_done* if $i = i_{max}$.

Otherwise set $z_k \gets y_i$, $j \gets j+1$, and go to *y_done* if $j = j_{max}$.

Then set $k \gets k+1$ and go to *z_done* if $k = k_{max}$.

$x$ and $y$ are the input lists, $z$ is the output merged list. $i$, $j$, and $k$ are loop indices for the three respective lists and the $_{max}$ variants are the lists length.

I got curious and decided to see how a standard optimizing compilier would
handle this case, and whether writing the assmebly yourself would provide any
gain in performance. After all, this is just slightly more complicated than the
trivial examples used to show off good codegen, so it would not be unreasonable
for the compiler to manage to fix a bad implementation of this. In addition, it
would serve as a great excuse to finally learn how to write `x86`

.

Here’s the inner loop in C code:

```
void branching(uint64_t *xs, size_t xmax, uint64_t *ys, size_t ymax,
uint64_t *zs, size_t zmax) {
size_t i = 0, j = 0, k = 0;
while (k < zmax) {
if (xs[i] < ys[j]) {
zs[k++] = xs[i++];
if (i == xmax) { // x_done
memcpy(zs + k, ys + j, 8 * (zmax - k));
return;
}
} else {
zs[k++] = ys[j++];
if (j == ymax) { // y_done
memcpy(zs + k, xs + i, 8 * (zmax - k));
return;
}
}
} // z_done
}
```

This seems to be a more or less straight forward textbook implementation of the
procedure, so it will do fine as a benchmark. As a quick check before going any
deeper into this we can use godbolt.org to see whether
this experiment is even worth doing. Godbolts `x86-64 gcc 8.3`

with `-O3`

spits
out this (annotations are by me):

```
branching(unsigned long*, unsigned long, unsigned long*, unsigned long,
unsigned long*, unsigned long):
test r9, r9 ; if (r9 == 0)
je .L15 ; goto .L15
push r13 ;
xor eax, eax ;
xor r11d, r11d ; j = 0
xor r10d, r10d ; i = 0
push r12 ;
push rbp ;
push rbx ;
jmp .L2 ;
.L17:
add r10, 1 ; i++
mov QWORD PTR [r8-8+rax*8], rbp ; zs[k-1] = xi
cmp r10, rsi ; if (i == xmax)
je .L16 ; goto .L16
.L6:
cmp r9, rax ; if (k == zmax)
je .L1 ; goto .L1
.L2:
lea r12, [rdi+r10*8] ; calculate xs + i
lea r13, [rdx+r11*8] ; calculate ys + j
add rax, 1 ; k++
mov rbp, QWORD PTR [r12] ; xi = xs[i]
mov rbx, QWORD PTR [r13+0] ; yj = ys[j]
cmp rbp, rbx ; if (xi < yj)
jb .L17 ; goto .L17
add r11, 1 ; j++
mov QWORD PTR [r8-8+rax*8], rbx ; zs[k-1] = yj
cmp r11, rcx ; if (j != ymax)
jne .L6 ; goto .L6
sub r9, rax ; y_done
pop rbx ;
mov rsi, r12 ;
pop rbp ;
lea rdi, [r8+rax*8] ;
pop r12 ;
lea rdx, [0+r9*8] ;
pop r13 ;
jmp memcpy ;
.L1:
pop rbx ; z_done
pop rbp ;
pop r12 ;
pop r13 ;
ret ;
.L16:
sub r9, rax ; x_done
pop rbx ;
mov rsi, r13 ;
pop rbp ;
lea rdi, [r8+rax*8] ;
pop r12 ;
lea rdx, [0+r9*8] ;
pop r13 ;
jmp memcpy ;
.L15:
ret
```

Plenty of branches!^{1}

Now, maybe it turns out that it doesn’t matter if we’re branching or not and
that the compiler knows best. We could guess that the reason we’re still
getting branches is because that’s really the best way to go here. After all
“you can’t beat the compiler” seems to be the consensus in *many* programming
circles. Let’s try to write a version in C without exessive use of branching.
Then perhaps the compiler will generate different code, and we can see what
that difference amounts to in terms of running time. We can adopt Knuth’s
branchless version:

```
void nonbranching_but_branching(uint64_t *xs, size_t xmax, uint64_t *ys, size_t ymax,
uint64_t *zs, size_t zmax) {
size_t i = 0, j = 0, k = 0;
uint64_t xi = xs[i], yj = ys[j];
while ((i < xmax) && (j < ymax) && (k < zmax)) {
int64_t t = one_if_lt(xi - yj);
yj = min(xi, yj);
zs[k] = yj;
i += t;
xi = xs[i];
t ^= 1;
j += t;
yj = ys[j];
k += 1;
}
if (i == xmax)
memcpy(zs + k, ys + j, 8 * (zmax - k));
if (j == ymax)
memcpy(zs + k, xs + i, 8 * (zmax - k));
}
```

What is going on, you might ask? The general idea is to first get ```
min(xi,
yj)
```

, and then have a number `t`

that’s `1`

if `xi < yj`

and `0`

otherwise: we
can add `t`

to `i`

, since `t=1`

if we just wrote `xi`

to `zs[k]`

. Then we can
`xor`

it with `1`

, effectively flipping `1`

to `0`

and `0`

to `1`

, and then add
`t^1`

to `j`

; this causes either `i`

or `j`

to be incremented but not both. We
used two convenience functions here, `one_if_lt`

and `min`

, both implemented
straight forward **with branching**, hoping that the compiler will figure this
out for us, now that the branches are much smaller.

Next, if we cheat a litte and assume that the highest bit in the numbers are
never set we can get rid of those branches^{2}:

```
void nonbranching(uint64_t *xs, size_t xmax, uint64_t *ys, size_t ymax,
uint64_t *zs, size_t zmax) {
size_t i = 0, j = 0, k = 0;
uint64_t xi = xs[i], yj = ys[j];
while ((i < xmax) && (j < ymax) && (k < zmax)) {
uint64_t neg = (xi - yj) >> 63;
yj = neg * xi + (1 - neg) * yj;
zs[k] = yj;
i += neg;
xi = xs[i];
neg ^= 1;
j += neg;
yj = ys[j];
k += 1;
}
if (i == xmax)
memcpy(zs + k, ys + j, 8 * (zmax - k));
if (j == ymax)
memcpy(zs + k, xs + i, 8 * (zmax - k));
}
```

What is up with `(xi - yj) >> 63`

you may ask? This result is negative if ```
xi <
yj
```

, and so it will overflow and its most significant bit will be set.
Then we shift down logically (since we’re using unsigned integers^{3}) so
the bits that are filled in are all zeroes. Since the width is 64, we effectively
move the upper bit to the lowest position while setting all other bits to zero.

Knuth has another quirk, namely that his arrays usually points to the *end* of
the array, and his indices are negative, going from `-xmax`

up to `0`

instead
of the more standard going from `0`

up to `xmax`

. One consequence of this is
that the termination check can be done with one comparison instead of three, by
`and`

ing together the three indices: since they are negative they have their
most significant bit set, unless zero. Here’s both of the previous versions
with this reversal trick:

```
void nonbranching_but_branching_reverse(uint64_t *xs, size_t xmax,
uint64_t *ys, size_t ymax,
uint64_t *zs, size_t zmax) {
uint64_t *xse = xs + xmax;
uint64_t *yse = ys + ymax;
uint64_t *zse = zs + zmax;
ssize_t i = -((ssize_t) xmax);
ssize_t j = -((ssize_t) ymax);
ssize_t k = -((ssize_t) zmax);
uint64_t xi = xse[i], yj = yse[j];
while (i & j & k) {
uint64_t t = one_if_lt(xi - yj);
yj = min(xi, yj);
zse[k] = yj;
i += t;
xi = xse[i];
t ^= 1;
j += t;
yj = yse[j];
k += 1;
}
if (i == 0)
memcpy(zse + k, yse + j, -8 * k);
if (j == 0)
memcpy(zse + k, xse + i, -8 * k);
}
void nonbranching_reverse(uint64_t *xs, size_t xmax, uint64_t *ys, size_t ymax,
uint64_t *zs, size_t zmax) {
uint64_t *xse = xs + xmax;
uint64_t *yse = ys + ymax;
uint64_t *zse = zs + zmax;
ssize_t i = -((ssize_t) xmax);
ssize_t j = -((ssize_t) ymax);
ssize_t k = -((ssize_t) zmax);
uint64_t xi = xse[i], yj = yse[j];
while (i & j & k) {
uint64_t neg = (xi - yj) >> 63;
yj = neg * xi + (1 - neg) * yj;
zse[k] = yj;
i += neg;
xi = xse[i];
neg ^= 1;
j += neg;
yj = yse[j];
k += 1;
}
if (i == 0)
memcpy(zse + k, yse + j, -8 * k);
if (j == 0)
memcpy(zse + k, xse + i, -8 * k);
}
```

Technically, I suppose we do assume that the length of the
arrays are not `>2**63`

, so that they fit in an `ssize_t`

, but considering
that the address space of `x86-64`

is *not* 64 bits, but *merely* 48 bits^{4},
this is not a problem, even in theory.

Lastly, we can try to write the assembly ourselves. When translating the
branch-free routine by Knuth into `x86`

there are a number of things to do.
First we need to figure out how to get `-1/0/+1`

by comparing two variables, as
`MMIX`

s `CMP`

instruction does. However, instead of trying to translate this
line by line, which would end up with us having more instructions than needed,
we should rather look more closely at what we’re doing, so that we really
understand the minimal amount of work that we have to do.

We only need to do two things: compare $x_i$ and $y_i$ and load the smaller
into a register, and increment either `i`

or `j`

. The former can be done using
`cmovl`

, and the latter can be done in a similar fasion as Knuth does it,
which is basically what we’ve been doing up to this point in C.
This is the version I ended up with (here in inline-GCC asm format):

```
1: mov %[minxy], %[yj] ;
cmp %[xi], %[yj] ; minxy = min(xi, yj)
cmovl %[minxy], %[xi] ;
mov QWORD PTR [%[zse]+8*%[k]], %[minxy] ; zs[k] = minxy
mov %[t], 0 ; t = 0
cmovl %[t], %[one] ; if xi < yj: t = 1
add %[i], %[t] ; i += t
mov %[xi], QWORD PTR [%[xse]+8*%[i]] ; xi = xs[i]
xor %[t], 1 ; t ^= 1
add %[j], %[t] ; j += t
mov %[yj], QWORD PTR [%[yse]+8*%[j]] ; yj = ys[j]
add %[k], 1 ; k += 1
mov %[u], %[i] ;
and %[u], %[j] ;
test %[u], %[k] ; if ((i & j & k) != 0)
jnz 1b ; goto 1
```

There’s a few quirks here, like having a couple of `mov`

instructions in
between the second conditional load and the instruction it conditions on, and
the fact that `cmovl`

couldn’t take an immediate value, so I had to setup a
register with only the value `1`

in it. A sneaky detail to keep in mind is that
when we set `t = 0`

we cannot use the trick of `xor`

ing `t`

with itself,
since this will change the flags, causing the subsequent `cmovl`

to be wrong.

Now we can take a look at the assembly generated from some of the other
fuctions by using `objdump -d`

.
Our own programs are compiled with `-O3 -march=native`

.
Here is the inner loop in `nonbranching_reverse`

:

```
<nonbranching_reverse>:
1ef0: mov rax,rdi
1ef3: sub rax,rsi
1ef6: shr rax,0x3f
1efa: mov rdx,r8
1efd: sub rdx,rax
1f00: imul rdx,rsi
1f04: imul rdi,rax
1f08: add rbp,rax
1f0b: xor rax,0x1
1f0f: add rdi,rdx
1f12: mov QWORD PTR [r13+r12*8+0x0],rdi
1f17: add rcx,rax
1f1a: inc r12
1f1d: mov rax,rbp
1f20: and rax,r12
1f23: mov rdi,QWORD PTR [rbx+rbp*8]
1f27: mov rsi,QWORD PTR [r10+rcx*8]
1f2b: test rax,rcx
1f2e: jne 1ef0 <nonbranching_reverse+0x40>
```

Sure looks a lot better than `branching`

!
This seems more or less reasonable, but we can see that the multiplication
trickery that we used to avoid the `min`

branch takes up some space here;
presumably it also takes some time. Maybe one little branch isn’t too bad
though, and perhaps the compiler is more willingly to use conditional
instructions if we use the ternary operator, like this:

```
void nonbranching_reverse_ternary(uint64_t *xs, size_t xmax, uint64_t *ys, size_t ymax,
uint64_t *zs, size_t zmax) {
uint64_t *xse = xs + xmax;
uint64_t *yse = ys + ymax;
uint64_t *zse = zs + zmax;
ssize_t i = -((ssize_t) xmax);
ssize_t j = -((ssize_t) ymax);
ssize_t k = -((ssize_t) zmax);
uint64_t xi = xse[i], yj = yse[j];
while (i & j & k) {
uint64_t ybig = (xi - yj) >> 63;
yj = ybig ? xi : yj;
zse[k] = yj;
i += ybig;
xi = xse[i];
ybig ^= 1;
j += ybig;
yj = yse[j];
k += 1;
}
if (i == 0)
memcpy(zse + k, yse + j, -8 * k);
if (j == 0)
memcpy(zse + k, xse + i, -8 * k);
}
```

This time, if we look at the assembly, we can see that the compiler is finally getting it: `cmove`

!

```
2080: mov rax,yj ;
2083: sub rax,xi ;
2086: shr rax,0x3f ; t = (yj - xi) >> 63
208a: cmove yj,xi ; yj = t == 0 ? xi : yj
208e: add j,rax ; j += t
2091: mov QWORD PTR [zs+k*8],yj ; z[k] = yj
2096: xor rax,0x1 ; t ^= 1
209a: inc k ; k++
209d: add i,rax ; i += t
20a0: mov rax,k ;
20a3: and rax,j ; t = k & j
20a6: mov yj,QWORD PTR [ys+j*8] ; yj = ys[j]
20aa: mov xi,QWORD PTR [xs+i*8] ; xi = xs[i]
20ae: test rax,i ; if ((i & j & k) != 0)
20b1: jne 2080 ; goto .2080
```

So we see it’s really the same! Curiously, the compiler turned our code around
to have `t`

be `1`

if `xi`

was the bigger, whereas our `ybig`

was `1`

if `yj`

was the bigger.

And now for the results! We fill two arrays with random elements and run
`branching`

on it, such that we get the merged array back. This is used as the
ground truth which all other variations are checked agaist, in case we have
messed up. Then we use `clock_gettime`

to measure the wall clock time that we
spend, per method. The following is running time in milliseconds where both
lists are `2**25`

elements long, averaged over 100 runs; 10 iterations per seed
and 10 different seeds (`srand(i)`

for each iteration).

These are the numbers I got on a Intel i7-7500U@2.7GHz (`avg +/- var`

):

```
branching: 30.998 +/- 0.001
nonbranching_but_branching: 27.330 +/- 0.002
nonbranching: 24.770 +/- 0.000
nonbranching_but_branching_reverse: 19.387 +/- 0.000
nonbranching_reverse: 20.015 +/- 0.000
nonbranching_reverse_ternary: 19.038 +/- 0.000
asm_nb_rev: 18.987 +/- 0.001
```

I also ran the suite on another machine with a Intel i5-8250U@1.60GHz, in order to see if there would be any significant difference:

```
branching: 31.405 +/- 0.034
nonbranching_but_branching: 27.646 +/- 0.097
nonbranching: 27.894 +/- 0.021
nonbranching_but_branching_reverse: 22.760 +/- 0.040
nonbranching_reverse: 21.284 +/- 0.050
nonbranching_reverse_ternary: 19.299 +/- 0.002
asm_nb_rev: 19.793 +/- 0.009
```

Interestingly, on this CPU our assembly is slightly slower than the ternary
version; I guess this is due to us using a `cmovl`

where the compiler generated
version used the shifting trick.

We can’t possibly have done all this merging without making a proper
`mergesort`

in the end! Luckily for us, the `merge`

part is really the
only difficult part of the routine:

```
void merge_sort(uint64_t *xs, size_t n, uint64_t *buf) {
if (n < 2) return;
size_t h = n / 2;
merge_sort(xs, h, buf);
merge_sort(xs + h, n - h, buf + h);
merge(xs, h, xs + h, n - h, buf, n);
memcpy(xs, buf, 8 * n);
}
```

Unfortunately we have to merge to a buffer and then `memcpy`

it back. Perhaps
this is fixable: we can make the sorting routine either put the result in `xs`

or in `buf`

, and by having the recursive calls say which we can merge into the
other, assuming both recursive calls agree(!!^{5}). That is, if the
recursive calls say that the sorted subarrays are in `xs`

, we merge into `buf`

and tell our caller that `our`

result is in `buf`

. At the end, we just need to
make sure that the final sorted numbers are in `xs`

.

```
void _sort_asm(uint64_t *xs, size_t n, uint64_t *buf, int *into_buf) {
if (n < 2) {
*into_buf = 0;
return;
}
size_t h = n / 2;
int res_in_buf;
_sort_asm(xs, h, buf, &res_in_buf); // WARNING: `res_in_buf` for the two calls is needs
_sort_asm(xs + h, n - h, buf + h, &res_in_buf); // not be the same in the real world!
*into_buf = res_in_buf ^ 1;
if (res_in_buf)
asm_nb_rev(buf, h, buf + h, n - h, xs, n);
else
asm_nb_rev(xs, h, xs + h, n - h, buf, n);
}
void sort_asm(uint64_t *xs, size_t n, uint64_t *buf) {
int res_in_buf;
_sort_asm(xs, n, buf, &res_in_buf);
if (res_in_buf) {
memcpy(xs, buf, 8 * n);
}
}
```

and similar, for the other variants.
You might see the branch and wonder if we can remove it — I tried, by making
an array `{xs, buf}`

and index it with `res_in_buf`

, but it caused a minor
slowdown: maybe some branching is fine after all.

Here are the running times:

```
i7-7500U i5-8250U
sort_branching: 369.479 +/- 0.047 393.762 +/- 0.082
sort_nonbranching_but_branching: 324.337 +/- 0.014 337.120 +/- 0.099
sort_nonbranching: 325.658 +/- 0.028 352.802 +/- 0.120
sort_nonbranching_but_branching_reverse: 279.237 +/- 0.164 287.799 +/- 0.154
sort_nonbranching_reverse: 283.927 +/- 0.033 299.277 +/- 0.929
sort_nonbranching_reverse_ternary: 270.668 +/- 0.009 278.644 +/- 1.677
sort_asm_nb_rev: 270.228 +/- 0.009 281.657 +/- 0.360
```

If you would like to run the suite yourself, the git repo is avaiable here.

Thanks for reading.

- Originally I had omitted the
`_done`

parts, and the code was much cleaner, and I’m not sure why having it in complicates this that much. Also, why is`k`

incremented before storing`zs[k]`

so that we have to store`zs[k-1]`

instead?^{[return]} - Curiously, if we change from
`uint64_t`

to`int64_t`

and use`((a-b)>>63)&1`

for the test we do not depend on the magnitudes of the numbers (as the compiler can assume signed overflow will not happen); also the`and`

never makes it to the assembly, and we still use logical instead of arithmetic shift.^{[return]} - The alternative is
*arithmetic shift*in which the sign bit is propagated down. In this case we would end up with either all zeroes or all ones.^{[return]} - https://en.wikipedia.org/wiki/X86-64#Virtual_address_space_details
^{[return]} - This is really only the case if
`n`

is a power of two: otherwise you’ll have two siblings in the call tree with different`n`

s, and this difference will cause two leaf nodes to be at different depths, which in turn will make them “out of sync”.^{[return]}