/**
 * This class implements the Square block cipher.
 *
 * <P>
 * <b>References</b>
 *
 * <P>
 * The Square algorithm was developed by <a href="mailto:Daemen.J@banksys.com">Joan Daemen</a>
 * and <a href="mailto:vincent.rijmen@esat.kuleuven.ac.be">Vincent Rijmen</a>, and is
 * in the public domain.
 *
 * See
 *      J. Daemen, L.R. Knudsen, V. Rijmen,
 *      "The block cipher Square,"
 *      <cite>Fast Software Encryption Haifa Security Workshop Proceedings</cite>,
 *      LNCS, E. Biham, Ed., Springer-Verlag, to appear.
 *
 * <P>
 * @author  This public domain Java implementation was written by
 * <a href="mailto:pbarreto@nw.com.br">Paulo S.L.M. Barreto</a> based on C software
 * originally written by Vincent Rijmen.  Packaged and stuff by Mr. Tines
 * after the fact.
 *
 * @version 2.1 (1997.08.11)
 *
 *
 *
 * =============================================================================
 *
 * Differences from version 2.0 (1997.07.28)
 *
 * -- Simplified the static initialization by directly using the coefficients of
 *    the diffusion polynomial and its inverse (as chosen in the defining paper)
 *    instead of generating the full diffusion and inverse diffusion matrices
 *    G[][] and iG[][].  This avoids the burden of the matrix inversion code.
 * -- Generalized the code to an arbitrary number of rounds by explicitly
 *    computing the round offsets and explicitly looping the round function.
 * -- Simplified the mappings between byte arrays and Square data blocks.
 *    Together with the other changes, this reduces bytecode
 *
 * THIS SOFTWARE IS PROVIDED BY THE AUTHORS ''AS IS'' AND ANY EXPRESS
 * OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE
 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
 * BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
 * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE
 * OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
 * EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 */

package uk.co.demon.windsong.crypt.cea;

import uk.co.demon.windsong.crypt.cea.CEA;

public final class Square implements CEA
{

    /**
    * Number of bytes in the block
    */
    public  static final int BLOCK_LENGTH = 16;
    /**
    * Number of bytes in the key
    */
    public  static final int KEY_LENGTH = BLOCK_LENGTH;
    /**
    * number of rounds
    */
    private static final int R = 8;

    /**
    * cypher constant tables
    */
    private static final int[] offset = new int[R];
    private static final int[] phi = new int[256];
    private static final int[] Se = new int[256];
    private static final int[] Sd = new int[256];
    private static final int[] Te = new int[256];
    private static final int[] Td = new int[256];

    ////////////////////////////////////////////////////////////////////////////

    private static final int ROOT = 0x1f5;
    private static final int[] exp = new int[256];
    private static final int[] log = new int[256];


    /**
    * multiply two elements of GF(2**8)
    */
    private static final int mul (int a, int b)
    {
        return (a == 0 || b == 0) ? 0 :
            exp[(log[a] + log[b]) % 255];
    } // mul


