-- Terms and substitutions, implemented using flatterms.
-- This module contains all the low-level icky bits
-- and provides primitives for building higher-level stuff.
{-# LANGUAGE CPP, PatternSynonyms, ViewPatterns,
    MagicHash, UnboxedTuples, BangPatterns,
    RankNTypes, RecordWildCards, GeneralizedNewtypeDeriving,
    OverloadedStrings, RoleAnnotations #-}
{-# OPTIONS_GHC -O2 -fmax-worker-args=100 #-}
#ifdef USE_LLVM
{-# OPTIONS_GHC -fllvm #-}
#endif
module Twee.Term.Core where

import Data.Primitive(sizeOf)
#ifdef BOUNDS_CHECKS
import Data.Primitive.ByteArray.Checked
#else
import Data.Primitive.ByteArray
#endif
import Control.Monad.ST.Strict
import Data.Bits
import Data.Int
import GHC.Types(Int(..))
import GHC.Prim
import GHC.ST hiding (liftST)
import Data.Ord
import Data.Semigroup(Semigroup(..))
import Twee.Profile

--------------------------------------------------------------------------------
-- Symbols. A symbol is a single function or variable in a flatterm.
--------------------------------------------------------------------------------

data Symbol =
  Symbol {
    -- Is it a function?
    Symbol -> Bool
isFun :: Bool,
    -- What is its number?
    Symbol -> Int
index :: Int,
    -- What is the size of the term rooted at this symbol?
    Symbol -> Int
size  :: Int }

instance Show Symbol where
  show :: Symbol -> String
show Symbol{Bool
Int
size :: Int
index :: Int
isFun :: Bool
size :: Symbol -> Int
index :: Symbol -> Int
isFun :: Symbol -> Bool
..}
    | Bool
isFun = String
"f" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
index String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"=" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
size
    | Bool
otherwise = Var -> String
forall a. Show a => a -> String
show (Int -> Var
V Int
index)

-- Convert symbols to/from Int64 for storage in flatterms.
-- The encoding:
--   * bits 0-30: size
--   * bit  31: 0 (variable) or 1 (function)
--   * bits 32-63: index
{-# INLINE toSymbol #-}
toSymbol :: Int64 -> Symbol
toSymbol :: Int64 -> Symbol
toSymbol Int64
n =
  Bool -> Int -> Int -> Symbol
Symbol (Int64 -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
testBit Int64
n Int
31)
    (Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64
n Int64 -> Int -> Int64
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
32))
    (Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64
n Int64 -> Int64 -> Int64
forall a. Bits a => a -> a -> a
.&. Int64
0x7fffffff))

{-# INLINE fromSymbol #-}
fromSymbol :: Symbol -> Int64
fromSymbol :: Symbol -> Int64
fromSymbol Symbol{Bool
Int
size :: Int
index :: Int
isFun :: Bool
size :: Symbol -> Int
index :: Symbol -> Int
isFun :: Symbol -> Bool
..} =
  Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
size Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
+
  Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
index Int64 -> Int -> Int64
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
32 Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
+
  Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Bool -> Int
forall a. Enum a => a -> Int
fromEnum Bool
isFun) Int64 -> Int -> Int64
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
31

{-# INLINE symbolSize #-}
symbolSize :: Int
symbolSize :: Int
symbolSize = Int64 -> Int
forall a. Prim a => a -> Int
sizeOf (Symbol -> Int64
fromSymbol Symbol
forall a. HasCallStack => a
undefined)

--------------------------------------------------------------------------------
-- Flatterms, or rather lists of terms.
--------------------------------------------------------------------------------

-- | @'TermList' f@ is a list of terms whose function symbols have type @f@.
-- It is either a 'Cons' or an 'Empty'. You can turn it into a @['Term' f]@
-- with 'Twee.Term.unpack'.

-- A TermList is a slice of an unboxed array of symbols.
data TermList f =
  TermList {
    TermList f -> Int
low   :: {-# UNPACK #-} !Int,
    TermList f -> Int
high  :: {-# UNPACK #-} !Int,
    TermList f -> ByteArray
array :: {-# UNPACK #-} !ByteArray }

type role TermList nominal

-- | Index into a termlist.
at :: Int -> TermList f -> Term f
at :: Int -> TermList f -> Term f
at Int
n TermList f
t
  | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 Bool -> Bool -> Bool
|| TermList f -> Int
forall f. TermList f -> Int
low TermList f
t Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= TermList f -> Int
forall f. TermList f -> Int
high TermList f
t = String -> Term f
forall a. HasCallStack => String -> a
error String
"term index out of bounds"
  | Bool
otherwise = Int -> TermList f -> Term f
forall f. Int -> TermList f -> Term f
unsafeAt Int
n TermList f
t

-- | Index into a termlist, without bounds checking.
unsafeAt :: Int -> TermList f -> Term f
unsafeAt :: Int -> TermList f -> Term f
unsafeAt Int
n (TermList Int
lo Int
hi ByteArray
arr) =
  case Int -> Int -> ByteArray -> TermList f
forall f. Int -> Int -> ByteArray -> TermList f
TermList (Int
loInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
n) Int
hi ByteArray
arr of
    UnsafeCons Term f
t TermList f
_ -> Term f
t

{-# INLINE lenList #-}
-- | The length of (number of symbols in) a termlist.
lenList :: TermList f -> Int
lenList :: TermList f -> Int
lenList (TermList Int
low Int
high ByteArray
_) = Int
high Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
low

-- | @'Term' f@ is a term whose function symbols have type @f@.
-- It is either a 'Var' or an 'App'.

-- A term is a special case of a termlist.
-- We store it as the termlist together with the root symbol.
data Term f =
  Term {
    Term f -> Int64
root     :: {-# UNPACK #-} !Int64,
    Term f -> TermList f
termlist :: {-# UNPACK #-} !(TermList f) }

type role Term nominal

instance Eq (Term f) where
  Term f
x == :: Term f -> Term f -> Bool
== Term f
y = Term f -> TermList f
forall f. Term f -> TermList f
termlist Term f
x TermList f -> TermList f -> Bool
forall a. Eq a => a -> a -> Bool
== Term f -> TermList f
forall f. Term f -> TermList f
termlist Term f
y

instance Ord (Term f) where
  compare :: Term f -> Term f -> Ordering
compare = (Term f -> TermList f) -> Term f -> Term f -> Ordering
forall a b. Ord a => (b -> a) -> b -> b -> Ordering
comparing Term f -> TermList f
forall f. Term f -> TermList f
termlist

-- Pattern synonyms for termlists:
-- * Empty :: TermList f
--   Empty is the empty termlist.
-- * Cons t ts :: Term f -> TermList f -> TermList f
--   Cons t ts is the termlist t:ts.
-- * ConsSym t ts :: Term f -> TermList f -> TermList f
--   ConsSym t ts is like Cons t ts but ts also includes t's children
--   (operationally, ts seeks one term to the right in the termlist).
-- * UnsafeCons/UnsafeConsSym: like Cons and ConsSym but don't check
--   that the termlist is non-empty.

-- | Matches the empty termlist.
pattern Empty :: TermList f
pattern $mEmpty :: forall r f. TermList f -> (Void# -> r) -> (Void# -> r) -> r
Empty <- (patHead -> Nothing)

-- | Matches a non-empty termlist, unpacking it into head and tail.
pattern Cons :: Term f -> TermList f -> TermList f
pattern $mCons :: forall r f.
TermList f -> (Term f -> TermList f -> r) -> (Void# -> r) -> r
Cons t ts <- (patHead -> Just (t, _, ts))

{-# COMPLETE Empty, Cons #-}
{-# COMPLETE Empty, ConsSym #-}

-- | Like 'Cons', but does not check that the termlist is non-empty. Use only if
-- you are sure the termlist is non-empty.
pattern UnsafeCons :: Term f -> TermList f -> TermList f
pattern $mUnsafeCons :: forall r f.
TermList f -> (Term f -> TermList f -> r) -> (Void# -> r) -> r
UnsafeCons t ts <- (unsafePatHead -> (t, _, ts))

-- | Matches a non-empty termlist, unpacking it into head and
-- /everything except the root symbol of the head/.
-- Useful for iterating through terms one symbol at a time.
--
-- For example, if @ts@ is the termlist @[f(x,y), g(z)]@,
-- then @let ConsSym u us = ts@ results in the following bindings:
--
-- > u  = f(x,y)
-- > us = [x, y, g(z)]
pattern ConsSym :: Term f -> TermList f -> TermList f -> TermList f
pattern $mConsSym :: forall r f.
TermList f
-> (Term f -> TermList f -> TermList f -> r) -> (Void# -> r) -> r
ConsSym{TermList f -> Term f
hd, TermList f -> TermList f
tl, TermList f -> TermList f
rest} <- (patHead -> Just (hd, rest, tl))

-- | Like 'ConsSym', but does not check that the termlist is non-empty. Use only
-- if you are sure the termlist is non-empty.
pattern UnsafeConsSym :: Term f -> TermList f -> TermList f -> TermList f
pattern $mUnsafeConsSym :: forall r f.
TermList f
-> (Term f -> TermList f -> TermList f -> r) -> (Void# -> r) -> r
UnsafeConsSym{TermList f -> Term f
uhd, TermList f -> TermList f
utl, TermList f -> TermList f
urest} <- (unsafePatHead -> (uhd, urest, utl))

-- A helper for UnsafeCons/UnsafeConsSym.
{-# INLINE unsafePatHead #-}
unsafePatHead :: TermList f -> (Term f, TermList f, TermList f)
unsafePatHead :: TermList f -> (Term f, TermList f, TermList f)
unsafePatHead TermList{Int
ByteArray
array :: ByteArray
high :: Int
low :: Int
array :: forall f. TermList f -> ByteArray
high :: forall f. TermList f -> Int
low :: forall f. TermList f -> Int
..} =
  (Int64 -> TermList f -> Term f
forall f. Int64 -> TermList f -> Term f
Term Int64
x (Int -> Int -> ByteArray -> TermList f
forall f. Int -> Int -> ByteArray -> TermList f
TermList Int
low (Int
lowInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
size) ByteArray
array),
   Int -> Int -> ByteArray -> TermList f
forall f. Int -> Int -> ByteArray -> TermList f
TermList (Int
lowInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Int
high ByteArray
array,
   Int -> Int -> ByteArray -> TermList f
forall f. Int -> Int -> ByteArray -> TermList f
TermList (Int
lowInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
size) Int
high ByteArray
array)
  where
    !x :: Int64
x = ByteArray -> Int -> Int64
forall a. Prim a => ByteArray -> Int -> a
indexByteArray ByteArray
array Int
low
    Symbol{Bool
Int
index :: Int
isFun :: Bool
size :: Int
size :: Symbol -> Int
index :: Symbol -> Int
isFun :: Symbol -> Bool
..} = Int64 -> Symbol
toSymbol Int64
x

-- A helper for Cons/ConsSym.
{-# INLINE patHead #-}
patHead :: TermList f -> Maybe (Term f, TermList f, TermList f)
patHead :: TermList f -> Maybe (Term f, TermList f, TermList f)
patHead t :: TermList f
t@TermList{Int
ByteArray
array :: ByteArray
high :: Int
low :: Int
array :: forall f. TermList f -> ByteArray
high :: forall f. TermList f -> Int
low :: forall f. TermList f -> Int
..}
  | Int
low Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
high = Maybe (Term f, TermList f, TermList f)
forall a. Maybe a
Nothing
  | Bool
otherwise = (Term f, TermList f, TermList f)
-> Maybe (Term f, TermList f, TermList f)
forall a. a -> Maybe a
Just (TermList f -> (Term f, TermList f, TermList f)
forall f. TermList f -> (Term f, TermList f, TermList f)
unsafePatHead TermList f
t)

-- Pattern synonyms for single terms.
-- * Var :: Var -> Term f
-- * App :: Fun f -> TermList f -> Term f

-- | A function symbol. @f@ is the underlying type of function symbols defined
-- by the user; @'Fun' f@ is an @f@ together with an automatically-generated unique number.
newtype Fun f =
  F {
    -- | The unique number of a 'Fun'. Must fit in 32 bits.
    Fun f -> Int
fun_id :: Int }
  deriving (Fun f -> Fun f -> Bool
(Fun f -> Fun f -> Bool) -> (Fun f -> Fun f -> Bool) -> Eq (Fun f)
forall f. Fun f -> Fun f -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Fun f -> Fun f -> Bool
$c/= :: forall f. Fun f -> Fun f -> Bool
== :: Fun f -> Fun f -> Bool
$c== :: forall f. Fun f -> Fun f -> Bool
Eq, Eq (Fun f)
Eq (Fun f)
-> (Fun f -> Fun f -> Ordering)
-> (Fun f -> Fun f -> Bool)
-> (Fun f -> Fun f -> Bool)
-> (Fun f -> Fun f -> Bool)
-> (Fun f -> Fun f -> Bool)
-> (Fun f -> Fun f -> Fun f)
-> (Fun f -> Fun f -> Fun f)
-> Ord (Fun f)
Fun f -> Fun f -> Bool
Fun f -> Fun f -> Ordering
Fun f -> Fun f -> Fun f
forall f. Eq (Fun f)
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall f. Fun f -> Fun f -> Bool
forall f. Fun f -> Fun f -> Ordering
forall f. Fun f -> Fun f -> Fun f
min :: Fun f -> Fun f -> Fun f
$cmin :: forall f. Fun f -> Fun f -> Fun f
max :: Fun f -> Fun f -> Fun f
$cmax :: forall f. Fun f -> Fun f -> Fun f
>= :: Fun f -> Fun f -> Bool
$c>= :: forall f. Fun f -> Fun f -> Bool
> :: Fun f -> Fun f -> Bool
$c> :: forall f. Fun f -> Fun f -> Bool
<= :: Fun f -> Fun f -> Bool
$c<= :: forall f. Fun f -> Fun f -> Bool
< :: Fun f -> Fun f -> Bool
$c< :: forall f. Fun f -> Fun f -> Bool
compare :: Fun f -> Fun f -> Ordering
$ccompare :: forall f. Fun f -> Fun f -> Ordering
$cp1Ord :: forall f. Eq (Fun f)
Ord)

type role Fun nominal

-- | A variable.
newtype Var =
  V {
    -- | The variable's number.
    -- Don't use huge variable numbers:
    -- they will be truncated to 32 bits when stored in a term.
    Var -> Int
var_id :: Int } deriving (Var -> Var -> Bool
(Var -> Var -> Bool) -> (Var -> Var -> Bool) -> Eq Var
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Var -> Var -> Bool
$c/= :: Var -> Var -> Bool
== :: Var -> Var -> Bool
$c== :: Var -> Var -> Bool
Eq, Eq Var
Eq Var
-> (Var -> Var -> Ordering)
-> (Var -> Var -> Bool)
-> (Var -> Var -> Bool)
-> (Var -> Var -> Bool)
-> (Var -> Var -> Bool)
-> (Var -> Var -> Var)
-> (Var -> Var -> Var)
-> Ord Var
Var -> Var -> Bool
Var -> Var -> Ordering
Var -> Var -> Var
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Var -> Var -> Var
$cmin :: Var -> Var -> Var
max :: Var -> Var -> Var
$cmax :: Var -> Var -> Var
>= :: Var -> Var -> Bool
$c>= :: Var -> Var -> Bool
> :: Var -> Var -> Bool
$c> :: Var -> Var -> Bool
<= :: Var -> Var -> Bool
$c<= :: Var -> Var -> Bool
< :: Var -> Var -> Bool
$c< :: Var -> Var -> Bool
compare :: Var -> Var -> Ordering
$ccompare :: Var -> Var -> Ordering
$cp1Ord :: Eq Var
Ord, Int -> Var
Var -> Int
Var -> [Var]
Var -> Var
Var -> Var -> [Var]
Var -> Var -> Var -> [Var]
(Var -> Var)
-> (Var -> Var)
-> (Int -> Var)
-> (Var -> Int)
-> (Var -> [Var])
-> (Var -> Var -> [Var])
-> (Var -> Var -> [Var])
-> (Var -> Var -> Var -> [Var])
-> Enum Var
forall a.
(a -> a)
-> (a -> a)
-> (Int -> a)
-> (a -> Int)
-> (a -> [a])
-> (a -> a -> [a])
-> (a -> a -> [a])
-> (a -> a -> a -> [a])
-> Enum a
enumFromThenTo :: Var -> Var -> Var -> [Var]
$cenumFromThenTo :: Var -> Var -> Var -> [Var]
enumFromTo :: Var -> Var -> [Var]
$cenumFromTo :: Var -> Var -> [Var]
enumFromThen :: Var -> Var -> [Var]
$cenumFromThen :: Var -> Var -> [Var]
enumFrom :: Var -> [Var]
$cenumFrom :: Var -> [Var]
fromEnum :: Var -> Int
$cfromEnum :: Var -> Int
toEnum :: Int -> Var
$ctoEnum :: Int -> Var
pred :: Var -> Var
$cpred :: Var -> Var
succ :: Var -> Var
$csucc :: Var -> Var
Enum)
instance Show Var where
  show :: Var -> String
show Var
x = String
"x" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (Var -> Int
var_id Var
x)

-- | Matches a variable.
pattern Var :: Var -> Term f
pattern $mVar :: forall r f. Term f -> (Var -> r) -> (Void# -> r) -> r
Var x <- (patTerm -> Left x)

-- | Matches a function application.
pattern App :: Fun f -> TermList f -> Term f
pattern $mApp :: forall r f.
Term f -> (Fun f -> TermList f -> r) -> (Void# -> r) -> r
App f ts <- (patTerm -> Right (f, ts))

{-# COMPLETE Var, App #-}

-- A helper function for Var and App.
{-# INLINE patTerm #-}
patTerm :: Term f -> Either Var (Fun f, TermList f)
patTerm :: Term f -> Either Var (Fun f, TermList f)
patTerm Term{Int64
TermList f
termlist :: TermList f
root :: Int64
termlist :: forall f. Term f -> TermList f
root :: forall f. Term f -> Int64
..}
  | Bool
isFun     = (Fun f, TermList f) -> Either Var (Fun f, TermList f)
forall a b. b -> Either a b
Right (Int -> Fun f
forall f. Int -> Fun f
F Int
index, TermList f
ts)
  | Bool
otherwise = Var -> Either Var (Fun f, TermList f)
forall a b. a -> Either a b
Left (Int -> Var
V Int
index)
  where
    Symbol{Bool
Int
size :: Int
index :: Int
isFun :: Bool
size :: Symbol -> Int
index :: Symbol -> Int
isFun :: Symbol -> Bool
..} = Int64 -> Symbol
toSymbol Int64
root
    !UnsafeConsSym{urest :: forall f. TermList f -> TermList f
urest = TermList f
ts} = TermList f
termlist

-- | Convert a term to a termlist.
{-# INLINE singleton #-}
singleton :: Term f -> TermList f
singleton :: Term f -> TermList f
singleton Term{Int64
TermList f
termlist :: TermList f
root :: Int64
termlist :: forall f. Term f -> TermList f
root :: forall f. Term f -> Int64
..} = TermList f
termlist

instance Eq (TermList f) where
  TermList f
t == :: TermList f -> TermList f -> Bool
== TermList f
u =
    TermList f -> Int
forall f. TermList f -> Int
lenList TermList f
t Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== TermList f -> Int
forall f. TermList f -> Int
lenList TermList f
u Bool -> Bool -> Bool
&&
    TermList f -> TermList f -> Ordering
forall f. TermList f -> TermList f -> Ordering
compareSameLength TermList f
t TermList f
u Ordering -> Ordering -> Bool
forall a. Eq a => a -> a -> Bool
== Ordering
EQ

instance Ord (TermList f) where
  {-# INLINE compare #-}
  compare :: TermList f -> TermList f -> Ordering
compare TermList f
t TermList f
u =
    Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (TermList f -> Int
forall f. TermList f -> Int
lenList TermList f
t) (TermList f -> Int
forall f. TermList f -> Int
lenList TermList f
u) Ordering -> Ordering -> Ordering
forall a. Monoid a => a -> a -> a
`mappend`
    TermList f -> TermList f -> Ordering
forall f. TermList f -> TermList f -> Ordering
compareSameLength TermList f
t TermList f
u

{-# INLINE compareSameLength #-}
compareSameLength :: TermList f -> TermList f -> Ordering
compareSameLength :: TermList f -> TermList f -> Ordering
compareSameLength TermList f
t TermList f
u =
  ByteArray -> Int -> ByteArray -> Int -> Int -> Ordering
compareByteArrays (TermList f -> ByteArray
forall f. TermList f -> ByteArray
array TermList f
t) (TermList f -> Int
forall f. TermList f -> Int
low TermList f
t Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
k)
    (TermList f -> ByteArray
forall f. TermList f -> ByteArray
array TermList f
u) (TermList f -> Int
forall f. TermList f -> Int
low TermList f
u Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
k) ((TermList f -> Int
forall f. TermList f -> Int
high TermList f
t Int -> Int -> Int
forall a. Num a => a -> a -> a
- TermList f -> Int
forall f. TermList f -> Int
low TermList f
t) Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
k)
  where
    k :: Int
k = Int
symbolSize

--------------------------------------------------------------------------------
-- Building terms.
--------------------------------------------------------------------------------

-- | A monoid for building terms.
-- 'mempty' represents the empty termlist, while 'mappend' appends two termlists.
newtype Builder f =
  Builder {
    Builder f -> forall s. Builder1 s f
unBuilder ::
      -- Takes: the term array, and current position in the term.
      -- Returns the final array and position.
      forall s. Builder1 s f }

type role Builder nominal

type Builder1 s f = State# s -> MutableByteArray# s -> Int# -> (# State# s, MutableByteArray# s, Int# #)

instance Semigroup (Builder f) where
  {-# INLINE (<>) #-}
  Builder forall s. Builder1 s f
m1 <> :: Builder f -> Builder f -> Builder f
<> Builder forall s. Builder1 s f
m2 = (forall s. Builder1 s f) -> Builder f
forall f. (forall s. Builder1 s f) -> Builder f
Builder (Builder1 s f
forall s. Builder1 s f
m1 Builder1 s f -> Builder1 s f -> Builder1 s f
forall s f. Builder1 s f -> Builder1 s f -> Builder1 s f
`then_` Builder1 s f
forall s. Builder1 s f
m2)
instance Monoid (Builder f) where
  {-# INLINE mempty #-}
  mempty :: Builder f
mempty = (forall s. Builder1 s f) -> Builder f
forall f. (forall s. Builder1 s f) -> Builder f
Builder forall s. Builder1 s f
forall s f. Builder1 s f
built
  {-# INLINE mappend #-}
  mappend :: Builder f -> Builder f -> Builder f
mappend = Builder f -> Builder f -> Builder f
forall a. Semigroup a => a -> a -> a
(<>)

-- Build a termlist from a Builder.
{-# INLINE buildTermList #-}
buildTermList :: Builder f -> TermList f
buildTermList :: Builder f -> TermList f
buildTermList (Builder forall s. Builder1 s f
m) = String -> TermList f -> TermList f
forall symbol a. symbol -> a -> a
stamp String
"build term" (TermList f -> TermList f) -> TermList f -> TermList f
forall a b. (a -> b) -> a -> b
$ (forall s. ST s (TermList f)) -> TermList f
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (TermList f)) -> TermList f)
-> (forall s. ST s (TermList f)) -> TermList f
forall a b. (a -> b) -> a -> b
$ do
  MutableByteArray MutableByteArray# s
marr# <-
    -- Start with a capacity of 16 symbols (arbitrary choice)
    Int -> ST s (MutableByteArray (PrimState (ST s)))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
newByteArray (Int
16 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
symbolSize)
  (MutableByteArray s
marr, Int
n) <-
    STRep s (MutableByteArray s, Int) -> ST s (MutableByteArray s, Int)
forall s a. STRep s a -> ST s a
ST (STRep s (MutableByteArray s, Int)
 -> ST s (MutableByteArray s, Int))
-> STRep s (MutableByteArray s, Int)
-> ST s (MutableByteArray s, Int)
forall a b. (a -> b) -> a -> b
$ \State# s
s ->
      case Builder1 s f
forall s. Builder1 s f
m State# s
s MutableByteArray# s
marr# Int#
0# of
        (# State# s
s, MutableByteArray# s
marr#, Int#
n# #) ->
          (# State# s
s, (MutableByteArray# s -> MutableByteArray s
forall s. MutableByteArray# s -> MutableByteArray s
MutableByteArray MutableByteArray# s
marr#, Int# -> Int
I# Int#
n#) #)
  MutableByteArray (PrimState (ST s)) -> Int -> ST s ()
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> Int -> m ()
shrinkMutableByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
marr (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
symbolSize)
  !ByteArray
arr <- MutableByteArray (PrimState (ST s)) -> ST s ByteArray
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
unsafeFreezeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
marr
  TermList f -> ST s (TermList f)
forall (m :: * -> *) a. Monad m => a -> m a
return (Int -> Int -> ByteArray -> TermList f
forall f. Int -> Int -> ByteArray -> TermList f
TermList Int
0 Int
n ByteArray
arr)

-- A builder which does nothing.
{-# INLINE built #-}
built :: Builder1 s f
built :: Builder1 s f
built State# s
s MutableByteArray# s
arr# Int#
n# = (# State# s
s, MutableByteArray# s
arr#, Int#
n# #)

-- Sequence two builder operations.
{-# INLINE then_ #-}
then_ :: Builder1 s f -> Builder1 s f -> Builder1 s f
Builder1 s f
m1 then_ :: Builder1 s f -> Builder1 s f -> Builder1 s f
`then_` Builder1 s f
m2 = \State# s
s MutableByteArray# s
arr# Int#
n# ->
  case Builder1 s f
m1 State# s
s MutableByteArray# s
arr# Int#
n# of
    (# State# s
s, MutableByteArray# s
arr#, Int#
n# #) ->
      Builder1 s f
m2 State# s
s MutableByteArray# s
arr# Int#
n#

-- Emit an arbitrary symbol, with given arguments.
{-# INLINE emitSymbolBuilder #-}
emitSymbolBuilder :: Symbol -> Builder f -> Builder f
emitSymbolBuilder :: Symbol -> Builder f -> Builder f
emitSymbolBuilder Symbol
x (Builder forall s. Builder1 s f
inner) =
  (forall s. Builder1 s f) -> Builder f
forall f. (forall s. Builder1 s f) -> Builder f
Builder ((forall s. Builder1 s f) -> Builder f)
-> (forall s. Builder1 s f) -> Builder f
forall a b. (a -> b) -> a -> b
$ \State# s
s MutableByteArray# s
arr# Int#
n# ->
    let n :: Int
n = Int# -> Int
I# Int#
n# in
    -- Reserve space for the symbol
    case State# s
-> MutableByteArray# s
-> Int#
-> (# State# s, MutableByteArray# s #)
forall s.
State# s
-> MutableByteArray# s
-> Int#
-> (# State# s, MutableByteArray# s #)
reserve State# s
s MutableByteArray# s
arr# (Int -> Int#
unInt (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)) of
      (# State# s
s, MutableByteArray# s
arr# #) ->
        -- Fill in the argument list
        case State# s
-> MutableByteArray# s
-> Int#
-> (# State# s, MutableByteArray# s, Int# #)
forall s. Builder1 s f
inner State# s
s MutableByteArray# s
arr# (Int -> Int#
unInt (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)) of
          (# State# s
s, MutableByteArray# s
arr#, Int#
m# #) ->
            let arr :: MutableByteArray s
arr = MutableByteArray# s -> MutableByteArray s
forall s. MutableByteArray# s -> MutableByteArray s
MutableByteArray MutableByteArray# s
arr#
                m :: Int
m = Int# -> Int
I# Int#
m# in
            -- Check the length of the argument list in symbols,
            -- then write the symbol, with the correct size
            case ST s () -> State# s -> (# State# s, () #)
forall s a. ST s a -> State# s -> (# State# s, a #)
unST (MutableByteArray (PrimState (ST s)) -> Int -> Int64 -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
writeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
arr Int
n (Symbol -> Int64
fromSymbol Symbol
x { size :: Int
size = Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
n })) State# s
s of
              (# State# s
s, () #) ->
                (# State# s
s, MutableByteArray# s
arr#, Int#
m# #)

-- Emit a function application.
{-# INLINE emitApp #-}
emitApp :: Fun f -> Builder f -> Builder f
emitApp :: Fun f -> Builder f -> Builder f
emitApp (F Int
n) Builder f
inner = Symbol -> Builder f -> Builder f
forall f. Symbol -> Builder f -> Builder f
emitSymbolBuilder (Bool -> Int -> Int -> Symbol
Symbol Bool
True Int
n Int
0) Builder f
inner

-- Emit a variable.
{-# INLINE emitVar #-}
emitVar :: Var -> Builder f
emitVar :: Var -> Builder f
emitVar Var
x = Symbol -> Builder f -> Builder f
forall f. Symbol -> Builder f -> Builder f
emitSymbolBuilder (Bool -> Int -> Int -> Symbol
Symbol Bool
False (Var -> Int
var_id Var
x) Int
1) Builder f
forall a. Monoid a => a
mempty

-- Emit a whole termlist.
{-# INLINE emitTermList #-}
emitTermList :: TermList f -> Builder f
emitTermList :: TermList f -> Builder f
emitTermList (TermList Int
lo Int
hi ByteArray
array) =
  (forall s. Builder1 s f) -> Builder f
forall f. (forall s. Builder1 s f) -> Builder f
Builder ((forall s. Builder1 s f) -> Builder f)
-> (forall s. Builder1 s f) -> Builder f
forall a b. (a -> b) -> a -> b
$ \State# s
s MutableByteArray# s
arr# Int#
n# ->
    let n :: Int
n = Int# -> Int
I# Int#
n# in
    -- Reserve space for the termlist
    case State# s
-> MutableByteArray# s
-> Int#
-> (# State# s, MutableByteArray# s #)
forall s.
State# s
-> MutableByteArray# s
-> Int#
-> (# State# s, MutableByteArray# s #)
reserve State# s
s MutableByteArray# s
arr# (Int -> Int#
unInt (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
hi Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
lo)) of
      (# State# s
s, MutableByteArray# s
arr# #) ->
        let k :: Int
k = Int
symbolSize
            arr :: MutableByteArray s
arr = MutableByteArray# s -> MutableByteArray s
forall s. MutableByteArray# s -> MutableByteArray s
MutableByteArray MutableByteArray# s
arr# in
        case ST s () -> State# s -> (# State# s, () #)
forall s a. ST s a -> State# s -> (# State# s, a #)
unST (MutableByteArray (PrimState (ST s))
-> Int -> ByteArray -> Int -> Int -> ST s ()
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m)
-> Int -> ByteArray -> Int -> Int -> m ()
copyByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
arr (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
k) ByteArray
array (Int
loInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
k) ((Int
hi Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
lo)Int -> Int -> Int
forall a. Num a => a -> a -> a
*Int
k)) State# s
s of
          (# State# s
s, () #) ->
            (# State# s
s, MutableByteArray# s
arr#, Int -> Int#
unInt (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
hi Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
lo) #)

-- Make sure that the term array has enough space to hold the given
-- number of additional symbols.
{-# NOINLINE reserve #-}
reserve :: State# s -> MutableByteArray# s -> Int# -> (# State# s, MutableByteArray# s #)
reserve :: State# s
-> MutableByteArray# s
-> Int#
-> (# State# s, MutableByteArray# s #)
reserve State# s
s MutableByteArray# s
arr# Int#
n# =
  case MutableByteArray (PrimState (ST s))
-> Int -> ST s (MutableByteArray (PrimState (ST s)))
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m)
-> Int -> m (MutableByteArray (PrimState m))
reserve' (MutableByteArray# s -> MutableByteArray s
forall s. MutableByteArray# s -> MutableByteArray s
MutableByteArray MutableByteArray# s
arr#) (Int# -> Int
I# Int#
n#) of
    ST m ->
      case STRep s (MutableByteArray s)
m State# s
s of
        (# State# s
s, MutableByteArray MutableByteArray# s
arr# #) ->
          (# State# s
s, MutableByteArray# s
arr# #)
  where
    {-# INLINE reserve' #-}
    reserve' :: MutableByteArray (PrimState m)
-> Int -> m (MutableByteArray (PrimState m))
reserve' MutableByteArray (PrimState m)
arr Int
n = do
      let !m :: Int
m = Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
symbolSize
      Int
size <- MutableByteArray (PrimState m) -> m Int
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m Int
getSizeofMutableByteArray MutableByteArray (PrimState m)
arr
      if Int
size Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
m then MutableByteArray (PrimState m)
-> m (MutableByteArray (PrimState m))
forall (m :: * -> *) a. Monad m => a -> m a
return MutableByteArray (PrimState m)
arr else MutableByteArray (PrimState m)
-> Int -> Int -> m (MutableByteArray (PrimState m))
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m)
-> Int -> Int -> m (MutableByteArray (PrimState m))
expand MutableByteArray (PrimState m)
arr (Int
sizeInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
2) Int
m
    expand :: MutableByteArray (PrimState m)
-> Int -> Int -> m (MutableByteArray (PrimState m))
expand MutableByteArray (PrimState m)
arr Int
size Int
m
      | Int
size Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
m = MutableByteArray (PrimState m)
-> Int -> m (MutableByteArray (PrimState m))
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m)
-> Int -> m (MutableByteArray (PrimState m))
resizeMutableByteArray MutableByteArray (PrimState m)
arr Int
size
      | Bool
otherwise = MutableByteArray (PrimState m)
-> Int -> Int -> m (MutableByteArray (PrimState m))
expand MutableByteArray (PrimState m)
arr (Int
sizeInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
2) Int
m

unST :: ST s a -> State# s -> (# State# s, a #)
unST :: ST s a -> State# s -> (# State# s, a #)
unST (ST State# s -> (# State# s, a #)
m) = State# s -> (# State# s, a #)
m

unInt :: Int -> Int#
unInt :: Int -> Int#
unInt (I# Int#
n) = Int#
n

----------------------------------------------------------------------
-- Efficient subterm testing.
----------------------------------------------------------------------

-- | Is a term contained as a subterm in a given termlist?
{-# INLINE isSubtermOfList #-}
isSubtermOfList :: Term f -> TermList f -> Bool
isSubtermOfList :: Term f -> TermList f -> Bool
isSubtermOfList Term f
t TermList f
u =
  [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
or [ Term f -> TermList f
forall f. Term f -> TermList f
singleton Term f
t TermList f -> TermList f -> Bool
forall a. Eq a => a -> a -> Bool
== TermList f
u{low :: Int
low = TermList f -> Int
forall f. TermList f -> Int
low TermList f
u Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i, high :: Int
high = TermList f -> Int
forall f. TermList f -> Int
low TermList f
u Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
n}
     | Int
i <- [Int
0..TermList f -> Int
forall f. TermList f -> Int
lenList TermList f
u Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
n]]
  where
    n :: Int
n = TermList f -> Int
forall f. TermList f -> Int
lenList (Term f -> TermList f
forall f. Term f -> TermList f
singleton Term f
t)

-- | Check if a variable occurs in a termlist.
{-# INLINE occursList #-}
occursList :: Var -> TermList f -> Bool
occursList :: Var -> TermList f -> Bool
occursList (V Int
x) TermList f
t = Int64 -> TermList f -> Bool
forall f. Int64 -> TermList f -> Bool
symbolOccursList (Symbol -> Int64
fromSymbol (Bool -> Int -> Int -> Symbol
Symbol Bool
False Int
x Int
1)) TermList f
t

symbolOccursList :: Int64 -> TermList f -> Bool
symbolOccursList :: Int64 -> TermList f -> Bool
symbolOccursList !Int64
_ TermList f
Empty = Bool
False
symbolOccursList Int64
n ConsSym{hd :: forall f. TermList f -> Term f
hd = Term f
t, rest :: forall f. TermList f -> TermList f
rest = TermList f
ts} = Term f -> Int64
forall f. Term f -> Int64
root Term f
t Int64 -> Int64 -> Bool
forall a. Eq a => a -> a -> Bool
== Int64
n Bool -> Bool -> Bool
|| Int64 -> TermList f -> Bool
forall f. Int64 -> TermList f -> Bool
symbolOccursList Int64
n TermList f
ts