--------------------------------------------------------------------------------
{-# LANGUAGE LambdaCase    #-}
{-# LANGUAGE PatternGuards #-}
module Language.Haskell.Stylish.Util
    ( indent
    , padRight
    , everything
    , trimLeft
    , trimRight
    , wrap
    , wrapRest
    , wrapMaybe
    , wrapRestMaybe

    -- * Extra list functions
    , withHead
    , withInit
    , withTail
    , withLast
    , flagEnds

    , traceOutputable
    , traceOutputableM

    , unguardedRhsBody
    , rhsBody

    , getGuards
    ) where


--------------------------------------------------------------------------------
import           Data.Char                     (isSpace)
import           Data.Data                     (Data)
import qualified Data.Generics                 as G
import           Data.Maybe                    (maybeToList)
import           Data.Typeable                 (cast)
import           Debug.Trace                   (trace)
import qualified GHC.Hs                        as Hs
import qualified GHC.Types.SrcLoc              as GHC
import qualified GHC.Utils.Outputable          as GHC


--------------------------------------------------------------------------------
import           Language.Haskell.Stylish.GHC  (showOutputable)
import           Language.Haskell.Stylish.Step


--------------------------------------------------------------------------------
indent :: Int -> String -> String
indent :: Int -> String -> String
indent Int
len = (Int -> String
indentPrefix Int
len String -> String -> String
forall a. [a] -> [a] -> [a]
++)


--------------------------------------------------------------------------------
indentPrefix :: Int -> String
indentPrefix :: Int -> String
indentPrefix = (Int -> Char -> String
forall a. Int -> a -> [a]
`replicate` Char
' ')


--------------------------------------------------------------------------------
padRight :: Int -> String -> String
padRight :: Int -> String -> String
padRight Int
len String
str = String
str String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> Char -> String
forall a. Int -> a -> [a]
replicate (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- String -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length String
str) Char
' '


--------------------------------------------------------------------------------
everything :: (Data a, Data b) => a -> [b]
everything :: forall a b. (Data a, Data b) => a -> [b]
everything = ([b] -> [b] -> [b]) -> GenericQ [b] -> GenericQ [b]
forall r. (r -> r -> r) -> GenericQ r -> GenericQ r
G.everything [b] -> [b] -> [b]
forall a. [a] -> [a] -> [a]
(++) (Maybe b -> [b]
forall a. Maybe a -> [a]
maybeToList (Maybe b -> [b]) -> (a -> Maybe b) -> a -> [b]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Maybe b
forall a b. (Typeable a, Typeable b) => a -> Maybe b
cast)


--------------------------------------------------------------------------------
{-
infoPoints :: [S.Located pass] -> [((Int, Int), (Int, Int))]
infoPoints = fmap (helper . S.getLoc)
  where
    helper :: S.SrcSpan -> ((Int, Int), (Int, Int))
    helper (S.RealSrcSpan s) = do
               let
                start = S.realSrcSpanStart s
                end = S.realSrcSpanEnd s
               ((S.srcLocLine start, S.srcLocCol start), (S.srcLocLine end, S.srcLocCol end))
    helper _                   = ((-1,-1), (-1,-1))
-}

--------------------------------------------------------------------------------
trimLeft :: String -> String
trimLeft :: String -> String
trimLeft  = (Char -> Bool) -> String -> String
forall a. (a -> Bool) -> [a] -> [a]
dropWhile Char -> Bool
isSpace


--------------------------------------------------------------------------------
trimRight :: String -> String
trimRight :: String -> String
trimRight = String -> String
forall a. [a] -> [a]
reverse (String -> String) -> (String -> String) -> String -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> String
trimLeft (String -> String) -> (String -> String) -> String -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> String
forall a. [a] -> [a]
reverse


--------------------------------------------------------------------------------
wrap :: Int       -- ^ Maximum line width
     -> String    -- ^ Leading string
     -> Int       -- ^ Indentation
     -> [String]  -- ^ Strings to add/wrap
     -> Lines     -- ^ Resulting lines
wrap :: Int -> String -> Int -> [String] -> [String]
wrap Int
maxWidth String
leading Int
ind = String -> [String] -> [String]
wrap' String
leading
  where
    wrap' :: String -> [String] -> [String]
wrap' String
ss [] = [String
ss]
    wrap' String
ss (String
str:[String]
strs)
        | String -> String -> Bool
forall {t :: * -> *} {t :: * -> *} {a} {a}.
(Foldable t, Foldable t) =>
t a -> t a -> Bool
overflows String
ss String
str =
            String
ss String -> [String] -> [String]
forall a. a -> [a] -> [a]
: Int -> Int -> [String] -> [String]
wrapRest Int
maxWidth Int
ind (String
strString -> [String] -> [String]
forall a. a -> [a] -> [a]
:[String]
strs)
        | Bool
otherwise = String -> [String] -> [String]
wrap' (String
ss String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
str) [String]
strs

    overflows :: t a -> t a -> Bool
overflows t a
ss t a
str = t a -> Int
forall a. t a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length t a
ss Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
maxWidth Bool -> Bool -> Bool
||
        ((t a -> Int
forall a. t a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length t a
ss Int -> Int -> Int
forall a. Num a => a -> a -> a
+ t a -> Int
forall a. t a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length t a
str) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
maxWidth Bool -> Bool -> Bool
&& Int
ind Int -> Int -> Int
forall a. Num a => a -> a -> a
+ t a -> Int
forall a. t a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length t a
str  Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
maxWidth)


--------------------------------------------------------------------------------
wrapMaybe :: Maybe Int -- ^ Maximum line width (maybe)
          -> String    -- ^ Leading string
          -> Int       -- ^ Indentation
          -> [String]  -- ^ Strings to add/wrap
          -> Lines     -- ^ Resulting lines
wrapMaybe :: Maybe Int -> String -> Int -> [String] -> [String]
wrapMaybe (Just Int
maxWidth) = Int -> String -> Int -> [String] -> [String]
wrap Int
maxWidth
wrapMaybe Maybe Int
Nothing         = String -> Int -> [String] -> [String]
noWrap


--------------------------------------------------------------------------------
noWrap :: String    -- ^ Leading string
       -> Int       -- ^ Indentation
       -> [String]  -- ^ Strings to add
       -> Lines     -- ^ Resulting lines
noWrap :: String -> Int -> [String] -> [String]
noWrap String
leading Int
_ind = String -> [String] -> [String]
noWrap' String
leading
  where
    noWrap' :: String -> [String] -> [String]
noWrap' String
ss []         = [String
ss]
    noWrap' String
ss (String
str:[String]
strs) = String -> [String] -> [String]
noWrap' (String
ss String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
str) [String]
strs


