﻿
# coding: utf-8

"""
    
    pyTomCrypt.py
    
    A Python interface to LibTomCrypt implemented with ctypes.
    
    pyTomCrypt is experimental software and intended for 
    educational purposes only.  To make your learning and 
    experimentation less cumbersome, pyTomCrypt is placed 
    in the Public Domain and is free for any use.      
    
    THIS PROGRAM IS PROVIDED WITHOUT WARRANTY OR GUARANTEE
    OF ANY KIND.  USE AT YOUR OWN RISK.  
    
    
    Enjoy,
    
    Larry Bugbee
    bugbee@seanet.com
    March 6, 2007
    
"""

from ctypes import *
import os, sys, time, struct

__VERSION__  = 'v0.20'

USE_MATH_LIB = 'LTM'                    #  'TFM' or 'LTM'

SHOW_STATUS  = 0


#=============================================================================
# load the libraries

if SHOW_STATUS:
    print '  pyTomCrypt %s' % __VERSION__
    print '  sys.platform =', sys.platform
    print '  os.name      =', os.name

libsuffixes = {'darwin': '.so',         # '.dylib', 
               'linux':  '.so', 
               'linux2': '.so', 
               'nt':     '.lib'}        # .so ???
try:
    libsuffix = libsuffixes[sys.platform]
except:
    raise Exception('pyTomCrypt does not know library suffix for "%s"' % sys.platform)

try:                            # load libtomcrypt
    # look for .so first
    LTC = CDLL('libtomcrypt.so')
except:
    # perhaps a platform specific suffix
    if sys.platform == 'darwin':
        LTC = CDLL('libtomcrypt'+libsuffix)
    else:
        raise Exception('unable to find or load libtomcrypt.so')
if SHOW_STATUS:
    print '  loaded: %s' % LTC._name

try:                            # load libtomcrypt_plus library
    LTCplus = CDLL('libtomcrypt_plus.so')
except:
    if sys.platform == 'darwin':
        LTCplus = CDLL('libtomcrypt_plus'+libsuffix)
    else:
        raise Exception('unable to find or load libtomcrypt_plus.so')
if SHOW_STATUS:
    print '  loaded: %s' % LTCplus._name

if USE_MATH_LIB == 'TFM':
    LTCplus.init_TFM()              # init TomsFastMath library
else:
    LTCplus.init_LTM()              # init LibTomMath


#=============================================================================
# named constants (get the right values via call to tomcrypt so the libs
# need to be loaded)

#PUBLIC_KEY  = 0
#PRIVATE_KEY = 1
PUBLIC_KEY  = LTCplus.PUBLIC_KEY()
PRIVATE_KEY = LTCplus.PRIVATE_KEY()

#LTC_PKCS_1_V1_5 = 1         # PKCS #1 v1.5 padding (\sa ltc_pkcs_1_v1_5_blocks)
#LTC_PKCS_1_OAEP = 2         # PKCS #1 v2.0 encryption padding
#LTC_PKCS_1_PSS  = 3         # PKCS #1 v2.1 signature padding
LTC_PKCS_1_V1_5 = LTCplus.PKCS_1_V1_5()
LTC_PKCS_1_OAEP = LTCplus.PKCS_1_OAEP()
LTC_PKCS_1_PSS  = LTCplus.PKCS_1_PSS()

#CTR_COUNTER_LITTLE_ENDIAN = 0   # not based on endian of machine
#CTR_COUNTER_BIG_ENDIAN    = 1   # not based on endian of machine
#LTC_CTR_RFC3686           = 2   # not good ???   (doesn't work here)
CTR_COUNTER_LITTLE_ENDIAN = LTCplus.COUNTER_LITTLE_ENDIAN()
CTR_COUNTER_BIG_ENDIAN    = LTCplus.COUNTER_BIG_ENDIAN()
LTC_CTR_RFC3686           = LTCplus.CTR_RFC3686()


#-----------------------------------------------------------------------------
# Determine the size of the basic working states.  The amount of memory 
# to allocate is provided by tomcrypt_plus and if not installed, estimates
# are provided. Crashes may result if those estimates don't cover future
# memory requirements.  Install a current version of tomcrypt_plus if at
# all possible.

if LTCplus:
    hash_state_size = LTCplus.hash_state_size()
    symmetric_ECB_size = LTCplus.symmetric_ECB_size()
    symmetric_CBC_size = LTCplus.symmetric_CBC_size()
    symmetric_CTR_size = LTCplus.symmetric_CTR_size()
    symmetric_CFB_size = LTCplus.symmetric_CFB_size()
    symmetric_OFB_size = LTCplus.symmetric_OFB_size()
    symmetric_LRW_size = LTCplus.symmetric_LRW_size()
    symmetric_F8_size  = LTCplus.symmetric_F8_size()
    prng_state_size = LTCplus.prng_state_size()
else:
    hash_state_size    = 300            # default estimate
    symmetric_CBC_size = 4500           # default estimate
    symmetric_CTR_size = 4500           # default estimate
    prng_state_size    = 14000          # default estimate

#=============================================================================
# no convenient way to get the error codes from tomcrypt...  it's brute force 
# either way.
error_codes = {
     0: 'CRYPT_OK',                 # Result OK 
     1: 'CRYPT_ERROR',              # Generic Error 
     2: 'CRYPT_NOP',                # Not a failure, but no operation performed 

     3: 'CRYPT_INVALID_KEYSIZE',    # Invalid key size given 
     4: 'CRYPT_INVALID_ROUNDS',     # Invalid number of rounds 
     5: 'CRYPT_FAIL_TESTVECTOR',    # Algorithm failed test vectors 

     6: 'CRYPT_BUFFER_OVERFLOW',    # Not enough space for output 
     7: 'CRYPT_INVALID_PACKET',     # Invalid input packet given 

     8: 'CRYPT_INVALID_PRNGSIZE',   # Invalid number of bits for a PRNG 
     9: 'CRYPT_ERROR_READPRNG',     # Could not read enough from PRNG 

    10: 'CRYPT_INVALID_CIPHER',     # Invalid cipher specified 
    11: 'CRYPT_INVALID_HASH',       # Invalid hash specified 
    12: 'CRYPT_INVALID_PRNG',       # Invalid PRNG specified 

    13: 'CRYPT_MEM',                # Out of memory 

    14: 'CRYPT_PK_TYPE_MISMATCH',   # Not equivalent types of PK keys 
    15: 'CRYPT_PK_NOT_PRIVATE',     # Requires a private PK key 

    16: 'CRYPT_INVALID_ARG',        # Generic invalid argument 
    17: 'CRYPT_FILE_NOTFOUND',      # File Not Found 

    18: 'CRYPT_PK_INVALID_TYPE',    # Invalid type of PK key 
    19: 'CRYPT_PK_INVALID_SYSTEM',  # Invalid PK system specified 
    20: 'CRYPT_PK_DUP',             # Duplicate key already in key ring 
    21: 'CRYPT_PK_NOT_FOUND',       # Key not found in keyring 
    22: 'CRYPT_PK_INVALID_SIZE',    # Invalid size input for PK parameters 

    23: 'CRYPT_INVALID_PRIME_SIZE', # Invalid size of prime requested 
    24: 'CRYPT_PK_INVALID_PADDING', # Invalid padding on input 
}


