diff --git a/server/src-lib/Hasura/GraphQL/Transport/WebSocket.hs b/server/src-lib/Hasura/GraphQL/Transport/WebSocket.hs index b4f6bcfe53a1e..5876dee1b6278 100644 --- a/server/src-lib/Hasura/GraphQL/Transport/WebSocket.hs +++ b/server/src-lib/Hasura/GraphQL/Transport/WebSocket.hs @@ -16,7 +16,6 @@ import qualified Data.Aeson.TH as J import qualified Data.ByteString.Lazy as BL import qualified Data.CaseInsensitive as CI import qualified Data.HashMap.Strict as Map -import qualified Data.TByteString as TBS import qualified Data.Text as T import qualified Data.Text.Encoding as TE import qualified Language.GraphQL.Draft.Syntax as G @@ -70,19 +69,12 @@ sendMsg :: (MonadIO m) => WSConn -> ServerMsg -> m () sendMsg wsConn = liftIO . WS.sendMsg wsConn . encodeServerMsg -data SubsDetail - = SDStarted - | SDStopped - deriving (Show, Eq) -$(J.deriveToJSON - J.defaultOptions { J.constructorTagModifier = J.snakeCase . drop 2 - , J.sumEncoding = J.TaggedObject "type" "detail" - } - ''SubsDetail) - data OpDetail - = ODCompleted - | ODError !QErr + = ODStarted + | ODProtoErr !Text + | ODQueryErr !QErr + | ODCompleted + | ODStopped deriving (Show, Eq) $(J.deriveToJSON J.defaultOptions { J.constructorTagModifier = J.snakeCase . drop 2 @@ -93,9 +85,8 @@ $(J.deriveToJSON data WSEvent = EAccepted | ERejected !QErr - | EProtocolError !TBS.TByteString !ConnErrMsg - | EOperation !OperationId !OpDetail - | ESubscription !OperationId !SubsDetail + | EConnErr !ConnErrMsg + | EOperation !OperationId !(Maybe OperationName) !OpDetail | EClosed deriving (Show, Eq) $(J.deriveToJSON @@ -154,31 +145,28 @@ onConn (L.Logger logger) wsId requestHead = do throw404 "only /v1alpha1/graphql is supported on websockets" onStart :: WSServerEnv -> WSConn -> StartMsg -> IO () -onStart serverEnv wsConn msg@(StartMsg opId q) = catchAndSend $ do +onStart serverEnv wsConn (StartMsg opId q) = catchAndIgnore $ do opM <- liftIO $ STM.atomically $ STMMap.lookup opId opMap - when (isJust opM) $ withExceptT preExecErr $ loggingQErr $ - throw400 UnexpectedPayload $ + when (isJust opM) $ withComplete $ sendConnErr $ "an operation already exists with this id: " <> unOperationId opId userInfoM <- liftIO $ IORef.readIORef userInfoR userInfo <- case userInfoM of Just (Right userInfo) -> return userInfo - Just (Left initErr) -> throwError $ SMConnErr $ ConnErrMsg $ - "cannot start as connection_init failed with : " <> initErr + Just (Left initErr) -> do + let connErr = "cannot start as connection_init failed with : " <> initErr + withComplete $ sendConnErr connErr Nothing -> do - let err = "start received before the connection is initialised" - liftIO $ logger $ WSLog wsId $ - -- TODO: we are encoding the start msg back into a bytestring - -- should we be throwing protocol error here? - EProtocolError (TBS.fromLBS $ J.encode msg) err - throwError $ SMConnErr err + let connErr = "start received before the connection is initialised" + withComplete $ sendConnErr connErr -- validate and build tx gCtxMap <- fmap snd $ liftIO $ IORef.readIORef gCtxMapRef let gCtx = getGCtx (userRole userInfo) gCtxMap - (opTy, fields) <- withExceptT preExecErr $ loggingQErr $ + + (opTy, fields) <- either (withComplete . preExecErr) return $ runReaderT (validateGQ q) gCtx let qTx = RQ.setHeadersTx userInfo >> resolveSelSet userInfo gCtx opTy fields @@ -189,35 +177,55 @@ onStart serverEnv wsConn msg@(StartMsg opId q) = catchAndSend $ do liftIO $ STM.atomically $ STMMap.insert lq opId opMap liftIO $ LQ.addLiveQuery runTx lqMap lq qTx (wsId, opId) liveQOnChange - liftIO $ logger $ WSLog wsId $ ESubscription opId SDStarted - - _ -> withExceptT postExecErr $ loggingQErr $ do - resp <- ExceptT $ runTx qTx - sendMsg wsConn $ SMData $ DataMsg opId $ GQSuccess resp - sendMsg wsConn $ SMComplete $ CompletionMsg opId - liftIO $ logger $ WSLog wsId $ EOperation opId ODCompleted + logOpEv ODStarted + _ -> do + logOpEv ODStarted + resp <- liftIO $ runTx qTx + either postExecErr sendSuccResp resp + sendCompleted where - (WSServerEnv (L.Logger logger) _ runTx lqMap gCtxMapRef _) = serverEnv + WSServerEnv (L.Logger logger) _ runTx lqMap gCtxMapRef _ = serverEnv wsId = WS.getWSId wsConn - (WSConnData userInfoR opMap) = WS.getData wsConn + WSConnData userInfoR opMap = WS.getData wsConn - -- on change, send message on the websocket - liveQOnChange resp = WS.sendMsg wsConn $ encodeServerMsg $ SMData $ - DataMsg opId resp + logOpEv opDet = + liftIO $ logger $ WSLog wsId $ + EOperation opId (_grOperationName q) opDet + + sendConnErr connErr = do + sendMsg wsConn $ SMErr $ ErrorMsg opId $ J.toJSON connErr + logOpEv $ ODProtoErr connErr + + sendCompleted = do + sendMsg wsConn $ SMComplete $ CompletionMsg opId + logOpEv ODCompleted - loggingQErr m = catchError m $ \qErr -> do - liftIO $ logger $ WSLog wsId $ EOperation opId $ ODError qErr - throwError qErr + postExecErr qErr = do + logOpEv $ ODQueryErr qErr + sendMsg wsConn $ SMData $ DataMsg opId $ + GQExecError $ pure $ encodeQErr False qErr - preExecErr qErr = SMErr $ ErrorMsg opId $ encodeQErr False qErr - postExecErr qErr = SMData $ DataMsg opId $ GQExecError - [encodeQErr False qErr] + -- why wouldn't pre exec error use graphql response? + preExecErr qErr = do + logOpEv $ ODQueryErr qErr + sendMsg wsConn $ SMErr $ ErrorMsg opId $ encodeQErr False qErr + + sendSuccResp bs = + sendMsg wsConn $ SMData $ DataMsg opId $ GQSuccess bs + + withComplete :: ExceptT () IO () -> ExceptT () IO a + withComplete action = do + action + sendCompleted + throwError () + + -- on change, send message on the websocket + liveQOnChange resp = + WS.sendMsg wsConn $ encodeServerMsg $ SMData $ DataMsg opId resp - catchAndSend :: ExceptT ServerMsg IO () -> IO () - catchAndSend m = do - res <- runExceptT m - either (sendMsg wsConn) return res + catchAndIgnore :: ExceptT () IO () -> IO () + catchAndIgnore m = void $ runExceptT m onMessage :: AuthMode @@ -227,8 +235,7 @@ onMessage authMode serverEnv wsConn msgRaw = case J.eitherDecode msgRaw of Left e -> do let err = ConnErrMsg $ "parsing ClientMessage failed: " <> T.pack e - liftIO $ logger $ WSLog (WS.getWSId wsConn) $ - EProtocolError (TBS.fromLBS msgRaw) err + liftIO $ logger $ WSLog (WS.getWSId wsConn) $ EConnErr err sendMsg wsConn $ SMConnErr err Right msg -> case msg of @@ -247,7 +254,8 @@ onStop serverEnv wsConn (StopMsg opId) = do opM <- liftIO $ STM.atomically $ STMMap.lookup opId opMap case opM of Just liveQ -> do - liftIO $ logger $ WSLog wsId $ ESubscription opId SDStopped + let opNameM = _grOperationName $ LQ._lqRequest liveQ + liftIO $ logger $ WSLog wsId $ EOperation opId opNameM ODStopped LQ.removeLiveQuery lqMap liveQ (wsId, opId) Nothing -> return () STM.atomically $ STMMap.delete opId opMap @@ -266,7 +274,9 @@ onConnInit (L.Logger logger) manager wsConn authMode connParamsM = do Left e -> do liftIO $ IORef.writeIORef (_wscUser $ WS.getData wsConn) $ Just $ Left $ qeError e - sendMsg wsConn $ SMConnErr $ ConnErrMsg $ qeError e + let connErr = ConnErrMsg $ qeError e + liftIO $ logger $ WSLog (WS.getWSId wsConn) $ EConnErr connErr + sendMsg wsConn $ SMConnErr connErr Right userInfo -> do liftIO $ IORef.writeIORef (_wscUser $ WS.getData wsConn) $ Just $ Right userInfo