
"""

    pyTomCrypt.py  --  version 0.10
    
    A ctypes Python interface to LibTomCrypt.
    
    Implemented:
      hash algorithms: all except chc
      prng algorithms: all except sober128
      symmetric encryption: all except safer
      modes: only ecb, cbc, and ctr
      public key: ecc sign/verify, 
    
    To come (as time permits): 
      chc, sober128, safer  (encountered problems)
      the remaining modes   (just didn't take the time)
      DSA, ASN1 and other utilities  (makes sense, will do)
      more robustness thru better error checking and handling
    
    Developed and tested on PPC MacOSX; not yet tested on 
    Windows or Linux.
    
    Examples of use:  see demos and development tests
    
    Works best if tomcrypt_plus library is installed.  If not,
    the sizes of some data structures are estimated and a math 
    library is not initialized.  Not having a math library 
    prevents the use of all public key algorithms (ECC, RSA, 
    DSA); all hashing, prng and symmetric encryption algorithms 
    are unaffected.  (Will ask that future versions of tomcrypt 
    fold in the dozen or so tomcrypt_plus library calls.)
    
    I like the implementation approach I took for hash, prng,
    and ecc.  Not happy with my pattern for modes, may redo.
    RSA works but is in need of attention/rework.  That said, 
    the whole thing will likely see some refinement.
    
    
    Larry Bugbee
    December 2006
    
"""

from pyTomCrypt import *
import sys, time


def savetofile(filename, content):
    f = open(filename, 'wb')
    f.write(content)
    f.close()
    
def loadfmfile(filename):
    f = open(filename, 'rb')
    content = f.read()
    f.close()
    return content
    

#-----------------------------------------------------------------------------
# some hashing tests

def test_hashes():
            
    print 
    
    tests = [           #  (message, (algorithm, expeted value))
      ('', (                        # message
        ('md5',    ""),             # (algorithm, expected hash value)
        ('rmd128', ""),
        ('sha1',   ""),
        ('tiger',  ""),
        ('sha224', ""),
        ('sha256', ""),
        ('sha384', ""),
        ('sha512', ""),
        ('whirlpool', """
            19FA61D75522A466 9B44E39C1D2E1726 C530232130D407F8 
            9AFEE0964997F7A7 3E83BE698B288FEB CF88E3E03C4F0757 
            EA8964E59B63D937 08B138CC42A66EB3 """),
      )),
      
      ('a', (
        ('md5',    ""),
        ('rmd128', ""),
        ('sha1',   ""),
        ('tiger',  ""),
        ('sha224', ""),
        ('sha256', ""),
        ('sha384', ""),
        ('sha512', ""),
        ('whirlpool', """
            8ACA2602792AEC6F 11A67206531FB7D7 F0DFF59413145E69 
            73C45001D0087B42 D11BC645413AEFF6 3A42391A39145A59 
            1A92200D560195E5 3B478584FDAE231A """),
      )),
      
      ('abc', (
        ('md5',    ""),
        ('sha1',   ""),
        ('tiger',  ""),
        ('whirlpool', """
            4E2448A4C6F486BB 16B6562C73B4020B F3043E3A731BCE72 
            1AE1B303D97E6D4C 7181EEBDB6C57E27 7D0E34957114CBD6 
            C797FC9D95D8B582 D225292076D4EEF5 """),
      )),
      
      ('abcdefghijklmnopqrstuvwxyz', (
        ('md5',    ""),
        ('rmd128', ""),
        ('sha1',   ""),
        ('tiger',  ""),
        ('sha224', ""),
        ('sha256', ""),
        ('sha384', ""),
        ('sha512', ""),
        ('whirlpool', """
            F1D754662636FFE9 2C82EBB9212A484A 8D38631EAD4238F5 
            442EE13B8054E41B 08BF2A9251C30B6A 0B8AAE86177AB4A6 
            F68F673E7207865D 5D9819A3DBA4EB3B """),
      )),
      
      ('1234567890'*8, (
        ('md5',    ""),
        ('sha1',   ""),
        ('tiger',  ""),
        ('whirlpool', """
            466EF18BABB0154D 25B9D38A6414F5C0 8784372BCCB204D6 
            549C4AFADB601429 4D5BD8DF2A6C44E5 38CD047B2681A51A 
            2C60481E88C5A20B 2C2A80CF3A9A083B """),
      )),
    ]
    
    for suite in tests:
        message = suite[0]
        print '\n[%s]  length %d' % (message, len(message))
        for algorithm, expected in suite[1]:
            # create an instance, hash the data, and get the digest
            hasher = eval(algorithm+'()')
            hasher.update(message)
            digest = hasher.digest()
            
            # show results
            print '  %-9s: %s' % (algorithm, byt2hex(digest)),
            if expected:                # test and print status if provided
                expected = hex2byt(expected.replace(' ','').replace('\n',''))
                if digest == expected:
                    print 'good'
                else:
                    print '\n      *** bad *** \n'
            else:
                print
    # verify that the hash algs got registered against _myHashIndices 
    template = "  registration index error: %s, %d, %d"
    for alg in hash_algorithms:
        idx = LTC.find_hash(alg)
