-- |
-- Module      : Foundation.Random.XorShift
-- License     : BSD-style
--
-- XorShift variant: Xoroshiro128+
-- <https://en.wikipedia.org/wiki/Xoroshiro128%2B>
--
-- Xoroshiro128+ is a PRNG that uses a shift/rotate-based linear transformation.
-- This is lar
--
-- C implementation at:
-- <http://xoroshiro.di.unimi.it/xoroshiro128plus.c>
--
module Basement.Alg.XorShift
    ( State(..)
    , next
    , nextDouble
    , jump
    ) where

import           Data.Word
import           Data.Bits
import           Basement.Compat.Base
import           Basement.Floating (wordToDouble)
import           Basement.Numerical.Additive
import           Basement.Numerical.Subtractive

-- | State of Xoroshiro128 plus
data State = State {-# UNPACK #-} !Word64 {-# UNPACK #-} !Word64

-- | Given a state, call the function 'f' with the generated Word64 and the next State
next :: State -> (Word64 -> State -> a) -> a
next :: forall a. State -> (Word64 -> State -> a) -> a
next (State Word64
s0 Word64
s1prev) Word64 -> State -> a
f = Word64 -> State -> a
f Word64
ran State
stNext
  where
    !stNext :: State
stNext = Word64 -> Word64 -> State
State Word64
s0' Word64
s1'
    !ran :: Word64
ran    = Word64
s0 Word64 -> Word64 -> Word64
forall a. Additive a => a -> a -> a
+ Word64
s1prev
    !s1 :: Word64
s1     = Word64
s0 Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
`xor` Word64
s1prev
    s0' :: Word64
s0'     = (Word64
s0 Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`rotateL` Int
55) Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
`xor` Word64
s1 Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
`xor` (Word64
s1 Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
14)
    s1' :: Word64
s1'     = (Word64
s1 Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`rotateL` Int
36)

-- | Same as 'next' but give a random value of type Double in the range of [0.0 .. 1.0]
nextDouble :: State -> (Double -> State -> a) -> a
nextDouble :: forall a. State -> (Double -> State -> a) -> a
nextDouble State
st Double -> State -> a
f = State -> (Word64 -> State -> a) -> a
forall a. State -> (Word64 -> State -> a) -> a
next State
st ((Word64 -> State -> a) -> a) -> (Word64 -> State -> a) -> a
forall a b. (a -> b) -> a -> b
$ \Word64
w -> Double -> State -> a
f (Word64 -> Double
toDouble Word64
w)
  where
    -- generate a number in the interval [1..2[ by bit manipulation.
    -- this generate double with a ~2^52
    toDouble :: Word64 -> Difference Double
toDouble Word64
w = Word64 -> Double
wordToDouble (Word64
upperMask Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
.|. (Word64
w Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
.&. Word64
lowerMask)) Double -> Double -> Difference Double
forall a. Subtractive a => a -> a -> Difference a
- Double
1.0
      where
        upperMask :: Word64
upperMask = Word64
0x3FF0000000000000
        lowerMask :: Word64
lowerMask = Word64
0x000FFFFFFFFFFFFF

-- | Jump the state by 2^64 calls of next
jump :: State -> State
jump :: State -> State
jump (State Word64
s0 Word64
s1) = Word64 -> State -> State
withK Word64
0xd86b048b86aa9922
                   (State -> State) -> State -> State
forall a b. (a -> b) -> a -> b
$ Word64 -> State -> State
withK Word64
0xbeac0467eba5facb
                   (State -> State) -> State -> State
forall a b. (a -> b) -> a -> b
$ (Word64 -> Word64 -> State
State Word64
0 Word64
0)
  where
    withK :: Word64 -> State -> State
    withK :: Word64 -> State -> State
withK !Word64
k = Int -> State -> State
loop Int
0
      where
        loop :: Int -> State -> State
loop !Int
i st :: State
st@(State Word64
c0 Word64
c1)
            | Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
64     = State
st
            | Word64 -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
testBit Word64
k Int
i = Int -> State -> State
loop (Int
iInt -> Int -> Int
forall a. Additive a => a -> a -> a
+Int
1) (Word64 -> Word64 -> State
State (Word64
c0 Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
`xor` Word64
s0) (Word64
c1 Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
`xor` Word64
s1))
            | Bool
otherwise   = State
st