-- |
-- Module      : Foundation.Random.XorShift
-- License     : BSD-style
--
-- XorShift variant: Xoroshiro128+
-- <https://en.wikipedia.org/wiki/Xoroshiro128%2B>
--
-- C implementation at:
-- <http://xoroshiro.di.unimi.it/xoroshiro128plus.c>
--
{-# LANGUAGE MagicHash #-}
module Foundation.Random.XorShift
    ( State
    , initialize
    , next
    , nextList
    , nextDouble
    ) where

import           Basement.Imports
import           Basement.PrimType
import           Basement.Types.OffsetSize
import           Foundation.Numerical
import           Foundation.Bits
import           Foundation.Random.Class
import           Foundation.Random.DRG
import           Basement.Compat.Bifunctor
import           Basement.Compat.ExtList (reverse)
import qualified Basement.UArray as A
import qualified Prelude
import           GHC.Prim
import           GHC.Float


-- | State of Xoroshiro128 plus
data State = State {-# UNPACK #-} !Word64 {-# UNPACK #-} !Word64

instance RandomGen State where
    randomNew :: forall (m :: * -> *). MonadRandom m => m State
randomNew = Word64 -> Word64 -> State
initialize (Word64 -> Word64 -> State) -> m Word64 -> m (Word64 -> State)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m Word64
forall (m :: * -> *). MonadRandom m => m Word64
getRandomWord64 m (Word64 -> State) -> m Word64 -> m State
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> m Word64
forall (m :: * -> *). MonadRandom m => m Word64
getRandomWord64
    randomNewFrom :: UArray Word8 -> Maybe State
randomNewFrom UArray Word8
bs
        | UArray Word8 -> CountOf Word8
forall ty. UArray ty -> CountOf ty
A.length UArray Word8
bs CountOf Word8 -> CountOf Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== CountOf Word8
16 =
            let bs64 :: UArray Word64
bs64 = UArray Word8 -> UArray Word64
forall a b. (PrimType a, PrimType b) => UArray a -> UArray b
A.recast UArray Word8
bs
             in State -> Maybe State
forall a. a -> Maybe a
Just (State -> Maybe State) -> State -> Maybe State
forall a b. (a -> b) -> a -> b
$ Word64 -> Word64 -> State
State (UArray Word64 -> Offset Word64 -> Word64
forall ty. PrimType ty => UArray ty -> Offset ty -> ty
A.index UArray Word64
bs64 Offset Word64
0) (UArray Word64 -> Offset Word64 -> Word64
forall ty. PrimType ty => UArray ty -> Offset ty -> ty
A.index UArray Word64
bs64 Offset Word64
1)
        | Bool
otherwise         = Maybe State
forall a. Maybe a
Nothing
    randomGenerate :: CountOf Word8 -> State -> (UArray Word8, State)
randomGenerate = CountOf Word8 -> State -> (UArray Word8, State)
generate
    randomGenerateWord64 :: State -> (Word64, State)
randomGenerateWord64 = State -> (Word64, State)
next
    randomGenerateF32 :: State -> (Float, State)
randomGenerateF32 = State -> (Float, State)
nextFloat
    randomGenerateF64 :: State -> (Double, State)
randomGenerateF64 = State -> (Double, State)
nextDouble

initialize :: Word64 -> Word64 -> State
initialize :: Word64 -> Word64 -> State
initialize Word64
s0 Word64
s1 = Word64 -> Word64 -> State
State Word64
s0 Word64
s1

generate :: CountOf Word8 -> State -> (UArray Word8, State)
generate :: CountOf Word8 -> State -> (UArray Word8, State)
generate CountOf Word8
c State
st =
    ([Word64] -> UArray Word8)
-> ([Word64], State) -> (UArray Word8, State)
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (CountOf Word8 -> UArray Word8 -> UArray Word8
forall ty. CountOf ty -> UArray ty -> UArray ty
A.take CountOf Word8
c (UArray Word8 -> UArray Word8)
-> ([Word64] -> UArray Word8) -> [Word64] -> UArray Word8
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. UArray Word64 -> UArray Word8
forall a b. (PrimType a, PrimType b) => UArray a -> UArray b
A.unsafeRecast (UArray Word64 -> UArray Word8)
-> ([Word64] -> UArray Word64) -> [Word64] -> UArray Word8
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [Word64] -> UArray Word64
[Item (UArray Word64)] -> UArray Word64
forall l. IsList l => [Item l] -> l
fromList) (([Word64], State) -> (UArray Word8, State))
-> ([Word64], State) -> (UArray Word8, State)
forall a b. (a -> b) -> a -> b
$ CountOf Word64 -> State -> ([Word64], State)
nextList CountOf Word64
c64 State
st
  where
    c64 :: CountOf Word64
c64 = CountOf Word8 -> CountOf Word64
forall a b. (PrimType a, PrimType b) => CountOf a -> CountOf b
sizeRecast CountOf Word8
c'
    c' :: CountOf Word8
c' = Int -> CountOf Word8 -> CountOf Word8
forall ty. Int -> CountOf ty -> CountOf ty
countOfRoundUp Int
8 CountOf Word8
c

next :: State -> (Word64, State)
next :: State -> (Word64, State)
next (State Word64
s0 Word64
s1prev) = (Word64
s0 Word64 -> Word64 -> Word64
forall a. Additive a => a -> a -> a
+ Word64
s1prev, Word64 -> Word64 -> State
State Word64
s0' Word64
s1')
  where
    !s1 :: Word64
s1 = Word64
s0 Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
`xor` Word64
s1prev
    s0' :: Word64
s0' = (Word64
s0 Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`rotateL` Int
55) Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
`xor` Word64
s1 Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
`xor` (Word64
s1 Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
.<<. Int
14)
    s1' :: Word64
s1' = (Word64
s1 Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`rotateL` Int
36)

nextList :: CountOf Word64 -> State -> ([Word64], State)
nextList :: CountOf Word64 -> State -> ([Word64], State)
nextList CountOf Word64
c State
state = [Word64] -> State -> Offset Word64 -> ([Word64], State)
loop [] State
state Offset Word64
0
  where
    loop :: [Word64] -> State -> Offset Word64 -> ([Word64], State)
loop [Word64]
acc State
st Offset Word64
o
        | Offset Word64
o Offset Word64 -> CountOf Word64 -> Bool
forall ty. Offset ty -> CountOf ty -> Bool
.==# CountOf Word64
c  = ([Word64] -> [Word64]
forall a. [a] -> [a]
reverse [Word64]
acc, State
st)
        | Bool
otherwise =
            let (Word64
w, State
st') = State -> (Word64, State)
next State
st
             in [Word64] -> State -> Offset Word64 -> ([Word64], State)
loop (Word64
wWord64 -> [Word64] -> [Word64]
forall a. a -> [a] -> [a]
:[Word64]
acc) State
st' (Offset Word64
oOffset Word64 -> Offset Word64 -> Offset Word64
forall a. Additive a => a -> a -> a
+Offset Word64
1)

nextFloat :: State -> (Float, State)
nextFloat :: State -> (Float, State)
nextFloat = (Double -> Float) -> (Double, State) -> (Float, State)
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first Double -> Float
dToF ((Double, State) -> (Float, State))
-> (State -> (Double, State)) -> State -> (Float, State)
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. State -> (Double, State)
nextDouble
  where dToF :: Double -> Float
dToF (D# Double#
d) = Float# -> Float
F# (Double# -> Float#
double2Float# Double#
d)

nextDouble :: State -> (Double, State)
nextDouble :: State -> (Double, State)
nextDouble !State
st = (Double
d' Double -> Double -> Difference Double
forall a. Subtractive a => a -> a -> Difference a
- Double
1.0 , State
st')
  where
    !(Word64
w, State
st') = State -> (Word64, State)
next State
st
    upperMask :: Word64
upperMask = Word64
0x3FF0000000000000
    lowerMask :: Word64
lowerMask = Word64
0x000FFFFFFFFFFFFF
    d' :: Double
    d' :: Double
d' = Word64 -> Double
forall a b. (Integral a, Num b) => a -> b
Prelude.fromIntegral Word64
d
    d :: Word64
d = Word64
upperMask Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
.|. (Word64
w Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
.&. Word64
lowerMask)