# -*- coding: iso-8859-1 -*-
""" crypto.cipher.rijndael

    Rijndael encryption algorithm

    This byte oriented implementation is intended to closely
    match FIPS specification for readability.  It is not implemented
    for performance.

    Copyright © (c) 2002 by Paul A. Lambert
    Read LICENSE.txt for license information.

    2002-06-01
"""

from crypto.cipher.base import BlockCipher, padWithPadLen, noPadding

class Rijndael(BlockCipher):
    """ Rijndael encryption algorithm """
    def __init__(self, key = None, padding = padWithPadLen(), keySize=16, blockSize=16 ):
        self.name       = 'RIJNDAEL'
        self.keySize    = keySize
        self.strength   = keySize*8
        self.blockSize  = blockSize  # blockSize is in bytes
        self.padding    = padding    # change default to noPadding() to get normal ECB behavior

        assert( keySize%4==0 and NrTable[4].has_key(keySize/4)),'key size must be 16,20,24,29 or 32 bytes'
        assert( blockSize%4==0 and NrTable.has_key(blockSize/4)), 'block size must be 16,20,24,29 or 32 bytes'

        self.Nb = self.blockSize/4          # Nb is number of columns of 32 bit words
        self.Nk = keySize/4                 # Nk is the key length in 32-bit words
        self.Nr = NrTable[self.Nb][self.Nk] # The number of rounds (Nr) is a function of
                                            # the block (Nb) and key (Nk) sizes.
        if key != None:
            self.setKey(key)

    def setKey(self, key):
        """ Set a key and generate the expanded key """
        assert( len(key) == (self.Nk*4) ), 'Key length must be same as keySize parameter'
        self.__expandedKey = keyExpansion(self, key)
        self.reset()                   # BlockCipher.reset()

    def encryptBlock(self, plainTextBlock):
        """ Encrypt a block, plainTextBlock must be a array of bytes [Nb by 4] """
        self.state = self._toBlock(plainTextBlock)
        AddRoundKey(self, self.__expandedKey[0:self.Nb])
        for round in range(1,self.Nr):          #for round = 1 step 1 to Nr–1
            SubBytes(self)
            ShiftRows(self)
            MixColumns(self)
            AddRoundKey(self, self.__expandedKey[round*self.Nb:(round+1)*self.Nb])
        SubBytes(self)
        ShiftRows(self)
        AddRoundKey(self, self.__expandedKey[self.Nr*self.Nb:(self.Nr+1)*self.Nb])
        return self._toBString(self.state)


    def decryptBlock(self, encryptedBlock):
        """ decrypt a block (array of bytes) """
        self.state = self._toBlock(encryptedBlock)
        AddRoundKey(self, self.__expandedKey[self.Nr*self.Nb:(self.Nr+1)*self.Nb])
        for round in range(self.Nr-1,0,-1):
            InvShiftRows(self)
            InvSubBytes(self)
            AddRoundKey(self, self.__expandedKey[round*self.Nb:(round+1)*self.Nb])
            InvMixColumns(self)
        InvShiftRows(self)
        InvSubBytes(self)
        AddRoundKey(self, self.__expandedKey[0:self.Nb])
        return self._toBString(self.state)

    def _toBlock(self, bs):
        """ Convert binary string to array of bytes, state[col][row]"""
        assert ( len(bs) == 4*self.Nb ), 'Rijndarl blocks must be of size blockSize'
        return [[ord(bs[4*i]),ord(bs[4*i+1]),ord(bs[4*i+2]),ord(bs[4*i+3])] for i in range(self.Nb)]

    def _toBString(self, block):
        """ Convert block (array of bytes) to binary string """
        l = []
        for col in block:
            for rowElement in col:
                l.append(chr(rowElement))
        return ''.join(l)
#-------------------------------------
"""    Number of rounds Nr = NrTable[Nb][Nk]

            Nb  Nk=4   Nk=5   Nk=6   Nk=7   Nk=8
            -------------------------------------   """