#=============================================================================
# supported algorithms
#
# If you use a custom built version of LibTomCrypt that implements a subset of
# the available algorithms, the following lists need to be likewise tailored.

public_key_algs   = ['DSA', 'RSA', 'ECC']       # ECC is both ECDSA and ECDH

hash_algorithms   = ['md2', 'md4', 'md5', 
                     'rmd128', 'rmd160', 'rmd256', 'rmd320', 
                     'sha1', 'sha224', 'sha256', 'sha384', 'sha512', 
                     'tiger', 'whirlpool']              # problems with chc

cipher_algorithms = ['aes', 'rijndael', 'twofish', 'blowfish', 
                     'des', 'rc2', 'des3', 'cast5', 'kasumi', 
                     'anubis', 'kseed', 'khazad', 'noekeon', 
                     'rc5', 'rc6', 'skipjack', 'xtea']  # problems with safer_desc

cipher_modes      = ['ecb', 'cbc', 'ctr', 'cfb', 'ofb'] # , 'lrw', 'f8']

prng_algorithms   = ['fortuna', 'rc4', 'sprng', 'yarrow', 'sober128']

mac_algorithms    = ['hmac', 'omac', 'pmac', 'pelican', 'xcbc', 'f9'] # no file, mem

#=============================================================================
# Register all supported algorithms (hash, symmetric cipher, PRNG)
#
# Registered algorithms are available for higher level function calls
# but registration is typically not required when specific individual
# algorithms are called.  ...there may be some exceptions.

def register_hashes():
    for alg in hash_algorithms:
        hash_desc = eval('LTC.'+alg+'_desc')    # get the descriptor
        idx = LTC.register_hash(hash_desc)

def register_ciphers():
    for alg in cipher_algorithms:
        cipher_desc = eval('LTC.'+alg+'_desc')  # get the descriptor
        idx = LTC.register_cipher(cipher_desc)

def register_prngs():
    for alg in prng_algorithms:
        prng_desc = eval('LTC.'+alg+'_desc')    # get the descriptor
        idx = LTC.register_prng(prng_desc)


#-----------------------------------------------------------------------------
# low level functions (hash, symmetric cipher, PRNG, etc)

def find_hash(name):
    "returns the index or -1 if not registered"
    return LTC.find_hash(name)
    
def hash_is_valid(idx):
    "performs internal test vector validity checks"
    return LTC.hash_is_valid(idx)
    
def register_hash(name, ID, hashsize, blocksize, OID, 
                  init, process, done, test):
    raise Exception('  *** register_hash not yet implemented ***')
    
def unregister_hash(hash_desc):
    raise Exception('  *** unregister_hash not yet implemented ***')
    
def find_hash_id(ID):
    raise Exception('  *** find_hash_id not yet implemented ***')
    
def find_hash_oid():
    raise Exception('  *** find_hash_oid not yet implemented ***')
    
def find_hash_any(digestlen):           # digestlen in bytes
    raise Exception('  *** find_hash_any not yet implemented ***')

def get_supported_hashes():
    return hash_algorithms
    
def get_supported_prngs():
    return prng_algorithms
    
def get_supported_ciphers():
    return cipher_algorithms

def get_supported_modes():
    return cipher_modes
    
def get_supported_public_key_algorithms():
    return public_key_algs

def get_supported_modes():
    return cipher_modes
    
#-----------------------------------------------------------------------------
# common definitions

class MP_Int(Structure):
    _fields_ = [('used',   c_int),
                ('alloc',  c_int),
                ('size',   c_int),      # unused
                ('dp', POINTER(c_char)),
               ]

#-----------------------------------------------------------------------------
# utilities: there may be better ways, but this is local and convenient

def savetofile(filename, content):
    "write content to a file [as binary]"
    f = open(filename, 'wb')
    f.write(content)
    f.close()
    
def loadfmfile(filename):
    "get [binary] content from a file"
    f = open(filename, 'rb')
    content = f.read()
    f.close()
    return content
    
def byt2hex(bs):
    "convert a byte string to hex"
    h = ''
    for c in bs:
        h += ('0'+hex(ord(c))[2:])[-2:] # for each chr add 2 hex chars
    return h

def hex2byt(h):
    if h[-1:] == 'L':  h = h[:-1]       # remove any trailing 'L'
    if h[:2] == '0x':  h = h[2:]        # remove any leading '0x'
    if len(h) % 2:     h = '0'+h        # be sure num chars is even
    bs = ''
    while h:
        bs += chr(eval('0x'+h[:2]))     # eval each char and add to str
        h = h[2:]                       # pop them off
    return bs

    
def mpi2hex(mp_int):
    "Convert Tom's 'infamous' mp_int to hex"
    mpi = cast(mp_int, POINTER(MP_Int))[0]
    if 0:
        print
        print 'used  ', mpi.used
        print 'alloc ', mpi.alloc
        print 'size  ', mpi.size
        print 'bytes ', mpi.used*7
    size = mpi.used*7                   # number of actual bytes used, may have leading zeros
    work = byt2hex(mpi.dp[:mpi.used*4]) # get the right number of words (4 byte words)
    h = ''
    while work:
        h = work[1:8] + h               # strip leading 0, append in reverse order
        work = work[8:]                 # used, remove
    if len(h) % 2:                      # if odd count
        h = '0' + h                     #   pad with a leading '0'
    while h[:2] == '00':                # if leading byte == 0
        h = h[2:]                       #   remove it
#    if eval('0x'+h[:2]) & 0x80:         # if hi bit set
#        h = '00'+h                      #   prefix with '00'
    return h
    

#=============================================================================
# Registering is so other libtomcrypt functions can find and use the 
# algorithms.  Registering does not directly support a Python program;
# the code that follows does.
#
#=============================================================================
# MAC Algorithms
#
# The goal here is to create a class, one for each MAC algorithm, for 
# standalone use.  (The registration process is not required to use these 
# classes.)
#
# Typical usage:
#     hmac = HMAC('sha256', keyanylen)
#     hmac.process(msg)
#     sig = hmac.done()
#

# MAC is a base class
class MAC(object):
    def _init(self, alg, idx, key):
        self.name = '%s(%s)' % (self.__class__.__name__, alg)
        self.alg = alg
        self._state = c_buffer(self._state_size)
        keybuf = c_buffer(key)
        err = self._init_func(byref(self._state), idx, byref(keybuf), len(key))
        if err: raise Exception('init error = %d (%s)' % (err, error_codes[err]))

    def process(self, message):
        "chunk message and loop this step if desired"
        err = self._proc_func(byref(self._state), message, len(message))
        if err: raise Exception('process error = %d (%s)' % (err, error_codes[err]))
    
    def done(self, message=''):
        "returns the signature"
        if message:
            self.update(message)
        outlen = c_int(64)                  # same as hash size, start big
        self._sig = c_buffer(outlen.value)
        err = self._done_func(byref(self._state), byref(self._sig), byref(outlen))
        if err: raise Exception('done error = %d (%s)' % (err, error_codes[err]))
        return self._sig.raw[:outlen.value]

    def test(self):
        err = self._test_func(byref(self._state))
        if err: raise Exception('test error = %d (%s)' % (err, error_codes[err]))
        return 0

    update = process                        # common Python synonym
    final  = done                           # common Python synonym

