Rust ints to Rust enums with less instructions
I have an enum. It looks a bit like this:
pub enum SomeEnum {
A = 0,
B = 1,
C = 2,
D = 3,
}
I have a number. It looks a lot like this: let number = 0b10
. Turning an enum into
a number is pretty simple; I can do SomeEnum::B as u8
. Turning a number into an
enum is considerably more difficult; number as SomeEnum
won't work as
number
can't be constrained to only be the values of the provided enum (i.e.
there's no u<2>
type).
Ideally, an enum like SomeEnum
above is nothing more than a u8
at
runtime, so the ideal function is a black box function
fn to_int(e: SomeEnum) -> u8
that does close to nothing.
Option 1: #[derive]
There's a few packages that provide functionality for this:
derive_more
'sTryFrom
int_enum
is explicitly built to do this viaTryFrom
These work fine for normal code, but if you're using const fns
, these won't
work on stable Rust due to being trait derives and you can't have const fn
in
trait methods. Also, a lot of the time this requires a useless use of a result when you know
full well that you're only ever going to be passing in correct values (usually masked off of a
bitfield); it's better to panic.
Option 2: std::mem::transmute
Transmute is pretty simple: it turns type A into type B without type checking. Given that
SomeEnum
has four patterns with the bit patterns 0b00
,
0b01
, 0b10
, and 0b11
, I should AND-off the upper bits so
that this doesn't cause demons to erupt:
// Assume I put a ``#[repr(u8)]`` on the enum above.
pub const fn accursed_unutterable_perform_conversion(e: u8) -> SomeEnum {
return unsafe { std::mem::transmute::<u8, SomeEnum>(e & 0b11) };
}
Indeed, by checking on GodBolt, this generates a
very simple function that does just and r0, r0, #0b11
followed by a
bx lr
; in real code, this would be inlined and become just the and
.
Quick primer on ARM assembly: and dest, source, const
means
dest = source & const
. bx lr
means branch (with exchange) to
link, aka the return address, so the caller function. The ARM ABI passes in parameters using
the first four registers, so r0
is my first argument.
I'm using 32-bit ARM assembly because I like it more than x86. The x86 assembly is often very similar anyway.
This is very much a "trust the developer" function; it only works in the case that the enum has 2**N valid patterns, and will produce UB if transmuted otherwise. If the bitfield is widened and the enum gains more variants this will produce invalid results for enum variants higher than that value, but not cause UB. On the other hand, if a variant is removed this will inevitably cause UB.
Option 3: match
This is very self explanatory. Instead of directly transmuting, I just write a match over all the possible bit patterns and return the enum variant.
pub const fn uncursed_utterable_perform_conversion(e: u8) -> SomeEnum {
return match e {
0b00 => SomeEnum::A,
0b01 => SomeEnum::B,
0b10 => SomeEnum::C,
0b11 => SomeEnum::D,
// Obviously you'd put a better message in, but you can't do that with ``unreachable``
// in const contexts.
_ => unreachable!()
};
}
This is where I would imagine most people would stop, but I'm interested in seeing how it looks like at the assembly level. Looking at Godbolt again, it generates roughly what I would expect:
-
uxtb r1, r0
extracts the lowest 8 bits of the passed-in register and puts it inr1
. -
cmp r1, #0b100
compares r1 and the constant, and sets the appropriate flags in the program status register. -
bxlo
is a branch with a condition code; if the previous compare saidr1
is lower than the constant, it returns to the caller.
Post-optimisation, this code does nearly exactly the same as the transmute where it passes the
value through unchanged provided it's within the constraints, but without using any unsafe
code. Good! But that extra branch bothers me a little bit; that's quite a few extra bytes of
code! Even -Os
can't remove this, but there's a simple trick: just mask off
the bits in the match
statement. With match e
turning into
match e & 0b11
both functions compile to exactly the same.
Making it a bit "better"
This has a similar issue as the transmute example in that adding new variants will make this
code buggy. However, there is a nice way to solve this by using some defensive programming via
std::mem::variant_count
instead, at the cost of reaching for nightly Rust.
Small soapbox about this: variant_count
has been unstable for five years with
zero real changes. This is a very useful function e.g. when using arrays with enum indexes
for type safety but has been stalled because it could be part of a better
design.
In fact, this function was nominated for stabilisation around a year and a half ago at the prompting of a reddit thread about unstabilised features; after a month it was unnominated in favour of a new RFC, which has now died in committee for well over a year. I don't expect to see this feature stabilised within the next five years either.
#[unsafe(no_mangle)]
pub const fn noncursed_utterable_perform_conversion(e: u8) -> SomeEnum {
let variants = std::mem::variant_count::<SomeEnum>();
let size = u8::BITS - (variants - 1).leading_zeros();
let mask = !0u8 >> (8 - size);
return match e & mask {
0b00 => SomeEnum::A,
0b01 => SomeEnum::B,
0b10 => SomeEnum::C,
0b11 => SomeEnum::D,
_ => unreachable!()
};
}
In this case, the new prologue generates a mask using the (ceil'd) log2 of the number of
variants. If there are as many entries in the match as the value of the mask, the
unreachable!()
is optimised out and this function
reduces to the same and
+
bx
chain as before. If there are a mismatched number of entries in the match as the
value of the mask, then the unreachable
still exists and will be hit at runtime. It
doesn't matter if entries are added or removed from the enum as this will fail appropriately
with a panic (or an error case, if you change it to return Result) rather than silently
truncating or causing UB.
Is it faster?
Using arrays of [0u8; 4096]
:
running 3 tests
test accursed_match ... bench: 2,349.23 ns/iter (+/- 71.76)
test optimised_match ... bench: 2,305.93 ns/iter (+/- 61.23)
test regular_match ... bench: 1,576.90 ns/iter (+/- 33.30)
Using arrays of std::array::from_fn(|it| (it % 4) as u8);
(also 4096):
running 3 tests
test accursed_match ... bench: 1,681.19 ns/iter (+/- 245.20)
test optimised_match ... bench: 1,681.06 ns/iter (+/- 261.87)
test regular_match ... bench: 2,339.23 ns/iter (+/- 74.51)
Using arrays of [0u8; 16384]
:
test accursed_match ... bench: 9,373.10 ns/iter (+/- 486.20)
test optimised_match ... bench: 9,396.42 ns/iter (+/- 597.27)
test regular_match ... bench: 6,392.49 ns/iter (+/- 76.26)
Using arrays of std::array::from_fn(|it| (it % 4) as u8);
(also 16384):
test accursed_match ... bench: 9,577.98 ns/iter (+/- 1,145.13)
test optimised_match ... bench: 9,393.89 ns/iter (+/- 793.25)
test regular_match ... bench: 6,379.16 ns/iter (+/- 154.36)
Looks like the regular match is faster 3/4 times.