-- | The Deustch-Jozsa Oracle algorithm
module DeutschJozsa where

import FunQ
import Control.Monad ( replicateM )

type Oracle = ([QBit], QBit) -> QM ([QBit], QBit)

-- | An oracle with a balanced function
balanced :: Oracle
balanced :: Oracle
balanced (xs :: [QBit]
xs,y :: QBit
y) = do
    (QBit -> QM QBit) -> [QBit] -> QM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ QBit -> QM QBit
pauliX [QBit]
xs
    (QBit -> QM (QBit, QBit)) -> [QBit] -> QM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (\q :: QBit
q -> (QBit, QBit) -> QM (QBit, QBit)
cnot (QBit
q,QBit
y)) [QBit]
xs
    (QBit -> QM QBit) -> [QBit] -> QM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ QBit -> QM QBit
pauliX [QBit]
xs
    Oracle
forall (m :: * -> *) a. Monad m => a -> m a
return ([QBit]
xs, QBit
y)

-- | An oracle with a constant function
constant :: Oracle
constant :: Oracle
constant (xs :: [QBit]
xs,y :: QBit
y) = do
    [QBit]
zs <- Int -> QM QBit -> QM [QBit]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM ([QBit] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [QBit]
xs) (Bit -> QM QBit
new 0)
    ((QBit, QBit) -> QM (QBit, QBit)) -> [(QBit, QBit)] -> QM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (QBit, QBit) -> QM (QBit, QBit)
swap ([(QBit, QBit)] -> QM ()) -> [(QBit, QBit)] -> QM ()
forall a b. (a -> b) -> a -> b
$ [QBit] -> [QBit] -> [(QBit, QBit)]
forall a b. [a] -> [b] -> [(a, b)]
zip [QBit]
xs [QBit]
zs
    (QBit -> QM (QBit, QBit)) -> [QBit] -> QM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (\q :: QBit
q -> (QBit, QBit) -> QM (QBit, QBit)
cnot (QBit
q,QBit
y)) [QBit]
xs
    ((QBit, QBit) -> QM (QBit, QBit)) -> [(QBit, QBit)] -> QM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (QBit, QBit) -> QM (QBit, QBit)
swap ([(QBit, QBit)] -> QM ()) -> [(QBit, QBit)] -> QM ()
forall a b. (a -> b) -> a -> b
$ [QBit] -> [QBit] -> [(QBit, QBit)]
forall a b. [a] -> [b] -> [(a, b)]
zip [QBit]
xs [QBit]
zs
    Oracle
forall (m :: * -> *) a. Monad m => a -> m a
return ([QBit]
xs, QBit
y)

-- | Will return a list of ones if balanced and list of zeros if constant.
-- Size is the number of qubit inputs to the oracle.
deutschJozsa :: Int -> Oracle -> QM [Bit]
deutschJozsa :: Int -> Oracle -> QM [Bit]
deutschJozsa size :: Int
size oracle :: Oracle
oracle = do
    [QBit]
xs <- Int -> QM QBit -> QM [QBit]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
size (Bit -> QM QBit
new 0)
    QBit
y <- Bit -> QM QBit
new 1
    (QBit -> QM QBit) -> [QBit] -> QM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ QBit -> QM QBit
hadamard [QBit]
xs
    QBit -> QM QBit
hadamard QBit
y
    Oracle
oracle ([QBit]
xs, QBit
y)
    (QBit -> QM QBit) -> [QBit] -> QM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ QBit -> QM QBit
hadamard [QBit]
xs
    (QBit -> QM Bit) -> [QBit] -> QM [Bit]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM QBit -> QM Bit
measure [QBit]
xs