"""
.. module:: ristretto
ristretto module
================
This module exports the classes :obj:`~oblivious.ristretto.point` and
:obj:`~oblivious.ristretto.scalar` for representing points and scalars. It
also exports the two wrapper classes/namespaces
:obj:`~oblivious.ristretto.python` and :obj:`~oblivious.ristretto.sodium`
that encapsulate pure-Python and shared/dynamic library variants of the
above (respectively) and also include low-level operations that correspond
more directly to the functions found in the underlying libraries.
* Under all conditions, the wrapper class :obj:`~oblivious.ristretto.python`
is defined and encapsulates a pure-Python variant of every class exported
by this module as a whole. It also includes pure-Python variants of low-level
operations that correspond to functions found in the underlying libraries.
* If a shared/dynamic library instance of the
`libsodium <https://doc.libsodium.org>`__ library is found on the system
(and successfully loaded at the time this module is imported) or the
optional `rbcl <https://pypi.org/project/rbcl>`__ package is installed,
then the wrapper class :obj:`~oblivious.ristretto.sodium` is defined.
Otherwise, the exported variable ``sodium`` is assigned ``None``.
* If a dynamic/shared library instance is loaded, all classes exported by
this module correspond to the variants defined within
:obj:`~oblivious.ristretto.sodium`. Otherwise, they correspond to the
variants defined within :obj:`~oblivious.ristretto.python`.
For most users, the classes :obj:`~oblivious.ristretto.point` and
:obj:`~oblivious.ristretto.scalar` should be sufficient. When using the
low-level operations that correspond to a specific implementation (*e.g.*,
:obj:`oblivious.ristretto.sodium.add`), users are responsible for ensuring
that inputs have the type and/or representation appropriate for that
operation.
"""
from __future__ import annotations
from typing import Any, NoReturn, Union, Optional
import doctest
import platform
import os
import hashlib
import ctypes
import ctypes.util
import secrets
import base64
import ge25519
#
# Attempt to load rbcl. If no local libsodium shared/dynamic library file
# is found, only pure-Python implementations of the functions and methods
# will be available.
#
try: # pragma: no cover
import rbcl # pylint: disable=E0401
# Add synonyms to deal with variations in capitalization of function names.
setattr(
rbcl,
'crypto_core_ristretto255_scalarbytes',
lambda: rbcl.crypto_core_ristretto255_SCALARBYTES
)
setattr(
rbcl,
'crypto_core_ristretto255_bytes',
lambda: rbcl.crypto_core_ristretto255_BYTES
)
except: # pylint: disable=W0702 # pragma: no cover
rbcl = None
#
# Use pure-Python implementations of primitives by default.
#
def _zero(n: bytes) -> bool:
d = 0
for b in n:
d |= b
return ((d - 1) >> 8) % 2 == 1
_sc25519_is_canonical_L = [ # 2^252+27742317777372353535851937790883648493.
0xed, 0xd3, 0xf5, 0x5c, 0x1a, 0x63, 0x12, 0x58, 0xd6, 0x9c, 0xf7,
0xa2, 0xde, 0xf9, 0xde, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x10
]
def _sc25519_is_canonical(s: bytes) -> bool:
"""
Confirm that the bytes-like object represents a canonical
scalar.
"""
c = 0
n = 1
for i in range(31, -1, -1):
c |= ((s[i] - _sc25519_is_canonical_L[i]) >> 8) & n
n &= ((s[i] ^ _sc25519_is_canonical_L[i]) - 1) >> 8
return c != 0
def _sc25519_mul(a: bytes, b: bytes) -> bytes:
"""
Multiply the two scalars represented by the bytes-like objects
"""
(a, b) = (int.from_bytes(a, 'little'), int.from_bytes(b, 'little'))
return (
(a * b) % (pow(2, 252) + 27742317777372353535851937790883648493)
).to_bytes(32, 'little')
def _sc25519_sqmul(s: bytes, n: int, a: bytes) -> bytes:
"""
Perform repeated squaring of a scalar for the designated number
of iterations, then multiply the result by another scalar.
"""
for _ in range(n):
s = _sc25519_mul(s, s)
return _sc25519_mul(s, a)
def _sc25519_invert(s: bytes) -> bytes:
"""
Invert the scalar represented by the bytes-like object.
"""
b_10 = _sc25519_mul(s, s)
b_100 = _sc25519_mul(b_10, b_10)
b_11 = _sc25519_mul(b_10, s)
b_101 = _sc25519_mul(b_10, b_11)
b_111 = _sc25519_mul(b_10, b_101)
b_1001 = _sc25519_mul(b_10, b_111)
b_1011 = _sc25519_mul(b_10, b_1001)
b_1111 = _sc25519_mul(b_100, b_1011)
recip = _sc25519_mul(b_1111, s)
recip = _sc25519_sqmul(recip, 123 + 3, b_101)
recip = _sc25519_sqmul(recip, 2 + 2, b_11)
recip = _sc25519_sqmul(recip, 1 + 4, b_1111)
recip = _sc25519_sqmul(recip, 1 + 4, b_1111)
recip = _sc25519_sqmul(recip, 4, b_1001)
recip = _sc25519_sqmul(recip, 2, b_11)
recip = _sc25519_sqmul(recip, 1 + 4, b_1111)
recip = _sc25519_sqmul(recip, 1 + 3, b_101)
recip = _sc25519_sqmul(recip, 3 + 3, b_101)
recip = _sc25519_sqmul(recip, 3, b_111)
recip = _sc25519_sqmul(recip, 1 + 4, b_1111)
recip = _sc25519_sqmul(recip, 2 + 3, b_111)
recip = _sc25519_sqmul(recip, 2 + 2, b_11)
recip = _sc25519_sqmul(recip, 1 + 4, b_1011)
recip = _sc25519_sqmul(recip, 2 + 4, b_1011)
recip = _sc25519_sqmul(recip, 6 + 4, b_1001)
recip = _sc25519_sqmul(recip, 2 + 2, b_11)
recip = _sc25519_sqmul(recip, 3 + 2, b_11)
recip = _sc25519_sqmul(recip, 3 + 2, b_11)
recip = _sc25519_sqmul(recip, 1 + 4, b_1001)
recip = _sc25519_sqmul(recip, 1 + 3, b_111)
recip = _sc25519_sqmul(recip, 2 + 4, b_1111)
recip = _sc25519_sqmul(recip, 1 + 4, b_1011)
recip = _sc25519_sqmul(recip, 3, b_101)
recip = _sc25519_sqmul(recip, 2 + 4, b_1111)
recip = _sc25519_sqmul(recip, 3, b_101)
recip = _sc25519_sqmul(recip, 1 + 2, b_11)
return recip
def _ristretto255_is_canonical(s: bytes) -> bool:
"""
Confirm that the bytes-like object represents a canonical
Ristretto point.
"""
c = ((s[31] & 0x7f) ^ 0x7f) % 256
for i in range(30, 0, -1):
c |= (s[i] ^ 0xff) % 256
c = (c - 1) >> 8
d = ((0xed - 1 - s[0]) >> 8) % 256
return (1 - (((c & d) | s[0]) & 1)) == 1
[docs]class python:
"""
Wrapper class for pure-Python implementations of primitive operations.
This class encapsulates pure-Python variants of all low-level operations
and of both classes exported by this module:
:obj:`python.scl <scl>`, :obj:`python.rnd <rnd>`,
:obj:`python.inv <inv>`, :obj:`python.smu <smu>`,
:obj:`python.pnt <pnt>`, :obj:`python.bas <bas>`,
:obj:`python.can <can>`, :obj:`python.mul <mul>`,
:obj:`python.add <add>`, :obj:`python.sub <sub>`,
:obj:`python.neg <neg>`,
:obj:`python.point <oblivious.ristretto.python.point>`, and
:obj:`python.scalar <oblivious.ristretto.python.scalar>`.
For example, you can perform addition of points using
the pure-Python point addition implementation.
>>> p = python.pnt()
>>> q = python.pnt()
>>> python.add(p, q) == python.add(q, p)
True
Pure-Python variants of the :obj:`python.point <point>` and
:obj:`python.scalar <scalar>` classes always employ pure-Python
implementations of operations when their methods are invoked.
>>> p = python.point()
>>> q = python.point()
>>> p + q == q + p
True
Nevertheless, all bytes-like objects, :obj:`point` objects, and
:obj:`scalar` objects accepted and emitted by the various operations and
class methods in :obj:`python` are compatible with those accepted and
emitted by the operations and class methods in :obj:`sodium`.
"""
[docs] @staticmethod
def pnt(h: bytes = None) -> bytes:
"""
Return point from 64-byte vector (normally obtained via hashing).
>>> p = python.pnt(hashlib.sha512('123'.encode()).digest())
>>> p.hex()
'047f39a6c6dd156531a25fa605f017d4bec13b0b6c42f0e9b641c8ee73359c5f'
"""
return ge25519.ge25519_p3.from_hash_ristretto255(
hashlib.sha512(python.rnd()).digest() if h is None else h
)
[docs] @staticmethod
def bas(s: bytes) -> bytes:
"""
Return base point multiplied by supplied scalar.
>>> python.bas(scalar.hash('123'.encode())).hex()
'4c207a5377f3badf358914f20b505cd1e2a6396720a9c240e5aff522e2446005'
"""
t = bytearray(s)
t[31] &= 127
return ge25519.ge25519_p3.scalar_mult_base(t).to_bytes_ristretto255()
[docs] @staticmethod
def can(p: bytes) -> bytes:
"""
Normalize the representation of a point into its canonical form.
>>> p = point.hash('123'.encode())
>>> python.can(p) == p
True
"""
return p # In this module, the canonical representation is used at all times.
[docs] @staticmethod
def mul(s: bytes, p: bytes) -> bytes:
"""
Multiply the point by the supplied scalar and return the result.
>>> p = python.pnt(hashlib.sha512('123'.encode()).digest())
>>> s = python.scl(bytes.fromhex(
... '35c141f1c2c43543de9d188805a210abca3cd39a1e986304991ceded42b11709'
... ))
>>> python.mul(s, p).hex()
'183a06e0fe6af5d7913afb40baefc4dd52ae718fee77a3a0af8777c89fe16210'
"""
p3 = ge25519.ge25519_p3.from_bytes_ristretto255(p)
if not _ristretto255_is_canonical(p) or p3 is None:
return bytes(32) # pragma: no cover
t = bytearray(s)
t[31] &= 127
return p3.scalar_mult(t).to_bytes_ristretto255()
[docs] @staticmethod
def add(p: bytes, q: bytes) -> bytes:
"""
Return sum of the supplied points.
>>> p = point.hash('123'.encode())
>>> q = point.hash('456'.encode())
>>> python.add(p, q).hex()
'7076739c9df665d416e68b9512f5513bf1d0181a2aacefdeb1b7244528a4dd77'
"""
p_p3 = ge25519.ge25519_p3.from_bytes_ristretto255(p)
q_p3 = ge25519.ge25519_p3.from_bytes_ristretto255(q)
if (
not _ristretto255_is_canonical(p) or p_p3 is None or
not _ristretto255_is_canonical(q) or q_p3 is None
):
return bytes(32) # pragma: no cover
q_cached = ge25519.ge25519_cached.from_p3(q_p3)
r_p1p1 = ge25519.ge25519_p1p1.add(p_p3, q_cached)
r_p3 = ge25519.ge25519_p3.from_p1p1(r_p1p1)
return r_p3.to_bytes_ristretto255()
[docs] @staticmethod
def sub(p: bytes, q: bytes) -> bytes:
"""
Return result of subtracting second point from first point.
>>> p = point.hash('123'.encode())
>>> q = point.hash('456'.encode())
>>> python.sub(p, q).hex()
'1a3199ca7debfe31a90171696d8bab91b99eb23a541b822a7061b09776e1046c'
"""
p_p3 = ge25519.ge25519_p3.from_bytes_ristretto255(p)
q_p3 = ge25519.ge25519_p3.from_bytes_ristretto255(q)
if (
not _ristretto255_is_canonical(p) or p_p3 is None or
not _ristretto255_is_canonical(q) or q_p3 is None
):
return bytes(32) # pragma: no cover
q_cached = ge25519.ge25519_cached.from_p3(q_p3)
r_p1p1 = ge25519.ge25519_p1p1.sub(p_p3, q_cached)
r_p3 = ge25519.ge25519_p3.from_p1p1(r_p1p1)
return r_p3.to_bytes_ristretto255()
[docs] @staticmethod
def neg(p: bytes) -> bytes:
"""
Return the additive inverse of a point.
>>> p = point.hash('123'.encode())
>>> q = point.hash('456'.encode())
>>> python.add(python.neg(p), python.add(p, q)) == q
True
"""
return python.sub(bytes(32), p)
[docs] @staticmethod
def rnd() -> bytes:
"""
Return random non-zero scalar.
>>> len(python.rnd())
32
"""
while True:
r = bytearray(secrets.token_bytes(32))
r[-1] &= 0x1f
if _sc25519_is_canonical(r) and not _zero(r):
return r
[docs] @classmethod
def scl(cls, s: bytes = None) -> Optional[bytes]:
"""
Return supplied byte vector if it is a valid scalar; otherwise, return
``None``. If no byte vector is supplied, return a random scalar.
>>> s = python.scl()
>>> t = python.scl(s)
>>> s == t
True
>>> python.scl(bytes([255] * 32)) is None
True
"""
if s is None:
return cls.rnd()
s = bytearray(s)
s[-1] &= 0x1f
return bytes(s) if _sc25519_is_canonical(s) else None
[docs] @staticmethod
def inv(s: bytes) -> bytes:
"""
Return the inverse of a scalar (modulo
``2**252 + 27742317777372353535851937790883648493``).
>>> s = python.scl()
>>> p = python.pnt()
>>> python.mul(python.inv(s), python.mul(s, p)) == p
True
"""
return _sc25519_invert(s)
[docs] @staticmethod
def smu(s: bytes, t: bytes) -> bytes:
"""
Return scalar multiplied by another scalar.
>>> s = python.scl()
>>> t = python.scl()
>>> python.smu(s, t) == python.smu(t, s)
True
"""
return _sc25519_mul(s, t)
#
# Attempt to load primitives from libsodium, if it is present;
# otherwise, use the rbcl library, if it is present. Otherwise,
# silently assign ``None`` to ``sodium``.
#
try:
def _call_variant_unwrapped(length, function, x=None, y=None):
"""
Wrapper to invoke external function.
"""
buf = ctypes.create_string_buffer(length)
if y is not None:
function(buf, x, y)
elif x is not None:
function(buf, x)
else:
function(buf)
return buf.raw
def _call_variant_wrapped(_, function, x=None, y=None): # pragma: no cover
"""
Wrapper to invoke external (wrapped) function.
"""
if y is not None:
return function(x, y)
if x is not None:
return function(x)
return function()
_sodium = None
_call_variant = _call_variant_unwrapped
# Attempt to load libsodium shared/dynamic library file.
xdll = ctypes.cdll if platform.system() != 'Windows' else ctypes.windll
libf = ctypes.util.find_library('sodium') or ctypes.util.find_library('libsodium')
if libf is not None:
_sodium = xdll.LoadLibrary(libf)
else: # pragma: no cover
# Perform explicit search in case `ld` is not present in environment.
libf = 'libsodium.so' if platform.system() != 'Windows' else 'libsodium.dll'
for var in ['PATH', 'LD_LIBRARY_PATH']:
if var in os.environ:
for path in os.environ[var].split(os.pathsep):
try:
_sodium = ctypes.cdll.LoadLibrary(path + os.path.sep + libf)
break
except: # pylint: disable=W0702
continue
# Default to bindings exported by the rbcl library if the above attempts
# failed and rbcl is available.
if _sodium is None and rbcl is not None: # pragma: no cover
_sodium = rbcl
_call_variant = _call_variant_wrapped
# Add method variants that are not present in libsodium.
if _sodium is not rbcl and _sodium is not None: # pragma: no cover
def _crypto_scalarmult_ristretto255_allow_scalar_zero(buf, s, p):
"""
Variant of scalar-point multiplication function that permits
a scalar corresponding to the zero residue.
"""
r = _sodium.crypto_scalarmult_ristretto255(buf, s, p)
if (1 - _zero(s)) * int(r == -1):
raise RuntimeError('libsodium error (possibly due to invalid input)')
return buf
def _crypto_scalarmult_ristretto255_base_allow_scalar_zero(buf, s):
"""
Variant of scalar-point multiplication function that permits
a scalar corresponding to the zero residue.
"""
r = _sodium.crypto_scalarmult_ristretto255_base(buf, s)
if (1 - _zero(s)) * int(r == -1):
raise RuntimeError('libsodium error (possibly due to invalid input)')
return buf
setattr(
_sodium,
'crypto_scalarmult_ristretto255_allow_scalar_zero',
_crypto_scalarmult_ristretto255_allow_scalar_zero
)
setattr(
_sodium,
'crypto_scalarmult_ristretto255_base_allow_scalar_zero',
_crypto_scalarmult_ristretto255_base_allow_scalar_zero
)
# Ensure the chosen version of libsodium (or its substitute) has the
# necessary primitives.
assert hasattr(_sodium, 'crypto_core_ristretto255_bytes')
assert hasattr(_sodium, 'crypto_core_ristretto255_scalarbytes')
assert hasattr(_sodium, 'crypto_core_ristretto255_scalar_random')
assert hasattr(_sodium, 'crypto_core_ristretto255_scalar_invert')
assert hasattr(_sodium, 'crypto_core_ristretto255_scalar_mul')
assert hasattr(_sodium, 'crypto_core_ristretto255_from_hash')
assert hasattr(_sodium, 'crypto_scalarmult_ristretto255_base')
assert hasattr(_sodium, 'crypto_scalarmult_ristretto255')
assert hasattr(_sodium, 'crypto_core_ristretto255_add')
assert hasattr(_sodium, 'crypto_core_ristretto255_sub')
# Exported symbol.
[docs] class sodium:
"""
Wrapper class for binary implementations of primitive operations.
When this module is imported, it makes a number of attempts to
locate an instance of the shared/dynamic library file of the
`libsodium <https://doc.libsodium.org>`__ library on the host
system. The sequence of attempts is listed below, in order.
1. It uses ``ctypes.util.find_library`` to look for ``'sodium'`` or
``'libsodium'``.
2. It attempts to find a file ``libsodium.so`` or ``libsodium.dll`` in
the paths specified by the ``PATH`` and ``LD_LIBRARY_PATH``
environment variables.
3. If the `rbcl <https://pypi.org/project/rbcl>`__ package is
installed, it reverts to the compiled subset of libsodium included
in that package.
If all of the above fail, then :obj:`sodium` is assigned the value
``None`` and all classes exported by this module default to their
pure-Python variants (*i.e.*, those encapsulated within :obj:`python`).
To confirm that a dynamic/shared library *has been found* when this
module is imported, evaluate the expression ``sodium is not None``.
If a shared/dynamic library file has been loaded successfully, this
class encapsulates shared/dynamic library variants of both classes
exported by this module and of all the underlying low-level operations:
:obj:`sodium.scl <scl>`, :obj:`sodium.rnd <rnd>`,
:obj:`sodium.inv <inv>`, :obj:`sodium.smu <smu>`,
:obj:`sodium.pnt <pnt>`, :obj:`sodium.bas <bas>`,
:obj:`sodium.can <can>`, :obj:`sodium.mul <mul>`,
:obj:`sodium.add <add>`, :obj:`sodium.sub <sub>`,
:obj:`sodium.neg <neg>`,
:obj:`sodium.point <oblivious.ristretto.sodium.point>`, and
:obj:`sodium.scalar <oblivious.ristretto.sodium.scalar>`.
For example, you can perform addition of points using
the point addition implementation found in the libsodium
shared/dynamic library found on the host system.
>>> p = sodium.pnt()
>>> q = sodium.pnt()
>>> sodium.add(p, q) == sodium.add(q, p)
True
Methods found in the shared/dynamic library variants of the
:obj:`point` and :obj:`scalar` classes are wrappers for the
shared/dynamic library implementations of the underlying
operations.
>>> p = sodium.point()
>>> q = sodium.point()
>>> p + q == q + p
True
Nevertheless, all bytes-like objects, :obj:`point` objects, and
:obj:`scalar` objects accepted and emitted by the various operations
and class methods in :obj:`sodium` are compatible with those accepted
and emitted by the operations and class methods in :obj:`python`.
"""
_lib = _sodium
_call_unwrapped = _call_variant_unwrapped
_call_wrapped = _call_variant_wrapped
_call = _call_variant
[docs] @staticmethod
def pnt(h: bytes = None) -> bytes:
"""
Construct a point from its 64-byte vector representation (normally
obtained via hashing).
>>> p = sodium.pnt(hashlib.sha512('123'.encode()).digest())
>>> p.hex()
'047f39a6c6dd156531a25fa605f017d4bec13b0b6c42f0e9b641c8ee73359c5f'
"""
return sodium._call(
sodium._lib.crypto_core_ristretto255_bytes(),
sodium._lib.crypto_core_ristretto255_from_hash,
bytes(
hashlib.sha512(sodium.rnd()).digest() if h is None else h
)
)
[docs] @staticmethod
def bas(s: bytes) -> bytes:
"""
Return the base point multiplied by the supplied scalar.
>>> sodium.bas(scalar.hash('123'.encode())).hex()
'4c207a5377f3badf358914f20b505cd1e2a6396720a9c240e5aff522e2446005'
"""
return sodium._call(
sodium._lib.crypto_core_ristretto255_scalarbytes(),
sodium._lib.crypto_scalarmult_ristretto255_base_allow_scalar_zero,
bytes(s)
)
[docs] @staticmethod
def can(p: bytes) -> bytes:
"""
Normalize the representation of a point into its canonical form.
>>> p = point.hash('123'.encode())
>>> sodium.can(p) == p
True
"""
# In this module, the canonical representation is used at all times.
return p
[docs] @staticmethod
def mul(s: bytes, p: bytes) -> bytes:
"""
Multiply a point by a scalar and return the result.
>>> p = sodium.pnt(hashlib.sha512('123'.encode()).digest())
>>> s = sodium.scl(bytes.fromhex(
... '35c141f1c2c43543de9d188805a210abca3cd39a1e986304991ceded42b11709'
... ))
>>> sodium.mul(s, p).hex()
'183a06e0fe6af5d7913afb40baefc4dd52ae718fee77a3a0af8777c89fe16210'
"""
return sodium._call(
sodium._lib.crypto_core_ristretto255_scalarbytes(),
sodium._lib.crypto_scalarmult_ristretto255_allow_scalar_zero,
bytes(s), bytes(p)
)
[docs] @staticmethod
def add(p: bytes, q: bytes) -> bytes:
"""
Return the sum of the supplied points.
>>> p = point.hash('123'.encode())
>>> q = point.hash('456'.encode())
>>> sodium.add(p, q).hex()
'7076739c9df665d416e68b9512f5513bf1d0181a2aacefdeb1b7244528a4dd77'
"""
return sodium._call(
sodium._lib.crypto_core_ristretto255_scalarbytes(),
sodium._lib.crypto_core_ristretto255_add,
bytes(p), bytes(q)
)
[docs] @staticmethod
def sub(p: bytes, q: bytes) -> bytes:
"""
Return the result of subtracting the right-hand point from the
left-hand point.
>>> p = point.hash('123'.encode())
>>> q = point.hash('456'.encode())
>>> sodium.sub(p, q).hex()
'1a3199ca7debfe31a90171696d8bab91b99eb23a541b822a7061b09776e1046c'
"""
return sodium._call(
_sodium.crypto_core_ristretto255_scalarbytes(),
sodium._lib.crypto_core_ristretto255_sub,
bytes(p), bytes(q)
)
[docs] @staticmethod
def neg(p: bytes) -> bytes:
"""
Return the additive inverse of a point.
>>> p = point.hash('123'.encode())
>>> q = point.hash('456'.encode())
>>> sodium.add(sodium.neg(p), sodium.add(p, q)) == q
True
"""
return sodium.sub(bytes(32), p)
[docs] @staticmethod
def rnd() -> bytes:
"""
Return random non-zero scalar.
>>> len(sodium.rnd())
32
"""
return sodium._call(
sodium._lib.crypto_core_ristretto255_scalarbytes(),
sodium._lib.crypto_core_ristretto255_scalar_random
)
[docs] @classmethod
def scl(cls, s: bytes = None) -> Optional[bytes]:
"""
Return supplied byte vector if it is a valid scalar; otherwise,
return ``None``. If no byte vector is supplied, return a random
scalar.
>>> s = sodium.scl()
>>> t = sodium.scl(s)
>>> s == t
True
>>> sodium.scl(bytes([255] * 32)) is None
True
"""
if s is None:
return cls.rnd()
s = bytearray(s)
s[-1] &= 0x1f
return bytes(s) if _sc25519_is_canonical(s) else None
[docs] @staticmethod
def inv(s: bytes) -> bytes:
"""
Return the inverse of a scalar (modulo
``2**252 + 27742317777372353535851937790883648493``).
>>> s = sodium.scl()
>>> p = sodium.pnt()
>>> sodium.mul(sodium.inv(s), sodium.mul(s, p)) == p
True
"""
return sodium._call(
sodium._lib.crypto_core_ristretto255_scalarbytes(),
sodium._lib.crypto_core_ristretto255_scalar_invert,
bytes(s)
)
[docs] @staticmethod
def smu(s: bytes, t: bytes) -> bytes:
"""
Return the product of two scalars.
>>> s = sodium.scl()
>>> t = sodium.scl()
>>> sodium.smu(s, t) == sodium.smu(t, s)
True
"""
return sodium._call(
sodium._lib.crypto_core_ristretto255_scalarbytes(),
sodium._lib.crypto_core_ristretto255_scalar_mul,
bytes(s), bytes(t)
)
except: # pylint: disable=W0702 # pragma: no cover
# Exported symbol.
sodium = None # pragma: no cover
#
# Dedicated point and scalar data structures derived from `bytes`.
#
for _implementation in [python] + ([sodium] if sodium is not None else []):
# pylint: disable=cell-var-from-loop
[docs] class point(bytes): # pylint: disable=E0102
"""
Class for representing a point. Because this class is derived from
:obj:`bytes`, it inherits methods such as :obj:`bytes.hex` and
:obj:`bytes.fromhex`.
>>> len(point.random())
32
>>> p = point.hash('123'.encode())
>>> p.hex()
'047f39a6c6dd156531a25fa605f017d4bec13b0b6c42f0e9b641c8ee73359c5f'
>>> point.fromhex(p.hex()) == p
True
"""
_implementation = _implementation
[docs] @classmethod
def random(cls) -> point:
"""
Return random point object.
>>> len(point.random())
32
"""
return bytes.__new__(cls, cls._implementation.pnt())
[docs] @classmethod
def bytes(cls, bs: bytes) -> point:
"""
Return the point object obtained by transforming the supplied
bytes-like object.
>>> p = point.bytes(hashlib.sha512('123'.encode()).digest())
>>> p.hex()
'047f39a6c6dd156531a25fa605f017d4bec13b0b6c42f0e9b641c8ee73359c5f'
"""
return bytes.__new__(cls, cls._implementation.pnt(bs))
[docs] @classmethod
def hash(cls, bs: bytes) -> point:
"""
Return point object by hashing supplied bytes-like object.
>>> point.hash('123'.encode()).hex()
'047f39a6c6dd156531a25fa605f017d4bec13b0b6c42f0e9b641c8ee73359c5f'
"""
return bytes.__new__(cls, cls._implementation.pnt(hashlib.sha512(bs).digest()))
[docs] @classmethod
def base(cls, s: scalar) -> Optional[point]:
"""
Return base point multiplied by supplied scalar if the scalar is valid;
otherwise, return ``None``.
>>> point.base(scalar.hash('123'.encode())).hex()
'4c207a5377f3badf358914f20b505cd1e2a6396720a9c240e5aff522e2446005'
Use of the scalar corresponding to the zero residue is permitted.
>>> p = point()
>>> point.base(scalar.from_int(0)) + p == p
True
"""
return bytes.__new__(cls, cls._implementation.bas(s))
[docs] @classmethod
def from_bytes(cls, bs: bytes) -> point:
"""
Return the instance corresponding to the supplied bytes-like object.
>>> p = point.bytes(hashlib.sha512('123'.encode()).digest())
>>> p == point.from_bytes(p.to_bytes())
True
"""
return bytes(bs)
[docs] @classmethod
def from_base64(cls, s: str) -> point:
"""
Construct an instance from its Base64 UTF-8 string representation.
>>> point.from_base64('hoVaKq3oIlxEndP2Nqv3Rdbmiu4iinZE6Iwo+kcKAik=').hex()
'86855a2aade8225c449dd3f636abf745d6e68aee228a7644e88c28fa470a0229'
"""
return bytes.__new__(cls, base64.standard_b64decode(s))
def __new__(cls, bs: bytes = None) -> point:
"""
If a bytes-like object is supplied, return a point object
corresponding to the supplied bytes-like object (no checking
is performed to confirm that the bytes-like object is a valid
point). If no argument is supplied, return a random point
object.
>>> bs = bytes.fromhex(
... '86855a2aade8225c449dd3f636abf745d6e68aee228a7644e88c28fa470a0229'
... )
>>> point(bs).hex()
'86855a2aade8225c449dd3f636abf745d6e68aee228a7644e88c28fa470a0229'
>>> len(point())
32
"""
return bytes.__new__(cls, bs) if bs is not None else cls.random()
[docs] def canonical(self: point) -> point:
"""
Normalize the representation of this instance into its canonical form.
>>> p = point.hash('123'.encode())
>>> p.canonical() == p
True
"""
# In this module, the canonical representation is used at all times.
return self
[docs] def __mul__(self: point, other: Any) -> NoReturn:
"""
A point cannot be a left-hand argument for a multiplication operation.
>>> point() * scalar()
Traceback (most recent call last):
...
TypeError: point must be on right-hand side of multiplication operator
"""
raise TypeError('point must be on right-hand side of multiplication operator')
[docs] def __rmul__(self: point, other: Any) -> NoReturn:
"""
This functionality is implemented exclusively in the method
:obj:`scalar.__mul__`, as that method pre-empts this method
when the second argument has the correct type (*i.e.*, it is
a :obj:`scalar` instance). This method is included so that an
exception can be raised if an incorrect argument is supplied.
>>> p = point.hash('123'.encode())
>>> 2 * p
Traceback (most recent call last):
...
TypeError: point can only be multiplied by a scalar
"""
raise TypeError('point can only be multiplied by a scalar')
[docs] def __add__(self: point, other: point) -> Optional[point]:
"""
Return the sum of this instance and another point.
>>> p = point.hash('123'.encode())
>>> q = point.hash('456'.encode())
>>> (p + q).hex()
'7076739c9df665d416e68b9512f5513bf1d0181a2aacefdeb1b7244528a4dd77'
>>> p + (q - q) == p
True
"""
return self._implementation.point(self._implementation.add(self, other))
[docs] def __sub__(self: point, other: point) -> Optional[point]:
"""
Return the result of subtracting another point from this instance.
>>> p = point.hash('123'.encode())
>>> q = point.hash('456'.encode())
>>> (p - q).hex()
'1a3199ca7debfe31a90171696d8bab91b99eb23a541b822a7061b09776e1046c'
>>> p - p == point.base(scalar.from_int(0))
True
"""
return self._implementation.point(self._implementation.sub(self, other))
[docs] def __neg__(self: point) -> point:
"""
Return the negation of this instance.
>>> p = point.hash('123'.encode())
>>> q = point.hash('456'.encode())
>>> ((p + q) + (-q)) == p
True
"""
return (self - self) - self
[docs] def to_bytes(self: point) -> bytes:
"""
Return the bytes-like object that represents this instance.
>>> p = point()
>>> p.to_bytes() == p
True
"""
return bytes(self)
[docs] def to_base64(self: point) -> str:
"""
Return the Base64 UTF-8 string representation of this instance.
>>> p = point.from_base64('hoVaKq3oIlxEndP2Nqv3Rdbmiu4iinZE6Iwo+kcKAik=')
>>> p.to_base64()
'hoVaKq3oIlxEndP2Nqv3Rdbmiu4iinZE6Iwo+kcKAik='
"""
return base64.standard_b64encode(self).decode('utf-8')
[docs] class scalar(bytes):
"""
Class for representing a scalar. Because this class is derived from
:obj:`bytes`, it inherits methods such as :obj:`bytes.hex` and
:obj:`bytes.fromhex`.
>>> len(scalar.random())
32
>>> s = scalar.hash('123'.encode())
>>> s.hex()
'a665a45920422f9d417e4867efdc4fb8a04a1f3fff1fa07e998e86f7f7a27a03'
>>> scalar.fromhex(s.hex()) == s
True
"""
_implementation = _implementation
[docs] @classmethod
def random(cls) -> scalar:
"""
Return random non-zero scalar object.
>>> len(scalar.random())
32
"""
return bytes.__new__(cls, cls._implementation.rnd())
[docs] @classmethod
def bytes(cls, bs: bytes) -> Optional[scalar]:
"""
Return scalar object obtained by transforming supplied bytes-like
object if it is possible to do; otherwise, return ``None``.
>>> s = python.scl()
>>> t = scalar.bytes(s)
>>> s.hex() == t.hex()
True
"""
s = cls._implementation.scl(bs)
return bytes.__new__(cls, s) if s is not None else None
[docs] @classmethod
def hash(cls, bs: bytes) -> scalar:
"""
Return scalar object by hashing supplied bytes-like object.
>>> scalar.hash('123'.encode()).hex()
'a665a45920422f9d417e4867efdc4fb8a04a1f3fff1fa07e998e86f7f7a27a03'
"""
h = hashlib.sha256(bs).digest()
s = cls._implementation.scl(h)
while s is None:
h = hashlib.sha256(h).digest()
s = cls._implementation.scl(h)
return bytes.__new__(cls, s)
[docs] @classmethod
def from_int(cls, i: int) -> scalar:
"""
Construct an instance from its integer (*i.e.*, residue) representation.
>>> p = point()
>>> zero = scalar.from_int(0)
>>> zero * p == p - p
True
>>> one = scalar.from_int(1)
>>> one * p == p
True
>>> two = scalar.from_int(2)
>>> two * p == p + p
True
Negative integers are supported (and automatically converted into their
corresponding least nonnegative residues).
>>> q = point()
>>> p - p == scalar.from_int(0) * p
True
>>> q - p - p == q + (scalar.from_int(-2) * p)
True
"""
return bytes.__new__(
cls,
(
i % (pow(2, 252) + 27742317777372353535851937790883648493)
).to_bytes(32, 'little')
)
[docs] @classmethod
def from_bytes(cls, bs: bytes) -> Optional[scalar]:
"""
Return the instance corresponding to the supplied bytes-like object.
>>> s = python.scl()
>>> t = scalar.from_bytes(s)
>>> s.hex() == t.hex()
True
"""
return cls(bs)
[docs] @classmethod
def from_base64(cls, s: str) -> scalar:
"""
Construct an instance from its Base64 UTF-8 string representation.
>>> scalar.from_base64('MS0MkTD2kVO+yfXQOGqVE160XuvxMK9fH+0cbtFfJQA=').hex()
'312d0c9130f69153bec9f5d0386a95135eb45eebf130af5f1fed1c6ed15f2500'
"""
return bytes.__new__(cls, base64.standard_b64decode(s))
def __new__(cls, bs: bytes = None) -> scalar:
"""
If a bytes-like object is supplied, return a scalar object
corresponding to the supplied bytes-like object (no checking
is performed to confirm that the bytes-like object is a valid
scalar). If no argument is supplied, return a random scalar
object.
>>> s = python.scl()
>>> t = scalar(s)
>>> s.hex() == t.hex()
True
>>> len(scalar())
32
"""
return bytes.__new__(cls, bs) if bs is not None else cls.random()
[docs] def __invert__(self: scalar) -> scalar:
"""
Return the inverse of this instance (modulo
``2**252 + 27742317777372353535851937790883648493``).
>>> s = scalar()
>>> p = point()
>>> ((~s) * (s * p)) == p
True
The scalar corresponding to the zero residue cannot be inverted.
>>> ~scalar.from_int(0)
Traceback (most recent call last):
...
ValueError: cannot invert scalar corresponding to zero
"""
if _zero(self):
raise ValueError('cannot invert scalar corresponding to zero')
return self._implementation.scalar(self._implementation.inv(self))
[docs] def __mul__(self: scalar, other: Union[scalar, point]) -> Union[scalar, point]:
"""
Multiply the supplied scalar or point by this instance.
>>> p = point.hash('123'.encode())
>>> s = scalar.hash('456'.encode())
>>> (s * p).hex()
'f61b377aa86050aaa88c90f4a4a0f1e36b0000cf46f6a34232c2f1da7a799f16'
>>> p = point.from_base64('hoVaKq3oIlxEndP2Nqv3Rdbmiu4iinZE6Iwo+kcKAik=')
>>> s = scalar.from_base64('MS0MkTD2kVO+yfXQOGqVE160XuvxMK9fH+0cbtFfJQA=')
>>> (s * s).hex()
'd4aecf034f60edc5cb32cdd5a4be6d069959aa9fd133c51c9dcfd960ee865e0f'
>>> isinstance(s * s, scalar)
True
>>> (s * p).hex()
'2208082412921a67f42ea399748190d2b889228372509f2f2d9929813d074e1b'
>>> isinstance(s * p, point)
True
Multiplying any point or scalar by the scalar corresponding to the
zero residue yields the point or scalar corresponding to zero.
>>> scalar.from_int(0) * point() == p - p
True
>>> scalar.from_int(0) * scalar() == scalar.from_int(0)
True
Any attempt to multiply a value or object of an incompatible type by this
instance raises an exception.
>>> s * 2
Traceback (most recent call last):
...
TypeError: multiplication by a scalar is defined only for scalars and points
"""
if (
isinstance(other, python.scalar) or
(sodium is not None and isinstance(other, sodium.scalar))
):
return self._implementation.scalar(self._implementation.smu(self, other))
if (
isinstance(other, python.point) or
(sodium is not None and isinstance(other, sodium.point))
):
return self._implementation.point(self._implementation.mul(self, other))
raise TypeError(
'multiplication by a scalar is defined only for scalars and points'
)
[docs] def __rmul__(self: scalar, other: Union[scalar, point]):
"""
A scalar cannot be on the right-hand side of a non-scalar.
>>> point() * scalar()
Traceback (most recent call last):
...
TypeError: point must be on right-hand side of multiplication operator
"""
raise TypeError(
'scalar must be on left-hand side of multiplication operator'
)
[docs] def to_bytes(self: scalar) -> bytes:
"""
Return the bytes-like object that represents this instance.
>>> s = scalar()
>>> s.to_bytes() == s
True
"""
return bytes(self)
[docs] def __int__(self: scalar) -> int:
"""
Return the integer (*i.e.*, least nonnegative residue) representation
of this instance.
>>> s = scalar()
>>> int(s * (~s))
1
"""
return int.from_bytes(self, 'little')
[docs] def to_int(self: scalar) -> int:
"""
Return the integer (*i.e.*, least nonnegative residue) representation
of this instance.
>>> s = scalar()
>>> (s * (~s)).to_int()
1
"""
return int(self)
[docs] def to_base64(self: scalar) -> str:
"""
Return the Base64 UTF-8 string representation of this instance.
>>> s = scalar.from_base64('MS0MkTD2kVO+yfXQOGqVE160XuvxMK9fH+0cbtFfJQA=')
>>> s.to_base64()
'MS0MkTD2kVO+yfXQOGqVE160XuvxMK9fH+0cbtFfJQA='
"""
return base64.standard_b64encode(self).decode('utf-8')
# Encapsulate classes for this implementation, regardless of which are
# exported as the unqualified symbols.
_implementation.point = point
_implementation.scalar = scalar
# Redefine top-level wrapper classes to ensure that they appear at the end of
# the auto-generated documentation.
python = python # pylint: disable=self-assigning-variable
sodium = sodium # pylint: disable=self-assigning-variable
if __name__ == '__main__':
doctest.testmod() # pragma: no cover