-- | Generalization (anti-unification) of 'PrimExp's.
module Futhark.Analysis.PrimExp.Generalize
  ( leastGeneralGeneralization,
  )
where

import Data.List (elemIndex)
import Futhark.Analysis.PrimExp
import Futhark.IR.Syntax.Core (Ext (..))

-- | Generalize two 'PrimExp's of the the same type.
leastGeneralGeneralization ::
  (Eq v) =>
  [(PrimExp v, PrimExp v)] ->
  PrimExp v ->
  PrimExp v ->
  (PrimExp (Ext v), [(PrimExp v, PrimExp v)])
leastGeneralGeneralization :: forall v.
Eq v =>
[(PrimExp v, PrimExp v)]
-> PrimExp v
-> PrimExp v
-> (PrimExp (Ext v), [(PrimExp v, PrimExp v)])
leastGeneralGeneralization [(PrimExp v, PrimExp v)]
m exp1 :: PrimExp v
exp1@(LeafExp v
v1 PrimType
t1) exp2 :: PrimExp v
exp2@(LeafExp v
v2 PrimType
_) =
  if v
v1 v -> v -> Bool
forall a. Eq a => a -> a -> Bool
== v
v2
    then (Ext v -> PrimType -> PrimExp (Ext v)
forall v. v -> PrimType -> PrimExp v
LeafExp (v -> Ext v
forall a. a -> Ext a
Free v
v1) PrimType
t1, [(PrimExp v, PrimExp v)]
m)
    else [(PrimExp v, PrimExp v)]
-> PrimExp v
-> PrimExp v
-> (PrimExp (Ext v), [(PrimExp v, PrimExp v)])
forall v.
Eq v =>
[(PrimExp v, PrimExp v)]
-> PrimExp v
-> PrimExp v
-> (PrimExp (Ext v), [(PrimExp v, PrimExp v)])
generalize [(PrimExp v, PrimExp v)]
m PrimExp v
exp1 PrimExp v
exp2
leastGeneralGeneralization [(PrimExp v, PrimExp v)]
m exp1 :: PrimExp v
exp1@(ValueExp PrimValue
v1) exp2 :: PrimExp v
exp2@(ValueExp PrimValue
v2) =
  if PrimValue
v1 PrimValue -> PrimValue -> Bool
forall a. Eq a => a -> a -> Bool
== PrimValue
v2
    then (PrimValue -> PrimExp (Ext v)
forall v. PrimValue -> PrimExp v
ValueExp PrimValue
v1, [(PrimExp v, PrimExp v)]
m)
    else [(PrimExp v, PrimExp v)]
-> PrimExp v
-> PrimExp v
-> (PrimExp (Ext v), [(PrimExp v, PrimExp v)])
forall v.
Eq v =>
[(PrimExp v, PrimExp v)]
-> PrimExp v
-> PrimExp v
-> (PrimExp (Ext v), [(PrimExp v, PrimExp v)])
generalize [(PrimExp v, PrimExp v)]
m PrimExp v
exp1 PrimExp v
exp2
leastGeneralGeneralization [(PrimExp v, PrimExp v)]
m exp1 :: PrimExp v
exp1@(BinOpExp BinOp
op1 PrimExp v
e11 PrimExp v
e12) exp2 :: PrimExp v
exp2@(BinOpExp BinOp
op2 PrimExp v
e21 PrimExp v
e22) =
  if BinOp
op1 BinOp -> BinOp -> Bool
forall a. Eq a => a -> a -> Bool
== BinOp
op2
    then
      let (PrimExp (Ext v)
e1, [(PrimExp v, PrimExp v)]
m1) = [(PrimExp v, PrimExp v)]
-> PrimExp v
-> PrimExp v
-> (PrimExp (Ext v), [(PrimExp v, PrimExp v)])
forall v.
Eq v =>
[(PrimExp v, PrimExp v)]
-> PrimExp v
-> PrimExp v
-> (PrimExp (Ext v), [(PrimExp v, PrimExp v)])
leastGeneralGeneralization [(PrimExp v, PrimExp v)]
m PrimExp v
e11 PrimExp v
e21
          (PrimExp (Ext v)
e2, [(PrimExp v, PrimExp v)]
m2) = [(PrimExp v, PrimExp v)]
-> PrimExp v
-> PrimExp v
-> (PrimExp (Ext v), [(PrimExp v, PrimExp v)])
forall v.
Eq v =>
[(PrimExp v, PrimExp v)]
-> PrimExp v
-> PrimExp v
-> (PrimExp (Ext v), [(PrimExp v, PrimExp v)])
leastGeneralGeneralization [(PrimExp v, PrimExp v)]
m1 PrimExp v
e12 PrimExp v
e22
       in (BinOp -> PrimExp (Ext v) -> PrimExp (Ext v) -> PrimExp (Ext v)
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp BinOp
op1 PrimExp (Ext v)
e1 PrimExp (Ext v)
e2, [(PrimExp v, PrimExp v)]
m2)
    else [(PrimExp v, PrimExp v)]