--------------------------------------------------------------------------------
wrapRest :: Int
         -> Int
         -> [String]
         -> Lines
wrapRest :: Int -> Int -> [String] -> [String]
wrapRest Int
maxWidth Int
ind = [String] -> [String]
forall a. [a] -> [a]
reverse ([String] -> [String])
-> ([String] -> [String]) -> [String] -> [String]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [String] -> String -> [String] -> [String]
wrapRest' [] String
""
  where
    wrapRest' :: [String] -> String -> [String] -> [String]
wrapRest' [String]
ls String
ss []
        | String -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null String
ss = [String]
ls
        | Bool
otherwise = String
ssString -> [String] -> [String]
forall a. a -> [a] -> [a]
:[String]
ls
    wrapRest' [String]
ls String
ss (String
str:[String]
strs)
        | String -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null String
ss = [String] -> String -> [String] -> [String]
wrapRest' [String]
ls (Int -> String -> String
indent Int
ind String
str) [String]
strs
        | String -> String -> Bool
forall {t :: * -> *} {t :: * -> *} {a} {a}.
(Foldable t, Foldable t) =>
t a -> t a -> Bool
overflows String
ss String
str = [String] -> String -> [String] -> [String]
wrapRest' (String
ssString -> [String] -> [String]
forall a. a -> [a] -> [a]
:[String]
ls) String
"" (String
strString -> [String] -> [String]
forall a. a -> [a] -> [a]
:[String]
strs)
        | Bool
