{-# LANGUAGE NoMonoPatBinds #-}

-- This option enables polymorphic pattern bindings, a standard
-- feature of Haskell'98, but which is not enabled by default in GHC 

module SparseCheck where

import Control.Monad
import Data.Maybe
import qualified Data.IntMap as IntMap

-- LP Monad

newtype LP a = LP { unLP :: State -> (State -> a -> [[Exp]]) -> [[Exp]] }

instance Monad LP where
  return a = LP (\s k -> k s a)
  m >>= f = LP (\s k -> unLP m s (\s a -> unLP (f a) s k))

instance MonadPlus LP where
  mzero = LP (\s k -> [])
  mplus m n = LP (\s k -> unLP m s k ++ unLP n s k)

runLP :: Int -> LP [Exp] -> [[Exp]]
runLP d m = unLP m (initState d) (\s e -> [e])

-- LP data types

data State = State { env :: IntMap.IntMap Status
                   , fresh :: Int
                   , caseType :: CaseType
                   , depthBound :: Int
                   }

initState :: Int -> State
initState d = State { env = IntMap.empty
                    , fresh = 0
                    , caseType = Instantiate
                    , depthBound = d }

data Status = Unbound Int [Residual] | Bound Exp

type Residual = Exp -> LP ()

newtype Term a = Term { expr :: Exp }

data Exp = Var Var | Ctr Int [Exp]

type Var = Int

data CaseType = Residuate | Instantiate

-- LP classes

data Pair a b = a :- b
  deriving (Eq, Show)

infix 5 ===

class Tuple a where
  free :: LP a
  (===) :: a -> a -> LP ()
  match :: Tuple b => a -> a -> (LP () -> LP b) -> LP b

class Convert a where
  term :: a -> Term a
  unterm :: (Term a -> LP (), Exp -> a)

-- Variables

newVar :: LP Var
newVar = LP (\s k -> let init = Unbound (depthBound s) []
                     in  k (s { fresh = fresh s + 1
                              , env = ins (fresh s) init (env s) }) (fresh s))
  where
    ins = IntMap.insert

readVar :: Var -> LP Status
readVar v = LP (\s k -> k s (env s IntMap.! v))

writeVar :: Var -> Status -> LP ()
writeVar v a = LP (\s k -> k (s { env = ins v a (env s) }) ())
  where
    ins = IntMap.insert

-- Unification

ifBound :: Var -> (Exp -> LP a) -> LP a -> LP a
ifBound v t e = readVar v >>= f
  where
    f (Unbound _ _) = e
    f (Bound a) = t a

root :: Exp -> LP Exp
root (Var v) = ifBound v root (return (Var v))
root a = return a

deepRoot :: Exp -> LP Exp
deepRoot e = do val <- root e
                case val of
                  Var v -> return (Var v)
                  Ctr c es -> liftM (Ctr c) (mapM deepRoot es)

unifyExp :: Exp -> Exp -> LP ()
unifyExp a b = do ra <- root a ; rb <- root b ; u ra rb
  where
    u (Var v)    (Var w)    | v == w = return ()
    u (Var v)    (Var w)             = unifyVar v w
    u (Var v)    b                   = bindVar v b
    u a          (Var w)             = bindVar w a
    u (Ctr c as) (Ctr d bs) | c == d = zipWithM_ unifyExp as bs
    u _          _                   = mzero

-- Residuation

unifyVar :: Var -> Var -> LP ()
unifyVar v w = do Unbound d0 rs0 <- readVar v
                  Unbound d1 rs1 <- readVar w
                  writeVar w (Unbound (min d0 d1) (rs0 ++ rs1))
                  writeVar v (Bound (Var w))

bindVar :: Var -> Exp -> LP ()
bindVar v a = do Unbound d rs <- readVar v
                 when (d == 0) mzero
                 when (d > 0) (setDepth d a)
                 writeVar v (Bound a)
                 rs `resumeOn` a

setDepth :: Int -> Exp -> LP ()
setDepth d0 (Var w) =
  do val <- readVar w
     case val of
       Unbound d1 rs -> when (d0 <= d1) (writeVar w (Unbound d0 rs))
       Bound e -> setDepth d0 e
setDepth d0 (Ctr c es) = mapM_ (setDepth (d0-1)) es

resumeOn :: [Residual] -> Exp -> LP ()
resumeOn rs (Var v) = do Unbound d ss <- readVar v
                         writeVar v (Unbound d (rs ++ ss))
resumeOn rs a = mapM_ ($ a) rs

rigidExp :: Tuple b => (Exp -> LP b) -> Exp -> LP b
rigidExp f a = do ra <- root a
                  b <- free
                  let g x = do c <- f x ; b === c
                  [g] `resumeOn` ra
                  return b

rigid :: Tuple b => (Term a -> LP b) -> (Term a -> LP b)
rigid f a = rigidExp (f . Term) (expr a)

resid :: LP a -> LP a
resid m = LP (\s k -> unLP m (s { caseType = Residuate })
                             (\s' k' -> k (s' { caseType = caseType s }) k'))

eq :: Exp -> Exp -> LP () -> LP ()
eq a b sk = rigidExp (\a -> rigidExp (\b -> f a b) b) a
  where
    f (Ctr c as) (Ctr d bs) | c == d = zipEq as bs
    f _          _                   = return ()

    zipEq [] [] = sk
    zipEq (a:as) (b:bs) = eq a b (zipEq as bs)

(=/=) :: Term a -> Term a -> LP ()
a =/= b = eq (expr a) (expr b) mzero

-- Matching

matchExp :: Tuple a => Exp -> Exp -> (LP () -> LP a) -> LP a
matchExp a b k = do ra <- root a ; rb <- root b ; m ra rb
  where
    m (Var v)    (Var w)    | v == w = k (return ())
    m (Var v)    b                   = k (bindVar v b)
    m _          (Var w)             = rigidExp (\b -> matchExp a b k) b
    m (Ctr c as) (Ctr d bs) | c == d = mzip as bs (return ())
    m _          _                   = k mzero
      -- doesn't fail immediately, but probably should

    mzip [] [] m = k m
    mzip (a:as) (b:bs) m = matchExp a b (\n -> mzip as bs (m >> n))

-- Tuple instances

instance Tuple () where
  free = return ()
  _ === _ = return ()
  match _ _ k = k (return ())

instance Tuple (Term a) where
  free = do v <- newVar ; return (Term (Var v))
  a === b = unifyExp (expr a) (expr b)
  match a b k = matchExp (expr a) (expr b) k

instance (Tuple a, Tuple b) => Tuple (Pair a b) where
  free = do v0 <- free ; v1 <- free ; return (v0 :- v1)
  (a0 :- a1) === (b0 :- b1) = a0 === b0 >> a1 === b1
  match (a0 :- a1) (b0 :- b1) k =
    match a0 b0 (\m -> match a1 b1 (\n -> k (m >> n)))

instance (Tuple a, Tuple b) => Tuple (a, b) where
  free = do v0 <- free ; v1 <- free ; return (v0, v1)
  (a0, a1) === (b0, b1) = a0 === b0 >> a1 === b1
  match (a0, a1) (b0, b1) k =
    match a0 b0 (\m -> match a1 b1 (\n -> k (m >> n)))

instance (Tuple a, Tuple b, Tuple c) => Tuple (a, b, c) where
  free = do v0 <- free ; v1 <- free ; v2 <- free ; return (v0, v1, v2)
  (a0, a1, a2) === (b0, b1, b2) = a0 === b0 >> a1 === b1 >> a2 === b2
  match (a0, a1, a2) (b0, b1, b2) k =
    match a0 b0 $ \m0 -> match a1 b1 $ \m1 -> match a2 b2 $ \m2 ->
      k (m0 >> m1 >> m2)

instance (Tuple a, Tuple b, Tuple c, Tuple d) => Tuple (a, b, c, d) where
  free = do v0 <- free ; v1 <- free ; v2 <- free
            v3 <- free ; return (v0, v1, v2, v3)
  (a0, a1, a2, a3) === (b0, b1, b2, b3) = a0 === b0 >> a1 === b1 >>
    a2 === b2 >> a3 === b3
  match (a0, a1, a2, a3) (b0, b1, b2, b3) k =
    match a0 b0 $ \m0 -> match a1 b1 $ \m1 -> match a2 b2 $ \m2 ->
      match a3 b3 $ \m3 -> k (m0 >> m1 >> m2 >> m3)

-- Pattern matching

infixr 1 :|:
infix 2 :-> 

data Alts a b = a :-> b | Alts a b :|: Alts a b

flattenAlts :: Alts a b -> [(a, b)]
flattenAlts (a :-> b) = [(a, b)]
flattenAlts (a :|: b) = flattenAlts a ++ flattenAlts b

getCaseType :: LP CaseType
getCaseType = LP (\s k -> k s (caseType s))

instantiate :: Tuple a => a -> [(a, LP b)] -> LP b
instantiate a as = do (pat, rhs) <- msum (map return as)
                      pat === a >> rhs

residuate :: (Tuple a, Tuple b) => a -> [(a, LP b)] -> LP b
residuate a [] = mzero
residuate a ((pat, rhs):rest) =
  match pat a (\m -> (m >> rhs) ? residuate a rest)

caseOf :: (Tuple a, Tuple b, Tuple c) => a -> (b -> Alts a (LP c)) -> LP c
caseOf a alts = do ct <- getCaseType
                   vs <- free
                   let as = flattenAlts (alts vs)
                   case ct of
                     Instantiate -> instantiate a as
                     Residuate -> residuate a as

-- Logical interface

infixr 3 ?
infixr 4 &

type Pred = LP ()

(?) :: LP a -> LP a -> LP a
(?) = mplus

(&) :: LP a -> LP b -> LP b
(&) = (>>)

exists :: Tuple a => (a -> LP b) -> LP b
exists f = free >>= f

true :: LP ()
true = return ()

false :: LP a
false = mzero

-- Solving interface

solveHelp :: Int -> ([Exp] -> LP a) -> LP [Exp]
solveHelp n p =
  do vs <- sequence (replicate n (liftM Var newVar))
     p vs
     mapM deepRoot vs

solve :: Convert a => Int -> (Term a -> Pred) -> [a]
solve d p = map f $ runLP d $ solveHelp 1 p'
  where
    f [a] = conv a
    p' [a] = p (Term a) >> inst (Term a)
    (inst, conv) = unterm

solve2 :: (Convert a, Convert b) => Int ->
            (Term a -> Term b -> Pred) -> [(a, b)]
solve2 d p = map f $ runLP d $ solveHelp 2 p'
  where
    f [a,b] = (conv0 a, conv1 b)
    p' [a,b] = p (Term a) (Term b) >> inst0 (Term a) >> inst1 (Term b)
    (inst0, conv0) = unterm
    (inst1, conv1) = unterm

solve3 :: (Convert a, Convert b, Convert c) => Int ->
            (Term a -> Term b -> Term c -> Pred) -> [(a, b, c)]
solve3 d p = map f $ runLP d $ solveHelp 3 p'
  where
    f [a,b,c] = (conv0 a, conv1 b, conv2 c)
    p' [a,b,c] = p (Term a) (Term b) (Term c) >> inst0 (Term a) >>
                   inst1 (Term b) >> inst2 (Term c)
    (inst0, conv0) = unterm
    (inst1, conv1) = unterm
    (inst2, conv2) = unterm

solve4 :: (Convert a, Convert b, Convert c, Convert d) => Int ->
            (Term a -> Term b -> Term c -> Term d -> Pred) -> [(a, b, c, d)]
solve4 d p = map f $ runLP d $ solveHelp 4 p'
  where
    f [a,b,c,d] = (conv0 a, conv1 b, conv2 c, conv3 d)
    p' [a,b,c,d] = p (Term a) (Term b) (Term c) (Term d) >> inst0 (Term a) >>
                     inst1 (Term b) >> inst2 (Term c) >> inst3 (Term d)
    (inst0, conv0) = unterm
    (inst1, conv1) = unterm
    (inst2, conv2) = unterm
    (inst3, conv3) = unterm

-- Property checking interface

check n p q = do a <- solve n p ; guard (not (q a)) ; return a

check2 n p q = do (a, b) <- solve2 n p ; guard (not (q a b)) ; return (a, b)

check3 n p q = do (a, b, c) <- solve3 n p
                  guard (not (q a b c))
                  return (a, b, c)

check4 n p q = do (a, b, c, d) <- solve4 n p
                  guard (not (q a b c d))
                  return (a, b, c, d)

-- Predicate to boolean function conversion

lower p a = not $ null $ solve (-1) (\x -> p (term a) & x === tru)

lower2 p a b = not $ null $ solve (-1) (\x -> p (term a) (term b) & x === tru)

lower3 p a b c = not $ null $ solve (-1)
                   (\x -> p (term a) (term b) (term c) & x === tru)

lower4 p a b c d = not $ null $ solve (-1)
                     (\x -> p (term a) (term b) (term c) (term d) & x === tru)

-- Data type construction

ctr0 :: a -> Int -> Term a
ctr0 f n = Term (Ctr n [])

ctr1 :: (a -> b) -> Int -> Term a -> Term b
ctr1 f n = \a -> Term (Ctr n [expr a])

ctr2 :: (a -> b -> c) -> Int -> Term a -> Term b -> Term c
ctr2 f n = \a b -> Term (Ctr n [expr a, expr b])

ctr3 :: (a -> b -> c -> d) -> Int -> Term a -> Term b -> Term c -> Term d
ctr3 f n = \a b c -> Term (Ctr n [expr a, expr b, expr c])

ctr4 :: (a -> b -> c -> d -> e) -> Int ->
          Term a -> Term b -> Term c -> Term d -> Term e
ctr4 f n = \a b c d -> Term (Ctr n [expr a, expr b, expr c, expr d])

infixr 5 \/

(\/) :: (Int -> b) -> (Int -> c) -> Int -> Pair b c
(f \/ g) n = f n :- g (n+1)

datatype :: (Int -> b) -> b
datatype f = f 0

-- Conversion helpers

type Family = [(Int, Int)]

type Conv a = (Term a -> Pred, Exp -> Maybe a)

instCtr :: Var -> (Int, Int) -> Pred
instCtr v (c, n) = do vs <- sequence (replicate n newVar)
                      bindVar v (Ctr c (map Var vs))

mkInst :: Int -> Int -> ([Exp] -> Pred) -> Term a -> Pred
mkInst c n f (Term (Var v)) =
  do val <- root (Var v)
     case val of
       Var v -> do instCtr v (c, n)
                   mkInst c n f (Term (Var v))
       _ -> mkInst c n f (Term val)
mkInst c n f (Term (Ctr ctr es))
  | ctr == c = f es
  | otherwise = mzero

mkConv :: Int -> ([Exp] -> a) -> Exp -> Maybe a
mkConv n f (Ctr c es) | n == c = Just (f es)
mkConv _ _ _ = Nothing

conv0 :: a -> Int -> Conv a
conv0 f n = ( mkInst n 0 (\[] -> return ())
            , mkConv n (\[] -> f) )

conv1 :: Convert a => (a -> b) -> Int -> Conv b
conv1 f n = ( mkInst n 1 (\[a] -> i0 (Term a))
            , mkConv n (\[a] -> f (c0 a)) )
  where
    (i0, c0) = unterm

conv2 :: (Convert a, Convert b) => (a -> b -> c) -> Int -> Conv c
conv2 f n = ( mkInst n 2 (\[a, b] -> i0 (Term a) >> i1 (Term b))
            , mkConv n (\[a, b] -> f (c0 a) (c1 b)) )
  where
    (i0, c0) = unterm
    (i1, c1) = unterm

conv3 :: (Convert a, Convert b, Convert c) => (a -> b -> c -> d) ->
           Int -> Conv d
conv3 f n = ( mkInst n 3 (\[a, b, c] -> i0 (Term a) >> i1 (Term b)
                                     >> i2 (Term c))
            , mkConv n (\[a, b, c] -> f (c0 a) (c1 b) (c2 c)) )
  where
    (i0, c0) = unterm
    (i1, c1) = unterm
    (i2, c2) = unterm

conv4 :: (Convert a, Convert b, Convert c, Convert d) =>
           (a -> b -> c -> d -> e) -> Int -> Conv e
conv4 f n = ( mkInst n 4 (\[a, b, c, d] -> i0 (Term a) >> i1 (Term b)
                                        >> i2 (Term c) >> i3 (Term d))
            , mkConv n (\[a, b, c, d] -> f (c0 a) (c1 b) (c2 c) (c3 c)) )
  where
    (i0, c0) = unterm
    (i1, c1) = unterm
    (i2, c2) = unterm
    (i3, c3) = unterm

infixr 5 -+-

(a -+- b) n = (\a -> i0 a `mplus` i1 a, \a -> c0 a `mplus` c1 a)
  where
    (i0, c0) = a n
    (i1, c1) = b (n+1)

converter f = let (i, c) = f 0 in (i, fromJust . c)

-- Algebraic types from Prelude

infixr 7 ***, :-

(***) = datatype (ctr2 (:-))

instance (Convert a, Convert b) => Convert (Pair a b) where
  term (a :- b) = term a *** term b
  unterm = converter (conv2 (:-))

pair = datatype (ctr2 (,))

instance (Convert a, Convert b) => Convert (a, b) where
  term (a, b) = pair (term a) (term b)
  unterm = converter (conv2 (,))

(fal :- tru) = datatype (ctr0 False \/ ctr0 True)

instance Convert Bool where
  term False = fal
  term True = tru
  unterm = converter (conv0 False -+- conv0 True)

(nothing :- just) = datatype (ctr0 Nothing \/ ctr1 Just)

instance Convert a => Convert (Maybe a) where
  term Nothing = nothing
  term (Just a) = just (term a)
  unterm = converter (conv0 Nothing -+- conv1 Just)

(nil :- (|>)) = datatype (ctr0 [] \/ ctr2 (:))

instance Convert a => Convert [a] where
  term [] = nil
  term (a:as) = term a |> term as
  unterm = converter (conv0 [] -+- conv2 (:))

(left :- right) = datatype (ctr1 Left \/ ctr1 Right)

instance (Convert a, Convert b) => Convert (Either a b) where
  term (Left a) = left (term a)
  term (Right a) = right (term a)
  unterm = converter (conv1 Left -+- conv1 Right)

-- Int arithmetic (currently positive integers only)

zeroInt :: Int
zeroInt = 0

succInt :: Int -> Int
succInt =  succ

(zero :- suc) = datatype (ctr0 zeroInt \/ ctr1 succInt)

instance Convert Int where
  term n = i n
  unterm = converter (conv0 zeroInt -+- conv1 succInt)

add :: Term Int -> Term Int -> Term Int -> Pred
add a b c = caseOf (a, c) alts where
              alts (a, c) =
                    (zero, c) :-> b === c
                :|: (suc a, suc c) :-> add a b c

sub' :: Term Int -> Term Int -> Term (Either Int Int) -> Pred
sub' a b c = caseOf (a, b, c) alts where
               alts (a, b, c) =
                     (a, zero, right a) :-> return ()
                 :|: (zero, a, left a) :-> return ()
                 :|: (suc a, suc b, c) :-> sub' a b c

sub :: Term Int -> Term Int -> Term Int -> Pred
sub a b c = sub' a b (right c)

mul :: Term Int -> Term Int -> Term Int -> Pred
mul a b c = caseOf a alts where
              alts (a, t) =
                    zero :-> c === zero
                :|: suc a :-> mul a b t & add b t c

pow :: Term Int -> Term Int -> Term Int -> Pred
pow a b c = caseOf b alts where
              alts (b, t) =
                    zero :-> c === i 1
                :|: suc b :-> pow a b t & mul a t c

quotrem :: Term Int -> Term Int -> Term Int -> Term Int -> Pred
quotrem a b c d = exists $ \diff ->
                    do sub' a b diff
                       caseOf diff alts where
                         alts (a', q) =
                               left a' :-> c === zero & d === a
                           :|: right a' :-> do quotrem a' b q d
                                               c === suc q

class Ordered a where
  (|<|) :: Term a -> Term a -> Pred
  (|>|) :: Term a -> Term a -> Pred
  (|<=|) :: Term a -> Term a -> Pred
  (|>=|) :: Term a -> Term a -> Pred


instance Ordered Int where
  a |<| b = caseOf (a, b) alts where
              alts (a, b) =
                    (zero, suc b) :-> return ()
                :|: (suc a, suc b) :-> a |<| b

  a |<=| b = caseOf (a, b) alts where
               alts (a, b) =
                     (zero, b) :-> return ()
                 :|: (suc a, suc b) :-> a |<=| b

  a |>| b = caseOf (a, b) alts where
              alts (a, b) =
                    (suc a, zero) :-> return ()
                :|: (suc a, suc b) :-> a |>| b

  a |>=| b = caseOf (a, b) alts where
               alts (a, b) =
                     (a, zero) :-> return ()
                 :|: (suc a, suc b) :-> a |>=| b

i :: Int -> Term Int
i 0 = zero
i n = suc (i (n-1))

-- A few important list predicates

append :: Term [a] -> Term [a] -> Term [a] -> Pred
append as bs cs = caseOf (as, cs) alts where
                    alts (a, as, cs) =
                          (nil, cs) :-> bs === cs
                      :|: (a |> as, a |> cs) :-> append as bs cs

len :: Term [a] -> Term Int -> Pred
len as n = caseOf as alts where
             alts (a, as, m) =
                   nil :-> n === i 0
               :|: (a |> as) :-> len as m & add m (i 1) n

forall :: Term [a] -> (Term a -> Pred) -> Pred
forall as p = caseOf as alts where
                alts (a, as) =
                      nil :-> return ()
                  :|: (a |> as) :-> p a & forall as p

forany :: Term [a] -> (Term a -> Pred) -> Pred
forany as p = caseOf as alts where
                alts (a, as) =
                      nil :-> mzero
                  :|: (a |> as) :-> p a ? forany as p

mapP :: (Term a -> Term b -> Pred) -> Term [a] -> Term [b] -> Pred
mapP p xs ys = caseOf (xs, ys) alts where
                 alts (x, xs, y, ys) =
                       (nil, nil) :-> return ()
                   :|: (x |> xs, y |> ys) :-> p x y & mapP p xs ys