{-# LANGUAGE TypeOperators #-}
module Math.LinearCircuit (resistance) where

import qualified Data.Graph.Comfort as Graph
import Data.Graph.Comfort (Graph)

import qualified Numeric.LAPACK.Matrix.Shape as MatrixShape
import qualified Numeric.LAPACK.Matrix.Triangular as Triangular
import qualified Numeric.LAPACK.Matrix as Matrix
import qualified Numeric.LAPACK.Vector as Vector
import qualified Numeric.Netlib.Class as Class
import Numeric.LAPACK.Matrix.Triangular ((#%%%#))
import Numeric.LAPACK.Matrix ((#\|))

import qualified Data.Array.Comfort.Boxed as BoxedArray
import qualified Data.Array.Comfort.Storable as Array
import Data.Array.Comfort.Storable ((!))
import Data.Array.Comfort.Shape ((:+:)((:+:)))

import Control.Monad.Trans.Identity (IdentityT(IdentityT))

import qualified Data.Map as Map
import Data.Set (Set)


type Wrap = IdentityT

voltageMatrix ::
   (Graph.Edge edge, Ord node, Class.Floating a) =>
   Set node -> Set (Wrap edge node) ->
   Matrix.General (Set (Wrap edge node)) (Set node) a
voltageMatrix nodes =
   Matrix.fromRowArray nodes .
   fmap
      (\(IdentityT e) ->
         Array.fromAssociations 0 nodes
            [(Graph.from e, 1), (Graph.to e, -1)]) .
   BoxedArray.indices

fullMatrix ::
   (Graph.Edge edge, Ord node, Class.Floating a) =>
   Graph edge node a nodeLabel -> node ->
   Triangular.Symmetric (():+:(Set (Wrap edge node):+:Set node)) a
fullMatrix gr src =
   let edges = Map.mapKeysMonotonic IdentityT $ Graph.edgeLabels gr
       nodes = Graph.nodeSet gr
       order = MatrixShape.RowMajor
       symmetricZero sh = Matrix.zero $ MatrixShape.symmetric order sh
       unit = Vector.unit (Map.keysSet edges :+: nodes) (Right src)
   in (symmetricZero (), Matrix.singleRow order unit)
      #%%%#
      (Triangular.diagonal order $ Array.fromMap edges,
            voltageMatrix nodes $ Map.keysSet edges)
      #%%%#
      symmetricZero nodes

resistance ::
   (Graph.Edge edge, Ord node, Class.Floating a) =>
   Graph edge node a nodeLabel -> node -> node -> a
resistance gr src dst =
   let m = fullMatrix gr src
       ix = Right (Right dst)
   in - (m #\| Vector.unit (Triangular.size m) ix) ! ix