module ProjectM36.AtomFunctions.Primitive where
import ProjectM36.Base
import ProjectM36.Relation (relFold)
import ProjectM36.Tuple
import ProjectM36.AtomFunctionError
import ProjectM36.AtomFunction
import qualified Data.HashSet as HS
import qualified Data.Vector as V
import qualified Data.ByteString.Base64 as B64
import qualified Data.Text.Encoding as TE
primitiveAtomFunctions :: AtomFunctions
primitiveAtomFunctions = HS.fromList [
AtomFunction { atomFuncName = "add",
atomFuncType = [IntAtomType, IntAtomType, IntAtomType],
atomFuncBody = body (\((IntAtom i1):(IntAtom i2):_) -> pure (IntAtom (i1 + i2)))},
AtomFunction { atomFuncName = "id",
atomFuncType = [TypeVariableType "a", TypeVariableType "a"],
atomFuncBody = body (\(x:_) -> pure x)},
AtomFunction { atomFuncName = "sum",
atomFuncType = foldAtomFuncType IntAtomType IntAtomType,
atomFuncBody = body (\(RelationAtom rel:_) -> relationSum rel)},
AtomFunction { atomFuncName = "count",
atomFuncType = foldAtomFuncType (TypeVariableType "a") IntAtomType,
atomFuncBody = body (\((RelationAtom relIn):_) -> relationCount relIn)},
AtomFunction { atomFuncName = "max",
atomFuncType = foldAtomFuncType IntAtomType IntAtomType,
atomFuncBody = body (\((RelationAtom relIn):_) -> relationMax relIn)},
AtomFunction { atomFuncName = "min",
atomFuncType = foldAtomFuncType IntAtomType IntAtomType,
atomFuncBody = body (\((RelationAtom relIn):_) -> relationMin relIn)},
AtomFunction { atomFuncName = "lt",
atomFuncType = [IntAtomType, IntAtomType, BoolAtomType],
atomFuncBody = body $ intAtomFuncLessThan False},
AtomFunction { atomFuncName = "lte",
atomFuncType = [IntAtomType, IntAtomType, BoolAtomType],
atomFuncBody = body $ intAtomFuncLessThan True},
AtomFunction { atomFuncName = "gte",
atomFuncType = [IntAtomType, IntAtomType, BoolAtomType],
atomFuncBody = body $ \args -> intAtomFuncLessThan False args >>= boolAtomNot},
AtomFunction { atomFuncName = "gt",
atomFuncType = [IntAtomType, IntAtomType, BoolAtomType],
atomFuncBody = body $ \args -> intAtomFuncLessThan True args >>= boolAtomNot},
AtomFunction { atomFuncName = "not",
atomFuncType = [BoolAtomType, BoolAtomType],
atomFuncBody = body $ \(b:_) -> boolAtomNot b },
AtomFunction { atomFuncName = "makeByteString",
atomFuncType = [TextAtomType, ByteStringAtomType],
atomFuncBody = body $ \((TextAtom textIn):_) -> case B64.decode (TE.encodeUtf8 textIn) of
Left err -> Left (AtomFunctionBytesDecodingError err)
Right bs -> pure (ByteStringAtom bs) }
]
where
body = AtomFunctionBody Nothing
intAtomFuncLessThan :: Bool -> [Atom] -> Either AtomFunctionError Atom
intAtomFuncLessThan equality ((IntAtom i1):(IntAtom i2):_) = pure (BoolAtom (i1 `op` i2))
where
op = if equality then (<=) else (<)
intAtomFuncLessThan _ _= pure (BoolAtom False)
boolAtomNot :: Atom -> Either AtomFunctionError Atom
boolAtomNot (BoolAtom b) = pure (BoolAtom (not b))
boolAtomNot _ = error "boolAtomNot called on non-Bool atom"
relationSum :: Relation -> Either AtomFunctionError Atom
relationSum relIn = pure (IntAtom (relFold (\tupIn acc -> acc + (newVal tupIn)) 0 relIn))
where
newVal :: RelationTuple -> Int
newVal tupIn = castInt ((tupleAtoms tupIn) V.! 0)
relationCount :: Relation -> Either AtomFunctionError Atom
relationCount relIn = pure (IntAtom (relFold (\_ acc -> acc + 1) (0::Int) relIn))
relationMax :: Relation -> Either AtomFunctionError Atom
relationMax relIn = pure (IntAtom (relFold (\tupIn acc -> max acc (newVal tupIn)) minBound relIn))
where
newVal tupIn = castInt ((tupleAtoms tupIn) V.! 0)
relationMin :: Relation -> Either AtomFunctionError Atom
relationMin relIn = pure (IntAtom (relFold (\tupIn acc -> min acc (newVal tupIn)) maxBound relIn))
where
newVal tupIn = castInt ((tupleAtoms tupIn) V.! 0)
castInt :: Atom -> Int
castInt (IntAtom i) = i
castInt _ = error "attempted to cast non-IntAtom to Int"