
/* This is an independent implementation of the E2 encryption       */
/* algorithm designed by a Nippon Telegraph and Telephone team in	*/
/* Japan and submitted as a candidate in the US NIST Advanced		*/
/* Encryption Standard (AES) programme.                             */
/*                                                                  */
/* Copyright in this implementation is held by Dr B R Gladman. In	*/
/* accordance with the wishes of NTT this implementation is made	*/
/* available for academic and study purposes only.					*/
/*                                                                  */
/* This implementation is partially optimised and achieves a speed	*/
/* of about 25 Megabits/second on my Pentium Pro 200.				*/
/*                                                                  */
/* Dr Brian Gladman (gladman@seven77.demon.co.uk) 24th July 1998    */
/*                                                                  */

#include "../std_defs.h"

static char *alg_name = "E2";

char *cipher_name()
{
    return alg_name;
};

typedef union
{	u8byte	qq[1];
	u4byte	ll[2];
	u2byte	ww[4];
	u1byte	bb[8];
} un8byte;

#ifndef _MSC_VER

#define rotr(x,n)   (((x) >> ((int)(n))) | ((x) << (32 - (int)(n))))
#define rotl(x,n)   (((x) << ((int)(n))) | ((x) >> (32 - (int)(n))))

#else

#include <stdlib.h>

#pragma intrinsic(_lrotr,_lrotl)
#define rotr(x,n)   _lrotr(x,n)
#define rotl(x,n)   _lrotl(x,n)

#endif

#define	mask_0	0x000000ff
#define	mask_1	0x0000ff00
#define	mask_2	0x00ff0000
#define	mask_3	0xff000000

#define	v_0		0x67452301
#define	v_1		0xefcdab89

/* s_fun(s_fun(s_fun(v)))			*/

#define	k2_0	0x30d32e58
#define	k2_1	0xb89e4984

/* s_fun(s_fun(s_fun(s_fun(v))))	*/

#define	k3_0	0x0957cfec
#define	k3_1	0xd800502e

#define bswap(x)	(rotl(x, 8) & (mask_0 | mask_2) | rotr(x, 8) & (mask_1 | mask_3))

#define s_b(x,y)	rotl(s_box[rotr((x),y) & mask_0],y)

#define s_fun(x)	x = s_box[(x) & mask_0] | s_b(x,8) | s_b(x,16) | s_b(x,24);

#define	sp_fun(a,b)	s_fun(a); s_fun(b); b ^= a; a ^= rotl(b, 16); b ^= rotl(a, 8); a ^= b;

#define	f_fun(a,b,c,d,k)				\
{	u = c ^ *(k); v = d ^ *(k + 1);		\
	sp_fun(u, v);						\
	u ^= *(k + 2); v ^= *(k + 3);		\
	s_fun(u); s_fun(v);					\
	a ^= (u >> 8) | (v << 24);			\
	b ^= (v >> 8) | (u << 24);			\
}

#define bp_fun(a, b, c, d, t, u, v, w)		\
{	a = (t >> 24) | (u >> 8) & mask_1 | (v << 8) & mask_2 | (w << 24);	\
	b = (u >> 24) | (v >> 8) & mask_1 | (w << 8) & mask_2 | (t << 24);	\
	c = (v >> 24) | (w >> 8) & mask_1 | (t << 8) & mask_2 | (u << 24);	\
	d = (w >> 24) | (t >> 8) & mask_1 | (u << 8) & mask_2 | (v << 24);	\
};

