module Snap.Snaplet.Auth.Backends.SqliteSimple
( initSqliteAuth
) where
import Control.Concurrent
import qualified Data.Aeson as A
import Data.ByteString (ByteString)
import qualified Data.Configurator as C
import qualified Data.HashMap.Lazy as HM
import Data.Maybe
import Data.Text (Text)
import qualified Data.Text as T
import qualified Data.Text.Encoding as T
import qualified Data.Text.Lazy as LT
import qualified Data.Text.Lazy.Encoding as LT
import qualified Database.SQLite.Simple as S
import Database.SQLite.Simple.FromField
import Database.SQLite.Simple.FromRow
import qualified Database.SQLite.Simple.ToField as S
import Database.SQLite.Simple.Types
import Database.SQLite3 (SQLData(..))
import Paths_snaplet_sqlite_simple
import Snap
import Snap.Snaplet.Auth
import Snap.Snaplet.Session
import Snap.Snaplet.SqliteSimple
import Web.ClientSession
data SqliteAuthManager = SqliteAuthManager
{ pamTable :: AuthTable
, pamConnPool :: MVar S.Connection
}
initSqliteAuth
:: SnapletLens b SessionManager
-> Snaplet Sqlite
-> SnapletInit b (AuthManager b)
initSqliteAuth sess db = makeSnaplet "sqlite-auth" desc datadir $ do
config <- getSnapletUserConfig
authTable <- liftIO $ C.lookupDefault "snap_auth_user" config "authTable"
authSettings <- authSettingsFromConfig
key <- liftIO $ getKey (asSiteKey authSettings)
let tableDesc = defAuthTable { tblName = authTable }
let manager = SqliteAuthManager tableDesc $
sqliteConn $ db ^# snapletValue
liftIO $ createTableIfMissing manager
rng <- liftIO mkRNG
return AuthManager
{ backend = manager
, session = sess
, activeUser = Nothing
, minPasswdLen = asMinPasswdLen authSettings
, rememberCookieName = asRememberCookieName authSettings
, rememberPeriod = asRememberPeriod authSettings
, siteKey = key
, lockout = asLockout authSettings
, randomNumberGenerator = rng
}
where
desc = "An Sqlite backend for user authentication"
datadir = Just $ liftM (++"/resources/auth") getDataDir
tableExists :: S.Connection -> T.Text -> IO Bool
tableExists conn tblName = do
r <- S.query conn "SELECT name FROM sqlite_master WHERE type='table' AND name=?" (Only tblName)
case r of
[Only (_ :: String)] -> return True
_ -> return False
createInitialSchema :: S.Connection -> AuthTable -> IO ()
createInitialSchema conn pamTable = do
let q = Query $ T.concat
[ "CREATE TABLE ", tblName pamTable, " (uid INTEGER PRIMARY KEY,"
, "login text UNIQUE NOT NULL,"
, "password text,"
, "activated_at timestamp,suspended_at timestamp,remember_token text,"
, "login_count INTEGER NOT NULL,failed_login_count INTEGER NOT NULL,"
, "locked_out_until timestamp,current_login_at timestamp,"
, "last_login_at timestamp,current_login_ip text,"
, "last_login_ip text,created_at timestamp,updated_at timestamp);"
]
S.execute_ conn q
versionTable :: AuthTable -> T.Text
versionTable pamTable = T.concat [tblName pamTable, "_version"]
schemaVersion :: S.Connection -> AuthTable -> IO Int
schemaVersion conn pamTable = do
let verTbl = versionTable pamTable
versionExists <- tableExists conn verTbl
if not versionExists
then return 0
else
do
let q = T.concat ["SELECT version FROM ", verTbl, " LIMIT 1"]
[Only v] <- S.query_ conn (Query q) :: IO [Only Int]
return v
setSchemaVersion :: S.Connection -> AuthTable -> Int -> IO ()
setSchemaVersion conn pamTable v = do
let q = Query $ T.concat ["UPDATE ", versionTable pamTable, " SET version = ?"]
S.execute conn q (Only v)
upgradeSchema :: Connection -> AuthTable -> Int -> IO ()
upgradeSchema conn pam fromVersion = do
ver <- schemaVersion conn pam
when (ver == fromVersion) (upgrade ver >> setSchemaVersion conn pam (fromVersion+1))
where
upgrade 0 = do
S.execute_ conn (Query $ T.concat ["CREATE TABLE ", versionTable pam, " (version INTEGER)"])
S.execute_ conn (Query $ T.concat ["INSERT INTO ", versionTable pam, " VALUES (1)"])
upgrade 1 = do
S.execute_ conn (addColumnQ (colEmail pam))
S.execute_ conn (addColumnQ (colResetToken pam))
S.execute_ conn (addColumnQ (colResetRequestedAt pam))
upgrade 2 = do
S.execute_ conn (addColumnQ (colRoles pam))
S.execute_ conn (addColumnQ (colMeta pam))
upgrade _ = error "unknown version"
addColumnQ (c,t) =
Query $ T.concat [ "ALTER TABLE ", tblName pam, " ADD COLUMN ", c, " ", t]
createTableIfMissing :: SqliteAuthManager -> IO ()
createTableIfMissing SqliteAuthManager{..} =
withMVar pamConnPool $ \conn -> do
authTblExists <- tableExists conn $ tblName pamTable
unless authTblExists $ createInitialSchema conn pamTable
upgradeSchema conn pamTable 0
upgradeSchema conn pamTable 1
upgradeSchema conn pamTable 2
buildUid :: Int -> UserId
buildUid = UserId . T.pack . show
instance FromField UserId where
fromField f = buildUid <$> fromField f
instance FromField Password where
fromField f = Encrypted <$> fromField f
instance FromRow AuthUser where
fromRow =
AuthUser
<$> _userId
<*> _userLogin
<*> _userEmail
<*> _userPassword
<*> _userActivatedAt
<*> _userSuspendedAt
<*> _userRememberToken
<*> _userLoginCount
<*> _userFailedLoginCount
<*> _userLockedOutUntil
<*> _userCurrentLoginAt
<*> _userLastLoginAt
<*> _userCurrentLoginIp
<*> _userLastLoginIp
<*> _userCreatedAt
<*> _userUpdatedAt
<*> _userResetToken
<*> _userResetRequestedAt
<*> decodeRoles
<*> decodeMeta
where
!_userId = field
!_userLogin = field
!_userEmail = field
!_userPassword = field
!_userActivatedAt = field
!_userSuspendedAt = field
!_userRememberToken = field
!_userLoginCount = field
!_userFailedLoginCount = field
!_userLockedOutUntil = field
!_userCurrentLoginAt = field
!_userLastLoginAt = field
!_userCurrentLoginIp = field
!_userLastLoginIp = field
!_userCreatedAt = field
!_userUpdatedAt = field
!_userResetToken = field
!_userResetRequestedAt = field
!_userRoles = field :: RowParser (Maybe LT.Text)
!_userMeta = field :: RowParser (Maybe LT.Text)
decodeRoles :: RowParser [Role]
decodeRoles = do
roles <- fmap (fmap (map Role) . textDecodeBS) _userRoles
return $ fromMaybe [] roles
decodeMeta = do
meta <- fmap (fmap (fromMaybe HM.empty . A.decode' . LT.encodeUtf8)) _userMeta
return $ fromMaybe HM.empty meta
textDecodeBS :: Maybe LT.Text -> Maybe [ByteString]
textDecodeBS Nothing = Nothing
textDecodeBS (Just t) = fmap (map T.encodeUtf8) . A.decode' . LT.encodeUtf8 $ t
querySingle :: (ToRow q, FromRow a)
=> MVar S.Connection -> Query -> q -> IO (Maybe a)
querySingle pool q ps = withMVar pool $ \conn -> return . listToMaybe =<<
S.query conn q ps
authExecute :: ToRow q
=> MVar S.Connection -> Query -> q -> IO ()
authExecute pool q ps = do
withMVar pool $ \conn -> S.execute conn q ps
return ()
instance S.ToField Password where
toField (ClearText bs) = S.toField bs
toField (Encrypted bs) = S.toField bs
data AuthTable
= AuthTable
{ tblName :: Text
, colId :: (Text, Text)
, colLogin :: (Text, Text)
, colEmail :: (Text, Text)
, colPassword :: (Text, Text)
, colActivatedAt :: (Text, Text)
, colSuspendedAt :: (Text, Text)
, colRememberToken :: (Text, Text)
, colLoginCount :: (Text, Text)
, colFailedLoginCount :: (Text, Text)
, colLockedOutUntil :: (Text, Text)
, colCurrentLoginAt :: (Text, Text)
, colLastLoginAt :: (Text, Text)
, colCurrentLoginIp :: (Text, Text)
, colLastLoginIp :: (Text, Text)
, colCreatedAt :: (Text, Text)
, colUpdatedAt :: (Text, Text)
, colResetToken :: (Text, Text)
, colResetRequestedAt :: (Text, Text)
, colRoles :: (Text, Text)
, colMeta :: (Text, Text)
}
defAuthTable :: AuthTable
defAuthTable
= AuthTable
{ tblName = "snap_auth_user"
, colId = ("uid", "INTEGER PRIMARY KEY")
, colLogin = ("login", "text UNIQUE NOT NULL")
, colEmail = ("email", "text")
, colPassword = ("password", "text")
, colActivatedAt = ("activated_at", "timestamp")
, colSuspendedAt = ("suspended_at", "timestamp")
, colRememberToken = ("remember_token", "text")
, colLoginCount = ("login_count", "INTEGER NOT NULL")
, colFailedLoginCount = ("failed_login_count", "INTEGER NOT NULL")
, colLockedOutUntil = ("locked_out_until", "timestamp")
, colCurrentLoginAt = ("current_login_at", "timestamp")
, colLastLoginAt = ("last_login_at", "timestamp")
, colCurrentLoginIp = ("current_login_ip", "text")
, colLastLoginIp = ("last_login_ip", "text")
, colCreatedAt = ("created_at", "timestamp")
, colUpdatedAt = ("updated_at", "timestamp")
, colResetToken = ("reset_token", "text")
, colResetRequestedAt = ("reset_requested_at", "timestamp")
, colRoles = ("roles_json", "text")
, colMeta = ("meta_json", "text")
}
colDef :: [(AuthTable -> (Text, Text), AuthUser -> SQLData)]
colDef =
[ (colId , S.toField . fmap unUid . userId)
, (colLogin , S.toField . userLogin)
, (colEmail , S.toField . userEmail)
, (colPassword , S.toField . userPassword)
, (colActivatedAt , S.toField . userActivatedAt)
, (colSuspendedAt , S.toField . userSuspendedAt)
, (colRememberToken , S.toField . userRememberToken)
, (colLoginCount , S.toField . userLoginCount)
, (colFailedLoginCount, S.toField . userFailedLoginCount)
, (colLockedOutUntil , S.toField . userLockedOutUntil)
, (colCurrentLoginAt , S.toField . userCurrentLoginAt)
, (colLastLoginAt , S.toField . userLastLoginAt)
, (colCurrentLoginIp , S.toField . userCurrentLoginIp)
, (colLastLoginIp , S.toField . userLastLoginIp)
, (colCreatedAt , S.toField . userCreatedAt)
, (colUpdatedAt , S.toField . userUpdatedAt)
, (colResetToken , S.toField . userResetToken)
, (colResetRequestedAt, S.toField . userResetRequestedAt)
, (colRoles , S.toField . encodeOrNull . userRoles)
, (colMeta , S.toField . encodeOrNullHM . userMeta)
]
where
encodeOrNull [] = Nothing
encodeOrNull x = Just . LT.decodeUtf8 . A.encode $ x
encodeOrNullHM hm | HM.null hm = Nothing
| otherwise = Just . LT.decodeUtf8 . A.encode $ hm
colNames :: AuthTable -> T.Text
colNames pam =
T.intercalate "," . map (\(f,_) -> fst (f pam)) $ colDef
saveQuery :: AuthTable -> AuthUser -> (Text, [SQLData])
saveQuery at u@AuthUser{..} = maybe insertQuery updateQuery userId
where
insertQuery = (T.concat [ "INSERT INTO "
, tblName at
, " ("
, T.intercalate "," cols
, ") VALUES ("
, T.intercalate "," vals
, ")"
]
, params)
qval f = fst (f at) `T.append` " = ?"
updateQuery uid =
(T.concat [ "UPDATE "
, tblName at
, " SET "
, T.intercalate "," (map (qval . fst) $ tail colDef)
, " WHERE "
, fst (colId at)
, " = ?"
]
, params ++ [S.toField $ unUid uid])
cols = map (fst . ($at) . fst) $ tail colDef
vals = map (const "?") cols
params = map (($u) . snd) $ tail colDef
instance IAuthBackend SqliteAuthManager where
save SqliteAuthManager{..} u@AuthUser{..} = do
let (qstr, params) = saveQuery pamTable u
withMVar pamConnPool $ \conn -> do
S.execute conn (Query qstr) params
let q2 = Query $ T.concat
[ "select ", colNames pamTable, " from "
, tblName pamTable
, " where "
, fst (colLogin pamTable)
, " = ?"
]
res <- S.query conn q2 [userLogin]
case res of
[savedUser] -> return $ Right savedUser
_ -> return . Left $ AuthError "snaplet-sqlite-simple: Failed user save"
lookupByUserId SqliteAuthManager{..} uid = do
let q = Query $ T.concat
[ "select ", colNames pamTable, " from "
, tblName pamTable
, " where "
, fst (colId pamTable)
, " = ?"
]
querySingle pamConnPool q [unUid uid]
lookupByLogin SqliteAuthManager{..} login = do
let q = Query $ T.concat
[ "select ", colNames pamTable, " from "
, tblName pamTable
, " where "
, fst (colLogin pamTable)
, " = ?"
]
querySingle pamConnPool q [login]
lookupByRememberToken SqliteAuthManager{..} token = do
let q = Query $ T.concat
[ "select ", colNames pamTable, " from "
, tblName pamTable
, " where "
, fst (colRememberToken pamTable)
, " = ?"
]
querySingle pamConnPool q [token]
destroy SqliteAuthManager{..} AuthUser{..} = do
let q = Query $ T.concat
[ "delete from "
, tblName pamTable
, " where "
, fst (colLogin pamTable)
, " = ?"
]
authExecute pamConnPool q [userLogin]