NrTable =  {4: {4:10,  5:11,  6:12,  7:13,  8:14},
            5: {4:11,  5:11,  6:12,  7:13,  8:14},
            6: {4:12,  5:12,  6:12,  7:13,  8:14},
            7: {4:13,  5:13,  6:13,  7:13,  8:14},
            8: {4:14,  5:14,  6:14,  7:14,  8:14}}
#-------------------------------------
def keyExpansion(algInstance, keyString):
    """ Expand a string of size keySize into a larger array """
    Nk, Nb, Nr = algInstance.Nk, algInstance.Nb, algInstance.Nr # for readability
    key = [ord(byte) for byte in keyString]  # convert string to list
    w = [[key[4*i],key[4*i+1],key[4*i+2],key[4*i+3]] for i in range(Nk)]
    for i in range(Nk,Nb*(Nr+1)):
        temp = w[i-1]        # a four byte column
        if (i%Nk) == 0 :
            temp     = temp[1:]+[temp[0]]  # RotWord(temp)
            temp     = [ Sbox[byte] for byte in temp ]
            temp[0] ^= Rcon[i/Nk]
        elif Nk > 6 and  i%Nk == 4 :
            temp     = [ Sbox[byte] for byte in temp ]  # SubWord(temp)
        w.append( [ w[i-Nk][byte]^temp[byte] for byte in range(4) ] )
    return w

Rcon = (0,0x01,0x02,0x04,0x08,0x10,0x20,0x40,0x80,0x1b,0x36,     # note extra '0' !!!
        0x6c,0xd8,0xab,0x4d,0x9a,0x2f,0x5e,0xbc,0x63,0xc6,
        0x97,0x35,0x6a,0xd4,0xb3,0x7d,0xfa,0xef,0xc5,0x91)

#-------------------------------------
def AddRoundKey(algInstance, keyBlock):
    """ XOR the algorithm state with a block of key material """
    for column in range(algInstance.Nb):
        for row in range(4):
            algInstance.state[column][row] ^= keyBlock[column][row]
#-------------------------------------

def SubBytes(algInstance):
    for column in range(algInstance.Nb):
        for row in range(4):
            algInstance.state[column][row] = Sbox[algInstance.state[column][row]]

def InvSubBytes(algInstance):
    for column in range(algInstance.Nb):
        for row in range(4):
            algInstance.state[column][row] = InvSbox[algInstance.state[column][row]]

Sbox =    (0x63,0x7c,0x77,0x7b,0xf2,0x6b,0x6f,0xc5,
           0x30,0x01,0x67,0x2b,0xfe,0xd7,0xab,0x76,
           0xca,0x82,0xc9,0x7d,0xfa,0x59,0x47,0xf0,
           0xad,0xd4,0xa2,0xaf,0x9c,0xa4,0x72,0xc0,
           0xb7,0xfd,0x93,0x26,0x36,0x3f,0xf7,0xcc,
           0x34,0xa5,0xe5,0xf1,0x71,0xd8,0x31,0x15,
           0x04,0xc7,0x23,0xc3,0x18,0x96,0x05,0x9a,
           0x07,0x12,0x80,0xe2,0xeb,0x27,0xb2,0x75,
           0x09,0x83,0x2c,0x1a,0x1b,0x6e,0x5a,0xa0,
           0x52,0x3b,0xd6,0xb3,0x29,0xe3,0x2f,0x84,
           0x53,0xd1,0x00,0xed,0x20,0xfc,0xb1,0x5b,
           0x6a,0xcb,0xbe,0x39,0x4a,0x4c,0x58,0xcf,
           0xd0,0xef,0xaa,0xfb,0x43,0x4d,0x33,0x85,
           0x45,0xf9,0x02,0x7f,0x50,0x3c,0x9f,0xa8,
           0x51,0xa3,0x40,0x8f,0x92,0x9d,0x38,0xf5,
           0xbc,0xb6,0xda,0x21,0x10,0xff,0xf3,0xd2,
           0xcd,0x0c,0x13,0xec,0x5f,0x97,0x44,0x17,
           0xc4,0xa7,0x7e,0x3d,0x64,0x5d,0x19,0x73,
           0x60,0x81,0x4f,0xdc,0x22,0x2a,0x90,0x88,
           0x46,0xee,0xb8,0x14,0xde,0x5e,0x0b,0xdb,
           0xe0,0x32,0x3a,0x0a,0x49,0x06,0x24,0x5c,
           0xc2,0xd3,0xac,0x62,0x91,0x95,0xe4,0x79,
           0xe7,0xc8,0x37,0x6d,0x8d,0xd5,0x4e,0xa9,
           0x6c,0x56,0xf4,0xea,0x65,0x7a,0xae,0x08,
           0xba,0x78,0x25,0x2e,0x1c,0xa6,0xb4,0xc6,
           0xe8,0xdd,0x74,0x1f,0x4b,0xbd,0x8b,0x8a,
           0x70,0x3e,0xb5,0x66,0x48,0x03,0xf6,0x0e,
           0x61,0x35,0x57,0xb9,0x86,0xc1,0x1d,0x9e,
           0xe1,0xf8,0x98,0x11,0x69,0xd9,0x8e,0x94,
           0x9b,0x1e,0x87,0xe9,0xce,0x55,0x28,0xdf,
           0x8c,0xa1,0x89,0x0d,0xbf,0xe6,0x42,0x68,
           0x41,0x99,0x2d,0x0f,0xb0,0x54,0xbb,0x16)

