# =================================================================== # # Copyright (c) 2015, Legrandin # All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions # are met: # # 1. Redistributions of source code must retain the above copyright # notice, this list of conditions and the following disclaimer. # 2. Redistributions in binary form must reproduce the above copyright # notice, this list of conditions and the following disclaimer in # the documentation and/or other materials provided with the # distribution. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS # FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE # COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, # INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, # BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT # LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN # ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE # POSSIBILITY OF SUCH DAMAGE. # =================================================================== import unittest import time from Crypto.SelfTest.st_common import list_test_cases from Crypto.SelfTest.loader import load_tests from Crypto.PublicKey import ECC from Crypto.PublicKey.ECC import EccPoint, _curve, EccKey class TestEccPoint_NIST(unittest.TestCase): """Tests defined in section 4.3 of https://www.nsa.gov/ia/_files/nist-routines.pdf""" pointS = EccPoint( 0xde2444bebc8d36e682edd27e0f271508617519b3221a8fa0b77cab3989da97c9, 0xc093ae7ff36e5380fc01a5aad1e66659702de80f53cec576b6350b243042a256) pointT = EccPoint( 0x55a8b00f8da1d44e62f6b3b25316212e39540dc861c89575bb8cf92e35e0986b, 0x5421c3209c2d6c704835d82ac4c3dd90f61a8a52598b9e7ab656e9d8c8b24316) def test_set(self): pointW = EccPoint(0, 0) pointW.set(self.pointS) self.assertEqual(pointW, self.pointS) def test_copy(self): pointW = self.pointS.copy() self.assertEqual(pointW, self.pointS) pointW.set(self.pointT) self.assertEqual(pointW, self.pointT) self.assertNotEqual(self.pointS, self.pointT) def test_addition(self): pointRx = 0x72b13dd4354b6b81745195e98cc5ba6970349191ac476bd4553cf35a545a067e pointRy = 0x8d585cbb2e1327d75241a8a122d7620dc33b13315aa5c9d46d013011744ac264 pointR = self.pointS + self.pointT self.assertEqual(pointR.x, pointRx) self.assertEqual(pointR.y, pointRy) pai = EccPoint.point_at_infinity() # S + 0 pointR = self.pointS + pai self.assertEqual(pointR, self.pointS) # 0 + S pointR = pai + self.pointS self.assertEqual(pointR, self.pointS) # 0 + 0 pointR = pai + pai self.assertEqual(pointR, pai) def test_inplace_addition(self): pointRx = 0x72b13dd4354b6b81745195e98cc5ba6970349191ac476bd4553cf35a545a067e pointRy = 0x8d585cbb2e1327d75241a8a122d7620dc33b13315aa5c9d46d013011744ac264 pointR = self.pointS.copy() pointR += self.pointT self.assertEqual(pointR.x, pointRx) self.assertEqual(pointR.y, pointRy) pai = EccPoint.point_at_infinity() # S + 0 pointR = self.pointS.copy() pointR += pai self.assertEqual(pointR, self.pointS) # 0 + S pointR = pai.copy() pointR += self.pointS self.assertEqual(pointR, self.pointS) # 0 + 0 pointR = pai.copy() pointR += pai self.assertEqual(pointR, pai) def test_doubling(self): pointRx = 0x7669e6901606ee3ba1a8eef1e0024c33df6c22f3b17481b82a860ffcdb6127b0 pointRy = 0xfa878162187a54f6c39f6ee0072f33de389ef3eecd03023de10ca2c1db61d0c7 pointR = self.pointS.copy() pointR.double() self.assertEqual(pointR.x, pointRx) self.assertEqual(pointR.y, pointRy) # 2*0 pai = self.pointS.point_at_infinity() pointR = pai.copy() pointR.double() self.assertEqual(pointR, pai) # S + S pointR = self.pointS.copy() pointR += pointR self.assertEqual(pointR.x, pointRx) self.assertEqual(pointR.y, pointRy) def test_scalar_multiply(self): d = 0xc51e4753afdec1e6b6c6a5b992f43f8dd0c7a8933072708b6522468b2ffb06fd pointRx = 0x51d08d5f2d4278882946d88d83c97d11e62becc3cfc18bedacc89ba34eeca03f pointRy = 0x75ee68eb8bf626aa5b673ab51f6e744e06f8fcf8a6c0cf3035beca956a7b41d5 pointR = self.pointS * d self.assertEqual(pointR.x, pointRx) self.assertEqual(pointR.y, pointRy) # 0*S pai = self.pointS.point_at_infinity() pointR = self.pointS * 0 self.assertEqual(pointR, pai) # -1*S self.assertRaises(ValueError, lambda: self.pointS * -1) def test_joing_scalar_multiply(self): d = 0xc51e4753afdec1e6b6c6a5b992f43f8dd0c7a8933072708b6522468b2ffb06fd e = 0xd37f628ece72a462f0145cbefe3f0b355ee8332d37acdd83a358016aea029db7 pointRx = 0xd867b4679221009234939221b8046245efcf58413daacbeff857b8588341f6b8 pointRy = 0xf2504055c03cede12d22720dad69c745106b6607ec7e50dd35d54bd80f615275 pointR = self.pointS * d + self.pointT * e self.assertEqual(pointR.x, pointRx) self.assertEqual(pointR.y, pointRy) class TestEccPoint_PAI(unittest.TestCase): """Test vectors from http://point-at-infinity.org/ecc/nisttv""" pointG = EccPoint(_curve.Gx, _curve.Gy) tv_pai = load_tests(("Crypto", "SelfTest", "PublicKey", "test_vectors", "ECC"), "point-at-infinity.org-P256.txt", "P-256 tests from point-at-infinity.org", { "k" : lambda k: int(k), "x" : lambda x: int(x, 16), "y" : lambda y: int(y, 16)} ) assert(tv_pai) for tv in tv_pai: def new_test(self, scalar=tv.k, x=tv.x, y=tv.y): result = self.pointG * scalar self.assertEqual(result.x, x) self.assertEqual(result.y, y) setattr(TestEccPoint_PAI, "test_%d" % tv.count, new_test) class TestEccKey(unittest.TestCase): def test_private_key(self): key = EccKey(curve="P-256", d=1) self.assertEqual(key.d, 1) self.assertTrue(key.has_private()) self.assertEqual(key.pointQ.x, _curve.Gx) self.assertEqual(key.pointQ.y, _curve.Gy) point = EccPoint(_curve.Gx, _curve.Gy) key = EccKey(curve="P-256", d=1, point=point) self.assertEqual(key.d, 1) self.assertTrue(key.has_private()) self.assertEqual(key.pointQ, point) # Other names key = EccKey(curve="secp256r1", d=1) key = EccKey(curve="prime256v1", d=1) def test_public_key(self): point = EccPoint(_curve.Gx, _curve.Gy) key = EccKey(curve="P-256", point=point) self.assertFalse(key.has_private()) self.assertEqual(key.pointQ, point) def test_public_key_derived(self): priv_key = EccKey(curve="P-256", d=3) pub_key = priv_key.public_key() self.assertFalse(pub_key.has_private()) self.assertEqual(priv_key.pointQ, pub_key.pointQ) def test_invalid_curve(self): self.assertRaises(ValueError, lambda: EccKey(curve="P-257", d=1)) def test_invalid_d(self): self.assertRaises(ValueError, lambda: EccKey(curve="P-256", d=0)) self.assertRaises(ValueError, lambda: EccKey(curve="P-256", d=_curve.order)) def test_equality(self): private_key = ECC.construct(d=3, curve="P-256") private_key2 = ECC.construct(d=3, curve="P-256") private_key3 = ECC.construct(d=4, curve="P-256") public_key = private_key.public_key() public_key2 = private_key2.public_key() public_key3 = private_key3.public_key() self.assertEqual(private_key, private_key2) self.assertNotEqual(private_key, private_key3) self.assertEqual(public_key, public_key2) self.assertNotEqual(public_key, public_key3) self.assertNotEqual(public_key, private_key) class TestEccModule(unittest.TestCase): def test_generate(self): key = ECC.generate(curve="P-256") self.assertTrue(key.has_private()) self.assertEqual(key.pointQ, EccPoint(_curve.Gx, _curve.Gy) * key.d) # Other names ECC.generate(curve="secp256r1") ECC.generate(curve="prime256v1") def test_construct(self): key = ECC.construct(curve="P-256", d=1) self.assertTrue(key.has_private()) self.assertEqual(key.pointQ, _curve.G) key = ECC.construct(curve="P-256", point_x=_curve.Gx, point_y=_curve.Gy) self.assertFalse(key.has_private()) self.assertEqual(key.pointQ, _curve.G) # Other names ECC.construct(curve="secp256r1", d=1) ECC.construct(curve="prime256v1", d=1) def test_negative_construct(self): coord = dict(point_x=10, point_y=4) coordG = dict(point_x=_curve.Gx, point_y=_curve.Gy) self.assertRaises(ValueError, ECC.construct, curve="P-256", **coord) self.assertRaises(ValueError, ECC.construct, curve="P-256", d=2, **coordG) def get_tests(config={}): tests = [] tests += list_test_cases(TestEccPoint_NIST) tests += list_test_cases(TestEccPoint_PAI) tests += list_test_cases(TestEccKey) tests += list_test_cases(TestEccModule) return tests if __name__ == '__main__': suite = lambda: unittest.TestSuite(get_tests()) unittest.main(defaultTest='suite')