{-# LANGUAGE ScopedTypeVariables #-}
-- | A list diff.
module Test.StateMachine.TreeDiff.List (diffBy, Edit (..)) where

import           Data.List.Compat
                   (sortOn)
import qualified Data.MemoTrie    as M
import qualified Data.Vector      as V

-- | List edit operations
--
-- The 'Swp' constructor is redundant, but it let us spot
-- a recursion point when performing tree diffs.
data Edit a
    = Ins a    -- ^ insert
    | Del a    -- ^ delete
    | Cpy a    -- ^ copy unchanged
    | Swp a a  -- ^ swap, i.e. delete + insert
  deriving Int -> Edit a -> ShowS
[Edit a] -> ShowS
Edit a -> String
(Int -> Edit a -> ShowS)
-> (Edit a -> String) -> ([Edit a] -> ShowS) -> Show (Edit a)
forall a. Show a => Int -> Edit a -> ShowS
forall a. Show a => [Edit a] -> ShowS
forall a. Show a => Edit a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall a. Show a => Int -> Edit a -> ShowS
showsPrec :: Int -> Edit a -> ShowS
$cshow :: forall a. Show a => Edit a -> String
show :: Edit a -> String
$cshowList :: forall a. Show a => [Edit a] -> ShowS
showList :: [Edit a] -> ShowS
Show

-- | List difference.
--
-- >>> diffBy (==) "hello" "world"
-- [Swp 'h' 'w',Swp 'e' 'o',Swp 'l' 'r',Cpy 'l',Swp 'o' 'd']
--
-- >>> diffBy (==) "kitten" "sitting"
-- [Swp 'k' 's',Cpy 'i',Cpy 't',Cpy 't',Swp 'e' 'i',Cpy 'n',Ins 'g']
--
-- prop> \xs ys -> length (diffBy (==) xs ys) >= max (length xs) (length (ys :: String))
-- prop> \xs ys -> length (diffBy (==) xs ys) <= length xs + length (ys :: String)
--
-- /Note:/ currently this has O(n*m) memory requirements, for the sake
-- of more obviously correct implementation.
--
diffBy :: forall a. (a -> a -> Bool) -> [a] -> [a] -> [Edit a]
diffBy :: forall a. (a -> a -> Bool) -> [a] -> [a] -> [Edit a]
diffBy a -> a -> Bool
eq [a]
xs' [a]
ys' = [Edit a] -> [Edit a]
forall a. [a] -> [a]
reverse ((Int, [Edit a]) -> [Edit a]
forall a b. (a, b) -> b
snd (Int -> Int -> (Int, [Edit a])
lcs (Vector a -> Int
forall a. Vector a -> Int
V.length Vector a
xs) (Vector a -> Int
forall a. Vector a -> Int
V.length Vector a
ys)))
  where
    xs :: Vector a
xs = [a] -> Vector a
forall a. [a] -> Vector a
V.fromList [a]
xs'
    ys :: Vector a
ys = [a] -> Vector a
forall a. [a] -> Vector a
V.fromList [a]
ys'

    lcs :: Int -> Int -> (Int, [Edit a])
lcs = (Int -> Int -> (Int, [Edit a])) -> Int -> Int -> (Int, [Edit a])
forall s t a.
(HasTrie s, HasTrie t) =>
(s -> t -> a) -> s -> t -> a
M.memo2 Int -> Int -> (Int, [Edit a])
impl

    impl :: Int -> Int -> (Int, [Edit a])
    impl :: Int -> Int -> (Int, [Edit a])
impl Int
0 Int
0 = (Int
0, [])
    impl Int
0 Int
m = case Int -> Int -> (Int, [Edit a])
lcs Int
0 (Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) of
        (Int
w, [Edit a]
edit) -> (Int
w Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1, a -> Edit a
forall a. a -> Edit a
Ins (Vector a
ys Vector a -> Int -> a
forall a. Vector a -> Int -> a
V.! (Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)) Edit a -> [Edit a] -> [Edit a]
forall a. a -> [a] -> [a]
: [Edit a]
edit)
    impl Int
n Int
0 = case Int -> Int -> (Int, [Edit a])
lcs (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) Int
0 of
        (Int
w, [Edit a]
edit) -> (Int
w Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1, a -> Edit a
forall a. a -> Edit a
Del (Vector a
xs Vector a -> Int -> a
forall a. Vector a -> Int -> a
V.! (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)) Edit a -> [Edit a] -> [Edit a]
forall a. a -> [a] -> [a]
: [Edit a]
edit)

    impl Int
n Int
m = [(Int, [Edit a])] -> (Int, [Edit a])
forall a. HasCallStack => [a] -> a
head ([(Int, [Edit a])] -> (Int, [Edit a]))
-> [(Int, [Edit a])] -> (Int, [Edit a])
forall a b. (a -> b) -> a -> b
$ ((Int, [Edit a]) -> Int) -> [(Int, [Edit a])] -> [(Int, [Edit a])]
forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn (Int, [Edit a]) -> Int
forall a b. (a, b) -> a
fst
        [ (Int, [Edit a])
edit
        , (Int -> Int)
-> ([Edit a] -> [Edit a]) -> (Int, [Edit a]) -> (Int, [Edit a])
forall a c b d. (a -> c) -> (b -> d) -> (a, b) -> (c, d)
bimap (Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) (a -> Edit a
forall a. a -> Edit a
Ins a
y Edit a -> [Edit a] -> [Edit a]
forall a. a -> [a] -> [a]
:) (Int -> Int -> (Int, [Edit a])
lcs Int
n (Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1))
        , (Int -> Int)
-> ([Edit a] -> [Edit a]) -> (Int, [Edit a]) -> (Int, [Edit a])
forall a c b d. (a -> c) -> (b -> d) -> (a, b) -> (c, d)
bimap (Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) (a -> Edit a
forall a. a -> Edit a
Del a
x Edit a -> [Edit a] -> [Edit a]
forall a. a -> [a] -> [a]
:) (Int -> Int -> (Int, [Edit a])
lcs (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Int
m)
        ]
      where
        x :: a
x = Vector a
xs Vector a -> Int -> a
forall a. Vector a -> Int -> a
V.! (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
        y :: a
y = Vector a
ys Vector a -> Int -> a
forall a. Vector a -> Int -> a
V.! (Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)

        edit :: (Int, [Edit a])
edit
            | a -> a -> Bool
eq a
x a
y    = (Int -> Int)
-> ([Edit a] -> [Edit a]) -> (Int, [Edit a]) -> (Int, [Edit a])
forall a c b d. (a -> c) -> (b -> d) -> (a, b) -> (c, d)
bimap Int -> Int
forall a. a -> a
id   (a -> Edit a
forall a. a -> Edit a
Cpy a
x Edit a -> [Edit a] -> [Edit a]
forall a. a -> [a] -> [a]
:)   (Int -> Int -> (Int, [Edit a])
lcs (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) (Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1))
            | Bool
otherwise = (Int -> Int)
-> ([Edit a] -> [Edit a]) -> (Int, [Edit a]) -> (Int, [Edit a])
forall a c b d. (a -> c) -> (b -> d) -> (a, b) -> (c, d)
bimap (Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) (a -> a -> Edit a
forall a. a -> a -> Edit a
Swp a
x a
y Edit a -> [Edit a] -> [Edit a]
forall a. a -> [a] -> [a]
:) (Int -> Int -> (Int, [Edit a])
lcs (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1 ) (Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1))

bimap :: (a -> c) -> (b -> d) -> (a, b) -> (c, d)
bimap :: forall a c b d. (a -> c) -> (b -> d) -> (a, b) -> (c, d)
bimap a -> c
f b -> d
g (a
x, b
y) = (a -> c
f a
x, b -> d
g b
y)