module EVM.Flatten (flatten) where
import EVM.Dapp (DappInfo, dappSources)
import EVM.Solidity (sourceAsts)
import EVM.Demand (demand)
import Control.Lens (preview, view, universe)
import Data.Aeson (Value (String))
import Data.Aeson.Lens (key, _String, _Array, _Integer)
import qualified Data.Graph.Inductive.Graph as Fgl
import qualified Data.Graph.Inductive.PatriciaTree as Fgl
import qualified Data.Graph.Inductive.Query.BFS as Fgl
import qualified Data.Graph.Inductive.Query.DFS as Fgl
import Data.SemVer (SemVerRange, parseSemVerRange)
import qualified Data.SemVer as SemVer
import Control.Monad (forM)
import Data.ByteString (ByteString)
import Data.Foldable (foldl', toList)
import Data.List (sort, nub)
import Data.Map (Map, (!), (!?))
import Data.Maybe (mapMaybe, isJust, catMaybes, fromMaybe)
import Data.Monoid ((<>))
import Data.Text (Text, unpack, pack, intercalate)
import Data.Text.Encoding (encodeUtf8)
import Text.Read (readMaybe)
import qualified Data.Map as Map
import qualified Data.Text as Text
import qualified Data.ByteString as BS
type FileGraph = Fgl.Gr Text ()
importsFrom :: Value -> [Text]
importsFrom ast =
  let
    
    
    allNodes :: [Value]
    allNodes = universe ast
    
    
    resolveImport :: Value -> Maybe Text
    resolveImport node =
      case preview (key "name") node of
        Just (String "ImportDirective") ->
          preview (key "attributes" . key "absolutePath" . _String) node
        _ ->
          Nothing
  
  in mapMaybe resolveImport allNodes
flatten :: DappInfo -> Text -> IO ()
flatten dapp target = do
  let
    
    graph :: FileGraph
    graph = Fgl.mkGraph nodes edges
    
    nodes :: [(Int, Text)]
    nodes = zip [1..] (Map.keys asts)
    
    edges =
      [ (indices ! s, indices ! t, ()) 
      | (s, v) <- Map.toList asts      
      , t      <- importsFrom v ]      
    
    indices :: Map Text Int
    indices = Map.fromList [(v, k) | (k, v) <- nodes]
    
    asts :: Map Text Value
    asts = view (dappSources . sourceAsts) dapp
    topScopeIds :: [Integer]
    topScopeIds = mconcat $ fmap f $ Map.elems asts
      where
        id' = preview (key "id" . _Integer)
        f ast =
          [ fromJust' "no id for SourceUnit" $ id' node
          | node <- universe ast
          , nodeIs "SourceUnit" node
          ]
    contractsAndStructsToRename :: Map Integer Text
    contractsAndStructsToRename =
      Map.fromList
        $ indexed [ x | x <- xs, (snd x) `elem` xs' ]
      where
        xs = mconcat $ fmap f $ Map.elems asts
        xs' = repeated $ fmap snd xs
        scope = preview (key "attributes" . key "scope" . _Integer)
        name = preview (key "attributes" . key "name" . _String)
        id' = preview (key "id" . _Integer)
        p x = (nodeIs "ContractDefinition" x || nodeIs "StructDefinition" x)
          && (fromJust' "no contract/struct scope" $ scope x) `elem` topScopeIds
        f ast =
          [ ( fromJust' "no id for top scoped contract or struct" $ id' node
            , fromJust' "no id for top scoped contract or struct" $ name node
            )
          | node <- universe ast
          , p node
          ]
    contractStructs :: [(Integer, (Integer, Text))]
    contractStructs = mconcat $ fmap f $ Map.elems asts
      where
        scope = preview (key "attributes" . key "scope" . _Integer)
        cname = preview (key "attributes" . key "canonicalName" . _String)
        id' = preview (key "id" . _Integer)
        p x = (nodeIs "StructDefinition" x)
          && (fromJust' "line:137 nested struct" $ scope x) `Map.member` contractsAndStructsToRename
        f ast =
          [ let
              id'' = fromJust' "no id for nested struct" $ id' node
              cname' = fromJust'
                ("no canonical name of nested struct with id:" ++ show id'') $ cname node
              ref = fromJust'
                ("no scope of nested struct with id:" ++ show id'') $ scope node
            in
              (id'', (ref, cname'))
          | node <- universe ast
          , p node
          ]
  
  
  case Map.lookup target indices of
    Nothing ->
      error "didn't find contract AST"
    Just root -> do
      let
        
        
        subgraph :: Fgl.Gr Text ()
        subgraph = Fgl.subgraph (Fgl.bfs root graph) graph
        
        
        ordered :: [Text]
        ordered = reverse (Fgl.topsort' subgraph)
        
        pragma :: Text
        pragma = maximalPragma (Map.elems (Map.filterWithKey (\k _ -> k `elem` ordered) asts))
      
      
      sources <-
        forM ordered $ \path -> do
          src <- BS.readFile (unpack path)
          pure $ mconcat
            [ "////// ", encodeUtf8 path, "\n"
            
            , fst
                (prefixContractAst
                  contractsAndStructsToRename
                  contractStructs
                  (stripImportsAndPragmas (src, 0) (asts ! path))
                  (asts ! path)), "\n"
            ]
      
      
      demand target; demand pragma; demand sources
      
      putStrLn $ "// hevm: flattened sources of " <> unpack target
      putStrLn (unpack pragma)
      BS.putStr (mconcat sources)
maximalPragma :: [Value] -> Text
maximalPragma asts = (
    case mapMaybe versions asts of
      [] -> error "no Solidity version pragmas in any source files"
      xs ->
        "pragma solidity "
          <> pack (show (rangeIntersection xs))
          <> ";\n"
  )
  <> (
    mconcat . nub . sort . fmap (\ast ->
      mconcat $ fmap
        (\xs -> "pragma "
          <> intercalate " " [x | String x <- xs]
          <> ";\n")
        (otherPragmas ast)
    )
  ) asts
  where
    isVersionPragma :: [Value] -> Bool
    isVersionPragma =
      \case
        String "solidity" : _ -> True
        _ -> False
    pragmaComponents :: Value -> [[Value]]
    pragmaComponents ast = components
      where
        ps :: [Value]
        ps = filter (nodeIs "PragmaDirective") (universe ast)
        components :: [[Value]]
        components = catMaybes $ fmap
          ((fmap toList) . preview (key "attributes" . key "literals" . _Array))
          ps
    
    
    
    rangeIntersection :: [SemVerRange] -> SemVerRange
    rangeIntersection = foldr1 SemVer.And . nub . sort
    
    
    versions :: Value -> Maybe SemVerRange
    versions ast = fmap grok components
      where
        components :: Maybe [Value]
        components =
          case filter isVersionPragma (pragmaComponents ast) of
            [_:xs] -> Just xs
            []  -> Nothing
            x   -> error $ "multiple version pragmas" ++ show x
        grok :: [Value] -> SemVerRange
        grok xs =
          let
            rangeText = mconcat [x | String x <- xs]
          in
            case parseSemVerRange rangeText of
              Right r -> r
              Left _ ->
                error ("failed to parse SemVer range " ++ show rangeText)
    otherPragmas :: Value -> [[Value]]
    otherPragmas = (filter (not . isVersionPragma)) . pragmaComponents
nodeIs :: Text -> Value -> Bool
nodeIs t x = isSourceNode && hasRightName
  where
    isSourceNode =
      isJust (preview (key "src") x)
    hasRightName =
      Just t == preview (key "name" . _String) x
stripImportsAndPragmas :: (ByteString, Int) -> Value -> (ByteString, Int)
stripImportsAndPragmas bso ast = stripAstNodes bso ast p
  where
    p x = nodeIs "ImportDirective" x || nodeIs "PragmaDirective" x
stripAstNodes :: (ByteString, Int)-> Value -> (Value -> Bool) -> (ByteString, Int)
stripAstNodes bso ast p =
  cutRanges [sourceRange node | node <- universe ast, p node]
  where
    
    
    cutRanges :: [(Int, Int)] -> (ByteString, Int)
    cutRanges (sort -> rs) = foldl' f bso rs
      where
        f (bs', n) (i, j) =
          ( cut bs' (i + n) (j + n)
          , n + length ("/*  */" :: String))
    
    cut :: ByteString -> Int -> Int -> ByteString
    cut x i j =
      let (a, b) = BS.splitAt i x
      in a <> "/* " <> BS.take (j - i) b <> " */" <> BS.drop (j - i) b
readAs :: Read a => Text -> Maybe a
readAs = readMaybe . Text.unpack
prefixContractAst :: Map Integer Text -> [(Integer, (Integer, Text))] -> (ByteString, Int) -> Value -> (ByteString, Int)
prefixContractAst castr cs bso ast = prefixAstNodes
  where
    bs = fst bso
    refDec = preview (key "attributes" . key "referencedDeclaration" . _Integer)
    name = preview (key "attributes" . key "name" . _String)
    id' = preview (key "id" . _Integer)
    
    p x = (nodeIs "ContractDefinition" x || nodeIs "StructDefinition" x)
      && (fromJust' "id of any" $ id' x) `Map.member` castr
    
    p' x =
      (nodeIs "Identifier" x || nodeIs "UserDefinedTypeName" x)
        && (fromJust' "refDec of ident/userdef" $ refDec x) `Map.member` castr
    
    
    p'' x =
      (nodeIs "Identifier" x || nodeIs "UserDefinedTypeName" x)
      && (isJust $ name x)
      && (
        let
          refs = fmap fst cs
          i = fromJust' "no id for ident/userdef" $ id' x
          ref = fromJust' ("no refDec for ident/userdef: " ++ show i) $ refDec x
          n = fromJust' ("no name for ident/userdef: " ++ show i) $ name x
          cn = fromJust'
            ("no match for lookup in nested structs: "
              ++ show i
              ++ " -> "
              ++ show ref
            ) $ lookup ref cs
        in
          
          
          ref `elem` refs && n == snd cn
      )
    p''' x = p x || p' x || p'' x
    prefixAstNodes :: (ByteString, Int)
    prefixAstNodes  =
      cutRanges [sourceId node | node <- universe ast, p''' node]
    
    
    sourceId :: Value -> (Int, Integer)
    sourceId v =
      if (not $ p v || p' v) &&  p'' v then (
        let
          ref = fromJust' "refDec of nested struct ref" $ refDec v
          cn = fromJust' "no match for lookup in nested structs" $ lookup ref cs
        in
          (end, fst cn)
      ) else
        fromJust' "internal error: no id found for contract reference" x
      where
        (start, end) = sourceRange v
        x :: Maybe (Int, Integer)
        x = case preview (key "name" . _String) v of
          Just t
            | t `elem` ["ContractDefinition", "StructDefinition"] ->
              let
                name' = encodeUtf8 $ fromJust' "no name for contract/struct" $ name v
                bs' = snd $ BS.splitAt (start + snd bso) bs
                pos = start
                  + (BS.length $ fst $ BS.breakSubstring name' bs')
                  + (BS.length name')
              in
                fmap ((,) pos) $ id' v
            | t `elem` ["UserDefinedTypeName", "Identifier"] ->
              fmap ((,) end) $ refDec v
            | otherwise ->
              error "internal error: not a contract reference"
          Nothing ->
            error "internal error: not a contract reference"
    
    
    cutRanges :: [(Int, Integer)] -> (ByteString, Int)
    cutRanges (sort -> rs) = foldl' f bso rs
      where
        f (bs', n) (i, t) =
          let
            t' = "_" <> (castr ! t)
          in
            ( prefix t' bs' (i + n)
            , n + Text.length t' )
    
    prefix :: Text -> ByteString -> Int -> ByteString
    prefix t x i =
      let (a, b) = BS.splitAt i x
      in a <> encodeUtf8 t <> b
sourceRange :: Value -> (Int, Int)
sourceRange v =
  case preview (key "src" . _String) v of
    Just (Text.splitOn ":" -> [readAs -> Just i, readAs -> Just n, _]) ->
      (i, i + n)
    _ ->
      error "internal error: no source position for AST node"
fromJust' :: String -> Maybe a -> a
fromJust' msg = \case
  Just x -> x
  Nothing -> error msg
repeated :: Eq a => [a] -> [a]
repeated = fmap fst $ foldl' f ([], [])
  where
    f (acc, seen) x =
      ( if (x `elem` seen) && (not $ x `elem` acc)
        then x : acc
        else acc
      , x : seen
      )
indexed :: [(Integer, Text)] -> [(Integer, Text)]
indexed = fst . foldl' f ([], Map.empty) 
  where
    f (acc, seen) (id', n) =
      let
        count = (fromMaybe 0 $ seen !? n) + 1
      in
        ((id', pack $ show count) : acc, Map.insert n count seen)