    static {
        /* produce log and exp, needed for multiplying in the field GF(2**8):
         */
        exp[0] = exp[255] = 1;
        log[1] = 0;
        for (int i = 1; i < 255; i++) {
            int j = exp[i - 1] << 1; // 0x02 is used as generator (mod ROOT)
            if (j >= 256) {
                j ^= ROOT; // reduce j (mod ROOT)
            }
            exp[i] = j;
            log[j] = i;
        }

        /* compute the substitution box Se[] and its inverse Sd[]
         * based on F(x) = x**{-1} plus affine transform of the output
         */
        Se[0] = 0;
        for (int i = 1; i < 256; i++) {
            Se[i] = exp[255 - log[i]]; // Se[i] = i^{-1}, i.e. mul(Se[i], i) == 1
        }
        /* the selection criterion for the actual affine transform is that
         * the bit matrix corresponding to its linear has a triangular structure:
           0x01     00000001
           0x03     00000011
           0x05     00000101
           0x0f     00001111
           0x1f     00011111
           0x3d     00111101
           0x7b     01111011
           0xd6     11010110
         */
        int[] trans = {0x01, 0x03, 0x05, 0x0f, 0x1f, 0x3d, 0x7b, 0xd6};
        for (int i = 0; i < 256; i++) {
            /* let Se[i] be represented as an 8-row vector V over GF(2);
             * the affine transformation is A*V + T, where the rows of
             * the 8x8 matrix A are contained in trans[0]...trans[7] and
             * the 8-row vector T is contained in trans[8] above.
             */
            int v = 0xb1; // this is the affine part of the transform
            for (int t = 0; t < 8; t++) {
                // column-wise multiplication over GF(2):
                int u = Se[i] & trans[t];
                // sum over GF(2) of all bits of u:
                u ^= u >> 4; u ^= u >> 2; u ^= u >> 1; u &= 1;
                // row alignment of the result:
                v ^= u << t;
            }
            Se[i] = v;
            Sd[v] = i; // inverse substitution box
        }
        /* diffusion and inverse diffusion polynomials:
         * by definition (cf. "The block cipher Square", section 2.1),
         * c(x)d(x) = 1 (mod 1 + x**4)
         * where the polynomial coefficients are taken from GF(2**8);
         * the actual polynomial and its inverse are:
         * c(x) = 3.x**3 + 1.x**2 + 1.x + 2
         * d(x) = B.x**3 + D.x**2 + 9.x + E
         */
        int[] c = {0x2, 0x1, 0x1, 0x3};
        int[] d = {0xE, 0x9, 0xD, 0xB};

        /* substitution/diffusion layers and key schedule transform:
         */
        int v;
        for (int t = 0; t < 256; t++) {
            phi[t] =
                mul (t, c[3]) << 24 ^
                mul (t, c[2]) << 16 ^
                mul (t, c[1]) <<  8 ^
                mul (t, c[0]);
            v = Se[t];
            Te[t] = (Se[t & 3] == 0) ? 0 :
                mul (v, c[3]) << 24 ^
                mul (v, c[2]) << 16 ^
                mul (v, c[1]) <<  8 ^
                mul (v, c[0]);
            v = Sd[t];
            Td[t] = (Sd[t & 3] == 0) ? 0 :
                mul (v, d[3]) << 24 ^
                mul (v, d[2]) << 16 ^
                mul (v, d[1]) <<  8 ^
                mul (v, d[0]);
        }
        /* offset table:
         */
        offset[0] = 0x1;
        for (int i = 1; i < R; i++) {
            offset[i] = mul (offset[i - 1], 0x2);
        }
    } // static


    ////////////////////////////////////////////////////////////////////////////


    private static final int rotr (int x, int s)
    {
        return (x >>> s) | (x <<  (32 - s));
    } // rotr


    private static final int rotl (int x, int s)
    {
        return (x <<  s) | (x >>> (32 - s));
    } // rotl


    /* apply the theta function to a round key:
     */
    private final void transform (int[] roundKey)
    {
    	roundKey[0] = phi[(roundKey[0]       ) & 0xff] ^
        		rotl (phi[(roundKey[0] >>>  8) & 0xff],  8) ^
        		rotl (phi[(roundKey[0] >>> 16) & 0xff], 16) ^
        		rotl (phi[(roundKey[0] >>> 24) & 0xff], 24);
    	roundKey[1] = phi[(roundKey[1]       ) & 0xff] ^
        		rotl (phi[(roundKey[1] >>>  8) & 0xff],  8) ^
        		rotl (phi[(roundKey[1] >>> 16) & 0xff], 16) ^
        		rotl (phi[(roundKey[1] >>> 24) & 0xff], 24);
    	roundKey[2] = phi[(roundKey[2]       ) & 0xff] ^
        		rotl (phi[(roundKey[2] >>>  8) & 0xff],  8) ^
        		rotl (phi[(roundKey[2] >>> 16) & 0xff], 16) ^
        		rotl (phi[(roundKey[2] >>> 24) & 0xff], 24);
    	roundKey[3] = phi[(roundKey[3]       ) & 0xff] ^
        		rotl (phi[(roundKey[3] >>>  8) & 0xff],  8) ^
        		rotl (phi[(roundKey[3] >>> 16) & 0xff], 16) ^
        		rotl (phi[(roundKey[3] >>> 24) & 0xff], 24);
    } // transform