InvSbox = (0x52,0x09,0x6a,0xd5,0x30,0x36,0xa5,0x38,
           0xbf,0x40,0xa3,0x9e,0x81,0xf3,0xd7,0xfb,
           0x7c,0xe3,0x39,0x82,0x9b,0x2f,0xff,0x87,
           0x34,0x8e,0x43,0x44,0xc4,0xde,0xe9,0xcb,
           0x54,0x7b,0x94,0x32,0xa6,0xc2,0x23,0x3d,
           0xee,0x4c,0x95,0x0b,0x42,0xfa,0xc3,0x4e,
           0x08,0x2e,0xa1,0x66,0x28,0xd9,0x24,0xb2,
           0x76,0x5b,0xa2,0x49,0x6d,0x8b,0xd1,0x25,
           0x72,0xf8,0xf6,0x64,0x86,0x68,0x98,0x16,
           0xd4,0xa4,0x5c,0xcc,0x5d,0x65,0xb6,0x92,
           0x6c,0x70,0x48,0x50,0xfd,0xed,0xb9,0xda,
           0x5e,0x15,0x46,0x57,0xa7,0x8d,0x9d,0x84,
           0x90,0xd8,0xab,0x00,0x8c,0xbc,0xd3,0x0a,
           0xf7,0xe4,0x58,0x05,0xb8,0xb3,0x45,0x06,
           0xd0,0x2c,0x1e,0x8f,0xca,0x3f,0x0f,0x02,
           0xc1,0xaf,0xbd,0x03,0x01,0x13,0x8a,0x6b,
           0x3a,0x91,0x11,0x41,0x4f,0x67,0xdc,0xea,
           0x97,0xf2,0xcf,0xce,0xf0,0xb4,0xe6,0x73,
           0x96,0xac,0x74,0x22,0xe7,0xad,0x35,0x85,
           0xe2,0xf9,0x37,0xe8,0x1c,0x75,0xdf,0x6e,
           0x47,0xf1,0x1a,0x71,0x1d,0x29,0xc5,0x89,
           0x6f,0xb7,0x62,0x0e,0xaa,0x18,0xbe,0x1b,
           0xfc,0x56,0x3e,0x4b,0xc6,0xd2,0x79,0x20,
           0x9a,0xdb,0xc0,0xfe,0x78,0xcd,0x5a,0xf4,
           0x1f,0xdd,0xa8,0x33,0x88,0x07,0xc7,0x31,
           0xb1,0x12,0x10,0x59,0x27,0x80,0xec,0x5f,
           0x60,0x51,0x7f,0xa9,0x19,0xb5,0x4a,0x0d,
           0x2d,0xe5,0x7a,0x9f,0x93,0xc9,0x9c,0xef,
           0xa0,0xe0,0x3b,0x4d,0xae,0x2a,0xf5,0xb0,
           0xc8,0xeb,0xbb,0x3c,0x83,0x53,0x99,0x61,
           0x17,0x2b,0x04,0x7e,0xba,0x77,0xd6,0x26,
           0xe1,0x69,0x14,0x63,0x55,0x21,0x0c,0x7d)

