{-# LANGUAGE OverloadedLists #-}

{-|
Module      : Internal.Core
Description : Core language internals
Stability   : experimental

Internal matrix and measurment operations
-}
module Lib.Internal.Core where

import Data.Bit ( Bit )
import Lib.QM ( QState(QState), Ix, QBit (Ptr), stateSize )
import Data.Bits ( Bits((.&.)) )
import Numeric.LinearAlgebra
    ( Complex,
      magnitude,
      flatten,
      outer,
      normalize,
      size,
      toList,
      fromList,
      C,
      Vector )
import qualified Control.Monad.Random as Rand ( fromList, evalRandIO )

-- | Appends state with tensor 
appendState :: QState -> QState -> QState
appendState :: QState -> QState -> QState
appendState (QState new :: Vector C
new) (QState [])    = Vector C -> QState
QState Vector C
new
appendState (QState new :: Vector C
new) (QState state :: Vector C
state) = Vector C -> QState
QState (Vector C -> QState) -> Vector C -> QState
forall a b. (a -> b) -> a -> b
$ Vector C -> Vector C -> Vector C
tensorVector Vector C
new Vector C
state


-- | Tensor product between two vectors
tensorVector :: Vector C -> Vector C -> Vector C
tensorVector :: Vector C -> Vector C -> Vector C
tensorVector newVector :: Vector C
newVector oldVector :: Vector C
oldVector = Matrix C -> Vector C
forall t. Element t => Matrix t -> Vector t
flatten (Matrix C -> Vector C) -> Matrix C -> Vector C
forall a b. (a -> b) -> a -> b
$ Vector C -> Vector C -> Matrix C
forall t. Product t => Vector t -> Vector t -> Matrix t
outer Vector C
oldVector Vector C
newVector

-- | Vector state representations of qubits with 100% probaility
-- to collapse to their bit counterparts
newVector :: Bit -> QState
newVector :: Bit -> QState
newVector 0 = Vector C -> QState
QState [1, 0]
newVector 1 = Vector C -> QState
QState [0, 1]


-- Model probability as a rational number.
type Prob = Rational

-- | Finds the probability of the qubit measing to a 1.
--  Find all the amplitudes where that qubit is one and converts it to probabilities.
findQbitProb1 :: QBit -> QState -> Prob
findQbitProb1 :: QBit -> QState -> Prob
findQbitProb1 qbit :: QBit
qbit qstate :: QState
qstate = [Prob] -> Prob
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Prob] -> Prob) -> [Prob] -> Prob
forall a b. (a -> b) -> a -> b
$ (C -> Prob) -> [C] -> [Prob]
forall a b. (a -> b) -> [a] -> [b]
map C -> Prob
ampToProb (QBit -> QState -> [C]
findMarginAmps1 QBit
qbit QState
qstate)

type Amplitude = Complex Double

-- | Finds the amplitudes from all the positions where that qubit is one.
findMarginAmps1 :: QBit -> QState -> [Amplitude]
findMarginAmps1 :: QBit -> QState -> [C]
findMarginAmps1 qbit :: QBit
qbit qstate :: QState
qstate = ((Ix, C) -> C) -> [(Ix, C)] -> [C]
forall a b. (a -> b) -> [a] -> [b]
map (Ix, C) -> C
forall a b. (a, b) -> b
snd ([(Ix, C)] -> [C]) -> [(Ix, C)] -> [C]
forall a b. (a -> b) -> a -> b
$ ((Ix, C) -> Bool) -> [(Ix, C)] -> [(Ix, C)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Ix, C) -> Bool
isMargin [(Ix, C)]
allAmps
  where
    allAmps :: [(Ix, Amplitude)]
    allAmps :: [(Ix, C)]
allAmps = QState -> [(Ix, C)]
qstateAmps QState
qstate

    isMargin :: (Ix, Amplitude) -> Bool
    isMargin :: (Ix, C) -> Bool
isMargin (ix :: Ix
ix, _) = Ix -> Ix -> Bool
maskMatch Ix
ix Ix
ixMask -- This implies that the amplitude deals with a case where the given qbit measures to a 1
    ixMask :: Ix
ixMask = QBit -> QState -> Ix
qbitMask QBit
qbit QState
qstate

-- | If all the bits in b are in a it is a match.
maskMatch :: Int -> Int -> Bool
maskMatch :: Ix -> Ix -> Bool
maskMatch a :: Ix
a b :: Ix
b = Ix
a Ix -> Ix -> Ix
forall a. Bits a => a -> a -> a
.&. Ix
b Ix -> Ix -> Bool
forall a. Eq a => a -> a -> Bool
== Ix
b