-> PrimExp v
-> PrimExp v
-> (PrimExp (Ext v), [(PrimExp v, PrimExp v)])
forall v.
Eq v =>
[(PrimExp v, PrimExp v)]
-> PrimExp v
-> PrimExp v
-> (PrimExp (Ext v), [(PrimExp v, PrimExp v)])
generalize [(PrimExp v, PrimExp v)]
m PrimExp v
exp1 PrimExp v
exp2
leastGeneralGeneralization [(PrimExp v, PrimExp v)]
m exp1 :: PrimExp v
exp1@(CmpOpExp CmpOp
op1 PrimExp v
e11 PrimExp v
e12) exp2 :: PrimExp v
exp2@(CmpOpExp CmpOp
op2 PrimExp v
e21 PrimExp v
e22) =
  if CmpOp
op1 CmpOp -> CmpOp -> Bool
forall a. Eq a => a -> a -> Bool
== CmpOp
op2
    then
      let (PrimExp (Ext v)
e1, [(PrimExp v, PrimExp v)]
m1) = [(PrimExp v, PrimExp v)]
-> PrimExp v
-> PrimExp v
-> (PrimExp (Ext v), [(PrimExp v, PrimExp v)])
forall v.
Eq v =>
[(PrimExp v, PrimExp v)]
-> PrimExp v
-> PrimExp v
-> (PrimExp (Ext v), [(PrimExp v, PrimExp v)])
leastGeneralGeneralization [(PrimExp v, PrimExp v)]
m PrimExp v
e11 PrimExp v
e21
          (PrimExp (Ext v)
e2, [(PrimExp v, PrimExp v)]
m2) = [(PrimExp v, PrimExp v)]
-> PrimExp v
-> PrimExp v
-> (PrimExp (Ext v), [(PrimExp v, PrimExp v)])
forall v.
Eq v =>
[(PrimExp v, PrimExp v)]
-> PrimExp v
-> PrimExp v
-> (PrimExp (Ext v), [(PrimExp v, PrimExp v)])
leastGeneralGeneralization [(PrimExp v, PrimExp v)]
m1 PrimExp v
e12 PrimExp v
e22
       in (CmpOp -> PrimExp (Ext v) -> PrimExp (Ext v) -> PrimExp (Ext v)
forall v. CmpOp -> PrimExp v -> PrimExp v -> PrimExp v
CmpOpExp CmpOp
op1 PrimExp (Ext v)
e1 PrimExp (Ext v)
e2, [(PrimExp v, PrimExp v)]
m2)
    else [(PrimExp v, PrimExp v)]
-> PrimExp v
-> PrimExp v
-> (PrimExp (Ext v), [(PrimExp v, PrimExp v)])
forall v.
Eq v =>
[(PrimExp v, PrimExp v)]
-> PrimExp v
-> PrimExp v
-> (PrimExp (Ext v), [(PrimExp v, PrimExp v)])
generalize [(PrimExp v, PrimExp v)]
m PrimExp v
exp1 PrimExp v
exp2
leastGeneralGeneralization [(PrimExp v, PrimExp v)]
m exp1 :: PrimExp v
exp1@(UnOpExp UnOp
op1 PrimExp v
e1) exp2 :: PrimExp v
exp2@(UnOpExp UnOp
op2 PrimExp v
e2) =
  if UnOp
