{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE GADTs #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  Data.Type.Vector
-- Copyright   :  Copyright (C) 2015 Kyle Carter
-- License     :  BSD3
--
-- Maintainer  :  Kyle Carter <kylcarte@indiana.edu>
-- Stability   :  experimental
-- Portability :  RankNTypes
--
-- 'Vec' and its combinator analog 'VecT' represent lists
-- of known length, characterized by the index @(n :: N)@ in
-- @'Vec' n a@ or @'VecT' n f a@.
--
-- The classic example used ad nauseum for type-level programming.
--
-- The operations on 'Vec' and 'VecT' correspond to the type level arithmetic
-- operations on the kind 'N'.
--
-----------------------------------------------------------------------------

module Data.Type.Vector where

import Data.Type.Combinator
import Data.Type.Fin
import Data.Type.Length
import Data.Type.Nat
import Data.Type.Product (Prod(..),curry',pattern (:>))

import Type.Class.Higher
import Type.Class.Known
import Type.Class.Witness

import Type.Family.Constraint
import Type.Family.List
import Type.Family.Nat

import qualified Data.List as L
import Data.Monoid

data VecT (n :: N) (f :: k -> *) :: k -> * where
  ØV   :: VecT Z f a
  (:*) :: !(f a) -> !(VecT n f a) -> VecT (S n) f a
infixr 4 :*

(*:) :: f a -> f a -> VecT (S (S Z)) f a
a *: b = a :* b :* ØV
infix 5 *:

elimVecT :: p Z
       -> (forall x. f a -> p x -> p (S x))
       -> VecT n f a
       -> p n
elimVecT z s = \case
  ØV      -> z
  a :* as -> s a $ elimVecT z s as

elimV :: p Z
      -> (forall x. a -> p x -> p (S x))
      -> Vec n a
      -> p n
elimV z s = elimVecT z $ s . getI

type Vec n = VecT n I

pattern (:+) :: a -> Vec n a -> Vec (S n) a
pattern a :+ as = I a :* as
infixr 4 :+

(+:) :: a -> a -> Vec (S (S Z)) a
a +: b = a :+ b :+ ØV
infix 5 +:

deriving instance Eq   (f a) => Eq   (VecT n f a)
deriving instance Ord  (f a) => Ord  (VecT n f a)
deriving instance Show (f a) => Show (VecT n f a)

(.++) :: VecT x f a -> VecT y f a -> VecT (x + y) f a
(.++) = \case
  ØV      -> id
  a :* as -> (a :*) . (as .++)
infixr 5 .++

vrep :: forall n f a. Known Nat n => f a -> VecT n f a
vrep a = go (known :: Nat n)
  where
  go :: Nat x -> VecT x f a
  go = \case
    Z_   -> ØV
    S_ x -> a :* go x

head' :: VecT (S n) f a -> f a
head' (a :* _) = a

tail' :: VecT (S n) f a -> VecT n f a
tail' (_ :* as) = as

onTail :: (VecT m f a -> VecT n f a) -> VecT (S m) f a -> VecT (S n) f a
onTail f (a :* as) = a :* f as

vDel :: Fin n -> VecT n f a -> VecT (Pred n) f a
vDel = \case
  FZ   -> tail'
  FS x -> onTail (vDel x) \\ x

imap :: (Fin n -> f a -> g b) -> VecT n f a -> VecT n g b
imap f = \case
  ØV      -> ØV
  a :* as -> f FZ a :* imap (f . FS) as

ifoldMap :: Monoid m => (Fin n -> f a -> m) -> VecT n f a -> m
ifoldMap f = \case
  ØV      -> mempty
  a :* as -> f FZ a <> ifoldMap (f . FS) as

itraverse :: Applicative h => (Fin n -> f a -> h (g b)) -> VecT n f a -> h (VecT n g b)
itraverse f = \case
  ØV      -> pure ØV
  a :* as -> (:*) <$> f FZ a <*> itraverse (f . FS) as

index :: Fin n -> VecT n f a -> f a
index = \case
  FZ   -> head'
  FS x -> index x . tail'

index' :: Fin n -> Vec n a -> a
index' i = getI . index i

vmap :: (f a -> g b) -> VecT n f a -> VecT n g b
vmap f = \case
  ØV      -> ØV
  a :* as -> f a :* vmap f as

vap :: (f a -> g b -> h c) -> VecT n f a -> VecT n g b -> VecT n h c
vap f = \case
  ØV      -> const ØV
  a :* as -> \case
    b :* bs -> f a b :* vap f as bs

vfoldr :: (f a -> b -> b) -> b -> VecT n f a -> b
vfoldr s z = \case
  ØV      -> z
  a :* as -> s a $ vfoldr s z as

vfoldMap' :: (b -> b -> b) -> b -> (f a -> b) -> VecT n f a -> b
vfoldMap' j z f = \case
  ØV       -> z
  a :* ØV  -> f a
  a :* as  -> j (f a) $ vfoldMap' j z f as

vfoldMap :: Monoid m => (f a -> m) -> VecT n f a -> m
vfoldMap f = \case
  ØV      -> mempty
  a :* as -> f a <> vfoldMap f as

withVecT :: [f a] -> (forall n. VecT n f a -> r) -> r
withVecT as k = case as of
  []      -> k ØV
  a : as' -> withVecT as' $ \v -> k $ a :* v

withV :: [a] -> (forall n. Vec n a -> r) -> r
withV as = withVecT (I <$> as)

findV :: Eq a => a -> Vec n a -> Maybe (Fin n)
findV = findVecT . I

findVecT :: Eq (f a) => f a -> VecT n f a -> Maybe (Fin n)
findVecT a = \case
  ØV      -> Nothing
  b :* as -> if a == b
    then Just FZ
    else FS <$> findVecT a as

instance Functor1 (VecT n) where
  map1 f = \case
    ØV      -> ØV
    a :* as -> f a :* map1 f as

instance Foldable1 (VecT n) where
  foldMap1 f = \case
    ØV -> mempty
    a :* as -> f a <> foldMap1 f as

instance Traversable1 (VecT n) where
  traverse1 f = \case
    ØV      -> pure ØV
    a :* as -> (:*) <$> f a <*> traverse1 f as

instance Functor f => Functor (VecT n f) where
  fmap = vmap . fmap

instance (Applicative f, Known Nat n) => Applicative (VecT n f) where
  pure  = vrep . pure
  (<*>) = vap (<*>)

instance (Monad f, Known Nat n) => Monad (VecT n f) where
  v >>= f = imap (\x -> (>>= index x . f)) v

instance Foldable f => Foldable (VecT n f) where
  foldMap f = \case
    ØV      -> mempty
    a :* as -> foldMap f a <> foldMap f as

instance Traversable f => Traversable (VecT n f) where
  traverse f = \case
    ØV      -> pure ØV
    a :* as -> (:*) <$> traverse f a <*> traverse f as

{-
instance (Witness p q (f a), n ~ S x) => Witness p q (VecT n f a) where
  type WitnessC p q (VecT n f a) = Witness p q (f a)
  (\\) r = \case
    a :* _ -> r \\ a
    _      -> error "impossible type"
-}

instance Witness ØC (Known Nat n) (VecT n f a) where
  (\\) r = \case
    ØV      -> r
    _ :* as -> r \\ as

instance (Num (f a), Known Nat n) => Num (VecT n f a) where
  (*)         = vap (*)
  (+)         = vap (+)
  (-)         = vap (-)
  negate      = vmap negate
  abs         = vmap abs
  signum      = vmap signum
  fromInteger = vrep . fromInteger

newtype M ns a = M { getMatrix :: Matrix ns a }

deriving instance Eq   (Matrix ns a) => Eq   (M ns a)
deriving instance Ord  (Matrix ns a) => Ord  (M ns a)
deriving instance Show (Matrix ns a) => Show (M ns a)

instance Num (Matrix ns a) => Num (M ns a) where
  fromInteger  = M . fromInteger
  M a * M b    = M $ a * b
  M a + M b    = M $ a + b
  M a - M b    = M $ a - b
  abs (M a)    = M $ abs a
  signum (M a) = M $ signum a

type family Matrix (ns :: [N]) :: * -> * where
  Matrix Ø         = I
  Matrix (n :< ns) = VecT n (Matrix ns)

vgen_ :: Known Nat n => (Fin n -> f a) -> VecT n f a
vgen_ = vgen known

vgen :: Nat n -> (Fin n -> f a) -> VecT n f a
vgen x f = case x of
  Z_   -> ØV
  S_ y -> f FZ :* vgen y (f . FS)

mgen_ :: Known (Prod Nat) ns => (Prod Fin ns -> a) -> M ns a
mgen_ = mgen known

mgen :: Prod Nat ns -> (Prod Fin ns -> a) -> M ns a
mgen ns f = case ns of
  Ø        -> M $ I $ f Ø
  n :< ns' -> M $ vgen n $ getMatrix . mgen ns' . curry' f

onMatrix :: (Matrix ms a -> Matrix ns b) -> M ms a -> M ns b
onMatrix f = M . f . getMatrix

diagonal :: VecT n (VecT n f) a -> VecT n f a
diagonal = imap index

vtranspose :: Known Nat n => VecT m (VecT n f) a -> VecT n (VecT m f) a
vtranspose v = vgen_ $ \x -> vmap (index x) v

transpose :: Known Nat n => M (m :< n :< ns) a -> M (n :< m :< ns) a
transpose = onMatrix vtranspose

m0 :: M Ø Int
m0 = 1

m1 :: M '[N2] Int
m1 = 2

m2 :: M '[N2,N4] Int
m2 = 3

m3 :: M '[N2,N3,N4] (Int,Int,Int)
m3 = mgen_ $ \(x :< y :> z) -> (fin x,fin y,fin z)

m4 :: M '[N2,N3,N4,N5] (Int,Int,Int,Int)
m4 = mgen_ $ \(w :< x :< y :> z) -> (fin w,fin x,fin y,fin z)

ppVec :: (VecT n ((->) String) String -> ShowS) -> (f a -> ShowS) -> VecT n f a -> ShowS
ppVec pV pF = pV . vmap pF

ppMatrix :: forall ns a. (Show a, Known Length ns) => M ns a -> IO ()
ppMatrix = putStrLn . ($ "") . ppMatrix' (known :: Length ns) . getMatrix

ppMatrix' :: Show a => Length ns -> Matrix ns a -> ShowS
ppMatrix' = \case
  LZ   -> shows . getI
  LS l -> ppVec
    ( vfoldMap'
      ( if lEven l
        then zipLines $ \x y -> x . showChar '|'  . y
        else            \x y -> x . showChar '\n' . y
      ) (showString "[]") id
    ) $ ppMatrix' l

mzipWith :: Monoid a => (a -> a -> b) -> [a] -> [a] -> [b]
mzipWith f as bs = case (as,bs) of
  ([]  ,[]  ) -> []
  (a:as',[]  ) -> f a      mempty  : mzipWith f as' []
  ([]   ,b:bs') -> f mempty b      : mzipWith f []  bs'
  (a:as',b:bs') -> f a      b      : mzipWith f as' bs'

zipLines :: (ShowS -> ShowS -> ShowS) -> ShowS -> ShowS -> ShowS
zipLines f a b = compose $ L.intersperse (showChar '\n') $ mzipWith
  (\(Endo x) (Endo y) -> f x y)
  (Endo . showString <$> lines (a ""))
  (Endo . showString <$> lines (b ""))

{-
juxtLines :: (ShowS -> ShowS -> ShowS) -> ShowS -> ShowS -> ShowS
juxtLines f a b = appEndo $ foldMap id $ mzip (\x y -> Endo $ f (appEndo x) (appEndo y)) as bs
  where
  as = map (Endo . showString) $ lines $ a ""
  bs = map (Endo . showString) $ lines $ b ""
-}

compose :: Foldable f => f (a -> a) -> a -> a
compose = appEndo . foldMap Endo

{-
-- Linear {{{

class Functor f => Additive f where
  zero   :: Num a => f a
  (^+^)  :: Num a => f a -> f a -> f a
  (^-^)  :: Num a => f a -> f a -> f a
  lerp   :: Num a =>   a -> f a -> f a -> f a
  liftU2 :: (a -> a -> a) -> f a -> f a -> f a
  liftI2 :: (a -> b -> c) -> f a -> f b -> f c
  --------
  default zero :: (Applicative f, Num a) => f a
  zero = pure 0
  (^+^)   = liftU2 (+)
  a ^-^ b = a ^+^ negated b
  lerp alpha a b = alpha *^ a ^+^ (1 - alpha) *^ b
  default liftU2 :: Applicative f => (a -> a -> a) -> f a -> f a -> f a
  liftU2 = liftA2
  default liftI2 :: Applicative f => (a -> b -> c) -> f a -> f b -> f c
  liftI2 = liftA2
infixl 6 ^+^, ^-^

instance Additive I
instance (Additive f, Known Nat n) => Additive (VecT n f) where
  zero   = vrep zero
  liftU2 = vap . liftU2
  liftI2 = vap . liftI2

class Additive (Diff f) => Affine f where
  type Diff f :: * -> *
  type Diff f = f
  (.-.) :: Num a => f a -> f a -> Diff f a
  (.+^) :: Num a => f a -> Diff f a -> f a
  (.-^) :: Num a => f a -> Diff f a -> f a
  --------
  p .-^ d = p .+^ negated d
  default (.-.) :: (Affine f, Diff f ~ f, Num a) => f a -> f a -> Diff f a
  (.-.) = (^-^)
  default (.+^) :: (Affine f, Diff f ~ f, Num a) => f a -> f a -> Diff f a
  (.+^) = (^+^)
infixl 6 .-., .+^, .-^

instance Affine I
instance (Affine f, Known Nat n) => Affine (VecT n f) where
  type Diff (VecT n f) = VecT n (Diff f)
  (.-.) = vap (.-.)
  (.+^) = vap (.+^)
  (.-^) = vap (.-^)

class Additive f => Metric f where
  dot       :: Num a => f a -> f a -> a
  quadrance :: Num a => f a -> a
  qd        :: Num a => f a -> f a -> a
  distance  :: Floating a => f a -> f a -> a
  norm      :: Floating a => f a -> a
  signorm   :: Floating a => f a -> f a
  --------
  default dot :: (Foldable f, Num a) => f a -> f a -> a
  dot       a b = F.sum $ liftI2 (*) a b
  quadrance     = join dot
  qd        a b = quadrance $ a ^-^ b
  distance  a b = norm $ a ^-^ b
  norm          = sqrt . quadrance
  signorm   a   = (/ norm a) <$> a

instance Metric I where
  dot (I a) (I b) = a * b

instance (Metric f, Known Nat n) => Metric (VecT n f) where
  dot a b = getSum $ foldMap Sum $ vap ((I .) . dot) a b

(*^) :: (Functor f, Num a) => a -> f a -> f a
(*^) a = fmap (a*)
infixl 7 *^

negated :: (Functor f, Num a) => f a -> f a
negated = fmap negate

qdA :: (Affine f, Foldable (Diff f), Num a) => f a -> f a -> a
qdA a b = F.sum $ join (*) <$> a .-. b

distanceA :: (Affine f, Foldable (Diff f), Floating a) => f a -> f a -> a
distanceA a b = sqrt $ qdA a b

-- }}}
-}