diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index d33794006..2e6ae154f 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -674,8 +674,7 @@ getSMPServerClient :: AgentClient -> NetworkRequestMode -> SMPTransportSession - getSMPServerClient c@AgentClient {active, smpClients, workerSeq} nm tSess = do unlessM (readTVarIO active) $ throwE INACTIVE ts <- liftIO getCurrentTime - atomically (getSessVar workerSeq tSess smpClients ts) - >>= either newClient (waitForProtocolClient c nm tSess smpClients) + withGetSessVar workerSeq tSess smpClients ts newClient (waitForProtocolClient c nm tSess smpClients) where newClient v = do prs <- liftIO TM.emptyIO @@ -686,29 +685,25 @@ getSMPProxyClient c@AgentClient {active, smpClients, smpProxiedRelays, workerSeq unlessM (readTVarIO active) $ throwE INACTIVE proxySrv <- maybe (getNextServer c userId proxySrvs [destSrv]) pure proxySrv_ ts <- liftIO getCurrentTime - atomically (getClientVar proxySrv ts) >>= \(tSess, auth, v) -> - either (newProxyClient tSess auth ts) (waitForProxyClient tSess auth) v + (tSess, auth) <- atomically $ proxiedTransportSession proxySrv + withGetSessVar workerSeq tSess smpClients ts (newProxyClient tSess auth ts) (waitForProxyClient tSess auth) where - getClientVar :: SMPServerWithAuth -> UTCTime -> STM (SMPTransportSession, Maybe SMP.BasicAuth, Either SMPClientVar SMPClientVar) - getClientVar proxySrv ts = do + proxiedTransportSession :: SMPServerWithAuth -> STM (SMPTransportSession, Maybe SMP.BasicAuth) + proxiedTransportSession proxySrv = do ProtoServerWithAuth srv auth <- TM.lookup destSess smpProxiedRelays >>= maybe (TM.insert destSess proxySrv smpProxiedRelays $> proxySrv) pure - let tSess = (userId, srv, qId) - (tSess,auth,) <$> getSessVar workerSeq tSess smpClients ts + pure ((userId, srv, qId), auth) newProxyClient :: SMPTransportSession -> Maybe SMP.BasicAuth -> UTCTime -> SMPClientVar -> AM (SMPConnectedClient, Either AgentErrorType ProxiedRelay) newProxyClient tSess auth ts v = do prs <- liftIO TM.emptyIO - -- we do not need to check if it is a new proxied relay session, - -- as the client is just created and there are no sessions yet - rv <- atomically $ either id id <$> getSessVar workerSeq destSrv prs ts clnt <- smpConnectClient c nm tSess prs v - (clnt,) <$> newProxiedRelay clnt auth rv + -- the relay var is always new (the client is just created and has no sessions yet) + sess <- withGetSessVar workerSeq destSrv prs ts (newProxiedRelay clnt auth) (waitForProxiedRelay tSess) + pure (clnt, sess) waitForProxyClient :: SMPTransportSession -> Maybe SMP.BasicAuth -> SMPClientVar -> AM (SMPConnectedClient, Either AgentErrorType ProxiedRelay) waitForProxyClient tSess auth v = do clnt@(SMPConnectedClient _ prs) <- waitForProtocolClient c nm tSess smpClients v ts <- liftIO getCurrentTime - sess <- - atomically (getSessVar workerSeq destSrv prs ts) - >>= either (newProxiedRelay clnt auth) (waitForProxiedRelay tSess) + sess <- withGetSessVar workerSeq destSrv prs ts (newProxiedRelay clnt auth) (waitForProxiedRelay tSess) pure (clnt, sess) newProxiedRelay :: SMPConnectedClient -> Maybe SMP.BasicAuth -> ProxiedRelayVar -> AM (Either AgentErrorType ProxiedRelay) newProxiedRelay (SMPConnectedClient smp prs) proxyAuth rv = @@ -846,10 +841,7 @@ getNtfServerClient :: AgentClient -> NetworkRequestMode -> NtfTransportSession - getNtfServerClient c@AgentClient {active, ntfClients, workerSeq, proxySessTs, presetDomains} nm tSess@(_, srv, _) = do unlessM (readTVarIO active) $ throwE INACTIVE ts <- liftIO getCurrentTime - atomically (getSessVar workerSeq tSess ntfClients ts) - >>= either - (newProtocolClient c tSess ntfClients connectClient) - (waitForProtocolClient c nm tSess ntfClients) + withGetSessVar workerSeq tSess ntfClients ts (newProtocolClient c tSess ntfClients connectClient) (waitForProtocolClient c nm tSess ntfClients) where connectClient :: NtfClientVar -> AM NtfClient connectClient v = do @@ -870,10 +862,7 @@ getXFTPServerClient :: AgentClient -> XFTPTransportSession -> AM XFTPClient getXFTPServerClient c@AgentClient {active, xftpClients, workerSeq, proxySessTs, presetDomains} tSess@(_, srv, _) = do unlessM (readTVarIO active) $ throwE INACTIVE ts <- liftIO getCurrentTime - atomically (getSessVar workerSeq tSess xftpClients ts) - >>= either - (newProtocolClient c tSess xftpClients connectClient) - (waitForProtocolClient c NRMBackground tSess xftpClients) + withGetSessVar workerSeq tSess xftpClients ts (newProtocolClient c tSess xftpClients connectClient) (waitForProtocolClient c NRMBackground tSess xftpClients) where connectClient :: XFTPClientVar -> AM XFTPClient connectClient v = do diff --git a/src/Simplex/Messaging/Client/Agent.hs b/src/Simplex/Messaging/Client/Agent.hs index 76b2a7cf9..e41d7a811 100644 --- a/src/Simplex/Messaging/Client/Agent.hs +++ b/src/Simplex/Messaging/Client/Agent.hs @@ -205,11 +205,8 @@ getSMPServerClient' ca srv = snd <$> getSMPServerClient'' ca srv getSMPServerClient'' :: SMPClientAgent p -> SMPServer -> ExceptT SMPClientError IO (OwnServer, SMPClient) getSMPServerClient'' ca@SMPClientAgent {agentCfg, smpClients, smpSessions, workerSeq} srv = do ts <- liftIO getCurrentTime - atomically (getClientVar ts) >>= either (ExceptT . newSMPClient) waitForSMPClient + withGetSessVar workerSeq srv smpClients ts (ExceptT . newSMPClient) waitForSMPClient where - getClientVar :: UTCTime -> STM (Either SMPClientVar SMPClientVar) - getClientVar = getSessVar workerSeq srv smpClients - waitForSMPClient :: SMPClientVar -> ExceptT SMPClientError IO (OwnServer, SMPClient) waitForSMPClient v = do let ProtocolClientConfig {networkConfig = NetworkConfig {tcpConnectTimeout}} = smpCfg agentCfg diff --git a/src/Simplex/Messaging/Notifications/Server.hs b/src/Simplex/Messaging/Notifications/Server.hs index 02429e910..f700d8d39 100644 --- a/src/Simplex/Messaging/Notifications/Server.hs +++ b/src/Simplex/Messaging/Notifications/Server.hs @@ -647,13 +647,14 @@ pushNotification s srvHost_ isOwn tkn@NtfTknRec {token = DeviceToken pp _} ntf = getOrCreatePushWorker :: NtfPushServer -> (Maybe T.Text, PushProvider) -> OwnServer -> M (TBQueue (NtfTknRec, PushNotification)) getOrCreatePushWorker s@NtfPushServer {pushWorkers, pushWorkerSeq, pushQSize} key@(srvHost_, _) isOwn = do ts <- liftIO getCurrentTime - atomically (getSessVar pushWorkerSeq key pushWorkers ts) >>= \case - Left v -> do + withGetSessVar' pushWorkerSeq key pushWorkers ts createWorker existingWorker + where + createWorker v = do q <- liftIO $ newTBQueueIO pushQSize tId <- mkWeakThreadId =<< forkIO (runPushWorker s srvHost_ isOwn q) atomically $ putTMVar (sessionVar v) PushWorker {workerQ = q, workerThreadId = tId} pure q - Right v -> workerQ <$> atomically (readTMVar $ sessionVar v) + existingWorker v = workerQ <$> atomically (readTMVar $ sessionVar v) runPushWorker :: NtfPushServer -> Maybe T.Text -> OwnServer -> TBQueue (NtfTknRec, PushNotification) -> M () runPushWorker s srvHost_ isOwn q = forever $ do diff --git a/src/Simplex/Messaging/Session.hs b/src/Simplex/Messaging/Session.hs index ff5d7e0a0..16ff3c934 100644 --- a/src/Simplex/Messaging/Session.hs +++ b/src/Simplex/Messaging/Session.hs @@ -6,14 +6,20 @@ module Simplex.Messaging.Session ( SessionVar (..), getSessVar, removeSessVar, + withGetSessVar, + withGetSessVar', tryReadSessVar, ) where import Control.Concurrent.STM +import Control.Monad.Except (ExceptT (..), runExceptT) +import Control.Monad.IO.Class (liftIO) +import Control.Monad.IO.Unlift (MonadUnliftIO) import Data.Time (UTCTime) import Simplex.Messaging.TMap (TMap) import qualified Simplex.Messaging.TMap as TM -import Simplex.Messaging.Util (($>>=)) +import Simplex.Messaging.Util (whenM, ($>>=)) +import UnliftIO.Exception (bracketOnError) data SessionVar a = SessionVar { sessionVar :: TMVar a, @@ -38,5 +44,31 @@ removeSessVar v sessKey vs = Just v' | sessionVarId v == sessionVarId v' -> TM.delete sessKey vs _ -> pure () +-- | Get or create a session var and route to onNew (newly created) or onExisting. The new-var +-- branch is bracketed from the point of creation: if it is interrupted before filling the var +-- (e.g. an async exception during connect), the still-empty var is dropped from the map so the +-- next request creates a fresh session instead of blocking on a var that will never be filled. +-- A thrown ExceptT error is a normal result (the var keeps the error it was filled with) - only +-- an interrupting exception drops the empty var. +withGetSessVar :: + (Ord k, MonadUnliftIO m) => + TVar Int -> k -> TMap k (SessionVar a) -> UTCTime -> + (SessionVar a -> ExceptT e m b) -> (SessionVar a -> ExceptT e m b) -> ExceptT e m b +withGetSessVar sessSeq sessKey vs ts onNew onExisting = + ExceptT $ withGetSessVar' sessSeq sessKey vs ts (runExceptT . onNew) (runExceptT . onExisting) + +-- | withGetSessVar for actions in the underlying monad (without ExceptT). +withGetSessVar' :: + (Ord k, MonadUnliftIO m) => + TVar Int -> k -> TMap k (SessionVar a) -> UTCTime -> + (SessionVar a -> m b) -> (SessionVar a -> m b) -> m b +withGetSessVar' sessSeq sessKey vs ts onNew onExisting = + bracketOnError + (liftIO $ atomically $ getSessVar sessSeq sessKey vs ts) + (either (liftIO . atomically . dropEmptySessVar) (\_ -> pure ())) + (either onNew onExisting) + where + dropEmptySessVar v = whenM (isEmptyTMVar $ sessionVar v) $ removeSessVar v sessKey vs + tryReadSessVar :: Ord k => k -> TMap k (SessionVar a) -> STM (Maybe a) tryReadSessVar sessKey vs = TM.lookup sessKey vs $>>= (tryReadTMVar . sessionVar) diff --git a/tests/CoreTests/UtilTests.hs b/tests/CoreTests/UtilTests.hs index 580f4e9b0..c274cbf32 100644 --- a/tests/CoreTests/UtilTests.hs +++ b/tests/CoreTests/UtilTests.hs @@ -50,6 +50,13 @@ utilTests = do runExceptT (tryAllErrors throwTestException) `shouldReturn` Right (Left (TestException "user error (error)")) it "should return no errors as Right" $ runExceptT (tryAllErrors noErrors) `shouldReturn` Right (Right "no errors") + -- tryAllErrors rethrows asynchronous exceptions (it uses UnliftIO.catch). Any recovery placed + -- after `tryAllErrors action` - e.g. putTMVar to fill a SessionVar - is therefore SKIPPED when + -- the thread is killed mid-action. Unlike tryAllOwnErrors, it also rethrows the overflow exceptions. + it "should rethrow ThreadKilled" $ + runExceptT (tryAllErrors $ throwAsync ThreadKilled) `shouldThrow` (\e -> e == ThreadKilled) + it "should rethrow StackOverflow" $ + runExceptT (tryAllErrors $ throwAsync StackOverflow) `shouldThrow` (\e -> e == StackOverflow) describe "catchAllErrors" $ do it "should catch ExceptT error" $ runExceptT (throwTestError `catchAllErrors` handleCatch) `shouldReturn` Right "caught TestError \"error\"" diff --git a/tests/SMPClient.hs b/tests/SMPClient.hs index d043fd3c8..85bd08a1b 100644 --- a/tests/SMPClient.hs +++ b/tests/SMPClient.hs @@ -25,7 +25,7 @@ import Network.Socket import qualified Network.TLS as TLS import Simplex.Messaging.Agent.Store.Postgres.Options (DBOpts (..)) import Simplex.Messaging.Agent.Store.Shared (MigrationConfirmation (..)) -import Simplex.Messaging.Client (ProtocolClientConfig (..), chooseTransportHost, defaultNetworkConfig) +import Simplex.Messaging.Client (NetworkConfig (..), NetworkTimeout (..), ProtocolClientConfig (..), chooseTransportHost, defaultNetworkConfig) import Simplex.Messaging.Client.Agent (SMPClientAgentConfig (..), defaultSMPClientAgentConfig) import qualified Simplex.Messaging.Crypto as C import Simplex.Messaging.Encoding @@ -339,6 +339,16 @@ proxyCfgJ2QS = \case SQSMemory -> journalCfg (proxyCfgMS $ ASType SQSMemory SMSJournal) testStoreLogFile2 testStoreMsgsDir2 SQSPostgres -> journalCfgDB (proxyCfgMS $ ASType SQSPostgres SMSJournal) testStoreDBOpts2 testStoreMsgsDir2 +-- Proxy config with a short relay-connection timeout, to bound how long a failing +-- proxy->relay connection attempt blocks in the relay reconnection tests. +proxyCfgShortTimeout :: AServerConfig +proxyCfgShortTimeout = + updateCfg proxyCfg $ \cfg' -> + let aCfg = smpAgentCfg cfg' + cCfg = smpCfg aCfg + nt = NetworkTimeout {backgroundTimeout = 4_000000, interactiveTimeout = 4_000000} + in cfg' {smpAgentCfg = aCfg {smpCfg = cCfg {networkConfig = (networkConfig cCfg) {tcpConnectTimeout = nt}}}} + proxyVRangeV8 :: VersionRangeSMP proxyVRangeV8 = mkVersionRange minServerSMPRelayVersion sendingProxySMPVersion @@ -383,6 +393,15 @@ serverBracket process afterProcess f = do Nothing -> error $ "server did not " <> s _ -> pure () +-- A TCP server that accepts connections but never performs a TLS handshake, so a client +-- connecting to it stays blocked in the TLS handshake until its connection timeout. +withStallingServerOn :: HasCallStack => ServiceName -> IO a -> IO a +withStallingServerOn port action = + serverBracket + (\started -> runLocalTCPServer started port (\_ -> threadDelay maxBound)) + (pure ()) + (const action) + withSmpServerOn :: HasCallStack => (ASrvTransport, AStoreType) -> ServiceName -> IO a -> IO a withSmpServerOn ps port' = withSmpServerThreadOn ps port' . const diff --git a/tests/SMPProxyTests.hs b/tests/SMPProxyTests.hs index 0d8ccdf89..88525eac0 100644 --- a/tests/SMPProxyTests.hs +++ b/tests/SMPProxyTests.hs @@ -58,6 +58,14 @@ smpProxyTests = do describe "server configuration" $ do it "refuses proxy handshake unless enabled" testNoProxy it "checks basic auth in proxy requests" testProxyAuth + describe "relay reconnection" $ do + it "recovers when unresponsive relay restarts (control, no disconnect)" $ \_ -> + testProxyRecoversWithoutDisconnect + it "reconnects to relay after sender disconnects mid-connection" $ \_ -> + testProxyReconnectAfterRelayRestart + describe "agent client reconnection" $ do + it "reconnects after a connect is cancelled mid-flight" $ \_ -> + testAgentClientReconnectAfterCancel describe "proxy requests" $ do describe "bad relay URIs" $ do xit "host not resolved" todo @@ -447,6 +455,89 @@ testProxyAuth msType = do where proxyCfgAuth = updateCfg (proxyCfgMS msType) $ \cfg_ -> cfg_ {newQueueBasicAuth = Just "correct"} +-- Connect a sender client to the proxy and request a relay session to testSMPServer2 (PRXY). +-- On success the reply is PKEY; otherwise it is the proxy error for the relay connection. +requestRelaySession :: IO (Either SMP.ErrorType SMP.BrokerMsg) +requestRelaySession = + testSMPClient_ "localhost" testPort proxyVRangeV8 Nothing $ \(th :: THandleSMP TLS 'TClient) -> + (\(_, _, reply) -> reply) <$> sendRecv th (Nothing, "1", NoEntity, SMP.PRXY testSMPServer2 Nothing) + +-- Shared "phase 2" of the reconnection tests: start a healthy relay, confirm it is reachable +-- directly (PING, not via the proxy) so a proxy failure can only mean the proxy didn't reconnect, +-- let any stored connection error expire, then require the proxy to establish the session (PKEY). +requireProxyReconnect :: IO () +requireProxyReconnect = + withSmpServerConfigOn (transport @TLS) proxyCfgJ2 testPort2 $ \_ -> do + testSMPClient_ "127.0.0.1" testPort2 proxyVRangeV8 Nothing $ \(th :: THandleSMP TLS 'TClient) -> do + (_, _, reply) <- sendRecv th (Nothing, "0", NoEntity, SMP.PING) + reply `shouldBe` Right SMP.PONG + threadDelay 1500000 -- > persistErrorInterval (1s), so the stored connection error has expired + requestRelaySession >>= \case + Right SMP.PKEY {} -> pure () + reply -> expectationFailure $ "proxy failed to reach the healthy relay; expected PKEY, got: " <> show reply + +-- Control: same stalling relay and proxy config as the bug test, but the sender stays connected. +-- The connect fails by timing out (storing a Left error that self-heals via persistErrorInterval), +-- so once a healthy relay is running the proxy reconnects. This proves the stalling relay alone +-- does not cause the permanent failure - only the mid-connection disconnect does. +testProxyRecoversWithoutDisconnect :: IO () +testProxyRecoversWithoutDisconnect = + withSmpServerConfigOn (transport @TLS) proxyCfgShortTimeout testPort $ \_ -> do + withStallingServerOn testPort2 $ + requestRelaySession >>= \case + Right (SMP.ERR (SMP.PROXY (SMP.BROKER _))) -> pure () + reply -> expectationFailure $ "expected a proxy broker error from the unresponsive relay, got: " <> show reply + requireProxyReconnect + +-- Reproduces the production bug: an SMP proxy permanently fails to reconnect to a destination +-- relay after the relay restarts (logs: repeated PCEResponseTimeout). +-- +-- A PRXY request makes the proxy worker (forked via forkClient, registered in the sender's +-- endThreads) insert an empty SessionVar into smpClients and then block in connectClient. If the +-- sender disconnects while that connect is in flight, clientDisconnected kills the worker; +-- clientHandlers re-throws the async exception, so the SessionVar is never filled. Nothing removes +-- an empty SessionVar, so every later request waits the connection timeout on it - PROXY (BROKER +-- TIMEOUT) - forever, even once the relay is healthy again. +-- +-- The stalling relay (accepts TCP, never completes TLS) holds the connect open long enough to +-- interleave the disconnect. Phase 2 (requireProxyReconnect) is identical to the control above; +-- the only difference is this disconnect. +testProxyReconnectAfterRelayRestart :: IO () +testProxyReconnectAfterRelayRestart = + withSmpServerConfigOn (transport @TLS) proxyCfgShortTimeout testPort $ \_ -> do + -- disconnect the sender 1s into the 4s connect to the stalling relay, killing the in-flight worker + withStallingServerOn testPort2 $ + race_ (threadDelay 1000000) requestRelaySession + requireProxyReconnect + +-- Bug B (same root cause as the proxy, in the messaging agent): getSMPServerClient inserts an +-- empty SessionVar into smpClients, then connects inside newProtocolClient's tryAllErrors, which +-- rethrows async exceptions. If the connecting thread is cancelled mid-connect, putTMVar is +-- skipped and the empty var is left in smpClients, so every later connection to that server times +-- out on it. Phase 1 cancels a connect to a stalling relay; phase 2 requires a fresh connect to a +-- healthy relay to succeed. +testAgentClientReconnectAfterCancel :: IO () +testAgentClientReconnectAfterCancel = + withAgent 1 agentCfg agentServersLeak testDB $ \a -> do + withStallingServerOn testPort2 $ do + t <- async $ runExceptT $ A.createConnection a NRMInteractive 1 True True SCMInvitation Nothing Nothing CR.IKPQOn SMSubscribe + threadDelay 1000000 -- let the connect to the stalling relay start, then kill it mid-flight + cancel t + withSmpServerConfigOn (transport @TLS) cfgJ2 testPort2 $ \_ -> do + testSMPClient_ "127.0.0.1" testPort2 proxyVRangeV8 Nothing $ \(th :: THandleSMP TLS 'TClient) -> do + (_, _, reply) <- sendRecv th (Nothing, "0", NoEntity, SMP.PING) + reply `shouldBe` Right SMP.PONG -- the relay is up and reachable, so a timeout can only be the poisoned var + r <- timeout 8000000 $ runExceptT $ A.createConnection a NRMInteractive 1 True True SCMInvitation Nothing Nothing CR.IKPQOn SMSubscribe + case r of + Just (Right _) -> pure () + _ -> expectationFailure $ "agent failed to connect after a cancelled connect; got: " <> show r + where + agentServersLeak = + initAgentServers + { smp = userServers [testSMPServer2], + netCfg = (netCfg initAgentServers) {tcpConnectTimeout = NetworkTimeout 4000000 4000000} + } + todo :: AStoreType -> IO () todo _ = fail "TODO"