op1 UnOp -> UnOp -> Bool
forall a. Eq a => a -> a -> Bool
== UnOp
op2
    then
      let (PrimExp (Ext v)
e, [(PrimExp v, PrimExp v)]
m1) = [(PrimExp v, PrimExp v)]
-> PrimExp v
-> PrimExp v
-> (PrimExp (Ext v), [(PrimExp v, PrimExp v)])
forall v.
Eq v =>
[(PrimExp v, PrimExp v)]
-> PrimExp v
-> PrimExp v
-> (PrimExp (Ext v), [(PrimExp v, PrimExp v)])
leastGeneralGeneralization [(PrimExp v, PrimExp v)]
m PrimExp v
e1 PrimExp v
e2
       in (UnOp -> PrimExp (Ext v) -> PrimExp (Ext v)
forall v. UnOp -> PrimExp v -> PrimExp v
UnOpExp UnOp
op1 PrimExp (Ext v)
e, [(PrimExp v, PrimExp v)]
m1)
    else [(PrimExp v, PrimExp v)]
-> PrimExp v
-> PrimExp v
-> (PrimExp (Ext v), [(PrimExp v, PrimExp v)])
forall v.
Eq v =>
[(PrimExp v, PrimExp v)]
-> PrimExp v
-> PrimExp v
-> (PrimExp (Ext v), [(PrimExp v, PrimExp v)])
generalize [(PrimExp v, PrimExp v)]
m PrimExp v
exp1 PrimExp v
exp2
leastGeneralGeneralization [(PrimExp v, PrimExp v)]
m exp1 :: PrimExp v
exp1@(ConvOpExp ConvOp
op1 PrimExp v
e1) exp2 :: PrimExp v
exp2@(ConvOpExp ConvOp
op2 PrimExp v
e2) =
  if ConvOp
op1 ConvOp -> ConvOp -> Bool
forall a. Eq a => a -> a -> Bool
== ConvOp
op2
    then
      let (PrimExp (Ext v)
e, [(PrimExp v, PrimExp v)]
m1) = [(PrimExp v, PrimExp v)]
-> PrimExp v
-> PrimExp v
-> (PrimExp (Ext v), [(PrimExp v, PrimExp v)])
forall v.
Eq v =>
[(PrimExp v, PrimExp v)]
-> PrimExp v
-> PrimExp v
-> (PrimExp (Ext v), [(PrimExp v, PrimExp v)])
leastGeneralGeneralization [(PrimExp v, PrimExp v)]
m PrimExp v
e1 PrimExp v
e2
       in (ConvOp -> PrimExp (Ext v) -> PrimExp (Ext v)
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp ConvOp
op1 PrimExp (Ext v)
e, [(PrimExp v, PrimExp v)]
m1)
    else [(PrimExp v, PrimExp v)]
-> PrimExp v
-> PrimExp v
-> (PrimExp (Ext v), [(PrimExp v, PrimExp v)])
forall v.
Eq v =>
[(PrimExp v, PrimExp v)]
-> PrimExp v
-> PrimExp v
-> (PrimExp (Ext v), [(PrimExp v, PrimExp v)])
generalize [(PrimExp v, PrimExp v)]
m PrimExp v
exp1 PrimExp v
exp2
leastGeneralGeneralization [(PrimExp v, PrimExp v)]
m exp1 :: PrimExp v
exp1@(FunExp String
s1 [PrimExp v]
args1 PrimType
t1) exp2 :: PrimExp v
exp2@(FunExp String
s2 [PrimExp v]
args2 PrimType
_) =
  if String
