{-# LANGUAGE GADTs #-}

module Control.Concurrent.STM.TVar.Zoom (
    TVar(..)
  , zoomTVar
  , newTVar, newTVarIO
  , pairTVars
  , readTVar, readTVarIO
  , modifyTVar, modifyTVar'
  , writeTVar
  , swapTVar
  , STM.STM, STM.atomically
) where

import Control.Concurrent.STM (STM)
import qualified Control.Concurrent.STM as STM
import Control.Lens 

data TVar a where
   Leaf   :: STM.TVar x           -> ALens'  x a    -> TVar a
   Branch ::     TVar x -> TVar y -> ALens' (x,y) a -> TVar a

newTVar :: a -> STM (TVar a)
newTVar a = Leaf <$> STM.newTVar a <*> pure id

newTVarIO :: a -> IO (TVar a)
newTVarIO a = Leaf <$> STM.newTVarIO a <*> pure id

zoomTVar :: ALens' a b -> TVar a -> TVar b
zoomTVar l1 (Leaf v l2) = Leaf v . fusing $ cloneLens l2 . cloneLens l1
zoomTVar l1 (Branch x y l2) = Branch x y . fusing $ cloneLens l2 . cloneLens l1

pairTVars :: TVar a -> TVar b -> TVar (a,b)
pairTVars x y = Branch x y id

readTVar :: TVar a -> STM a
readTVar (Leaf v l) = (^#l) <$> STM.readTVar v 
readTVar (Branch x y l) = (^#l) <$> readBranch x y

readTVarIO :: TVar a -> IO a
readTVarIO (Leaf v l) = (^#l) <$> STM.readTVarIO v 
readTVarIO (Branch x y l) = STM.atomically $ (^#l) <$> readBranch x y

modifyTVar :: TVar a -> (a -> a) -> STM ()
modifyTVar (Leaf v l) f = STM.modifyTVar v $ l #%~ f
modifyTVar (Branch x'tv y'tv l) f = do
  (x,y) <- (l #%~ f) <$> readBranch x'tv y'tv
  writeTVar x'tv x
  writeTVar y'tv y

modifyTVar' :: TVar a -> (a -> a) -> STM ()
modifyTVar' (Leaf v l) f = STM.modifyTVar' v $ l #%~ f
modifyTVar' b f = modifyTVar b f

writeTVar :: TVar a -> a -> STM ()
writeTVar (Leaf v l) a = STM.modifyTVar' v $ l #~ a
writeTVar b a = modifyTVar b $ const a

swapTVar :: TVar a -> a -> STM a
swapTVar lv a = do
  prev <- readTVar lv
  writeTVar lv a
  return prev

readBranch :: TVar a -> TVar b -> STM (a,b)
readBranch x y = (,) <$> readTVar x <*> readTVar y