From 00bad820899eaf01328f7d5b4b67f3f1a2e11ac3 Mon Sep 17 00:00:00 2001 From: SimonTaurus Date: Fri, 8 May 2026 04:29:12 +0200 Subject: [PATCH] fix: handle expired MW sessions in ApiGateway transport On 403, refresh CSRF token and retry. If still 403, re-login via CredentialManager and retry once more. Applied to all transport creation points (get_gateway_httpx_settings, _deploy, _httpx_gateway). --- src/osw/utils/_httpx_gateway.py | 11 +++++++++- src/osw/utils/workflow.py | 39 +++++++++++++++++++++++++++++---- 2 files changed, 45 insertions(+), 5 deletions(-) diff --git a/src/osw/utils/_httpx_gateway.py b/src/osw/utils/_httpx_gateway.py index 6cdd649..db72f50 100644 --- a/src/osw/utils/_httpx_gateway.py +++ b/src/osw/utils/_httpx_gateway.py @@ -37,9 +37,18 @@ def _ensure_initialized(self): from osw.utils.workflow import ApiGatewayTransport, connect osw_instance = connect() + mw_site = osw_instance.site.mw_site + + def _relogin(): + cred = osw_instance.site._cred_mngr.get_credential( + osw_instance.site._iri + ) + mw_site.login(username=cred.username, password=cred.password) + self._inner = ApiGatewayTransport( gateway_url=self._gateway_url, - mw_site=osw_instance.site.mw_site, + mw_site=mw_site, + relogin_cb=_relogin, ) async def handle_async_request(self, request): diff --git a/src/osw/utils/workflow.py b/src/osw/utils/workflow.py index 5a8c3f4..4c7425b 100644 --- a/src/osw/utils/workflow.py +++ b/src/osw/utils/workflow.py @@ -116,7 +116,13 @@ class ApiGatewayTransport(httpx.AsyncBaseTransport): and injects MediaWiki session cookies + CSRF tokens. """ - def __init__(self, gateway_url: str, mw_site, csrf_required: bool = True): + def __init__( + self, + gateway_url: str, + mw_site, + csrf_required: bool = True, + relogin_cb=None, + ): """ Parameters ---------- @@ -127,11 +133,14 @@ def __init__(self, gateway_url: str, mw_site, csrf_required: bool = True): Authenticated mwclient Site instance. csrf_required Whether to send MW CSRF token for write methods. + relogin_cb + Callable that re-authenticates ``mw_site`` when the session expires. """ self._gateway_url = gateway_url.rstrip("/") self._mw_site = mw_site self._csrf_token = None self._csrf_required = csrf_required + self._relogin_cb = relogin_cb def _get_csrf_token(self) -> str: if self._csrf_token is None: @@ -230,11 +239,19 @@ async def handle_async_request(self, request: httpx.Request) -> httpx.Response: response = await httpx.AsyncHTTPTransport().handle_async_request( redirect_req ) - # Refresh CSRF token and retry once on 403 + # Refresh CSRF token and retry on 403; if still 403, re-login if response.status_code == 403: self._csrf_token = None rewritten = self._rewrite_request(request) response = await httpx.AsyncHTTPTransport().handle_async_request(rewritten) + if response.status_code == 403 and self._relogin_cb: + log.warning("ApiGateway session expired, re-authenticating") + self._relogin_cb() + self._csrf_token = None + rewritten = self._rewrite_request(request) + response = await httpx.AsyncHTTPTransport().handle_async_request( + rewritten + ) return response @@ -250,10 +267,17 @@ def get_gateway_httpx_settings(gateway_url: str, osw_instance: OSW) -> dict: osw_instance A connected OSW instance (provides mwclient session). """ + mw_site = osw_instance.site.mw_site + + def _relogin(): + cred = osw_instance.site._cred_mngr.get_credential(osw_instance.site._iri) + mw_site.login(username=cred.username, password=cred.password) + transport = ApiGatewayTransport( gateway_url=gateway_url, - mw_site=osw_instance.site.mw_site, + mw_site=mw_site, csrf_required=False, + relogin_cb=_relogin, ) return {"transport": transport, "base_url": gateway_url} @@ -640,9 +664,16 @@ async def _deploy(param: DeployParam): if _is_apigateway_url(gateway_url) and param.osw is not None: _original_api_url = environ.get("PREFECT_API_URL") environ["PREFECT_API_URL"] = gateway_url + _mw_site = param.osw.site.mw_site + + def _relogin(): + cred = param.osw.site._cred_mngr.get_credential(param.osw.site._iri) + _mw_site.login(username=cred.username, password=cred.password) + _gw_transport = ApiGatewayTransport( gateway_url=gateway_url, - mw_site=param.osw.site.mw_site, + mw_site=_mw_site, + relogin_cb=_relogin, ) # Patch httpx.AsyncClient to auto-inject our transport when # the base_url is an ApiGateway URL. One patch covers ALL