-- | In-memory TLS 1.2/1.3 session manager.
--
-- * Limitation: you can set the maximum size of the session data database.
-- * Automatic pruning: old session data over their lifetime are pruned automatically.
-- * Energy saving: no dedicate pruning thread is running when the size of session data database is zero.
-- * Replay resistance: each session data is used at most once to prevent replay attacks against 0RTT early data of TLS 1.3.
module Network.TLS.SessionManager (
    newSessionManager,
    Config,
    defaultConfig,
    ticketLifetime,
    pruningDelay,
    dbMaxSize,
) where

import Basement.Block (Block)
import Control.Exception (assert)
import Control.Reaper
import Data.ByteArray (convert)
import Data.ByteString (ByteString)
import Data.IORef
import Data.OrdPSQ (OrdPSQ)
import qualified Data.OrdPSQ as Q
import Network.TLS
import qualified System.Clock as C

import Network.TLS.Imports

----------------------------------------------------------------

-- | Configuration for session managers.
data Config = Config
    { Config -> Int
ticketLifetime :: Int
    -- ^ Ticket lifetime in seconds.
    , Config -> Int
pruningDelay :: Int
    -- ^ Pruning delay in seconds. This is set to 'reaperDelay'.
    , Config -> Int
dbMaxSize :: Int
    -- ^ The limit size of session data entries.
    }

-- | ticketLifetime: 2 hours (7200 seconds), pruningDelay: 10 minutes (600 seconds), dbMaxSize: 1000 entries.
defaultConfig :: Config
defaultConfig :: Config
defaultConfig =
    Config
        { ticketLifetime :: Int
ticketLifetime = Int
7200
        , pruningDelay :: Int
pruningDelay = Int
600
        , dbMaxSize :: Int
dbMaxSize = Int
1000
        }

----------------------------------------------------------------

toKey :: ByteString -> Block Word8
toKey :: ByteString -> SessionIDCopy
toKey = ByteString -> SessionIDCopy
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
convert

toValue :: SessionData -> SessionDataCopy
toValue :: SessionData -> SessionDataCopy
toValue SessionData
sd =
    SessionData -> SessionDataCopy
SessionDataCopy (SessionData -> SessionDataCopy) -> SessionData -> SessionDataCopy
forall a b. (a -> b) -> a -> b
$
        SessionData
sd
            { sessionSecret = convert $ sessionSecret sd
            , sessionALPN = convert <$> sessionALPN sd
            }

fromValue :: SessionDataCopy -> SessionData
fromValue :: SessionDataCopy -> SessionData
fromValue (SessionDataCopy SessionData
sd) =
    SessionData
sd
        { sessionSecret = convert $ sessionSecret sd
        , sessionALPN = convert <$> sessionALPN sd
        }

----------------------------------------------------------------

type SessionIDCopy = Block Word8
newtype SessionDataCopy = SessionDataCopy SessionData
    deriving (Int -> SessionDataCopy -> ShowS
[SessionDataCopy] -> ShowS
SessionDataCopy -> String
(Int -> SessionDataCopy -> ShowS)
-> (SessionDataCopy -> String)
-> ([SessionDataCopy] -> ShowS)
-> Show SessionDataCopy
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> SessionDataCopy -> ShowS
showsPrec :: Int -> SessionDataCopy -> ShowS
$cshow :: SessionDataCopy -> String
show :: SessionDataCopy -> String
$cshowList :: [SessionDataCopy] -> ShowS
showList :: [SessionDataCopy] -> ShowS
Show, SessionDataCopy -> SessionDataCopy -> Bool
(SessionDataCopy -> SessionDataCopy -> Bool)
-> (SessionDataCopy -> SessionDataCopy -> Bool)
-> Eq SessionDataCopy
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: SessionDataCopy -> SessionDataCopy -> Bool
== :: SessionDataCopy -> SessionDataCopy -> Bool
$c/= :: SessionDataCopy -> SessionDataCopy -> Bool
/= :: SessionDataCopy -> SessionDataCopy -> Bool
Eq)