-- | Given a qbit, finds its mask in the qstate.
-- E.g. an amplitude in a 3 qbits state could be |100>.
-- A mask of 100=8 is wanted if the zero'th qbit is in interest, same mask would work for |101>.
-- If qbit is 1 we want it to be 010=4...
qbitMask :: QBit -> QState -> Int
qbitMask :: QBit -> QState -> Ix
qbitMask (Ptr qbitIx :: Ix
qbitIx) qstate :: QState
qstate = 2Ix -> Ix -> Ix
forall a b. (Num a, Integral b) => a -> b -> a
^(Ix
numQbits Ix -> Ix -> Ix
forall a. Num a => a -> a -> a
- 1 Ix -> Ix -> Ix
forall a. Num a => a -> a -> a
- Ix
qbitIx)
  where
    numQbits :: Ix
numQbits = QState -> Ix
stateSize QState
qstate

-- | Given a complex amplitude, will return its probability.
ampToProb :: Amplitude -> Prob
ampToProb :: C -> Prob
ampToProb = Double -> Prob
forall a. Real a => a -> Prob
toRational (Double -> Prob) -> (C -> Double) -> C -> Prob
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Double -> Integer -> Double
forall a b. (Num a, Integral b) => a -> b -> a
^2) (Double -> Double) -> (C -> Double) -> C -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. C -> Double
forall a. RealFloat a => Complex a -> a
magnitude

-- | Removes the states that contradict the measurment from the qbit to bit, also normalizes the state to a length of one.
remImpossibleStates :: QState -> QBit -> Bit -> QState
remImpossibleStates :: QState -> QBit -> Bit -> QState
remImpossibleStates qstate :: QState
qstate qbit :: QBit
qbit bit :: Bit
bit = (Vector C -> QState
QState (Vector C -> QState)
-> ([(Ix, C)] -> Vector C) -> [(Ix, C)] -> QState
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector C -> Vector C
forall t.
(Normed (Vector t), Num (Vector t), Field t) =>
Vector t -> Vector t
normalize (Vector C -> Vector C)
-> ([(Ix, C)] -> Vector C) -> [(Ix, C)] -> Vector C
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [C] -> Vector C
forall a. Storable a => [a] -> Vector a
fromList ([C] -> Vector C) -> ([(Ix, C)] -> [C]) -> [(Ix, C)] -> Vector C
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Ix, C) -> C) -> [(Ix, C)] -> [C]
forall a b. (a -> b) -> [a] -> [b]
map (Ix, C) -> C
transformAmp) [(Ix, C)]
amps
  where
    amps :: [(Ix, C)]
amps = QState -> [(Ix, C)]
qstateAmps QState
qstate
    ixMask :: Ix
ixMask = QBit -> QState -> Ix
qbitMask QBit
qbit QState
qstate

    -- If the index is at a impossible state, then amplitude is set to 0, else it is kept.
    transformAmp :: (Ix, Amplitude) -> Amplitude
    transformAmp :: (Ix, C) -> C
transformAmp (ix :: Ix
ix, amp :: C
amp) | Ix -> Bool
impossibleState Ix
ix = 0
                           | Bool
otherwise          = C
amp

    -- If the mask is a match we are at a position where that qubit is a 1, if the bit is measured
    -- as a 0 we are at impossible state. Or the opposite, if the position is where qbit is 0
    -- we are impossible if we measured a one.
    impossibleState :: Ix -> Bool
    impossibleState :: Ix -> Bool
impossibleState ix :: Ix
ix | Ix -> Ix -> Bool
maskMatch Ix
ix Ix
ixMask = Bit
bit Bit -> Bit -> Bool
forall a. Eq a => a -> a -> Bool
== 0 
                       | Bool
otherwise           = Bit
bit Bit -> Bit -> Bool
forall a. Eq a => a -> a -> Bool
== 1


-- | From a qstate, returns the amplitudes with its indexes. 
qstateAmps :: QState -> [(Ix, Amplitude)]
qstateAmps :: QState -> [(Ix, C)]
qstateAmps (QState stateVector :: Vector C
stateVector) = [Ix] -> [C] -> [(Ix, C)]
forall a b. [a] -> [b] -> [(a, b)]
zip [0..] (Vector C -> [C]
forall a. Storable a => Vector a -> [a]
toList Vector C
stateVector)

-- Uses random number generator to return a bit according to the probabilites given.
rngQbit :: Prob -> IO Bit
rngQbit :: Prob -> IO Bit
rngQbit p1 :: Prob
p1 = Rand StdGen Bit -> IO Bit
forall a. Rand StdGen a -> IO a
Rand.evalRandIO (Rand StdGen Bit -> IO Bit) -> Rand StdGen Bit -> IO Bit
forall a b. (a -> b) -> a -> b
$ [(Bit, Prob)] -> Rand StdGen Bit
forall (m :: * -> *) a. MonadRandom m => [(a, Prob)] -> m a
Rand.fromList [(0, 1Prob -> Prob -> Prob
forall a. Num a => a -> a -> a
-Prob
p1), (1, Prob
p1)]