Making rustc output better assembly
I was profiling some of our workloads at work, and while reading some of the assembly output noticed
that the is_zero
method of the rust_decimal crate
was being called quite a bit of times: not much, but enough for me to take a look at it. Our codebase
is already pretty well optimized, and we have gotten rid of the low-hanging fruit in terms of performance,
so any small improvement can add up quickly (and also small performance regressions,
which will lead to a death by a thousand unoptimized cuts).
What really caught my eye though, was the assembly output. I was expecting it to be much shorter,
but it somehow wasn’t, and had a bunch of jumps. I already knew the internal representation of their
Decimal
struct because I had to write some custom code to serialize it efficiently, so I knew the
compiler could do better at generating its is_zero
function. So I opened the wonderful Compiler
Explorer website and started to dig in.
Decimal
under the hood
In the rust_decimal
crate, Decimal
is represented as follows:
struct Decimal {
flags: u32,
// The lo, mid, hi, and flags fields contain the representation of the
// Decimal value as a 96-bit integer.
hi: u32,
lo: u32,
mid: u32,
}
Of course, this is an internal detail and is subject to change, but when you deal with performance,
you need to intimately know your data. The is_zero
function states that a Decimal
is zero if
all the hi
, lo
and mid
fields are zero.
impl Decimal {
pub const fn is_zero(&self) -> bool {
self.lo == 0 && self.mid == 0 && self.hi == 0
}
}
This is the code that rustc
gives us with -C opt-level=3
:
is_zero:
mov eax, dword ptr [rdi + 12]
or eax, dword ptr [rdi + 8]
jne .LBB0_1
cmp dword ptr [rdi + 4], 0
sete al
ret
.LBB0_1:
xor eax, eax
ret
jne
? cmp
? So much code. So much branching. Of course, we are short-circuiting, which is branching,
but I would expect the compiler to be smarter than that.
The bitwise approach
EDIT: This approach has been merged in the v1 branch
We can just compare all these fields at once. This should result in much fewer instructions, and fewer jumps (that could or could not be expensive, you should always measure).
Our new approach could look like:
pub const fn is_zero(&self) -> bool {
self.lo | self.mid | self.hi == 0
}
This is the output that rustc
gives us:
is_zero:
mov eax, dword ptr [rdi + 12]
or eax, dword ptr [rdi + 8]
or eax, dword ptr [rdi + 4]
sete al
ret
Much better. We or
the bits together of these fields and store it in eax
. At the end of the day,
even though each field is a number, it is still made out of bits
. We want to check all of them at once and avoid the
extra branching and code that results in having to exit sooner if one of the numbers is not zero. By treating
the numbers as just bits, we can achieve this in a more efficient manner. At the end of the day, zero
is zero, so we can just check for it at the bit level. No fancy bitmasks going on here. Since it’s a
logical or
, if at least one bit is set to 1, the final number in eax
will have a 1 in it. And you
don’t need to be a computer science wiz to know that if there is at least a 1 somewhere in a binary
representation of a number, then it can’t be a zero.
The unsafe way
⚠️ WARNING: This is not for the faint of heart, specially if you are a good boy rustacean and avoid unsafe. What I am about to show you requires a strong stomach. Reader’s discretion is advised.
This duplicate or
is bothering me. I want to truly check it all in one fell swoop. Can I use SIMD
somehow? Do I really need to?
If we look at the data more closely, we can see that a Decimal
is made of 4 fields, 32 bits each.
And wouldn’t you know it, rust has an unsigned 128-bit integer type.
We can just transmute the memory (a safer and fancier form of memory casting) to interpret the struct as one big 128 bit number.
let num = unsafe { core::mem::transmute::<Decimal, u128>(dec) };
Since they are the same size, transmute
lets us do it (another great safety feature from rust, you
can still use unsafe, but you get a safer unsafe experience).
We can’t just check if this number is 0 and call it a day, since we don’t use the flags
member,
which is the very first field of the struct and is therefore also part of our number. To get rid of it,
we can just shift those 32 bits out (the other way around,
since we are running on a little endian machine, making our code technically not portable), with the rest being
filled with 0 as they come in. Now, the whole content of our 128-bit number is composed of the values of the
mid
, lo
and hi
fields, and the rest would be zeroes, so we can now check for equality. The final
code would look like this (transmute
takes an owned value as a parameter, so we need to change
the function signature. This is ok since Decimal
implements Copy
and it is only 128 bits):
pub const fn is_zero(dec: Decimal) -> bool {
let num = unsafe { core::mem::transmute::<Decimal, u128>(dec) };
num >> 32 == 0
}
This gets rid of the extra or
:
is_zero:
mov eax, dword ptr [rdi + 4]
or rax, qword ptr [rdi + 8]
sete al
ret
However, I would still expect to see a shift instruction, not an or. The 64 bit calling convention
says that you can pack a small struct (which Decimal
qualifies as) into registers, so I would
expect to see a shift on a register and not even load from memory.
C++ and clang
Since this all seemed like something an optimizing compiler should be able to do on its own, I tried porting the code to C++ and see what clang did with it, since they are both based on LLVM (rust has its own fork though).
With the original implementation (the one with the short-circuiting &&) in rust, if we changed the
code to take in a copy of Decimal
instead of a reference, the assembly output is the same. If we
port the same piece of code to C++ using the latest stable clang
to compile, we get the same
assembly output than in rust when using references, but when using a copy, clang
is able to
generate even better assembly code than our custom unsafe rust version:
#include <cstdint>
struct Decimal {
uint32_t flags;
uint32_t hi;
uint32_t lo;
uint32_t mid;
};
bool is_zero(Decimal dec) {
return dec.hi == 0 && dec.lo == 0 && dec.mid == 0;
}
is_zero(Decimal):
shr rdi, 32
or rdi, rsi
sete al
ret
Since the compiler only sees very little code to compile, it is unclear if rustc would end up putting
our Decimal
in a register too (seems likely), but what is clear is that clang was able to generate
much better code from the get-go without needing to be too explicit about it as opposed to rust.
Results?
Our sample workloads got 2-3% faster with this little code change, which is worth it for our use case, since we try to squeeze as much performance as possible. Your mileage may vary, so always make sure to measure on your machine.
All in all, the unsafe version is not worth it for just getting rid of one extra or
. It is less clear
to read, more complex and error-prone. My goal was to make it output a shr
instruction, but the
compiler seems to be more stubborn (or maybe is more conservative and plays it safer when it has less
context of when the function will be called, who knows).
Again, these were just tiny code snippets in a vacuum, you would need to measure for yourself, but what is clear is that using an explicit bit comparison yielded much better assembly than the original in our production setup.