{-# LANGUAGE RankNTypes #-}

module System.Zfs.Iter (
       getRootCount,
       getChildrenCount,
       getZpoolCount,
       getRoots,
       getChildren,
       getZpools,
       forRoots,
       forChildren,
       forFilesystems,
       forSnapshots,
       forZpools,
       forZpools_,
       forVdevs
       ) where

import Control.Monad
import Control.Monad.IO.Class
import qualified System.Zfs.Lowlevel as L
import System.Zfs.Types
import System.Zfs.Errors
import System.Zfs.Zpool
import Foreign.ForeignPtr
import Foreign.StablePtr
import Foreign.Ptr
import Foreign.Marshal.Alloc
import Foreign.Marshal.Array
import Foreign.Storable

-- Abstract function for getting the number of { root nodes | child nodes | .. }
getIterCount :: MonadIO m => ptr -> (ptr -> L.ZfsIterF Int -> Ptr Int -> IO Int) -> ZfsT z m Int
getIterCount inptr iterfun = Zfs $ \(ZfsContext z) ->
  liftIO $ alloca $ \ptr -> do
    poke ptr (0 :: Int)
    fun' <- L.wrap_zfs_iter fun
    iterfun inptr fun' ptr
    i <- peek ptr
    return $ Right i
  where fun zfs ptr = do
          i <- peek ptr
          poke ptr (i+1 :: Int)
          L.zfs_close zfs
          return 0

-- | Get number of root zfs
getRootCount :: Zfs z Int
getRootCount = do
  z <- Zfs $ \(ZfsContext z) -> return (Right z)
  getIterCount z L.zfs_iter_root

-- | Get number of child zfs
getChildrenCount :: Zdataset z -> Zfs z Int
getChildrenCount (Zdataset node) = Zfs $ \z -> liftIO $
  withForeignPtr node $ \ptr ->
  runZfs' (getIterCount ptr L.zfs_iter_children) z

-- Abstract function for iterating over all { root zfs | child zfs | ... }
forIter :: MonadIO m => ptr -> (forall b. ptr -> L.ZfsIterF b -> Ptr b -> IO Int) -> (Zdataset z -> ZfsT z IO a) -> ZfsT z m [a]
forIter inptr iterfun f = do
  count <- getIterCount inptr iterfun
  let reservedSize = elemSize * count + indexSize
      elemSize = sizeOf (nullPtr :: Ptr ())
      indexSize = sizeOf (0 :: Int)
      fun z znode' ptr = do
        znode <- liftM Zdataset $ newForeignPtr L.zfs_close_ znode'
        res <- runZfs' (f znode) z
        case res of
          Left e -> return 1 -- TODO !!!!
          Right a -> do
            i <- peek (castPtr ptr)
            poke (castPtr ptr) (i+1 :: Int)
            sptr <- newStablePtr a
            let sptr' = castStablePtrToPtr sptr
            poke (plusPtr ptr (indexSize + elemSize * i)) sptr'
            return 0
  Zfs $ \z ->
    liftIO $ allocaBytes reservedSize $ \ptr -> do
      poke (castPtr ptr) (0 :: Int)
      fun' <- L.wrap_zfs_iter $ fun z
      iterfun inptr fun' ptr
      as <- peekArray count (plusPtr ptr indexSize)
      as' <- mapM (deRefStablePtr . castPtrToStablePtr) as
      mapM_ (freeStablePtr . castPtrToStablePtr) as
      return $ Right as'

-- | Iterate over all root zfs
forRoots :: MonadIO m => (Zdataset z -> ZfsT z IO a) -> ZfsT z m [a]
forRoots fun = do
  z <- Zfs $ \(ZfsContext z) -> return (Right z)
  forIter z L.zfs_iter_root fun

-- | Iterate over all children
forChildren :: MonadIO m => Zdataset z -> (Zdataset z -> ZfsT z IO a) -> ZfsT z m [a]
forChildren (Zdataset node) fun = Zfs $ \z -> liftIO $
  withForeignPtr node $ \ptr ->
  runZfs' (forIter ptr L.zfs_iter_children fun) z

-- | Iterate over all child filesystems
forFilesystems :: MonadIO m => Zdataset z -> (Zdataset z -> ZfsT z IO a) -> ZfsT z m [a]
forFilesystems (Zdataset node) fun = Zfs $ \z -> liftIO $
  withForeignPtr node $ \ptr ->
  runZfs' (forIter ptr L.zfs_iter_filesystems fun) z

-- | Iterate over all child zfs
forSnapshots :: MonadIO m => Zdataset z -> (Zdataset z -> ZfsT z IO a) -> ZfsT z m [a]
forSnapshots (Zdataset node) fun = Zfs $ \z -> liftIO $
  withForeignPtr node $ \ptr ->
  runZfs' (forIter ptr L.zfs_iter_snapshots fun) z


-- | Get all root file systems. Implemented using `forRoots'.
getRoots :: Zfs z [Zdataset z]
getRoots = forRoots return

-- | Get all zfs children . Implemented using `forChildren'.
getChildren :: Zdataset z -> Zfs z [Zdataset z]
getChildren z = forChildren z return