1#!/usr/local/bin/python2
2#
3# Copyright (c) 2014 The FreeBSD Foundation
4# All rights reserved.
5#
6# This software was developed by John-Mark Gurney under
7# the sponsorship from the FreeBSD Foundation.
8# Redistribution and use in source and binary forms, with or without
9# modification, are permitted provided that the following conditions
10# are met:
11# 1.  Redistributions of source code must retain the above copyright
12#     notice, this list of conditions and the following disclaimer.
13# 2.  Redistributions in binary form must reproduce the above copyright
14#     notice, this list of conditions and the following disclaimer in the
15#     documentation and/or other materials provided with the distribution.
16#
17# THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
18# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20# ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
21# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
22# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
23# OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
24# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
25# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
26# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
27# SUCH DAMAGE.
28#
29# $FreeBSD$
30#
31
32from __future__ import print_function
33import errno
34import cryptodev
35import itertools
36import os
37import struct
38import unittest
39from cryptodev import *
40from glob import iglob
41
42katdir = '/usr/local/share/nist-kat'
43
44def katg(base, glob):
45	assert os.path.exists(katdir), "Please 'pkg install nist-kat'"
46	if not os.path.exists(os.path.join(katdir, base)):
47		raise unittest.SkipTest("Missing %s test vectors" % (base))
48	return iglob(os.path.join(katdir, base, glob))
49
50aesmodules = [ 'cryptosoft0', 'aesni0', 'ccr0', 'ccp0' ]
51desmodules = [ 'cryptosoft0', ]
52shamodules = [ 'cryptosoft0', 'aesni0', 'ccr0', 'ccp0' ]
53
54def GenTestCase(cname):
55	try:
56		crid = cryptodev.Crypto.findcrid(cname)
57	except IOError:
58		return None
59
60	class GendCryptoTestCase(unittest.TestCase):
61		###############
62		##### AES #####
63		###############
64		@unittest.skipIf(cname not in aesmodules, 'skipping AES-XTS on %s' % (cname))
65		def test_xts(self):
66			for i in katg('XTSTestVectors/format tweak value input - data unit seq no', '*.rsp'):
67				self.runXTS(i, cryptodev.CRYPTO_AES_XTS)
68
69		@unittest.skipIf(cname not in aesmodules, 'skipping AES-CBC on %s' % (cname))
70		def test_cbc(self):
71			for i in katg('KAT_AES', 'CBC[GKV]*.rsp'):
72				self.runCBC(i)
73
74		@unittest.skipIf(cname not in aesmodules, 'skipping AES-CCM on %s' % (cname))
75		def test_ccm(self):
76			for i in katg('ccmtestvectors', 'V*.rsp'):
77				self.runCCMEncrypt(i)
78
79			for i in katg('ccmtestvectors', 'D*.rsp'):
80				self.runCCMDecrypt(i)
81
82		@unittest.skipIf(cname not in aesmodules, 'skipping AES-GCM on %s' % (cname))
83		def test_gcm(self):
84			for i in katg('gcmtestvectors', 'gcmEncrypt*'):
85				self.runGCM(i, 'ENCRYPT')
86
87			for i in katg('gcmtestvectors', 'gcmDecrypt*'):
88				self.runGCM(i, 'DECRYPT')
89
90		_gmacsizes = { 32: cryptodev.CRYPTO_AES_256_NIST_GMAC,
91			24: cryptodev.CRYPTO_AES_192_NIST_GMAC,
92			16: cryptodev.CRYPTO_AES_128_NIST_GMAC,
93		}
94		def runGCM(self, fname, mode):
95			curfun = None
96			if mode == 'ENCRYPT':
97				swapptct = False
98				curfun = Crypto.encrypt
99			elif mode == 'DECRYPT':
100				swapptct = True
101				curfun = Crypto.decrypt
102			else:
103				raise RuntimeError('unknown mode: %r' % repr(mode))
104
105			for bogusmode, lines in cryptodev.KATParser(fname,
106			    [ 'Count', 'Key', 'IV', 'CT', 'AAD', 'Tag', 'PT', ]):
107				for data in lines:
108					curcnt = int(data['Count'])
109					cipherkey = data['Key'].decode('hex')
110					iv = data['IV'].decode('hex')
111					aad = data['AAD'].decode('hex')
112					tag = data['Tag'].decode('hex')
113					if 'FAIL' not in data:
114						pt = data['PT'].decode('hex')
115					ct = data['CT'].decode('hex')
116
117					if len(iv) != 12:
118						# XXX - isn't supported
119						continue
120
121					try:
122						c = Crypto(cryptodev.CRYPTO_AES_NIST_GCM_16,
123						    cipherkey,
124						    mac=self._gmacsizes[len(cipherkey)],
125						    mackey=cipherkey, crid=crid,
126						    maclen=16)
127					except EnvironmentError, e:
128						# Can't test algorithms the driver does not support.
129						if e.errno != errno.EOPNOTSUPP:
130							raise
131						continue
132
133					if mode == 'ENCRYPT':
134						try:
135							rct, rtag = c.encrypt(pt, iv, aad)
136						except EnvironmentError, e:
137							# Can't test inputs the driver does not support.
138							if e.errno != errno.EINVAL:
139								raise
140							continue
141						rtag = rtag[:len(tag)]
142						data['rct'] = rct.encode('hex')
143						data['rtag'] = rtag.encode('hex')
144						self.assertEqual(rct, ct, repr(data))
145						self.assertEqual(rtag, tag, repr(data))
146					else:
147						if len(tag) != 16:
148							continue
149						args = (ct, iv, aad, tag)
150						if 'FAIL' in data:
151							self.assertRaises(IOError,
152								c.decrypt, *args)
153						else:
154							try:
155								rpt, rtag = c.decrypt(*args)
156							except EnvironmentError, e:
157								# Can't test inputs the driver does not support.
158								if e.errno != errno.EINVAL:
159									raise
160								continue
161							data['rpt'] = rpt.encode('hex')
162							data['rtag'] = rtag.encode('hex')
163							self.assertEqual(rpt, pt,
164							    repr(data))
165
166		def runCBC(self, fname):
167			curfun = None
168			for mode, lines in cryptodev.KATParser(fname,
169			    [ 'COUNT', 'KEY', 'IV', 'PLAINTEXT', 'CIPHERTEXT', ]):
170				if mode == 'ENCRYPT':
171					swapptct = False
172					curfun = Crypto.encrypt
173				elif mode == 'DECRYPT':
174					swapptct = True
175					curfun = Crypto.decrypt
176				else:
177					raise RuntimeError('unknown mode: %r' % repr(mode))
178
179				for data in lines:
180					curcnt = int(data['COUNT'])
181					cipherkey = data['KEY'].decode('hex')
182					iv = data['IV'].decode('hex')
183					pt = data['PLAINTEXT'].decode('hex')
184					ct = data['CIPHERTEXT'].decode('hex')
185
186					if swapptct:
187						pt, ct = ct, pt
188					# run the fun
189					c = Crypto(cryptodev.CRYPTO_AES_CBC, cipherkey, crid=crid)
190					r = curfun(c, pt, iv)
191					self.assertEqual(r, ct)
192
193		def runXTS(self, fname, meth):
194			curfun = None
195			for mode, lines in cryptodev.KATParser(fname,
196			    [ 'COUNT', 'DataUnitLen', 'Key', 'DataUnitSeqNumber', 'PT',
197			    'CT' ]):
198				if mode == 'ENCRYPT':
199					swapptct = False
200					curfun = Crypto.encrypt
201				elif mode == 'DECRYPT':
202					swapptct = True
203					curfun = Crypto.decrypt
204				else:
205					raise RuntimeError('unknown mode: %r' % repr(mode))
206
207				for data in lines:
208					curcnt = int(data['COUNT'])
209					nbits = int(data['DataUnitLen'])
210					cipherkey = data['Key'].decode('hex')
211					iv = struct.pack('QQ', int(data['DataUnitSeqNumber']), 0)
212					pt = data['PT'].decode('hex')
213					ct = data['CT'].decode('hex')
214
215					if nbits % 128 != 0:
216						# XXX - mark as skipped
217						continue
218					if swapptct:
219						pt, ct = ct, pt
220					# run the fun
221					try:
222						c = Crypto(meth, cipherkey, crid=crid)
223						r = curfun(c, pt, iv)
224					except EnvironmentError, e:
225						# Can't test hashes the driver does not support.
226						if e.errno != errno.EOPNOTSUPP:
227							raise
228						continue
229					self.assertEqual(r, ct)
230
231		def runCCMEncrypt(self, fname):
232			for data in cryptodev.KATCCMParser(fname):
233				Nlen = int(data['Nlen'])
234				if Nlen != 12:
235					# OCF only supports 12 byte IVs
236					continue
237				key = data['Key'].decode('hex')
238				nonce = data['Nonce'].decode('hex')
239				Alen = int(data['Alen'])
240				if Alen != 0:
241					aad = data['Adata'].decode('hex')
242				else:
243					aad = None
244				payload = data['Payload'].decode('hex')
245				ct = data['CT'].decode('hex')
246
247				try:
248					c = Crypto(crid=crid,
249					    cipher=cryptodev.CRYPTO_AES_CCM_16,
250					    key=key,
251					    mac=cryptodev.CRYPTO_AES_CCM_CBC_MAC,
252					    mackey=key, maclen=16)
253					r, tag = Crypto.encrypt(c, payload,
254					    nonce, aad)
255				except EnvironmentError, e:
256					if e.errno != errno.EOPNOTSUPP:
257						raise
258					continue
259
260				out = r + tag
261				self.assertEqual(out, ct,
262				    "Count " + data['Count'] + " Actual: " + \
263				    repr(out.encode("hex")) + " Expected: " + \
264				    repr(data) + " on " + cname)
265
266		def runCCMDecrypt(self, fname):
267			# XXX: Note that all of the current CCM
268			# decryption test vectors use IV and tag sizes
269			# that aren't supported by OCF none of the
270			# tests are actually ran.
271			for data in cryptodev.KATCCMParser(fname):
272				Nlen = int(data['Nlen'])
273				if Nlen != 12:
274					# OCF only supports 12 byte IVs
275					continue
276				Tlen = int(data['Tlen'])
277				if Tlen != 16:
278					# OCF only supports 16 byte tags
279					continue
280				key = data['Key'].decode('hex')
281				nonce = data['Nonce'].decode('hex')
282				Alen = int(data['Alen'])
283				if Alen != 0:
284					aad = data['Adata'].decode('hex')
285				else:
286					aad = None
287				ct = data['CT'].decode('hex')
288				tag = ct[-16:]
289				ct = ct[:-16]
290
291				try:
292					c = Crypto(crid=crid,
293					    cipher=cryptodev.CRYPTO_AES_CCM_16,
294					    key=key,
295					    mac=cryptodev.CRYPTO_AES_CCM_CBC_MAC,
296					    mackey=key, maclen=16)
297				except EnvironmentError, e:
298					if e.errno != errno.EOPNOTSUPP:
299						raise
300					continue
301
302				if data['Result'] == 'Fail':
303					self.assertRaises(IOError,
304					    c.decrypt, payload, nonce, aad, tag)
305				else:
306					r = Crypto.decrypt(c, payload, nonce,
307					    aad, tag)
308
309					payload = data['Payload'].decode('hex')
310					Plen = int(data('Plen'))
311					payload = payload[:plen]
312					self.assertEqual(r, payload,
313					    "Count " + data['Count'] + \
314					    " Actual: " + repr(r.encode("hex")) + \
315					    " Expected: " + repr(data) + \
316					    " on " + cname)
317
318		###############
319		##### DES #####
320		###############
321		@unittest.skipIf(cname not in desmodules, 'skipping DES on %s' % (cname))
322		def test_tdes(self):
323			for i in katg('KAT_TDES', 'TCBC[a-z]*.rsp'):
324				self.runTDES(i)
325
326		def runTDES(self, fname):
327			curfun = None
328			for mode, lines in cryptodev.KATParser(fname,
329			    [ 'COUNT', 'KEYs', 'IV', 'PLAINTEXT', 'CIPHERTEXT', ]):
330				if mode == 'ENCRYPT':
331					swapptct = False
332					curfun = Crypto.encrypt
333				elif mode == 'DECRYPT':
334					swapptct = True
335					curfun = Crypto.decrypt
336				else:
337					raise RuntimeError('unknown mode: %r' % repr(mode))
338
339				for data in lines:
340					curcnt = int(data['COUNT'])
341					key = data['KEYs'] * 3
342					cipherkey = key.decode('hex')
343					iv = data['IV'].decode('hex')
344					pt = data['PLAINTEXT'].decode('hex')
345					ct = data['CIPHERTEXT'].decode('hex')
346
347					if swapptct:
348						pt, ct = ct, pt
349					# run the fun
350					c = Crypto(cryptodev.CRYPTO_3DES_CBC, cipherkey, crid=crid)
351					r = curfun(c, pt, iv)
352					self.assertEqual(r, ct)
353
354		###############
355		##### SHA #####
356		###############
357		@unittest.skipIf(cname not in shamodules, 'skipping SHA on %s' % str(cname))
358		def test_sha(self):
359			for i in katg('shabytetestvectors', 'SHA*Msg.rsp'):
360				self.runSHA(i)
361
362		def runSHA(self, fname):
363			# Skip SHA512_(224|256) tests
364			if fname.find('SHA512_') != -1:
365				return
366
367			for hashlength, lines in cryptodev.KATParser(fname,
368			    [ 'Len', 'Msg', 'MD' ]):
369				# E.g., hashlength will be "L=20" (bytes)
370				hashlen = int(hashlength.split("=")[1])
371
372				if hashlen == 20:
373					alg = cryptodev.CRYPTO_SHA1
374				elif hashlen == 28:
375					alg = cryptodev.CRYPTO_SHA2_224
376				elif hashlen == 32:
377					alg = cryptodev.CRYPTO_SHA2_256
378				elif hashlen == 48:
379					alg = cryptodev.CRYPTO_SHA2_384
380				elif hashlen == 64:
381					alg = cryptodev.CRYPTO_SHA2_512
382				else:
383					# Skip unsupported hashes
384					# Slurp remaining input in section
385					for data in lines:
386						continue
387					continue
388
389				for data in lines:
390					msg = data['Msg'].decode('hex')
391                                        msg = msg[:int(data['Len'])]
392					md = data['MD'].decode('hex')
393
394					try:
395						c = Crypto(mac=alg, crid=crid,
396						    maclen=hashlen)
397					except EnvironmentError, e:
398						# Can't test hashes the driver does not support.
399						if e.errno != errno.EOPNOTSUPP:
400							raise
401						continue
402
403					_, r = c.encrypt(msg, iv="")
404
405					self.assertEqual(r, md, "Actual: " + \
406					    repr(r.encode("hex")) + " Expected: " + repr(data) + " on " + cname)
407
408		@unittest.skipIf(cname not in shamodules, 'skipping SHA-HMAC on %s' % str(cname))
409		def test_sha1hmac(self):
410			for i in katg('hmactestvectors', 'HMAC.rsp'):
411				self.runSHA1HMAC(i)
412
413		def runSHA1HMAC(self, fname):
414			for hashlength, lines in cryptodev.KATParser(fname,
415			    [ 'Count', 'Klen', 'Tlen', 'Key', 'Msg', 'Mac' ]):
416				# E.g., hashlength will be "L=20" (bytes)
417				hashlen = int(hashlength.split("=")[1])
418
419				blocksize = None
420				if hashlen == 20:
421					alg = cryptodev.CRYPTO_SHA1_HMAC
422					blocksize = 64
423				elif hashlen == 28:
424					alg = cryptodev.CRYPTO_SHA2_224_HMAC
425					blocksize = 64
426				elif hashlen == 32:
427					alg = cryptodev.CRYPTO_SHA2_256_HMAC
428					blocksize = 64
429				elif hashlen == 48:
430					alg = cryptodev.CRYPTO_SHA2_384_HMAC
431					blocksize = 128
432				elif hashlen == 64:
433					alg = cryptodev.CRYPTO_SHA2_512_HMAC
434					blocksize = 128
435				else:
436					# Skip unsupported hashes
437					# Slurp remaining input in section
438					for data in lines:
439						continue
440					continue
441
442				for data in lines:
443					key = data['Key'].decode('hex')
444					msg = data['Msg'].decode('hex')
445					mac = data['Mac'].decode('hex')
446					tlen = int(data['Tlen'])
447
448					if len(key) > blocksize:
449						continue
450
451					try:
452						c = Crypto(mac=alg, mackey=key,
453						    crid=crid, maclen=hashlen)
454					except EnvironmentError, e:
455						# Can't test hashes the driver does not support.
456						if e.errno != errno.EOPNOTSUPP:
457							raise
458						continue
459
460					_, r = c.encrypt(msg, iv="")
461
462					self.assertEqual(r[:tlen], mac, "Actual: " + \
463					    repr(r.encode("hex")) + " Expected: " + repr(data))
464
465	return GendCryptoTestCase
466
467cryptosoft = GenTestCase('cryptosoft0')
468aesni = GenTestCase('aesni0')
469ccr = GenTestCase('ccr0')
470ccp = GenTestCase('ccp0')
471
472if __name__ == '__main__':
473	unittest.main()
474