#-------------------------------------
""" For each block size (Nb), the ShiftRow operation shifts row i
    by the amount Ci.  Note that row 0 is not shifted.
                 Nb      C1 C2 C3
               -------------------  """
shiftOffset  = { 4 : ( 0, 1, 2, 3),
                 5 : ( 0, 1, 2, 3),
                 6 : ( 0, 1, 2, 3),
                 7 : ( 0, 1, 2, 4),
                 8 : ( 0, 1, 3, 4) }
def ShiftRows(algInstance):
    tmp = [0]*algInstance.Nb   # list of size Nb
    for r in range(1,4):       # row 0 reamains unchanged and can be skipped
        for c in range(algInstance.Nb):
            tmp[c] = algInstance.state[(c+shiftOffset[algInstance.Nb][r]) % algInstance.Nb][r]
        for c in range(algInstance.Nb):
            algInstance.state[c][r] = tmp[c]
def InvShiftRows(algInstance):
    tmp = [0]*algInstance.Nb   # list of size Nb
    for r in range(1,4):       # row 0 reamains unchanged and can be skipped
        for c in range(algInstance.Nb):
            tmp[c] = algInstance.state[(c+algInstance.Nb-shiftOffset[algInstance.Nb][r]) % algInstance.Nb][r]
        for c in range(algInstance.Nb):
            algInstance.state[c][r] = tmp[c]
#-------------------------------------
def MixColumns(a):
    Sprime = [0,0,0,0]
    for j in range(a.Nb):    # for each column
        Sprime[0] = mul(2,a.state[j][0])^mul(3,a.state[j][1])^mul(1,a.state[j][2])^mul(1,a.state[j][3])
        Sprime[1] = mul(1,a.state[j][0])^mul(2,a.state[j][1])^mul(3,a.state[j][2])^mul(1,a.state[j][3])
        Sprime[2] = mul(1,a.state[j][0])^mul(1,a.state[j][1])^mul(2,a.state[j][2])^mul(3,a.state[j][3])
        Sprime[3] = mul(3,a.state[j][0])^mul(1,a.state[j][1])^mul(1,a.state[j][2])^mul(2,a.state[j][3])
        for i in range(4):
            a.state[j][i] = Sprime[i]

def InvMixColumns(a):
    """ Mix the four bytes of every column in a linear way
        This is the opposite operation of Mixcolumn """
    Sprime = [0,0,0,0]
    for j in range(a.Nb):    # for each column
        Sprime[0] = mul(0x0E,a.state[j][0])^mul(0x0B,a.state[j][1])^mul(0x0D,a.state[j][2])^mul(0x09,a.state[j][3])
        Sprime[1] = mul(0x09,a.state[j][0])^mul(0x0E,a.state[j][1])^mul(0x0B,a.state[j][2])^mul(0x0D,a.state[j][3])
        Sprime[2] = mul(0x0D,a.state[j][0])^mul(0x09,a.state[j][1])^mul(0x0E,a.state[j][2])^mul(0x0B,a.state[j][3])
        Sprime[3] = mul(0x0B,a.state[j][0])^mul(0x0D,a.state[j][1])^mul(0x09,a.state[j][2])^mul(0x0E,a.state[j][3])
        for i in range(4):
            a.state[j][i] = Sprime[i]

#-------------------------------------
def mul(a, b):
    """ Multiply two elements of GF(2^m)
        needed for MixColumn and InvMixColumn """
    if (a !=0 and  b!=0):
        return Alogtable[(Logtable[a] + Logtable[b])%255]
    else:
        return 0