class HMAC(MAC):
    def __init__(self, hash_alg, key):
        self._state_size = LTCplus.hmac_state_size()
        self._init_func  = LTC.hmac_init
        self._proc_func  = LTC.hmac_process
        self._done_func  = LTC.hmac_done
        self._test_func  = LTC.hmac_test
        idx = LTC.find_hash(hash_alg)
        self._init(hash_alg, idx, key)

class OMAC(MAC):
    def __init__(self, cipher_alg, key):
        self._state_size = LTCplus.omac_state_size()
        self._init_func  = LTC.omac_init
        self._proc_func  = LTC.omac_process
        self._done_func  = LTC.omac_done
        self._test_func  = LTC.omac_test
        idx = LTC.find_cipher(cipher_alg)
        self._init(cipher_alg, idx, key)

class PMAC(MAC):
    def __init__(self, cipher_alg, key):
        self._state_size = LTCplus.pmac_state_size()
        self._init_func  = LTC.pmac_init
        self._proc_func  = LTC.pmac_process
        self._done_func  = LTC.pmac_done
        self._test_func  = LTC.pmac_test
        idx = LTC.find_cipher(cipher_alg)
        self._init(cipher_alg, idx, key)

class XCBC(MAC):
    def __init__(self, cipher_alg, key):
        self._state_size = LTCplus.xcbc_state_size()
        self._init_func  = LTC.xcbc_init
        self._proc_func  = LTC.xcbc_process
        self._done_func  = LTC.xcbc_done
        self._test_func  = LTC.xcbc_test
        idx = LTC.find_cipher(cipher_alg)
        self._init(cipher_alg, idx, key)

class F9(MAC):
    def __init__(self, cipher_alg, key):
        self._state_size = LTCplus.f9_state_size()
        self._init_func  = LTC.f9_init
        self._proc_func  = LTC.f9_process
        self._done_func  = LTC.f9_done
        self._test_func  = LTC.f9_test
        idx = LTC.find_cipher(cipher_alg)
        self._init(cipher_alg, idx, key)

# Pelican assumes AES... just one more detail to not have to worry about.
#   ...but because AES is hardcoded, there is no need for outlen in done()
#   so we will override done below.
class Pelican(MAC):
    def __init__(self, key):
        self.name = self.__class__.__name__
        self._state = c_buffer(LTCplus.pelican_state_size())
        self._proc_func  = LTC.pelican_process
        self._done_func  = LTC.pelican_done
        self._test_func  = LTC.pelican_test
        keybuf = c_buffer(key)
        err = LTC.pelican_init(byref(self._state), byref(keybuf), len(key))
        if err: raise Exception('init error = %d (%s)' % (err, error_codes[err]))

    def done(self, message=''):
        "returns the signature"
        if message:
            self.update(message)
        self._sig = c_buffer(16)         # determined by AES
        err = LTC.pelican_done(byref(self._state), byref(self._sig))
        if err: raise Exception('done error = %d (%s)' % (err, error_codes[err]))
        return self._sig.raw



#=============================================================================
# Hashing Algorithms
#
# The goal here is to create a class, one for each hashing algorithm, for 
# standalone use.  (The registration process is not required to use these 
# classes.)
#
# Typical usage:
#     md = sha256()
#     md.update('Kilroy was ')
#     md.update('here!')
#     digest = md.final()
#

class HashDescriptors(Structure):
    "This descriptor is needed to gain access to the output digest size"
    _fields_ = [('name',       c_char_p),
                ('ID',         c_byte),
                ('hashsize',   c_ulong),
                ('blocksize',  c_ulong),
                ('OID',        c_ulong*16),
                ('OIDlen',     c_ulong),
                ('init',       CFUNCTYPE(c_int, c_char_p)),
                ('process',    CFUNCTYPE(c_int, c_char_p)),
                ('done',       CFUNCTYPE(c_int, c_char_p)),
                ('test',       CFUNCTYPE(c_int)),
                ('hmac_block', c_ulong),    # not implemented
               ]

# HashAlgorithm is a base class
class HashAlgorithm(object):
    """ This is a base class for each of the hashing algorithms.
        Most of the logic is here because of the parallel structure
        adopted by Tom.
    """
    def _init(self):
        "setup instance and then call self.init()"
        self.name = self.__class__.__name__
        self._hash_state = c_buffer(self._state_size)
        self.init()
    
    def init(self):
        "initializes state to initial conditions"
        err = self._init_func(byref(self._hash_state))
        if err: raise Exception('init error = %d (%s)' % (err, error_codes[err]))

    def process(self, message):
        "chunk message and loop this step if desired"
        err = self._proc_func(byref(self._hash_state), message, len(message))
        if err: raise Exception('process error = %d (%s)' % (err, error_codes[err]))
    
    def done(self, message=''):
        "returns the digest (non-destructively) so may do additional hashing"
        if message:
            self.update(message)
        self._digest = c_buffer(self._digest_size/8)
        err = self._done_func(byref(self._hash_state), byref(self._digest))
        if err: raise Exception('done error = %d (%s)' % (err, error_codes[err]))
        return self._digest.raw[:self._digest_size/8]

    def _estimate_state_size(self):
        "not used if libtomcrypt_plus library is installed"
        est = 400                      # hopefully large enough
        print '  Note: estimated state buffer for %s is %d bytes' % (self.name, est)
        return est
    
    def test(self):
        return self._test_func()

    update = process
    digest = final = done


# dynamically create a class for each hashing algorithm by 'exec'uting
# the following template.  Note the inheritance from HashAlgorithm and
# the use of %s and %d placeholders.

hash_class_template = """
class %s(HashAlgorithm):
    def __init__(self):
        self._state_size  = hash_state_size     # global
        self._digest_size = %d*8                # in bits
        self._init_func   = LTC.%s_init
        self._proc_func   = LTC.%s_process
        self._done_func   = LTC.%s_done
        self._test_func   = LTC.%s_test
        self._init()
"""

# make a class, one for each algorithm.  The class name will be exactly
# as specified in "hash_algorithms" above.
for alg in hash_algorithms:
    hash_desc = eval('LTC.'+alg+'_desc')    # get the descriptor
    desc = cast(hash_desc, POINTER(HashDescriptors))[0]
    if desc.name == 'sha224':
        exec(hash_class_template % (alg, desc.hashsize, alg, 'sha256', alg, alg))
    elif desc.name == 'sha384':
        exec(hash_class_template % (alg, desc.hashsize, alg, 'sha512', alg, alg))
    else:
        exec(hash_class_template % (alg, desc.hashsize, alg, alg, alg, alg))
    if 0:                   # enable if want to see particulars
        print
        print '  name   ', desc.name
        print '  ID     ', desc.ID
        print '  hashsiz', desc.hashsize
        print '  blksize', desc.blocksize
        print '  OID    ', '.'.join([str(x) for x in desc.OID[:desc.OIDlen]])
    

#=============================================================================
# PRNG Algorithms
#
# The goal here is to create a class, one for each PRNG algorithm, for 
# standalone use.  (The registration process is not required to use these 
# classes.)
#
# Typical usage:
#     prng = fortuna()
#     prng.add_entropy(<some entropy>)  # optional
#     prng.ready()                      # optional, reqd if add more entropy
#     randbytes = prng.read(<number to read>)
#

class PRNGDescriptors(Structure):
    _fields_ = [('name',        c_char_p),
                ('export_size', c_uint),
                ('start',       CFUNCTYPE(c_int, c_char_p)),
                ('add_entropy', CFUNCTYPE(c_char_p, c_ulong, c_char_p)),
                ('ready',       CFUNCTYPE(c_int, c_char_p)),
                ('read',        CFUNCTYPE(c_int, c_char_p)),
                ('done',        CFUNCTYPE(c_int, c_char_p)),
                ('pexport',     CFUNCTYPE(c_int, POINTER(c_int))),
                ('pimport',     CFUNCTYPE(c_int, POINTER(c_int))),
                ('test',        CFUNCTYPE(c_int)),
               ]


# PrngAlgorithm is a base class
class PrngAlgorithm(object):
    """ This is a base class for each of the PRNG algorithms.
        Most of the logic is here because of the parallel structure
        adopted by Tom.
    """
    def _init(self):
        "setup instance and then call self.init()"
        self.name = self.__class__.__name__
        self.idx = LTC.find_prng(self.name)
        self._prng_state = c_buffer(self._state_size)
        self.start()
    
    def start(self):
        "initializes state to initial conditions"
        err = self._start_func(byref(self._prng_state))
        if err: raise Exception('start error = %d (%s)' % (err, error_codes[err]))
        # the next two obviate the need to perform them externally.
        # although, they may be called individually if you want to
        # set with a specific seed
        self.add_entropy()
        self.ready()

    def add_entropy(self, entropy=''):
        if not entropy:
            entropy = struct.pack('d', time.time())     # length 8 bytes
            md = sha256()                               # sha256 returns 32 bytes
            md.update(entropy*4)                        # input 32 bytes
            entropy = md.final()
        # sober128 accepts entropy in multiples of 4 bytes
        # fortuna accepts a max of 32 bytes
        # the others appear to be unlimited by size
        if self.name == 'sober128':
            entropy = entropy[:-(len(entropy)%4)]
        if self.name == 'fortuna':
            entropy = entropy[:32]
        err = self._add_entropy_func(entropy, len(entropy), byref(self._prng_state))
        if err: 
            raise Exception('add_entropy error = %d (%s)' % (err, error_codes[err]))
        
    def ready(self):
        err = self._ready_func(byref(self._prng_state))
        if err: raise Exception('ready error = %d (%s)' % (err, error_codes[err]))
        
    def read(self, numtoread):
        buffer = c_buffer(numtoread)
        num = self._read_func(byref(buffer), numtoread, byref(self._prng_state))
        if num < 0: raise Exception('read error')
        if num != numtoread: raise Exception('read error2')
        return buffer.raw
        
    def done(self):
        err = self._done_func(byref(self._prng_state))
        if err: raise Exception('done error = %d (%s)' % (err, error_codes[err]))
        
    def test(self):
        err = self._test_func(byref(self._prng_state))
        if err: raise Exception('test error = %d (%s)' % (err, error_codes[err]))
        
    def get_state(self):
        out = c_buffer(self._state_size)
        outlen = c_int(self._state_size)
        err = self._export_func(byref(out), byref(outlen), byref(self._prng_state))
        if err: raise Exception('export error = %d (%s)' % (err, error_codes[err]))
        return out.raw[:outlen.value]
        
    def load_state(self, state):
        err = self._import_func(state, len(state), byref(self._prng_state))
        if err: raise Exception('import error = %d (%s)' % (err, error_codes[err]))
        
    def _estimate_state_size(self):
        "not used if libtomcrypt_plus library is installed"
        est = 14000                     # hopefully large enough
        print '  Note: estimated state buffer for %s is %d bytes' % (self.name, est)
        return est
    

# dynamically create a class for each PRNG algorithm by 'exec'uting
# the following template.  Note the inheritance from PrngAlgorithm and
# the use of %s and %d placeholders.  

prng_class_template = """
class %s(PrngAlgorithm):
    def __init__(self):
        self._state_size       = prng_state_size    # global
        self._start_func       = LTC.%s_start
        self._add_entropy_func = LTC.%s_add_entropy
        self._ready_func       = LTC.%s_ready
        self._read_func        = LTC.%s_read
        self._done_func        = LTC.%s_done
        self._test_func        = LTC.%s_test
        self._export_func      = LTC.%s_export
        self._import_func      = LTC.%s_import
        self._init()
"""

# make a class, one for each algorithm.  The class name will be exactly
# as specified in "prng_algorithms" above.
for alg in prng_algorithms:
    prng_desc = eval('LTC.'+alg+'_desc')    # get the descriptor
    desc = cast(prng_desc, POINTER(PRNGDescriptors))[0]
    exec(prng_class_template % (alg, alg, alg, alg, alg, alg, alg, alg, alg))
    if 0:                   # enable if want to see particulars
        print
        print '  name       ', desc.name
        print '  export_size', desc.export_size


#=============================================================================
# Symmetric Encryption Algorithms
#
# You will not be accessing the libtomcrypt encryption algorithms directly.  
# Instead, you will create an instance of CipherAlgorithm below passing in
# the desired algorithm, mode, key, and depending on the mode chosen, an IV,
# and a CTR mode.  (Registration is necessary for symmetric encryption.)

class CipherDescriptors(Structure):
    _fields_ = [('name',            c_char_p),
                ('ID',              c_byte),
                ('min_key_length',  c_ulong),
                ('max_key_length',  c_ulong),
                ('block_length',    c_ulong),
                ('default_rounds',  c_ulong),
#                  [ snip ]
               ]

class Cipher_Context(object):

    def __init__(self, alg, mode, key, IV='', ctr_mode=CTR_COUNTER_BIG_ENDIAN):
        
        # fix alg issues (inconsistent use of some names)  ???  <<<<<<<<<<<<<<<<<<<<<
        if alg == 'rijndael':
            alg = 'aes'
        if alg == 'des3':
            cipher_idx = LTC.find_cipher('3des')
        elif alg == 'kseed':
            cipher_idx = LTC.find_cipher('seed')
        else:
            cipher_idx = LTC.find_cipher(alg)
#        print 'cipher_idx', cipher_idx
        if cipher_idx == -1:
            raise Exception('unsupported algorithm')

        # fix key issues
        keylen = self._get_key_len(alg)
        key = (key+chr(0)*keylen)[:keylen]  # pad w/nulls, then lop off
        key = (key*keylen)[:keylen]  # pad then lop off

        # fix IV issues
        blocksize = self._get_IV_len(alg)
        if mode != 'ecb' and not IV:
            IV = a_prng.read(blocksize)

        self._alg  = alg
        self._mode = mode
        self._key  = key
        self._IV   = IV[:blocksize]
        self._ctr_mode = ctr_mode
        self._blocksize = blocksize

