-- |
-- Module      : Crypto.PubKey.ECC.P256
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : experimental
-- Portability : unknown
--
-- P256 support
--
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE EmptyDataDecls #-}
{-# OPTIONS_GHC -fno-warn-unused-binds #-}
module Crypto.PubKey.ECC.P256
    ( Scalar
    , Point
    -- * Point arithmetic
    , pointBase
    , pointAdd
    , pointNegate
    , pointMul
    , pointDh
    , pointsMulVarTime
    , pointIsValid
    , pointIsAtInfinity
    , toPoint
    , pointX
    , pointToIntegers
    , pointFromIntegers
    , pointToBinary
    , pointFromBinary
    , unsafePointFromBinary
    -- * Scalar arithmetic
    , scalarGenerate
    , scalarZero
    , scalarN
    , scalarIsZero
    , scalarAdd
    , scalarSub
    , scalarMul
    , scalarInv
    , scalarInvSafe
    , scalarCmp
    , scalarFromBinary
    , scalarToBinary
    , scalarFromInteger
    , scalarToInteger
    ) where

import           Data.Word
import           Foreign.Ptr
import           Foreign.C.Types

import           Crypto.Internal.Compat
import           Crypto.Internal.Imports
import           Crypto.Internal.ByteArray
import qualified Crypto.Internal.ByteArray as B
import           Data.Memory.PtrMethods (memSet)
import           Crypto.Error
import           Crypto.Random
import           Crypto.Number.Serialize.Internal (os2ip, i2ospOf)
import qualified Crypto.Number.Serialize as S (os2ip, i2ospOf)

-- | A P256 scalar
newtype Scalar = Scalar ScrubbedBytes
    deriving (Int -> Scalar -> ShowS
[Scalar] -> ShowS
Scalar -> String
(Int -> Scalar -> ShowS)
-> (Scalar -> String) -> ([Scalar] -> ShowS) -> Show Scalar
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Scalar -> ShowS
showsPrec :: Int -> Scalar -> ShowS
$cshow :: Scalar -> String
show :: Scalar -> String
$cshowList :: [Scalar] -> ShowS
showList :: [Scalar] -> ShowS
Show,Scalar -> Scalar -> Bool
(Scalar -> Scalar -> Bool)
-> (Scalar -> Scalar -> Bool) -> Eq Scalar
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Scalar -> Scalar -> Bool
== :: Scalar -> Scalar -> Bool
$c/= :: Scalar -> Scalar -> Bool
/= :: Scalar -> Scalar -> Bool
Eq,Scalar -> Int
(Scalar -> Int)
-> (forall p a. Scalar -> (Ptr p -> IO a) -> IO a)
-> (forall p. Scalar -> Ptr p -> IO ())
-> ByteArrayAccess Scalar
forall p. Scalar -> Ptr p -> IO ()
forall ba.
(ba -> Int)
-> (forall p a. ba -> (Ptr p -> IO a) -> IO a)
-> (forall p. ba -> Ptr p -> IO ())
-> ByteArrayAccess ba
forall p a. Scalar -> (Ptr p -> IO a) -> IO a
$clength :: Scalar -> Int
length :: Scalar -> Int
$cwithByteArray :: forall p a. Scalar -> (Ptr p -> IO a) -> IO a
withByteArray :: forall p a. Scalar -> (Ptr p -> IO a) -> IO a
$ccopyByteArrayToPtr :: forall p. Scalar -> Ptr p -> IO ()
copyByteArrayToPtr :: forall p. Scalar -> Ptr p -> IO ()
ByteArrayAccess,Scalar -> ()
(Scalar -> ()) -> NFData Scalar
forall a. (a -> ()) -> NFData a
$crnf :: Scalar -> ()
rnf :: Scalar -> ()
NFData)

-- | A P256 point
newtype Point = Point Bytes
    deriving (Int -> Point -> ShowS
[Point] -> ShowS
Point -> String
(Int -> Point -> ShowS)
-> (Point -> String) -> ([Point] -> ShowS) -> Show Point
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Point -> ShowS
showsPrec :: Int -> Point -> ShowS
$cshow :: Point -> String
show :: Point -> String
$cshowList :: [Point] -> ShowS
showList :: [Point] -> ShowS
Show,Point -> Point -> Bool
(Point -> Point -> Bool) -> (Point -> Point -> Bool) -> Eq Point
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Point -> Point -> Bool
== :: Point -> Point -> Bool
$c/= :: Point -> Point -> Bool
/= :: Point -> Point -> Bool
Eq,Point -> ()
(Point -> ()) -> NFData Point
forall a. (a -> ()) -> NFData a
$crnf :: Point -> ()
rnf :: Point -> ()
NFData)

scalarSize :: Int
scalarSize :: Int
scalarSize = Int
32

pointSize :: Int
pointSize :: Int
pointSize = Int
64

type P256Digit  = Word32

data P256Scalar
data P256Y
data P256X

order :: Integer
order :: Integer
order = Integer
0xffffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632551

------------------------------------------------------------------------
-- Point methods
------------------------------------------------------------------------

-- | Get the base point for the P256 Curve
pointBase :: Point
pointBase :: Point
pointBase =
    case Integer -> CryptoFailable Scalar
scalarFromInteger Integer
1 of
        CryptoPassed Scalar
s  -> Scalar -> Point
toPoint Scalar
s
        CryptoFailed CryptoError
_ -> String -> Point
forall a. HasCallStack => String -> a
error String
"pointBase: assumption failed"

-- | Lift to curve a scalar
--
-- Using the curve generator as base point compute:
--
-- > scalar * G
--
toPoint :: Scalar -> Point
toPoint :: Scalar -> Point
toPoint Scalar
s
    | Scalar -> Bool
scalarIsZero Scalar
s = String -> Point
forall a. HasCallStack => String -> a
error String
"cannot create point from zero"
    | Bool
otherwise      =
        (Ptr P256X -> Ptr P256Y -> IO ()) -> Point
withNewPoint ((Ptr P256X -> Ptr P256Y -> IO ()) -> Point)
-> (Ptr P256X -> Ptr P256Y -> IO ()) -> Point
forall a b. (a -> b) -> a -> b
$ \Ptr P256X
px Ptr P256Y
py -> Scalar -> (Ptr P256Scalar -> IO ()) -> IO ()
forall a. Scalar -> (Ptr P256Scalar -> IO a) -> IO a
withScalar Scalar
s ((Ptr P256Scalar -> IO ()) -> IO ())
-> (Ptr P256Scalar -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr P256Scalar
p ->
            Ptr P256Scalar -> Ptr P256X -> Ptr P256Y -> IO ()
ccryptonite_p256_basepoint_mul Ptr P256Scalar
p Ptr P256X
px Ptr P256Y
py

-- | Add a point to another point
pointAdd :: Point -> Point -> Point
pointAdd :: Point -> Point -> Point
pointAdd Point
a Point
b = (Ptr P256X -> Ptr P256Y -> IO ()) -> Point
withNewPoint ((Ptr P256X -> Ptr P256Y -> IO ()) -> Point)
-> (Ptr P256X -> Ptr P256Y -> IO ()) -> Point
forall a b. (a -> b) -> a -> b
$ \Ptr P256X
dx Ptr P256Y
dy ->
    Point -> (Ptr P256X -> Ptr P256Y -> IO ()) -> IO ()
forall a. Point -> (Ptr P256X -> Ptr P256Y -> IO a) -> IO a
withPoint Point
a ((Ptr P256X -> Ptr P256Y -> IO ()) -> IO ())
-> (Ptr P256X -> Ptr P256Y -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr P256X
ax Ptr P256Y
ay -> Point -> (Ptr P256X -> Ptr P256Y -> IO ()) -> IO ()
forall a. Point -> (Ptr P256X -> Ptr P256Y -> IO a) -> IO a
withPoint Point
b ((Ptr P256X -> Ptr P256Y -> IO ()) -> IO ())
-> (Ptr P256X -> Ptr P256Y -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr P256X
bx Ptr P256Y
by ->
        Ptr P256X
-> Ptr P256Y
-> Ptr P256X
-> Ptr P256Y
-> Ptr P256X
-> Ptr P256Y
-> IO ()
ccryptonite_p256e_point_add Ptr P256X
ax Ptr P256Y
ay Ptr P256X
bx Ptr P256Y
by Ptr P256X
dx Ptr P256Y
dy

-- | Negate a point
pointNegate :: Point -> Point
pointNegate :: Point -> Point
pointNegate Point
a = (Ptr P256X -> Ptr P256Y -> IO ()) -> Point
withNewPoint ((Ptr P256X -> Ptr P256Y -> IO ()) -> Point)
-> (Ptr P256X -> Ptr P256Y -> IO ()) -> Point
forall a b. (a -> b) -> a -> b
$ \Ptr P256X
dx Ptr P256Y
dy ->
    Point -> (Ptr P256X -> Ptr P256Y -> IO ()) -> IO ()
forall a. Point -> (Ptr P256X -> Ptr P256Y -> IO a) -> IO a
withPoint Point
a ((Ptr P256X -> Ptr P256Y -> IO ()) -> IO ())
-> (Ptr P256X -> Ptr P256Y -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr P256X
ax Ptr P256Y
ay ->
        Ptr P256X -> Ptr P256Y -> Ptr P256X -> Ptr P256Y -> IO ()
ccryptonite_p256e_point_negate Ptr P256X
ax Ptr P256Y
ay Ptr P256X
dx Ptr P256Y
dy

-- | Multiply a point by a scalar
--
-- warning: variable time
pointMul :: Scalar -> Point -> Point
pointMul :: Scalar -> Point -> Point
pointMul Scalar
scalar Point
p = (Ptr P256X -> Ptr P256Y -> IO ()) -> Point
withNewPoint ((Ptr P256X -> Ptr P256Y -> IO ()) -> Point)
-> (Ptr P256X -> Ptr P256Y -> IO ()) -> Point
forall a b. (a -> b) -> a -> b
$ \Ptr P256X
dx Ptr P256Y
dy ->
    Scalar -> (Ptr P256Scalar -> IO ()) -> IO ()
forall a. Scalar -> (Ptr P256Scalar -> IO a) -> IO a
withScalar Scalar
scalar ((Ptr P256Scalar -> IO ()) -> IO ())
-> (Ptr P256Scalar -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr P256Scalar
n -> Point -> (Ptr P256X -> Ptr P256Y -> IO ()) -> IO ()
forall a. Point -> (Ptr P256X -> Ptr P256Y -> IO a) -> IO a
withPoint Point
p ((Ptr P256X -> Ptr P256Y -> IO ()) -> IO ())
-> (Ptr P256X -> Ptr P256Y -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr P256X
px Ptr P256Y
py ->
        Ptr P256Scalar
-> Ptr P256X -> Ptr P256Y -> Ptr P256X -> Ptr P256Y -> IO ()
ccryptonite_p256e_point_mul Ptr P256Scalar
n Ptr P256X
px Ptr P256Y
py Ptr P256X
dx Ptr P256Y
dy

-- | Similar to 'pointMul', serializing the x coordinate as binary.
-- When scalar is multiple of point order the result is all zero.
pointDh :: ByteArray binary => Scalar -> Point -> binary
pointDh :: forall binary. ByteArray binary => Scalar -> Point -> binary
pointDh Scalar
scalar Point
p =
    Int -> (Ptr Word8 -> IO ()) -> binary
forall a p. ByteArray a => Int -> (Ptr p -> IO ()) -> a
B.unsafeCreate Int
scalarSize ((Ptr Word8 -> IO ()) -> binary) -> (Ptr Word8 -> IO ()) -> binary
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
dst -> (Ptr P256X -> Ptr P256Y -> IO ()) -> IO ()
forall a. (Ptr P256X -> Ptr P256Y -> IO a) -> IO a
withTempPoint ((Ptr P256X -> Ptr P256Y -> IO ()) -> IO ())
-> (Ptr P256X -> Ptr P256Y -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr P256X
dx Ptr P256Y
dy -> do
        Scalar -> (Ptr P256Scalar -> IO ()) -> IO ()
forall a. Scalar -> (Ptr P256Scalar -> IO a) -> IO a
withScalar Scalar
scalar ((Ptr P256Scalar -> IO ()) -> IO ())
-> (Ptr P256Scalar -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr P256Scalar
n -> Point -> (Ptr P256X -> Ptr P256Y -> IO ()) -> IO ()
forall a. Point -> (Ptr P256X -> Ptr P256Y -> IO a) -> IO a
withPoint Point
p ((Ptr P256X -> Ptr P256Y -> IO ()) -> IO ())
-> (Ptr P256X -> Ptr P256Y -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr P256X
px Ptr P256Y
py ->
            Ptr P256Scalar
-> Ptr P256X -> Ptr P256Y -> Ptr P256X -> Ptr P256Y -> IO ()
ccryptonite_p256e_point_mul Ptr P256Scalar
n Ptr P256X
px Ptr P256Y
py Ptr P256X
dx Ptr P256Y
dy
        Ptr P256Scalar -> Ptr Word8 -> IO ()
ccryptonite_p256_to_bin (Ptr P256X -> Ptr P256Scalar
forall a b. Ptr a -> Ptr b
castPtr Ptr P256X
dx) Ptr Word8
dst

-- | multiply the point @p with @n2 and add a lifted to curve value @n1
--
-- > n1 * G + n2 * p
--
-- warning: variable time
pointsMulVarTime :: Scalar -> Scalar -> Point -> Point
pointsMulVarTime :: Scalar -> Scalar -> Point -> Point
pointsMulVarTime Scalar
n1 Scalar
n2 Point
p = (Ptr P256X -> Ptr P256Y -> IO ()) -> Point
withNewPoint ((Ptr P256X -> Ptr P256Y -> IO ()) -> Point)
-> (Ptr P256X -> Ptr P256Y -> IO ()) -> Point
forall a b. (a -> b) -> a -> b
$ \Ptr P256X
dx Ptr P256Y
dy ->
    Scalar -> (Ptr P256Scalar -> IO ()) -> IO ()
forall a. Scalar -> (Ptr P256Scalar -> IO a) -> IO a
withScalar Scalar
n1 ((Ptr P256Scalar -> IO ()) -> IO ())
-> (Ptr P256Scalar -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr P256Scalar
pn1 -> Scalar -> (Ptr P256Scalar -> IO ()) -> IO ()
forall a. Scalar -> (Ptr P256Scalar -> IO a) -> IO a
withScalar Scalar
n2 ((Ptr P256Scalar -> IO ()) -> IO ())
-> (Ptr P256Scalar -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr P256Scalar
pn2 -> Point -> (Ptr P256X -> Ptr P256Y -> IO ()) -> IO ()
forall a. Point -> (Ptr P256X -> Ptr P256Y -> IO a) -> IO a
withPoint Point
p ((Ptr P256X -> Ptr P256Y -> IO ()) -> IO ())
-> (Ptr P256X -> Ptr P256Y -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr P256X
px Ptr P256Y
py ->
        Ptr P256Scalar
-> Ptr P256Scalar
-> Ptr P256X
-> Ptr P256Y
-> Ptr P256X
-> Ptr P256Y
-> IO ()
ccryptonite_p256_points_mul_vartime Ptr P256Scalar
pn1 Ptr P256Scalar
pn2 Ptr P256X
px Ptr P256Y
py Ptr P256X
dx Ptr P256Y
dy

-- | Check if a 'Point' is valid
pointIsValid :: Point -> Bool
pointIsValid :: Point -> Bool
pointIsValid Point
p = IO Bool -> Bool
forall a. IO a -> a
unsafeDoIO (IO Bool -> Bool) -> IO Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Point -> (Ptr P256X -> Ptr P256Y -> IO Bool) -> IO Bool
forall a. Point -> (Ptr P256X -> Ptr P256Y -> IO a) -> IO a
withPoint Point
p ((Ptr P256X -> Ptr P256Y -> IO Bool) -> IO Bool)
-> (Ptr P256X -> Ptr P256Y -> IO Bool) -> IO Bool
forall a b. (a -> b) -> a -> b
$ \Ptr P256X
px Ptr P256Y
py -> do
    CInt
r <- Ptr P256X -> Ptr P256Y -> IO CInt
ccryptonite_p256_is_valid_point Ptr P256X
px Ptr P256Y
py
    Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (CInt
r CInt -> CInt -> Bool
forall a. Eq a => a -> a -> Bool
/= CInt
0)

-- | Check if a 'Point' is the point at infinity
pointIsAtInfinity :: Point -> Bool
pointIsAtInfinity :: Point -> Bool
pointIsAtInfinity (Point Bytes
b) = Bytes -> Bool
forall ba. ByteArrayAccess ba => ba -> Bool
constAllZero Bytes
b

-- | Return the x coordinate as a 'Scalar' if the point is not at infinity
pointX :: Point -> Maybe Scalar
pointX :: Point -> Maybe Scalar
pointX Point
p
    | Point -> Bool
pointIsAtInfinity Point
p = Maybe Scalar
forall a. Maybe a
Nothing
    | Bool
otherwise           = Scalar -> Maybe Scalar
forall a. a -> Maybe a
Just (Scalar -> Maybe Scalar) -> Scalar -> Maybe Scalar
forall a b. (a -> b) -> a -> b
$
        (Ptr P256Scalar -> IO ()) -> Scalar
withNewScalarFreeze ((Ptr P256Scalar -> IO ()) -> Scalar)
-> (Ptr P256Scalar -> IO ()) -> Scalar
forall a b. (a -> b) -> a -> b
$ \Ptr P256Scalar
d    ->
        Point -> (Ptr P256X -> Ptr P256Y -> IO ()) -> IO ()
forall a. Point -> (Ptr P256X -> Ptr P256Y -> IO a) -> IO a
withPoint Point
p         ((Ptr P256X -> Ptr P256Y -> IO ()) -> IO ())
-> (Ptr P256X -> Ptr P256Y -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr P256X
px Ptr P256Y
_ ->
            Ptr P256Scalar -> Ptr P256Scalar -> Ptr P256Scalar -> IO ()
ccryptonite_p256_mod Ptr P256Scalar
ccryptonite_SECP256r1_n (Ptr P256X -> Ptr P256Scalar
forall a b. Ptr a -> Ptr b
castPtr Ptr P256X
px) (Ptr P256Scalar -> Ptr P256Scalar
forall a b. Ptr a -> Ptr b
castPtr Ptr P256Scalar
d)

-- | Convert a point to (x,y) Integers
pointToIntegers :: Point -> (Integer, Integer)
pointToIntegers :: Point -> (Integer, Integer)
pointToIntegers Point
p = IO (Integer, Integer) -> (Integer, Integer)
forall a. IO a -> a
unsafeDoIO (IO (Integer, Integer) -> (Integer, Integer))
-> IO (Integer, Integer) -> (Integer, Integer)
forall a b. (a -> b) -> a -> b
$ Point
-> (Ptr P256X -> Ptr P256Y -> IO (Integer, Integer))
-> IO (Integer, Integer)
forall a. Point -> (Ptr P256X -> Ptr P256Y -> IO a) -> IO a
withPoint Point
p ((Ptr P256X -> Ptr P256Y -> IO (Integer, Integer))
 -> IO (Integer, Integer))
-> (Ptr P256X -> Ptr P256Y -> IO (Integer, Integer))
-> IO (Integer, Integer)
forall a b. (a -> b) -> a -> b
$ \Ptr P256X
px Ptr P256Y
py ->
    Int
-> (Ptr Word8 -> IO (Integer, Integer)) -> IO (Integer, Integer)
forall a. Int -> (Ptr Word8 -> IO a) -> IO a
allocTemp Int
32 (Ptr P256Scalar
-> Ptr P256Scalar -> Ptr Word8 -> IO (Integer, Integer)
serialize (Ptr P256X -> Ptr P256Scalar
forall a b. Ptr a -> Ptr b
castPtr Ptr P256X
px) (Ptr P256Y -> Ptr P256Scalar
forall a b. Ptr a -> Ptr b
castPtr Ptr P256Y
py))
  where
    serialize :: Ptr P256Scalar
-> Ptr P256Scalar -> Ptr Word8 -> IO (Integer, Integer)
serialize Ptr P256Scalar
px Ptr P256Scalar
py Ptr Word8
temp = do
        Ptr P256Scalar -> Ptr Word8 -> IO ()
ccryptonite_p256_to_bin Ptr P256Scalar
px Ptr Word8
temp
        Integer
x <- Ptr Word8 -> Int -> IO Integer
os2ip Ptr Word8
temp Int
scalarSize
        Ptr P256Scalar -> Ptr Word8 -> IO ()
ccryptonite_p256_to_bin Ptr P256Scalar
py Ptr Word8
temp
        Integer
y <- Ptr Word8 -> Int -> IO Integer
os2ip Ptr Word8
temp Int
scalarSize
        (Integer, Integer) -> IO (Integer, Integer)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Integer
x,Integer
y)

-- | Convert from (x,y) Integers to a point
pointFromIntegers :: (Integer, Integer) -> Point
pointFromIntegers :: (Integer, Integer) -> Point
pointFromIntegers (Integer
x,Integer
y) = (Ptr P256X -> Ptr P256Y -> IO ()) -> Point
withNewPoint ((Ptr P256X -> Ptr P256Y -> IO ()) -> Point)
-> (Ptr P256X -> Ptr P256Y -> IO ()) -> Point
forall a b. (a -> b) -> a -> b
$ \Ptr P256X
dx Ptr P256Y
dy ->
    Int -> (Ptr Word8 -> IO ()) -> IO ()
forall a. Int -> (Ptr Word8 -> IO a) -> IO a
allocTemp Int
scalarSize (\Ptr Word8
temp -> Ptr Word8 -> Ptr P256Scalar -> Integer -> IO ()
fill Ptr Word8
temp (Ptr P256X -> Ptr P256Scalar
forall a b. Ptr a -> Ptr b
castPtr Ptr P256X
dx) Integer
x IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Ptr Word8 -> Ptr P256Scalar -> Integer -> IO ()
fill Ptr Word8
temp (Ptr P256Y -> Ptr P256Scalar
forall a b. Ptr a -> Ptr b
castPtr Ptr P256Y
dy) Integer
y)
  where
    -- put @n to @temp in big endian format, then from @temp to @dest in p256 scalar format
    fill :: Ptr Word8 -> Ptr P256Scalar -> Integer -> IO ()
    fill :: Ptr Word8 -> Ptr P256Scalar -> Integer -> IO ()
fill Ptr Word8
temp Ptr P256Scalar
dest Integer
n = do
        -- write the integer in big endian format to temp
        Ptr Word8 -> Word8 -> Int -> IO ()
memSet Ptr Word8
temp Word8
0 Int
scalarSize
        Int
e <- Integer -> Ptr Word8 -> Int -> IO Int
i2ospOf Integer
n Ptr Word8
temp Int
scalarSize
        if Int
e Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0
            then String -> IO ()
forall a. HasCallStack => String -> a
error String
"pointFromIntegers: filling failed"
            else () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        -- then fill dest with the P256 scalar from temp
        Ptr Word8 -> Ptr P256Scalar -> IO ()
ccryptonite_p256_from_bin Ptr Word8
temp Ptr P256Scalar
dest

-- | Convert a point to a binary representation
pointToBinary :: ByteArray ba => Point -> ba
pointToBinary :: forall ba. ByteArray ba => Point -> ba
pointToBinary Point
p = Int -> (Ptr Word8 -> IO ()) -> ba
forall a p. ByteArray a => Int -> (Ptr p -> IO ()) -> a
B.unsafeCreate Int
pointSize ((Ptr Word8 -> IO ()) -> ba) -> (Ptr Word8 -> IO ()) -> ba
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
dst -> Point -> (Ptr P256X -> Ptr P256Y -> IO ()) -> IO ()
forall a. Point -> (Ptr P256X -> Ptr P256Y -> IO a) -> IO a
withPoint Point
p ((Ptr P256X -> Ptr P256Y -> IO ()) -> IO ())
-> (Ptr P256X -> Ptr P256Y -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr P256X
px Ptr P256Y
py -> do
    Ptr P256Scalar -> Ptr Word8 -> IO ()
ccryptonite_p256_to_bin (Ptr P256X -> Ptr P256Scalar
forall a b. Ptr a -> Ptr b
castPtr Ptr P256X
px) Ptr Word8
dst
    Ptr P256Scalar -> Ptr Word8 -> IO ()
ccryptonite_p256_to_bin (Ptr P256Y -> Ptr P256Scalar
forall a b. Ptr a -> Ptr b
castPtr Ptr P256Y
py) (Ptr Word8
dst Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
32)

-- | Convert from binary to a valid point
pointFromBinary :: ByteArrayAccess ba => ba -> CryptoFailable Point
pointFromBinary :: forall ba. ByteArrayAccess ba => ba -> CryptoFailable Point
pointFromBinary ba
ba = ba -> CryptoFailable Point
forall ba. ByteArrayAccess ba => ba -> CryptoFailable Point
unsafePointFromBinary ba
ba CryptoFailable Point
-> (Point -> CryptoFailable Point) -> CryptoFailable Point
forall a b.
CryptoFailable a -> (a -> CryptoFailable b) -> CryptoFailable b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Point -> CryptoFailable Point
validatePoint
  where
    validatePoint :: Point -> CryptoFailable Point
    validatePoint :: Point -> CryptoFailable Point
validatePoint Point
p
        | Point -> Bool
pointIsValid Point
p = Point -> CryptoFailable Point
forall a. a -> CryptoFailable a
CryptoPassed Point
p
        | Bool
otherwise      = CryptoError -> CryptoFailable Point
forall a. CryptoError -> CryptoFailable a
CryptoFailed CryptoError
CryptoError_PointCoordinatesInvalid

-- | Convert from binary to a point, possibly invalid
unsafePointFromBinary :: ByteArrayAccess ba => ba -> CryptoFailable Point
unsafePointFromBinary :: forall ba. ByteArrayAccess ba => ba -> CryptoFailable Point
unsafePointFromBinary ba
ba
    | ba -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length ba
ba Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
pointSize = CryptoError -> CryptoFailable Point
forall a. CryptoError -> CryptoFailable a
CryptoFailed CryptoError
CryptoError_PublicKeySizeInvalid
    | Bool
otherwise                =
        Point -> CryptoFailable Point
forall a. a -> CryptoFailable a
CryptoPassed (Point -> CryptoFailable Point) -> Point -> CryptoFailable Point
forall a b. (a -> b) -> a -> b
$ (Ptr P256X -> Ptr P256Y -> IO ()) -> Point
withNewPoint ((Ptr P256X -> Ptr P256Y -> IO ()) -> Point)
-> (Ptr P256X -> Ptr P256Y -> IO ()) -> Point
forall a b. (a -> b) -> a -> b
$ \Ptr P256X
px Ptr P256Y
py -> ba -> (Ptr Word8 -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
forall p a. ba -> (Ptr p -> IO a) -> IO a
B.withByteArray ba
ba ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
src -> do
            Ptr Word8 -> Ptr P256Scalar -> IO ()
ccryptonite_p256_from_bin Ptr Word8
src                        (Ptr P256X -> Ptr P256Scalar
forall a b. Ptr a -> Ptr b
castPtr Ptr P256X
px)
            Ptr Word8 -> Ptr P256Scalar -> IO ()
ccryptonite_p256_from_bin (Ptr Word8
src Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
scalarSize) (Ptr P256Y -> Ptr P256Scalar
forall a b. Ptr a -> Ptr b
castPtr Ptr P256Y
py)

------------------------------------------------------------------------
-- Scalar methods
------------------------------------------------------------------------

-- | Generate a randomly generated new scalar
scalarGenerate :: MonadRandom randomly => randomly Scalar
scalarGenerate :: forall (randomly :: * -> *).
MonadRandom randomly =>
randomly Scalar
scalarGenerate = CryptoFailable Scalar -> Scalar
forall {a}. CryptoFailable a -> a
unwrap (CryptoFailable Scalar -> Scalar)
-> (ScrubbedBytes -> CryptoFailable Scalar)
-> ScrubbedBytes
-> Scalar
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ScrubbedBytes -> CryptoFailable Scalar
forall ba. ByteArrayAccess ba => ba -> CryptoFailable Scalar
scalarFromBinary (ScrubbedBytes -> CryptoFailable Scalar)
-> (ScrubbedBytes -> ScrubbedBytes)
-> ScrubbedBytes
-> CryptoFailable Scalar
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ScrubbedBytes -> ScrubbedBytes
witness (ScrubbedBytes -> Scalar)
-> randomly ScrubbedBytes -> randomly Scalar
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> randomly ScrubbedBytes
forall byteArray. ByteArray byteArray => Int -> randomly byteArray
forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
Int -> m byteArray
getRandomBytes Int
32
  where
    unwrap :: CryptoFailable a -> a
unwrap (CryptoFailed CryptoError
_) = String -> a
forall a. HasCallStack => String -> a
error String
"scalarGenerate: assumption failed"
    unwrap (CryptoPassed a
s) = a
s
    witness :: ScrubbedBytes -> ScrubbedBytes
    witness :: ScrubbedBytes -> ScrubbedBytes
witness = ScrubbedBytes -> ScrubbedBytes
forall a. a -> a
id

-- | The scalar representing 0
scalarZero :: Scalar
scalarZero :: Scalar
scalarZero = (Ptr P256Scalar -> IO ()) -> Scalar
withNewScalarFreeze ((Ptr P256Scalar -> IO ()) -> Scalar)
-> (Ptr P256Scalar -> IO ()) -> Scalar
forall a b. (a -> b) -> a -> b
$ \Ptr P256Scalar
d -> Ptr P256Scalar -> IO ()
ccryptonite_p256_init Ptr P256Scalar
d

-- | The scalar representing the curve order
scalarN :: Scalar
scalarN :: Scalar
scalarN = CryptoFailable Scalar -> Scalar
forall {a}. CryptoFailable a -> a
throwCryptoError (Integer -> CryptoFailable Scalar
scalarFromInteger Integer
order)

-- | Check if the scalar is 0
scalarIsZero :: Scalar -> Bool
scalarIsZero :: Scalar -> Bool
scalarIsZero Scalar
s = IO Bool -> Bool
forall a. IO a -> a
unsafeDoIO (IO Bool -> Bool) -> IO Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Scalar -> (Ptr P256Scalar -> IO Bool) -> IO Bool
forall a. Scalar -> (Ptr P256Scalar -> IO a) -> IO a
withScalar Scalar
s ((Ptr P256Scalar -> IO Bool) -> IO Bool)
-> (Ptr P256Scalar -> IO Bool) -> IO Bool
forall a b. (a -> b) -> a -> b
$ \Ptr P256Scalar
d -> do
    CInt
result <- Ptr P256Scalar -> IO CInt
ccryptonite_p256_is_zero Ptr P256Scalar
d
    Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> IO Bool) -> Bool -> IO Bool
forall a b. (a -> b) -> a -> b
$ CInt
result CInt -> CInt -> Bool
forall a. Eq a => a -> a -> Bool
/= CInt
0

-- | Perform addition between two scalars
--
-- > a + b
scalarAdd :: Scalar -> Scalar -> Scalar
scalarAdd :: Scalar -> Scalar -> Scalar
scalarAdd Scalar
a Scalar
b =
    (Ptr P256Scalar -> IO ()) -> Scalar
withNewScalarFreeze ((Ptr P256Scalar -> IO ()) -> Scalar)
-> (Ptr P256Scalar -> IO ()) -> Scalar
forall a b. (a -> b) -> a -> b
$ \Ptr P256Scalar
d -> Scalar -> (Ptr P256Scalar -> IO ()) -> IO ()
forall a. Scalar -> (Ptr P256Scalar -> IO a) -> IO a
withScalar Scalar
a ((Ptr P256Scalar -> IO ()) -> IO ())
-> (Ptr P256Scalar -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr P256Scalar
pa -> Scalar -> (Ptr P256Scalar -> IO ()) -> IO ()
forall a. Scalar -> (Ptr P256Scalar -> IO a) -> IO a
withScalar Scalar
b ((Ptr P256Scalar -> IO ()) -> IO ())
-> (Ptr P256Scalar -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr P256Scalar
pb ->
        Ptr P256Scalar
-> Ptr P256Scalar -> Ptr P256Scalar -> Ptr P256Scalar -> IO ()
ccryptonite_p256e_modadd Ptr P256Scalar
ccryptonite_SECP256r1_n Ptr P256Scalar
pa Ptr P256Scalar
pb Ptr P256Scalar
d

-- | Perform subtraction between two scalars
--
-- > a - b
scalarSub :: Scalar -> Scalar -> Scalar
scalarSub :: Scalar -> Scalar -> Scalar
scalarSub Scalar
a Scalar
b =
    (Ptr P256Scalar -> IO ()) -> Scalar
withNewScalarFreeze ((Ptr P256Scalar -> IO ()) -> Scalar)
-> (Ptr P256Scalar -> IO ()) -> Scalar
forall a b. (a -> b) -> a -> b
$ \Ptr P256Scalar
d -> Scalar -> (Ptr P256Scalar -> IO ()) -> IO ()
forall a. Scalar -> (Ptr P256Scalar -> IO a) -> IO a
withScalar Scalar
a ((Ptr P256Scalar -> IO ()) -> IO ())
-> (Ptr P256Scalar -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr P256Scalar
pa -> Scalar -> (Ptr P256Scalar -> IO ()) -> IO ()
forall a. Scalar -> (Ptr P256Scalar -> IO a) -> IO a
withScalar Scalar
b ((Ptr P256Scalar -> IO ()) -> IO ())
-> (Ptr P256Scalar -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr P256Scalar
pb ->
        Ptr P256Scalar
-> Ptr P256Scalar -> Ptr P256Scalar -> Ptr P256Scalar -> IO ()
ccryptonite_p256e_modsub Ptr P256Scalar
ccryptonite_SECP256r1_n Ptr P256Scalar
pa Ptr P256Scalar
pb Ptr P256Scalar
d

-- | Perform multiplication between two scalars
--
-- > a * b
scalarMul :: Scalar -> Scalar -> Scalar
scalarMul :: Scalar -> Scalar -> Scalar
scalarMul Scalar
a Scalar
b =
    (Ptr P256Scalar -> IO ()) -> Scalar
withNewScalarFreeze ((Ptr P256Scalar -> IO ()) -> Scalar)
-> (Ptr P256Scalar -> IO ()) -> Scalar
forall a b. (a -> b) -> a -> b
$ \Ptr P256Scalar
d -> Scalar -> (Ptr P256Scalar -> IO ()) -> IO ()
forall a. Scalar -> (Ptr P256Scalar -> IO a) -> IO a
withScalar Scalar
a ((Ptr P256Scalar -> IO ()) -> IO ())
-> (Ptr P256Scalar -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr P256Scalar
pa -> Scalar -> (Ptr P256Scalar -> IO ()) -> IO ()
forall a. Scalar -> (Ptr P256Scalar -> IO a) -> IO a
withScalar Scalar
b ((Ptr P256Scalar -> IO ()) -> IO ())
-> (Ptr P256Scalar -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr P256Scalar
pb ->
         Ptr P256Scalar
-> Ptr P256Scalar
-> P256Digit
-> Ptr P256Scalar
-> Ptr P256Scalar
-> IO ()
ccryptonite_p256_modmul Ptr P256Scalar
ccryptonite_SECP256r1_n Ptr P256Scalar
pa P256Digit
0 Ptr P256Scalar
pb Ptr P256Scalar
d

-- | Give the inverse of the scalar
--
-- > 1 / a
--
-- warning: variable time
scalarInv :: Scalar -> Scalar
scalarInv :: Scalar -> Scalar
scalarInv Scalar
a =
    (Ptr P256Scalar -> IO ()) -> Scalar
withNewScalarFreeze ((Ptr P256Scalar -> IO ()) -> Scalar)
-> (Ptr P256Scalar -> IO ()) -> Scalar
forall a b. (a -> b) -> a -> b
$ \Ptr P256Scalar
b -> Scalar -> (Ptr P256Scalar -> IO ()) -> IO ()
forall a. Scalar -> (Ptr P256Scalar -> IO a) -> IO a
withScalar Scalar
a ((Ptr P256Scalar -> IO ()) -> IO ())
-> (Ptr P256Scalar -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr P256Scalar
pa ->
        Ptr P256Scalar -> Ptr P256Scalar -> Ptr P256Scalar -> IO ()
ccryptonite_p256_modinv_vartime Ptr P256Scalar
ccryptonite_SECP256r1_n Ptr P256Scalar
pa Ptr P256Scalar
b

-- | Give the inverse of the scalar using safe exponentiation
--
-- > 1 / a
scalarInvSafe :: Scalar -> Scalar
scalarInvSafe :: Scalar -> Scalar
scalarInvSafe Scalar
a =
    (Ptr P256Scalar -> IO ()) -> Scalar
withNewScalarFreeze ((Ptr P256Scalar -> IO ()) -> Scalar)
-> (Ptr P256Scalar -> IO ()) -> Scalar
forall a b. (a -> b) -> a -> b
$ \Ptr P256Scalar
b -> Scalar -> (Ptr P256Scalar -> IO ()) -> IO ()
forall a. Scalar -> (Ptr P256Scalar -> IO a) -> IO a
withScalar Scalar
a ((Ptr P256Scalar -> IO ()) -> IO ())
-> (Ptr P256Scalar -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr P256Scalar
pa ->
        Ptr P256Scalar -> Ptr P256Scalar -> IO ()
ccryptonite_p256e_scalar_invert Ptr P256Scalar
pa Ptr P256Scalar
b

-- | Compare 2 Scalar
scalarCmp :: Scalar -> Scalar -> Ordering
scalarCmp :: Scalar -> Scalar -> Ordering
scalarCmp Scalar
a Scalar
b = IO Ordering -> Ordering
forall a. IO a -> a
unsafeDoIO (IO Ordering -> Ordering) -> IO Ordering -> Ordering
forall a b. (a -> b) -> a -> b
$
    Scalar -> (Ptr P256Scalar -> IO Ordering) -> IO Ordering
forall a. Scalar -> (Ptr P256Scalar -> IO a) -> IO a
withScalar Scalar
a ((Ptr P256Scalar -> IO Ordering) -> IO Ordering)
-> (Ptr P256Scalar -> IO Ordering) -> IO Ordering
forall a b. (a -> b) -> a -> b
$ \Ptr P256Scalar
pa -> Scalar -> (Ptr P256Scalar -> IO Ordering) -> IO Ordering
forall a. Scalar -> (Ptr P256Scalar -> IO a) -> IO a
withScalar Scalar
b ((Ptr P256Scalar -> IO Ordering) -> IO Ordering)
-> (Ptr P256Scalar -> IO Ordering) -> IO Ordering
forall a b. (a -> b) -> a -> b
$ \Ptr P256Scalar
pb -> do
        CInt
v <- Ptr P256Scalar -> Ptr P256Scalar -> IO CInt
ccryptonite_p256_cmp Ptr P256Scalar
pa Ptr P256Scalar
pb
        Ordering -> IO Ordering
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Ordering -> IO Ordering) -> Ordering -> IO Ordering
forall a b. (a -> b) -> a -> b
$ CInt -> CInt -> Ordering
forall a. Ord a => a -> a -> Ordering
compare CInt
v CInt
0

-- | convert a scalar from binary
scalarFromBinary :: ByteArrayAccess ba => ba -> CryptoFailable Scalar
scalarFromBinary :: forall ba. ByteArrayAccess ba => ba -> CryptoFailable Scalar
scalarFromBinary ba
ba
    | ba -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length ba
ba Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
scalarSize = CryptoError -> CryptoFailable Scalar
forall a. CryptoError -> CryptoFailable a
CryptoFailed CryptoError
CryptoError_SecretKeySizeInvalid
    | Bool
otherwise                 =
        Scalar -> CryptoFailable Scalar
forall a. a -> CryptoFailable a
CryptoPassed (Scalar -> CryptoFailable Scalar)
-> Scalar -> CryptoFailable Scalar
forall a b. (a -> b) -> a -> b
$ (Ptr P256Scalar -> IO ()) -> Scalar
withNewScalarFreeze ((Ptr P256Scalar -> IO ()) -> Scalar)
-> (Ptr P256Scalar -> IO ()) -> Scalar
forall a b. (a -> b) -> a -> b
$ \Ptr P256Scalar
p -> ba -> (Ptr Word8 -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
forall p a. ba -> (Ptr p -> IO a) -> IO a
B.withByteArray ba
ba ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
b ->
            Ptr Word8 -> Ptr P256Scalar -> IO ()
ccryptonite_p256_from_bin Ptr Word8
b Ptr P256Scalar
p
{-# NOINLINE scalarFromBinary #-}

-- | convert a scalar to binary
scalarToBinary :: ByteArray ba => Scalar -> ba
scalarToBinary :: forall ba. ByteArray ba => Scalar -> ba
scalarToBinary Scalar
s = Int -> (Ptr Word8 -> IO ()) -> ba
forall a p. ByteArray a => Int -> (Ptr p -> IO ()) -> a
B.unsafeCreate Int
scalarSize ((Ptr Word8 -> IO ()) -> ba) -> (Ptr Word8 -> IO ()) -> ba
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
b -> Scalar -> (Ptr P256Scalar -> IO ()) -> IO ()
forall a. Scalar -> (Ptr P256Scalar -> IO a) -> IO a
withScalar Scalar
s ((Ptr P256Scalar -> IO ()) -> IO ())
-> (Ptr P256Scalar -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr P256Scalar
p ->
    Ptr P256Scalar -> Ptr Word8 -> IO ()
ccryptonite_p256_to_bin Ptr P256Scalar
p Ptr Word8
b
{-# NOINLINE scalarToBinary #-}

-- | Convert from an Integer to a P256 Scalar
scalarFromInteger :: Integer -> CryptoFailable Scalar
scalarFromInteger :: Integer -> CryptoFailable Scalar
scalarFromInteger Integer
i =
    CryptoFailable Scalar
-> (Bytes -> CryptoFailable Scalar)
-> Maybe Bytes
-> CryptoFailable Scalar
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (CryptoError -> CryptoFailable Scalar
forall a. CryptoError -> CryptoFailable a
CryptoFailed CryptoError
CryptoError_SecretKeySizeInvalid) Bytes -> CryptoFailable Scalar
forall ba. ByteArrayAccess ba => ba -> CryptoFailable Scalar
scalarFromBinary (Int -> Integer -> Maybe Bytes
forall ba. ByteArray ba => Int -> Integer -> Maybe ba
S.i2ospOf Int
32 Integer
i :: Maybe Bytes)

-- | Convert from a P256 Scalar to an Integer
scalarToInteger :: Scalar -> Integer
scalarToInteger :: Scalar -> Integer
scalarToInteger Scalar
s = Bytes -> Integer
forall ba. ByteArrayAccess ba => ba -> Integer
S.os2ip (Scalar -> Bytes
forall ba. ByteArray ba => Scalar -> ba
scalarToBinary Scalar
s :: Bytes)

------------------------------------------------------------------------
-- Memory Helpers
------------------------------------------------------------------------
withNewPoint :: (Ptr P256X -> Ptr P256Y -> IO ()) -> Point
withNewPoint :: (Ptr P256X -> Ptr P256Y -> IO ()) -> Point
withNewPoint Ptr P256X -> Ptr P256Y -> IO ()
f = Bytes -> Point
Point (Bytes -> Point) -> Bytes -> Point
forall a b. (a -> b) -> a -> b
$ Int -> (Ptr P256X -> IO ()) -> Bytes
forall a p. ByteArray a => Int -> (Ptr p -> IO ()) -> a
B.unsafeCreate Int
pointSize ((Ptr P256X -> IO ()) -> Bytes) -> (Ptr P256X -> IO ()) -> Bytes
forall a b. (a -> b) -> a -> b
$ \Ptr P256X
px -> Ptr P256X -> Ptr P256Y -> IO ()
f Ptr P256X
px (Ptr P256X -> Ptr P256Y
pxToPy Ptr P256X
px)
{-# NOINLINE withNewPoint #-}

withPoint :: Point -> (Ptr P256X -> Ptr P256Y -> IO a) -> IO a
withPoint :: forall a. Point -> (Ptr P256X -> Ptr P256Y -> IO a) -> IO a
withPoint (Point Bytes
d) Ptr P256X -> Ptr P256Y -> IO a
f = Bytes -> (Ptr P256X -> IO a) -> IO a
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
forall p a. Bytes -> (Ptr p -> IO a) -> IO a
B.withByteArray Bytes
d ((Ptr P256X -> IO a) -> IO a) -> (Ptr P256X -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \Ptr P256X
px -> Ptr P256X -> Ptr P256Y -> IO a
f Ptr P256X
px (Ptr P256X -> Ptr P256Y
pxToPy Ptr P256X
px)

pxToPy :: Ptr P256X -> Ptr P256Y
pxToPy :: Ptr P256X -> Ptr P256Y
pxToPy Ptr P256X
px = Ptr Any -> Ptr P256Y
forall a b. Ptr a -> Ptr b
castPtr (Ptr P256X
px Ptr P256X -> Int -> Ptr Any
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
scalarSize)

withNewScalarFreeze :: (Ptr P256Scalar -> IO ()) -> Scalar
withNewScalarFreeze :: (Ptr P256Scalar -> IO ()) -> Scalar
withNewScalarFreeze Ptr P256Scalar -> IO ()
f = ScrubbedBytes -> Scalar
Scalar (ScrubbedBytes -> Scalar) -> ScrubbedBytes -> Scalar
forall a b. (a -> b) -> a -> b
$ Int -> (Ptr P256Scalar -> IO ()) -> ScrubbedBytes
forall a p. ByteArray a => Int -> (Ptr p -> IO ()) -> a
B.allocAndFreeze Int
scalarSize Ptr P256Scalar -> IO ()
f
{-# NOINLINE withNewScalarFreeze #-}

withTempPoint :: (Ptr P256X -> Ptr P256Y -> IO a) -> IO a
withTempPoint :: forall a. (Ptr P256X -> Ptr P256Y -> IO a) -> IO a
withTempPoint Ptr P256X -> Ptr P256Y -> IO a
f = Int -> (Ptr Word8 -> IO a) -> IO a
forall a. Int -> (Ptr Word8 -> IO a) -> IO a
allocTempScrubbed Int
pointSize (\Ptr Word8
p -> let px :: Ptr b
px = Ptr Word8 -> Ptr b
forall a b. Ptr a -> Ptr b
castPtr Ptr Word8
p in Ptr P256X -> Ptr P256Y -> IO a
f Ptr P256X
forall {b}. Ptr b
px (Ptr P256X -> Ptr P256Y
pxToPy Ptr P256X
forall {b}. Ptr b
px))

withScalar :: Scalar -> (Ptr P256Scalar -> IO a) -> IO a
withScalar :: forall a. Scalar -> (Ptr P256Scalar -> IO a) -> IO a
withScalar (Scalar ScrubbedBytes
d) Ptr P256Scalar -> IO a
f = ScrubbedBytes -> (Ptr P256Scalar -> IO a) -> IO a
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
forall p a. ScrubbedBytes -> (Ptr p -> IO a) -> IO a
B.withByteArray ScrubbedBytes
d Ptr P256Scalar -> IO a
f

allocTemp :: Int -> (Ptr Word8 -> IO a) -> IO a
allocTemp :: forall a. Int -> (Ptr Word8 -> IO a) -> IO a
allocTemp Int
n Ptr Word8 -> IO a
f = (a, Bytes) -> a
forall a. (a, Bytes) -> a
ignoreSnd ((a, Bytes) -> a) -> IO (a, Bytes) -> IO a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> (Ptr Word8 -> IO a) -> IO (a, Bytes)
forall ba p a. ByteArray ba => Int -> (Ptr p -> IO a) -> IO (a, ba)
forall p a. Int -> (Ptr p -> IO a) -> IO (a, Bytes)
B.allocRet Int
n Ptr Word8 -> IO a
f
  where
    ignoreSnd :: (a, Bytes) -> a
    ignoreSnd :: forall a. (a, Bytes) -> a
ignoreSnd = (a, Bytes) -> a
forall a b. (a, b) -> a
fst

allocTempScrubbed :: Int -> (Ptr Word8 -> IO a) -> IO a
allocTempScrubbed :: forall a. Int -> (Ptr Word8 -> IO a) -> IO a
allocTempScrubbed Int
n Ptr Word8 -> IO a
f = (a, ScrubbedBytes) -> a
forall a. (a, ScrubbedBytes) -> a
ignoreSnd ((a, ScrubbedBytes) -> a) -> IO (a, ScrubbedBytes) -> IO a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> (Ptr Word8 -> IO a) -> IO (a, ScrubbedBytes)
forall ba p a. ByteArray ba => Int -> (Ptr p -> IO a) -> IO (a, ba)
forall p a. Int -> (Ptr p -> IO a) -> IO (a, ScrubbedBytes)
B.allocRet Int
n Ptr Word8 -> IO a
f
  where
    ignoreSnd :: (a, ScrubbedBytes) -> a
    ignoreSnd :: forall a. (a, ScrubbedBytes) -> a
ignoreSnd = (a, ScrubbedBytes) -> a
forall a b. (a, b) -> a
fst

------------------------------------------------------------------------
-- Foreign bindings
------------------------------------------------------------------------
foreign import ccall "&cryptonite_SECP256r1_n"
    ccryptonite_SECP256r1_n :: Ptr P256Scalar
foreign import ccall "&cryptonite_SECP256r1_p"
    ccryptonite_SECP256r1_p :: Ptr P256Scalar
foreign import ccall "&cryptonite_SECP256r1_b"
    ccryptonite_SECP256r1_b :: Ptr P256Scalar

foreign import ccall "cryptonite_p256_init"
    ccryptonite_p256_init :: Ptr P256Scalar -> IO ()
foreign import ccall "cryptonite_p256_is_zero"
    ccryptonite_p256_is_zero :: Ptr P256Scalar -> IO CInt
foreign import ccall "cryptonite_p256_clear"
    ccryptonite_p256_clear :: Ptr P256Scalar -> IO ()
foreign import ccall "cryptonite_p256e_modadd"
    ccryptonite_p256e_modadd :: Ptr P256Scalar -> Ptr P256Scalar -> Ptr P256Scalar -> Ptr P256Scalar -> IO ()
foreign import ccall "cryptonite_p256_add_d"
    ccryptonite_p256_add_d :: Ptr P256Scalar -> P256Digit -> Ptr P256Scalar -> IO CInt
foreign import ccall "cryptonite_p256e_modsub"
    ccryptonite_p256e_modsub :: Ptr P256Scalar -> Ptr P256Scalar -> Ptr P256Scalar -> Ptr P256Scalar -> IO ()
foreign import ccall "cryptonite_p256_cmp"
    ccryptonite_p256_cmp :: Ptr P256Scalar -> Ptr P256Scalar -> IO CInt
foreign import ccall "cryptonite_p256_mod"
    ccryptonite_p256_mod :: Ptr P256Scalar -> Ptr P256Scalar -> Ptr P256Scalar -> IO ()
foreign import ccall "cryptonite_p256_modmul"
    ccryptonite_p256_modmul :: Ptr P256Scalar -> Ptr P256Scalar -> P256Digit -> Ptr P256Scalar -> Ptr P256Scalar -> IO ()
foreign import ccall "cryptonite_p256e_scalar_invert"
    ccryptonite_p256e_scalar_invert :: Ptr P256Scalar -> Ptr P256Scalar -> IO ()
--foreign import ccall "cryptonite_p256_modinv"
--    ccryptonite_p256_modinv :: Ptr P256Scalar -> Ptr P256Scalar -> Ptr P256Scalar -> IO ()
foreign import ccall "cryptonite_p256_modinv_vartime"
    ccryptonite_p256_modinv_vartime :: Ptr P256Scalar -> Ptr P256Scalar -> Ptr P256Scalar -> IO ()
foreign import ccall "cryptonite_p256_base_point_mul"
    ccryptonite_p256_basepoint_mul :: Ptr P256Scalar
                                   -> Ptr P256X -> Ptr P256Y
                                   -> IO ()

foreign import ccall "cryptonite_p256e_point_add"
    ccryptonite_p256e_point_add :: Ptr P256X -> Ptr P256Y
                                -> Ptr P256X -> Ptr P256Y
                                -> Ptr P256X -> Ptr P256Y
                                -> IO ()

foreign import ccall "cryptonite_p256e_point_negate"
    ccryptonite_p256e_point_negate :: Ptr P256X -> Ptr P256Y
                                   -> Ptr P256X -> Ptr P256Y
                                   -> IO ()

-- compute (out_x,out_y) = n * (in_x,in_y)
foreign import ccall "cryptonite_p256e_point_mul"
    ccryptonite_p256e_point_mul :: Ptr P256Scalar -- n
                                -> Ptr P256X -> Ptr P256Y -- in_{x,y}
                                -> Ptr P256X -> Ptr P256Y -- out_{x,y}
                                -> IO ()

-- compute (out_x,out,y) = n1 * G + n2 * (in_x,in_y)
foreign import ccall "cryptonite_p256_points_mul_vartime"
    ccryptonite_p256_points_mul_vartime :: Ptr P256Scalar -- n1
                                        -> Ptr P256Scalar -- n2
                                        -> Ptr P256X -> Ptr P256Y -- in_{x,y}
                                        -> Ptr P256X -> Ptr P256Y -- out_{x,y}
                                        -> IO ()
foreign import ccall "cryptonite_p256_is_valid_point"
    ccryptonite_p256_is_valid_point :: Ptr P256X -> Ptr P256Y -> IO CInt

foreign import ccall "cryptonite_p256_to_bin"
    ccryptonite_p256_to_bin :: Ptr P256Scalar -> Ptr Word8 -> IO ()

foreign import ccall "cryptonite_p256_from_bin"
    ccryptonite_p256_from_bin :: Ptr Word8 -> Ptr P256Scalar -> IO ()