{-# LANGUAGE TypeOperators #-}
module Math.LinearCircuit (resistance) where

import qualified Data.Graph.Comfort as Graph
import Data.Graph.Comfort (Graph)

import qualified Numeric.LAPACK.Matrix.Shape as MatrixShape
import qualified Numeric.LAPACK.Matrix.Triangular as Triangular
import qualified Numeric.LAPACK.Matrix as Matrix
import qualified Numeric.LAPACK.Vector as Vector
import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Storable as Array
import Data.Array.Comfort.Storable ((!))
import Data.Array.Comfort.Shape ((:+:))

import Control.Monad.Trans.Identity (IdentityT(IdentityT))

import qualified Data.Map as Map
import qualified Data.Set as Set
import Data.Set (Set)
import Data.Bool.HT (if')


type Wrap = IdentityT

voltageMatrix ::
   (Graph.Edge edge, Ord node, Class.Floating a) =>
   Graph edge node a nodeLabel ->
   Matrix.General (Set (Wrap edge node)) (Set node) a
voltageMatrix gr =
   Matrix.fromRowMajor $
   Array.sample
      (Set.mapMonotonic IdentityT $ Graph.edgeSet gr, Graph.nodeSet gr) $
         \(IdentityT e, n) ->
      if' (Graph.from e == n) 1 $
      if' (Graph.to   e == n) (-1) $
      0

resistanceMatrix ::
   (Graph.Edge edge, Ord node, Class.Floating a) =>
   Graph edge node a nodeLabel ->
   Triangular.Symmetric (Set (Wrap edge node)) a
resistanceMatrix gr =
   Triangular.diagonal MatrixShape.RowMajor $
   Array.fromMap $ Map.mapKeysMonotonic IdentityT $ Graph.edgeLabels gr

fullMatrix ::
   (Graph.Edge edge, Ord node, Class.Floating a) =>
   Graph edge node a nodeLabel -> node ->
   Triangular.Symmetric (():+:(Set (Wrap edge node):+:Set node)) a
fullMatrix gr src =
   let vm = voltageMatrix gr
       mainMatrix =
         Triangular.stackSymmetric
            (resistanceMatrix gr)
            vm
            (Vector.constant
               (MatrixShape.symmetric MatrixShape.RowMajor $ Matrix.width vm) 0)
   in Triangular.stackSymmetric
         (Triangular.symmetricFromList MatrixShape.RowMajor () [0])
         (Matrix.singleRow MatrixShape.RowMajor $
            Vector.unit (Triangular.size mainMatrix) (Right src))
         mainMatrix

resistance ::
   (Graph.Edge edge, Ord node, Class.Floating a) =>
   Graph edge node a nodeLabel -> node -> node -> a
resistance gr src dst =
   let m = fullMatrix gr src
       ix = Right (Right dst)
   in negate $
      Matrix.unliftColumn MatrixShape.ColumnMajor
         (Triangular.solve m)
         (Vector.unit (Triangular.size m) ix)
            ! ix