-- |
-- Module      : Crypto.Store.Keys
-- License     : BSD-style
-- Maintainer  : Olivier Chéron <olivier.cheron@gmail.com>
-- Stability   : experimental
-- Portability : unknown
--
--
{-# LANGUAGE RecordWildCards #-}
module Crypto.Store.Keys
    ( KeyPair(..), keyPairFromPrivKey, keyPairToPrivKey, keyPairToPubKey
    , keyPairMatchesKey, keyPairMatchesCert
    ) where

import Data.Function (on)
import Data.Maybe (fromMaybe)

import qualified Data.X509 as X509
import Data.X509.EC

import qualified Crypto.PubKey.RSA.Types as RSA
import qualified Crypto.PubKey.DSA as DSA
import qualified Crypto.PubKey.Curve25519 as X25519
import qualified Crypto.PubKey.Curve448 as X448
import qualified Crypto.PubKey.Ed25519 as Ed25519
import qualified Crypto.PubKey.Ed448 as Ed448

import Crypto.Store.PKCS8.EC

-- | Holds a private and public key together, with guaranty that they both
-- match.  Therefore no constructor is exposed.  Content may be accessed
-- through functions 'keyPairToPrivKey' and 'keyPairToPubKey'.
--
-- Call function 'keyPairFromPrivKey' to build a @KeyPair@.
data KeyPair =
      KeyPairRSA RSA.PrivateKey RSA.PublicKey             -- ^ RSA key pair
    | KeyPairDSA DSA.KeyPair                              -- ^ DSA key pair
    | KeyPairEC X509.PrivKeyEC X509.PubKeyEC              -- ^ EC key pair
    | KeyPairX25519 X25519.SecretKey X25519.PublicKey     -- ^ X25519 key pair
    | KeyPairX448 X448.SecretKey X448.PublicKey           -- ^ X448 key pair
    | KeyPairEd25519 Ed25519.SecretKey Ed25519.PublicKey  -- ^ Ed25519 key pair
    | KeyPairEd448 Ed448.SecretKey Ed448.PublicKey        -- ^ Ed448 key pair

instance Show KeyPair where
    showsPrec :: Int -> KeyPair -> ShowS
showsPrec Int
d KeyPair
keyPair = Bool -> ShowS -> ShowS
showParen (Int
d Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
10) (ShowS -> ShowS) -> ShowS -> ShowS
forall a b. (a -> b) -> a -> b
$
        String -> ShowS
showString String
"keyPairFromPrivKey " ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> PrivKey -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec Int
11 (KeyPair -> PrivKey
keyPairToPrivKey KeyPair
keyPair)

instance Eq KeyPair where
    == :: KeyPair -> KeyPair -> Bool
(==) = PrivKey -> PrivKey -> Bool
forall a. Eq a => a -> a -> Bool
(==) (PrivKey -> PrivKey -> Bool)
-> (KeyPair -> PrivKey) -> KeyPair -> KeyPair -> Bool
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` KeyPair -> PrivKey
keyPairToPrivKey

-- | Builds a key pair from an X.509 private key.
keyPairFromPrivKey :: X509.PrivKey -> KeyPair
keyPairFromPrivKey :: PrivKey -> KeyPair
keyPairFromPrivKey (X509.PrivKeyRSA PrivateKey
priv) = PrivateKey -> PublicKey -> KeyPair
KeyPairRSA PrivateKey
priv (KeyPair -> PublicKey
RSA.toPublicKey (PrivateKey -> KeyPair
RSA.KeyPair PrivateKey
priv))
keyPairFromPrivKey (X509.PrivKeyDSA PrivateKey
priv) = KeyPair -> KeyPair
KeyPairDSA (PrivateKey -> KeyPair
dsaPairFromPriv PrivateKey
priv)
keyPairFromPrivKey (X509.PrivKeyEC PrivKeyEC
priv) = PrivKeyEC -> PubKeyEC -> KeyPair
KeyPairEC PrivKeyEC
priv (PrivKeyEC -> PubKeyEC
ecPubFromPriv PrivKeyEC
priv)
keyPairFromPrivKey (X509.PrivKeyX25519 SecretKey
priv) = SecretKey -> PublicKey -> KeyPair
KeyPairX25519 SecretKey
priv (SecretKey -> PublicKey
X25519.toPublic SecretKey
priv)
keyPairFromPrivKey (X509.PrivKeyX448 SecretKey
priv) = SecretKey -> PublicKey -> KeyPair
KeyPairX448 SecretKey
priv (SecretKey -> PublicKey
X448.toPublic SecretKey
priv)
keyPairFromPrivKey (X509.PrivKeyEd25519 SecretKey
priv) = SecretKey -> PublicKey -> KeyPair
KeyPairEd25519 SecretKey
priv (SecretKey -> PublicKey
Ed25519.toPublic SecretKey
priv)
keyPairFromPrivKey (X509.PrivKeyEd448 SecretKey
priv) = SecretKey -> PublicKey -> KeyPair
KeyPairEd448 SecretKey
priv (SecretKey -> PublicKey
Ed448.toPublic SecretKey
priv)

dsaPairFromPriv :: DSA.PrivateKey -> DSA.KeyPair
dsaPairFromPriv :: PrivateKey -> KeyPair
dsaPairFromPriv PrivateKey
k = Params -> Integer -> Integer -> KeyPair
DSA.KeyPair Params
params Integer
y Integer
x
  where y :: Integer
y       = Params -> Integer -> Integer
DSA.calculatePublic Params
params Integer
x
        params :: Params
params  = PrivateKey -> Params
DSA.private_params PrivateKey
k
        x :: Integer
x       = PrivateKey -> Integer
DSA.private_x PrivateKey
k

ecPubFromPriv :: X509.PrivKeyEC -> X509.PubKeyEC
ecPubFromPriv :: PrivKeyEC -> PubKeyEC
ecPubFromPriv PrivKeyEC
priv = case PrivKeyEC
priv of
    X509.PrivKeyEC_Prime{Integer
SerializedPoint
privkeyEC_priv :: Integer
privkeyEC_a :: Integer
privkeyEC_b :: Integer
privkeyEC_prime :: Integer
privkeyEC_generator :: SerializedPoint
privkeyEC_order :: Integer
privkeyEC_cofactor :: Integer
privkeyEC_seed :: Integer
privkeyEC_seed :: PrivKeyEC -> Integer
privkeyEC_cofactor :: PrivKeyEC -> Integer
privkeyEC_order :: PrivKeyEC -> Integer
privkeyEC_generator :: PrivKeyEC -> SerializedPoint
privkeyEC_prime :: PrivKeyEC -> Integer
privkeyEC_b :: PrivKeyEC -> Integer
privkeyEC_a :: PrivKeyEC -> Integer
privkeyEC_priv :: PrivKeyEC -> Integer
..} -> X509.PubKeyEC_Prime
        { pubkeyEC_pub :: SerializedPoint
pubkeyEC_pub = Curve -> Integer -> SerializedPoint
getSerializedPoint Curve
curve Integer
privkeyEC_priv
        , pubkeyEC_a :: Integer
pubkeyEC_a = Integer
privkeyEC_a
        , pubkeyEC_b :: Integer
pubkeyEC_b = Integer
privkeyEC_b
        , pubkeyEC_prime :: Integer
pubkeyEC_prime = Integer
privkeyEC_prime
        , pubkeyEC_generator :: SerializedPoint
pubkeyEC_generator = SerializedPoint
privkeyEC_generator
        , pubkeyEC_order :: Integer
pubkeyEC_order = Integer
privkeyEC_order
        , pubkeyEC_cofactor :: Integer
pubkeyEC_cofactor = Integer
privkeyEC_cofactor
        , pubkeyEC_seed :: Integer
pubkeyEC_seed = Integer
privkeyEC_seed
        }
    X509.PrivKeyEC_Named{Integer
CurveName
privkeyEC_priv :: PrivKeyEC -> Integer
privkeyEC_name :: CurveName
privkeyEC_priv :: Integer
privkeyEC_name :: PrivKeyEC -> CurveName
..} -> X509.PubKeyEC_Named
        { pubkeyEC_name :: CurveName
X509.pubkeyEC_name = CurveName
privkeyEC_name
        , pubkeyEC_pub :: SerializedPoint
X509.pubkeyEC_pub = Curve -> Integer -> SerializedPoint
getSerializedPoint Curve
curve Integer
privkeyEC_priv
        }
  where curve :: Curve
curve = Curve -> Maybe Curve -> Curve
forall a. a -> Maybe a -> a
fromMaybe (String -> Curve
forall a. HasCallStack => String -> a
error String
"ecPubFromPriv: invalid EC parameters") (PrivKeyEC -> Maybe Curve
ecPrivKeyCurve PrivKeyEC
priv)

-- | Gets the X.509 private key in a key pair.
keyPairToPrivKey :: KeyPair -> X509.PrivKey
keyPairToPrivKey :: KeyPair -> PrivKey
keyPairToPrivKey (KeyPairRSA PrivateKey
priv PublicKey
_) = PrivateKey -> PrivKey
X509.PrivKeyRSA PrivateKey
priv
keyPairToPrivKey (KeyPairDSA KeyPair
pair) = PrivateKey -> PrivKey
X509.PrivKeyDSA (KeyPair -> PrivateKey
DSA.toPrivateKey KeyPair
pair)
keyPairToPrivKey (KeyPairEC PrivKeyEC
priv PubKeyEC
_) = PrivKeyEC -> PrivKey
X509.PrivKeyEC PrivKeyEC
priv
keyPairToPrivKey (KeyPairX25519 SecretKey
priv PublicKey
_) = SecretKey -> PrivKey
X509.PrivKeyX25519 SecretKey
priv
keyPairToPrivKey (KeyPairX448 SecretKey
priv PublicKey
_) = SecretKey -> PrivKey
X509.PrivKeyX448 SecretKey
priv
keyPairToPrivKey (KeyPairEd25519 SecretKey
priv PublicKey
_) = SecretKey -> PrivKey
X509.PrivKeyEd25519 SecretKey
priv
keyPairToPrivKey (KeyPairEd448 SecretKey
priv PublicKey
_) = SecretKey -> PrivKey
X509.PrivKeyEd448 SecretKey
priv

-- | Gets the X.509 public key in a key pair.
keyPairToPubKey :: KeyPair -> X509.PubKey
keyPairToPubKey :: KeyPair -> PubKey
keyPairToPubKey (KeyPairRSA PrivateKey
_ PublicKey
pub) = PublicKey -> PubKey
X509.PubKeyRSA PublicKey
pub
keyPairToPubKey (KeyPairDSA KeyPair
pair) = PublicKey -> PubKey
X509.PubKeyDSA (KeyPair -> PublicKey
DSA.toPublicKey KeyPair
pair)
keyPairToPubKey (KeyPairEC PrivKeyEC
_ PubKeyEC
pub) = PubKeyEC -> PubKey
X509.PubKeyEC PubKeyEC
pub
keyPairToPubKey (KeyPairX25519 SecretKey
_ PublicKey
pub) = PublicKey -> PubKey
X509.PubKeyX25519 PublicKey
pub
keyPairToPubKey (KeyPairX448 SecretKey
_ PublicKey
pub) = PublicKey -> PubKey
X509.PubKeyX448 PublicKey
pub
keyPairToPubKey (KeyPairEd25519 SecretKey
_ PublicKey
pub) = PublicKey -> PubKey
X509.PubKeyEd25519 PublicKey
pub
keyPairToPubKey (KeyPairEd448 SecretKey
_ PublicKey
pub) = PublicKey -> PubKey
X509.PubKeyEd448 PublicKey
pub

-- | Returns 'True' when the given X.509 public key is consistent with a key
-- pair, which means that the public key can be derived from the private key in
-- the key pair.
keyPairMatchesKey :: KeyPair -> X509.PubKey -> Bool
keyPairMatchesKey :: KeyPair -> PubKey -> Bool
keyPairMatchesKey (KeyPairRSA PrivateKey
_ PublicKey
pub) (X509.PubKeyRSA PublicKey
other) = PublicKey
pub PublicKey -> PublicKey -> Bool
forall a. Eq a => a -> a -> Bool
== PublicKey
other
keyPairMatchesKey (KeyPairDSA KeyPair
pair) (X509.PubKeyDSA PublicKey
other) = KeyPair -> PublicKey
DSA.toPublicKey KeyPair
pair PublicKey -> PublicKey -> Bool
forall a. Eq a => a -> a -> Bool
== PublicKey
other
keyPairMatchesKey (KeyPairEC PrivKeyEC
_ PubKeyEC
pub) (X509.PubKeyEC PubKeyEC
other) = PubKeyEC
pub PubKeyEC -> PubKeyEC -> Bool
forall a. Eq a => a -> a -> Bool
== PubKeyEC
other
keyPairMatchesKey (KeyPairX25519 SecretKey
_ PublicKey
pub) (X509.PubKeyX25519 PublicKey
other) = PublicKey
pub PublicKey -> PublicKey -> Bool
forall a. Eq a => a -> a -> Bool
== PublicKey
other
keyPairMatchesKey (KeyPairX448 SecretKey
_ PublicKey
pub) (X509.PubKeyX448 PublicKey
other) = PublicKey
pub PublicKey -> PublicKey -> Bool
forall a. Eq a => a -> a -> Bool
== PublicKey
other
keyPairMatchesKey (KeyPairEd25519 SecretKey
_ PublicKey
pub) (X509.PubKeyEd25519 PublicKey
other) = PublicKey
pub PublicKey -> PublicKey -> Bool
forall a. Eq a => a -> a -> Bool
== PublicKey
other
keyPairMatchesKey (KeyPairEd448 SecretKey
_ PublicKey
pub) (X509.PubKeyEd448 PublicKey
other) = PublicKey
pub PublicKey -> PublicKey -> Bool
forall a. Eq a => a -> a -> Bool
== PublicKey
other
keyPairMatchesKey KeyPair
_ PubKey
_ = Bool
False

keyPairMatchesCert :: KeyPair -> X509.SignedCertificate -> Bool
keyPairMatchesCert :: KeyPair -> SignedCertificate -> Bool
keyPairMatchesCert KeyPair
keyPair SignedCertificate
cert =
    let obj :: Certificate
obj = Signed Certificate -> Certificate
forall a. (Show a, Eq a, ASN1Object a) => Signed a -> a
X509.signedObject (SignedCertificate -> Signed Certificate
forall a. (Show a, Eq a, ASN1Object a) => SignedExact a -> Signed a
X509.getSigned SignedCertificate
cert)
        pub :: PubKey
pub = Certificate -> PubKey
X509.certPubKey Certificate
obj
     in KeyPair -> PubKey -> Bool
keyPairMatchesKey KeyPair
keyPair PubKey
pub