#        print '  keylen   ', keylen
#        print '  blocksize', blocksize
        
        rounds = 0                      # use default
        if mode == 'ecb':
            self._state = c_buffer(symmetric_ECB_size)
            err = LTC.ecb_start(cipher_idx, key, len(key), rounds, self._state)
        elif mode == 'cbc':
            self._state = c_buffer(symmetric_CBC_size)
            err = LTC.cbc_start(cipher_idx, IV, key, len(key), rounds, self._state)
        elif mode == 'ctr':
            self._state = c_buffer(symmetric_CTR_size)
            err = LTC.ctr_start(cipher_idx, IV, key, len(key), rounds, ctr_mode, self._state)
        elif mode == 'cfb':
            self._state = c_buffer(symmetric_CFB_size)
            err = LTC.cfb_start(cipher_idx, IV, key, len(key), rounds, self._state)
        elif mode == 'ofb':
            self._state = c_buffer(symmetric_OFB_size)
            err = LTC.ofb_start(cipher_idx, IV, key, len(key), rounds, self._state)
#        elif mode == 'lrw':
#            self._state = c_buffer(symmetric_LRW_size)
#      ???      err = LTC.lrw_start(cipher_idx, IV, key, len(key), rounds, self._state)
#        elif mode == 'f8':
#            self._state = c_buffer(symmetric_F8_size)
#      ???      err = LTC.f8_start(cipher_idx, IV, key, len(key), rounds, self._state)
        else:
            raise Exception('unsupported mode')
        
    def encrypt(self, plain_text):
        size = len(plain_text)
        encrypted_text = c_buffer(size)
        err = eval('LTC.%s_encrypt(plain_text, encrypted_text, size, self._state)' % self._mode)
        if err: 
            raise Exception('encrypt error = %d (%s)' % (err, error_codes[err]))
        return encrypted_text.raw

    def decrypt(self, cipher_text, IV=''):
        if not IV:
            IV = self._IV
        if self._mode != 'ecb':
            self.setiv(IV)
        size = len(cipher_text)
        decrypted_text = c_buffer(size)
        err = eval('LTC.%s_decrypt(cipher_text, decrypted_text, size, self._state)' % self._mode)
        if err: 
            print err
            raise Exception('decrypt error = %d (%s)' % (err, error_codes[err]))
        return decrypted_text.raw

    def done(self):
        err = eval('LTC.%s_done(self._state)' % self._mode)
        if err: raise Exception('done error = %d (%s)' % (err, error_codes[err]))
        
    def test(self):
        if self._mode not in ['ecb', 'cbc']:
            err = eval('LTC.%s_test(self._state)' % self._mode)
            if err: 
                raise Exception('test error = %d (%s)' % (err, error_codes[err]))
        
    def setiv(self, IV):
        if self._mode != 'ecb':
            err = eval('LTC.%s_setiv(IV, len(IV), self._state)' % self._mode)
            if err: 
                raise Exception('setiv error = %d (%s)' % (err, error_codes[err]))
        
    def getiv(self):
        if self._mode != 'ecb':
            IV = c_buffer(self._blocksize)
            IVlen = c_int(self._blocksize)
            err = eval('LTC.%s_getiv(IV, byref(IVlen), self._state)' % self._mode)
            if err: 
                raise Exception('getiv error = %d (%s)' % (err, error_codes[err]))
            if IVlen.value != self._blocksize: raise Exception('getiv error2')
            return IV.raw
        else:
            return ''

    def _get_key_len(self, alg, prefkeylen=1000):
        # should be desired key len but all
        # seem to support providing the max
        # if the imput is too big.  so, for now...
        keylen_p = c_uint(prefkeylen)     
        if alg not in ['aes', 'rijndael']:
            err = eval('LTC.%s_keysize(byref(keylen_p))' % alg)
        else:
            err = LTC.rijndael_keysize(byref(keylen_p)) # why not aes?????
        keylen = int(keylen_p.value)
        return keylen
    
    def _get_IV_len(self, alg):
        cipher_desc = eval('LTC.%s_desc' % alg)
        desc = cast(cipher_desc, POINTER(CipherDescriptors))[0]
        if 0:
            print
            print '  name   ', desc.name
            print '  ID     ', desc.ID
            print '  min_key_length', desc.min_key_length
            print '  max_key_length', desc.max_key_length
            print '  block_length', desc.block_length
            print '  default_rounds', desc.default_rounds        
        return desc.block_length
    get_blocksize = _get_IV_len
        
    def _estimate_state_size(self):
        "not used if libtomcrypt_plus library is installed"
        est = 5000                     # hopefully large enough
        print '  Note: estimated state buffer for %s is %d bytes' % (self.name, est)
        return est
    
# Several padding schemes exist.  For now, this is the only one supported.
def pad(message, blocksize):
    "padding scheme:  80 + 00 00 00 ... until block filled"
    lenmsg = len(message)
    pads = chr(80)
    remain = (lenmsg+1) % blocksize
    if remain:
        pads += chr(0)*(blocksize-remain)
    return message+pads

def unpad(message):
    "padding scheme:  80 + 00 00 00 ... until block filled"
    pad80 = message.rfind(chr(80))
    if pad80 == -1:
        raise Exception('no trailing pad found')  # also other tests
    return message[:pad80]

#=============================================================================
#=============================================================================
# Elliptical Curve Cryptography

def select_hash(inlen, pklen):                  # everything in bytes
    candidate_hashes = [ ('sha512', 512/8),
                         ('sha256', 256/8),
                         ('sha1',   160/8),
                         ('md5',    128/8),
                         ('nogo',   0),
                       ]

    for hash_alg, hashlen in candidate_hashes:
        if inlen <= hashlen <= pklen:
            break
    if hash_alg == 'nogo':
        raise Exception('need another hash - in %d ecc %d' % (inlen, pklen))
    return hash_alg, hashlen
    
class ECC_Set_Type(Structure):
    _fields_ = [('size',   c_int),      # The size of the curve in octets
                ('name',   c_char_p),   # The name of curve
                ('prime',  c_char_p),   # The prime that defines the field the 
                                        # curve is in (encoded in hex) 
                ('B',      c_char_p),   # The fields B param (hex) 
                ('order',  c_char_p),   # The order of the curve (hex) 
                ('Gx',     c_char_p),   # The x co-ordinate of the base point 
                                        # on the curve (hex) 
                ('Gy',     c_char_p),   # The y co-ordinate of the base point 
                                        # on the curve (hex) 
               ]

class ECC_Point(Structure):
    _fields_ = [('x', POINTER(MP_Int)), # The x co-ordinate
                ('y', POINTER(MP_Int)), # The y co-ordinate
                ('z', POINTER(MP_Int)), # The z co-ordinate
               ]

class ECC_Key(Structure):
    _fields_ = [('type',   c_int),      # Type of key, PK_PRIVATE or PK_PUBLIC
                ('idx',    c_int),      # Index into the ltc_ecc_sets[] for the 
                                        # parameters of this curve; if -1, then 
                                        # this key is using user supplied curve 
                                        # in dp
                ('dp', POINTER(ECC_Set_Type)), # pointer to domain parameters; either 
                                        # points to NIST curves (identified by 
                                        # idx >= 0) or user supplied curve
                ('pubkey', ECC_Point),  # The public key
                ('k', POINTER(MP_Int)), # The private key
               ]



