/*
 * aes-cbc.cpp
 *
 *  Created on: Feb 19, 2013
 *      Author: nick
 */

#include <stdint.h>
#include <stdio.h>
#include <wmmintrin.h>

#if !defined (ALIGN16)
# if defined (__GNUC__)
# define ALIGN16 __attribute__ ( (aligned (16)))
# endif
#endif

class AES
{
public:
	AES()
	{

	}
	void set_user_key(const unsigned char* user_cypher)
	{
		user_key = _mm_loadu_si128((__m128i *) user_cypher);
		calculate_encryption_key();
		calculate_decryption_key();
	}
	static void print_m128i_with_string(char* string, __m128i data);
	static void print_m128i_with_string(char* string, __m128i* data, unsigned int length);

	int ECB_encrypt(const unsigned char *in, //pointer to the PLAINTEXT
			unsigned char *out, //pointer to the CIPHERTEXT buffer
			unsigned long length //text length in bytes
			);
	int ECB_decrypt(const unsigned char *in, //pointer to the CIPHERTEXT
			unsigned char *out, //pointer to the DECRYPTED TEXT buffer
			unsigned long length //text length in bytes
			);
	int CBC_encrypt(const unsigned char *in, unsigned char *out,
			unsigned char ivec[16], unsigned long length);
	int CBC_decrypt(const unsigned char *in, unsigned char *out,
			unsigned char ivec[16], unsigned long length);
	static int Check_CPU_support_AES();
private:
	__m128i enc_key[10];
	__m128i dec_key[10];

	__m128i user_key;
	__m128i init_vec;

	inline __m128i AES_128_ASSIST(__m128i temp1, __m128i temp2);
	void calculate_encryption_key();
	void calculate_decryption_key();
	void encrypt_helper(__m128i *in, __m128i * out);
	void decrypt_helper(__m128i *in, __m128i * out);

};

/*****************************************************************************/

inline __m128i AES::AES_128_ASSIST(__m128i temp1, __m128i temp2)
{
	__m128i temp3;
	temp2 = _mm_shuffle_epi32 (temp2 ,0xff);
	temp3 = _mm_slli_si128 (temp1, 0x4);
	temp1 = _mm_xor_si128(temp1, temp3);
	temp3 = _mm_slli_si128 (temp3, 0x4);
	temp1 = _mm_xor_si128(temp1, temp3);
	temp3 = _mm_slli_si128 (temp3, 0x4);
	temp1 = _mm_xor_si128(temp1, temp3);
	temp1 = _mm_xor_si128(temp1, temp2);
	return temp1;
}

void AES::calculate_encryption_key()
{
	__m128i temp1, temp2;

	temp1 = user_key;
	enc_key[0] = temp1;
	temp2 = _mm_aeskeygenassist_si128(temp1, 0x1);
	temp1 = AES_128_ASSIST(temp1, temp2);
	enc_key[1] = temp1;
	temp2 = _mm_aeskeygenassist_si128(temp1, 0x2);
	temp1 = AES_128_ASSIST(temp1, temp2);
	enc_key[2] = temp1;
	temp2 = _mm_aeskeygenassist_si128(temp1, 0x4);
	temp1 = AES_128_ASSIST(temp1, temp2);
	enc_key[3] = temp1;
	temp2 = _mm_aeskeygenassist_si128(temp1, 0x8);
	temp1 = AES_128_ASSIST(temp1, temp2);
	enc_key[4] = temp1;
	temp2 = _mm_aeskeygenassist_si128(temp1, 0x10);
	temp1 = AES_128_ASSIST(temp1, temp2);
	enc_key[5] = temp1;
	temp2 = _mm_aeskeygenassist_si128(temp1, 0x20);
	temp1 = AES_128_ASSIST(temp1, temp2);
	enc_key[6] = temp1;
	temp2 = _mm_aeskeygenassist_si128(temp1, 0x40);
	temp1 = AES_128_ASSIST(temp1, temp2);
	enc_key[7] = temp1;
	temp2 = _mm_aeskeygenassist_si128(temp1, 0x80);
	temp1 = AES_128_ASSIST(temp1, temp2);
	enc_key[8] = temp1;
	temp2 = _mm_aeskeygenassist_si128(temp1, 0x1b);
	temp1 = AES_128_ASSIST(temp1, temp2);
	enc_key[9] = temp1;
	temp2 = _mm_aeskeygenassist_si128(temp1, 0x36);
	temp1 = AES_128_ASSIST(temp1, temp2);
	enc_key[10] = temp1;
}