    private class Keyschedule
    {
        int[][] roundKeys_e = null;
        int[][] roundKeys_d = null;
        /**
         * This creates a Square block cipher from a byte array user key.
         * @param key   The 128-bit user key.
         */
        public Keyschedule (byte[] key, int keyoffset)
        {
            roundKeys_e = new int[R+1][4];
            roundKeys_d = new int[R+1][4];
            // map user key to first round key:
            for (int i = 0; i < 16; i += 4) {
                roundKeys_e[0][i >> 2] =
                    ((int)key[keyoffset+i    ] & 0xff)       |
                    ((int)key[keyoffset+i + 1] & 0xff) <<  8 |
                    ((int)key[keyoffset+i + 2] & 0xff) << 16 |
                    ((int)key[keyoffset+i + 3] & 0xff) << 24;
            }
    	    for (int t = 1; t <= R; t++) {
    		    // apply the key evolution function:
    		    roundKeys_d[R-t][0] = roundKeys_e[t][0] =
                    roundKeys_e[t-1][0] ^ rotr (roundKeys_e[t-1][3], 8) ^ offset[t-1];
    		    roundKeys_d[R-t][1] = roundKeys_e[t][1] =
                    roundKeys_e[t-1][1] ^ roundKeys_e[t][0];
    		    roundKeys_d[R-t][2] = roundKeys_e[t][2] =
                    roundKeys_e[t-1][2] ^ roundKeys_e[t][1];
    		    roundKeys_d[R-t][3] = roundKeys_e[t][3] =
                    roundKeys_e[t-1][3] ^ roundKeys_e[t][2];
    		    // apply the theta diffusion function:
    		    transform (roundKeys_e[t-1]);
    	    }

        	for (int i = 0; i < 4; i++) {
        	    roundKeys_d[R][i] = roundKeys_e[0][i];
    	    }
        } // Keyschedule


        /**
        * The round function to transform an intermediate data block <code>block</code> with
        * the substitution-diffusion table <code>T</code> and the round key <code>roundKey</code>
        * @param   block       the data block
        * @param   T           the substitution-diffusion table
        * @param   roundKey    the 128-bit round key
        */
        private final void round (int[] block, int[] T, int[] roundKey)
        {
            int t0, t1, t2, t3;

            t0 = block[0];
            t1 = block[1];
            t2 = block[2];
            t3 = block[3];

    	    block[0] =  T[(t0       ) & 0xff]
			    ^ rotl (T[(t1       ) & 0xff],  8)
			    ^ rotl (T[(t2       ) & 0xff], 16)
			    ^ rotl (T[(t3       ) & 0xff], 24)
			    ^ roundKey[0];
    	    block[1] =  T[(t0 >>>  8) & 0xff]
			    ^ rotl (T[(t1 >>>  8) & 0xff],  8)
			    ^ rotl (T[(t2 >>>  8) & 0xff], 16)
			    ^ rotl (T[(t3 >>>  8) & 0xff], 24)
    		    ^ roundKey[1];
    	    block[2] =  T[(t0 >>> 16) & 0xff]
			    ^ rotl (T[(t1 >>> 16) & 0xff],  8)
			    ^ rotl (T[(t2 >>> 16) & 0xff], 16)
			    ^ rotl (T[(t3 >>> 16) & 0xff], 24)
    		    ^ roundKey[2];
    	    block[3] =  T[(t0 >>> 24) & 0xff]
			    ^ rotl (T[(t1 >>> 24) & 0xff],  8)
			    ^ rotl (T[(t2 >>> 24) & 0xff], 16)
			    ^ rotl (T[(t3 >>> 24) & 0xff], 24)
    		    ^ roundKey[3];

            // destroy potentially sensitive information:
            t0 = t1 = t2 = t3 = 0;
        } // round


        /**
         * Encrypt a block.
         * The in and out buffers can be the same.
         * @param in            The data to be encrypted.
         * @param in_offset     The start of data within the in buffer.
         * @param out           The encrypted data.
         * @param out_offset    The start of data within the out buffer.
         */
        public final void blockEncrypt (byte in[], int in_offset, byte out[], int out_offset)
        {
            int[] block = new int[4];

            // map byte array to block and add initial key:
            for (int i = 0; i < 4; i++) {
                block[i] =
                ((int)in[in_offset++] & 0xff)       ^
                ((int)in[in_offset++] & 0xff) <<  8 ^
                ((int)in[in_offset++] & 0xff) << 16 ^
                ((int)in[in_offset++] & 0xff) << 24 ^
                roundKeys_e[0][i];
            }

    	    // R - 1 full rounds:
    	    for (int r = 1; r < R; r++) {
        	    round (block, Te, roundKeys_e[r]);
    	    }

    	    // last round (diffusion becomes only transposition):
    	    round (block, Se, roundKeys_e[R]);

            // map block to byte array:
            for (int i = 0; i < 4; i++) {
                int w = block[i];
                out[out_offset++] = (byte)(w       );
                out[out_offset++] = (byte)(w >>>  8);
                out[out_offset++] = (byte)(w >>> 16);
                out[out_offset++] = (byte)(w >>> 24);
            }
        } // blockEncrypt