class ECC_Context:
    def __init__(self):
        self.key = ECC_Key()            # empty
        self.key.type = -1              # flag as invalid
        self.keylen = None

    def make_key(self, keylen):
        self.keylen = keylen
        err = LTC.ecc_make_key(a_prng._prng_state, a_prng.idx, 
                               keylen/8, byref(self.key))
        if err: 
            raise Exception('ecc_make_key error = %d (%s)' % (err, error_codes[err]))
        if 0:
            ecc_key = self.key
            print 'ecc_key.type     ', ecc_key.type
            print 'ecc_key.idx      ', ecc_key.idx
            print 'ecc_key.dp.size  ', ecc_key.dp.contents.size, '       <<<<<<<<<<<<<<<<<<<<<< works'
            print 'ecc_key.dp.name  ', ecc_key.dp.contents.name
            print 'ecc_key.dp.prime ', ecc_key.dp.contents.prime
            print 'ecc_key.dp.B     ', ecc_key.dp.contents.B
            print 'ecc_key.dp.order ', ecc_key.dp.contents.order
            print 'ecc_key.dp.Gx    ', ecc_key.dp.contents.Gx
            print 'ecc_key.dp.Gy    ', ecc_key.dp.contents.Gy    
            print 'ecc_key.pubkey.x ', mpi2hex(ecc_key.pubkey.x)  #, ecc_key.dp.contents.size)
            print 'ecc_key.pubkey.y ', mpi2hex(ecc_key.pubkey.y)  #, ecc_key.dp.contents.size)
            print 'ecc_key.pubkey.z ', mpi2hex(ecc_key.pubkey.z)  #, ecc_key.dp.contents.size)
            print 'ecc_key.k (pvt)  ', mpi2hex(ecc_key.k.contents)  #, ecc_key.dp.contents.size)
    #        print 'ecc_key.k.contents.used  ', ecc_key.k.contents.used
    #        print 'ecc_key.k.contents.alloc ', ecc_key.k.contents.alloc
    #        print 'ecc_key.k.contents.size  ', ecc_key.k.contents.size
            print

    def load_key(self, key):
        err = LTC.ecc_import(key, len(key), byref(self.key))
        self.keylen = self.key.dp.contents.size*8
        if err: 
            raise Exception('ecc_import error = %d (%s)' % (err, error_codes[err]))
        
    def load_key_ex(self, key, params=None):
        err = LTC.ecc_import_ex(key, len(key), byref(self.key), byref(params))
        self.keylen = self.key.dp.contents.size*8
        if err: 
            raise Exception('ecc_import error = %d (%s)' % (err, error_codes[err]))
        
    def free_key(self):
        if self.key.type in [0, 1]: 
            LTC.ecc_free(self.ecc_key)        
    
    def get_key_size(self):
        return self.keylen
        
    def get_size(self):
        print "ecc_get_size"
        size = LTC.ecc_get_size(byref(ecc_key)) # <<<<<<<<<<<<<<< doesn't work
        print '  size', err, error_codes[err]
        if size == 0:
            raise Exception("\n  *** get_size doesn't work ***\n")
        return size * 8
        
    def _get_sizes(self):
        "return min and max key lengths in bytes, mult by 8 to get bits"
        low  = c_int()
        high = c_int()
        print 'ecc_sizes'
        LTC.ecc_sizes(byref(low), byref(high))
        print '  low high', low.value, high.value
        return low.value, high.value
        
    def get_private_key(self):
        size = 300         # 215 is the largest observed for ECC-521  <<<<<<<<<<
#        size = LTCplus.ecc_key_size(PRIVATE_KEY, keylen)   # not yet implemented
        out = c_buffer(size)
        outlen = c_int(size)
        err = LTC.ecc_export(out, byref(outlen), PRIVATE_KEY, byref(self.key))
        if err: 
            raise Exception('ecc_export pvt error = %d (%s)' % (err, error_codes[err]))
#        print '  pvt key (asn1) ', byt2hex(out.raw[:outlen.value]), outlen.value
        return out.raw[:outlen.value]
        
    def get_public_key(self):
        size = 200         # 146 is the largest observed for ECC-521  <<<<<<<<<<
#        size = LTCplus.ecc_key_size(PRIVATE_KEY, keylen)   # not yet implemented
        out = c_buffer(size)
        outlen = c_int(size)
        err = LTC.ecc_export(out, byref(outlen), PUBLIC_KEY, byref(self.key))
        if err: 
            raise Exception('ecc_export pub error = %d (%s)' % (err, error_codes[err]))
#        print '  pub key (asn1) ', byt2hex(out.raw[:outlen.value]), outlen.value
        return out.raw[:outlen.value]
                
    def sign(self, digest):
        "digest is a byte string, signature is DER"
        size = 200          # 139 is the largest observed for ECC-521  <<<<<<<<<<
#        size = LTCplus.ecc_sig_size(keylen)   # not yet implemented
        sig = c_buffer(size)
        siglen = c_int(size)
        err = LTC.ecc_sign_hash(digest, len(digest), 
                                sig, byref(siglen),
                                a_prng._prng_state, a_prng.idx,
                                byref(self.key))
        if err: 
            raise Exception('ecc_sign_hash error = %d (%s)' % (err, error_codes[err]))
        return sig.raw[:siglen.value]
        
    def verify(self, digest, signature):
        "digest is a byte string, signature is DER"
        good = c_int()
        err = LTC.ecc_verify_hash(signature, len(signature), 
                                  digest, len(digest),
                                  byref(good), byref(self.key))
        if err: 
            raise Exception('ecc_verify_hash error = %d (%s)' % (err, error_codes[err]))
        return good.value
        
    def encrypt(self, plaintext):
        "plaintext is a byte string, ciphertext is a byte string"
        hash_alg, hashlen = select_hash(len(plaintext), self.keylen/8)
        hash_idx = LTC.find_hash(hash_alg)
#        print '  hash alg = %s, idx %d' % (hash_alg, hash_idx)
        ciphertext = c_buffer(512)  # <<<<<<<<<<<<<<<<<<<<<<<<<<<
        lenciphertext = c_int(512)
        err = LTC.ecc_encrypt_key(plaintext, len(plaintext), 
                                  byref(ciphertext), byref(lenciphertext),
                                  a_prng._prng_state, a_prng.idx, 
                                  hash_idx,
                                  byref(self.key))
        if err:
            raise Exception('ecc_encrypt_key error = %d (%s)' % (err, error_codes[err]))
        return ciphertext.raw[:lenciphertext.value]
    
    def decrypt(self, ciphertext):
        plaintext = c_buffer(512)
        lenplaintext = c_int(512)
        err = LTC.ecc_decrypt_key(ciphertext, len(ciphertext), 
                                  byref(plaintext), byref(lenplaintext),
                                  byref(self.key))
        if err: 
            raise Exception('ecc_decrypt_key error = %d (%s)' % (err, error_codes[err]))
        return plaintext.raw[:lenplaintext.value]
        
    def test(self):
        err = LTC.ecc_test()
        if err: 
            raise Exception('ecc_test error = %d (%s)' % (err, error_codes[err]))
        