s1 String -> String -> Bool
forall a. Eq a => a -> a -> Bool
== String
s2 Bool -> Bool -> Bool
&& [PrimExp v] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimExp v]
args1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [PrimExp v] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimExp v]
args2
    then
      let ([PrimExp (Ext v)]
args, [(PrimExp v, PrimExp v)]
m') =
            (([PrimExp (Ext v)], [(PrimExp v, PrimExp v)])
 -> (PrimExp v, PrimExp v)
 -> ([PrimExp (Ext v)], [(PrimExp v, PrimExp v)]))
-> ([PrimExp (Ext v)], [(PrimExp v, PrimExp v)])
-> [(PrimExp v, PrimExp v)]
-> ([PrimExp (Ext v)], [(PrimExp v, PrimExp v)])
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
              ( \([PrimExp (Ext v)]
arg_acc, [(PrimExp v, PrimExp v)]
m_acc) (PrimExp v
a1, PrimExp v
a2) ->
                  let (PrimExp (Ext v)
a, [(PrimExp v, PrimExp v)]
m'') = [(PrimExp v, PrimExp v)]
-> PrimExp v
-> PrimExp v
-> (PrimExp (Ext v), [(PrimExp v, PrimExp v)])
forall v.
Eq v =>
[(PrimExp v, PrimExp v)]
-> PrimExp v
-> PrimExp v
-> (PrimExp (Ext v), [(PrimExp v, PrimExp v)])
leastGeneralGeneralization [(PrimExp v, PrimExp v)]
m_acc PrimExp v
a1 PrimExp v
a2
                   in (PrimExp (Ext v)
a PrimExp (Ext v) -> [PrimExp (Ext v)] -> [PrimExp (Ext v)]
forall a. a -> [a] -> [a]
: [PrimExp (Ext v)]
arg_acc, [(PrimExp v, PrimExp v)]
m'')
              )
              ([], [(PrimExp v, PrimExp v)]
m)
              ([PrimExp v] -> [PrimExp v] -> [(PrimExp v, PrimExp v)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PrimExp v]
args1 [PrimExp v]
args2)
       in (String -> [PrimExp (Ext v)] -> PrimType -> PrimExp (Ext v)
forall v. String -> [PrimExp v] -> PrimType -> PrimExp v
FunExp String
s1 ([PrimExp (Ext v)] -> [PrimExp (Ext v)]
forall a. [a] -> [a]
reverse [PrimExp (Ext v)]
args) PrimType
t1, [(PrimExp v, PrimExp v)]
m')
    else [(PrimExp v, PrimExp v)]
-> PrimExp v
-> PrimExp v
-> (PrimExp (Ext v), [(PrimExp v, PrimExp v)])
forall v.
Eq v =>
[(PrimExp v, PrimExp v)]
-> PrimExp v
-> PrimExp v
-> (PrimExp (Ext v), [(PrimExp v, PrimExp v)])
generalize [(PrimExp v, PrimExp v)]
m PrimExp v
exp1 PrimExp v
exp2
leastGeneralGeneralization [(PrimExp v, PrimExp v)]
m PrimExp v
exp1 PrimExp v
exp2 =
  [(PrimExp v, PrimExp v)]
-> PrimExp v
-> PrimExp v
-> (PrimExp (Ext v), [(PrimExp v, PrimExp v)])
forall v.
Eq v =>
[(PrimExp v, PrimExp v)]
-> PrimExp v
-> PrimExp v
-> (PrimExp (Ext v), [(PrimExp v, PrimExp v)])
generalize [(PrimExp v, PrimExp v)]
m PrimExp v
exp1 PrimExp v
exp2

generalize :: Eq v => [(PrimExp v, PrimExp v)] -> PrimExp v -> PrimExp v -> (PrimExp (Ext v), [(PrimExp v, PrimExp v)])
generalize :: forall v.
Eq v =>
[(PrimExp v, PrimExp v)]
-> PrimExp v
-> PrimExp v
-> (PrimExp (Ext v), [(PrimExp v, PrimExp v)])
generalize [(PrimExp v, PrimExp v)]
m PrimExp v
exp1 PrimExp v
exp2 =
  let t :: PrimType
t = PrimExp v -> PrimType
forall v. PrimExp v -> PrimType
primExpType PrimExp v
exp1
   in case (PrimExp v, PrimExp v) -> [(PrimExp v, PrimExp v)] -> Maybe Int
forall a. Eq a => a -> [a] -> Maybe Int
elemIndex (PrimExp v
exp1, PrimExp v
exp2) [(PrimExp v, PrimExp v)]
m of
        Just Int
i -> (Ext v -> PrimType -> PrimExp (Ext v)
forall v. v -> PrimType -> PrimExp v
LeafExp (Int -> Ext v
forall a. Int -> Ext a
Ext Int
i) PrimType
t, [(PrimExp v, PrimExp v)]
m)
        Maybe Int
Nothing -> (Ext v -> PrimType -> PrimExp (Ext v)
forall v. v -> PrimType -> PrimExp v
LeafExp (Int -> Ext v
forall a. Int -> Ext a
Ext (Int -> Ext v) -> Int -> Ext v
forall a b. (a -> b) -> a -> b
$ [(PrimExp v, PrimExp v)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(PrimExp v, PrimExp v)]
m) PrimType
t, [(PrimExp v, PrimExp v)]
m [(PrimExp v, PrimExp v)]
-> [(PrimExp v, PrimExp v)] -> [(PrimExp v, PrimExp v)]
forall a. [a] -> [a] -> [a]
++ [(PrimExp v
exp1, PrimExp v
exp2)])