        /**
        * Decrypt a block.
        * The in and out buffers can be the same.
        * @param in            The data to be decrypted.
        * @param in_offset     The start of data within the in buffer.
        * @param out           The decrypted data.
        * @param out_offset    The start of data within the out buffer.
        */
        public final void blockDecrypt (byte in[], int in_offset, byte out[], int out_offset)
        {
            int[] block = new int[4];

            // map byte array to block and add initial key:
            for (int i = 0; i < 4; i++) {
                block[i] =
                ((int)in[in_offset++] & 0xff)       ^
                ((int)in[in_offset++] & 0xff) <<  8 ^
                ((int)in[in_offset++] & 0xff) << 16 ^
                ((int)in[in_offset++] & 0xff) << 24 ^
                roundKeys_d[0][i];
            }

    	    // R - 1 full rounds:
    	    for (int r = 1; r < R; r++) {
        	    round (block, Td, roundKeys_d[r]);
    	    }

    	    // last round (diffusion becomes only transposition):
    	    round (block, Sd, roundKeys_d[R]);

            // map block to byte array:
            int w;
            for (int i = 0; i < 4; i++) {
                w = block[i];
                out[out_offset++] = (byte)(w       );
                out[out_offset++] = (byte)(w >>>  8);
                out[out_offset++] = (byte)(w >>> 16);
                out[out_offset++] = (byte)(w >>> 24);
            }

    	    // destroy sensitive data:
    	    w = 0;
            for (int i = 0; i < 4; i++) {
                block[i] = 0;
            }
        } // blockDecrypt


        /**
        * Wipe key schedule information
        */
        public final void destroy ()
        {
            for (int r = 0; r <= R; r++) {
                for (int i = 0; i < 4; i++) {
                    roundKeys_e[r][i] = roundKeys_d[r][i] = 0;
                }
                roundKeys_e[r] = null;
                roundKeys_d[r] = null;
            }
            roundKeys_e = null;
            roundKeys_d = null;
        } // finalize
    } // keyschedule

    /**
    * Is the jacket doing triple encryption?
    */
    private boolean triple = false;
    /**
    * The key schedule data
    */
    private Keyschedule[] ks = null;

    /**
    * Initialise the object with one or three key blocks
    * @param key array of key bytes, 1 or 3 key block lengths
    * @param triple true if three keys for triple application
    */
    public void init(byte[] key, int keyoffset, boolean triple)
    {
        this.triple = triple;
	    int keys = triple ? 3 : 1;
	    int i;

        ks = new Keyschedule[keys];
	    for(i=0; i < keys; i++)
	    {
            ks[i] = new Keyschedule(key, keyoffset+i*KEY_LENGTH);
	    }
    }

    /**
    * Transform one block in ecb mode
    * @param encrypt true if forwards transformation
    * @param in input block
    * @param offin offset into block of input data
    * @param out output block
    * @param offout offset into block of output data
    */
    public final void ecb(boolean encrypt, byte[] in, int offin,
        byte[] out, int offout)
    {
        if(!triple)
        {
   	        if(encrypt) ks[0].blockEncrypt(in, offin, out, offout);
            else ks[0].blockDecrypt(in, offin, out, offout);
        }
        else
        {
   	        byte[] tmp = new byte[BLOCK_LENGTH];
            byte[] tmp2 = new byte[BLOCK_LENGTH];
            if(encrypt)
            {
      	        ks[0].blockEncrypt(in, offin, tmp, 0);
			    ks[1].blockEncrypt(tmp, 0, tmp2, 0);
			    ks[0].blockEncrypt(tmp2, 0, out, offout);
            }
            else
            {
      	        ks[0].blockDecrypt(in, offin, tmp, 0);
			    ks[1].blockDecrypt(tmp, 0, tmp2, 0);
			    ks[0].blockDecrypt(tmp2, 0, out, offout);
            }
            for(int i=0; i<BLOCK_LENGTH;++i)
            {
                tmp[i] = 0;
                tmp2[i] = 0;
            }
        }
    }

    /**
    * Wipe key schedule information
    */
    public final void destroy()
    {
        if(ks != null)
        {
            for(int i=0; i<ks.length; ++i)
            {
                ks[i].destroy();
                ks[i] = null;
            }
            ks = null;
        }
        triple = false;
    }

