mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 19:04:19 +00:00
fix: Don't reuse session in NVIDIA post_training request handler
This commit is contained in:
parent
0d0b8d2be1
commit
6346024fa3
1 changed files with 7 additions and 12 deletions
|
@ -67,18 +67,12 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
|||
self.timeout = aiohttp.ClientTimeout(total=config.timeout)
|
||||
# TODO: filter by available models based on /config endpoint
|
||||
ModelRegistryHelper.__init__(self, model_entries=_MODEL_ENTRIES)
|
||||
self.session = None
|
||||
|
||||
self.customizer_url = config.customizer_url
|
||||
if not self.customizer_url:
|
||||
warnings.warn("Customizer URL is not set, using default value: http://nemo.test", stacklevel=2)
|
||||
self.customizer_url = "http://nemo.test"
|
||||
|
||||
async def _get_session(self) -> aiohttp.ClientSession:
|
||||
if self.session is None or self.session.closed:
|
||||
self.session = aiohttp.ClientSession(headers=self.headers, timeout=self.timeout)
|
||||
return self.session
|
||||
|
||||
async def _make_request(
|
||||
self,
|
||||
method: str,
|
||||
|
@ -99,13 +93,13 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
|||
if json and "Content-Type" not in request_headers:
|
||||
request_headers["Content-Type"] = "application/json"
|
||||
|
||||
session = await self._get_session()
|
||||
for _ in range(self.config.max_retries):
|
||||
async with session.request(method, url, params=params, json=json, **kwargs) as response:
|
||||
if response.status >= 400:
|
||||
error_data = await response.json()
|
||||
raise Exception(f"API request failed: {error_data}")
|
||||
return await response.json()
|
||||
async with aiohttp.ClientSession(headers=request_headers) as session:
|
||||
async with session.request(method, url, params=params, json=json, **kwargs) as response:
|
||||
if response.status >= 400:
|
||||
error_data = await response.json()
|
||||
raise Exception(f"API request failed: {error_data}")
|
||||
return await response.json()
|
||||
|
||||
async def get_training_jobs(
|
||||
self,
|
||||
|
@ -176,6 +170,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
|||
- metrics: Optional[Dict] - Additional training metrics
|
||||
- status_logs: Optional[List] - Detailed logs of status changes
|
||||
"""
|
||||
print("Using local Llama Stack Customizer API")
|
||||
response = await self._make_request(
|
||||
"GET",
|
||||
f"/v1/customization/jobs/{job_uuid}/status",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue