module Michelson.TypeCheck.Helpers ( onLeft , deriveSpecialVN , deriveSpecialFNs , deriveVN , deriveNsOr , deriveNsOption , convergeHSTEl , convergeHST , hstToTs , eqHST , eqHST1 , lengthHST , ensureDistinctAsc , eqType , onTypeCheckInstrAnnErr , onTypeCheckInstrErr , onScopeCheckInstrErr , typeCheckInstrErr , typeCheckInstrErr' , typeCheckImpl , matchTypes , memImpl , getImpl , updImpl , sliceImpl , concatImpl , concatImpl' , sizeImpl , arithImpl , addImpl , subImpl , mulImpl , edivImpl , unaryArithImpl , withCompareableCheck ) where import Prelude hiding (EQ, GT, LT) import Control.Monad.Except (MonadError, throwError) import Data.Constraint (Dict(..), withDict) import Data.Default (def) import Data.Singletons (Sing, SingI(sing), demote) import qualified Data.Text as T import Data.Typeable ((:~:)(..), eqT) import Fmt ((+||), (||+)) import Michelson.ErrorPos (InstrCallStack) import Michelson.TypeCheck.Error (TCError(..), TCTypeError(..), TypeContext(..)) import Michelson.TypeCheck.TypeCheck import Michelson.TypeCheck.Types import Michelson.Typed (BadTypeForScope(..), Comparable, Instr(..), Notes(..), PackedNotes(..), SingT(..), T(..), WellTyped, converge , getComparableProofS, notesT, orAnn, starNotes) import Michelson.Typed.Annotation (AnnConvergeError, isStar) import Michelson.Typed.Arith (Add, ArithOp(..), Mul, Sub, UnaryArithOp(..)) import Michelson.Typed.Polymorphic (ConcatOp, EDivOp(..), GetOp(..), MemOp(..), SizeOp, SliceOp, UpdOp(..)) import qualified Michelson.Untyped as Un import Michelson.Untyped.Annotation (Annotation(..), FieldAnn, VarAnn, ann) -- | Function which derives special annotations -- for PAIR instruction. -- -- Namely, it does following transformation: -- @ -- PAIR %@@ %@@ [ @@p.a int : @@p.b int : .. ] -- ~ -- [ @@p (pair (int %a) (int %b) : .. ] -- @ -- -- All relevant cases (e.g. @PAIR %myf %@@ @) -- are handled as they should be according to spec. deriveSpecialFNs :: FieldAnn -> FieldAnn -> VarAnn -> VarAnn -> (VarAnn, FieldAnn, FieldAnn) deriveSpecialFNs "@" "@" pvn qvn = (vn, pfn, qfn) where ps = T.splitOn "." $ unAnnotation pvn qs = T.splitOn "." $ unAnnotation qvn fns = fst <$> takeWhile (uncurry (==)) (zip ps qs) vn = ann $ T.intercalate "." fns pfn = ann $ T.intercalate "." $ drop (length fns) ps qfn = ann $ T.intercalate "." $ drop (length fns) qs deriveSpecialFNs "@" qfn pvn _ = (def, Un.convAnn pvn, qfn) deriveSpecialFNs pfn "@" _ qvn = (def, pfn, Un.convAnn qvn) deriveSpecialFNs pfn qfn _ _ = (def, pfn, qfn) -- | Function which derives special annotations -- for CDR / CAR instructions. deriveSpecialVN :: VarAnn -> FieldAnn -> VarAnn -> VarAnn deriveSpecialVN vn elFn pairVN | vn == "%" = Un.convAnn elFn | vn == "%%" && elFn /= def = pairVN <> Un.convAnn elFn | otherwise = vn -- | Append suffix to variable annotation (if it's not empty) deriveVN :: VarAnn -> VarAnn -> VarAnn deriveVN suffix vn = bool (suffix <> vn) def (vn == def) -- | Function which extracts annotations for @or@ type -- (for left and right parts). -- -- It extracts field/type annotations and also auto-generates variable -- annotations if variable annotation is not provided as second argument. deriveNsOr :: Notes ('TOr a b) -> VarAnn -> (Notes a, Notes b, VarAnn, VarAnn) deriveNsOr (NTOr _ afn bfn an bn) ovn = let avn = deriveVN (Un.convAnn afn `orAnn` "left") ovn bvn = deriveVN (Un.convAnn bfn `orAnn` "right") ovn in (an, bn, avn, bvn) -- | Function which extracts annotations for @option t@ type. -- -- It extracts field/type annotations and also auto-generates variable -- annotation for @Some@ case if it is not provided as second argument. deriveNsOption :: Notes ('TOption a) -> VarAnn -> (Notes a, VarAnn) deriveNsOption (NTOption _ an) ovn = let avn = deriveVN "some" ovn in (an, avn) convergeHSTEl :: (Notes t, Dict (WellTyped t), VarAnn) -> (Notes t, Dict (WellTyped t), VarAnn) -> Either AnnConvergeError (Notes t, Dict (WellTyped t), VarAnn) convergeHSTEl (an, d@Dict, avn) (bn, _, bvn) = (,,) <$> converge an bn <*> pure d <*> pure (bool def avn $ avn == bvn) -- | Combine annotations from two given stack types convergeHST :: HST ts -> HST ts -> Either AnnConvergeError (HST ts) convergeHST SNil SNil = pure SNil convergeHST (a ::& as) (b ::& bs) = liftA2 (::&) (convergeHSTEl a b) (convergeHST as bs) -- TODO move to Util module onLeft :: Either a c -> (a -> b) -> Either b c onLeft = flip first -- | Extract singleton for each single type of the given stack. hstToTs :: HST st -> [T] hstToTs = \case SNil -> [] (notes, _, _) ::& hst -> notesT notes : hstToTs hst -- | Check whether the given stack types are equal. eqHST :: forall as bs. (Typeable as, Typeable bs) => HST as -> HST bs -> Either TCTypeError (as :~: bs) eqHST (hst :: HST xs) (hst' :: HST ys) = do case eqT @as @bs of Nothing -> Left $ StackEqError (hstToTs hst) (hstToTs hst') Just Refl -> do void $ convergeHST hst hst' `onLeft` AnnError return Refl -- | Check whether the given stack has size 1 and its only element matches the -- given type. This function is a specialized version of `eqHST`. eqHST1 :: forall t st. (Typeable st, Typeable t, SingI t, WellTyped t) => HST st -> Either TCTypeError (st :~: '[t]) eqHST1 hst = do let hst' = sing @t -:& SNil case eqT @'[t] @st of Nothing -> Left $ StackEqError (hstToTs hst') (hstToTs hst) Just Refl -> Right Refl lengthHST :: HST xs -> Natural lengthHST (_ ::& xs) = 1 + lengthHST xs lengthHST SNil = 0 -------------------------------------------- -- Typechecker auxiliary -------------------------------------------- -- | Check whether elements go in strictly ascending order and -- return the original list (to keep only one pass on the original list). ensureDistinctAsc :: (Ord b, Show a) => (a -> b) -> [a] -> Either Text [a] ensureDistinctAsc toCmp = \case (e1 : e2 : l) -> if toCmp e1 < toCmp e2 then (e1 :) <$> ensureDistinctAsc toCmp (e2 : l) else Left $ "Entries are unordered (" +|| e1 ||+ " >= " +|| e2 ||+ ")" l -> Right l -- | Function @eqType@ is a simple wrapper around @Data.Typeable.eqT@ suited -- for use within @Either TCTypeError a@ applicative. eqType :: forall (a :: T) (b :: T). (Each [Typeable, SingI] [a, b]) => Either TCTypeError (a :~: b) eqType = maybe (Left $ TypeEqError (demote @a) (demote @b)) pure eqT onTypeCheckInstrErr :: (MonadReader InstrCallStack m, MonadError TCError m) => Un.ExpandedInstr -> SomeHST -> Maybe TypeContext -> Either TCTypeError a -> m a onTypeCheckInstrErr instr hst mContext ei = do either (typeCheckInstrErr' instr hst mContext) return ei onScopeCheckInstrErr :: forall (t :: T) m a. (MonadReader InstrCallStack m, MonadError TCError m, SingI t) => Un.ExpandedInstr -> SomeHST -> Maybe TypeContext -> Either BadTypeForScope a -> m a onScopeCheckInstrErr instr hst mContext = \case Right a -> return a Left e -> do pos <- ask throwError $ TCFailedOnInstr instr hst pos mContext $ Just $ UnsupportedTypeForScope (demote @t) e typeCheckInstrErr :: (MonadReader InstrCallStack m, MonadError TCError m) => Un.ExpandedInstr -> SomeHST -> Maybe TypeContext -> m a typeCheckInstrErr instr hst mContext = do pos <- ask throwError $ TCFailedOnInstr instr hst pos mContext Nothing typeCheckInstrErr' :: (MonadReader InstrCallStack m, MonadError TCError m) => Un.ExpandedInstr -> SomeHST -> Maybe TypeContext -> TCTypeError -> m a typeCheckInstrErr' instr hst mContext err = do pos <- ask throwError $ TCFailedOnInstr instr hst pos mContext (Just err) onTypeCheckInstrAnnErr :: (MonadReader InstrCallStack m, MonadError TCError m, Typeable ts) => Un.ExpandedInstr -> HST ts -> Maybe TypeContext -> Either AnnConvergeError a -> m a onTypeCheckInstrAnnErr instr i mContext ei = onTypeCheckInstrErr instr (SomeHST i) mContext (ei `onLeft` AnnError) withCompareableCheck :: forall a m v ts. (Typeable ts, MonadReader InstrCallStack m, MonadError TCError m) => Sing a -> Un.ExpandedInstr -> HST ts -> (Comparable a => v) -> m v withCompareableCheck sng instr i act = case getComparableProofS sng of Just d@Dict -> pure $ withDict d act Nothing -> typeCheckInstrErr instr (SomeHST i) $ Just ComparisonArguments typeCheckImpl :: forall inp . Typeable inp => TcInstrHandler -> [Un.ExpandedOp] -> HST inp -> TypeCheckInstr (SomeInstr inp) typeCheckImpl tcInstr instrs t@(a :: HST a) = case instrs of Un.WithSrcEx _ (i@(Un.WithSrcEx _ _)) : rs -> typeCheckImpl tcInstr (i : rs) t Un.WithSrcEx cs (Un.PrimEx i) : rs -> typeCheckPrim (Just cs) i rs Un.WithSrcEx cs (Un.SeqEx sq) : rs -> typeCheckSeq (Just cs) sq rs Un.PrimEx i : rs -> typeCheckPrim Nothing i rs Un.SeqEx sq : rs -> typeCheckSeq Nothing sq rs [] -> pure $ a :/ Nop ::: a where typeCheckPrim (Just cs) i [] = local (const cs) $ tcInstr i t typeCheckPrim (Just cs) i rs = local (const cs) $ typeCheckImplDo (tcInstr i t) id rs typeCheckPrim Nothing i [] = tcInstr i t typeCheckPrim Nothing i rs = typeCheckImplDo (tcInstr i t) id rs typeCheckSeq (Just cs) sq = local (const cs) . typeCheckImplDo (typeCheckImpl tcInstr sq t) Nested typeCheckSeq Nothing sq = typeCheckImplDo (typeCheckImpl tcInstr sq t) Nested typeCheckImplDo :: TypeCheckInstr (SomeInstr inp) -> (forall inp' out . Instr inp' out -> Instr inp' out) -> [Un.ExpandedOp] -> TypeCheckInstr (SomeInstr inp) typeCheckImplDo f wrap rs = do _ :/ pi' <- f case pi' of p ::: b -> do _ :/ qi <- typeCheckImpl tcInstr rs b case qi of q ::: c -> case q of Seq _ _ -> pure $ a :/ Seq (wrapWithNotes b (wrap p)) q ::: c -- Wrap the RHS if it is a single instruction and not a -- sequence _ -> pure $ a :/ Seq (wrapWithNotes b (wrap p)) (wrapWithNotes c q) ::: c AnyOutInstr q -> pure $ a :/ AnyOutInstr (Seq (wrapWithNotes b (wrap p)) q) AnyOutInstr instr -> case rs of [] -> pure $ a :/ AnyOutInstr instr r : rr -> throwError $ TCUnreachableCode (extractInstrPos r) (r :| rr) wrapWithNotes :: HST d -> Instr c d -> Instr c d wrapWithNotes h ins = case h of -- do not wrap in notes if the notes are "star" ((n, _, _) ::& _) | isStar n -> ins ((n, _, _) ::& _) -> InstrWithNotes (PackedNotes n) ins SNil -> ins extractInstrPos :: Un.ExpandedOp -> InstrCallStack extractInstrPos (Un.WithSrcEx cs _) = cs extractInstrPos _ = def -- | Check whether given types are structurally equal and annotations converge. matchTypes :: forall t1 t2. (Each [Typeable, SingI] [t1, t2]) => Notes t1 -> Notes t2 -> Either TCTypeError (t1 :~: t2, Notes t1) matchTypes n1 n2 = do Refl <- eqType @t1 @t2 nr <- converge n1 n2 `onLeft` AnnError return (Refl, nr) -------------------------------------------- -- Some generic instruction implementation -------------------------------------------- -- | Generic implementation for MEMeration memImpl :: forall (q :: T) (c :: T) ts inp m . ( MonadReader InstrCallStack m, MonadError TCError m, Typeable ts , Typeable (MemOpKey c), SingI (MemOpKey c), MemOp c , inp ~ (q : c : ts) ) => Un.ExpandedInstr -> HST inp -> VarAnn -> m (SomeInstr inp) memImpl instr i@(_ ::& _ ::& rs) vn = do pos <- ask case eqType @q @(MemOpKey c) of Right Refl -> pure $ i :/ MEM ::: ((starNotes, Dict, vn) ::& rs) Left m -> throwError $ TCFailedOnInstr instr (SomeHST i) pos (Just ContainerKeyType) (Just m) getImpl :: forall c getKey rs inp m . ( GetOp c, Typeable (GetOpKey c) , Typeable (GetOpVal c) , SingI (GetOpVal c), SingI (GetOpKey c) , WellTyped (GetOpVal c) , inp ~ (getKey : c : rs) , MonadReader InstrCallStack m , MonadError TCError m ) => Un.ExpandedInstr -> HST (getKey ': c ': rs) -> Notes (GetOpVal c) -> VarAnn -> m (SomeInstr inp) getImpl instr i@(_ ::& _ ::& rs) vns vn = do pos <- ask case eqType @getKey @(GetOpKey c) of Right Refl -> do let rn = NTOption def vns pure $ i :/ GET ::: ((rn, Dict, vn) ::& rs) Left m -> throwError $ TCFailedOnInstr instr (SomeHST i) pos (Just ContainerKeyType) (Just m) updImpl :: forall c updKey updParams rs inp m . ( UpdOp c , Typeable (UpdOpKey c), SingI (UpdOpKey c) , Typeable (UpdOpParams c), SingI (UpdOpParams c) , inp ~ (updKey : updParams : c : rs) , MonadReader InstrCallStack m , MonadError TCError m ) => Un.ExpandedInstr -> HST (updKey ': updParams ': c ': rs) -> VarAnn -> m (SomeInstr inp) updImpl instr i@(_ ::& _ ::& cTuple ::& rest) vn = do pos <- ask case (eqType @updKey @(UpdOpKey c), eqType @updParams @(UpdOpParams c)) of (Right Refl, Right Refl) -> pure $ i :/ UPDATE ::: ((cTuple & _3 .~ vn) ::& rest) (Left m, _) -> throwError $ TCFailedOnInstr instr (SomeHST i) pos (Just ContainerKeyType) (Just m) (_, Left m) -> throwError $ TCFailedOnInstr instr (SomeHST i) pos (Just ContainerValueType) (Just m) sizeImpl :: (SizeOp c, inp ~ (c ': rs), Monad m) => HST inp -> VarAnn -> m (SomeInstr inp) sizeImpl i@(_ ::& rs) vn = pure $ i :/ SIZE ::: ((starNotes, Dict, vn) ::& rs) sliceImpl :: (SliceOp c, Typeable c, inp ~ ('TNat ': 'TNat ': c ': rs), Monad m) => HST inp -> Un.VarAnn -> m (SomeInstr inp) sliceImpl i@(_ ::& _ ::& (cn, Dict, cvn) ::& rs) vn = do let vn' = vn `orAnn` deriveVN "slice" cvn rn = NTOption def cn pure $ i :/ SLICE ::: ((rn, Dict, vn') ::& rs) concatImpl' :: (ConcatOp c, Typeable c, SingI c, WellTyped c, inp ~ ('TList c : rs), Monad m) => HST inp -> Un.VarAnn -> m (SomeInstr inp) concatImpl' i@((NTList _ n, Dict, _) ::& rs) vn = do pure $ i :/ CONCAT' ::: ((n, Dict, vn) ::& rs) concatImpl :: ( ConcatOp c, Typeable c, inp ~ (c ': c ': rs) , WellTyped c , MonadReader InstrCallStack m , MonadError TCError m ) => HST inp -> Un.VarAnn -> m (SomeInstr inp) concatImpl i@((cn1, _, _) ::& (cn2, _, _) ::& rs) vn = do cn <- onTypeCheckInstrAnnErr (Un.CONCAT vn) i (Just ConcatArgument) (converge cn1 cn2) pure $ i :/ CONCAT ::: ((cn, Dict, vn) ::& rs) -- | Helper function to construct instructions for binary arithmetic -- operations. arithImpl :: ( Typeable (ArithRes aop n m) , SingI (ArithRes aop n m) , Typeable (ArithRes aop n m ': s) , WellTyped (ArithRes aop n m) , inp ~ (n ': m ': s) , Monad t ) => Instr inp (ArithRes aop n m ': s) -> HST inp -> VarAnn -> t (SomeInstr inp) arithImpl mkInstr i@(_ ::& _ ::& rs) vn = do pure $ i :/ mkInstr ::: ((starNotes, Dict, vn) ::& rs) addImpl :: forall a b inp rs m. ( Typeable rs , Each [Typeable, SingI] [a, b] , inp ~ (a ': b ': rs) , MonadReader InstrCallStack m , MonadError TCError m ) => Sing a -> Sing b -> HST inp -> VarAnn -> m (SomeInstr inp) addImpl t1 t2 = case (t1, t2) of (STInt, STInt) -> arithImpl @Add ADD (STInt, STNat) -> arithImpl @Add ADD (STNat, STInt) -> arithImpl @Add ADD (STNat, STNat) -> arithImpl @Add ADD (STInt, STTimestamp) -> arithImpl @Add ADD (STTimestamp, STInt) -> arithImpl @Add ADD (STMutez, STMutez) -> arithImpl @Add ADD _ -> \i vn -> typeCheckInstrErr' (Un.ADD vn) (SomeHST i) (Just ArithmeticOperation) $ NotNumericTypes (demote @a) (demote @b) edivImpl :: forall a b inp rs m. ( Typeable rs , Each [Typeable, SingI] [a, b] , inp ~ (a ': b ': rs) , MonadReader InstrCallStack m , MonadError TCError m ) => Sing a -> Sing b -> HST inp -> VarAnn -> m (SomeInstr inp) edivImpl t1 t2 = case (t1, t2) of (STInt, STInt) -> edivImplDo (STInt, STNat) -> edivImplDo (STNat, STInt) -> edivImplDo (STNat, STNat) -> edivImplDo (STMutez, STMutez) -> edivImplDo (STMutez, STNat) -> edivImplDo _ -> \i vn -> typeCheckInstrErr' (Un.EDIV vn) (SomeHST i) (Just ArithmeticOperation) $ NotNumericTypes (demote @a) (demote @b) edivImplDo :: ( EDivOp n m , SingI (EModOpRes n m) , Typeable (EModOpRes n m) , SingI (EDivOpRes n m) , Typeable (EDivOpRes n m) , WellTyped (EModOpRes n m) , WellTyped (EDivOpRes n m) , inp ~ (n ': m ': s) , Monad t ) => HST inp -> VarAnn -> t (SomeInstr inp) edivImplDo i@(_ ::& _ ::& rs) vn = do pure $ i :/ EDIV ::: ((starNotes, Dict, vn) ::& rs) subImpl :: forall a b inp rs m. ( Typeable rs , Each [Typeable, SingI] [a, b] , inp ~ (a ': b ': rs) , MonadReader InstrCallStack m , MonadError TCError m ) => Sing a -> Sing b -> HST inp -> VarAnn -> m (SomeInstr inp) subImpl t1 t2 = case (t1, t2) of (STInt, STInt) -> arithImpl @Sub SUB (STInt, STNat) -> arithImpl @Sub SUB (STNat, STInt) -> arithImpl @Sub SUB (STNat, STNat) -> arithImpl @Sub SUB (STTimestamp, STTimestamp) -> arithImpl @Sub SUB (STTimestamp, STInt) -> arithImpl @Sub SUB (STMutez, STMutez) -> arithImpl @Sub SUB _ -> \i vn -> typeCheckInstrErr' (Un.SUB vn) (SomeHST i) (Just ArithmeticOperation) $ NotNumericTypes (demote @a) (demote @b) mulImpl :: forall a b inp rs m. ( Typeable rs , Each [Typeable, SingI] [a, b] , inp ~ (a ': b ': rs) , MonadReader InstrCallStack m , MonadError TCError m ) => Sing a -> Sing b -> HST inp -> VarAnn -> m (SomeInstr inp) mulImpl t1 t2 = case (t1, t2) of (STInt, STInt) -> arithImpl @Mul MUL (STInt, STNat) -> arithImpl @Mul MUL (STNat, STInt) -> arithImpl @Mul MUL (STNat, STNat) -> arithImpl @Mul MUL (STNat, STMutez) -> arithImpl @Mul MUL (STMutez, STNat) -> arithImpl @Mul MUL _ -> \i vn -> typeCheckInstrErr' (Un.MUL vn) (SomeHST i) (Just ArithmeticOperation) $ NotNumericTypes (demote @a) (demote @b) -- | Helper function to construct instructions for binary arithmetic -- operations. unaryArithImpl :: ( Typeable (UnaryArithRes aop n) , SingI (UnaryArithRes aop n) , Typeable (UnaryArithRes aop n ': s) , WellTyped (UnaryArithRes aop n) , inp ~ (n ': s) , Monad t ) => Instr inp (UnaryArithRes aop n ': s) -> HST inp -> VarAnn -> t (SomeInstr inp) unaryArithImpl mkInstr i@(_ ::& rs) vn = do pure $ i :/ mkInstr ::: ((starNotes, Dict, vn) ::& rs)