otherwise = [String] -> String -> [String] -> [String]
wrapRest' [String]
ls (String
ss String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
str) [String]
strs

    overflows :: t a -> t a -> Bool
overflows t a
ss t a
str = (t a -> Int
forall a. t a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length t a
ss Int -> Int -> Int
forall a. Num a => a -> a -> a
+ t a -> Int
forall a. t a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length t a
str Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
maxWidth


--------------------------------------------------------------------------------
wrapRestMaybe :: Maybe Int
              -> Int
              -> [String]
              -> Lines
wrapRestMaybe :: Maybe Int -> Int -> [String] -> [String]
wrapRestMaybe (Just Int
maxWidth) = Int -> Int -> [String] -> [String]
wrapRest Int
maxWidth
wrapRestMaybe Maybe Int
Nothing         = Int -> [String] -> [String]
noWrapRest


--------------------------------------------------------------------------------
noWrapRest :: Int
           -> [String]
           -> Lines
noWrapRest :: Int -> [String] -> [String]
noWrapRest Int
ind = [String] -> [String]
forall a. [a] -> [a]
reverse ([String] -> [String])
-> ([String] -> [String]) -> [String] -> [String]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [String] -> String -> [String] -> [String]
noWrapRest' [] String
""
  where
    noWrapRest' :: [String] -> String -> [String] -> [String]
noWrapRest' [String]
ls String
ss []
        | String -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null String
ss = [String]
ls
        | Bool
otherwise = String
ssString -> [String] -> [String]
forall a. a -> [a] -> [a]
:[String]
ls
    noWrapRest' [String]
ls String
ss (String
str:[String]
strs)
        | String -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null String
ss = [String] -> String -> [String] -> [String]
noWrapRest' [String]
ls (Int -> String -> String
indent Int
ind String
str) [String]
strs
        | Bool
otherwise = [String] -> String -> [String] -> [String]
noWrapRest' [String]
ls (String
ss String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
str) [String]
strs


--------------------------------------------------------------------------------
withHead :: (a -> a) -> [a] -> [a]
withHead :: forall a. (a -> a) -> [a] -> [a]
withHead a -> a
_ []       = []
withHead a -> a
f (a
x : [a]
xs) = a -> a
f a
x a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a]
xs


--------------------------------------------------------------------------------
withLast :: (a -> a) -> [a] -> [a]
withLast :: forall a. (a -> a) -> [a] -> [a]
withLast a -> a
_ []       = []
withLast a -> a
f [a
x]      = [a -> a
f a
x]
withLast a -> a
f (a
x : [a]
xs) = a
x a -> [a] -> [a]
forall a. a -> [a] -> [a]
: (a -> a) -> [a] -> [a]
forall a. (a -> a) -> [a] -> [a]
withLast a -> a
f [a]
xs


--------------------------------------------------------------------------------
withInit :: (a -> a) -> [a] -> [a]
withInit :: forall a. (a -> a) -> [a] -> [a]
withInit a -> a
_ []       = []
withInit a -> a
_ [a
x]      = [a
x]
withInit a -> a
f (a
x : [a]
xs) = a -> a
f a
x a -> [a] -> [a]
forall a. a -> [a] -> [a]
: (a -> a) -> [a] -> [a]
forall a. (a -> a) -> [a] -> [a]
withInit a -> a
f [a]
xs


--------------------------------------------------------------------------------
withTail :: (a -> a) -> [a] -> [a]
withTail :: forall a. (a -> a) -> [a] -> [a]
withTail a -> a
_ []       = []
withTail a -> a
f (a
x : [a]
xs) = a
x a -> [a] -> [a]
forall a. a -> [a] -> [a]
: (a -> a) -> [a] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map a -> a
f [a]
xs



--------------------------------------------------------------------------------
-- | Utility for traversing through a list and knowing when you're at the
-- first and last element.
flagEnds :: [a] -> [(a, Bool, Bool)]
flagEnds :: forall a. [a] -> [(a, Bool, Bool)]
flagEnds = \case
    []         -> []
    [a
x]        -> [(a
x, Bool
True, Bool
True)]
    a
