Issue
I'm dealing with some C++ code that has an optimised version that uses inline assembly. The optimised version is exhibiting behaviour that is not thread safe, which can be traced to 3 global variables that are accessed extensively from inside the assembly.
__attribute__ ((aligned (16))) unsigned int SHAVITE_MESS[16];
__attribute__ ((aligned (16))) thread_local unsigned char SHAVITE_PTXT[8*4];
__attribute__ ((aligned (16))) unsigned int SHAVITE_CNTS[4] = {0,0,0,0};
...
asm ("movaps xmm0, SHAVITE_PTXT[rip]");
asm ("movaps xmm1, SHAVITE_PTXT[rip+16]");
asm ("movaps xmm3, SHAVITE_CNTS[rip]");
asm ("movaps xmm4, SHAVITE256_XOR2[rip]");
asm ("pxor xmm2, xmm2");
I naively thought that the simplest way to solve this would be to make the variables thread_local, however this leads to segfaults in the assembly - it appears that the assembly does not know that the variables are thread local?
I have dug around in the assembly of a small thread_local test case to see how gcc deals with them mov eax, DWORD PTR fs:num1@tpoff
and tried to modify the code to do the same:
asm ("movaps xmm0, fs:SHAVITE_PTXT@tpoff");
asm ("movaps xmm1, fs:SHAVITE_PTXT@tpoff+16");
asm ("movaps xmm3, fs:SHAVITE_CNTS@tpoff");
asm ("movaps xmm4, fs:SHAVITE256_XOR2@tpoff");
asm ("pxor xmm2, xmm2");
Which works if all variables are also thread_local, it also matches the reference implementation (non assembly) so appears to work successfully.
However this seems very CPU specific, if I look at the output for compiling with -m32
I get instead mov eax, DWORD PTR gs:num1@ntpoff
As the code is anyway 'x86' specific (uses aes-ni) I could I guess simply decompile and implement for all possible variants of this.
However I don't really like this as a solution, and it feels a bit like guess programming. Further doing so doesn't really help me learn anything for any future such cases which may be a bit less specific to one architecture.
Is there a more generic/correct way to deal with this? How do I go about telling the assembly that the variables are thread_local in a more generic way? Or is there a way I can pass in the variables such that it doesn't need to know and works regardless?
Solution
As another answer says, the inline asm is a mess and is misused. Rewriting with intrinsics should be good, and lets you compile with or without -mavx
(or -march=haswell
or -march=znver1
or whatever) to let the compiler save a bunch of register-copy instructions.
Also it lets the compiler optimize (vector) register allocation and when to load/store, which is something compilers are pretty good at.
Ok, well, I wasn't able to use the test data you provided. It uses several other routines not provided here, and I'm too lazy to go looking for them.
That said, I was able to cobble something together for test data. And my E256() returns the same values as yours. That doesn't mean I've got it 100% correct (you'll want to do your own testing), but given all the xor/aesenc against everything over and over again, if something were wrong, I'd expect it to show.
Converting to intrinsics wasn't particularly hard. Mostly you just need to find the equivalent _mm_
function for the given asm instruction. That and track down all places where you type x12 when you meant x13 (grrr).
Note that while this code makes use of variables named x0-x15, that's only because it made the translation easier. There is no correlation between these C variable names and the registers gcc will use when it compiles the code. Also, gcc uses a lot of knowledge about SSE to re-order instructions, so the output (esp for -O3) is very different from the original asm. If you're thinking you can compare them to check for correctness (like I did), expect to be frustrated.
This code contains both the original routines (prefixed with "old") and the new, and calls both from main() to see if they produce the same output. I have made no effort to make any changes to the builtins in an attempt to optimize it. As soon as it worked, I just stopped. I'll leave any further improvements to you, now that it's all C code.
That said, gcc is able to optimize intrinsics (something it can't do for asm). Which means that if you re-compile this code using -mavx2
, the generated code is quite different.
Some stats:
- The original (fully expanded) code for E256() took up 287 instructions.
- Building with the intrinsics without -mavx2 takes 251.
- Building with the intrinsics with -mavx2 takes 196.
I haven't done any timings, but I like to believe that dropping ~100 lines of asm out will help. OTOH, sometimes gcc does a lousy job optimizing SSE, so don't assume anything.
Hope this helps.
// Compile with -O3 -msse4.2 -maes
// or -O3 -msse4.2 -maes -mavx2
#include <wmmintrin.h>
#include <x86intrin.h>
#include <stdio.h>
///////////////////////////
#define tos(a) #a
#define tostr(a) tos(a)
#define rev_reg_0321(j){ asm ("pshufb xmm" tostr(j)", [oldSHAVITE_REVERSE]"); }
#define replace_aes(i, j){ asm ("aesenc xmm" tostr(i)", xmm" tostr(j)""); }
__attribute__ ((aligned (16))) unsigned int oldSHAVITE_MESS[16];
__attribute__ ((aligned (16))) unsigned char oldSHAVITE_PTXT[8*4];
__attribute__ ((aligned (16))) unsigned int oldSHAVITE_CNTS[4] = {0,0,0,0};
__attribute__ ((aligned (16))) unsigned int oldSHAVITE_REVERSE[4] = {0x07060504, 0x0b0a0908, 0x0f0e0d0c, 0x03020100 };
__attribute__ ((aligned (16))) unsigned int oldSHAVITE256_XOR2[4] = {0x0, 0xFFFFFFFF, 0x0, 0x0};
__attribute__ ((aligned (16))) unsigned int oldSHAVITE256_XOR3[4] = {0x0, 0x0, 0xFFFFFFFF, 0x0};
__attribute__ ((aligned (16))) unsigned int oldSHAVITE256_XOR4[4] = {0x0, 0x0, 0x0, 0xFFFFFFFF};
#define oldmixing() do {\
asm("movaps xmm11, xmm15");\
asm("movaps xmm10, xmm14");\
asm("movaps xmm9, xmm13");\
asm("movaps xmm8, xmm12");\
\
asm("movaps xmm6, xmm11");\
asm("psrldq xmm6, 4");\
asm("pxor xmm8, xmm6");\
asm("movaps xmm6, xmm8");\
asm("pslldq xmm6, 12");\
asm("pxor xmm8, xmm6");\
\
asm("movaps xmm7, xmm8");\
asm("psrldq xmm7, 4");\
asm("pxor xmm9, xmm7");\
asm("movaps xmm7, xmm9");\
asm("pslldq xmm7, 12");\
asm("pxor xmm9, xmm7");\
\
asm("movaps xmm6, xmm9");\
asm("psrldq xmm6, 4");\
asm("pxor xmm10, xmm6");\
asm("movaps xmm6, xmm10");\
asm("pslldq xmm6, 12");\
asm("pxor xmm10, xmm6");\
\
asm("movaps xmm7, xmm10");\
asm("psrldq xmm7, 4");\
asm("pxor xmm11, xmm7");\
asm("movaps xmm7, xmm11");\
asm("pslldq xmm7, 12");\
asm("pxor xmm11, xmm7");\
} while(0);
void oldE256()
{
asm (".intel_syntax noprefix");
/* (L,R) = (xmm0,xmm1) */
asm ("movaps xmm0, [oldSHAVITE_PTXT]");
asm ("movaps xmm1, [oldSHAVITE_PTXT+16]");
asm ("movaps xmm3, [oldSHAVITE_CNTS]");
asm ("movaps xmm4, [oldSHAVITE256_XOR2]");
asm ("pxor xmm2, xmm2");
/* init key schedule */
asm ("movaps xmm8, [oldSHAVITE_MESS]");
asm ("movaps xmm9, [oldSHAVITE_MESS+16]");
asm ("movaps xmm10, [oldSHAVITE_MESS+32]");
asm ("movaps xmm11, [oldSHAVITE_MESS+48]");
/* xmm8..xmm11 = rk[0..15] */
/* start key schedule */
asm ("movaps xmm12, xmm8");
asm ("movaps xmm13, xmm9");
asm ("movaps xmm14, xmm10");
asm ("movaps xmm15, xmm11");
rev_reg_0321(12);
rev_reg_0321(13);
rev_reg_0321(14);
rev_reg_0321(15);
replace_aes(12, 2);
replace_aes(13, 2);
replace_aes(14, 2);
replace_aes(15, 2);
asm ("pxor xmm12, xmm3");
asm ("pxor xmm12, xmm4");
asm ("movaps xmm4, [oldSHAVITE256_XOR3]");
asm ("pxor xmm12, xmm11");
asm ("pxor xmm13, xmm12");
asm ("pxor xmm14, xmm13");
asm ("pxor xmm15, xmm14");
/* xmm12..xmm15 = rk[16..31] */
/* F3 - first round */
asm ("movaps xmm6, xmm8");
asm ("pxor xmm8, xmm1");
replace_aes(8, 9);
replace_aes(8, 10);
replace_aes(8, 2);
asm ("pxor xmm0, xmm8");
asm ("movaps xmm8, xmm6");
/* F3 - second round */
asm ("movaps xmm6, xmm11");
asm ("pxor xmm11, xmm0");
replace_aes(11, 12);
replace_aes(11, 13);
replace_aes(11, 2);
asm ("pxor xmm1, xmm11");
asm ("movaps xmm11, xmm6");
/* key schedule */
oldmixing();
/* xmm8..xmm11 - rk[32..47] */
/* F3 - third round */
asm ("movaps xmm6, xmm14");
asm ("pxor xmm14, xmm1");
replace_aes(14, 15);
replace_aes(14, 8);
replace_aes(14, 2);
asm ("pxor xmm0, xmm14");
asm ("movaps xmm14, xmm6");
/* key schedule */
asm ("pshufd xmm3, xmm3,135");
asm ("movaps xmm12, xmm8");
asm ("movaps xmm13, xmm9");
asm ("movaps xmm14, xmm10");
asm ("movaps xmm15, xmm11");
rev_reg_0321(12);
rev_reg_0321(13);
rev_reg_0321(14);
rev_reg_0321(15);
replace_aes(12, 2);
replace_aes(13, 2);
replace_aes(14, 2);
replace_aes(15, 2);
asm ("pxor xmm12, xmm11");
asm ("pxor xmm14, xmm3");
asm ("pxor xmm14, xmm4");
asm ("movaps xmm4, [oldSHAVITE256_XOR4]");
asm ("pxor xmm13, xmm12");
asm ("pxor xmm14, xmm13");
asm ("pxor xmm15, xmm14");
/* xmm12..xmm15 - rk[48..63] */
/* F3 - fourth round */
asm ("movaps xmm6, xmm9");
asm ("pxor xmm9, xmm0");
replace_aes(9, 10);
replace_aes(9, 11);
replace_aes(9, 2);
asm ("pxor xmm1, xmm9");
asm ("movaps xmm9, xmm6");
/* key schedule */
oldmixing();
/* xmm8..xmm11 = rk[64..79] */
/* F3 - fifth round */
asm ("movaps xmm6, xmm12");
asm ("pxor xmm12, xmm1");
replace_aes(12, 13);
replace_aes(12, 14);
replace_aes(12, 2);
asm ("pxor xmm0, xmm12");
asm ("movaps xmm12, xmm6");
/* F3 - sixth round */
asm ("movaps xmm6, xmm15");
asm ("pxor xmm15, xmm0");
replace_aes(15, 8);
replace_aes(15, 9);
replace_aes(15, 2);
asm ("pxor xmm1, xmm15");
asm ("movaps xmm15, xmm6");
/* key schedule */
asm ("pshufd xmm3, xmm3, 147");
asm ("movaps xmm12, xmm8");
asm ("movaps xmm13, xmm9");
asm ("movaps xmm14, xmm10");
asm ("movaps xmm15, xmm11");
rev_reg_0321(12);
rev_reg_0321(13);
rev_reg_0321(14);
rev_reg_0321(15);
replace_aes(12, 2);
replace_aes(13, 2);
replace_aes(14, 2);
replace_aes(15, 2);
asm ("pxor xmm12, xmm11");
asm ("pxor xmm13, xmm3");
asm ("pxor xmm13, xmm4");
asm ("pxor xmm13, xmm12");
asm ("pxor xmm14, xmm13");
asm ("pxor xmm15, xmm14");
/* xmm12..xmm15 = rk[80..95] */
/* F3 - seventh round */
asm ("movaps xmm6, xmm10");
asm ("pxor xmm10, xmm1");
replace_aes(10, 11);
replace_aes(10, 12);
replace_aes(10, 2);
asm ("pxor xmm0, xmm10");
asm ("movaps xmm10, xmm6");
/* key schedule */
oldmixing();
/* xmm8..xmm11 = rk[96..111] */
/* F3 - eigth round */
asm ("movaps xmm6, xmm13");
asm ("pxor xmm13, xmm0");
replace_aes(13, 14);
replace_aes(13, 15);
replace_aes(13, 2);
asm ("pxor xmm1, xmm13");
asm ("movaps xmm13, xmm6");
/* key schedule */
asm ("pshufd xmm3, xmm3, 135");
asm ("movaps xmm12, xmm8");
asm ("movaps xmm13, xmm9");
asm ("movaps xmm14, xmm10");
asm ("movaps xmm15, xmm11");
rev_reg_0321(12);
rev_reg_0321(13);
rev_reg_0321(14);
rev_reg_0321(15);
replace_aes(12, 2);
replace_aes(13, 2);
replace_aes(14, 2);
replace_aes(15, 2);
asm ("pxor xmm12, xmm11");
asm ("pxor xmm15, xmm3");
asm ("pxor xmm15, xmm4");
asm ("pxor xmm13, xmm12");
asm ("pxor xmm14, xmm13");
asm ("pxor xmm15, xmm14");
/* xmm12..xmm15 = rk[112..127] */
/* F3 - ninth round */
asm ("movaps xmm6, xmm8");
asm ("pxor xmm8, xmm1");
replace_aes(8, 9);
replace_aes(8, 10);
replace_aes(8, 2);
asm ("pxor xmm0, xmm8");
asm ("movaps xmm8, xmm6");
/* F3 - tenth round */
asm ("movaps xmm6, xmm11");
asm ("pxor xmm11, xmm0");
replace_aes(11, 12);
replace_aes(11, 13);
replace_aes(11, 2);
asm ("pxor xmm1, xmm11");
asm ("movaps xmm11, xmm6");
/* key schedule */
oldmixing();
/* xmm8..xmm11 = rk[128..143] */
/* F3 - eleventh round */
asm ("movaps xmm6, xmm14");
asm ("pxor xmm14, xmm1");
replace_aes(14, 15);
replace_aes(14, 8);
replace_aes(14, 2);
asm ("pxor xmm0, xmm14");
asm ("movaps xmm14, xmm6");
/* F3 - twelfth round */
asm ("movaps xmm6, xmm9");
asm ("pxor xmm9, xmm0");
replace_aes(9, 10);
replace_aes(9, 11);
replace_aes(9, 2);
asm ("pxor xmm1, xmm9");
asm ("movaps xmm9, xmm6");
/* feedforward */
asm ("pxor xmm0, [oldSHAVITE_PTXT]");
asm ("pxor xmm1, [oldSHAVITE_PTXT+16]");
asm ("movaps [oldSHAVITE_PTXT], xmm0");
asm ("movaps [oldSHAVITE_PTXT+16], xmm1");
asm (".att_syntax noprefix");
return;
}
void oldCompress256(const unsigned char *message_block, unsigned char *chaining_value, unsigned long long counter,
const unsigned char salt[32])
{
int i, j;
for (i=0;i<8*4;i++)
oldSHAVITE_PTXT[i]=chaining_value[i];
for (i=0;i<16;i++)
oldSHAVITE_MESS[i] = *((unsigned int*)(message_block+4*i));
oldSHAVITE_CNTS[0] = (unsigned int)(counter & 0xFFFFFFFFULL);
oldSHAVITE_CNTS[1] = (unsigned int)(counter>>32);
/* encryption + Davies-Meyer transform */
oldE256();
for (i=0; i<4*8; i++)
chaining_value[i]=oldSHAVITE_PTXT[i];
return;
}
////////////////////////////////
__attribute__ ((aligned (16))) unsigned int SHAVITE_MESS[16];
__attribute__ ((aligned (16))) unsigned char SHAVITE_PTXT[8*4];
__attribute__ ((aligned (16))) unsigned int SHAVITE_CNTS[4] = {0,0,0,0};
__attribute__ ((aligned (16))) unsigned int SHAVITE_REVERSE[4] = {0x07060504, 0x0b0a0908, 0x0f0e0d0c, 0x03020100 };
__attribute__ ((aligned (16))) unsigned int SHAVITE256_XOR2[4] = {0x0, 0xFFFFFFFF, 0x0, 0x0};
__attribute__ ((aligned (16))) unsigned int SHAVITE256_XOR3[4] = {0x0, 0x0, 0xFFFFFFFF, 0x0};
__attribute__ ((aligned (16))) unsigned int SHAVITE256_XOR4[4] = {0x0, 0x0, 0x0, 0xFFFFFFFF};
#define mixing() do {\
x11 = x15; \
x10 = x14; \
x9 = x13;\
x8 = x12;\
\
x6 = x11;\
x6 = _mm_srli_si128(x6, 4);\
x8 = _mm_xor_si128(x8, x6);\
x6 = x8;\
x6 = _mm_slli_si128(x6, 12);\
x8 = _mm_xor_si128(x8, x6);\
\
x7 = x8;\
x7 = _mm_srli_si128(x7, 4);\
x9 = _mm_xor_si128(x9, x7);\
x7 = x9;\
x7 = _mm_slli_si128(x7, 12);\
x9 = _mm_xor_si128(x9, x7);\
\
x6 = x9;\
x6 = _mm_srli_si128(x6, 4);\
x10 = _mm_xor_si128(x10, x6);\
x6 = x10;\
x6 = _mm_slli_si128(x6, 12);\
x10 = _mm_xor_si128(x10, x6);\
\
x7 = x10;\
x7 = _mm_srli_si128(x7, 4);\
x11 = _mm_xor_si128(x11, x7);\
x7 = x11;\
x7 = _mm_slli_si128(x7, 12);\
x11 = _mm_xor_si128(x11, x7);\
} while(0);
void E256()
{
__m128i x0;
__m128i x1;
__m128i x2;
__m128i x3;
__m128i x4;
__m128i x5;
__m128i x6;
__m128i x7;
__m128i x8;
__m128i x9;
__m128i x10;
__m128i x11;
__m128i x12;
__m128i x13;
__m128i x14;
__m128i x15;
/* (L,R) = (xmm0,xmm1) */
const __m128i ptxt1 = _mm_loadu_si128((const __m128i*)SHAVITE_PTXT);
const __m128i ptxt2 = _mm_loadu_si128((const __m128i*)(SHAVITE_PTXT+16));
x0 = ptxt1;
x1 = ptxt2;
x3 = _mm_loadu_si128((__m128i*)SHAVITE_CNTS);
x4 = _mm_loadu_si128((__m128i*)SHAVITE256_XOR2);
x2 = _mm_setzero_si128();
/* init key schedule */
x8 = _mm_loadu_si128((__m128i*)SHAVITE_MESS);
x9 = _mm_loadu_si128((__m128i*)(SHAVITE_MESS+4));
x10 = _mm_loadu_si128((__m128i*)(SHAVITE_MESS+8));
x11 = _mm_loadu_si128((__m128i*)(SHAVITE_MESS+12));
/* xmm8..xmm11 = rk[0..15] */
/* start key schedule */
x12 = x8;
x13 = x9;
x14 = x10;
x15 = x11;
const __m128i xtemp = _mm_loadu_si128((__m128i*)SHAVITE_REVERSE);
x12 = _mm_shuffle_epi8(x12, xtemp);
x13 = _mm_shuffle_epi8(x13, xtemp);
x14 = _mm_shuffle_epi8(x14, xtemp);
x15 = _mm_shuffle_epi8(x15, xtemp);
x12 = _mm_aesenc_si128(x12, x2);
x13 = _mm_aesenc_si128(x13, x2);
x14 = _mm_aesenc_si128(x14, x2);
x15 = _mm_aesenc_si128(x15, x2);
x12 = _mm_xor_si128(x12, x3);
x12 = _mm_xor_si128(x12, x4);
x4 = _mm_loadu_si128((__m128i*)SHAVITE256_XOR3);
x12 = _mm_xor_si128(x12, x11);
x13 = _mm_xor_si128(x13, x12);
x14 = _mm_xor_si128(x14, x13);
x15 = _mm_xor_si128(x15, x14);
/* xmm12..xmm15 = rk[16..31] */
/* F3 - first round */
x6 = x8;
x8 = _mm_xor_si128(x8, x1);
x8 = _mm_aesenc_si128(x8, x9);
x8 = _mm_aesenc_si128(x8, x10);
x8 = _mm_aesenc_si128(x8, x2);
x0 = _mm_xor_si128(x0, x8);
x8 = x6;
/* F3 - second round */
x6 = x11;
x11 = _mm_xor_si128(x11, x0);
x11 = _mm_aesenc_si128(x11, x12);
x11 = _mm_aesenc_si128(x11, x13);
x11 = _mm_aesenc_si128(x11, x2);
x1 = _mm_xor_si128(x1, x11);
x11 = x6;
/* key schedule */
mixing();
/* xmm8..xmm11 - rk[32..47] */
/* F3 - third round */
x6 = x14;
x14 = _mm_xor_si128(x14, x1);
x14 = _mm_aesenc_si128(x14, x15);
x14 = _mm_aesenc_si128(x14, x8);
x14 = _mm_aesenc_si128(x14, x2);
x0 = _mm_xor_si128(x0, x14);
x14 = x6;
/* key schedule */
x3 = _mm_shuffle_epi32(x3, 135);
x12 = x8;
x13 = x9;
x14 = x10;
x15 = x11;
x12 = _mm_shuffle_epi8(x12, xtemp);
x13 = _mm_shuffle_epi8(x13, xtemp);
x14 = _mm_shuffle_epi8(x14, xtemp);
x15 = _mm_shuffle_epi8(x15, xtemp);
x12 = _mm_aesenc_si128(x12, x2);
x13 = _mm_aesenc_si128(x13, x2);
x14 = _mm_aesenc_si128(x14, x2);
x15 = _mm_aesenc_si128(x15, x2);
x12 = _mm_xor_si128(x12, x11);
x14 = _mm_xor_si128(x14, x3);
x14 = _mm_xor_si128(x14, x4);
x4 = _mm_loadu_si128((__m128i*)SHAVITE256_XOR4);
x13 = _mm_xor_si128(x13, x12);
x14 = _mm_xor_si128(x14, x13);
x15 = _mm_xor_si128(x15, x14);
/* xmm12..xmm15 - rk[48..63] */
/* F3 - fourth round */
x6 = x9;
x9 = _mm_xor_si128(x9, x0);
x9 = _mm_aesenc_si128(x9, x10);
x9 = _mm_aesenc_si128(x9, x11);
x9 = _mm_aesenc_si128(x9, x2);
x1 = _mm_xor_si128(x1, x9);
x9 = x6;
/* key schedule */
mixing();
/* xmm8..xmm11 = rk[64..79] */
/* F3 - fifth round */
x6 = x12;
x12 = _mm_xor_si128(x12, x1);
x12 = _mm_aesenc_si128(x12, x13);
x12 = _mm_aesenc_si128(x12, x14);
x12 = _mm_aesenc_si128(x12, x2);
x0 = _mm_xor_si128(x0, x12);
x12 = x6;
/* F3 - sixth round */
x6 = x15;
x15 = _mm_xor_si128(x15, x0);
x15 = _mm_aesenc_si128(x15, x8);
x15 = _mm_aesenc_si128(x15, x9);
x15 = _mm_aesenc_si128(x15, x2);
x1 = _mm_xor_si128(x1, x15);
x15 = x6;
/* key schedule */
x3 = _mm_shuffle_epi32(x3, 147);
x12 = x8;
x13 = x9;
x14 = x10;
x15 = x11;
x12 = _mm_shuffle_epi8(x12, xtemp);
x13 = _mm_shuffle_epi8(x13, xtemp);
x14 = _mm_shuffle_epi8(x14, xtemp);
x15 = _mm_shuffle_epi8(x15, xtemp);
x12 = _mm_aesenc_si128(x12, x2);
x13 = _mm_aesenc_si128(x13, x2);
x14 = _mm_aesenc_si128(x14, x2);
x15 = _mm_aesenc_si128(x15, x2);
x12 = _mm_xor_si128(x12, x11);
x13 = _mm_xor_si128(x13, x3);
x13 = _mm_xor_si128(x13, x4);
x13 = _mm_xor_si128(x13, x12);
x14 = _mm_xor_si128(x14, x13);
x15 = _mm_xor_si128(x15, x14);
/* xmm12..xmm15 = rk[80..95] */
/* F3 - seventh round */
x6 = x10;
x10 = _mm_xor_si128(x10, x1);
x10 = _mm_aesenc_si128(x10, x11);
x10 = _mm_aesenc_si128(x10, x12);
x10 = _mm_aesenc_si128(x10, x2);
x0 = _mm_xor_si128(x0, x10);
x10 = x6;
/* key schedule */
mixing();
/* xmm8..xmm11 = rk[96..111] */
/* F3 - eigth round */
x6 = x13;
x13 = _mm_xor_si128(x13, x0);
x13 = _mm_aesenc_si128(x13, x14);
x13 = _mm_aesenc_si128(x13, x15);
x13 = _mm_aesenc_si128(x13, x2);
x1 = _mm_xor_si128(x1, x13);
x13 = x6;
/* key schedule */
x3 = _mm_shuffle_epi32(x3, 135);
x12 = x8;
x13 = x9;
x14 = x10;
x15 = x11;
x12 = _mm_shuffle_epi8(x12, xtemp);
x13 = _mm_shuffle_epi8(x13, xtemp);
x14 = _mm_shuffle_epi8(x14, xtemp);
x15 = _mm_shuffle_epi8(x15, xtemp);
x12 = _mm_aesenc_si128(x12, x2);
x13 = _mm_aesenc_si128(x13, x2);
x14 = _mm_aesenc_si128(x14, x2);
x15 = _mm_aesenc_si128(x15, x2);
x12 = _mm_xor_si128(x12, x11);
x15 = _mm_xor_si128(x15, x3);
x15 = _mm_xor_si128(x15, x4);
x13 = _mm_xor_si128(x13, x12);
x14 = _mm_xor_si128(x14, x13);
x15 = _mm_xor_si128(x15, x14);
/* xmm12..xmm15 = rk[112..127] */
/* F3 - ninth round */
x6 = x8;
x8 = _mm_xor_si128(x8, x1);
x8 = _mm_aesenc_si128(x8, x9);
x8 = _mm_aesenc_si128(x8, x10);
x8 = _mm_aesenc_si128(x8, x2);
x0 = _mm_xor_si128(x0, x8);
x8 = x6;
/* F3 - tenth round */
x6 = x11;
x11 = _mm_xor_si128(x11, x0);
x11 = _mm_aesenc_si128(x11, x12);
x11 = _mm_aesenc_si128(x11, x13);
x11 = _mm_aesenc_si128(x11, x2);
x1 = _mm_xor_si128(x1, x11);
x11 = x6;
/* key schedule */
mixing();
/* xmm8..xmm11 = rk[128..143] */
/* F3 - eleventh round */
x6 = x14;
x14 = _mm_xor_si128(x14, x1);
x14 = _mm_aesenc_si128(x14, x15);
x14 = _mm_aesenc_si128(x14, x8);
x14 = _mm_aesenc_si128(x14, x2);
x0 = _mm_xor_si128(x0, x14);
x14 = x6;
/* F3 - twelfth round */
x6 = x9;
x9 = _mm_xor_si128(x9, x0);
x9 = _mm_aesenc_si128(x9, x10);
x9 = _mm_aesenc_si128(x9, x11);
x9 = _mm_aesenc_si128(x9, x2);
x1 = _mm_xor_si128(x1, x9);
x9 = x6;
/* feedforward */
x0 = _mm_xor_si128(x0, ptxt1);
x1 = _mm_xor_si128(x1, ptxt2);
_mm_storeu_si128((__m128i *)SHAVITE_PTXT, x0);
_mm_storeu_si128((__m128i *)(SHAVITE_PTXT + 16), x1);
return;
}
void Compress256(const unsigned char *message_block, unsigned char *chaining_value, unsigned long long counter,
const unsigned char salt[32])
{
int i, j;
for (i=0;i<8*4;i++)
SHAVITE_PTXT[i]=chaining_value[i];
for (i=0;i<16;i++)
SHAVITE_MESS[i] = *((unsigned int*)(message_block+4*i));
SHAVITE_CNTS[0] = (unsigned int)(counter & 0xFFFFFFFFULL);
SHAVITE_CNTS[1] = (unsigned int)(counter>>32);
/* encryption + Davies-Meyer transform */
E256();
for (i=0; i<4*8; i++)
chaining_value[i]=SHAVITE_PTXT[i];
return;
}
int main(int argc, char *argv[])
{
const int cvlen = 32;
unsigned char *cv = (unsigned char *)malloc(cvlen);
for (int x=0; x < cvlen; x++)
cv[x] = x + argc;
const int mblen = 64;
unsigned char *mb = (unsigned char *)malloc(mblen);
for (int x=0; x < mblen; x++)
mb[x] = x + argc;
unsigned long long counter = 0x1234567812345678ull;
unsigned char s[32] = {0};
oldCompress256(mb, cv, counter, s);
printf("old: ");
for (int x=0; x < cvlen; x++)
printf("%2x ", cv[x]);
printf("\n");
for (int x=0; x < cvlen; x++)
cv[x] = x + argc;
Compress256(mb, cv, counter, s);
printf("new: ");
for (int x=0; x < cvlen; x++)
printf("%2x ", cv[x]);
printf("\n");
}
Edit:
The globals are only used to pass values between C and asm. Perhaps the asm writer didn't know how to access parameters? IAC, they're unnecessary (and the source of the thread safety issues). Here's the code without them (along with some cosmetic changes):
#include <x86intrin.h>
#include <stdio.h>
#include <time.h>
#define mixing() \
x11 = x15;\
x10 = x14;\
x9 = x13;\
x8 = x12;\
\
x6 = x11;\
x6 = _mm_srli_si128(x6, 4);\
x8 = _mm_xor_si128(x8, x6);\
x6 = x8;\
x6 = _mm_slli_si128(x6, 12);\
x8 = _mm_xor_si128(x8, x6);\
\
x7 = x8;\
x7 = _mm_srli_si128(x7, 4);\
x9 = _mm_xor_si128(x9, x7);\
x7 = x9;\
x7 = _mm_slli_si128(x7, 12);\
x9 = _mm_xor_si128(x9, x7);\
\
x6 = x9;\
x6 = _mm_srli_si128(x6, 4);\
x10 = _mm_xor_si128(x10, x6);\
x6 = x10;\
x6 = _mm_slli_si128(x6, 12);\
x10 = _mm_xor_si128(x10, x6);\
\
x7 = x10;\
x7 = _mm_srli_si128(x7, 4);\
x11 = _mm_xor_si128(x11, x7);\
x7 = x11;\
x7 = _mm_slli_si128(x7, 12);\
x11 = _mm_xor_si128(x11, x7);
// If mess & chain won't be 16byte aligned, change _mm_load to _mm_loadu and
// _mm_store to _mm_storeu
void Compress256(const __m128i *mess, __m128i *chain, unsigned long long counter, const unsigned char salt[32])
{
// note: _mm_set_epi32 uses (int e3, int e2, int e1, int e0)
const __m128i SHAVITE_REVERSE = _mm_set_epi32(0x03020100, 0x0f0e0d0c, 0x0b0a0908, 0x07060504);
const __m128i SHAVITE256_XOR2 = _mm_set_epi32(0x0, 0x0, 0xFFFFFFFF, 0x0);
const __m128i SHAVITE256_XOR3 = _mm_set_epi32(0x0, 0xFFFFFFFF, 0x0, 0x0);
const __m128i SHAVITE256_XOR4 = _mm_set_epi32(0xFFFFFFFF, 0x0, 0x0, 0x0);
const __m128i SHAVITE_CNTS =
_mm_set_epi32(0, 0, (unsigned int)(counter>>32), (unsigned int)(counter & 0xFFFFFFFFULL));
__m128i x0, x1, x2, x3, x5, x6, x7, x8, x9, x10, x11, x12, x13, x14, x15;
/* (L,R) = (xmm0,xmm1) */
const __m128i ptxt1 = _mm_load_si128(chain);
const __m128i ptxt2 = _mm_load_si128(chain+1);
x0 = ptxt1;
x1 = ptxt2;
x3 = SHAVITE_CNTS;
x2 = _mm_setzero_si128();
/* init key schedule */
x8 = _mm_load_si128(mess);
x9 = _mm_load_si128(mess+1);
x10 = _mm_load_si128(mess+2);
x11 = _mm_load_si128(mess+3);
/* xmm8..xmm11 = rk[0..15] */
/* start key schedule */
x12 = x8;
x13 = x9;
x14 = x10;
x15 = x11;
x12 = _mm_shuffle_epi8(x12, SHAVITE_REVERSE);
x13 = _mm_shuffle_epi8(x13, SHAVITE_REVERSE);
x14 = _mm_shuffle_epi8(x14, SHAVITE_REVERSE);
x15 = _mm_shuffle_epi8(x15, SHAVITE_REVERSE);
x12 = _mm_aesenc_si128(x12, x2);
x13 = _mm_aesenc_si128(x13, x2);
x14 = _mm_aesenc_si128(x14, x2);
x15 = _mm_aesenc_si128(x15, x2);
x12 = _mm_xor_si128(x12, x3);
x12 = _mm_xor_si128(x12, SHAVITE256_XOR2);
x12 = _mm_xor_si128(x12, x11);
x13 = _mm_xor_si128(x13, x12);
x14 = _mm_xor_si128(x14, x13);
x15 = _mm_xor_si128(x15, x14);
/* xmm12..xmm15 = rk[16..31] */
/* F3 - first round */
x6 = x8;
x8 = _mm_xor_si128(x8, x1);
x8 = _mm_aesenc_si128(x8, x9);
x8 = _mm_aesenc_si128(x8, x10);
x8 = _mm_aesenc_si128(x8, x2);
x0 = _mm_xor_si128(x0, x8);
x8 = x6;
/* F3 - second round */
x6 = x11;
x11 = _mm_xor_si128(x11, x0);
x11 = _mm_aesenc_si128(x11, x12);
x11 = _mm_aesenc_si128(x11, x13);
x11 = _mm_aesenc_si128(x11, x2);
x1 = _mm_xor_si128(x1, x11);
x11 = x6;
/* key schedule */
mixing();
/* xmm8..xmm11 - rk[32..47] */
/* F3 - third round */
x6 = x14;
x14 = _mm_xor_si128(x14, x1);
x14 = _mm_aesenc_si128(x14, x15);
x14 = _mm_aesenc_si128(x14, x8);
x14 = _mm_aesenc_si128(x14, x2);
x0 = _mm_xor_si128(x0, x14);
x14 = x6;
/* key schedule */
x3 = _mm_shuffle_epi32(x3, 135);
x12 = x8;
x13 = x9;
x14 = x10;
x15 = x11;
x12 = _mm_shuffle_epi8(x12, SHAVITE_REVERSE);
x13 = _mm_shuffle_epi8(x13, SHAVITE_REVERSE);
x14 = _mm_shuffle_epi8(x14, SHAVITE_REVERSE);
x15 = _mm_shuffle_epi8(x15, SHAVITE_REVERSE);
x12 = _mm_aesenc_si128(x12, x2);
x13 = _mm_aesenc_si128(x13, x2);
x14 = _mm_aesenc_si128(x14, x2);
x15 = _mm_aesenc_si128(x15, x2);
x12 = _mm_xor_si128(x12, x11);
x14 = _mm_xor_si128(x14, x3);
x14 = _mm_xor_si128(x14, SHAVITE256_XOR3);
x13 = _mm_xor_si128(x13, x12);
x14 = _mm_xor_si128(x14, x13);
x15 = _mm_xor_si128(x15, x14);
/* xmm12..xmm15 - rk[48..63] */
/* F3 - fourth round */
x6 = x9;
x9 = _mm_xor_si128(x9, x0);
x9 = _mm_aesenc_si128(x9, x10);
x9 = _mm_aesenc_si128(x9, x11);
x9 = _mm_aesenc_si128(x9, x2);
x1 = _mm_xor_si128(x1, x9);
x9 = x6;
/* key schedule */
mixing();
/* xmm8..xmm11 = rk[64..79] */
/* F3 - fifth round */
x6 = x12;
x12 = _mm_xor_si128(x12, x1);
x12 = _mm_aesenc_si128(x12, x13);
x12 = _mm_aesenc_si128(x12, x14);
x12 = _mm_aesenc_si128(x12, x2);
x0 = _mm_xor_si128(x0, x12);
x12 = x6;
/* F3 - sixth round */
x6 = x15;
x15 = _mm_xor_si128(x15, x0);
x15 = _mm_aesenc_si128(x15, x8);
x15 = _mm_aesenc_si128(x15, x9);
x15 = _mm_aesenc_si128(x15, x2);
x1 = _mm_xor_si128(x1, x15);
x15 = x6;
/* key schedule */
x3 = _mm_shuffle_epi32(x3, 147);
x12 = x8;
x13 = x9;
x14 = x10;
x15 = x11;
x12 = _mm_shuffle_epi8(x12, SHAVITE_REVERSE);
x13 = _mm_shuffle_epi8(x13, SHAVITE_REVERSE);
x14 = _mm_shuffle_epi8(x14, SHAVITE_REVERSE);
x15 = _mm_shuffle_epi8(x15, SHAVITE_REVERSE);
x12 = _mm_aesenc_si128(x12, x2);
x13 = _mm_aesenc_si128(x13, x2);
x14 = _mm_aesenc_si128(x14, x2);
x15 = _mm_aesenc_si128(x15, x2);
x12 = _mm_xor_si128(x12, x11);
x13 = _mm_xor_si128(x13, x3);
x13 = _mm_xor_si128(x13, SHAVITE256_XOR4);
x13 = _mm_xor_si128(x13, x12);
x14 = _mm_xor_si128(x14, x13);
x15 = _mm_xor_si128(x15, x14);
/* xmm12..xmm15 = rk[80..95] */
/* F3 - seventh round */
x6 = x10;
x10 = _mm_xor_si128(x10, x1);
x10 = _mm_aesenc_si128(x10, x11);
x10 = _mm_aesenc_si128(x10, x12);
x10 = _mm_aesenc_si128(x10, x2);
x0 = _mm_xor_si128(x0, x10);
x10 = x6;
/* key schedule */
mixing();
/* xmm8..xmm11 = rk[96..111] */
/* F3 - eigth round */
x6 = x13;
x13 = _mm_xor_si128(x13, x0);
x13 = _mm_aesenc_si128(x13, x14);
x13 = _mm_aesenc_si128(x13, x15);
x13 = _mm_aesenc_si128(x13, x2);
x1 = _mm_xor_si128(x1, x13);
x13 = x6;
/* key schedule */
x3 = _mm_shuffle_epi32(x3, 135);
x12 = x8;
x13 = x9;
x14 = x10;
x15 = x11;
x12 = _mm_shuffle_epi8(x12, SHAVITE_REVERSE);
x13 = _mm_shuffle_epi8(x13, SHAVITE_REVERSE);
x14 = _mm_shuffle_epi8(x14, SHAVITE_REVERSE);
x15 = _mm_shuffle_epi8(x15, SHAVITE_REVERSE);
x12 = _mm_aesenc_si128(x12, x2);
x13 = _mm_aesenc_si128(x13, x2);
x14 = _mm_aesenc_si128(x14, x2);
x15 = _mm_aesenc_si128(x15, x2);
x12 = _mm_xor_si128(x12, x11);
x15 = _mm_xor_si128(x15, x3);
x15 = _mm_xor_si128(x15, SHAVITE256_XOR4);
x13 = _mm_xor_si128(x13, x12);
x14 = _mm_xor_si128(x14, x13);
x15 = _mm_xor_si128(x15, x14);
/* xmm12..xmm15 = rk[112..127] */
/* F3 - ninth round */
x6 = x8;
x8 = _mm_xor_si128(x8, x1);
x8 = _mm_aesenc_si128(x8, x9);
x8 = _mm_aesenc_si128(x8, x10);
x8 = _mm_aesenc_si128(x8, x2);
x0 = _mm_xor_si128(x0, x8);
x8 = x6;
/* F3 - tenth round */
x6 = x11;
x11 = _mm_xor_si128(x11, x0);
x11 = _mm_aesenc_si128(x11, x12);
x11 = _mm_aesenc_si128(x11, x13);
x11 = _mm_aesenc_si128(x11, x2);
x1 = _mm_xor_si128(x1, x11);
x11 = x6;
/* key schedule */
mixing();
/* xmm8..xmm11 = rk[128..143] */
/* F3 - eleventh round */
x6 = x14;
x14 = _mm_xor_si128(x14, x1);
x14 = _mm_aesenc_si128(x14, x15);
x14 = _mm_aesenc_si128(x14, x8);
x14 = _mm_aesenc_si128(x14, x2);
x0 = _mm_xor_si128(x0, x14);
x14 = x6;
/* F3 - twelfth round */
x6 = x9;
x9 = _mm_xor_si128(x9, x0);
x9 = _mm_aesenc_si128(x9, x10);
x9 = _mm_aesenc_si128(x9, x11);
x9 = _mm_aesenc_si128(x9, x2);
x1 = _mm_xor_si128(x1, x9);
x9 = x6;
/* feedforward */
x0 = _mm_xor_si128(x0, ptxt1);
x1 = _mm_xor_si128(x1, ptxt2);
_mm_store_si128(chain, x0);
_mm_store_si128(chain + 1, x1);
}
int main(int argc, char *argv[])
{
__m128i chain[2], mess[4];
unsigned char *p;
// argc prevents compiler from precalculating results
p = (unsigned char *)mess;
for (int x=0; x < 64; x++)
p[x] = x + argc;
p = (unsigned char *)chain;
for (int x=0; x < 32; x++)
p[x] = x + argc;
unsigned long long counter = 0x1234567812345678ull + argc;
// Unused, but prototype requires it.
unsigned char s[32] = {0};
Compress256(mess, chain, counter, s);
for (int x=0; x < 32; x++)
printf("%02x ", p[x]);
printf("\n");
struct timespec start, end;
clock_gettime(CLOCK_MONOTONIC, &start);
unsigned char res = 0;
for (int x=0; x < 400000; x++)
{
Compress256(mess, chain, counter, s);
// Ensure optimizer doesn't omit the calc
res ^= *p;
}
clock_gettime(CLOCK_MONOTONIC, &end);
unsigned long long delta_us = (end.tv_sec - start.tv_sec) * 1000000ull + (end.tv_nsec - start.tv_nsec) / 1000ull;
printf("%x: %llu\n", res, delta_us);
}
Answered By - David Wohlferd