diff --git a/llama_stack/providers/remote/post_training/nvidia/post_training.py b/llama_stack/providers/remote/post_training/nvidia/post_training.py index d839ffd6f..2277ac121 100644 --- a/llama_stack/providers/remote/post_training/nvidia/post_training.py +++ b/llama_stack/providers/remote/post_training/nvidia/post_training.py @@ -67,18 +67,12 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper): self.timeout = aiohttp.ClientTimeout(total=config.timeout) # TODO: filter by available models based on /config endpoint ModelRegistryHelper.__init__(self, model_entries=_MODEL_ENTRIES) - self.session = None self.customizer_url = config.customizer_url if not self.customizer_url: warnings.warn("Customizer URL is not set, using default value: http://nemo.test", stacklevel=2) self.customizer_url = "http://nemo.test" - async def _get_session(self) -> aiohttp.ClientSession: - if self.session is None or self.session.closed: - self.session = aiohttp.ClientSession(headers=self.headers, timeout=self.timeout) - return self.session - async def _make_request( self, method: str, @@ -99,13 +93,13 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper): if json and "Content-Type" not in request_headers: request_headers["Content-Type"] = "application/json" - session = await self._get_session() for _ in range(self.config.max_retries): - async with session.request(method, url, params=params, json=json, **kwargs) as response: - if response.status >= 400: - error_data = await response.json() - raise Exception(f"API request failed: {error_data}") - return await response.json() + async with aiohttp.ClientSession(headers=request_headers) as session: + async with session.request(method, url, params=params, json=json, **kwargs) as response: + if response.status >= 400: + error_data = await response.json() + raise Exception(f"API request failed: {error_data}") + return await response.json() async def get_training_jobs( self,