#=============================================================================
# The RSA Algorithm


class RSA_Key(Structure):
    _fields_ = [('type',    c_int),     # Type of key, PK_PRIVATE or PK_PUBLIC
                ('e',  POINTER(MP_Int)),  # The public exponent
                ('d',  POINTER(MP_Int)),  # The private exponent
                ('N',  POINTER(MP_Int)),  # The modulus
                ('p',  POINTER(MP_Int)),  # The p factor of N
                ('q',  POINTER(MP_Int)),  # The q factor of N
                ('qP', POINTER(MP_Int)),  # The 1/q mod p CRT param
                ('dP', POINTER(MP_Int)),  # The d mod (p - 1) CRT param
                ('dQ', POINTER(MP_Int)),  # The d mod (q - 1) CRT param
               ]


class RSA_Context:
    def __init__(self, key=None):
        self.key = RSA_Key()            # empty
        self.key.type = -1              # flag as invalid
        self.keylen = None

    def make_key(self, keylen):
        e = 65537                       # default (my choice)  <<<<<<<<<<<<<
        self.keylen = keylen
        err = LTC.rsa_make_key(byref(a_prng._prng_state), a_prng.idx, 
                         keylen/8, e, byref(self.key))
        if err:
            raise Exception('rsa_make_key error = %d (%s)' % (err, error_codes[err]))
        if 0:
            key = self.key
            if key.type:
                print '  pvt'
                print '  d (pvt key)', mpi2hex(key.d.contents)
            else:
                print '  pub'
            print '  e', mpi2hex(key.e.contents)
            print '  N', mpi2hex(key.N.contents)
            print '  p', mpi2hex(key.p.contents)
            print '  q', mpi2hex(key.q.contents)
            print
    
    def load_key(self, key):
        err = LTC.rsa_import(key, len(key), byref(self.key))
        if err:
            raise Exception('rsa_import error = %d (%s)' % (err, error_codes[err]))
        self.keylen = len(mpi2hex(self.key.p.contents))*4     # 4 bits per hex char

    def free_key(self):
        if self.key.type in [0, 1]: 
            LTC.rsa_free(self.key)
    
    def get_key_size(self):
        return self.keylen
        
    def get_private_key(self):
        size = 2400         # 2350 is the largest observed for RSA-4096  <<<<<<<<<<
                            #  610 is the largest observed for RSA-1024
#        size = LTCplus.rsa_key_size(PRIVATE_KEY, keylen)   # not yet implemented
        out = c_buffer(size)
        outlen = c_int(size)
        err = LTC.rsa_export(byref(out), byref(outlen), 
                             PRIVATE_KEY, byref(self.key))
        if err: 
            raise Exception('rsa_export pvt error = %d (%s)' % (err, error_codes[err]))
        return out.raw[:outlen.value]
        
    def get_public_key(self):
        size = 550          # 526 is the largest observed for RSA-4096  <<<<<<<<<<
                            # 140 is the largest observed for RSA-1024
#        size = LTCplus.rsa_key_size(PUBLIC_KEY, keylen)    # not yet implemented
        out = c_buffer(size)
        outlen = c_int(size)
        err = LTC.rsa_export(byref(out), byref(outlen), 
                             PUBLIC_KEY, byref(self.key))
        if err: 
            raise Exception('rsa_export pub error = %d (%s)' % (err, error_codes[err]))
        return out.raw[:outlen.value]
        
    def sign(self, digest, hash_alg, padding=LTC_PKCS_1_V1_5):
        "digest is a byte string, signature is ASN1"
        # Type of padding only...   LTC_PKCS_1_PSS or LTC_PKCS_1_V1_5
        if padding not in [LTC_PKCS_1_PSS, LTC_PKCS_1_V1_5]:
            raise Exception('invalid padding type for RSA signature')
        size = 512          # 512 is the largest observed for RSA-4096  <<<<<<<<<<
#        size = LTCplus.rsa_sig_size(keylen)   # not yet implemented
        sig = c_buffer(size)
        siglen = c_int(size)
        saltlen = 8                         # 8-16, see pg 111 <<<<<<<<<<<<<
        hash_idx = LTC.find_hash(hash_alg)
        print '  hash alg = %s, idx %d' % (hash_alg, hash_idx)
        err = LTC.rsa_sign_hash_ex(digest, len(digest), 
                                   byref(sig), byref(siglen),
                                   padding,
                                   a_prng._prng_state, a_prng.idx,
                                   hash_idx,
                                   saltlen, byref(self.key))
        if err:
            raise Exception('rsa_sign_hash error = %d (%s)' % (err, error_codes[err]))
        return sig.raw[:siglen.value]

    def sign_PSS(self, digest):
        "digest is a byte string, signature is ASN1"
        return self.sign(digest, 'sha1', LTC_PKCS_1_PSS)

    def verify(self, digest, hash_alg, signature, padding=LTC_PKCS_1_V1_5):
        "digest is a byte string, signature is ASN1"
        if padding not in [LTC_PKCS_1_PSS, LTC_PKCS_1_V1_5]:
            raise Exception('invalid padding type for RSA verification')
        saltlen = 8                         # 8-16, see pg 111 <<<<<<<<<<<<<
        hash_idx = LTC.find_hash(hash_alg)
        print '  hash alg = %s, idx %d' % (hash_alg, hash_idx)
        good = c_int()
        err = LTC.rsa_verify_hash_ex(signature, len(signature), 
                                     digest, len(digest),
                                     padding,
                                     hash_idx,
                                     saltlen, byref(good), 
                                     byref(self.key))
        if err:
            raise Exception('rsa_verify_hash error = %d (%s)' % (err, error_codes[err]))
        return good.value
        
    def verify_PSS(self, digest, signature):
        "digest is a byte string, signature is ASN1"
        return self.verify(digest, 'sha1', signature, LTC_PKCS_1_PSS)
        
    def encrypt(self, payload, lparam, hash_alg='sha1', padding=LTC_PKCS_1_V1_5):
        hash_idx = LTC.find_hash(hash_alg)
        out = c_buffer(512)
        outlen = c_int(512)
        err = LTC.rsa_encrypt_key_ex(payload, len(payload), 
                            byref(out), byref(outlen),
                            lparam, len(lparam),
                            a_prng._prng_state, a_prng.idx, 
                            hash_idx, padding, 
                            byref(self.key))
        if err: 
            raise Exception('rsa_encrypt_key_ex error = %d (%s)' % (err, error_codes[err]))
        return out.raw[:outlen.value]
        
    encrypt_PKCS1 = encrypt
    
    def encrypt_OAEP(self, payload, lparam, hash_alg='sha1'):
        return self.encrypt(payload, lparam, hash_alg, LTC_PKCS_1_OAEP)
    
    def decrypt(self, payload, lparam, hash_alg='sha1', padding=LTC_PKCS_1_V1_5):
        hash_idx = LTC.find_hash(hash_alg)
        out = c_buffer(512)
        outlen = c_int(512)
        good = c_int()
        err = LTC.rsa_decrypt_key_ex(payload, len(payload), 
                                     byref(out), byref(outlen),
                                     lparam, len(lparam),
                                     hash_idx, 2, byref(good), 
                                     byref(self.key))
        if err: 
            raise Exception('rsa_decrypt_key_ex error = %d (%s)' % (err, error_codes[err]))
        if not good:
            raise Exception('invalid rsa decryption')
        return out.raw[:outlen.value]
        
    decrypt_PKCS1 = decrypt
    
    def decrypt_OAEP(self, payload, lparam, hash_alg='sha1'):
        return self.decrypt(payload, lparam, hash_alg, LTC_PKCS_1_OAEP)
    