type Sec = Int64
type Value = (SessionDataCopy, IORef Availability)
type DB = OrdPSQ SessionIDCopy Sec Value
type Item = (SessionIDCopy, Sec, Value, Operation)

data Operation = Add | Del
data Use = SingleUse | MultipleUse
data Availability = Fresh | Used

----------------------------------------------------------------

-- | Creating an in-memory session manager.
newSessionManager :: Config -> IO SessionManager
newSessionManager :: Config -> IO SessionManager
newSessionManager Config
conf = do
    let lifetime :: Sec
lifetime = Int -> Sec
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Sec) -> Int -> Sec
forall a b. (a -> b) -> a -> b
$ Config -> Int
ticketLifetime Config
conf
        maxsiz :: Int
maxsiz = Config -> Int
dbMaxSize Config
conf
    Reaper (OrdPSQ SessionIDCopy Sec Value) Item
reaper <-
        ReaperSettings (OrdPSQ SessionIDCopy Sec Value) Item
-> IO (Reaper (OrdPSQ SessionIDCopy Sec Value) Item)
forall workload item.
ReaperSettings workload item -> IO (Reaper workload item)
mkReaper
            ReaperSettings [Any] Any
forall item. ReaperSettings [item] item
defaultReaperSettings
                { reaperEmpty = Q.empty
                , reaperCons = cons maxsiz
                , reaperAction = clean
                , reaperNull = Q.null
                , reaperDelay = pruningDelay conf * 1000000
                }
    SessionManager -> IO SessionManager
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (SessionManager -> IO SessionManager)
-> SessionManager -> IO SessionManager
forall a b. (a -> b) -> a -> b
$
        SessionManager
noSessionManager
            { sessionResume = resume reaper MultipleUse
            , sessionResumeOnlyOnce = resume reaper SingleUse
            , sessionEstablish = \ByteString
x SessionData
y -> Reaper (OrdPSQ SessionIDCopy Sec Value) Item
-> Sec -> ByteString -> SessionData -> IO ()
establish Reaper (OrdPSQ SessionIDCopy Sec Value) Item
reaper Sec
lifetime ByteString
x SessionData
y IO () -> IO (Maybe ByteString) -> IO (Maybe ByteString)
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Maybe ByteString -> IO (Maybe ByteString)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe ByteString
forall a. Maybe a
Nothing
            , sessionInvalidate = invalidate reaper
            , sessionUseTicket = False
            }

cons :: Int -> Item -> DB -> DB
cons :: Int
-> Item
-> OrdPSQ SessionIDCopy Sec Value
-> OrdPSQ SessionIDCopy Sec Value
cons Int
lim (SessionIDCopy
k, Sec
t, Value
v, Operation
Add) OrdPSQ SessionIDCopy Sec Value
db
    | Int
lim Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0 = OrdPSQ SessionIDCopy Sec Value
forall k p v. OrdPSQ k p v
Q.empty
    | OrdPSQ SessionIDCopy Sec Value -> Int
forall k p v. OrdPSQ k p v -> Int
Q.size OrdPSQ SessionIDCopy Sec Value
db Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
lim = case OrdPSQ SessionIDCopy Sec Value
-> Maybe
     (SessionIDCopy, Sec, Value, OrdPSQ SessionIDCopy Sec Value)
forall k p v.
(Ord k, Ord p) =>
OrdPSQ k p v -> Maybe (k, p, v, OrdPSQ k p v)
Q.minView OrdPSQ SessionIDCopy Sec Value
db of
        Maybe (SessionIDCopy, Sec, Value, OrdPSQ SessionIDCopy Sec Value)
