{-# LANGUAGE DataKinds #-}
{-# LANGUAGE MonoLocalBinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}

module Internal.Quasi.Operator.Quasi where

import Data.List.Split (chunksOf)
import Data.Proxy
import qualified GHC.Natural as Natural
import GHC.TypeNats
import Internal.Matrix
import qualified Internal.Quasi.Operator.Parser as Parser
import qualified Internal.Quasi.Parser as Parser
import Internal.Quasi.Quasi
import Language.Haskell.TH.Quote
import Language.Haskell.TH.Syntax
import QLinear.Identity

{- | Macro constructor for operator

>>> [operator| (x, y) => (y, x) |]
[0,1]
[1,0]
>>> [operator| (x, y) => (2 * x, y + x) |] ~*~ [vector| 3 4 |]
[6]
[7]

Do note,constructor __doesn't prove__ linearity.
It just builds matrix of given operator.

-}
operator :: QuasiQuoter
operator :: QuasiQuoter
operator =
  QuasiQuoter :: (String -> Q Exp)
-> (String -> Q Pat)
-> (String -> Q Type)
-> (String -> Q [Dec])
-> QuasiQuoter
QuasiQuoter
    { quoteExp :: String -> Q Exp
quoteExp = String -> Q Exp
expr,
      quotePat :: String -> Q Pat
quotePat = String -> String -> Q Pat
forall a. String -> a
notDefined "Pattern",
      quoteType :: String -> Q Type
quoteType = String -> String -> Q Type
forall a. String -> a
notDefined "Type",
      quoteDec :: String -> Q [Dec]
quoteDec = String -> String -> Q [Dec]
forall a. String -> a
notDefined "Declaration"
    }
  where
    notDefined :: String -> a
notDefined = String -> String -> a
forall a. String -> String -> a
isNotDefinedAs "operator"

expr :: String -> Q Exp
expr :: String -> Q Exp
expr source :: String
source = do
  let (params :: [Pat]
params, lams :: [Exp]
lams, n :: Integer
n) = Either [String] ([Pat], [Exp], Integer) -> ([Pat], [Exp], Integer)
forall a. Either [String] a -> a
unwrap (Either [String] ([Pat], [Exp], Integer)
 -> ([Pat], [Exp], Integer))
-> Either [String] ([Pat], [Exp], Integer)
-> ([Pat], [Exp], Integer)
forall a b. (a -> b) -> a -> b
$ String -> Either [String] ([Pat], [Exp], Integer)
parse String
source
  let sizeType :: Integer -> Type
sizeType = TyLit -> Type
LitT (TyLit -> Type) -> (Integer -> TyLit) -> Integer -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> TyLit
NumTyLit
  let size :: Exp
size = [Exp] -> Exp
TupE ([Exp] -> Exp) -> [Exp] -> Exp
forall a b. (a -> b) -> a -> b
$ (Integer -> Exp) -> [Integer] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (Lit -> Exp
LitE (Lit -> Exp) -> (Integer -> Lit) -> Integer -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> Lit
IntegerL) [Integer
n, 1]
  let func :: Exp
func = Name -> Exp
VarE 'matrixOfOperator
  let constructor :: Exp
constructor = (Exp -> Type -> Exp) -> Exp -> [Type] -> Exp
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Exp -> Type -> Exp
AppTypeE (Name -> Exp
ConE 'Matrix) [Integer -> Type
sizeType Integer
n, Integer -> Type
sizeType 1, Type
WildCardT]
  let value :: Exp
value = [Exp] -> Exp
ListE ([Exp] -> Exp) -> [Exp] -> Exp
forall a b. (a -> b) -> a -> b
$ (Exp -> Exp) -> [Exp] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map ([Exp] -> Exp
ListE ([Exp] -> Exp) -> (Exp -> [Exp]) -> Exp -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp -> [Exp]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp -> [Exp]) -> (Exp -> Exp) -> Exp -> [Exp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Pat] -> Exp -> Exp
LamE [[Pat] -> Pat
ListP [Pat]
params]) [Exp]
lams
  Exp -> Q Exp
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp -> Q Exp) -> Exp -> Q Exp
forall a b. (a -> b) -> a -> b
$ Exp -> Exp -> Exp
AppE Exp
func (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$ (Exp -> Exp -> Exp) -> Exp -> [Exp] -> Exp
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Exp -> Exp -> Exp
AppE Exp
constructor [Exp
size, Exp
value]
  where

parse :: String -> Either [String] ([Pat], [Exp], Integer)
parse :: String -> Either [String] ([Pat], [Exp], Integer)
parse source :: String
source = do
  (params :: [Pat]
params, lams :: [Exp]
lams) <- Parser ([Pat], [Exp])
-> String -> String -> Either [String] ([Pat], [Exp])
forall a. Parser a -> String -> String -> Either [String] a
Parser.parse Parser ([Pat], [Exp])
Parser.definition "QLinear" String
source
  Integer
size <- ([Pat], [Exp]) -> Either [String] Integer
checkSize ([Pat]
params, [Exp]
lams)
  ([Pat], [Exp], Integer) -> Either [String] ([Pat], [Exp], Integer)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Pat]
