Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion src/osw/utils/_httpx_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,18 @@ def _ensure_initialized(self):
from osw.utils.workflow import ApiGatewayTransport, connect

osw_instance = connect()
mw_site = osw_instance.site.mw_site

def _relogin():
cred = osw_instance.site._cred_mngr.get_credential(
osw_instance.site._iri
)
mw_site.login(username=cred.username, password=cred.password)

self._inner = ApiGatewayTransport(
gateway_url=self._gateway_url,
mw_site=osw_instance.site.mw_site,
mw_site=mw_site,
relogin_cb=_relogin,
)

async def handle_async_request(self, request):
Expand Down
39 changes: 35 additions & 4 deletions src/osw/utils/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,13 @@ class ApiGatewayTransport(httpx.AsyncBaseTransport):
and injects MediaWiki session cookies + CSRF tokens.
"""

def __init__(self, gateway_url: str, mw_site, csrf_required: bool = True):
def __init__(
self,
gateway_url: str,
mw_site,
csrf_required: bool = True,
relogin_cb=None,
):
"""
Parameters
----------
Expand All @@ -127,11 +133,14 @@ def __init__(self, gateway_url: str, mw_site, csrf_required: bool = True):
Authenticated mwclient Site instance.
csrf_required
Whether to send MW CSRF token for write methods.
relogin_cb
Callable that re-authenticates ``mw_site`` when the session expires.
"""
self._gateway_url = gateway_url.rstrip("/")
self._mw_site = mw_site
self._csrf_token = None
self._csrf_required = csrf_required
self._relogin_cb = relogin_cb

def _get_csrf_token(self) -> str:
if self._csrf_token is None:
Expand Down Expand Up @@ -230,11 +239,19 @@ async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
response = await httpx.AsyncHTTPTransport().handle_async_request(
redirect_req
)
# Refresh CSRF token and retry once on 403
# Refresh CSRF token and retry on 403; if still 403, re-login
if response.status_code == 403:
self._csrf_token = None
rewritten = self._rewrite_request(request)
response = await httpx.AsyncHTTPTransport().handle_async_request(rewritten)
if response.status_code == 403 and self._relogin_cb:
log.warning("ApiGateway session expired, re-authenticating")
self._relogin_cb()
self._csrf_token = None
rewritten = self._rewrite_request(request)
response = await httpx.AsyncHTTPTransport().handle_async_request(
rewritten
)
return response


Expand All @@ -250,10 +267,17 @@ def get_gateway_httpx_settings(gateway_url: str, osw_instance: OSW) -> dict:
osw_instance
A connected OSW instance (provides mwclient session).
"""
mw_site = osw_instance.site.mw_site

def _relogin():
cred = osw_instance.site._cred_mngr.get_credential(osw_instance.site._iri)
mw_site.login(username=cred.username, password=cred.password)

transport = ApiGatewayTransport(
gateway_url=gateway_url,
mw_site=osw_instance.site.mw_site,
mw_site=mw_site,
csrf_required=False,
relogin_cb=_relogin,
)
return {"transport": transport, "base_url": gateway_url}

Expand Down Expand Up @@ -640,9 +664,16 @@ async def _deploy(param: DeployParam):
if _is_apigateway_url(gateway_url) and param.osw is not None:
_original_api_url = environ.get("PREFECT_API_URL")
environ["PREFECT_API_URL"] = gateway_url
_mw_site = param.osw.site.mw_site

def _relogin():
cred = param.osw.site._cred_mngr.get_credential(param.osw.site._iri)
_mw_site.login(username=cred.username, password=cred.password)

_gw_transport = ApiGatewayTransport(
gateway_url=gateway_url,
mw_site=param.osw.site.mw_site,
mw_site=_mw_site,
relogin_cb=_relogin,
)
# Patch httpx.AsyncClient to auto-inject our transport when
# the base_url is an ApiGateway URL. One patch covers ALL
Expand Down
Loading