From 6346024fa391614c46f79e77877ca53e4d4653dc Mon Sep 17 00:00:00 2001 From: Jash Gulabrai Date: Fri, 6 Jun 2025 10:42:10 -0400 Subject: [PATCH] fix: Don't reuse session in NVIDIA post_training request handler --- .../post_training/nvidia/post_training.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/llama_stack/providers/remote/post_training/nvidia/post_training.py b/llama_stack/providers/remote/post_training/nvidia/post_training.py index d839ffd6f..341870290 100644 --- a/llama_stack/providers/remote/post_training/nvidia/post_training.py +++ b/llama_stack/providers/remote/post_training/nvidia/post_training.py @@ -67,18 +67,12 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper): self.timeout = aiohttp.ClientTimeout(total=config.timeout) # TODO: filter by available models based on /config endpoint ModelRegistryHelper.__init__(self, model_entries=_MODEL_ENTRIES) - self.session = None self.customizer_url = config.customizer_url if not self.customizer_url: warnings.warn("Customizer URL is not set, using default value: http://nemo.test", stacklevel=2) self.customizer_url = "http://nemo.test" - async def _get_session(self) -> aiohttp.ClientSession: - if self.session is None or self.session.closed: - self.session = aiohttp.ClientSession(headers=self.headers, timeout=self.timeout) - return self.session - async def _make_request( self, method: str, @@ -99,13 +93,13 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper): if json and "Content-Type" not in request_headers: request_headers["Content-Type"] = "application/json" - session = await self._get_session() for _ in range(self.config.max_retries): - async with session.request(method, url, params=params, json=json, **kwargs) as response: - if response.status >= 400: - error_data = await response.json() - raise Exception(f"API request failed: {error_data}") - return await response.json() + async with aiohttp.ClientSession(headers=request_headers) as session: + async with session.request(method, url, params=params, json=json, **kwargs) as response: + if response.status >= 400: + error_data = await response.json() + raise Exception(f"API request failed: {error_data}") + return await response.json() async def get_training_jobs( self, @@ -176,6 +170,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper): - metrics: Optional[Dict] - Additional training metrics - status_logs: Optional[List] - Detailed logs of status changes """ + print("Using local Llama Stack Customizer API") response = await self._make_request( "GET", f"/v1/customization/jobs/{job_uuid}/status",