//-----------------------------------------------------------------------------
// F326_AES_Cipher.c
//-----------------------------------------------------------------------------
// Copyright 2007 Silicon Laboratories, Inc.
// http://www.silabs.com
//
// Program Description:
//
//
// How To Test:    See Readme.txt
//
//         
// Target:         C8051F326
// Tool chain:     Keil C51 7.50 / Keil EVAL C51
//                 Silicon Laboratories IDE version 2.91
// Command Line:   See Readme.txt
// Project Name:   F326_AES
//
//
// Release 1.0
//    -Initial Revision (CG/GP)
//    -11 JUN 2007
//

//-----------------------------------------------------------------------------
// Includes
//-----------------------------------------------------------------------------

#include "F326_AES_Typedef.h"
#include "F326_AES_Parameters.h"
#include "F326_AES_Sbox.h"
#include "F326_AES_Cipher.h"

//-----------------------------------------------------------------------------
// Global Constants
//-----------------------------------------------------------------------------

extern pdata byte EXP_KEYS[];

static byte data State[4][4];          // State vector; stored explicitly in
                                       // RAM for speed purposes

static byte data CurrentKey[4][4];     // Stores the keys for the current round
                                       // of encryption

//-----------------------------------------------------------------------------
// Prototypes
//-----------------------------------------------------------------------------

void Cipher (byte *in, byte *out);

// Functions defined in the AES specification; Core functions of the encryption
void SubBytes    (void);
void ShiftRows   (void);
void MixColumns  (void);
void AddRoundKey (void);

void StateIn     (byte *in);
void StateOut    (byte *out);
void LoadKeys    (char i);
byte xtime       (byte input);
static byte FFMultiply (byte x, byte y);

//-----------------------------------------------------------------------------
// Support Subroutines
//-----------------------------------------------------------------------------

//-----------------------------------------------------------------------------
// Cipher
//-----------------------------------------------------------------------------
//
// Return Value : None
// Parameters   : 1) byte *in - pointer to array of input, plaintext data
//                2) byte *out - pointer to array output, encrypted data
//
// Procedure that encrypts a message using AES/Rijndael 
//
//-----------------------------------------------------------------------------

void Cipher (byte *in, byte *out)
{
   char r;

   StateIn (in);                       // Load string to be encrypted

   LoadKeys (0);
   AddRoundKey ();                     // Initial key addition (round key)

   for(r = 1; r < Nr; r++)             // Normal rounds
   {
      SubBytes ();                     // S-Table substitution
      ShiftRows ();                    // Shift rows
      MixColumns ();                   // Mix columns
      LoadKeys (r);                    // Load keys for current round
      AddRoundKey ();                  // Round key addition
   }

   SubBytes ();                        // Final S-Table substitution
   ShiftRows ();                       // Final shift rows
   LoadKeys (r);                       // Load keys for inal round
   AddRoundKey ();                     // Add final round key

   StateOut (out);                     // Copy results
}

//-----------------------------------------------------------------------------
// StateIn
//-----------------------------------------------------------------------------
//
// Return Value : None
// Parameters   : 1) byte *in - pointer to the input which is plaintext data
//
// Copies the plaintext data to the global State array
//
//-----------------------------------------------------------------------------

static void StateIn(byte *in)
{
   char col,row;

   for(col = 0; col < 4; col++) {
      for(row = 0; row < 4; row++) {
         State[row][col] = *(in++); }}
}

//-----------------------------------------------------------------------------
// StateOut
//-----------------------------------------------------------------------------
//
// Return Value : None
// Parameters   : 1) byte *out - pointer to the output which is encrypted data
//
// Copies the encrypted data from the global State array to the output array
//
//-----------------------------------------------------------------------------

static void StateOut (byte *out)
{
   byte col,row;

   for(col = 0; col < 4; col++) {
      for(row = 0; row < 4; row++) {
         *(out++) = State[row][col]; }}
}

//-----------------------------------------------------------------------------
// SubBytes
//-----------------------------------------------------------------------------
//
// Return Value : None
// Parameters   : None
//
// Executes the byte substitution using the Substitution Table (S-box)
//
//-----------------------------------------------------------------------------

