From 59991bf5471b15f7bba6c230b4859282cec70e4f Mon Sep 17 00:00:00 2001 From: rakeshkky Date: Tue, 7 May 2019 19:19:19 +0530 Subject: [PATCH 1/5] close websocket connection on JWT expiry, fix #578 --- server/graphql-engine.cabal | 3 ++ server/src-lib/Hasura/GraphQL/Explain.hs | 2 +- .../Hasura/GraphQL/Transport/WebSocket.hs | 35 +++++++++++++------ .../GraphQL/Transport/WebSocket/Server.hs | 35 +++++++++++++------ server/src-lib/Hasura/RQL/Types/Permission.hs | 19 +++++----- server/src-lib/Hasura/Server/Auth.hs | 14 ++++---- server/src-lib/Hasura/Server/Auth/JWT.hs | 13 +++---- server/src-lib/Hasura/Server/Utils.hs | 8 +++++ 8 files changed, 85 insertions(+), 44 deletions(-) diff --git a/server/graphql-engine.cabal b/server/graphql-engine.cabal index 061f539a1e687..92127eba47051 100644 --- a/server/graphql-engine.cabal +++ b/server/graphql-engine.cabal @@ -146,6 +146,9 @@ library -- metrics in multiplexed subs , ekg-core + -- hashable time + , hashable-time + exposed-modules: Hasura.Prelude , Hasura.Logging , Hasura.EncJSON diff --git a/server/src-lib/Hasura/GraphQL/Explain.hs b/server/src-lib/Hasura/GraphQL/Explain.hs index d58616de59cca..d11535905cac2 100644 --- a/server/src-lib/Hasura/GraphQL/Explain.hs +++ b/server/src-lib/Hasura/GraphQL/Explain.hs @@ -129,4 +129,4 @@ explainGQLQuery pgExecCtx sc sqlGenCtx (GQLExplain query userVarsRaw)= do throw400 InvalidParams "only queries can be explained" where usrVars = mkUserVars $ maybe [] Map.toList userVarsRaw - userInfo = mkUserInfo (fromMaybe adminRole $ roleFromVars usrVars) usrVars + userInfo = mkUserInfo (fromMaybe adminRole $ roleFromVars usrVars) usrVars Nothing diff --git a/server/src-lib/Hasura/GraphQL/Transport/WebSocket.hs b/server/src-lib/Hasura/GraphQL/Transport/WebSocket.hs index 1a89228108c93..36afac87bf774 100644 --- a/server/src-lib/Hasura/GraphQL/Transport/WebSocket.hs +++ b/server/src-lib/Hasura/GraphQL/Transport/WebSocket.hs @@ -16,6 +16,7 @@ import qualified Data.CaseInsensitive as CI import qualified Data.HashMap.Strict as Map import qualified Data.Text as T import qualified Data.Text.Encoding as TE +import qualified Data.Time.Clock as TC import qualified Language.GraphQL.Draft.Syntax as G import qualified ListT import qualified Network.HTTP.Client as H @@ -39,7 +40,8 @@ import Hasura.RQL.Types import Hasura.Server.Auth (AuthMode, getUserInfo) import Hasura.Server.Cors -import Hasura.Server.Utils (bsToTxt) +import Hasura.Server.Utils (bsToTxt, + diffTimeToMicro) type OperationMap = STMMap.Map OperationId (LQ.LiveQueryId, Maybe OperationName) @@ -58,7 +60,7 @@ data WSConnState data WSConnData = WSConnData -- the role and headers are set only on connection_init message - { _wscUser :: !(IORef.IORef WSConnState) + { _wscUser :: !(STM.TVar WSConnState) -- we only care about subscriptions, -- the other operations (query/mutations) -- are not tracked here @@ -138,14 +140,26 @@ onConn (L.Logger logger) corsPolicy wsId requestHead = do sendMsg wsConn SMConnKeepAlive threadDelay $ 5 * 1000 * 1000 + jwtExpiryHandler wsConn = do + currTime <- TC.getCurrentTime + expTime <- STM.atomically $ do + connState <- STM.readTVar $ (_wscUser . WS.getData) wsConn + case connState of + CSNotInitialised _ -> STM.retry + CSInitError _ -> STM.retry + CSInitialised userInfo _ -> + maybe STM.retry return $ userJWTExpiry userInfo + threadDelay $ diffTimeToMicro $ TC.diffUTCTime expTime currTime + accept hdrs = do logger $ WSLog wsId Nothing EAccepted Nothing connData <- WSConnData - <$> IORef.newIORef (CSNotInitialised hdrs) + <$> STM.newTVarIO (CSNotInitialised hdrs) <*> STMMap.newIO let acceptRequest = WS.defaultAcceptRequest { WS.acceptSubprotocol = Just "graphql-ws"} - return $ Right (connData, acceptRequest, Just keepAliveAction) + return $ Right $ WS.AcceptWith connData acceptRequest + (Just keepAliveAction) (Just jwtExpiryHandler) reject qErr = do logger $ WSLog wsId Nothing (ERejected qErr) Nothing @@ -202,7 +216,7 @@ onStart serverEnv wsConn (StartMsg opId q) msgRaw = catchAndIgnore $ do when (isJust opM) $ withComplete $ sendConnErr $ "an operation already exists with this id: " <> unOperationId opId - userInfoM <- liftIO $ IORef.readIORef userInfoR + userInfoM <- liftIO $ STM.readTVarIO userInfoR (userInfo, reqHdrs) <- case userInfoM of CSInitialised userInfo reqHdrs -> return (userInfo, reqHdrs) CSInitError initErr -> do @@ -344,7 +358,7 @@ logWSEvent :: (MonadIO m) => L.Logger -> WSConn -> WSEvent -> m () logWSEvent (L.Logger logger) wsConn wsEv = do - userInfoME <- liftIO $ IORef.readIORef userInfoR + userInfoME <- liftIO $ STM.readTVarIO userInfoR let userInfoM = case userInfoME of CSInitialised userInfo _ -> return $ userVars userInfo _ -> Nothing @@ -357,17 +371,17 @@ onConnInit :: (MonadIO m) => L.Logger -> H.Manager -> WSConn -> AuthMode -> Maybe ConnParams -> m () onConnInit logger manager wsConn authMode connParamsM = do - headers <- mkHeaders <$> liftIO (IORef.readIORef (_wscUser $ WS.getData wsConn)) + headers <- mkHeaders <$> liftIO (STM.readTVarIO (_wscUser $ WS.getData wsConn)) res <- runExceptT $ getUserInfo logger manager headers authMode case res of Left e -> do - liftIO $ IORef.writeIORef (_wscUser $ WS.getData wsConn) $ + liftIO $ STM.atomically $ STM.writeTVar (_wscUser $ WS.getData wsConn) $ CSInitError $ qeError e let connErr = ConnErrMsg $ qeError e logWSEvent logger wsConn $ EConnErr connErr sendMsg wsConn $ SMConnErr connErr Right userInfo -> do - liftIO $ IORef.writeIORef (_wscUser $ WS.getData wsConn) $ + liftIO $ STM.atomically $ STM.writeTVar (_wscUser $ WS.getData wsConn) $ CSInitialised userInfo paramHeaders sendMsg wsConn SMConnAck -- TODO: send it periodically? Why doesn't apollo's protocol use @@ -389,10 +403,9 @@ onConnInit logger manager wsConn authMode connParamsM = do onClose :: L.Logger -> LQ.LiveQueriesState - -> WS.ConnectionException -> WSConn -> IO () -onClose logger lqMap _ wsConn = do +onClose logger lqMap wsConn = do logWSEvent logger wsConn EClosed operations <- STM.atomically $ ListT.toList $ STMMap.listT opMap void $ A.forConcurrently operations $ \(_, (lqId, _)) -> diff --git a/server/src-lib/Hasura/GraphQL/Transport/WebSocket/Server.hs b/server/src-lib/Hasura/GraphQL/Transport/WebSocket/Server.hs index 7ed6e7178d56e..b8e33b83f8580 100644 --- a/server/src-lib/Hasura/GraphQL/Transport/WebSocket/Server.hs +++ b/server/src-lib/Hasura/GraphQL/Transport/WebSocket/Server.hs @@ -9,6 +9,7 @@ module Hasura.GraphQL.Transport.WebSocket.Server , closeConn , sendMsg + , AcceptWith(..) , OnConnH , OnCloseH , OnMessageH @@ -51,6 +52,7 @@ data WSEvent | ERejected | EMessageReceived !TBS.TByteString | EMessageSent !TBS.TByteString + | EJwtExpired | ECloseReceived | ECloseSent !TBS.TByteString | EClosed @@ -118,10 +120,17 @@ closeAll (WSServer (L.Logger writeLog) connMap) msg = do return conns void $ A.mapConcurrently (flip closeConn msg . snd) conns -type AcceptWith a = (a, WS.AcceptRequest, Maybe (WSConn a -> IO ())) +data AcceptWith a + = AcceptWith + { _awData :: !a + , _awReq :: !WS.AcceptRequest + , _awKeepAlive :: !(Maybe (WSConn a -> IO ())) + , _awOnJwtExpiry :: !(Maybe (WSConn a -> IO ())) + } + type OnConnH a = WSId -> WS.RequestHead -> IO (Either WS.RejectRequest (AcceptWith a)) -type OnCloseH a = WS.ConnectionException -> WSConn a -> IO () +type OnCloseH a = WSConn a -> IO () type OnMessageH a = WSConn a -> BL.ByteString -> IO () data WSHandlers a @@ -149,7 +158,7 @@ createServerApp (WSServer logger@(L.Logger writeLog) connMap) wsHandlers pending WS.rejectRequestWith pendingConn rejectRequest writeLog $ WSLog wsId ERejected - onAccept wsId (a, acceptWithParams, keepAliveM) = do + onAccept wsId (AcceptWith a acceptWithParams keepAliveM onJwtExpiryM) = do conn <- WS.acceptRequestWith pendingConn acceptWithParams writeLog $ WSLog wsId EAccepted @@ -168,19 +177,23 @@ createServerApp (WSServer logger@(L.Logger writeLog) connMap) wsHandlers pending writeLog $ WSLog wsId $ EMessageSent $ TBS.fromLBS msg keepAliveRefM <- forM keepAliveM $ \action -> A.async $ action wsConn + onJwtExpiryRefM <- forM onJwtExpiryM $ \action -> A.async $ action wsConn - -- terminates on WS.ConnectionException - let waitOnRefs = maybeToList keepAliveRefM <> [rcvRef, sendRef] + -- terminates on WS.ConnectionException and JWT expiry + let waitOnRefs = catMaybes [keepAliveRefM, onJwtExpiryRefM] + <> [rcvRef, sendRef] res <- try $ A.waitAnyCancel waitOnRefs case res of - Left e -> do + Left ( _ :: WS.ConnectionException) -> do writeLog $ WSLog (_wcConnId wsConn) ECloseReceived - onConnClose e wsConn - -- this will never happen as both the threads never finish - Right _ -> return () + onConnClose wsConn + -- this will happen when jwt is expired + Right _ -> do + writeLog $ WSLog (_wcConnId wsConn) EJwtExpired + onConnClose wsConn - onConnClose e wsConn = do + onConnClose wsConn = do STM.atomically $ STMMap.delete (_wcConnId wsConn) connMap - _hOnClose wsHandlers e wsConn + _hOnClose wsHandlers wsConn writeLog $ WSLog (_wcConnId wsConn) EClosed diff --git a/server/src-lib/Hasura/RQL/Types/Permission.hs b/server/src-lib/Hasura/RQL/Types/Permission.hs index 74040d20858b6..531390b5b1c10 100644 --- a/server/src-lib/Hasura/RQL/Types/Permission.hs +++ b/server/src-lib/Hasura/RQL/Types/Permission.hs @@ -11,9 +11,7 @@ module Hasura.RQL.Types.Permission , getVarVal , roleFromVars - , UserInfo - , userRole - , userVars + , UserInfo(..) , mkUserInfo , userInfoToList , adminUserInfo @@ -25,13 +23,17 @@ module Hasura.RQL.Types.Permission ) where import Hasura.Prelude -import Hasura.Server.Utils (adminSecretHeader, deprecatedAccessKeyHeader, userRoleHeader) +import Hasura.Server.Utils (adminSecretHeader, + deprecatedAccessKeyHeader, + userRoleHeader) import Hasura.SQL.Types import qualified Database.PG.Query as Q import Data.Aeson import Data.Hashable +import Data.Hashable.Time () +import Data.Time.Clock import Instances.TH.Lift () import Language.Haskell.TH.Syntax (Lift) @@ -84,11 +86,12 @@ mkUserVars l = data UserInfo = UserInfo - { userRole :: !RoleName - , userVars :: !UserVars + { userRole :: !RoleName + , userVars :: !UserVars + , userJWTExpiry :: !(Maybe UTCTime) } deriving (Show, Eq, Generic) -mkUserInfo :: RoleName -> UserVars -> UserInfo +mkUserInfo :: RoleName -> UserVars -> Maybe UTCTime -> UserInfo mkUserInfo rn (UserVars v) = UserInfo rn $ UserVars $ Map.insert userRoleHeader (getRoleTxt rn) $ foldl (flip Map.delete) v [adminSecretHeader, deprecatedAccessKeyHeader] @@ -107,7 +110,7 @@ userInfoToList userInfo = adminUserInfo :: UserInfo adminUserInfo = - mkUserInfo adminRole $ mkUserVars [] + mkUserInfo adminRole (mkUserVars []) Nothing data PermType = PTInsert diff --git a/server/src-lib/Hasura/Server/Auth.hs b/server/src-lib/Hasura/Server/Auth.hs index 65d30499fba9a..5f8c59604e0cb 100644 --- a/server/src-lib/Hasura/Server/Auth.hs +++ b/server/src-lib/Hasura/Server/Auth.hs @@ -167,7 +167,7 @@ mkUserInfoFromResp logger url method statusCode respBody throw500 "missing x-hasura-role key in webhook response" Just rn -> do logWebHookResp L.LevelInfo Nothing - return $ mkUserInfo rn usrVars + return $ mkUserInfo rn usrVars Nothing logError = logWebHookResp L.LevelError $ Just respBody @@ -234,7 +234,7 @@ getUserInfo logger manager rawHeaders = \case AMAdminSecret adminScrt unAuthRole -> case adminSecretM of Just givenAdminScrt -> userInfoWhenAdminSecret adminScrt givenAdminScrt - Nothing -> userInfoWhenNoAdminSecret unAuthRole + Nothing -> userInfoWhenNoAdminSecret unAuthRole AMAdminSecretAndHook accKey hook -> whenAdminSecretAbsent accKey (userInfoFromAuthHook logger manager hook rawHeaders) @@ -246,16 +246,16 @@ getUserInfo logger manager rawHeaders = \case -- when admin secret is absent, run the action to retrieve UserInfo, otherwise -- adminsecret override whenAdminSecretAbsent ak action = - maybe action (userInfoWhenAdminSecret ak) $ adminSecretM + maybe action (userInfoWhenAdminSecret ak) adminSecretM - adminSecretM= foldl1 (<|>) $ map (flip getVarVal usrVars) [adminSecretHeader, deprecatedAccessKeyHeader] + adminSecretM= foldl1 (<|>) $ map (`getVarVal` usrVars) [adminSecretHeader, deprecatedAccessKeyHeader] usrVars = mkUserVars $ hdrsToText rawHeaders userInfoFromHeaders = case roleFromVars usrVars of - Just rn -> mkUserInfo rn usrVars - Nothing -> mkUserInfo adminRole usrVars + Just rn -> mkUserInfo rn usrVars Nothing + Nothing -> mkUserInfo adminRole usrVars Nothing userInfoWhenAdminSecret key reqKey = do when (reqKey /= getAdminSecret key) $ throw401 $ "invalid " <> adminSecretHeader <> "/" <> deprecatedAccessKeyHeader @@ -263,4 +263,4 @@ getUserInfo logger manager rawHeaders = \case userInfoWhenNoAdminSecret = \case Nothing -> throw401 $ adminSecretHeader <> "/" <> deprecatedAccessKeyHeader <> " required, but not found" - Just role -> return $ mkUserInfo role usrVars + Just role -> return $ mkUserInfo role usrVars Nothing diff --git a/server/src-lib/Hasura/Server/Auth/JWT.hs b/server/src-lib/Hasura/Server/Auth/JWT.hs index 18a9890096c35..a06fc8ec31ff0 100644 --- a/server/src-lib/Hasura/Server/Auth/JWT.hs +++ b/server/src-lib/Hasura/Server/Auth/JWT.hs @@ -28,7 +28,8 @@ import Hasura.Prelude import Hasura.RQL.Types import Hasura.Server.Auth.JWT.Internal (parseHmacKey, parseRsaKey) import Hasura.Server.Auth.JWT.Logging -import Hasura.Server.Utils (bsToTxt, userRoleHeader) +import Hasura.Server.Utils (bsToTxt, diffTimeToMicro, + userRoleHeader) import qualified Control.Concurrent as C import qualified Data.Aeson as A @@ -106,13 +107,12 @@ jwkRefreshCtrl -> m () jwkRefreshCtrl lggr mngr url ref time = void $ liftIO $ C.forkIO $ do - C.threadDelay $ delay time + C.threadDelay $ diffTimeToMicro time forever $ do res <- runExceptT $ updateJwkRef lggr mngr url ref mTime <- either (const $ return Nothing) return res - C.threadDelay $ maybe (60 * aSecond) delay mTime + C.threadDelay $ maybe (60 * aSecond) diffTimeToMicro mTime where - delay t = (floor (realToFrac t :: Double) - 10) * aSecond aSecond = 1000 * 1000 @@ -183,7 +183,7 @@ processJwt jwtCtx headers mUnAuthRole = withoutAuthZHeader = do unAuthRole <- maybe missingAuthzHeader return mUnAuthRole - return $ mkUserInfo unAuthRole $ mkUserVars $ hdrsToText headers + return $ mkUserInfo unAuthRole (mkUserVars $ hdrsToText headers) Nothing missingAuthzHeader = throw400 InvalidHeaders "Missing Authorization header in JWT authentication mode" @@ -204,6 +204,7 @@ processAuthZHeader jwtCtx headers authzHeader = do let claimsNs = fromMaybe defaultClaimNs $ jcxClaimNs jwtCtx claimsFmt = jcxClaimsFormat jwtCtx + expTimeM = fmap (\(NumericDate t) -> t) $ claims ^. claimExp -- see if the hasura claims key exist in the claims map let mHasuraClaims = Map.lookup claimsNs $ claims ^. unregisteredClaims @@ -227,7 +228,7 @@ processAuthZHeader jwtCtx headers authzHeader = do -- transform the map of text:aeson-value -> text:text metadata <- decodeJSON $ A.Object finalClaims - return $ mkUserInfo role $ mkUserVars $ Map.toList metadata + return $ mkUserInfo role (mkUserVars $ Map.toList metadata) expTimeM where parseAuthzHeader = do diff --git a/server/src-lib/Hasura/Server/Utils.hs b/server/src-lib/Hasura/Server/Utils.hs index e3e642c0704eb..46e0d94d99c03 100644 --- a/server/src-lib/Hasura/Server/Utils.hs +++ b/server/src-lib/Hasura/Server/Utils.hs @@ -4,6 +4,7 @@ import qualified Database.PG.Query.Connection as Q import Data.Aeson import Data.List.Split +import Data.Time.Clock import Network.URI import System.Environment import System.Exit @@ -143,3 +144,10 @@ matchRegex regex caseSensitive src = fmapL :: (a -> a') -> Either a b -> Either a' b fmapL fn (Left e) = Left (fn e) fmapL _ (Right x) = pure x + +-- diff time to micro seconds +diffTimeToMicro :: NominalDiffTime -> Int +diffTimeToMicro diff = + (floor (realToFrac diff :: Double) - 10) * aSecond + where + aSecond = 1000 * 1000 From 5c6333b662b4abe3d846658a3b72b40c4cbae0b5 Mon Sep 17 00:00:00 2001 From: rakeshkky Date: Fri, 10 May 2019 17:43:11 +0530 Subject: [PATCH 2/5] add test for websocket connection close on jwt expiry --- server/tests-py/test_jwt.py | 28 +++++++++++++++++++++++++++ server/tests-py/test_subscriptions.py | 17 ++++++++-------- 2 files changed, 37 insertions(+), 8 deletions(-) diff --git a/server/tests-py/test_jwt.py b/server/tests-py/test_jwt.py index 811bd925b2f5f..ced3f37d9720f 100644 --- a/server/tests-py/test_jwt.py +++ b/server/tests-py/test_jwt.py @@ -1,10 +1,12 @@ from datetime import datetime, timedelta import math import json +import time import yaml import pytest import jwt +from test_subscriptions import init_ws_conn from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives import serialization @@ -207,3 +209,29 @@ def gen_rsa_key(): encryption_algorithm=serialization.NoEncryption() ) return pem + +class TestSubscriptionJwtExpiry(object): + + def test_jwt_expiry(self, hge_ctx, ws_client): + curr_time = datetime.now() + self.claims = { + 'sub': '1234567890', + 'name': 'John Doe', + 'iat': math.floor(curr_time.timestamp()) + } + self.claims['https://hasura.io/jwt/claims'] = mk_claims(hge_ctx.hge_jwt_conf, { + 'x-hasura-user-id': '1', + 'x-hasura-default-role': 'user', + 'x-hasura-allowed-roles': ['user'], + }) + exp = curr_time + timedelta(seconds=5) + self.claims['exp'] = round(exp.timestamp()) + token = jwt.encode(self.claims, hge_ctx.hge_jwt_key, algorithm='RS512').decode('utf-8') + payload = { + 'headers': { + 'Authorization': 'Bearer ' + token + } + } + init_ws_conn(hge_ctx, ws_client, payload) + time.sleep(5) + assert ws_client.remote_closed == True, ws_client.remote_closed diff --git a/server/tests-py/test_subscriptions.py b/server/tests-py/test_subscriptions.py index 9465dccbd1520..42f315f28e1e3 100644 --- a/server/tests-py/test_subscriptions.py +++ b/server/tests-py/test_subscriptions.py @@ -9,14 +9,16 @@ Refer: https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md#gql_connection_init ''' -def init_ws_conn(hge_ctx, ws_client): - payload = {} - if hge_ctx.hge_key is not None: - payload = { - 'headers' : { - 'X-Hasura-Admin-Secret': hge_ctx.hge_key +def init_ws_conn(hge_ctx, ws_client, payload = None): + if payload is None: + payload = {} + if hge_ctx.hge_key is not None: + payload = { + 'headers' : { + 'X-Hasura-Admin-Secret': hge_ctx.hge_key + } } - } + init_msg = { 'type': 'connection_init', 'payload': payload, @@ -251,4 +253,3 @@ def test_live_queries(self, hge_ctx, ws_client): @classmethod def dir(cls): return 'queries/subscriptions/live_queries' - From 936dff3f8c85700a4946dfe3f31a6b302fd9c45c Mon Sep 17 00:00:00 2001 From: rakeshkky Date: Mon, 13 May 2019 15:04:52 +0530 Subject: [PATCH 3/5] remove jwt expiry time from UserInfo Explicitly return jwt expiry time while resolving user info --- server/graphql-engine.cabal | 3 - server/src-lib/Hasura/GraphQL/Explain.hs | 2 +- .../Hasura/GraphQL/Transport/WebSocket.hs | 26 ++++---- server/src-lib/Hasura/RQL/Types/Permission.hs | 11 ++-- server/src-lib/Hasura/Server/Auth.hs | 59 ++++++++++++------- server/src-lib/Hasura/Server/Auth/JWT.hs | 13 ++-- 6 files changed, 63 insertions(+), 51 deletions(-) diff --git a/server/graphql-engine.cabal b/server/graphql-engine.cabal index 1663edb8449b2..141f610cb87f3 100644 --- a/server/graphql-engine.cabal +++ b/server/graphql-engine.cabal @@ -146,9 +146,6 @@ library -- metrics in multiplexed subs , ekg-core - -- hashable time - , hashable-time - exposed-modules: Hasura.Prelude , Hasura.Logging , Hasura.EncJSON diff --git a/server/src-lib/Hasura/GraphQL/Explain.hs b/server/src-lib/Hasura/GraphQL/Explain.hs index d11535905cac2..d58616de59cca 100644 --- a/server/src-lib/Hasura/GraphQL/Explain.hs +++ b/server/src-lib/Hasura/GraphQL/Explain.hs @@ -129,4 +129,4 @@ explainGQLQuery pgExecCtx sc sqlGenCtx (GQLExplain query userVarsRaw)= do throw400 InvalidParams "only queries can be explained" where usrVars = mkUserVars $ maybe [] Map.toList userVarsRaw - userInfo = mkUserInfo (fromMaybe adminRole $ roleFromVars usrVars) usrVars Nothing + userInfo = mkUserInfo (fromMaybe adminRole $ roleFromVars usrVars) usrVars diff --git a/server/src-lib/Hasura/GraphQL/Transport/WebSocket.hs b/server/src-lib/Hasura/GraphQL/Transport/WebSocket.hs index 11e19f5da3bdd..3ecf9b5b8b5d2 100644 --- a/server/src-lib/Hasura/GraphQL/Transport/WebSocket.hs +++ b/server/src-lib/Hasura/GraphQL/Transport/WebSocket.hs @@ -38,8 +38,7 @@ import qualified Hasura.Logging as L import Hasura.Prelude import Hasura.RQL.Types import Hasura.RQL.Types.Error (Code (StartFailed)) -import Hasura.Server.Auth (AuthMode, - getUserInfo) +import Hasura.Server.Auth (AuthMode, getUserInfoWithExpTime) import Hasura.Server.Cors import Hasura.Server.Utils (bsToTxt, diffTimeToMicro) @@ -61,7 +60,8 @@ data WSConnState = CSNotInitialised !WsHeaders | CSInitError Text -- headers from the client (in conn params) to forward to the remote schema - | CSInitialised UserInfo [H.Header] + -- and JWT expiry time if any + | CSInitialised UserInfo (Maybe TC.UTCTime) [H.Header] data WSConnData = WSConnData @@ -152,10 +152,10 @@ onConn (L.Logger logger) corsPolicy wsId requestHead = do expTime <- STM.atomically $ do connState <- STM.readTVar $ (_wscUser . WS.getData) wsConn case connState of - CSNotInitialised _ -> STM.retry - CSInitError _ -> STM.retry - CSInitialised userInfo _ -> - maybe STM.retry return $ userJWTExpiry userInfo + CSNotInitialised _ -> STM.retry + CSInitError _ -> STM.retry + CSInitialised _ expTimeM _ -> + maybe STM.retry return expTimeM threadDelay $ diffTimeToMicro $ TC.diffUTCTime expTime currTime accept hdrs errType = do @@ -228,7 +228,7 @@ onStart serverEnv wsConn (StartMsg opId q) msgRaw = catchAndIgnore $ do userInfoM <- liftIO $ STM.readTVarIO userInfoR (userInfo, reqHdrs) <- case userInfoM of - CSInitialised userInfo reqHdrs -> return (userInfo, reqHdrs) + CSInitialised userInfo _ reqHdrs -> return (userInfo, reqHdrs) CSInitError initErr -> do let e = "cannot start as connection_init failed with : " <> initErr withComplete $ sendStartErr e @@ -382,8 +382,8 @@ logWSEvent logWSEvent (L.Logger logger) wsConn wsEv = do userInfoME <- liftIO $ STM.readTVarIO userInfoR let userInfoM = case userInfoME of - CSInitialised userInfo _ -> return $ userVars userInfo - _ -> Nothing + CSInitialised userInfo _ _ -> return $ userVars userInfo + _ -> Nothing liftIO $ logger $ WSLog wsId userInfoM wsEv Nothing where WSConnData userInfoR _ _ = WS.getData wsConn @@ -394,7 +394,7 @@ onConnInit => L.Logger -> H.Manager -> WSConn -> AuthMode -> Maybe ConnParams -> m () onConnInit logger manager wsConn authMode connParamsM = do headers <- mkHeaders <$> liftIO (STM.readTVarIO (_wscUser $ WS.getData wsConn)) - res <- runExceptT $ getUserInfo logger manager headers authMode + res <- runExceptT $ getUserInfoWithExpTime logger manager headers authMode case res of Left e -> do liftIO $ STM.atomically $ STM.writeTVar (_wscUser $ WS.getData wsConn) $ @@ -402,9 +402,9 @@ onConnInit logger manager wsConn authMode connParamsM = do let connErr = ConnErrMsg $ qeError e logWSEvent logger wsConn $ EConnErr connErr sendMsg wsConn $ SMConnErr connErr - Right userInfo -> do + Right (userInfo, expTimeM) -> do liftIO $ STM.atomically $ STM.writeTVar (_wscUser $ WS.getData wsConn) $ - CSInitialised userInfo paramHeaders + CSInitialised userInfo expTimeM paramHeaders sendMsg wsConn SMConnAck -- TODO: send it periodically? Why doesn't apollo's protocol use -- ping/pong frames of websocket spec? diff --git a/server/src-lib/Hasura/RQL/Types/Permission.hs b/server/src-lib/Hasura/RQL/Types/Permission.hs index 531390b5b1c10..9404f1212ec66 100644 --- a/server/src-lib/Hasura/RQL/Types/Permission.hs +++ b/server/src-lib/Hasura/RQL/Types/Permission.hs @@ -32,8 +32,6 @@ import qualified Database.PG.Query as Q import Data.Aeson import Data.Hashable -import Data.Hashable.Time () -import Data.Time.Clock import Instances.TH.Lift () import Language.Haskell.TH.Syntax (Lift) @@ -86,12 +84,11 @@ mkUserVars l = data UserInfo = UserInfo - { userRole :: !RoleName - , userVars :: !UserVars - , userJWTExpiry :: !(Maybe UTCTime) + { userRole :: !RoleName + , userVars :: !UserVars } deriving (Show, Eq, Generic) -mkUserInfo :: RoleName -> UserVars -> Maybe UTCTime -> UserInfo +mkUserInfo :: RoleName -> UserVars -> UserInfo mkUserInfo rn (UserVars v) = UserInfo rn $ UserVars $ Map.insert userRoleHeader (getRoleTxt rn) $ foldl (flip Map.delete) v [adminSecretHeader, deprecatedAccessKeyHeader] @@ -110,7 +107,7 @@ userInfoToList userInfo = adminUserInfo :: UserInfo adminUserInfo = - mkUserInfo adminRole (mkUserVars []) Nothing + mkUserInfo adminRole $ mkUserVars [] data PermType = PTInsert diff --git a/server/src-lib/Hasura/Server/Auth.hs b/server/src-lib/Hasura/Server/Auth.hs index 5f8c59604e0cb..edd1b0db41b0b 100644 --- a/server/src-lib/Hasura/Server/Auth.hs +++ b/server/src-lib/Hasura/Server/Auth.hs @@ -1,8 +1,6 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE RankNTypes #-} - module Hasura.Server.Auth ( getUserInfo + , getUserInfoWithExpTime , AuthMode(..) , mkAuthMode , AdminSecret (..) @@ -23,6 +21,7 @@ import Control.Exception (try) import Control.Lens import Data.Aeson import Data.IORef (newIORef) +import Data.Time.Clock (UTCTime) import qualified Data.Aeson as J import qualified Data.ByteString.Lazy as BL @@ -102,11 +101,13 @@ mkAuthMode mAdminSecret mWebHook mJwtSecret mUnAuthRole httpManager lCtx = (Just _, Just _, Just _) -> throwError "Fatal Error: Both webhook and JWT mode cannot be enabled at the same time" where - requiresAdminScrtMsg = " requires --admin-secret (HASURA_GRAPHQL_ADMIN_SECRET) or --access-key (HASURA_GRAPHQL_ACCESS_KEY) to be set" + requiresAdminScrtMsg = + " requires --admin-secret (HASURA_GRAPHQL_ADMIN_SECRET) or " + <> " --access-key (HASURA_GRAPHQL_ACCESS_KEY) to be set" unAuthRoleNotReqForWebHook = - when (isJust mUnAuthRole) $ - throwError $ "Fatal Error: --unauthorized-role (HASURA_GRAPHQL_UNAUTHORIZED_ROLE) is not allowed" - <> " when --auth-hook (HASURA_GRAPHQL_AUTH_HOOK) is set" + when (isJust mUnAuthRole) $ throwError $ + "Fatal Error: --unauthorized-role (HASURA_GRAPHQL_UNAUTHORIZED_ROLE) is not allowed" + <> " when --auth-hook (HASURA_GRAPHQL_AUTH_HOOK) is set" mkJwtCtx :: ( MonadIO m @@ -167,7 +168,7 @@ mkUserInfoFromResp logger url method statusCode respBody throw500 "missing x-hasura-role key in webhook response" Just rn -> do logWebHookResp L.LevelInfo Nothing - return $ mkUserInfo rn usrVars Nothing + return $ mkUserInfo rn usrVars logError = logWebHookResp L.LevelError $ Just respBody @@ -219,7 +220,6 @@ userInfoFromAuthHook logger manager hook reqHeaders = do , "Cache-Control", "Connection", "DNT" ] - getUserInfo :: (MonadIO m, MonadError QErr m) => L.Logger @@ -227,17 +227,29 @@ getUserInfo -> [N.Header] -> AuthMode -> m UserInfo -getUserInfo logger manager rawHeaders = \case +getUserInfo l m r a = fst <$> getUserInfoWithExpTime l m r a - AMNoAuth -> return userInfoFromHeaders +getUserInfoWithExpTime + :: (MonadIO m, MonadError QErr m) + => L.Logger + -> H.Manager + -> [N.Header] + -> AuthMode + -> m (UserInfo, Maybe UTCTime) +getUserInfoWithExpTime logger manager rawHeaders = \case + + AMNoAuth -> return (userInfoFromHeaders, Nothing) AMAdminSecret adminScrt unAuthRole -> case adminSecretM of - Just givenAdminScrt -> userInfoWhenAdminSecret adminScrt givenAdminScrt - Nothing -> userInfoWhenNoAdminSecret unAuthRole + Just givenAdminScrt -> + withNoExpTime $ userInfoWhenAdminSecret adminScrt givenAdminScrt + Nothing -> + withNoExpTime $ userInfoWhenNoAdminSecret unAuthRole AMAdminSecretAndHook accKey hook -> - whenAdminSecretAbsent accKey (userInfoFromAuthHook logger manager hook rawHeaders) + whenAdminSecretAbsent accKey $ + withNoExpTime $ userInfoFromAuthHook logger manager hook rawHeaders AMAdminSecretAndJWT accKey jwtSecret unAuthRole -> whenAdminSecretAbsent accKey (processJwt jwtSecret rawHeaders unAuthRole) @@ -246,21 +258,26 @@ getUserInfo logger manager rawHeaders = \case -- when admin secret is absent, run the action to retrieve UserInfo, otherwise -- adminsecret override whenAdminSecretAbsent ak action = - maybe action (userInfoWhenAdminSecret ak) adminSecretM + maybe action (withNoExpTime . userInfoWhenAdminSecret ak) adminSecretM - adminSecretM= foldl1 (<|>) $ map (`getVarVal` usrVars) [adminSecretHeader, deprecatedAccessKeyHeader] + adminSecretM= foldl1 (<|>) $ + map (`getVarVal` usrVars) [adminSecretHeader, deprecatedAccessKeyHeader] usrVars = mkUserVars $ hdrsToText rawHeaders userInfoFromHeaders = case roleFromVars usrVars of - Just rn -> mkUserInfo rn usrVars Nothing - Nothing -> mkUserInfo adminRole usrVars Nothing + Just rn -> mkUserInfo rn usrVars + Nothing -> mkUserInfo adminRole usrVars userInfoWhenAdminSecret key reqKey = do - when (reqKey /= getAdminSecret key) $ throw401 $ "invalid " <> adminSecretHeader <> "/" <> deprecatedAccessKeyHeader + when (reqKey /= getAdminSecret key) $ throw401 $ + "invalid " <> adminSecretHeader <> "/" <> deprecatedAccessKeyHeader return userInfoFromHeaders userInfoWhenNoAdminSecret = \case - Nothing -> throw401 $ adminSecretHeader <> "/" <> deprecatedAccessKeyHeader <> " required, but not found" - Just role -> return $ mkUserInfo role usrVars Nothing + Nothing -> throw401 $ adminSecretHeader <> "/" + <> deprecatedAccessKeyHeader <> " required, but not found" + Just role -> return $ mkUserInfo role usrVars + + withNoExpTime a = (, Nothing) <$> a diff --git a/server/src-lib/Hasura/Server/Auth/JWT.hs b/server/src-lib/Hasura/Server/Auth/JWT.hs index a06fc8ec31ff0..9a9b7b21ada91 100644 --- a/server/src-lib/Hasura/Server/Auth/JWT.hs +++ b/server/src-lib/Hasura/Server/Auth/JWT.hs @@ -17,8 +17,8 @@ import Crypto.JWT import Data.IORef (IORef, modifyIORef, readIORef) import Data.List (find) -import Data.Time.Clock (NominalDiffTime, diffUTCTime, - getCurrentTime) +import Data.Time.Clock (NominalDiffTime, UTCTime, + diffUTCTime, getCurrentTime) import Data.Time.Format (defaultTimeLocale, parseTimeM) import Network.URI (URI) @@ -172,7 +172,7 @@ processJwt => JWTCtx -> HTTP.RequestHeaders -> Maybe RoleName - -> m UserInfo + -> m (UserInfo, Maybe UTCTime) processJwt jwtCtx headers mUnAuthRole = maybe withoutAuthZHeader withAuthZHeader mAuthZHeader where @@ -183,7 +183,8 @@ processJwt jwtCtx headers mUnAuthRole = withoutAuthZHeader = do unAuthRole <- maybe missingAuthzHeader return mUnAuthRole - return $ mkUserInfo unAuthRole (mkUserVars $ hdrsToText headers) Nothing + return $ (, Nothing) $ + mkUserInfo unAuthRole $ mkUserVars $ hdrsToText headers missingAuthzHeader = throw400 InvalidHeaders "Missing Authorization header in JWT authentication mode" @@ -194,7 +195,7 @@ processAuthZHeader => JWTCtx -> HTTP.RequestHeaders -> BLC.ByteString - -> m UserInfo + -> m (UserInfo, Maybe UTCTime) processAuthZHeader jwtCtx headers authzHeader = do -- try to parse JWT token from Authorization header jwt <- parseAuthzHeader @@ -228,7 +229,7 @@ processAuthZHeader jwtCtx headers authzHeader = do -- transform the map of text:aeson-value -> text:text metadata <- decodeJSON $ A.Object finalClaims - return $ mkUserInfo role (mkUserVars $ Map.toList metadata) expTimeM + return $ (, expTimeM) $ mkUserInfo role $ mkUserVars $ Map.toList metadata where parseAuthzHeader = do From 366731188e6df4ab13395c12a9bbb95941866d15 Mon Sep 17 00:00:00 2001 From: rakeshkky Date: Mon, 13 May 2019 18:53:13 +0530 Subject: [PATCH 4/5] fetch current time after resolving jwt expiry time --- server/src-lib/Hasura/GraphQL/Transport/WebSocket.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/src-lib/Hasura/GraphQL/Transport/WebSocket.hs b/server/src-lib/Hasura/GraphQL/Transport/WebSocket.hs index 3ecf9b5b8b5d2..e97492f102d7d 100644 --- a/server/src-lib/Hasura/GraphQL/Transport/WebSocket.hs +++ b/server/src-lib/Hasura/GraphQL/Transport/WebSocket.hs @@ -148,7 +148,6 @@ onConn (L.Logger logger) corsPolicy wsId requestHead = do threadDelay $ 5 * 1000 * 1000 jwtExpiryHandler wsConn = do - currTime <- TC.getCurrentTime expTime <- STM.atomically $ do connState <- STM.readTVar $ (_wscUser . WS.getData) wsConn case connState of @@ -156,6 +155,7 @@ onConn (L.Logger logger) corsPolicy wsId requestHead = do CSInitError _ -> STM.retry CSInitialised _ expTimeM _ -> maybe STM.retry return expTimeM + currTime <- TC.getCurrentTime threadDelay $ diffTimeToMicro $ TC.diffUTCTime expTime currTime accept hdrs errType = do From 41506f96970cc7b2c28c38d526110f91cdea836f Mon Sep 17 00:00:00 2001 From: rakeshkky Date: Mon, 13 May 2019 23:21:37 +0530 Subject: [PATCH 5/5] log jwt expiry time, if any, with WebSocket events --- .../Hasura/GraphQL/Transport/WebSocket.hs | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/server/src-lib/Hasura/GraphQL/Transport/WebSocket.hs b/server/src-lib/Hasura/GraphQL/Transport/WebSocket.hs index e97492f102d7d..2b9e79535a9c4 100644 --- a/server/src-lib/Hasura/GraphQL/Transport/WebSocket.hs +++ b/server/src-lib/Hasura/GraphQL/Transport/WebSocket.hs @@ -111,6 +111,7 @@ data WSLog = WSLog { _wslWebsocketId :: !WS.WSId , _wslUser :: !(Maybe UserVars) + , _wslJwtExpiry :: !(Maybe TC.UTCTime) , _wslEvent :: !WSEvent , _wslMsg :: !(Maybe Text) } deriving (Show, Eq) @@ -159,7 +160,7 @@ onConn (L.Logger logger) corsPolicy wsId requestHead = do threadDelay $ diffTimeToMicro $ TC.diffUTCTime expTime currTime accept hdrs errType = do - logger $ WSLog wsId Nothing EAccepted Nothing + logger $ WSLog wsId Nothing Nothing EAccepted Nothing connData <- WSConnData <$> STM.newTVarIO (CSNotInitialised hdrs) <*> STMMap.newIO @@ -170,7 +171,7 @@ onConn (L.Logger logger) corsPolicy wsId requestHead = do (Just keepAliveAction) (Just jwtExpiryHandler) reject qErr = do - logger $ WSLog wsId Nothing (ERejected qErr) Nothing + logger $ WSLog wsId Nothing Nothing (ERejected qErr) Nothing return $ Left $ WS.RejectRequest (H.statusCode $ qeStatus qErr) (H.statusMessage $ qeStatus qErr) [] @@ -192,7 +193,7 @@ onConn (L.Logger logger) corsPolicy wsId requestHead = do if readCookie then return reqHdrs else do - liftIO $ logger $ WSLog wsId Nothing EAccepted (Just corsNote) + liftIO $ logger $ WSLog wsId Nothing Nothing EAccepted (Just corsNote) return $ filter (\h -> fst h /= "Cookie") reqHdrs CCAllowedOrigins ds -- if the origin is in our cors domains, no error @@ -381,10 +382,12 @@ logWSEvent => L.Logger -> WSConn -> WSEvent -> m () logWSEvent (L.Logger logger) wsConn wsEv = do userInfoME <- liftIO $ STM.readTVarIO userInfoR - let userInfoM = case userInfoME of - CSInitialised userInfo _ _ -> return $ userVars userInfo - _ -> Nothing - liftIO $ logger $ WSLog wsId userInfoM wsEv Nothing + let (userVarsM, jwtExpM) = case userInfoME of + CSInitialised userInfo jwtM _ -> ( Just $ userVars userInfo + , jwtM + ) + _ -> (Nothing, Nothing) + liftIO $ logger $ WSLog wsId userVarsM jwtExpM wsEv Nothing where WSConnData userInfoR _ _ = WS.getData wsConn wsId = WS.getWSId wsConn