module Data.Tree.LogTree (
newTreeData, dotLogTree
, buildTree, newFFTTree
, getLevels, getFlatten, getEval
, modes, values
, coProd
) where
import Data.Complex
import Data.Tree
import Data.List
import Text.Printf (printf, PrintfArg)
import Control.Monad.State.Lazy
import Data.Newtypes.PrettyDouble (PrettyDouble(..))
data CompOp = Sum
| Prod
deriving (Eq)
type CompNodeOutput a = (CompOp, [a])
type CompNode a = [CompNodeOutput a]
type GenericLogTree a = Tree (Maybe a, [Int], Int, Bool)
class (Show a, t ~ GenericLogTree a) => LogTree t a | t -> a where
evalNode :: t -> [a]
getTwiddles :: t -> [[a]]
getTwiddleStrs :: t -> [[String]]
getTwiddleStrs = map (map show) . getTwiddles
getCompNodes :: t -> [CompNode a]
type FFTTree = GenericLogTree (Complex PrettyDouble)
instance LogTree FFTTree (Complex PrettyDouble) where
evalNode (Node (Just x, _, _, _) _) = [x]
evalNode (Node ( _, _, _, dif) children) =
foldl (zipWith (+)) [0.0 | n <- [1..nodeLen]]
$ zipWith (zipWith (*)) subTransforms phasors
where subTransforms =
[ subCombFunc
$ map evalNode
[ snd (coProd twiddle child)
| twiddle <- twiddles
]
| child <- children
]
subCombFunc =
if dif then concat . transpose
else concat
childLen = length $ last(levels $ head children)
radix = length children
nodeLen = childLen * radix
phasors = [ [ cis((2.0) * pi / degree * fromIntegral r * fromIntegral k)
| k <- [0..(nodeLen 1)]]
| r <- [0..(radix 1)]]
degree | dif = fromIntegral radix
| otherwise = fromIntegral nodeLen
twiddles = getTwiddles (Node (Nothing, [], 0, dif) children)
getTwiddles (Node ( _, _, _, dif) children) =
if dif
then [ [ cis((2.0) * pi / fromIntegral nodeLen * fromIntegral m * fromIntegral n)
| n <- [0..(childLen 1)]]
| m <- [0..(radix 1)]]
else [ [ 1.0 :+ 0.0
| n <- [0..(childLen 1)]]
| m <- [0..(radix 1)]]
where nodeLen = childLen * radix
childLen = length $ last(levels $ head children)
radix = length children
getTwiddleStrs (Node ( _, _, _, dif) children) =
if dif
then map (map ((\str -> " [" ++ str ++ "]") . show)) $ getTwiddles (Node (Nothing, [], 0, dif) children)
else [["" | i <- [1..(length (last (levels child)))]] | child <- children]
getCompNodes (Node ( Just x, _, _, _) _) = []
getCompNodes (Node (Nothing, _, _, dif) children) =
[ [ (Sum, [ cis (2.0 * pi * k * r / degree)
| r <- map fromIntegral [0..(radix 1)]
]
)
| k <- map fromIntegral [childLen * r + m | r <- [0..(radix 1)]]
]
| m <- map fromIntegral [0..(childLen 1)]
] where childLen = fromIntegral $ length $ last(levels $ head children)
radix = length children
nodeLen = childLen * radix
degree | dif = fromIntegral radix
| otherwise = fromIntegral nodeLen
coProd :: (Num a, t ~ GenericLogTree a) => [a] -> t -> ([a], t)
coProd [] (Node (Just x, offsets, skipFactor, dif) _) =
([], Node (Just x, offsets, skipFactor, dif) [])
coProd [a] (Node (Just x, offsets, skipFactor, dif) _) =
([], Node (Just (a * x), offsets, skipFactor, dif) [])
coProd (a:as) (Node (Just x, offsets, skipFactor, dif) _) =
(as, Node (Just (a * x), offsets, skipFactor, dif) [])
coProd as (Node (_, offsets, skipFactor, dif) children) =
(bs, Node (Nothing, offsets, skipFactor, dif) childProds)
where (bs, childProds) = foldl coProdStep (as, []) children
coProdStep :: (Num a, t ~ GenericLogTree a) => ([a], [t]) -> t -> ([a], [t])
coProdStep (as, ts) t = (bs, ts ++ [t'])
where (bs, t') = coProd as t
data TreeData a = TreeData {
modes :: [(Int, Bool)]
, values :: [a]
} deriving(Show)
newTreeData :: [(Int, Bool)]
-> [a]
-> TreeData a
newTreeData modes values = TreeData {
modes = modes
, values = values
}
newtype TreeBuilder t = TreeBuilder {
buildTree :: LogTree t a => TreeData a -> Either String t
}
newFFTTree :: TreeBuilder FFTTree
newFFTTree = TreeBuilder buildMixedRadixTree
buildMixedRadixTree :: TreeData a -> Either String (GenericLogTree a)
buildMixedRadixTree td = mixedRadixRecurse 0 1 td_modes td_values
where td_modes = modes td
td_values = values td
mixedRadixRecurse :: Int -> Int -> [(Int, Bool)] -> [a] -> Either String (GenericLogTree a)
mixedRadixRecurse _ _ _ [] = Left "mixedRadixRecurse(): called with empty list."
mixedRadixRecurse myOffset _ _ [x] = return $ Node (Just x, [myOffset], 0, False) []
mixedRadixRecurse myOffset mySkipFactor modes xs
| product (map fst modes) == length xs =
do
children <- sequence [ mixedRadixRecurse childOffset childSkipFactor
(tail modes) subList
| (childOffset, subList) <- zip childOffsets subLists
]
return $ Node (Nothing, childOffsets, childSkipFactor, dif) children
| otherwise =
Left "mixedRadixRecurse: Product of radices must equal length of input."
where subLists = [ [xs !! (offset + i * skipFactor) | i <- [0..(childLen 1)]]
| offset <- offsets
]
childSkipFactor | dif = mySkipFactor
| otherwise = mySkipFactor * radix
childOffsets | dif = [myOffset + (i * mySkipFactor * childLen) | i <- [0..(radix 1)]]
| otherwise = [myOffset + i * mySkipFactor | i <- [0..(radix 1)]]
skipFactor | dif = 1
| otherwise = radix
offsets | dif = [i * childLen | i <- [0..(radix 1)]]
| otherwise = [0..(radix 1)]
childLen = length xs `div` radix
radix = fst $ head modes
dif = snd $ head modes
dotLogTree :: (Show a, Eq a, Num a, LogTree t a) => Either String t -> (String, String)
dotLogTree (Left msg) = (header
++ "\"node0\" [label = \"" ++ msg ++ "\"]\n"
++ "}\n", "")
dotLogTree (Right tree) = (header
++ treeStr
++ "}\n",
compNodeLegend)
where (treeStr, compNodeTypes) = runState (dotLogTreeRecurse "0" (getCompNodes tree) tree twiddles) []
twiddles = concat $ getTwiddleStrs $ Node (Nothing, [], 0, False) $ subForest tree
nodeLen = fromIntegral $ length $ last (levels tree)
compNodeLegend = "digraph {\n"
++ "label = \"Computational Node Legend\" fontsize = \"24\"\n"
++ "\"node0L\""
++ " [label = <<table border=\"0\" cellborder=\"0\" cellpadding=\"3\" bgcolor=\"white\"> \\ \n"
++ unlines indexedStrs
++ "</table>>, shape = \"Mrecord\""
++ "];\n}\n"
indexedStrs = map (\str -> "<tr> \\ \n" ++ str ++ "</tr> \\") legendStrs
legendStrs = map (\(nodeType, typeInd) ->
concat $ (" <td align=\"right\">" ++ show typeInd ++ ":</td> \\ \n") : outSpecs nodeType
) $ zip compNodeTypes [0..]
outSpecs :: (Show a) => CompNode a -> [String]
outSpecs nodeOutputs = map (\(nodeOutput, yInd) ->
let opStr = case fst nodeOutput of
Sum -> " + "
Prod -> " * "
in " <td align=\"left\">y" ++ show yInd ++ " = "
++ intercalate opStr (map (\(coeff, k) -> "(" ++ show coeff ++ printf ") * x%d" k)
$ zip (snd nodeOutput) [(0::Int)..])
++ "</td> \\ \n"
) $ zip nodeOutputs [(0::Int)..]
header = "digraph g { \n \
\ ranksep = \"1.5\";\n \
\ nodesep = \"0\";\n \
\ label = \"Divide & Conquer Processing Graph\";\n \
\ labelloc = \"t\";\n \
\ fontsize = \"28\" \n \
\ graph [ \n \
\ rankdir = \"RL\" \n \
\ splines = \"false\" \n \
\ ]; \n \
\ node [ \n \
\ fontsize = \"16\" \n \
\ shape = \"circle\" \n \
\ height = \"0.3\" \n \
\ ]; \n \
\ edge [ \n \
\ dir = \"back\" \n \
\ ];\n"
dotLogTreeRecurse :: (Show a, Eq a, Num a, LogTree t a) => String -> [CompNode a] -> t -> [String] -> State [CompNode a] String
dotLogTreeRecurse nodeID _ (Node (Just x, offsets, _, _) _) twiddleVec =
return $ "\"node" ++ nodeID ++ "\" [label = \"<f0> "
++ "[" ++ show (head offsets) ++ "] " ++ show x ++ head twiddleVec
++ "\" shape = \"record\"];\n"
dotLogTreeRecurse nodeID compNodes (Node ( _, childOffsets, skip, dif) children) twiddleVec = do
let selfStr =
"\"node" ++ nodeID ++ "\" [label = \"<f0> "
++ show (head res) ++ head twiddleVec
++ concat [" | <f" ++ show k ++ "> " ++ show val ++ twiddle
| ((val, k), twiddle) <- zip (zip (tail res) [1..]) (tail twiddleVec)]
++ "\" shape = \"record\"];\n"
childrenStr <- liftM concat $
mapM (\((childID, child), twiddleVec) ->
do curState <- get
let (childStr, newState) =
runState (dotLogTreeRecurse childID (getCompNodes child) child twiddleVec) curState
put newState
return childStr
) $ zip (zip childIDs $ map (snd . coProd twiddleChoice) children) twiddles
compNodeStrs <- forM (zip compNodes [0..]) (\(compNode, k') -> do
let compNodeID = nodeID ++ "C" ++ show k'
curState <- get
let (compNodeType, newState) = runState (getCompNodeType compNode) curState
put newState
return $ "\"node" ++ compNodeID ++ "\""
++ " [label = \"" ++ show compNodeType ++ "\""
++ ", shape = \"circle\""
++ ", height = \"0.1\""
++ "];\n")
let conexStrs = [
"\"node" ++ nodeID ++ "\":f" ++ show (r * childLen + k')
++ " -> \"node" ++ nodeID ++ "C" ++ show k' ++ "\""
++ " [headlabel = \"y" ++ show r ++ "\" labelangle = \"-30\" labeldistance = \"2\"];\n"
++ "\"node" ++ nodeID ++ "C" ++ show k' ++ "\""
++ " -> \"node" ++ nodeID ++ show r ++ "\":f" ++ show k'
++ " [taillabel = \"x" ++ show r ++ "\" labelangle = \"20\" labeldistance = \"2.5\"];\n"
| k' <- [0..(length compNodes 1)]
, r <- [0..(length children 1)]
]
return (selfStr ++ childrenStr ++ concat compNodeStrs ++ concat conexStrs)
where childIDs = [nodeID ++ show i | i <- [0..(length children 1)]]
childLen = fromIntegral $ length $ last(levels $ head children)
res = evalNode $ Node (Nothing, childOffsets, skip, dif) children
twiddles = getTwiddleStrs $ Node (Nothing, [], 0, dif) children
twiddleVals = getTwiddles $ Node (Nothing, [], 0, dif) children
twiddleChoice = head $ reverse $ twiddleVals
getCompNodeType :: Eq a => CompNode a -> State [CompNode a] Int
getCompNodeType compNode = do
compNodes <- get
let (newCompNodes, compNodeType) = fetchCompNodeType compNode compNodes
put newCompNodes
return compNodeType
fetchCompNodeType :: Eq a => CompNode a -> [CompNode a] -> ([CompNode a], Int)
fetchCompNodeType compNode compNodes =
case findCompNode 0 compNode compNodes of
Just compNodeIndex -> (compNodes, compNodeIndex)
Nothing -> (compNodes ++ [compNode], length compNodes)
findCompNode :: Eq a => Int -> CompNode a -> [CompNode a] -> Maybe Int
findCompNode _ _ [] = Nothing
findCompNode index compNode (cn : cns) =
if compNode == cn
then Just index
else findCompNode (index + 1) compNode cns
getValue :: LogTree t a => t -> Maybe a
getValue (Node (x, _, _, _) _) = x
getEval (Left msg) = []
getEval (Right tree) = evalNode tree
getLevels (Left msg) = []
getLevels (Right tree) = levels tree
getFlatten (Left msg) = []
getFlatten (Right tree) = flatten tree