{-# OPTIONS_GHC -fno-warn-name-shadowing #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  Diagrams.Solve.Tridiagonal
-- Copyright   :  (c) 2011-2015 diagrams-solve team (see LICENSE)
-- License     :  BSD-style (see LICENSE)
-- Maintainer  :  diagrams-discuss@googlegroups.com
--
-- Solving of tridiagonal and cyclic tridiagonal linear systems.
--
-----------------------------------------------------------------------------
module Diagrams.Solve.Tridiagonal
       ( solveTriDiagonal
       , solveCyclicTriDiagonal
       ) where

-- | @solveTriDiagonal as bs cs ds@ solves a system of the form @A*X = ds@
--   where 'A' is an 'n' by 'n' matrix with 'bs' as the main diagonal
--   and 'as' the diagonal below and 'cs' the diagonal above.  See:
--   <http://en.wikipedia.org/wiki/Tridiagonal_matrix_algorithm>

solveTriDiagonal :: Fractional a => [a] -> [a] -> [a] -> [a] -> [a]
solveTriDiagonal :: [a] -> [a] -> [a] -> [a] -> [a]
solveTriDiagonal [a]
as (a
b0:[a]
bs) (a
c0:[a]
cs) (a
d0:[a]
ds) = [a] -> [a] -> [a]
forall a. Num a => [a] -> [a] -> [a]
h [a]
cs' [a]
ds'
  where
    cs' :: [a]
cs' = a
c0 a -> a -> a
forall a. Fractional a => a -> a -> a
/ a
b0 a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a] -> [a] -> [a] -> [a] -> [a]
forall a. Fractional a => [a] -> [a] -> [a] -> [a] -> [a]
f [a]
cs' [a]
as [a]
bs [a]
cs
    f :: [a] -> [a] -> [a] -> [a] -> [a]
f [a]
_ [a
_] [a]
_ [a]
_ = []
    f (a
c':[a]
cs') (a
a:[a]
as) (a
b:[a]
bs) (a
c:[a]
cs) = a
c a -> a -> a
forall a. Fractional a => a -> a -> a
/ (a
b a -> a -> a
forall a. Num a => a -> a -> a
- a
c' a -> a -> a
forall a. Num a => a -> a -> a
* a
a) a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a] -> [a] -> [a] -> [a] -> [a]
f [a]
cs' [a]
as [a]
bs [a]
cs
    f [a]
_ [a]
_ [a]
_ [a]
_ = [Char] -> [a]
forall a. HasCallStack => [Char] -> a
error [Char]
"solveTriDiagonal.f: impossible!"

    ds' :: [a]
ds' = a
d0 a -> a -> a
forall a. Fractional a => a -> a -> a
/ a
b0 a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a] -> [a] -> [a] -> [a] -> [a] -> [a]
forall a. Fractional a => [a] -> [a] -> [a] -> [a] -> [a] -> [a]
g [a]
ds' [a]
as [a]
bs [a]
cs' [a]
ds
    g :: [a] -> [a] -> [a] -> [a] -> [a] -> [a]
g [a]
_ [] [a]
_ [a]
_ [a]
_ = []
    g (a
d':[a]
ds') (a
a:[a]
as) (a
b:[a]
bs) (a
c':[a]
cs') (a
d:[a]
ds) = (a
d a -> a -> a
forall a. Num a => a -> a -> a
- a
d' a -> a -> a
forall a. Num a => a -> a -> a
* a
a)a -> a -> a
forall a. Fractional a => a -> a -> a
/(a
b a -> a -> a
forall a. Num a => a -> a -> a
- a
c' a -> a -> a
forall a. Num a => a -> a -> a
* a
a) a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a] -> [a] -> [a] -> [a] -> [a] -> [a]
g [a]
ds' [a]
as [a]
bs [a]
cs' [a]
ds
    g [a]
_ [a]
_ [a]
_ [a]
_ [a]
_ = [Char] -> [a]
forall a. HasCallStack => [Char] -> a
error [Char]
"solveTriDiagonal.g: impossible!"

    h :: [a] -> [a] -> [a]
h [a]
_ [a
d] = [a
d]
    h (a
c:[a]
cs) (a
d:[a]
ds) = let xs :: [a]
xs@(a
x:[a]
_) = [a] -> [a] -> [a]
h [a]
cs [a]
ds in a
d a -> a -> a
forall a. Num a => a -> a -> a
- a
c a -> a -> a
forall a. Num a => a -> a -> a
* a
x a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a]
xs
    h [a]
_ [a]
_ = [Char] -> [a]
forall a. HasCallStack => [Char] -> a
error [Char]
"solveTriDiagonal.h: impossible!"

solveTriDiagonal [a]
_ [a]
_ [a]
_ [a]
_ = [Char] -> [a]
forall a. HasCallStack => [Char] -> a
error [Char]
"arguments 2,3,4 to solveTriDiagonal must be nonempty"