#        if _myHashIndices[alg] != idx:
#            print template % (alg, _myHashIndices[alg], idx)
    
    print 'done'
    

#-----------------------------------------------------------------------------
# some PRNG tests

def test_prngs():
            
    # sprng, rc4, yarrow and fortuna work
    # sober128 doesn't like add_entropy
    
    print
    for alg in prng_algorithms:
    
        print alg
        prng = eval('%s()' % alg)
        prng.test()
        
        prng.add_entropy('hello'*1000)
        prng.ready()
        prng.test()

        randombytes = prng.read(60)
        print '  %s' % byt2hex(randombytes)
        
        state = prng.get_state()
        savetofile('prng-state-%s' % alg, state)
        state = loadfmfile('prng-state-%s' % alg)
        prng.load_state(state)
        
        randombytes = prng.read(60)
        print '  %s' % byt2hex(randombytes)
        randombytes = prng.read(60)
        print '  %s' % byt2hex(randombytes)
        
        
        prng.test()
        prng.done()
        prng.test()


# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
# some symmetric encryption tests

def test_ciphers():
    
    def test_one(alg, mode, key, orig_message, IV=''):
        print
        print '  alg      ', alg
        print '  mode     ', mode
        print '  key      ', key
        print '  origIV   ', byt2hex(IV)

        engine = Cipher_Context(alg, mode, key, IV)
        IV = engine.getiv()
        print '  actualIV ', byt2hex(IV)
        engine.setiv(IV)

        # assuming the last block...
        padded_message = pad(orig_message, engine._blocksize)
        encrypted_text = engine.encrypt(padded_message)
        
        
        if 1:                       # why ???
            engine2 = Cipher_Context(alg, mode, key, IV)   # works
        else:
            engine2 = Cipher_Context(alg, mode, key)       # no go
            print '  setIV    ', byt2hex(IV)
            engine2.setiv(IV)
        decrypted_text = engine2.decrypt(encrypted_text)
        # assuming the last block...
        decrypted_text = unpad(decrypted_text)
        