params, [Exp]
lams, Integer
size)

checkSize :: ([Pat], [Exp]) -> Either [String] Integer
checkSize :: ([Pat], [Exp]) -> Either [String] Integer
checkSize ([], _) = [String] -> Either [String] Integer
forall a b. a -> Either a b
Left ["Parameters of operator cannot be empty"]
checkSize (_, []) = [String] -> Either [String] Integer
forall a b. a -> Either a b
Left ["Body of operator cannot be empty"]
checkSize (names :: [Pat]
names, exprs :: [Exp]
exprs) =
  let namesLength :: Int
namesLength = [Pat] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Pat]
names
      exprsLength :: Int
exprsLength = [Exp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Exp]
exprs
   in if Int
namesLength Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
exprsLength
        then Integer -> Either [String] Integer
forall a b. b -> Either a b
Right (Integer -> Either [String] Integer)
-> Integer -> Either [String] Integer
forall a b. (a -> b) -> a -> b
$ Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
namesLength
        else [String] -> Either [String] Integer
forall a b. a -> Either a b
Left ["Number of arguments and number of lambdas must be equal"]

matrixOfOperator :: forall n a b. (KnownNat n, HasIdentity a) => Matrix n 1 ([a] -> b) -> Matrix n n b
matrixOfOperator :: Matrix n 1 ([a] -> b) -> Matrix n n b
matrixOfOperator (Matrix _ fs :: [[[a] -> b]]
fs) = (Int, Int) -> [[b]] -> Matrix n n b
forall (m :: Nat) (n :: Nat) a. (Int, Int) -> [[a]] -> Matrix m n a
Matrix (Int
n, Int
n) ([[b]] -> Matrix n n b) -> [[b]] -> Matrix n n b
forall a b. (a -> b) -> a -> b
$ Int -> [b] -> [[b]]
forall e. Int -> [e] -> [[e]]
chunksOf Int
n [[a] -> b
f [a]
line | [a] -> b
f <- [[[a] -> b]] -> [[a] -> b]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[[a] -> b]]
fs, [a]
line <- [[a]]
identity]
  where
    (Matrix _ identity :: [[a]]
identity) = Matrix n n a
forall (n :: Nat) a. (KnownNat n, HasIdentity a) => Identity n a
e :: Matrix n n a
    n :: Int
n = Natural -> Int
Natural.naturalToInt (Natural -> Int) -> Natural -> Int
forall a b. (a -> b) -> a -> b
$ Proxy n -> Natural
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Natural
natVal (Proxy n
forall k (t :: k). Proxy t
Proxy @n)