{- |
Magico
<http://www.spectra-verlag.de/>
-}
module Main where

import qualified Numeric.Container as NC
import qualified Data.Packed.Matrix as Matrix
import qualified Data.Packed.Vector as Vector
import Numeric.LinearAlgebra.HMatrix (null1, nullspace, rank)
import Numeric.LinearAlgebra ((<\>))
import Data.Packed.Matrix (Matrix)
import Data.Packed.Vector (Vector)

import Text.Printf (printf)

import qualified Control.Monad.Trans.State as MS
import Control.Applicative (Applicative, liftA2, liftA3, pure, (<*>))
import Control.Functor.HT (outerProduct)

import qualified Data.Foldable as Fold
import qualified Data.List as List
import Data.Traversable (Traversable, traverse, sequenceA)
import Data.Foldable (Foldable, foldMap)
import Data.Monoid ((<>))


data Triple a = Triple a a a deriving (Show)
data Sums a = Sums a a (Triple a) (Triple a) deriving (Show)
type Solution a = Triple (Triple a)


instance Functor Triple where
   fmap f (Triple x y z) = Triple (f x) (f y) (f z)

instance Applicative Triple where
   pure x = Triple x x x
   Triple fx fy fz <*> Triple x y z = Triple (fx x) (fy y) (fz z)

instance Foldable Triple where
   foldMap f (Triple x y z) = f x <> f y <> f z

instance Traversable Triple where
   traverse f (Triple x y z) = liftA3 Triple (f x) (f y) (f z)


instance Functor Sums where
   fmap f (Sums diag0 diag1 horiz vert) =
      Sums (f diag0) (f diag1) (fmap f horiz) (fmap f vert)

instance Foldable Sums where
   foldMap f (Sums diag0 diag1 horiz vert) =
      f diag0 <> f diag1 <> foldMap f horiz <> foldMap f vert


getFromList :: MS.State [a] a
getFromList =
   MS.state $ \(x:xs) -> (x,xs)

solFromList :: [a] -> Solution a
solFromList =
   MS.evalState (sequenceA $ pure $ sequenceA $ pure getFromList)

data Index = I0 | I1 | I2 deriving (Eq, Ord, Enum, Show)
type Index2 = (Index, Index)

indices :: Triple Index
indices = Triple I0 I1 I2

index :: Index -> Triple a -> a
index i (Triple x y z) =
   case i of
      I0 -> x
      I1 -> y
      I2 -> z

index2 :: Index2 -> Triple (Triple a) -> a
index2 (i,j) = index j . index i

flattenIndex2 :: Index2 -> Int
flattenIndex2 (i,j) = 3 * fromEnum i + fromEnum j

sumIndices :: Sums (Triple Index2)
sumIndices =
   Sums
      (Triple (I0,I0) (I1,I1) (I2,I2))
      (Triple (I0,I2) (I1,I1) (I2,I0))
      (outerProduct (,) indices indices)
      (outerProduct (flip (,)) indices indices)

sumOnBoard :: (Num a) => Solution a -> Triple Index2 -> a
sumOnBoard sol is = Fold.sum $ fmap (flip index2 sol) is

sums :: (Num a) => Solution a -> Sums a
sums sol = fmap (sumOnBoard sol) sumIndices

fullMatrix :: Matrix Double
fullMatrix =
   Matrix.fromLists $
   outerProduct
      (\is j -> if Fold.elem j is then 1 else 0)
      (Fold.toList sumIndices)
      (liftA2 (,) (Fold.toList indices) (Fold.toList indices))

removeAt :: Int -> [a] -> (a, [a])
removeAt k xs =
   case splitAt k xs of
      (_, []) -> error "removeAt: index too large"
      (leftXs, pivot:rightXs) -> (pivot, leftXs++rightXs)

insertAt :: Int -> a -> [a] -> [a]
insertAt k x xs =
   case splitAt k xs of
      (leftXs, rightXs) -> leftXs ++ x : rightXs


