Go back to my main page

Rust ints to Rust enums with less instructions

I have an enum. It looks a bit like this:

pub enum SomeEnum {
    A = 0,
    B = 1,
    C = 2,
    D = 3,
}

I have a number. It looks a lot like this: let number = 0b10. Turning an enum into a number is pretty simple; I can do SomeEnum::B as u8. Turning a number into an enum is considerably more difficult; number as SomeEnum won't work as number can't be constrained to only be the values of the provided enum (i.e. there's no u<2> type).

Ideally, an enum like SomeEnum above is nothing more than a u8 at runtime, so the ideal function is a black box function fn to_int(e: SomeEnum) -> u8 that does close to nothing.

Option 1: #[derive]

There's a few packages that provide functionality for this:

  • derive_more's TryFrom
  • int_enum is explicitly built to do this via TryFrom

These work fine for normal code, but if you're using const fns, these won't work on stable Rust due to being trait derives and you can't have const fn in trait methods. Also, a lot of the time this requires a useless use of a result when you know full well that you're only ever going to be passing in correct values (usually masked off of a bitfield); it's better to panic.

Option 2: std::mem::transmute

Transmute is pretty simple: it turns type A into type B without type checking. Given that SomeEnum has four patterns with the bit patterns 0b00, 0b01, 0b10, and 0b11, I should AND-off the upper bits so that this doesn't cause demons to erupt:

// Assume I put a ``#[repr(u8)]`` on the enum above. 

pub const fn accursed_unutterable_perform_conversion(e: u8) -> SomeEnum {
    return unsafe { std::mem::transmute::<u8, SomeEnum>(e & 0b11) };
}

Indeed, by checking on GodBolt, this generates a very simple function that does just and r0, r0, #0b11 followed by a bx lr; in real code, this would be inlined and become just the and.

Quick primer on ARM assembly: and dest, source, const means dest = source & const. bx lr means branch (with exchange) to link, aka the return address, so the caller function. The ARM ABI passes in parameters using the first four registers, so r0 is my first argument.

I'm using 32-bit ARM assembly because I like it more than x86. The x86 assembly is often very similar anyway.

This is very much a "trust the developer" function; it only works in the case that the enum has 2**N valid patterns, and will produce UB if transmuted otherwise. If the bitfield is widened and the enum gains more variants this will produce invalid results for enum variants higher than that value, but not cause UB. On the other hand, if a variant is removed this will inevitably cause UB.

Option 3: match

This is very self explanatory. Instead of directly transmuting, I just write a match over all the possible bit patterns and return the enum variant.

pub const fn uncursed_utterable_perform_conversion(e: u8) -> SomeEnum {
    return match e {
        0b00 => SomeEnum::A,
        0b01 => SomeEnum::B,
        0b10 => SomeEnum::C,
        0b11 => SomeEnum::D,
        // Obviously you'd put a better message in, but you can't do that with ``unreachable``
        // in const contexts.
        _ => unreachable!()
    };
}

This is where I would imagine most people would stop, but I'm interested in seeing how it looks like at the assembly level. Looking at Godbolt again, it generates roughly what I would expect:

  • uxtb r1, r0 extracts the lowest 8 bits of the passed-in register and puts it in r1.
  • cmp r1, #0b100 compares r1 and the constant, and sets the appropriate flags in the program status register.
  • bxlo is a branch with a condition code; if the previous compare said r1 is lower than the constant, it returns to the caller.

Post-optimisation, this code does nearly exactly the same as the transmute where it passes the value through unchanged provided it's within the constraints, but without using any unsafe code. Good! But that extra branch bothers me a little bit; that's quite a few extra bytes of code! Even -Os can't remove this, but there's a simple trick: just mask off the bits in the match statement. With match e turning into match e & 0b11 both functions compile to exactly the same.

Making it a bit "better"

This has a similar issue as the transmute example in that adding new variants will make this code buggy. However, there is a nice way to solve this by using some defensive programming via std::mem::variant_count instead, at the cost of reaching for nightly Rust.

Small soapbox about this: variant_count has been unstable for five years with zero real changes. This is a very useful function e.g. when using arrays with enum indexes for type safety but has been stalled because it could be part of a better design.

In fact, this function was nominated for stabilisation around a year and a half ago at the prompting of a reddit thread about unstabilised features; after a month it was unnominated in favour of a new RFC, which has now died in committee for well over a year. I don't expect to see this feature stabilised within the next five years either.

#[unsafe(no_mangle)]
pub const fn noncursed_utterable_perform_conversion(e: u8) -> SomeEnum {
    let variants = std::mem::variant_count::<SomeEnum>();
    let size = u8::BITS - (variants - 1).leading_zeros();
    let mask = !0u8 >> (8 - size);

    return match e & mask {
        0b00 => SomeEnum::A,
        0b01 => SomeEnum::B,
        0b10 => SomeEnum::C,
        0b11 => SomeEnum::D,
        _ => unreachable!()
    };
}

In this case, the new prologue generates a mask using the (ceil'd) log2 of the number of variants. If there are as many entries in the match as the value of the mask, the unreachable!() is optimised out and this function reduces to the same and + bx chain as before. If there are a mismatched number of entries in the match as the value of the mask, then the unreachable still exists and will be hit at runtime. It doesn't matter if entries are added or removed from the enum as this will fail appropriately with a panic (or an error case, if you change it to return Result) rather than silently truncating or causing UB.

Is it faster?

Using arrays of [0u8; 4096]:

running 3 tests
test accursed_match  ... bench:       2,349.23 ns/iter (+/- 71.76)
test optimised_match ... bench:       2,305.93 ns/iter (+/- 61.23)
test regular_match   ... bench:       1,576.90 ns/iter (+/- 33.30)

Using arrays of std::array::from_fn(|it| (it % 4) as u8); (also 4096):

running 3 tests
test accursed_match  ... bench:       1,681.19 ns/iter (+/- 245.20)
test optimised_match ... bench:       1,681.06 ns/iter (+/- 261.87)
test regular_match   ... bench:       2,339.23 ns/iter (+/- 74.51)

Using arrays of [0u8; 16384]:

test accursed_match  ... bench:       9,373.10 ns/iter (+/- 486.20)
test optimised_match ... bench:       9,396.42 ns/iter (+/- 597.27)
test regular_match   ... bench:       6,392.49 ns/iter (+/- 76.26)

Using arrays of std::array::from_fn(|it| (it % 4) as u8); (also 16384):

test accursed_match  ... bench:       9,577.98 ns/iter (+/- 1,145.13)
test optimised_match ... bench:       9,393.89 ns/iter (+/- 793.25)
test regular_match   ... bench:       6,379.16 ns/iter (+/- 154.36)

Looks like the regular match is faster 3/4 times.