u1byte	s_box[] =
{
	225,  66,  62, 129,  78,  23, 158, 253, 180,  63,  44, 218,  49,  30, 224,  65, 
	204, 243, 130, 125, 124,  18, 142, 187, 228,  88,  21, 213, 111, 233,  76,  75, 
	 53, 123,  90, 154, 144,  69, 188, 248, 121, 214,  27, 136,   2, 171, 207, 100, 
	  9,  12, 240,   1, 164, 176, 246, 147,  67,  99, 134, 220,  17, 165, 131, 139, 
	201, 208,  25, 149, 106, 161,  92,  36, 110,  80,  33, 128,  47, 231,  83,  15, 
	145,  34,   4, 237, 166,  72,  73, 103, 236, 247, 192,  57, 206, 242,  45, 190, 
	 93,  28, 227, 135,   7,  13, 122, 244, 251,  50, 245, 140, 219, 143,  37, 150, 
	168, 234, 205,  51, 101,  84,   6, 141, 137,  10,  94, 217,  22,  14, 113, 108, 
     11, 255,  96, 210,  46, 211, 200,  85, 194,  35, 183, 116, 226, 155, 223, 119, 
	 43, 185,  60,  98,  19, 229, 148,  52, 177,  39, 132, 159, 215,  81,   0,  97, 
	173, 133, 115,   3,   8,  64, 239, 104, 254, 151,  31, 222, 175, 102, 232, 184, 
	174, 189, 179, 235, 198, 107,  71, 169, 216, 167, 114, 238,  29, 126, 170, 182, 
	117, 203, 212,  48, 105,  32, 127,  55,  91, 157, 120, 163, 241, 118, 250,   5, 
	 61,  58,  68,  87,  59, 202, 199, 138,  24,  70, 156, 191, 186,  56,  86,  26, 
	146,  77,  38,  41, 162, 152,  16, 153, 112, 160, 197,  40, 193, 109,  20, 172, 
	249,  95,  79, 196, 195, 209, 252, 221, 178,  89, 230, 181,  54,  82,  74,  42
};

u4byte	l_key[72];

u4byte	mod_inv(u4byte x)
{	u4byte	y1, y2, a, b, q;

	y1 = ~(-x / x); y2 = 1;
		
	a = x; b = y1 * x;

	for(;;)
	{
		q = a / b; 
		
		if((a -= q * b) == 0)

			return (x * y1 == 1 ? y1 : -y1);
		
		y2 -= q * y1;

		q = b / a; 
		
		if((b -= q * a) == 0)
		
			return (x * y2 == 1 ? y2 : -y2);

		y1 -= q * y2;
	}
};

void g_fun(u4byte y[8], u4byte l[8], u4byte v[2])
{
	sp_fun(y[0], y[1]); sp_fun(v[0], v[1]); 
	l[0] = v[0] ^= y[0]; l[1] = v[1] ^= y[1];

	sp_fun(y[2], y[3]); sp_fun(v[0], v[1]); 
	l[2] = v[0] ^= y[2]; l[3] = v[1] ^= y[3];

	sp_fun(y[4], y[5]); sp_fun(v[0], v[1]);  
	l[4] = v[0] ^= y[4]; l[5] = v[1] ^= y[5];

	sp_fun(y[6], y[7]); sp_fun(v[0], v[1]); 
	l[6] = v[0] ^= y[6]; l[7] = v[1] ^= y[7];
};

u4byte *set_key(u4byte key_blk[], u4byte key_len)
{	u4byte	lk[8], v[2], lout[8];
	un8byte	*lp = (un8byte*)lout;
	u1byte	*bp = (u1byte*)l_key;
	u4byte	i, j, k;

	v[0] = v_0; v[1] = v_1;

	lk[0] = key_blk[0]; lk[1] = key_blk[1];
	lk[2] = key_blk[2]; lk[3] = key_blk[3];

	lk[4] = (key_len > 4 ? key_blk[4] : k2_0);
	lk[5] = (key_len > 4 ? key_blk[5] : k2_1);

	lk[6] = (key_len > 6 ? key_blk[6] : k3_0);
	lk[7] = (key_len > 6 ? key_blk[7] : k3_1);

	g_fun(lk, lout, v);

	for(i = 0; i < 8; ++i)
	{
		g_fun(lk, lout, v);

		for(j = 0; j < 8; ++j)
		{
			k = 32 * j + 2 * i;

			bp[k]      = (lp + 0)->bb[j];
			bp[k + 16] = (lp + 1)->bb[j];
			bp[k + 1]  = (lp + 2)->bb[j];
			bp[k + 17] = (lp + 3)->bb[j];
		}
	}

	for(i = 52; i < 60; ++i)
	{
		l_key[i] = bswap(l_key[i]) | 1;

		l_key[i + 12] = mod_inv(l_key[i]);
	}

	return (u4byte*)&l_key;
};