Nothing -> Bool
-> OrdPSQ SessionIDCopy Sec Value -> OrdPSQ SessionIDCopy Sec Value
forall a. (?callStack::CallStack) => Bool -> a -> a
assert Bool
False (OrdPSQ SessionIDCopy Sec Value -> OrdPSQ SessionIDCopy Sec Value)
-> OrdPSQ SessionIDCopy Sec Value -> OrdPSQ SessionIDCopy Sec Value
forall a b. (a -> b) -> a -> b
$ SessionIDCopy
-> Sec
-> Value
-> OrdPSQ SessionIDCopy Sec Value
-> OrdPSQ SessionIDCopy Sec Value
forall k p v.
(Ord k, Ord p) =>
k -> p -> v -> OrdPSQ k p v -> OrdPSQ k p v
Q.insert SessionIDCopy
k Sec
t Value
v OrdPSQ SessionIDCopy Sec Value
forall k p v. OrdPSQ k p v
Q.empty
        Just (SessionIDCopy
_, Sec
_, Value
_, OrdPSQ SessionIDCopy Sec Value
db') -> SessionIDCopy
-> Sec
-> Value
-> OrdPSQ SessionIDCopy Sec Value
-> OrdPSQ SessionIDCopy Sec Value
forall k p v.
(Ord k, Ord p) =>
k -> p -> v -> OrdPSQ k p v -> OrdPSQ k p v
Q.insert SessionIDCopy
k Sec
t Value
v OrdPSQ SessionIDCopy Sec Value
db'
    | Bool
otherwise = SessionIDCopy
-> Sec
-> Value
-> OrdPSQ SessionIDCopy Sec Value
-> OrdPSQ SessionIDCopy Sec Value
forall k p v.
(Ord k, Ord p) =>
k -> p -> v -> OrdPSQ k p v -> OrdPSQ k p v
Q.insert SessionIDCopy
k Sec
t Value
v OrdPSQ SessionIDCopy Sec Value
db
cons Int
_ (SessionIDCopy
k, Sec
_, Value
_, Operation
Del) OrdPSQ SessionIDCopy Sec Value
db = SessionIDCopy
-> OrdPSQ SessionIDCopy Sec Value -> OrdPSQ SessionIDCopy Sec Value
forall k p v. (Ord k, Ord p) => k -> OrdPSQ k p v -> OrdPSQ k p v
Q.delete SessionIDCopy
k OrdPSQ SessionIDCopy Sec Value
db

clean :: DB -> IO (DB -> DB)
clean :: OrdPSQ SessionIDCopy Sec Value
-> IO
     (OrdPSQ SessionIDCopy Sec Value -> OrdPSQ SessionIDCopy Sec Value)
clean OrdPSQ SessionIDCopy Sec Value
olddb = do
    Sec
currentTime <- TimeSpec -> Sec
C.sec (TimeSpec -> Sec) -> IO TimeSpec -> IO Sec
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Clock -> IO TimeSpec
C.getTime Clock
C.Monotonic
    let pruned :: OrdPSQ SessionIDCopy Sec Value
pruned = ([(SessionIDCopy, Sec, Value)], OrdPSQ SessionIDCopy Sec Value)
-> OrdPSQ SessionIDCopy Sec Value
forall a b. (a, b) -> b
snd (([(SessionIDCopy, Sec, Value)], OrdPSQ SessionIDCopy Sec Value)
 -> OrdPSQ SessionIDCopy Sec Value)
-> ([(SessionIDCopy, Sec, Value)], OrdPSQ SessionIDCopy Sec Value)
-> OrdPSQ SessionIDCopy Sec Value
forall a b. (a -> b) -> a -> b
$ Sec
-> OrdPSQ SessionIDCopy Sec Value
-> ([(SessionIDCopy, Sec, Value)], OrdPSQ SessionIDCopy Sec Value)
forall k p v.
(Ord k, Ord p) =>
p -> OrdPSQ k p v -> ([(k, p, v)], OrdPSQ k p v)
Q.atMostView Sec
currentTime OrdPSQ SessionIDCopy Sec Value
olddb
    (OrdPSQ SessionIDCopy Sec Value -> OrdPSQ SessionIDCopy Sec Value)
-> IO
     (OrdPSQ SessionIDCopy Sec Value -> OrdPSQ SessionIDCopy Sec Value)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ((OrdPSQ SessionIDCopy Sec Value -> OrdPSQ SessionIDCopy Sec Value)
 -> IO
      (OrdPSQ SessionIDCopy Sec Value -> OrdPSQ SessionIDCopy Sec Value))
-> (OrdPSQ SessionIDCopy Sec Value
    -> OrdPSQ SessionIDCopy Sec Value)
-> IO
     (OrdPSQ SessionIDCopy Sec Value -> OrdPSQ SessionIDCopy Sec Value)
forall a b. (a -> b) -> a -> b
$ OrdPSQ SessionIDCopy Sec Value
-> OrdPSQ SessionIDCopy Sec Value -> OrdPSQ SessionIDCopy Sec Value
forall {k} {p} {v}.
(Ord k, Ord p) =>
OrdPSQ k p v -> OrdPSQ k p v -> OrdPSQ k p v
merge OrdPSQ SessionIDCopy Sec Value
pruned
  where
    ins :: OrdPSQ k p v -> (k, p, v) -> OrdPSQ k p v
ins OrdPSQ k p v
db (k
k, p
p, v
v) = k -> p -> v -> OrdPSQ k p v -> OrdPSQ k p v
forall k p v.
(Ord k, Ord p) =>
k -> p -> v -> OrdPSQ k p v -> OrdPSQ k p v
Q.insert k
k p
p v
v OrdPSQ k p v
db
    -- There is not 'merge' API.
    -- We hope that newdb is smaller than pruned.
    merge :: OrdPSQ k p v -> OrdPSQ k p v -> OrdPSQ k p v
merge OrdPSQ k p v
pruned OrdPSQ k p v
newdb = (OrdPSQ k p v -> (k, p, v) -> OrdPSQ k p v)
-> OrdPSQ k p v -> [(k, p, v)] -> OrdPSQ k p v
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' OrdPSQ k p v -> (k, p, v) -> OrdPSQ k p v
forall {k} {p} {v}.
(Ord k, Ord p) =>
OrdPSQ k p v -> (k, p, v) -> OrdPSQ k p v
ins OrdPSQ k p v
pruned [(k, p, v)]
entries
      where
        entries :: [(k, p, v)]
entries = OrdPSQ k p v -> [(k, p, v)]
forall k p v. OrdPSQ k p v -> [(k, p, v)]
Q.toList OrdPSQ k p v
newdb

----------------------------------------------------------------

establish
    :: Reaper DB Item
    -> Sec
    -> SessionID
    -> SessionData
    -> IO ()
establish :: Reaper (OrdPSQ SessionIDCopy Sec Value) Item
-> Sec -> ByteString -> SessionData -> IO ()
establish Reaper (OrdPSQ SessionIDCopy Sec Value) Item
reaper Sec
lifetime ByteString
k SessionData
sd = do
    IORef Availability
ref <- Availability -> IO (IORef Availability)
forall a. a -> IO (IORef a)
newIORef Availability
Fresh
    Sec
p <- (Sec -> Sec -> Sec
forall a. Num a => a -> a -> a
+ Sec
lifetime) (Sec -> Sec) -> (TimeSpec -> Sec) -> TimeSpec -> Sec
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TimeSpec -> Sec
C.sec (TimeSpec -> Sec) -> IO TimeSpec -> IO Sec
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Clock -> IO TimeSpec
C.getTime Clock
C.Monotonic
    let v :: Value
v = (SessionDataCopy
sd', IORef Availability
ref)
    Reaper (OrdPSQ SessionIDCopy Sec Value) Item -> Item -> IO ()
forall workload item. Reaper workload item -> item -> IO ()
reaperAdd Reaper (OrdPSQ SessionIDCopy Sec Value) Item
reaper (SessionIDCopy
k', Sec
p, Value
v, Operation
Add)
  where
    k' :: SessionIDCopy
k' = ByteString -> SessionIDCopy
toKey ByteString
k
    sd' :: SessionDataCopy
sd' = SessionData -> SessionDataCopy
toValue SessionData
sd

resume
    :: Reaper DB Item
    -> Use
    -> SessionID
    -> IO (Maybe SessionData)
resume :: Reaper (OrdPSQ SessionIDCopy Sec Value) Item
-> Use -> ByteString -> IO (Maybe SessionData)
resume Reaper (OrdPSQ SessionIDCopy Sec Value) Item
reaper Use
use ByteString
k = do
    OrdPSQ SessionIDCopy Sec Value
db <- Reaper (OrdPSQ SessionIDCopy Sec Value) Item
-> IO (OrdPSQ SessionIDCopy Sec Value)
forall workload item. Reaper workload item -> IO workload
reaperRead Reaper (OrdPSQ SessionIDCopy Sec Value) Item
reaper
    case SessionIDCopy
-> OrdPSQ SessionIDCopy Sec Value -> Maybe (Sec, Value)
forall k p v. Ord k => k -> OrdPSQ k p v -> Maybe (p, v)
Q.lookup SessionIDCopy
k' OrdPSQ SessionIDCopy Sec Value
db of
        Maybe (Sec, Value)
Nothing -> Maybe SessionData -> IO (Maybe SessionData)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe SessionData
forall a. Maybe a
Nothing
        Just (Sec
p, v :: Value
v@(SessionDataCopy
sd, IORef Availability
ref)) ->
            case Use
use of
                Use
SingleUse -> do
                    Bool
available <- IORef Availability
-> (Availability -> (Availability, Bool)) -> IO Bool
forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef' IORef Availability
ref Availability -> (Availability, Bool)
check
                    Reaper (OrdPSQ SessionIDCopy Sec Value) Item -> Item -> IO ()
forall workload item. Reaper workload item -> item -> IO ()
reaperAdd Reaper (OrdPSQ SessionIDCopy Sec Value) Item
reaper (SessionIDCopy
k', Sec
p, Value
v, Operation
Del)
                    Maybe SessionData -> IO (Maybe SessionData)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe SessionData -> IO (Maybe SessionData))