#        engine.test()
#        engine2.test()

        '''
        chunk = orig_message[:16]
        orig_message = orig_message[16:]
        # assuming the last block...
        encrypted_text = engine.encrypt(chunk)
        
        padded_message = pad(orig_message, engine._blocksize)
        encrypted_text += engine.encrypt(padded_message)
        
        
        engine2 = Cipher_Context(alg, mode, key, IV)   # works
        decrypted_text = engine2.decrypt(encrypted_text)
        # assuming the last block...
        decrypted_text = unpad(decrypted_text)
        
        print '  chunk    ', chunk,                   ' (length %d)' % len(chunk)
        '''

        print '  original ', orig_message,            ' (length %d)' % len(orig_message)
        print '  encrypted', byt2hex(encrypted_text), ' (length %d)' % len(encrypted_text)
        print '  decrypted', decrypted_text,          ' (length %d)' % len(decrypted_text)
    
    '''
    'aes', 'rijndael', 'twofish', 'blowfish', 'des', 'rc2', 'des3', 
    'cast5', 'kasumi', 'anubis', 'kseed', 'khazad', 'noekeon', 'rc5', 
     'rc6', 'skipjack', 'xtea',     # problems with safer_desc
    '''

    print
    alg  = 'anubis'
    alg  = 'xtea'
    alg  = 'rijndael'
    alg  = 'aes'
    alg  = 'des3'
    alg  = 'twofish'
    mode = 'ctr'
    mode = 'ecb'
    mode = 'cbc'

    key  = 'hello'
    IV   = '1234567890123456'
    IV   = ''
    message = 'Kilroy was here!  ...and there.'

    test_one(alg, mode, key, message)

    if 1:
        for alg in ['twofish', 'anubis', 'rijndael', 'aes', 'des3', 'xtea']:
            for mode in ['ecb', 'cbc', 'ctr', 'cfb', 'ofb']:
                test_one(alg, mode, key, message)


# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
# some RSA tests

def test_rsa():
    print '-'*20

    keylen = 4096           #  1024  2048  3092  4096   <<<<<<<<<<<<<<<<<<<
    print 'RSA keylen =', keylen
    
    rsa = RSA_Context()
    if 0:
        rsa.make_key(keylen)
        if 0:
            key = rsa.get_private_key()
            print len(key)
            savetofile('rsa-%d-pvtkey.der' % keylen, key)
            key = rsa.get_public_key()
            print len(key)
    else:
        key = loadfmfile('rsa-%d-pvtkey.der' % keylen)
        rsa.load_key(key)

    if 1:
        message = 'Kilroy was here  ...and there.'
        hash_alg = 'sha1'
        md = eval('%s()'%hash_alg)            #  sha1  sha256  sha512  whirlpool  tiger  <<<<<<<<<<<<<<<<
        md.update(message)
        digest = md.final()
        print '  digest         ', byt2hex(digest), len(digest)
        if 1:
            sig = rsa.sign_PSS(digest)
            print '  signature      ', byt2hex(sig), len(sig)
            good = rsa.verify_PSS(digest, sig)
        else:
            sig = rsa.sign(digest, hash_alg, LTC_PKCS_1_V1_5)
            print '  signature      ', byt2hex(sig), len(sig)
            good = rsa.verify(digest, hash_alg, sig, LTC_PKCS_1_V1_5)
        if good:
            print '  *** good ***'
        else:
            print '  *** bad ***'
        print
    
    if 1:
        # encrypt and decrypt a 16-byte session key
        sk = 'abc456def012ghi6'
        secret = rsa.encrypt_OAEP(sk, 'dummyparam', 'sha1')
        sk2 = rsa.decrypt(secret, 'dummyparam', 'sha1')
        print '  session key - in  ', sk
        print '  session key - out ', sk2
        if sk == sk2:
            print '  *** good ***'
        else:
            print '  *** bad ***'
        print
    

# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
# some DSA tests

def test_dsa():
    print '-'*20

    keylen = 4096           #  1024  2048  3092  4096   <<<<<<<<<<<<<<<<<<<
    print 'DSA keylen =', keylen
    
    dsa = DSA_Context()
    if 0:
        dsa.make_key(keylen)
        if 0:
            key = dsa.get_private_key()
            print len(key)
            savetofile('dsa-%d-pvtkey.der' % keylen, key)
            key = dsa.get_public_key()
            print len(key)
    else:
        key = loadfmfile('dsa-%d-pvtkey.der' % keylen)
        dsa.load_key(key)

    if 1:
        message = 'Kilroy was here ...and there.'
        md = sha1()            #  sha1  sha256  sha512  whirlpool  tiger  <<<<<<<<<<<<<<<<
        md.update(message)
        digest = md.final()
        print '  digest         ', byt2hex(digest), len(digest)
        sig = dsa.sign(digest)
        print '  signature      ', byt2hex(sig), len(sig)
        good = dsa.verify(digest, sig)
        if good:
            print '  *** good ***'
        else:
            print '  *** bad ***'
        print
    
    if 0:
        # encrypt and decrypt a 16-byte session key
        sk = 'abc456def012ghi6'
        secret = dsa.encrypt(sk)
        sk2 = dsa.decrypt(secret)
        print 'session key - in  ', sk
        print 'session key - out ', sk2
        if sk == sk2:
            print '  *** good ***'
        else:
            print '  *** bad ***'
        print
    
    print '-'*20
    

# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
# some ECC tests

def test_ecc():
            
    print '\n%s\nElliptical Curve' % ('-'*70)
            
    keylen = 160           #  160  192  256  384  521   <<<<<<<<<<<<<<<<<<<
    print '  ECC keylen =', keylen
    print
    
    ec = ECC_Context()
    ec.make_key(keylen)
    
    if 0:
        key = ec.get_private_key()
        print 'loading key'
        ec.load_key(key)
        key = ec.get_public_key()

    # - - - - - - - - - - - - - - - - - - - - - - - - -
    if 1:
        print '- '*30
        print 'sign & verify'
    
        message = 'Kilroy was here ...and there.'
        md = tiger()           #  sha1  sha256  sha512  whirlpool  tiger  <<<<<<<<<<<<<<<<
        md.update(message)
        digest = md.final()
        print '  digest         ', byt2hex(digest), len(digest)
        sig = ec.sign(digest)
        print '  signature      ', byt2hex(sig), len(sig)
        good = ec.verify(digest, sig)
        if good:
            print '  *** good ***'
        else:
            print '  *** bad ***'
        print
    

    # - - - - - - - - - - - - - - - - - - - - - - - - -
    if 1:
        print '- '*30
        print 'encrypt & decrypt'
                
        # encrypt a 16-byte session key
#        ec = ECC_Context()
#        ec.make_key(256)         # <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
        plaintext  = 'abc456def012ghi6'
        print '  plaintext ', plaintext
        ciphertext = ec.encrypt(plaintext)
        print '  ciphertext', byt2hex(ciphertext), len(ciphertext)
        decrypted  = ec.decrypt(ciphertext)
        print '  decrypted ', decrypted
        if decrypted == plaintext:
            print '  *** good ***'
        else:
            print '  *** bad ***'
        
        print
    

# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
# some ASN1 and other utility tests

def test_utilities():

    if 1:
        print 'hashes ', get_supported_hashes()
        print 'prngs  ', get_supported_prngs()
        print 'ciphers', get_supported_ciphers()
        print 'modes  ', get_supported_modes()
        print 'public_key_algorithms', get_supported_public_key_algorithms()
        

    if 0:
        strings = ['hell', 'hello', 'helloo', 'hellooo']
        for s in strings:
            b64 = base64_encode(s)
            s2 = base64_decode(b64)