void AES::calculate_decryption_key()
{
	dec_key[10] = enc_key[0];
	dec_key[9] = _mm_aesimc_si128(enc_key[1]);
	dec_key[8] = _mm_aesimc_si128(enc_key[2]);
	dec_key[7] = _mm_aesimc_si128(enc_key[3]);
	dec_key[6] = _mm_aesimc_si128(enc_key[4]);
	dec_key[5] = _mm_aesimc_si128(enc_key[5]);
	dec_key[4] = _mm_aesimc_si128(enc_key[6]);
	dec_key[3] = _mm_aesimc_si128(enc_key[7]);
	dec_key[2] = _mm_aesimc_si128(enc_key[8]);
	dec_key[1] = _mm_aesimc_si128(enc_key[9]);
	dec_key[0] = enc_key[10];
}


/* Note – the length of the output buffer is assumed to be a multiple of 16 bytes */
int AES::ECB_encrypt(const unsigned char *in, //pointer to the PLAINTEXT
		unsigned char *out, //pointer to the CIPHERTEXT buffer
		unsigned long length //text length in bytes
		//pointer to the expanded key schedule
		) //number of AES rounds 10,12 or 14
{
	unsigned int i, j;
	if (length % 16 != 0)
	{
		printf("input length is not multiple times of 16 bytes!!!");
		return -1;
	}
	length = length / 16;

	for (i = 0; i < length; i++)
	{
		__m128i feedback = _mm_xor_si128(((__m128i *) in)[i], enc_key[0]);
		for (j = 1; j < 10; j++)
		{
			feedback = _mm_aesenc_si128(feedback, enc_key[j]);
		}
		feedback = _mm_aesenclast_si128(feedback, enc_key[10]);
		_mm_storeu_si128(&((__m128i *) out)[i], feedback);
	}
	return 0;
}


int AES::ECB_decrypt(const unsigned char *in, //pointer to the CIPHERTEXT
		unsigned char *out, //pointer to the DECRYPTED TEXT buffer
		unsigned long length //text length in bytes
		//pointer to the expanded key schedule
		) //number of AES rounds 10,12 or 14
{
	unsigned int i, j;
	if (length % 16 != 0)
	{
		printf("input length is not multiple times of 16 bytes!!!");
		return -1;
	}
	length = length / 16;

	for (i = 0; i < length; i++)
	{
		__m128i feedback = _mm_xor_si128(((__m128i *) in)[i], dec_key[0]);
		for (j = 1; j < 10; j++)
		{
			feedback = _mm_aesdec_si128(feedback, dec_key[j]);
		}
		feedback = _mm_aesdeclast_si128(feedback, dec_key[10]);
		_mm_storeu_si128(&((__m128i *) out)[i], feedback);
	}
	return 0;
}

int AES::CBC_encrypt(const unsigned char *in, unsigned char *out,
		unsigned char ivec[16], unsigned long length)
{
	unsigned int i, j;
	if (length % 16 != 0)
	{
		printf("input length is not multiple times of 16 bytes!!!");
		return -1;
	}
	length = length / 16;
	__m128i feedback = _mm_loadu_si128((__m128i *) ivec);
	for (i = 0; i < length; i++)
	{
		feedback = _mm_xor_si128(((__m128i *) in)[i], feedback);
		feedback = _mm_xor_si128(feedback, enc_key[0]);

		for (j = 1; j < 10; j++)
		{
			feedback = _mm_aesenc_si128(feedback, enc_key[j]);
		}
		feedback = _mm_aesenclast_si128(feedback, enc_key[10]);
		_mm_storeu_si128(&((__m128i *) out)[i], feedback);
	}
	return 0;
}

int AES::CBC_decrypt(const unsigned char *in, unsigned char *out,
		unsigned char ivec[16], unsigned long length)
{
	unsigned int i, j;
	if (length % 16 != 0)
	{
		printf("input length is not multiple times of 16 bytes!!!");
		return -1;
	}
	length = length / 16;
	__m128i feedback = _mm_loadu_si128((__m128i *) ivec);
	for (i = 0; i < length; i++)
	{
		__m128i last_in = _mm_loadu_si128(&((__m128i *) in)[i]);
		__m128i data = _mm_xor_si128(last_in, dec_key[0]);

		for (j = 1; j < 10; j++)
		{
			data = _mm_aesdec_si128(data, dec_key[j]);
		}
		data = _mm_aesdeclast_si128(data, dec_key[10]);
		data = _mm_xor_si128(data, feedback);
		_mm_storeu_si128(&((__m128i *) out)[i], data);
		feedback = last_in;
	}
	return 0;
}