void SubBytes (void)
{
	State[0][0] = Sbox[State[0][0]];
	State[0][1] = Sbox[State[0][1]];
	State[0][2] = Sbox[State[0][2]];
	State[0][3] = Sbox[State[0][3]];
	State[1][0] = Sbox[State[1][0]];
	State[1][1] = Sbox[State[1][1]];
	State[1][2] = Sbox[State[1][2]];
	State[1][3] = Sbox[State[1][3]];
	State[2][0] = Sbox[State[2][0]];
	State[2][1] = Sbox[State[2][1]];
	State[2][2] = Sbox[State[2][2]];
	State[2][3] = Sbox[State[2][3]];
	State[3][0] = Sbox[State[3][0]];
	State[3][1] = Sbox[State[3][1]];
	State[3][2] = Sbox[State[3][2]];
	State[3][3] = Sbox[State[3][3]];
}

//-----------------------------------------------------------------------------
// ShiftRows
//-----------------------------------------------------------------------------
//
// Return Value : None
// Parameters   : None
//
// Executes the row shifts; Only Rows 1, 2, and 3 shift.  Row 0 does not.
//
//-----------------------------------------------------------------------------

void ShiftRows (void)
{
    byte hold;
    
	// Shift Row 1 left one column and wrap around
    hold = State[1][0];               
    State[1][0] = State[1][1];
    State[1][1] = State[1][2];
    State[1][2] = State[1][3];
    State[1][3] = hold;
    
    // Shift Row 2 left two columns and wrap around
    // Implemented as a byte swap
    hold = State[2][0];
    State[2][0] = State[2][2];
    State[2][2] = hold;
    hold = State[2][1];
    State[2][1] = State[2][3];
    State[2][3] = hold;

    // Shift Row 3 left three columns;  Implemented as a right-shift for speed
    hold = State[3][3];                
	State[3][3] = State[3][2];
	State[3][2] = State[3][1];
	State[3][1] = State[3][0];
	State[3][0] = hold;
}

//-----------------------------------------------------------------------------
// MixColumns
//-----------------------------------------------------------------------------
//
// Return Value : None
// Parameters   : None
//
// MixColumns step using the xtime() function.
//
//-----------------------------------------------------------------------------

void MixColumns (void)
{
   byte aux0,aux1,aux2,aux3;

   aux1 = State[0][0] ^ State[1][0];
   aux3 = State[2][0]^State[3][0];
   aux0 = aux1 ^ aux3;
   aux2 = State[2][0]^State[1][0];
   aux1 = xtime(aux1);
   aux2 = xtime(aux2);
   aux3 = xtime(aux3);
   State[0][0] = aux0^aux1^State[0][0];
   State[1][0] = aux0^aux2^State[1][0];
   State[2][0] = aux0^aux3^State[2][0];
   State[3][0] = State[0][0]^State[1][0]^State[2][0]^aux0;

   aux1 = State[0][1] ^ State[1][1];
   aux3 = State[2][1]^State[3][1];
   aux0 = aux1 ^ aux3;
   aux2 = State[2][1]^State[1][1];
   aux1 = xtime(aux1);
   aux2 = xtime(aux2);
   aux3 = xtime(aux3);
   State[0][1] = aux0^aux1^State[0][1];
   State[1][1] = aux0^aux2^State[1][1];
   State[2][1] = aux0^aux3^State[2][1];
   State[3][1] = State[0][1]^State[1][1]^State[2][1]^aux0;

   aux1 = State[0][2] ^ State[1][2];
   aux3 = State[2][2]^State[3][2];
   aux0 = aux1 ^ aux3;
   aux2 = State[2][2]^State[1][2];
   aux1 = xtime(aux1);
   aux2 = xtime(aux2);
   aux3 = xtime(aux3);
   State[0][2] = aux0^aux1^State[0][2];
   State[1][2] = aux0^aux2^State[1][2];
   State[2][2] = aux0^aux3^State[2][2];
   State[3][2] = State[0][2]^State[1][2]^State[2][2]^aux0;

   aux1 = State[0][3] ^ State[1][3];
   aux3 = State[2][3]^State[3][3];
   aux0 = aux1 ^ aux3;
   aux2 = State[2][3]^State[1][3];
   aux1 = xtime(aux1);
   aux2 = xtime(aux2);
   aux3 = xtime(aux3);
   State[0][3] = aux0^aux1^State[0][3];
   State[1][3] = aux0^aux2^State[1][3];
   State[2][3] = aux0^aux3^State[2][3];
   State[3][3] = State[0][3]^State[1][3]^State[2][3]^aux0;
}