x : a
y : [a]
zs -> (a
x, Bool
True, Bool
False) (a, Bool, Bool) -> [(a, Bool, Bool)] -> [(a, Bool, Bool)]
forall a. a -> [a] -> [a]
: [a] -> [(a, Bool, Bool)]
forall a. [a] -> [(a, Bool, Bool)]
go (a
y a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a]
zs)
  where
    go :: [a] -> [(a, Bool, Bool)]
go (a
x : a
y : [a]
zs) = (a
x, Bool
False, Bool
False) (a, Bool, Bool) -> [(a, Bool, Bool)] -> [(a, Bool, Bool)]
forall a. a -> [a] -> [a]
: [a] -> [(a, Bool, Bool)]
go (a
y a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a]
zs)
    go [a
x]          = [(a
x, Bool
False, Bool
True)]
    go []           = []


--------------------------------------------------------------------------------
traceOutputable :: GHC.Outputable a => String -> a -> b -> b
traceOutputable :: forall a b. Outputable a => String -> a -> b -> b
traceOutputable String
title a
x =
    String -> b -> b
forall a. String -> a -> a
trace (String
title String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
": " String -> String -> String
forall a. [a] -> [a] -> [a]
++ (a -> String
forall a. Outputable a => a -> String
showOutputable a
x))