-> Maybe SessionData -> IO (Maybe SessionData)
forall a b. (a -> b) -> a -> b
$ if Bool
available then SessionData -> Maybe SessionData
forall a. a -> Maybe a
Just (SessionDataCopy -> SessionData
fromValue SessionDataCopy
sd) else Maybe SessionData
forall a. Maybe a
Nothing
                Use
MultipleUse -> Maybe SessionData -> IO (Maybe SessionData)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe SessionData -> IO (Maybe SessionData))
-> Maybe SessionData -> IO (Maybe SessionData)
forall a b. (a -> b) -> a -> b
$ SessionData -> Maybe SessionData
forall a. a -> Maybe a
Just (SessionDataCopy -> SessionData
fromValue SessionDataCopy
sd)
  where
    check :: Availability -> (Availability, Bool)
check Availability
Fresh = (Availability
Used, Bool
True)
    check Availability
Used = (Availability
Used, Bool
False)
    k' :: SessionIDCopy
k' = ByteString -> SessionIDCopy
toKey ByteString
k

invalidate
    :: Reaper DB Item
    -> SessionID
    -> IO ()
invalidate :: Reaper (OrdPSQ SessionIDCopy Sec Value) Item -> ByteString -> IO ()
invalidate Reaper (OrdPSQ SessionIDCopy Sec Value) Item
reaper ByteString
k = do
    OrdPSQ SessionIDCopy Sec Value
db <- Reaper (OrdPSQ SessionIDCopy Sec Value) Item
-> IO (OrdPSQ SessionIDCopy Sec Value)
forall workload item. Reaper workload item -> IO workload
reaperRead Reaper (OrdPSQ SessionIDCopy Sec Value) Item
reaper
    case SessionIDCopy
-> OrdPSQ SessionIDCopy Sec Value -> Maybe (Sec, Value)
forall k p v. Ord k => k -> OrdPSQ k p v -> Maybe (p, v)
Q.lookup SessionIDCopy
k' OrdPSQ SessionIDCopy Sec Value
db of
        Maybe (Sec, Value)
Nothing -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        Just (Sec
p, Value
v) -> Reaper (OrdPSQ SessionIDCopy Sec Value) Item -> Item -> IO ()
forall workload item. Reaper workload item -> item -> IO ()
reaperAdd Reaper (OrdPSQ SessionIDCopy Sec Value) Item
reaper (SessionIDCopy
k', Sec
p, Value
v, Operation
Del)
  where
    k' :: SessionIDCopy
k' = ByteString -> SessionIDCopy
toKey ByteString
k