//-----------------------------------------------------------------------------
// AddRoundKey
//-----------------------------------------------------------------------------
//
// Return Value : None
// Parameters   : None
//
// Procedure that adds round keys to the state
//
//-----------------------------------------------------------------------------

static void AddRoundKey (void)
{
    State[0][0] ^= CurrentKey[0][0];
	State[0][1] ^= CurrentKey[0][1];
	State[0][2] ^= CurrentKey[0][2];
	State[0][3] ^= CurrentKey[0][3];

	State[1][0] ^= CurrentKey[1][0];
	State[1][1] ^= CurrentKey[1][1];
	State[1][2] ^= CurrentKey[1][2];
	State[1][3] ^= CurrentKey[1][3];

    State[2][0] ^= CurrentKey[2][0];
	State[2][1] ^= CurrentKey[2][1];
	State[2][2] ^= CurrentKey[2][2];
	State[2][3] ^= CurrentKey[2][3];

    State[3][0] ^= CurrentKey[3][0];
	State[3][1] ^= CurrentKey[3][1];
	State[3][2] ^= CurrentKey[3][2];
	State[3][3] ^= CurrentKey[3][3];
}

//-----------------------------------------------------------------------------
// LoadKeys
//-----------------------------------------------------------------------------
//
// Return Value : None
// Parameters   : None
//
// Procedure that loads the current key from Flash into RAM
//
//-----------------------------------------------------------------------------

static void LoadKeys (char i)
{
   // Change index i from the number of the round to the start of round in
   // in the linear array of keys
   unsigned char index = (unsigned char) i * 16;           

   CurrentKey[0][0] = EXP_KEYS[index++];
   CurrentKey[1][0] = EXP_KEYS[index++];
   CurrentKey[2][0] = EXP_KEYS[index++];
   CurrentKey[3][0] = EXP_KEYS[index++];

   CurrentKey[0][1] = EXP_KEYS[index++];
   CurrentKey[1][1] = EXP_KEYS[index++];
   CurrentKey[2][1] = EXP_KEYS[index++];
   CurrentKey[3][1] = EXP_KEYS[index++];

   CurrentKey[0][2] = EXP_KEYS[index++];
   CurrentKey[1][2] = EXP_KEYS[index++];
   CurrentKey[2][2] = EXP_KEYS[index++];
   CurrentKey[3][2] = EXP_KEYS[index++];

   CurrentKey[0][3] = EXP_KEYS[index++];
   CurrentKey[1][3] = EXP_KEYS[index++];
   CurrentKey[2][3] = EXP_KEYS[index++];
   CurrentKey[3][3] = EXP_KEYS[index++];


}


//-----------------------------------------------------------------------------
// xtime
//-----------------------------------------------------------------------------
//
// Return Value : None
// Parameters   : None
//
// Performs a multiplication of a polynomial (input) by x, and then performs a
// modulo by m(x); m(x) is (x^8 + x^4 + x^3 + x + 1). At the byte level, this
// is implemented by left shifting the polynomial and then XORing with 0x1B if
// the shifted-out bit is 1.
//
// The finite field multiply of encrypt can also be implemented using the
// lookup table solution used in decryption, but xtime() is faster.
//
//-----------------------------------------------------------------------------

static byte xtime (byte input)
{
   if (input & 0x80)
   {
      input = input + input;              // Left shift input one bit
      input ^= 0x1B;                      // Performs the modulo with m(x);
   }
   else
   {
      input = input + input;
   }

   return input;
}

//-----------------------------------------------------------------------------
// End Of File
//-----------------------------------------------------------------------------