void AES::print_m128i_with_string(char* string, __m128i data)
{
	unsigned char *pointer = (unsigned char*) &data;
	int i;
	printf("%-40s[0x", string);
	for (i = 0; i < 16; i++)
		printf("%02x", pointer[i]);
	printf("]\n");
}

void AES::print_m128i_with_string(char* string, __m128i* data, unsigned int length)
{
	printf("%s:\n", string);
	for (unsigned int index = 0; index < length; index ++)
	{
		printf("%-40s", " ");
		printf("[0x");
		unsigned char *pointer = (unsigned char*) &(data[index]);
		int i;
		for (i = 0; i < 16; i++)
		{
			printf("%02x", pointer[i]);
		}
		printf("]\n");
	}
}

#define cpuid(func,ax,bx,cx,dx)\
__asm__ __volatile__ ("cpuid":\
"=a" (ax), "=b" (bx), "=c" (cx), "=d" (dx) : "a" (func));

int AES::Check_CPU_support_AES()
{
	unsigned int a, b, c, d;
	cpuid(1, a, b, c, d);
	return (c & 0x2000000);
}
/*****************************************************************************/
//3AD77BB40D7A3660A89ECAF32466EF97
//F5D3D58503B9699DE785895A96FDBAAF
int test1()
{
	printf("%s\n", __PRETTY_FUNCTION__);
	ALIGN16 uint8_t THE_TEST_VECTOR[] =
	{
			//6bc1bee22e409f96e93d7e117393172a
			0x6b, 0xc1, 0xbe, 0xe2, 0x2e, 0x40, 0x9f, 0x96,
			0xe9, 0x3d, 0x7e, 0x11, 0x73, 0x93, 0x17, 0x2a,
	};

	ALIGN16 uint8_t THE_TEST_KEY[] =
	{
			//Encryption key: 2b7e151628aed2a6abf7158809cf4f3c
			0x2b, 0x7e, 0x15, 0x16, 0x28, 0xae, 0xd2, 0xa6,
			0xab, 0xf7, 0x15, 0x88, 0x09, 0xcf, 0x4f, 0x3c

	};
	ALIGN16 uint8_t THE_EXPECTED[] =
	{
			//3ad77bb40d7a3660a89ecaf32466ef97
			0x3a, 0xd7, 0x7b, 0xb4, 0x0d, 0x7a, 0x36, 0x60,
			0xa8, 0x9e, 0xca, 0xf3, 0x24, 0x66, 0xef, 0x97,
	};
	AES ecb;
	if (AES::Check_CPU_support_AES())
	{
		AES::print_m128i_with_string("test vector",
				*((__m128i *) THE_TEST_VECTOR));

		AES::print_m128i_with_string("test key", *((__m128i *) THE_TEST_KEY));

		AES::print_m128i_with_string("test expected",
				*((__m128i *) THE_EXPECTED));
		ecb.set_user_key(THE_TEST_KEY);
		ALIGN16 uint8_t buffer[sizeof(THE_TEST_VECTOR)];
		ecb.ECB_encrypt(THE_TEST_VECTOR, buffer, sizeof(THE_TEST_VECTOR));
		AES::print_m128i_with_string("cypher text", *((__m128i *) buffer));

		ecb.ECB_decrypt(THE_EXPECTED, buffer, sizeof(THE_EXPECTED));
		AES::print_m128i_with_string("message text", *((__m128i *) buffer));

	}

	return 0;

}

