module Combinatorics.Battleship.Count.ShortenShip.Distribution where import qualified Combinatorics.Battleship.Count.ShortenShip as ShortenShip import qualified Combinatorics.Battleship.Count.CountMap as CountMap import qualified Combinatorics.Battleship.Count.Counter as Counter import qualified Combinatorics.Battleship.Count.Frontier as Frontier import qualified Combinatorics.Battleship.Fleet as Fleet import qualified Combinatorics.Battleship.Size as Size import Combinatorics.Battleship.Size (Nat, Zero, Succ, N10, Size(Size), size) import Combinatorics.Battleship.Count.ShortenShip (countBoundedFromMap) import Foreign.Storable (Storable, sizeOf, alignment, poke, peek) import Foreign.Ptr (Ptr, castPtr) import Control.Monad.HT (void) import Control.Monad (when) import Control.Applicative ((<$>)) import Control.DeepSeq (NFData, rnf, ($!!)) import qualified Data.StorableVector as SV import qualified Data.List.HT as ListHT import qualified Data.List as List import Data.Tuple.HT (mapFst) import Data.Word (Word64) import qualified Test.QuickCheck.Monadic as QCM import qualified Test.QuickCheck as QC {- | We need to encode the height in the type since the Storable instance requires that the size of the binary data can be infered from the Distribution type. -} newtype Distr w h a = Distr {getDistr :: SV.Vector a} instance (Storable a) => NFData (Distr w h a) where rnf = rnf . getDistr countFromDistr :: (Storable a) => Distr w h a -> a countFromDistr = SV.head . getDistr rowsFromDistr :: (Storable a) => Size w -> Distr w h a -> [Row w a] rowsFromDistr (Size width) = map Row . SV.sliceVertical width . SV.tail . getDistr heightType :: Size h -> Distr w h a -> Distr w h a heightType _ = id newtype Size2 w h = Size2 Int size2FromSizes :: Size w -> Size h -> Size2 w h size2FromSizes (Size width) (Size height) = Size2 (1 + width*height) size2 :: (Nat w, Nat h) => Size2 w h size2 = size2FromSizes size size instance (Nat w, Nat h, Storable a) => Storable (Distr w h a) where sizeOf = sizeOfWithSize size2 alignment (Distr xs) = alignment (SV.head xs) poke ptr (Distr xs) = SV.poke (castPtr ptr) xs peek = peekWithSize size2 -- not correct if padding is needed sizeOfWithSize :: (Storable a) => Size2 w h -> Distr w h a -> Int sizeOfWithSize (Size2 n) (Distr xs) = n * sizeOf (SV.head xs) peekWithSize :: (Storable a) => Size2 w h -> Ptr (Distr w h a) -> IO (Distr w h a) peekWithSize (Size2 n) ptr = fmap Distr $ SV.peek n (castPtr ptr) instance (Nat w, Nat h, Counter.C a, Storable a) => Counter.C (Distr w h a) where zero = constant size2 Counter.zero one = constant size2 Counter.one add (Distr x) (Distr y) = Distr $ SV.zipWith Counter.add x y constant :: (Storable a) => Size2 w h -> a -> Distr w h a constant (Size2 n) = Distr . SV.replicate n newtype Row w a = Row {getRow :: SV.Vector a} avg :: (Integral a) => a -> a -> a avg x y = case divMod (x+y) 2 of (z,0) -> z _ -> error "avg: odd sum" symmetric :: (Integral a, Storable a) => Row w a -> Row w a symmetric (Row xs) = Row $ SV.zipWith avg xs (SV.reverse xs) type Count = Word64 type CountMap = CountMap.T Count {-# SPECIALISE CountMap.mergeMany :: [CountDistrMap N10 Zero] -> CountDistrMap N10 Zero #-} type CountDistr w h = Distr w h Count type CountDistrMap w h = CountMap.T w (CountDistr w h) type CountDistrPath w h = CountMap.Path w (CountDistr w h) rowFromFrontier :: (Nat w) => Size w -> Count -> Frontier.T w -> Row w Count rowFromFrontier width cnt = Row . Frontier.mapToVector width (\x -> if x == Frontier.Free then 0 else cnt) addRowToDistr :: Row w Count -> CountDistr w h -> CountDistr w (Succ h) addRowToDistr (Row row) (Distr xs) = Distr $ SV.concat [SV.take 1 xs, row, SV.tail xs] addFrontierToDistr :: (Nat w) => Frontier.T w -> CountDistr w h -> CountDistr w (Succ h) addFrontierToDistr frontier cntDistr = addRowToDistr (rowFromFrontier size (countFromDistr cntDistr) frontier) cntDistr baseCase :: (Nat w) => CountDistrMap w Zero baseCase = CountMap.singleton (Frontier.empty, Fleet.empty) Counter.one nextFrontierBoundedExternal :: (Nat w, Nat h) => Size w -> Fleet.T -> CountDistrPath w (Succ h) -> CountDistrMap w h -> IO () nextFrontierBoundedExternal width maxFleet dst = CountMap.writeSorted dst . map (concatMap (\((frontier,fleet), cntDistr) -> map (\key -> (mapFst (ShortenShip.canonicalFrontier . Frontier.dilate width) key, addFrontierToDistr (fst key) cntDistr)) $ ShortenShip.transitionFrontierBounded width maxFleet frontier fleet)) . ListHT.sliceVertical bucketSize . CountMap.toAscList bucketSize :: Int bucketSize = 2^(11::Int) tmpPath :: Size h -> CountDistrPath w h tmpPath (Size height) = ShortenShip.tmpPath height reportCount :: (Nat w, Nat h) => Fleet.T -> CountDistrPath w h -> IO () reportCount fleet path = do putStrLn "" cd <- countBoundedFromMap fleet <$> CountMap.readFile path print $ countFromDistr cd putStr $ unlines $ map (unwords . map show . SV.unpack . getRow . symmetric) $ rowsFromDistr size cd withReport :: (Nat w, Nat h) => Bool -> Fleet.T -> (CountDistrPath w h -> IO ()) -> IOCountDistrPath w h withReport report fleet act = IOCountDistrPath $ case tmpPath size of path -> do act path when report $ reportCount fleet path return path newtype IOCountDistrPath w h = IOCountDistrPath {runIOCountDistrPath :: IO (CountDistrPath w h)} distributionBoundedExternal :: (Nat w, Nat h) => Bool -> Fleet.T -> IO (CountDistrPath w h) distributionBoundedExternal report fleet = runIOCountDistrPath $ Size.switch (withReport report fleet $ \path -> CountMap.writeFile path baseCase) (withReport report fleet $ \path -> nextFrontierBoundedExternal size fleet path =<< CountMap.readFile =<< distributionBoundedExternal report fleet) countExternal :: IO () countExternal = void (distributionBoundedExternal True Fleet.german :: IO (CountDistrPath N10 N10)) distributionExternalList :: (Nat w, Nat h) => Size w -> Size h -> Fleet.T -> IO (Count, [[Count]]) distributionExternalList w h fleet = do cdm <- (return $!!) . countBoundedFromMap fleet =<< CountMap.readFile =<< distributionBoundedExternal False fleet return (countFromDistr cdm, map (SV.unpack . getRow . symmetric) $ rowsFromDistr w $ heightType h cdm) propCountExternalTotal :: QC.Property propCountExternalTotal = QC.forAllShrink (QC.choose (0,6)) QC.shrink $ \width -> QC.forAllShrink (QC.choose (0,10)) QC.shrink $ \height -> QC.forAllShrink ShortenShip.genFleet QC.shrink $ \fleet -> Size.reifyInt width $ \w -> Size.reifyInt height $ \h -> QCM.monadicIO $ do (c,cd) <- QCM.run $ distributionExternalList w h fleet QCM.assert $ Counter.toInteger c * fromIntegral (sum $ map (uncurry (*)) $ Fleet.toList fleet) == (sum $ map (Counter.toInteger . Counter.sum) cd) propCountExternalSimple :: QC.Property propCountExternalSimple = QC.forAllShrink (QC.choose (0,6)) QC.shrink $ \width -> QC.forAllShrink (QC.choose (0,10)) QC.shrink $ \height -> QC.forAllShrink ShortenShip.genFleet QC.shrink $ \fleet -> Size.reifyInt width $ \w -> Size.reifyInt height $ \h -> QCM.monadicIO $ do (c,_cd) <- QCM.run $ distributionExternalList w h fleet ce <- QCM.run $ ShortenShip.countExternalReturn (w,height) fleet QCM.assert $ Counter.toInteger ce == Counter.toInteger c propCountExternalSymmetric :: QC.Property propCountExternalSymmetric = QC.forAllShrink (QC.choose (0,6)) QC.shrink $ \sz -> QC.forAllShrink ShortenShip.genFleet QC.shrink $ \fleet -> Size.reifyInt sz $ \n -> QCM.monadicIO $ do (_c,cd) <- QCM.run $ distributionExternalList n n fleet QCM.assert $ cd == List.transpose cd propCountExternalTransposed :: QC.Property propCountExternalTransposed = QC.forAllShrink (QC.choose (0,6)) QC.shrink $ \width -> QC.forAllShrink (QC.choose (0,6)) QC.shrink $ \height -> QC.forAllShrink ShortenShip.genFleet QC.shrink $ \fleet -> Size.reifyInt width $ \w -> Size.reifyInt height $ \h -> QCM.monadicIO $ do (c0,cd0) <- QCM.run $ distributionExternalList w h fleet (c1,cd1) <- QCM.run $ distributionExternalList h w fleet QCM.assert $ c0 == c1 QCM.assert $ List.transpose cd0 == cd1