Logtable = ( 0,   0,  25,   1,  50,   2,  26, 198,  75, 199,  27, 104,  51, 238, 223,   3,
           100,   4, 224,  14,  52, 141, 129, 239,  76, 113,   8, 200, 248, 105,  28, 193,
           125, 194,  29, 181, 249, 185,  39, 106,  77, 228, 166, 114, 154, 201,   9, 120,
           101,  47, 138,   5,  33,  15, 225,  36,  18, 240, 130,  69,  53, 147, 218, 142,
           150, 143, 219, 189,  54, 208, 206, 148,  19,  92, 210, 241,  64,  70, 131,  56,
           102, 221, 253,  48, 191,   6, 139,  98, 179,  37, 226, 152,  34, 136, 145,  16,
           126, 110,  72, 195, 163, 182,  30,  66,  58, 107,  40,  84, 250, 133,  61, 186,
            43, 121,  10,  21, 155, 159,  94, 202,  78, 212, 172, 229, 243, 115, 167,  87,
           175,  88, 168,  80, 244, 234, 214, 116,  79, 174, 233, 213, 231, 230, 173, 232,
            44, 215, 117, 122, 235,  22,  11, 245,  89, 203,  95, 176, 156, 169,  81, 160,
           127,  12, 246, 111,  23, 196,  73, 236, 216,  67,  31,  45, 164, 118, 123, 183,
           204, 187,  62,  90, 251,  96, 177, 134,  59,  82, 161, 108, 170,  85,  41, 157,
           151, 178, 135, 144,  97, 190, 220, 252, 188, 149, 207, 205,  55,  63,  91, 209,
            83,  57, 132,  60,  65, 162, 109,  71,  20,  42, 158,  93,  86, 242, 211, 171,
            68,  17, 146, 217,  35,  32,  46, 137, 180, 124, 184,  38, 119, 153, 227, 165,
           103,  74, 237, 222, 197,  49, 254,  24,  13,  99, 140, 128, 192, 247, 112,   7)

Alogtable= ( 1,   3,   5,  15,  17,  51,  85, 255,  26,  46, 114, 150, 161, 248,  19,  53,
            95, 225,  56,  72, 216, 115, 149, 164, 247,   2,   6,  10,  30,  34, 102, 170,
           229,  52,  92, 228,  55,  89, 235,  38, 106, 190, 217, 112, 144, 171, 230,  49,
            83, 245,   4,  12,  20,  60,  68, 204,  79, 209, 104, 184, 211, 110, 178, 205,
            76, 212, 103, 169, 224,  59,  77, 215,  98, 166, 241,   8,  24,  40, 120, 136,
           131, 158, 185, 208, 107, 189, 220, 127, 129, 152, 179, 206,  73, 219, 118, 154,
           181, 196,  87, 249,  16,  48,  80, 240,  11,  29,  39, 105, 187, 214,  97, 163,
           254,  25,  43, 125, 135, 146, 173, 236,  47, 113, 147, 174, 233,  32,  96, 160,
           251,  22,  58,  78, 210, 109, 183, 194,  93, 231,  50,  86, 250,  21,  63,  65,
           195,  94, 226,  61,  71, 201,  64, 192,  91, 237,  44, 116, 156, 191, 218, 117,
           159, 186, 213, 100, 172, 239,  42, 126, 130, 157, 188, 223, 122, 142, 137, 128,
           155, 182, 193,  88, 232,  35, 101, 175, 234,  37, 111, 177, 200,  67, 197,  84,
           252,  31,  33,  99, 165, 244,   7,   9,  27,  45, 119, 153, 176, 203,  70, 202,
            69, 207,  74, 222, 121, 139, 134, 145, 168, 227,  62,  66, 198,  81, 243,  14,
            18,  54,  90, 238,  41, 123, 141, 140, 143, 138, 133, 148, 167, 242,  13,  23,
            57,  75, 221, 124, 132, 151, 162, 253,  28,  36, 108, 180, 199,  82, 246,   1)