/*****************************************************************************/
//3AD77BB40D7A3660A89ECAF32466EF97
//F5D3D58503B9699DE785895A96FDBAAF
int test2()
{
	printf("%s\n", __PRETTY_FUNCTION__);
	ALIGN16 uint8_t THE_TEST_VECTOR[] =
	{
			//6bc1bee22e409f96e93d7e117393172a
			0x6b, 0xc1, 0xbe, 0xe2, 0x2e, 0x40, 0x9f, 0x96,
			0xe9, 0x3d, 0x7e, 0x11, 0x73, 0x93, 0x17, 0x2a,
	};

	ALIGN16 uint8_t THE_TEST_KEY[] =
	{
			//Encryption key: 2b7e151628aed2a6abf7158809cf4f3c
			0x2b, 0x7e, 0x15, 0x16, 0x28, 0xae, 0xd2, 0xa6,
			0xab, 0xf7, 0x15, 0x88, 0x09, 0xcf, 0x4f, 0x3c

	};
	ALIGN16 uint8_t CBC_IV[] =
	{
			0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
			0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f
	};

	ALIGN16 uint8_t THE_EXPECTED[] =
	{
			//3ad77bb40d7a3660a89ecaf32466ef97
			0x76, 0x49, 0xab, 0xac, 0x81, 0x19, 0xb2, 0x46,
			0xce, 0xe9, 0x8e, 0x9b, 0x12, 0xe9, 0x19, 0x7d,
	};
	AES cbc;
	if (AES::Check_CPU_support_AES())
	{
		AES::print_m128i_with_string("test vector",
				*((__m128i *) THE_TEST_VECTOR));

		AES::print_m128i_with_string("test key", *((__m128i *) THE_TEST_KEY));

		AES::print_m128i_with_string("test cbc_vec", *((__m128i *) CBC_IV));

		AES::print_m128i_with_string("test expected", *((__m128i *) THE_EXPECTED));

		cbc.set_user_key(THE_TEST_KEY);
		ALIGN16 uint8_t buffer[sizeof(THE_TEST_VECTOR)];
		cbc.CBC_encrypt(THE_TEST_VECTOR, buffer, CBC_IV,
				sizeof(THE_TEST_VECTOR));
		AES::print_m128i_with_string("cypher text", *((__m128i *) buffer));

		cbc.CBC_decrypt(THE_EXPECTED, buffer, CBC_IV, sizeof(THE_EXPECTED));
		AES::print_m128i_with_string("message text", *((__m128i *) buffer));

	}

	return 0;

}


int test3()
{
	printf("%s\n", __PRETTY_FUNCTION__);
	ALIGN16 uint8_t THE_TEST_VECTOR[] =
	{
			0x6b,0xc1,0xbe,0xe2,0x2e,0x40,0x9f,0x96,
			0xe9,0x3d,0x7e,0x11,0x73,0x93,0x17,0x2a,
			0xae,0x2d,0x8a,0x57,0x1e,0x03,0xac,0x9c,
			0x9e,0xb7,0x6f,0xac,0x45,0xaf,0x8e,0x51,
			0x30,0xc8,0x1c,0x46,0xa3,0x5c,0xe4,0x11,
			0xe5,0xfb,0xc1,0x19,0x1a,0x0a,0x52,0xef,
			0xf6,0x9f,0x24,0x45,0xdf,0x4f,0x9b,0x17,
			0xad,0x2b,0x41,0x7b,0xe6,0x6c,0x37,0x10


	};

	ALIGN16 uint8_t THE_TEST_KEY[] =
	{
			//Encryption key: 2b7e151628aed2a6abf7158809cf4f3c
			0x2b, 0x7e, 0x15, 0x16, 0x28, 0xae, 0xd2, 0xa6,
			0xab, 0xf7, 0x15, 0x88, 0x09, 0xcf, 0x4f, 0x3c

	};
	ALIGN16 uint8_t CBC_IV[] =
	{
			0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
			0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f
	};

	ALIGN16 uint8_t THE_EXPECTED[] =
	{
			0x76,0x49,0xab,0xac,0x81,0x19,0xb2,0x46,
			0xce,0xe9,0x8e,0x9b,0x12,0xe9,0x19,0x7d,
			0x50,0x86,0xcb,0x9b,0x50,0x72,0x19,0xee,
			0x95,0xdb,0x11,0x3a,0x91,0x76,0x78,0xb2,
			0x73,0xbe,0xd6,0xb8,0xe3,0xc1,0x74,0x3b,
			0x71,0x16,0xe6,0x9e,0x22,0x22,0x95,0x16,
			0x3f,0xf1,0xca,0xa1,0x68,0x1f,0xac,0x09,
			0x12,0x0e,0xca,0x30,0x75,0x86,0xe1,0xa7

	};
	AES cbc;
	if (AES::Check_CPU_support_AES())
	{
		int block_number = sizeof(THE_TEST_VECTOR)/sizeof(__m128i);
		AES::print_m128i_with_string("test vector",
				((__m128i *) THE_TEST_VECTOR), block_number);

		AES::print_m128i_with_string("test key", *((__m128i *) THE_TEST_KEY));

		AES::print_m128i_with_string("test cbc_vec", *((__m128i *) CBC_IV));

		AES::print_m128i_with_string("test expected", ((__m128i *) THE_EXPECTED),
				block_number);

		cbc.set_user_key(THE_TEST_KEY);
		ALIGN16 uint8_t buffer[sizeof(THE_TEST_VECTOR)];
		cbc.CBC_encrypt(THE_TEST_VECTOR, buffer, CBC_IV,
				sizeof(THE_TEST_VECTOR));
		AES::print_m128i_with_string("cypher text", ((__m128i *) buffer), block_number);

		cbc.CBC_decrypt(THE_EXPECTED, buffer, CBC_IV, sizeof(THE_EXPECTED));
		AES::print_m128i_with_string("message text", ((__m128i *) buffer), block_number);

	}

	return 0;

}