isIntegerVector :: Vector Double -> Bool
isIntegerVector =
   all (\x -> abs (x - fromInteger (round x)) < 1e-7) . Vector.toList

-- | it must be @NC.sumElements offset == 0@
integerSolutions :: Int -> Vector Double -> Vector Double -> [Vector Double]
integerSolutions maxN offset start =
   let maxI = NC.maxIndex offset
       startI = NC.atIndex start maxI
       offsetI = NC.atIndex offset maxI
   in  filter isIntegerVector $
       map
          (\x ->
             let c = (fromIntegral x - startI) / offsetI
             in  NC.add start $ NC.scale c offset)
          [0 .. maxN]

splittedMatrix :: Index2 -> ((Int, Vector Double), Matrix Double)
splittedMatrix givenI2 =
   let splitPos = flattenIndex2 givenI2
       (givenCol, remCols) = removeAt splitPos $ Matrix.toColumns fullMatrix
   in  ((splitPos, givenCol), Matrix.fromColumns remCols)

ranks :: Triple (Triple Int)
ranks = outerProduct (curry $ rank . snd . splittedMatrix) indices indices

masks :: Triple (Triple (Matrix Double))
masks = outerProduct (curry $ nullspace . snd . splittedMatrix) indices indices

{-
Vectors of the null space look like

 0  1 -1
-1  0  1
 1 -1  0

or

-1  2 -1
 0  0  0
 1 -2  1

The second one is the first one added to its horizontally flipped counterpart.

For givenI2 = (I1,I1) several solutions with a center zero can be combined.
In this case, 'solve' misses some solutions,
because it checks only one dimension, not two.
-}
solve :: (Index2, Int) -> Sums Int -> [Solution Int]
solve (givenI2, given) ss =
   let vec = Vector.fromList $ Fold.toList $ fmap fromIntegral ss
       ((splitPos, givenCol), mat) = splittedMatrix givenI2
   in  map (solFromList . insertAt splitPos given) $
       filter (all (>=0)) $
       map (map round . Vector.toList) $
       integerSolutions (Fold.maximum ss) (null1 mat) $
          mat <\> NC.sub vec (NC.scale (fromIntegral given) givenCol)

printSolution :: Solution Int -> IO ()
printSolution sol = do
   Fold.forM_ sol $
      putStrLn . List.intercalate " " . Fold.toList . fmap (printf "%2d")
   putStrLn ""


exampleA1 :: [Solution Int]
exampleA1 =
   solve ((I0,I2), 2) (Sums 5 8 (Triple 5 8 8) (Triple 9 6 6))

exampleA2 :: [Solution Int]
exampleA2 =
   solve ((I2,I1), 3) (Sums 9 6 (Triple 8 9 7) (Triple 6 9 9))

exampleA3 :: [Solution Int]
exampleA3 =
   solve ((I2,I1), 4) (Sums 9 4 (Triple 6 8 8) (Triple 9 7 6))

exampleA4 :: [Solution Int]
exampleA4 =
   solve ((I1,I2), 4) (Sums 9 6 (Triple 8 9 7) (Triple 9 9 6))

exampleA7 :: [Solution Int]
exampleA7 =
   solve ((I0,I0), 5) (Sums 8 4 (Triple 9 6 6) (Triple 9 9 3))

exampleA20 :: [Solution Int]
exampleA20 =
   solve ((I1,I2), 4) (Sums 7 7 (Triple 9 6 8) (Triple 8 6 9))

exampleB1 :: [Solution Int]
exampleB1 =
   solve ((I0,I1), 4) (Sums 11 11 (Triple 10 7 9) (Triple 9 9 8))

exampleD7 :: [Solution Int]
exampleD7 =
   solve ((I0,I0), 20) (Sums 44 34 (Triple 43 40 36) (Triple 37 49 33))

exampleD20 :: [Solution Int]
exampleD20 =
   solve ((I2,I2), 9) (Sums 37 43 (Triple 46 38 39) (Triple 49 35 39))


main :: IO ()
main = mapM_ printSolution exampleD7