void encrypt(u16byte in_blk, u16byte out_blk)
{   u4byte		a,b,c,d,t,u,v,w;

	t = in_blk[0] ^ l_key[48]; t = bswap(t) * l_key[52]; 
	u = in_blk[1] ^ l_key[49]; u = bswap(u) * l_key[53];
	v = in_blk[2] ^ l_key[50]; v = bswap(v) * l_key[54]; 
	w = in_blk[3] ^ l_key[51]; w = bswap(w) * l_key[55];

	bp_fun(a, b, c, d, t, u, v, w);

	f_fun(a, b, c, d, l_key);
	f_fun(c, d, a, b, l_key +  4);

	f_fun(a, b, c, d, l_key +  8);
	f_fun(c, d, a, b, l_key + 12);

	f_fun(a, b, c, d, l_key + 16);
	f_fun(c, d, a, b, l_key + 20);
	
	f_fun(a, b, c, d, l_key + 24);
	f_fun(c, d, a, b, l_key + 28);

	f_fun(a, b, c, d, l_key + 32);
	f_fun(c, d, a, b, l_key + 36);
	
	f_fun(a, b, c, d, l_key + 40);
	f_fun(c, d, a, b, l_key + 44);

	bp_fun(u, v, w, t, a, b, c, d);
	
	t *= l_key[68]; out_blk[0] = bswap(t) ^ l_key[60]; 
	u *= l_key[69]; out_blk[1] = bswap(u) ^ l_key[61];
	v *= l_key[70]; out_blk[2] = bswap(v) ^ l_key[62]; 
	w *= l_key[71]; out_blk[3] = bswap(w) ^ l_key[63];
};

void decrypt(u16byte in_blk, u16byte out_blk)
{   u4byte		a,b,c,d,t,u,v,w;

	t = in_blk[0] ^ l_key[60]; t = bswap(t) * l_key[56]; 
	u = in_blk[1] ^ l_key[61]; u = bswap(u) * l_key[57];
	v = in_blk[2] ^ l_key[62]; v = bswap(v) * l_key[58]; 
	w = in_blk[3] ^ l_key[63]; w = bswap(w) * l_key[59];

	bp_fun(a, b, c, d, t, u, v, w);

	f_fun(a, b, c, d, l_key + 44);
	f_fun(c, d, a, b, l_key + 40);

	f_fun(a, b, c, d, l_key + 36);
	f_fun(c, d, a, b, l_key + 32);

	f_fun(a, b, c, d, l_key + 28);
	f_fun(c, d, a, b, l_key + 24);
	
	f_fun(a, b, c, d, l_key + 20);
	f_fun(c, d, a, b, l_key + 16);

	f_fun(a, b, c, d, l_key + 12);
	f_fun(c, d, a, b, l_key +  8);
	
	f_fun(a, b, c, d, l_key +  4);
	f_fun(c, d, a, b, l_key);

	bp_fun(u, v, w, t, a, b, c, d);
	
	t *= l_key[64]; out_blk[0] = bswap(t) ^ l_key[48]; 
	u *= l_key[65]; out_blk[1] = bswap(u) ^ l_key[49];
	v *= l_key[66]; out_blk[2] = bswap(v) ^ l_key[50]; 
	w *= l_key[67]; out_blk[3] = bswap(w) ^ l_key[51];
};