int test4()
{
	printf("%s\n", __PRETTY_FUNCTION__);
	ALIGN16 uint8_t THE_TEST_VECTOR[] =
	{
			0x6b,0xc1,0xbe,0xe2,0x2e,0x40,0x9f,0x96,
			0xe9,0x3d,0x7e,0x11,0x73,0x93,0x17,0x2a,
			0xae,0x2d,0x8a,0x57,0x1e,0x03,0xac,0x9c,
			0x9e,0xb7,0x6f,0xac,0x45,0xaf,0x8e,0x51,
			0x30,0xc8,0x1c,0x46,0xa3,0x5c,0xe4,0x11,
			0xe5,0xfb,0xc1,0x19,0x1a,0x0a,0x52,0xef,
			0xf6,0x9f,0x24,0x45,0xdf,0x4f,0x9b,0x17,
			0xad,0x2b,0x41,0x7b,0xe6,0x6c,0x37,0x10


	};

	ALIGN16 uint8_t THE_TEST_KEY[] =
	{
			//Encryption key: 2b7e151628aed2a6abf7158809cf4f3c
			0x2b, 0x7e, 0x15, 0x16, 0x28, 0xae, 0xd2, 0xa6,
			0xab, 0xf7, 0x15, 0x88, 0x09, 0xcf, 0x4f, 0x3c

	};

	ALIGN16 uint8_t THE_EXPECTED[] =
	{
			0x3a,0xd7,0x7b,0xb4,0x0d,0x7a,0x36,0x60,
			0xa8,0x9e,0xca,0xf3,0x24,0x66,0xef,0x97,
			0xf5,0xd3,0xd5,0x85,0x03,0xb9,0x69,0x9d,
			0xe7,0x85,0x89,0x5a,0x96,0xfd,0xba,0xaf,
			0x43,0xb1,0xcd,0x7f,0x59,0x8e,0xce,0x23,
			0x88,0x1b,0x00,0xe3,0xed,0x03,0x06,0x88,
			0x7b,0x0c,0x78,0x5e,0x27,0xe8,0xad,0x3f,
			0x82,0x23,0x20,0x71,0x04,0x72,0x5d,0xd4

	};
	AES cbc;
	if (AES::Check_CPU_support_AES())
	{
		int block_number = sizeof(THE_TEST_VECTOR)/sizeof(__m128i);
		AES::print_m128i_with_string("test vector",
				((__m128i *) THE_TEST_VECTOR), block_number);

		AES::print_m128i_with_string("test key", *((__m128i *) THE_TEST_KEY));


		AES::print_m128i_with_string("test expected", ((__m128i *) THE_EXPECTED),
				block_number);

		cbc.set_user_key(THE_TEST_KEY);
		ALIGN16 uint8_t buffer[sizeof(THE_TEST_VECTOR)];
		cbc.ECB_encrypt(THE_TEST_VECTOR, buffer, sizeof(THE_TEST_VECTOR));
		AES::print_m128i_with_string("cypher text", ((__m128i *) buffer), block_number);

		cbc.ECB_decrypt(THE_EXPECTED, buffer, sizeof(THE_EXPECTED));
		AES::print_m128i_with_string("message text", ((__m128i *) buffer), block_number);

	}

	return 0;
}

int main()
{
	test1();
	test2();
	test3();
	test4();
	return 0;
}
