{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Ormolu.Printer.Operators
( OpTree (..),
opTreeLoc,
reassociateOpTree,
)
where
import BasicTypes (Fixity (..), SourceText (NoSourceText), compareFixity, defaultFixity)
import Data.Function (on)
import Data.List
import Data.Maybe (fromMaybe)
import GHC
import OccName (mkVarOcc)
import RdrName (mkRdrUnqual)
import SrcLoc (combineSrcSpans)
data OpTree ty op
= OpNode ty
| OpBranch
(OpTree ty op)
op
(OpTree ty op)
opTreeLoc :: OpTree (Located a) b -> SrcSpan
opTreeLoc (OpNode (L l _)) = l
opTreeLoc (OpBranch l _ r) = combineSrcSpans (opTreeLoc l) (opTreeLoc r)
reassociateOpTree ::
(op -> Maybe RdrName) ->
OpTree (Located ty) (Located op) ->
OpTree (Located ty) (Located op)
reassociateOpTree getOpName opTree =
reassociateOpTreeWith
(buildFixityMap getOpName normOpTree)
(getOpName . unLoc)
normOpTree
where
normOpTree = normalizeOpTree opTree
reassociateOpTreeWith ::
forall ty op.
[(RdrName, Fixity)] ->
(op -> Maybe RdrName) ->
OpTree ty op ->
OpTree ty op
reassociateOpTreeWith fixityMap getOpName = go
where
fixityOf :: op -> Fixity
fixityOf op = fromMaybe defaultFixity $ do
opName <- getOpName op
lookup opName fixityMap
go :: OpTree ty op -> OpTree ty op
go t@(OpNode _) = t
go t@(OpBranch (OpNode _) _ (OpNode _)) = t
go (OpBranch l@(OpNode _) op (OpBranch l' op' r')) =
go (OpBranch (OpBranch l op l') op' r')
go (OpBranch (OpBranch l op r) op' r'@(OpNode _)) =
if snd $ compareFixity (fixityOf op) (fixityOf op')
then OpBranch l op (go $ OpBranch r op' r')
else OpBranch (OpBranch l op r) op' r'
go (OpBranch (OpBranch l op r) op' (OpBranch l' op'' r')) =
if snd $ compareFixity (fixityOf op) (fixityOf op')
then go $ OpBranch (OpBranch l op (go $ OpBranch r op' l')) op'' r'
else go $ OpBranch (OpBranch (OpBranch l op r) op' l') op'' r'
buildFixityMap ::
forall ty op.
(op -> Maybe RdrName) ->
OpTree (Located ty) (Located op) ->
[(RdrName, Fixity)]
buildFixityMap getOpName opTree =
concatMap (\(i, ns) -> map (\(n, _) -> (n, fixity i InfixL)) ns)
. zip [0 ..]
. groupBy (doubleWithinEps 0.00001 `on` snd)
. (overrides ++)
. avgScores
$ score opTree
where
overrides :: [(RdrName, Double)]
overrides =
[ (mkRdrUnqual $ mkVarOcc "$", -1)
]
score :: OpTree (Located ty) (Located op) -> [(RdrName, Double)]
score (OpNode _) = []
score (OpBranch l o r) = fromMaybe (score r) $ do
le <- srcSpanEndLine <$> unSrcSpan (opTreeLoc l)
ob <- srcSpanStartLine <$> unSrcSpan (getLoc o)
oe <- srcSpanEndLine <$> unSrcSpan (getLoc o)
rb <- srcSpanStartLine <$> unSrcSpan (opTreeLoc r)
oc <- srcSpanStartCol <$> unSrcSpan (getLoc o)
opName <- getOpName (unLoc o)
let s =
if le < ob
then
fromIntegral oc / fromIntegral (maxCol + 1)
else
if oe < rb
then 1
else 2
return $ (opName, s) : score r
avgScores :: [(RdrName, Double)] -> [(RdrName, Double)]
avgScores =
sortOn snd
. map (\xs@((n, _) : _) -> (n, avg $ map snd xs))
. groupBy ((==) `on` fst)
. sort
avg :: [Double] -> Double
avg i = sum i / fromIntegral (length i)
maxCol = go opTree
where
go (OpNode (L _ _)) = 0
go (OpBranch l (L o _) r) =
maximum
[ go l,
maybe 0 srcSpanStartCol (unSrcSpan o),
go r
]
unSrcSpan (RealSrcSpan r) = Just r
unSrcSpan (UnhelpfulSpan _) = Nothing
normalizeOpTree :: OpTree ty op -> OpTree ty op
normalizeOpTree (OpNode n) =
OpNode n
normalizeOpTree (OpBranch (OpNode l) lop r) =
OpBranch (OpNode l) lop (normalizeOpTree r)
normalizeOpTree (OpBranch (OpBranch l' lop' r') lop r) =
normalizeOpTree (OpBranch l' lop' (OpBranch r' lop r))
fixity :: Int -> FixityDirection -> Fixity
fixity = Fixity NoSourceText
doubleWithinEps :: Double -> Double -> Double -> Bool
doubleWithinEps eps a b = abs (a - b) < eps