-- Helper that applies the passed function only to the last element of a list
modifyLast :: (a -> a) -> [a] -> [a]
modifyLast :: (a -> a) -> [a] -> [a]
modifyLast a -> a
_ []     = []
modifyLast a -> a
f [a
a]    = [a -> a
f a
a]
modifyLast a -> a
f (a
a:[a]
as) = a
a a -> [a] -> [a]
forall a. a -> [a] -> [a]
: (a -> a) -> [a] -> [a]
forall a. (a -> a) -> [a] -> [a]
modifyLast a -> a
f [a]
as

-- Helper that builds a list of length n of the form: '[s,m,m,...,m,m,e]'
sparseVector :: Int -> a -> a -> a -> [a]
sparseVector :: Int -> a -> a -> a -> [a]
sparseVector Int
n a
s a
m a
e
    | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
1     = []
    | Bool
otherwise = a
s a -> [a] -> [a]
forall a. a -> [a] -> [a]
: Int -> [a]
forall t. (Eq t, Num t) => t -> [a]
h (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
  where
    h :: t -> [a]
h t
1 = [a
e]
    h t
n = a
m a -> [a] -> [a]
forall a. a -> [a] -> [a]
: t -> [a]
h (t
n t -> t -> t
forall a. Num a => a -> a -> a
- t
1)

-- | Solves a system similar to the tri-diagonal system using a special case
--   of the Sherman-Morrison formula (<http://en.wikipedia.org/wiki/Sherman-Morrison_formula>).
--   This code is based on /Numerical Recpies in C/'s @cyclic@ function in section 2.7.
solveCyclicTriDiagonal :: Fractional a => [a] -> [a] -> [a] -> [a] -> a -> a -> [a]
solveCyclicTriDiagonal :: [a] -> [a] -> [a] -> [a] -> a -> a -> [a]
solveCyclicTriDiagonal [a]
as (a
b0:[a]
bs) [a]
cs [a]
ds a
alpha a
beta = (a -> a -> a) -> [a] -> [a] -> [a]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (a -> a -> a
forall a. Num a => a -> a -> a
(+) (a -> a -> a) -> (a -> a) -> a -> a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a
fact a -> a -> a
forall a. Num a => a -> a -> a
*)) [a]
zs [a]
xs
  where
    l :: Int
l = [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
ds
    gamma :: a
gamma = -a
b0
    us :: [a]
us = Int -> a -> a -> a -> [a]
forall a. Int -> a -> a -> a -> [a]
sparseVector Int
l a
gamma a
0 a
alpha

    bs' :: [a]
bs' = (a
b0 a -> a -> a
forall a. Num a => a -> a -> a
- a
gamma) a -> [a] -> [a]
forall a. a -> [a] -> [a]
: (a -> a) -> [a] -> [a]
forall a. (a -> a) -> [a] -> [a]
modifyLast (a -> a -> a
forall a. Num a => a -> a -> a
subtract (a
alphaa -> a -> a
forall a. Num a => a -> a -> a
*a
betaa -> a -> a
forall a. Fractional a => a -> a -> a
/a
gamma)) [a]
bs

    xs :: [a]
xs@(a
x:[a]
_) = [a] -> [a] -> [a] -> [a] -> [a]
forall a. Fractional a => [a] -> [a] -> [a] -> [a] -> [a]
solveTriDiagonal [a]
as [a]
bs' [a]
cs [a]
ds
    zs :: [a]
zs@(a
z:[a]
_) = [a] -> [a] -> [a] -> [a] -> [a]
forall a. Fractional a => [a] -> [a] -> [a] -> [a] -> [a]
solveTriDiagonal [a]
as [a]
bs' [a]
cs [a]
us

    fact :: a
fact = -(a
x a -> a -> a
forall a. Num a => a -> a -> a
+ a
beta a -> a -> a
forall a. Num a => a -> a -> a
* [a] -> a
forall a. [a] -> a
last [a]
xs a -> a -> a
forall a. Fractional a => a -> a -> a
/ a
gamma) a -> a -> a
forall a. Fractional a => a -> a -> a
/ (a
1.0 a -> a -> a
forall a. Num a => a -> a -> a
+ a
z a -> a -> a
forall a. Num a => a -> a -> a
+ a
beta a -> a -> a
forall a. Num a => a -> a -> a
* [a] -> a
forall a. [a] -> a
last [a]
zs a -> a -> a
forall a. Fractional a => a -> a -> a
/ a
gamma)

solveCyclicTriDiagonal [a]
_ [a]
_ [a]
_ [a]
_ a
_ a
_ = [Char] -> [a]
forall a. HasCallStack => [Char] -> a
error [Char]
"second argument to solveCyclicTriDiagonal must be nonempty"