#            print s, len(s)
#            print b64, len(b64)
#            print s2, len(s2)
            if s2 != s:  print '  *** bad base64 ***'

    if 0:
        #print dir(ASN1_List)
        r = 815992158877728526763878534080536901593045180893L
        s = 259352924525753817601304407399003355017482042934L
        r = '0215008eee5a98b0ac6c233e0ebe26e653fd1f53cca1dd'
        sig = '302d0215008eee5a98b0ac6c233e0ebe26e653fd1f53cca1dd02142d6dc9901f8733e72cb8ee69ba5f36f1ce777e36'  # 47
        sig = '302d0415008eee5a98b0ac6c233e0ebe26e653fd1f53cca1dd04142d6dc9901f8733e72cb8ee69ba5f36f1ce777e36'  # 47
        val = sig
        
        print val,
        val = hex2byt(val)
        print  len(val)
        
        seqlist = asn1_decode(val)
        
    #    seqlist = cast(seqlist, POINTER(ASN1_List))[0]
        print 'seqlist', seqlist
        print
        print 'type', seqlist.type,
        print 'data', seqlist.data,
        print 'size', seqlist.size,
        print 'used', seqlist.used
    
        child = seqlist.child[0]
        print 'type', child.type,
        print 'data', child.data,
        print 'size', child.size,
        print 'used', child.used
    
        next = child.next[0]
        print 'type', next.type,
        print 'data', next.data,
        print 'size', next.size,
        print 'used', next.used
        print
        
    #    num = cast(next.data, POINTER(c_char))
        print byt2hex(child.data[:child.size])
        print byt2hex(next.data[:next.size])
    
    '''
        Encodes a signature in a way it meets the W3C standard for DSA XML 
        signature values. Extracts the ASN.1 encoded values for r and s from 
        a DER encoded byte array. 
        ASN.1 Notation: sequence { integer r integer s } --> 
        Der-Encoding byte 0x30   // Sequence 
                     byte 44 + x // len in bytes (x = {0|1|2} depending on r and s 
                     byte 0x02   // integer 
                     byte <= 21  // len of r (21: if first bit of r set, we need a leading 0 --> 20 + 1 bytes) 
                     byte[] ...  // value of r (with leading zero if necessary) 
                     byte 0x02   // integer 
                     byte <= 21  // len of s (21: if first bit of s set, we need a leading 0 --> 20 + 1 bytes) 
                     byte[] ...  // value of s (with leading zero if necessary)
    
    '''

#-----------------------------------------------------------------------------
#-----------------------------------------------------------------------------

def test_macs():
    if 1:
        msg = 'Kilroy was here!  ...and there.'
        key8   = 'keyvalue'
        key16  = 'keyvalue90123456'
        keyany = 'Now is the time for all good men...'
        
        hmac = HMAC('md5', keyany)
        hmac.process(msg)
        sig = hmac.done()
        print hmac.name
        print ' ', byt2hex(sig)
        
        hmac = HMAC('sha1', keyany)
        hmac.process(msg)
        sig = hmac.done()
        print hmac.name
        print ' ', byt2hex(sig)
        
        hmac = HMAC('sha256', keyany)
        hmac.process(msg)
        sig = hmac.done()
        print hmac.name
        print ' ', byt2hex(sig)
        
        hmac = HMAC('whirlpool', keyany)
        hmac.process(msg)
        sig = hmac.done()
        print hmac.name
        print ' ', byt2hex(sig)
        
        omac = OMAC('twofish', key16)
        omac.process(msg)
        sig = omac.done()
        print omac.name
        print ' ', byt2hex(sig)
        
        print 'A quick OMAC(twofish) signature...'
        print ' ', byt2hex(OMAC('twofish', key16).done(msg))
        
        omac = OMAC('aes', key16)
        omac.process(msg)
        sig = omac.done()
        print omac.name
        print ' ', byt2hex(sig)
        
        omac = OMAC('des', key8)
        omac.process(msg)
        sig = omac.done()
        print omac.name
        print ' ', byt2hex(sig)
        
        pmac = PMAC('twofish', key16)
        pmac.process(msg)
        sig = pmac.done()
        print pmac.name
        print ' ', byt2hex(sig)
        
        pmac = Pelican(key16)
        pmac.process(msg)
        sig = pmac.done()
        print pmac.name
        print ' ', byt2hex(sig)
        
        xcbc = XCBC('twofish', key16)
        xcbc.process(msg)
        sig = xcbc.done()
        print xcbc.name
        print ' ', byt2hex(sig)
        
        if 1:
            omac = OMAC('aes', key16)
        #    msg2 = msg+chr(0x40)+str(chr(0x00)*(16-((len(msg)+1)%16)))
            msg2 = msg+chr(0x40)
            omac.process(msg2)
            sig = omac.done()
            print omac.name
            print ' ', byt2hex(sig)
        
        f9 = F9('aes', key16)
        f9.test()
        msg2 = msg+chr(0x40)
        pads = str(chr(0x00)*(16-(len(msg)+1)%16))
        if len(pads) != 16:
            msg2 += pads
        f9.process(msg2)
        sig = f9.done()
        print f9.name
        print ' ', byt2hex(sig)

    # - - - - - - - - - - - - - - - - - - - - - - - -

    if 0:
        problems = 0
        TVs = {}
        
        def build_hmac_tv(filename):
            line = ''
            
            f = open(filename)
            for line in f:
                line = line[:-1]
                if not line: continue
                parts = line.split()
                if len(parts) == 1:
                    key = parts[0]
                    TVs[key] = []
                if len(parts) == 2 and parts[0][-1] == ':':
                    TVs[key].append(parts[1])
        
        def showlist(values):
            for v in values:
                print '   ', v
        
        def showTVs():
            for key,values in TVs.items():
                print
                print ' ', key, len(values)
                showlist(values)
        
        def nnstr(nn):
            s = ''
            for i in range(nn):
                s += chr(i)
            return s
        
        
        build_hmac_tv('../libtomcrypt-1.16/notes/hmac_tv.txt')
    #    showTVs()
        
        algs = ['md5', 'sha1', 'sha224', 'whirlpool']
        
        for alg in algs:
            values = TVs['HMAC-'+alg]
            
