module Michelson.TypeCheck.Helpers
( onLeft
, deriveSpecialVN
, deriveSpecialFNs
, deriveVN
, deriveNsOr
, deriveNsOption
, convergeHSTEl
, convergeHST
, hstToTs
, eqHST
, eqHST1
, lengthHST
, ensureDistinctAsc
, eqType
, checkEqT
, checkEqHST
, onTypeCheckInstrAnnErr
, onTypeCheckInstrErr
, typeCheckInstrErr
, typeCheckInstrErr'
, typeCheckImpl
, matchTypes
, memImpl
, getImpl
, updImpl
, sliceImpl
, concatImpl
, concatImpl'
, sizeImpl
, arithImpl
, addImpl
, subImpl
, mulImpl
, edivImpl
, unaryArithImpl
) where
import Prelude hiding (EQ, GT, LT)
import Control.Monad.Except (MonadError, liftEither, throwError)
import Data.Default (def)
import Data.Singletons (SingI(sing), demote)
import qualified Data.Text as T
import Data.Typeable ((:~:)(..), eqT)
import Fmt ((+||), (||+))
import Michelson.ErrorPos (InstrCallStack)
import Michelson.TypeCheck.Error (TCError(..), TCTypeError(..))
import Michelson.TypeCheck.TypeCheck
import Michelson.TypeCheck.Types
import Michelson.Typed
(CT(..), Instr(..), Notes(..), PackedNotes(..), Sing(..), T(..), converge, fromSingT, orAnn,
starNotes)
import Michelson.Typed.Annotation (AnnConvergeError)
import Michelson.Typed.Arith (Add, ArithOp(..), Mul, Sub, UnaryArithOp(..))
import Michelson.Typed.Polymorphic
(ConcatOp, EDivOp(..), GetOp(..), MemOp(..), SizeOp, SliceOp, UpdOp(..))
import qualified Michelson.Untyped as Un
import Michelson.Untyped.Annotation (Annotation(..), FieldAnn, VarAnn, ann)
deriveSpecialFNs
:: FieldAnn -> FieldAnn
-> VarAnn -> VarAnn
-> (VarAnn, FieldAnn, FieldAnn)
deriveSpecialFNs "@" "@" pvn qvn = (vn, pfn, qfn)
where
ps = T.splitOn "." $ unAnnotation pvn
qs = T.splitOn "." $ unAnnotation qvn
fns = fst <$> takeWhile (uncurry (==)) (zip ps qs)
vn = ann $ T.intercalate "." fns
pfn = ann $ T.intercalate "." $ drop (length fns) ps
qfn = ann $ T.intercalate "." $ drop (length fns) qs
deriveSpecialFNs "@" qfn pvn _ = (def, Un.convAnn pvn, qfn)
deriveSpecialFNs pfn "@" _ qvn = (def, pfn, Un.convAnn qvn)
deriveSpecialFNs pfn qfn _ _ = (def, pfn, qfn)
deriveSpecialVN :: VarAnn -> FieldAnn -> VarAnn -> VarAnn
deriveSpecialVN vn elFn pairVN
| vn == "%" = Un.convAnn elFn
| vn == "%%" && elFn /= def = pairVN <> Un.convAnn elFn
| otherwise = vn
deriveVN :: VarAnn -> VarAnn -> VarAnn
deriveVN suffix vn = bool (suffix <> vn) def (vn == def)
deriveNsOr :: Notes ('TOr a b) -> VarAnn -> (Notes a, Notes b, VarAnn, VarAnn)
deriveNsOr (NTOr _ afn bfn an bn) ovn =
let avn = deriveVN (Un.convAnn afn `orAnn` "left") ovn
bvn = deriveVN (Un.convAnn bfn `orAnn` "right") ovn
in (an, bn, avn, bvn)
deriveNsOption :: Notes ('TOption a) -> VarAnn -> (Notes a, VarAnn)
deriveNsOption (NTOption _ an) ovn =
let avn = deriveVN "some" ovn
in (an, avn)
convergeHSTEl
:: (Sing t, Notes t, VarAnn)
-> (Sing t, Notes t, VarAnn)
-> Either AnnConvergeError (Sing t, Notes t, VarAnn)
convergeHSTEl (at, an, avn) (_, bn, bvn) =
(,,) at <$> converge an bn
<*> pure (bool def avn $ avn == bvn)
convergeHST :: HST ts -> HST ts -> Either AnnConvergeError (HST ts)
convergeHST SNil SNil = pure SNil
convergeHST (a ::& as) (b ::& bs) =
liftA2 (::&) (convergeHSTEl a b) (convergeHST as bs)
onLeft :: Either a c -> (a -> b) -> Either b c
onLeft = flip first
hstToTs :: HST st -> [T]
hstToTs = \case
SNil -> []
(s, _, _) ::& hst -> fromSingT s : hstToTs hst
eqHST
:: forall as bs.
(Typeable as, Typeable bs)
=> HST as -> HST bs -> Either TCTypeError (as :~: bs)
eqHST (hst :: HST xs) (hst' :: HST ys) = do
case eqT @as @bs of
Nothing -> Left $ StackEqError (hstToTs hst) (hstToTs hst')
Just Refl -> do
void $ convergeHST hst hst' `onLeft` AnnError
return Refl
eqHST1
:: forall t st.
(Typeable st, Typeable t, SingI t)
=> HST st -> Either TCTypeError (st :~: '[t])
eqHST1 hst = do
let hst' = sing @t -:& SNil
case eqT @'[t] @st of
Nothing -> Left $ StackEqError (hstToTs hst') (hstToTs hst)
Just Refl -> Right Refl
lengthHST :: HST xs -> Natural
lengthHST (_ ::& xs) = 1 + lengthHST xs
lengthHST SNil = 0
ensureDistinctAsc :: (Ord b, Show a) => (a -> b) -> [a] -> Either Text [a]
ensureDistinctAsc toCmp = \case
(e1 : e2 : l) ->
if toCmp e1 < toCmp e2
then (e1 :) <$> ensureDistinctAsc toCmp (e2 : l)
else Left $ "Entries are unordered (" +|| e1 ||+ " >= " +|| e2 ||+ ")"
l -> Right l
checkEqT
:: forall (a :: T) (b :: T) ts m .
( Each [Typeable, SingI] [a, b], Typeable ts
, MonadReader InstrCallStack m, MonadError TCError m
)
=> Un.ExpandedInstr
-> HST ts
-> Text
-> m (a :~: b)
checkEqT instr i m = do
pos <- ask
liftEither $ eqType @a @b `onLeft` (TCFailedOnInstr instr (SomeHST i) (m <> ": ") pos . Just)
eqType
:: forall (a :: T) (b :: T).
(Each [Typeable, SingI] [a, b])
=> Either TCTypeError (a :~: b)
eqType = maybe (Left $ TypeEqError (demote @a) (demote @b)) pure eqT
checkEqHST
:: forall (a :: [T]) (b :: [T]) ts m .
( Typeable a, Typeable b, Typeable ts
, MonadReader InstrCallStack m, MonadError TCError m
)
=> HST a
-> HST b
-> Un.ExpandedInstr
-> HST ts
-> Text
-> m (a :~: b)
checkEqHST a b instr i m = do
pos <- ask
liftEither $ eqHST a b `onLeft` (TCFailedOnInstr instr (SomeHST i) (m <> ": ") pos . Just)
onTypeCheckInstrErr
:: (MonadReader InstrCallStack m, MonadError TCError m)
=> Un.ExpandedInstr -> SomeHST -> Text -> Either TCTypeError a -> m a
onTypeCheckInstrErr instr hst msg ei = do
either (typeCheckInstrErr' instr hst msg) return ei
typeCheckInstrErr
:: (MonadReader InstrCallStack m, MonadError TCError m)
=> Un.ExpandedInstr -> SomeHST -> Text -> m a
typeCheckInstrErr instr hst msg = do
pos <- ask
throwError $ TCFailedOnInstr instr hst msg pos Nothing
typeCheckInstrErr'
:: (MonadReader InstrCallStack m, MonadError TCError m)
=> Un.ExpandedInstr -> SomeHST -> Text -> TCTypeError -> m a
typeCheckInstrErr' instr hst msg err = do
pos <- ask
throwError $ TCFailedOnInstr instr hst msg pos (Just err)
onTypeCheckInstrAnnErr
:: (MonadReader InstrCallStack m, MonadError TCError m, Typeable ts)
=> Un.ExpandedInstr -> HST ts -> Text -> Either AnnConvergeError a -> m a
onTypeCheckInstrAnnErr instr i msg ei =
onTypeCheckInstrErr instr (SomeHST i) msg (ei `onLeft` AnnError)
typeCheckImpl
:: forall inp . Typeable inp
=> TcInstrHandler
-> [Un.ExpandedOp]
-> HST inp
-> TypeCheckInstr (SomeInstr inp)
typeCheckImpl tcInstr instrs t@(a :: HST a) =
case instrs of
Un.WithSrcEx _ (i@(Un.WithSrcEx _ _)) : rs -> typeCheckImpl tcInstr (i : rs) t
Un.WithSrcEx cs (Un.PrimEx i) : rs -> typeCheckPrim (Just cs) i rs
Un.WithSrcEx cs (Un.SeqEx sq) : rs -> typeCheckSeq (Just cs) sq rs
Un.PrimEx i : rs -> typeCheckPrim Nothing i rs
Un.SeqEx sq : rs -> typeCheckSeq Nothing sq rs
[] -> pure $ a :/ Nop ::: a
where
typeCheckPrim (Just cs) i [] = local (const cs) $ tcInstr i t
typeCheckPrim (Just cs) i rs = local (const cs) $ typeCheckImplDo (tcInstr i t) id rs
typeCheckPrim Nothing i [] = tcInstr i t
typeCheckPrim Nothing i rs = typeCheckImplDo (tcInstr i t) id rs
typeCheckSeq (Just cs) sq = local (const cs) . typeCheckImplDo (typeCheckImpl tcInstr sq t) Nested
typeCheckSeq Nothing sq = typeCheckImplDo (typeCheckImpl tcInstr sq t) Nested
typeCheckImplDo
:: TypeCheckInstr (SomeInstr inp)
-> (forall inp' out . Instr inp' out -> Instr inp' out)
-> [Un.ExpandedOp]
-> TypeCheckInstr (SomeInstr inp)
typeCheckImplDo f wrap rs = do
_ :/ pi' <- f
case pi' of
p ::: b -> do
_ :/ qi <- typeCheckImpl tcInstr rs b
case qi of
q ::: c -> case q of
Seq _ _ ->
pure $ a :/ Seq (wrapWithNotes b (wrap p)) q ::: c
_ ->
pure $ a :/ Seq (wrapWithNotes b (wrap p)) (wrapWithNotes c q) ::: c
AnyOutInstr q ->
pure $ a :/ AnyOutInstr (Seq (wrapWithNotes b (wrap p)) q)
AnyOutInstr instr ->
case rs of
[] ->
pure $ a :/ AnyOutInstr instr
r : rr ->
throwError $ TCUnreachableCode (extractInstrPos r) (r :| rr)
wrapWithNotes :: HST d -> Instr c d -> Instr c d
wrapWithNotes h ins = case h of
((s, n, _) ::& _) -> InstrWithNotes (PackedNotes n s) ins
SNil -> ins
extractInstrPos :: Un.ExpandedOp -> InstrCallStack
extractInstrPos (Un.WithSrcEx cs _) = cs
extractInstrPos _ = def
matchTypes
:: forall t1 t2.
(Each [Typeable, SingI] [t1, t2])
=> Notes t1 -> Notes t2 -> Either TCTypeError (t1 :~: t2, Notes t1)
matchTypes n1 n2 = do
Refl <- eqType @t1 @t2
nr <- converge n1 n2 `onLeft` AnnError
return (Refl, nr)
memImpl
:: forall (q :: CT) (c :: T) ts inp m .
( MonadReader InstrCallStack m, MonadError TCError m, Typeable ts
, Typeable (MemOpKey c), SingI (MemOpKey c), MemOp c
, inp ~ ('Tc q : c : ts)
)
=> Un.ExpandedInstr
-> HST inp
-> VarAnn
-> m (SomeInstr inp)
memImpl instr i@(_ ::& _ ::& rs) vn = do
pos <- ask
case eqType @('Tc q) @('Tc (MemOpKey c)) of
Right Refl -> pure $ i :/ MEM ::: ((STc SCBool, starNotes, vn) ::& rs)
Left m -> throwError $
TCFailedOnInstr instr (SomeHST i) "query element type is not equal to set's element type" pos (Just m)
getImpl
:: forall c getKey rs inp m .
( GetOp c, Typeable (GetOpKey c)
, Typeable (GetOpVal c)
, SingI (GetOpVal c), SingI (GetOpKey c)
, inp ~ (getKey : c : rs)
, MonadReader InstrCallStack m
, MonadError TCError m
)
=> Un.ExpandedInstr
-> HST (getKey ': c ': rs)
-> Sing (GetOpVal c)
-> Notes (GetOpVal c)
-> VarAnn
-> m (SomeInstr inp)
getImpl instr i@(_ ::& _ ::& rs) rt vns vn = do
pos <- ask
case eqType @getKey @('Tc (GetOpKey c)) of
Right Refl -> do
let rn = NTOption def vns
pure $ i :/ GET ::: ((STOption rt, rn, vn) ::& rs)
Left m -> throwError $ TCFailedOnInstr instr (SomeHST i) "wrong key stack type" pos (Just m)
updImpl
:: forall c updKey updParams rs inp m .
( UpdOp c
, Typeable (UpdOpKey c), SingI (UpdOpKey c)
, Typeable (UpdOpParams c), SingI (UpdOpParams c)
, inp ~ (updKey : updParams : c : rs)
, MonadReader InstrCallStack m
, MonadError TCError m
)
=> Un.ExpandedInstr
-> HST (updKey ': updParams ': c ': rs)
-> VarAnn
-> m (SomeInstr inp)
updImpl instr i@(_ ::& _ ::& cTuple ::& rest) vn = do
pos <- ask
case (eqType @updKey @('Tc (UpdOpKey c)), eqType @updParams @(UpdOpParams c)) of
(Right Refl, Right Refl) ->
pure $ i :/ UPDATE ::: ((cTuple & _3 .~ vn) ::& rest)
(Left m, _) -> throwError $ TCFailedOnInstr instr (SomeHST i)
"wrong key stack type" pos (Just m)
(_, Left m) -> throwError $ TCFailedOnInstr instr (SomeHST i)
"wrong update value stack type" pos (Just m)
sizeImpl
:: (SizeOp c, inp ~ (c ': rs), Monad m)
=> HST inp
-> VarAnn
-> m (SomeInstr inp)
sizeImpl i@(_ ::& rs) vn =
pure $ i :/ SIZE ::: ((STc SCNat, starNotes, vn) ::& rs)
sliceImpl
:: (SliceOp c, Typeable c, inp ~ ('Tc 'CNat ': 'Tc 'CNat ': c ': rs), Monad m)
=> HST inp
-> Un.VarAnn
-> m (SomeInstr inp)
sliceImpl i@(_ ::& _ ::& (c, cn, cvn) ::& rs) vn = do
let vn' = vn `orAnn` deriveVN "slice" cvn
rn = NTOption def cn
pure $ i :/ SLICE ::: ((STOption c, rn, vn') ::& rs)
concatImpl'
:: (ConcatOp c, Typeable c, inp ~ ('TList c : rs), Monad m)
=> HST inp
-> Un.VarAnn
-> m (SomeInstr inp)
concatImpl' i@((STList c, NTList _ n, _) ::& rs) vn = do
pure $ i :/ CONCAT' ::: ((c, n, vn) ::& rs)
concatImpl
:: ( ConcatOp c, Typeable c, inp ~ (c ': c ': rs)
, MonadReader InstrCallStack m
, MonadError TCError m
)
=> HST inp
-> Un.VarAnn
-> m (SomeInstr inp)
concatImpl i@((c, cn1, _) ::& (_, cn2, _) ::& rs) vn = do
cn <- onTypeCheckInstrAnnErr (Un.CONCAT vn) i "wrong operand types for concat operation" (converge cn1 cn2)
pure $ i :/ CONCAT ::: ((c, cn, vn) ::& rs)
arithImpl
:: ( Typeable (ArithRes aop n m)
, SingI (ArithRes aop n m)
, Typeable ('Tc (ArithRes aop n m) ': s)
, inp ~ ('Tc n ': 'Tc m ': s)
, Monad t
)
=> Instr inp ('Tc (ArithRes aop n m) ': s)
-> HST inp
-> VarAnn
-> t (SomeInstr inp)
arithImpl mkInstr i@(_ ::& _ ::& rs) vn = do
pure $ i :/ mkInstr ::: ((sing, starNotes, vn) ::& rs)
addImpl
:: forall a b inp rs m.
( Typeable rs
, Each [Typeable, SingI] [a, b]
, inp ~ ('Tc a ': 'Tc b ': rs)
, MonadReader InstrCallStack m
, MonadError TCError m
)
=> Sing a -> Sing b
-> HST inp
-> VarAnn
-> m (SomeInstr inp)
addImpl SCInt SCInt = arithImpl @Add ADD
addImpl SCInt SCNat = arithImpl @Add ADD
addImpl SCNat SCInt = arithImpl @Add ADD
addImpl SCNat SCNat = arithImpl @Add ADD
addImpl SCInt SCTimestamp = arithImpl @Add ADD
addImpl SCTimestamp SCInt = arithImpl @Add ADD
addImpl SCMutez SCMutez = arithImpl @Add ADD
addImpl _ _ = \i vn -> onTypeCheckInstrErr (Un.ADD vn) (SomeHST i)
"wrong operand types for add operation"
(Left $ UnsupportedTypes [demote @('Tc a), demote @('Tc b)])
edivImpl
:: forall a b inp rs m.
( Typeable rs
, Each [Typeable, SingI] [a, b]
, inp ~ ('Tc a ': 'Tc b ': rs)
, MonadReader InstrCallStack m
, MonadError TCError m
)
=> Sing a -> Sing b
-> HST inp
-> VarAnn
-> m (SomeInstr inp)
edivImpl SCInt SCInt = edivImplDo
edivImpl SCInt SCNat = edivImplDo
edivImpl SCNat SCInt = edivImplDo
edivImpl SCNat SCNat = edivImplDo
edivImpl SCMutez SCMutez = edivImplDo
edivImpl SCMutez SCNat = edivImplDo
edivImpl _ _ = \i vn -> onTypeCheckInstrErr (Un.EDIV vn) (SomeHST i)
"wrong operand types for ediv operation"
(Left $ UnsupportedTypes [demote @('Tc a), demote @('Tc b)])
edivImplDo
:: ( EDivOp n m
, SingI (EModOpRes n m)
, Typeable (EModOpRes n m)
, SingI (EDivOpRes n m)
, Typeable (EDivOpRes n m)
, inp ~ ('Tc n ': 'Tc m ': s)
, Monad t
)
=> HST inp
-> VarAnn
-> t (SomeInstr inp)
edivImplDo i@(_ ::& _ ::& rs) vn = do
pure $ i :/ EDIV ::: ((sing, starNotes, vn) ::& rs)
subImpl
:: forall a b inp rs m.
( Typeable rs
, Each [Typeable, SingI] [a, b]
, inp ~ ('Tc a ': 'Tc b ': rs)
, MonadReader InstrCallStack m
, MonadError TCError m
)
=> Sing a -> Sing b
-> HST inp
-> VarAnn
-> m (SomeInstr inp)
subImpl SCInt SCInt = arithImpl @Sub SUB
subImpl SCInt SCNat = arithImpl @Sub SUB
subImpl SCNat SCInt = arithImpl @Sub SUB
subImpl SCNat SCNat = arithImpl @Sub SUB
subImpl SCTimestamp SCTimestamp = arithImpl @Sub SUB
subImpl SCTimestamp SCInt = arithImpl @Sub SUB
subImpl SCMutez SCMutez = arithImpl @Sub SUB
subImpl _ _ = \i vn -> onTypeCheckInstrErr (Un.SUB vn) (SomeHST i)
"wrong operand types for sub operation"
(Left $ UnsupportedTypes [demote @('Tc a), demote @('Tc b)])
mulImpl
:: forall a b inp rs m.
( Typeable rs
, Each [Typeable, SingI] [a, b]
, inp ~ ('Tc a ': 'Tc b ': rs)
, MonadReader InstrCallStack m
, MonadError TCError m
)
=> Sing a -> Sing b
-> HST inp
-> VarAnn
-> m (SomeInstr inp)
mulImpl SCInt SCInt = arithImpl @Mul MUL
mulImpl SCInt SCNat = arithImpl @Mul MUL
mulImpl SCNat SCInt = arithImpl @Mul MUL
mulImpl SCNat SCNat = arithImpl @Mul MUL
mulImpl SCNat SCMutez = arithImpl @Mul MUL
mulImpl SCMutez SCNat = arithImpl @Mul MUL
mulImpl _ _ = \i vn -> onTypeCheckInstrErr (Un.MUL vn) (SomeHST i)
"wrong operand types for mul operation"
(Left $ UnsupportedTypes [demote @('Tc a), demote @('Tc b)])
unaryArithImpl
:: ( Typeable (UnaryArithRes aop n)
, SingI (UnaryArithRes aop n)
, Typeable ('Tc (UnaryArithRes aop n) ': s)
, inp ~ ('Tc n ': s)
, Monad t
)
=> Instr inp ('Tc (UnaryArithRes aop n) ': s)
-> HST inp
-> VarAnn
-> t (SomeInstr inp)
unaryArithImpl mkInstr i@(_ ::& rs) vn = do
pure $ i :/ mkInstr ::: ((sing, starNotes, vn) ::& rs)