module Math.Algebra.LinearAlgebra where
import Prelude hiding ( (*>), (<*>) )
import qualified Data.List as L
import Math.Core.Field
infixr 8 *>, *>>
infixr 7 <<*>
infixl 7 <.>, <*>, <<*>>, <*>>
infixl 6 <+>, <->, <<+>>, <<->>
(<+>) :: (Num a) => [a] -> [a] -> [a]
u <+> v = zipWith (+) u v
(<->) :: (Num a) => [a] -> [a] -> [a]
u <-> v = zipWith () u v
(*>) :: (Num a) => a -> [a] -> [a]
k *> v = map (k*) v
(<.>) :: (Num a) => [a] -> [a] -> a
u <.> v = sum (zipWith (*) u v)
(<*>) :: (Num a) => [a] -> [a] -> [[a]]
u <*> v = [ [a*b | b <- v] | a <- u]
(<<+>>) :: (Num a) => [[a]] -> [[a]] -> [[a]]
a <<+>> b = (zipWith . zipWith) (+) a b
(<<->>) :: (Num a) => [[a]] -> [[a]] -> [[a]]
a <<->> b = (zipWith . zipWith) () a b
(<<*>>) :: (Num a) => [[a]] -> [[a]] -> [[a]]
a <<*>> b = [ [u <.> v | v <- L.transpose b] | u <- a]
(*>>) :: (Num a) => a -> [[a]] -> [[a]]
k *>> m = (map . map) (k*) m
(<<*>) :: (Num a) => [[a]] -> [a] -> [a]
m <<*> v = map (<.> v) m
(<*>>) :: (Num a) => [a] -> [[a]] -> [a]
v <*>> m = map (v <.>) (L.transpose m)
fMatrix n f = [[f i j | j <- [1..n]] | i <- [1..n]]
fMatrix' n f = [[f i j | j <- [0..n1]] | i <- [0..n1]]
idMx n = idMxs !! n where
idMxs = map snd $ iterate next (0,[])
next (j,m) = (j+1, (1 : replicate j 0) : map (0:) m)
iMx :: (Num t) => Int -> [[t]]
iMx n = idMx n
jMx :: (Num t) => Int -> [[t]]
jMx n = replicate n (replicate n 1)
zMx :: (Num t) => Int -> [[t]]
zMx n = replicate n (replicate n 0)
inverse :: (Eq a, Fractional a) => [[a]] -> Maybe [[a]]
inverse m =
let d = length m
i = idMx d
m' = zipWith (++) m i
i1 = inverse1 m'
i2 = inverse2 i1
in if length i1 == d
then Just i2
else Nothing
inverse1 [] = []
inverse1 ((x:xs):rs) =
if x /= 0
then let r' = (1/x) *> xs
in (1:r') : inverse1 [ys <-> y *> r' | (y:ys) <- rs]
else case filter (\r' -> head r' /= 0) rs of
[] -> []
r:_ -> inverse1 (((x:xs) <+> r) : rs)
inverse2 [] = []
inverse2 ((1:r):rs) = inverse2' r rs : inverse2 rs where
inverse2' xs [] = xs
inverse2' (x:xs) ((1:r):rs) = inverse2' (xs <-> x *> r) rs
xs ! i = xs !! (i1)
rowEchelonForm [] = []
rowEchelonForm ((x:xs):rs) =
if x /= 0
then let r' = (1/x) *> xs
in (1:r') : map (0:) (rowEchelonForm [ys <-> y *> r' | (y:ys) <- rs])
else case filter (\r' -> head r' /= 0) rs of
[] -> map (0:) (rowEchelonForm $ xs : map tail rs)
r:_ -> rowEchelonForm (((x:xs) <+> r) : rs)
rowEchelonForm zs@([]:_) = zs
reducedRowEchelonForm :: (Eq a, Fractional a) => [[a]] -> [[a]]
reducedRowEchelonForm m = reverse $ reduce $ reverse $ rowEchelonForm m where
reduce (r:rs) = let r':rs' = reduceStep (r:rs) in r' : reduce rs'
reduce [] = []
reduceStep ((1:xs):rs) = (1:xs) : [ 0: (ys <-> y *> xs) | y:ys <- rs]
reduceStep rs@((0:_):_) = zipWith (:) (map head rs) (reduceStep $ map tail rs)
reduceStep rs = rs
solveLinearSystem m b =
let augmented = zipWith (\r x -> r ++ [x]) m b
trisystem = inverse1 augmented
solution = reverse $ solveTriSystem $ reverse $ map reverse trisystem
in if length solution == length b then Just solution else Nothing
where solveTriSystem ([v,c]:rs) =
let x = v/c
rs' = map (\(v':c':r) -> (v'c'*x):r) rs
in x : solveTriSystem rs'
solveTriSystem [] = []
solveTriSystem _ = []
isZero v = all (==0) v
inSpanRE ((1:xs):bs) (y:ys) = inSpanRE (map tail bs) (if y == 0 then ys else ys <-> y *> xs)
inSpanRE ((0:xs):bs) (y:ys) = if y == 0 then inSpanRE (xs : map tail bs) ys else False
inSpanRE _ ys = isZero ys
rank m = length $ filter (not . isZero) $ rowEchelonForm m
kernel m = kernelRRE $ reducedRowEchelonForm m
kernelRRE m =
let nc = length $ head m
is = findLeadingCols 1 (L.transpose m)
js = [1..nc] L.\\ is
freeCols = let m' = take (length is) m
in zip is $ L.transpose [map (negate . (!j)) m' | j <- js]
boundCols = zip js (idMx $ length js)
in L.transpose $ map snd $ L.sort $ freeCols ++ boundCols
where
findLeadingCols i (c@(1:_):cs) = i : findLeadingCols (i+1) (map tail cs)
findLeadingCols i (c@(0:_):cs) = findLeadingCols (i+1) cs
findLeadingCols _ _ = []
det :: (Eq a, Fractional a) => [[a]] -> a
det [[x]] = x
det ((x:xs):rs) =
if x /= 0
then let r' = (1/x) *> xs
in x * det [ys <-> y *> r' | (y:ys) <- rs]
else case filter (\r' -> head r' /= 0) rs of
[] -> 0
r:_ -> det (((x:xs) <+> r) : rs)