#            print
#            print ' ', alg, len(values)
#            showlist(values)
            
            for nn in range(1,len(values)):
                msg = nnstr(nn)        
                key = hex2byt(values[nn-1])
                
                hmac = HMAC(alg, key)
                hmac.process(msg)
                sig = hmac.done()
                
                if sig != hex2byt(values[nn]):
                    print '  *** %s problem with vector %d' % (alg, nn)
                    problems += 1
        #        if nn == 5: break
            
        if not problems:
            print
            print '  *** no HMAC problems detected with', str(algs)
            print
        
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -

if __name__ == '__main__':
    test_hashes()
#    test_macs()
#    test_prngs()
#    test_ciphers()
#    test_rsa()
#    test_dsa()
    test_ecc()
#    test_utilities()
    pass

'''
Notes:
    ecc sign, verify, decrypt don't need a hash alg
    ecc encrypt does and encodes OID into output (hashlen must be > inlen)

    rsa decrypt don't need a hash alg
    rsa encrypt does if OAEP padding and encodes OID into output (hashlen must be > inlen)
                        LTC_PKCS_1_V1_5 doen't work
'''
#-----------------------------------------------------------------------------
#-----------------------------------------------------------------------------
#-----------------------------------------------------------------------------


'''
    def build_hmac_tv(filename):
        TVs = {}
        
        def skip_blank_lines():
            while 1:
                line = f.readline()[:-1]
                print line
                if line:
                    break
            return line
        
        def skip_section():
            while 1:
                line = f.readline()[:-1]
                print line
                if not line:
                    line = skip_blank_lines()
                    break
            return line
        
        def addTV(line):
            if line.startswith('HMAC'):
                key = line
                val = []
                while 1:
                    line = f.readline()[:-1]
                    print line
                    if not line:
                        break
                    n,v = line.split(': ')
                    val.append(v)
                TVs[key] = val
                line = skip_blank_lines()
            else:
                line = skip_section()
            return line
        
        f = open(filename)
        line = f.readline()[:-1]
        print line
        line = skip_section()
        while line:
            line = addTV(line)
            break
        for key,value in TVs.items():
            print key,value
    
    
    def build_hmac_tv(filename):
        TVs = {}
        line = ''
        
        def skip_blank_lines():
            global line
            while 1:
                line = f.readline()[:-1]
                print line
                if line:
                    break
        
        def skip_section():
            global line
            while 1:
                line = f.readline()[:-1]
                print line
                if not line:
                    skip_blank_lines()
                    break
        
        def addTV():
            global line
            if line.startswith('HMAC'):
                key = line
                val = []
                while 1:
                    line = f.readline()[:-1]
                    print line
                    if not line:
                        break
                    n,v = line.split(': ')
                    val.append(v)
                TVs[key] = val
        
        f = open(filename)
        line = f.readline()[:-1]
        print line
        skip_section()
        while line:
            addTV()
            skip_blank_lines()

        for key,value in TVs.items():
            print
            print key,value

'''