#=============================================================================
# The DSA Algorithm

class DSA_Key(Structure):
    _fields_ = [('type',    c_int),       # Type of key, PK_PRIVATE or PK_PUBLIC
                ('ord',     c_int),       # order in bytes
                ('g',  POINTER(MP_Int)),  # param - base generator
                ('p',  POINTER(MP_Int)),  # param - prime modulus
                ('q',  POINTER(MP_Int)),  # param - order of sub-group
                ('y',  POINTER(MP_Int)),  # public key
                ('x',  POINTER(MP_Int)),  # private key
               ]

class DSA_Context:
    def __init__(self, key=None):
        self.key = DSA_Key()            # empty
        self.key.type = -1              # flag as invalid
        self.keylen = None

    def make_key(self, keylen):
        assert keylen >= 1024, 'minimum supported DSA key length is 1024'
        self.keylen = keylen
        modulus_size = 128 * (keylen/1024)
        group_size = 20 * (keylen/1024)
        err = LTC.dsa_make_key(byref(a_prng._prng_state), a_prng.idx, 
                               group_size, modulus_size, 
                               byref(self.key))
        if err:
            raise Exception('dsa_make_key error = %d (%s)' % (err, error_codes[err]))
        if 0:
            key = self.key
            print 'ord', key.ord
            if key.type:
                print '  pvt'
                print '  x (pvt key)', mpi2hex(key.x.contents)
            else:
                print '  pub'
            print '  g', mpi2hex(key.g.contents)
            print '  p', mpi2hex(key.p.contents)
            print '  q', mpi2hex(key.q.contents)
            print '  y', mpi2hex(key.y.contents)
            print
    
    def load_key(self, key):
        err = LTC.dsa_import(key, len(key), byref(self.key))
        if err:
            raise Exception('dsa_import error = %d (%s)' % (err, error_codes[err]))
        self.keylen = len(mpi2hex(self.key.p.contents))*4    # 4 bits per hex char

    def free_key(self):
        if self.key.type in [0, 1]: 
            LTC.dsa_free(self.key)    
    
    def get_key_size(self):
        return self.keylen
        
    def get_private_key(self):
        size = 1800         # 1724 is the largest observed for DSA-4096  <<<<<<<<<<
                            #  449 is the largest observed for DSA-1024
#        size = LTCplus.dsa_key_size(PRIVATE_KEY, keylen)   # not yet implemented
        out = c_buffer(size)
        outlen = c_int(size)
        err = LTC.dsa_export(byref(out), byref(outlen), 
                             PRIVATE_KEY, byref(self.key))
        if err: 
            raise Exception('dsa_export pvt error = %d (%s)' % (err, error_codes[err]))
        return out.raw[:outlen.value]
        
    def get_public_key(self):
        size = 1800         # 1642 is the largest observed for DSA-4096  <<<<<<<<<<
                            #  426 is the largest observed for DSA-1024
#        size = LTCplus.dsa_key_size(PRIVATE_KEY, keylen)   # not yet implemented
        out = c_buffer(size)
        outlen = c_int(size)
        err = LTC.dsa_export(byref(out), byref(outlen), 
                             PUBLIC_KEY, byref(self.key))
        if err: 
            raise Exception('dsa_export pub error = %d (%s)' % (err, error_codes[err]))
        return out.raw[:outlen.value]
        
    def sign(self, digest):
        "digest is a byte string, signature is ASN1"
        size = 200          # 169 is the largest observed for DSA-4096  <<<<<<<<<<
#        size = LTCplus.dsa_sig_size(hashlen)   # not yet implemented
        sig = c_buffer(size)
        siglen = c_int(size)
        err = LTC.dsa_sign_hash(digest, len(digest), 
                                byref(sig), byref(siglen),
                                a_prng._prng_state, a_prng.idx,
                                byref(self.key))
        if err: 
            raise Exception('dsa_sign_hash error = %d (%s)' % (err, error_codes[err]))
        return sig.raw[:siglen.value]

    def verify(self, digest, signature):
        "digest is a byte string, signature is ASN1"
        good = c_int()
        err = LTC.dsa_verify_hash(signature, len(signature), 
                                  digest, len(digest),
                                  byref(good), 
                                  byref(self.key))
        if err: 
            raise Exception('dsa_verify_hash error = %d (%s)' % (err, error_codes[err]))
        return good.value
        
    def encrypt(self, plaintext):
        raise Exception('dsa encrypt not implemented')
        
    def decrypt(self, ciphertext):
        raise Exception('dsa decrypt not implemented')


#=============================================================================
# ASN.1 and other utilities

class ASN1_List(Structure):
    pass

ASN1_List._fields_ = [
                ('type',   c_int),              # xxxxx
                ('data',   POINTER(c_char)),    # may need to recast to int, etc
                ('size',   c_ulong),            # xxxxx 
                ('used',   c_int),              # xxxxx 
                ('prev',   POINTER(ASN1_List)),
                ('next',   POINTER(ASN1_List)),
                ('child',  POINTER(ASN1_List)),
                ('parent', POINTER(ASN1_List)),
               ]

def asn1_decode(asn1_in):
    inlen = c_ulong(len(asn1_in))
    out = pointer(ASN1_List())
    print 'der_decode_sequence_flexi',
    err = LTC.der_decode_sequence_flexi(asn1_in, byref(inlen), byref(out))
    print '  err', err, error_codes[err]
    return out[0]                   # the [0] dereferences it
    
def base64_encode(bytes_in):
    outlen = c_int((len(bytes_in)+3)*4/3)
    out = c_buffer(outlen.value)
    err = LTC.base64_encode(bytes_in, len(bytes_in), byref(out), byref(outlen))
    if err: 
        raise Exception('base64_encode error = %d (%s)' % (err, error_codes[err]))
    return out.raw[:outlen.value]

def base64_decode(bytes_in):
    outlen = c_int((len(bytes_in))*3/4)
    out = c_buffer(outlen.value)
    err = LTC.base64_decode(bytes_in, len(bytes_in), byref(out), byref(outlen))
    if err: 
        raise Exception('base64_decode error = %d (%s)' % (err, error_codes[err]))
    return out.raw[:outlen.value]


#=============================================================================
# Register algorithms and fire up a default PRNG

register_hashes()
register_ciphers()
register_prngs()

a_prng = sober128()         # default PRNG initialized with 
                            # sha256(current time*4).  

#-----------------------------------------------------------------------------
#-----------------------------------------------------------------------------
#-----------------------------------------------------------------------------
