{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PatternSynonyms #-}
module Horname.Internal.SMT where

import           Data.Data
import           Data.List (foldl', find)
import qualified Data.List as List
import           Data.Map (Map)
import qualified Data.Map as Map
import           Data.Maybe
import           Data.Text (Text)

newtype VarName =
  VarName Text
  deriving (Show, Eq, Ord, Data)

newtype Sort =
  Sort Text
  deriving (Show, Eq, Ord, Data)

data Arg = Arg
  { argName :: !VarName
  , argSort :: !Sort
  } deriving (Show, Eq, Ord, Data)

data DefineFun = DefineFun
  { funName :: !Text
  , arguments :: ![Arg]
  , returnSort :: !Sort
  , body :: !SExpr
  } deriving (Show, Eq, Ord, Data)

data SExpr
  = IntLit !Integer
  | StringLit !Text
  | List ![SExpr]
  deriving (Show, Eq, Ord, Data)

renameDefineFun :: [Text] -> DefineFun -> DefineFun
renameDefineFun newNames (DefineFun n args retSort expr) =
  DefineFun n renamedArgs retSort (renameInBody varMap expr)
  where
    renamedArgs =
      zipWith (\(Arg _ sort) name -> Arg (VarName name) sort) args newNames
    varMap :: Map Text Text
    varMap =
      Map.fromList $ zip (map (\(Arg (VarName name) _) -> name) args) newNames
    renameInBody :: Map Text Text -> SExpr -> SExpr
    renameInBody m (StringLit s) =
      case Map.lookup s m of
        Just s' -> StringLit s'
        Nothing -> StringLit s
    renameInBody _ (IntLit i) = IntLit i
    renameInBody m (List exprs) = List (map (renameInBody m) exprs)

insertBindings :: Map Text SExpr -> [SExpr] -> Map Text SExpr
insertBindings m bindings = foldl' (insertBinding m) m bindings

insertBinding :: Map Text SExpr -> Map Text SExpr -> SExpr -> Map Text SExpr
insertBinding oldMap m (List [StringLit key, val]) = Map.insert key (inlineLets' oldMap val) m
insertBinding _ _ _ = error "Syntax error in let bindings"

-- | This is not correct in the case of quantifiers but ignoring this
-- simplifies the implementation and seems to be enough for z3 and
-- eldarica
inlineLets :: SExpr -> SExpr
inlineLets = inlineLets' Map.empty

inlineLets' :: Map Text SExpr -> SExpr -> SExpr
inlineLets' _ (IntLit i) = IntLit i
inlineLets' m (StringLit t) =
  case Map.lookup t m of
    Just val -> val
    Nothing -> StringLit t
inlineLets' m (List [StringLit "let", List bindings, expr]) =
  inlineLets' (insertBindings m bindings) expr
inlineLets' m (List args) = List (map (inlineLets' m) args)

comparisonOps :: [Text]
comparisonOps = ["=", "<", "<=", ">", ">="]

partitionPosNeg :: SExpr -> ([SExpr],[SExpr])
partitionPosNeg (List (StringLit "+":args)) =
  partition
    (\case
       (List [StringLit "-", e]) -> Right e
       e -> Left e)
    args
partitionPosNeg (Neg e) = ([], [e])
partitionPosNeg e = ([e], [])

partition               :: (a -> Either b c) -> [a] -> ([b],[c])
partition p xs = foldr (select p) ([],[]) xs

select :: (a -> Either b c) -> a -> ([b], [c]) -> ([b], [c])
select p x (bs,cs) =
  case p x of
    Left b -> (b:bs,cs)
    Right c -> (bs, c:cs)

nonZero :: SExpr -> Bool
nonZero (IntLit 0) = False
nonZero _ = True

sumExprs :: [SExpr] -> SExpr
sumExprs [] = IntLit 0
sumExprs [e] = e
sumExprs args = List (StringLit "+" : args)

pattern And :: [SExpr] -> SExpr
pattern And args = List (StringLit "and" : args)

pattern Neg :: SExpr -> SExpr
pattern Neg arg = List [StringLit "-", arg]

pattern Or :: [SExpr] -> SExpr
pattern Or args = List (StringLit "or" : args)

pattern BinOp :: Text -> SExpr -> SExpr -> SExpr
pattern BinOp name op1 op2 = List [StringLit name, op1, op2]

-- first pass of simplifications
simplify :: SExpr -> SExpr
-- (* (- 1) x) → x
simplify (BinOp "*" (Neg (IntLit i)) expr) =
  let expr' =
        if i == 1
          then expr
          else BinOp "*" (IntLit i) expr
  in Neg expr'
-- merge nested ands
simplify (And args) = And (andArgs ++ others)
  where
    (ands, others) =
      partition
        (\case
           And args' -> Left args'
           e -> Right e)
        args
    andArgs = concat ands
-- pull out common subexpressions of disjunctions
simplify (Or (arg:args)) =
  if null commonSubExprs
    then Or (arg : args)
    else And
           (commonSubExprs ++
            [Or (map (removeSubExprs commonSubExprs) (arg : args))])
  where
    commonSubExprs =
      filter
        (\arg' -> all (arg' `subsumedBy`) args)
        (case arg of
           And exprs -> exprs
           _ -> [])
    subsumedBy :: SExpr -> SExpr -> Bool
    subsumedBy e (And args') = e `elem` args'
    subsumedBy _ _ = False
    removeSubExprs :: [SExpr] -> SExpr -> SExpr
    removeSubExprs subExprs (And exprs) =
      And (filter (\e -> not (e `elem` subExprs)) exprs)
    removeSubExprs _ e = e
-- Move negative and positive arguments to the same side of a comparison
simplify (BinOp opName arg1 arg2)
  | opName `elem` comparisonOps =
    case (partitionPosNeg arg1, partitionPosNeg arg2) of
      ((posLeft, negLeft), (posRight, negRight)) ->
        List
          [ StringLit opName
          , sumExprs . filter nonZero $ (posLeft ++ negRight)
          , sumExprs . filter nonZero $ (posRight ++ negLeft)
          ]
-- Transform (+ a (- b c)) to (+ a b (- c))
simplify (List (StringLit "+":args)) =
  List (StringLit "+" : (sepSubtraction =<< args))
  where
    sepSubtraction :: SExpr -> [SExpr]
    sepSubtraction (BinOp "-" arg1 arg2) = [arg1, Neg arg2]
    sepSubtraction e = [e]
-- transform (not (or (not …))) to (and …)
simplify (List [StringLit "not", Or args]) = And (map negateExpr args)
simplify e = e

antiSymmetricOp :: Text -> Bool
antiSymmetricOp n = n `elem` ["<=",">="]
-- second pass of simplifications
simplify' :: SExpr -> SExpr
-- transform two inequalities to an equality
simplify' (And args) = And (other ++ mergeInequalities inequality)
  where
    (inequality, other) =
      List.partition
        (\case
           List [StringLit op, _, _] -> antiSymmetricOp op
           _ -> False)
        args
    mergeInequalities :: [SExpr] -> [SExpr]
    mergeInequalities [] = []
    mergeInequalities (e@(List [StringLit op, expr1, expr2]):rest) =
      let reversedE = List [StringLit op, expr2, expr1]
      in if reversedE `elem` rest
           then BinOp "=" expr1 expr2 :
                mergeInequalities (filter (not . (`elem` [e, reversedE])) rest)
           else e : mergeInequalities rest
    mergeInequalities (e:es) = e : mergeInequalities es
simplify' e = e


negateExpr :: SExpr -> SExpr
negateExpr (List [StringLit "not", expr]) = expr
negateExpr expr = List [StringLit "not", expr]

extractDefinitions :: Map Text [Text] -> [DefineFun] -> [DefineFun]
extractDefinitions decls defs =
  mapMaybe
    (\(name, argNames) ->
       renameDefineFun argNames <$> find ((== name) . funName) defs)
    (Map.toList decls)