    /**
    * Provide infomation of desired key size
    * @return byte length of key
    */
    public int getKeysize()
    {
        return KEY_LENGTH;
    }

    /**
    * Provide infomation of algorithm block size
    * @return byte length of block
    */
    public int getBlocksize()
    {
        return BLOCK_LENGTH;
    }


    /**
    * drive the test vectors
    */
    public static void main(String[] args)
    {
        test();
    }

    /**
    * Digit values for output formatting
    */
    private static final String hexDigit [] = {
        "0", "1", "2", "3", "4", "5", "6", "7",
        "8", "9", "a", "b", "c", "d", "e", "f",
    };

    /**
    * format output for test vectors
    */
    protected static void printBuffer (byte[] buf, int len, String tag)
    {
        String s = new String ();

    	for (int i = 0; i < len; i++) {
    	    byte b = buf[i];
    		s += hexDigit[(b >>> 4) & 0x0f] + hexDigit[b & 0x0f];
    		if (((i + 1) & 15) == 0) {
    			s += " "; // put a space every 16 bytes
    		} else if (((i + 1) & 31) == 0) {
    			s += "\n"; // break line every 32 bytes
    		}
    	}
    	if ((len & 15) == 0 || (len & 31) == 0) {
    	    System.out.println (s + tag);
    	} else {
    	    System.out.println (s + " " + tag);
    	}
    } // printBuffer

    /**
    * drives the test vectors
    */
    public static void test()
    {
        byte[] key = {
            0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
            0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
            0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
            0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
        };
        System.out.println ("Raw test:");
        byte[] plaintext = {
            0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
            0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
        };
    	byte[] ciphertext = {
    		(byte)0x7c, (byte)0x34, (byte)0x91, (byte)0xd9,
    		(byte)0x49, (byte)0x94, (byte)0xe7, (byte)0x0f,
    		(byte)0x0e, (byte)0xc2, (byte)0xe7, (byte)0xa5,
    		(byte)0xcc, (byte)0xb5, (byte)0xa1, (byte)0x4f,
        };
        byte[] data = new byte[Square.BLOCK_LENGTH];
    	printBuffer (key, key.length, "key");
    	printBuffer (plaintext, plaintext.length, "plaintext");
   	    Square sq = new Square ();
        sq.init(key, 0, false);
    	sq.ecb (true, plaintext, 0, data, 0);
    	CHECK_ENCRYPTION: {
        	for (int i = 0; i < Square.BLOCK_LENGTH; i++) {
        	    if (data[i] != ciphertext[i]) {
            		printBuffer (data, data.length, "encrypted(ERROR)");
            		printBuffer (ciphertext, ciphertext.length, "expected");
            		break CHECK_ENCRYPTION;
            	}
        	}
       		printBuffer (data, data.length, "encrypted(OK)");
       	}
    	sq.ecb (false, data, 0, data, 0);
    	CHECK_DECRYPTION: {
        	for (int i = 0; i < Square.BLOCK_LENGTH; i++) {
        	    if (data[i] != plaintext[i]) {
            		printBuffer (data, data.length, "decrypted(ERROR)");
            		break CHECK_DECRYPTION;
            	}
        	}
       		printBuffer (data, data.length, "decrypted(OK)");
       	}

        int iterations = 10000;

       	if (iterations > 0) {
       	    System.out.println ("Speed test for " + iterations + " iterations:");
       	    long elapsed; float secs;
       	    System.out.println ("Measuring encryption speed...");
           	// measure encryption speed:
            elapsed = -System.currentTimeMillis ();
            for (int i = 0; i < iterations; i++) {
    	        sq.ecb (true, data, 0, data, 0);
            }
            elapsed += System.currentTimeMillis ();
            secs = (elapsed > 1) ? (float)elapsed/1000 : 1;
            System.out.println ("Elapsed time = " + secs + ", speed = "
                + ((float)iterations*Square.BLOCK_LENGTH/1024/secs) + " kbytes/s");
       	    System.out.println ("Measuring decryption speed...");
           	// measure encryption speed:
            elapsed = -System.currentTimeMillis ();
            for (int i = 0; i < iterations; i++) {
    	        sq.ecb (false, data, 0, data, 0);
            }
            elapsed += System.currentTimeMillis ();
            secs = (elapsed > 1) ? (float)elapsed/1000 : 1;
            System.out.println ("Elapsed time = " + secs + ", speed = "
                + ((float)iterations*Square.BLOCK_LENGTH/1024/secs) + " kbytes/s");
        }
    }
}