--------------------------------------------------------------------------------
traceOutputableM :: (GHC.Outputable a, Monad m) => String -> a -> m ()
traceOutputableM :: forall a (m :: * -> *).
(Outputable a, Monad m) =>
String -> a -> m ()
traceOutputableM String
title a
x = String -> a -> m () -> m ()
forall a b. Outputable a => String -> a -> b -> b
traceOutputable String
title a
x (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()


--------------------------------------------------------------------------------
-- Utility: grab the body out of guarded RHSs if it's a single unguarded one.
unguardedRhsBody :: Hs.GRHSs Hs.GhcPs a -> Maybe a
unguardedRhsBody :: forall a. GRHSs GhcPs a -> Maybe a
unguardedRhsBody (Hs.GRHSs XCGRHSs GhcPs a
_ [LGRHS GhcPs a
grhs] HsLocalBinds GhcPs
_)
    | Hs.GRHS XCGRHS GhcPs a
_ [] a
body <- GenLocated (Anno (GRHS GhcPs a)) (GRHS GhcPs a) -> GRHS GhcPs a
forall l e. GenLocated l e -> e
GHC.unLoc LGRHS GhcPs a
GenLocated (Anno (GRHS GhcPs a)) (GRHS GhcPs a)
grhs = a -> Maybe a
forall a. a -> Maybe a
Just a
body
unguardedRhsBody GRHSs GhcPs a
_ = Maybe a
forall a. Maybe a
Nothing


-- Utility: grab the body out of guarded RHSs
rhsBody :: Hs.GRHSs Hs.GhcPs a -> Maybe a
rhsBody :: forall a. GRHSs GhcPs a -> Maybe a
rhsBody (Hs.GRHSs XCGRHSs GhcPs a
_ [LGRHS GhcPs a
grhs] HsLocalBinds GhcPs
_)
    | Hs.GRHS XCGRHS GhcPs a
_ [GuardLStmt GhcPs]
_ a
body <- GenLocated (Anno (GRHS GhcPs a)) (GRHS GhcPs a) -> GRHS GhcPs a
forall l e. GenLocated l e -> e
GHC.unLoc LGRHS GhcPs a
GenLocated (Anno (GRHS GhcPs a)) (GRHS GhcPs a)
grhs = a -> Maybe a
forall a. a -> Maybe a
Just a
body
rhsBody GRHSs GhcPs a
_ = Maybe a
forall a. Maybe a
Nothing


--------------------------------------------------------------------------------
-- get guards in a guarded rhs of a Match
getGuards :: Hs.Match Hs.GhcPs (Hs.LHsExpr Hs.GhcPs) -> [Hs.GuardLStmt Hs.GhcPs]
getGuards :: Match GhcPs (LHsExpr GhcPs) -> [GuardLStmt GhcPs]
getGuards (Hs.Match XCMatch GhcPs (LHsExpr GhcPs)
_ HsMatchContext GhcPs
_ [LPat GhcPs]
_ GRHSs GhcPs (LHsExpr GhcPs)
grhss) =
  let
    lgrhs :: [LGRHS GhcPs (LHsExpr GhcPs)]
lgrhs = GRHSs GhcPs (LHsExpr GhcPs) -> [LGRHS GhcPs (LHsExpr GhcPs)]
getLocGRHS GRHSs GhcPs (LHsExpr GhcPs)
grhss -- []
    grhs :: [GRHS GhcPs (GenLocated SrcSpanAnnA (HsExpr GhcPs))]
grhs  = (GenLocated
   (SrcAnn NoEpAnns)
   (GRHS GhcPs (GenLocated SrcSpanAnnA (HsExpr GhcPs)))
 -> GRHS GhcPs (GenLocated SrcSpanAnnA (HsExpr GhcPs)))
-> [GenLocated
      (SrcAnn NoEpAnns)
      (GRHS GhcPs (GenLocated SrcSpanAnnA (HsExpr GhcPs)))]
-> [GRHS GhcPs (GenLocated SrcSpanAnnA (HsExpr GhcPs))]
forall a b. (a -> b) -> [a] -> [b]
map GenLocated
  (SrcAnn NoEpAnns)
  (GRHS GhcPs (GenLocated SrcSpanAnnA (HsExpr GhcPs)))
-> GRHS GhcPs (GenLocated SrcSpanAnnA (HsExpr GhcPs))
forall l e. GenLocated l e -> e
GHC.unLoc [GenLocated
   (SrcAnn NoEpAnns)
   (GRHS GhcPs (GenLocated SrcSpanAnnA (HsExpr GhcPs)))]
lgrhs
  in
    (GRHS GhcPs (GenLocated SrcSpanAnnA (HsExpr GhcPs))
 -> [GenLocated
       SrcSpanAnnA
       (StmtLR GhcPs GhcPs (GenLocated SrcSpanAnnA (HsExpr GhcPs)))])
-> [GRHS GhcPs (GenLocated SrcSpanAnnA (HsExpr GhcPs))]
-> [GenLocated
      SrcSpanAnnA
      (StmtLR GhcPs GhcPs (GenLocated SrcSpanAnnA (HsExpr GhcPs)))]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap GRHS GhcPs (LHsExpr GhcPs) -> [GuardLStmt GhcPs]
GRHS GhcPs (GenLocated SrcSpanAnnA (HsExpr GhcPs))
-> [GenLocated
      SrcSpanAnnA
      (StmtLR GhcPs GhcPs (GenLocated SrcSpanAnnA (HsExpr GhcPs)))]
getGuardLStmts [GRHS GhcPs (GenLocated SrcSpanAnnA (HsExpr GhcPs))]
grhs


getLocGRHS :: Hs.GRHSs Hs.GhcPs (Hs.LHsExpr Hs.GhcPs) -> [Hs.LGRHS Hs.GhcPs (Hs.LHsExpr Hs.GhcPs)]
getLocGRHS :: GRHSs GhcPs (LHsExpr GhcPs) -> [LGRHS GhcPs (LHsExpr GhcPs)]
getLocGRHS (Hs.GRHSs XCGRHSs GhcPs (LHsExpr GhcPs)
_ [LGRHS GhcPs (LHsExpr GhcPs)]
guardeds HsLocalBinds GhcPs
_) = [LGRHS GhcPs (LHsExpr GhcPs)]
guardeds


getGuardLStmts :: Hs.GRHS Hs.GhcPs (Hs.LHsExpr Hs.GhcPs) -> [Hs.GuardLStmt Hs.GhcPs]
getGuardLStmts :: GRHS GhcPs (LHsExpr GhcPs) -> [GuardLStmt GhcPs]
getGuardLStmts (Hs.GRHS XCGRHS GhcPs (LHsExpr GhcPs)
_ [GuardLStmt GhcPs]
guards LHsExpr GhcPs
_) = [GuardLStmt GhcPs]
guards