mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 10:54:19 +00:00
fix: Don't reuse session in NVIDIA post_training request handler
This commit is contained in:
parent
0d0b8d2be1
commit
6346024fa3
1 changed files with 7 additions and 12 deletions
|
@ -67,18 +67,12 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
||||||
self.timeout = aiohttp.ClientTimeout(total=config.timeout)
|
self.timeout = aiohttp.ClientTimeout(total=config.timeout)
|
||||||
# TODO: filter by available models based on /config endpoint
|
# TODO: filter by available models based on /config endpoint
|
||||||
ModelRegistryHelper.__init__(self, model_entries=_MODEL_ENTRIES)
|
ModelRegistryHelper.__init__(self, model_entries=_MODEL_ENTRIES)
|
||||||
self.session = None
|
|
||||||
|
|
||||||
self.customizer_url = config.customizer_url
|
self.customizer_url = config.customizer_url
|
||||||
if not self.customizer_url:
|
if not self.customizer_url:
|
||||||
warnings.warn("Customizer URL is not set, using default value: http://nemo.test", stacklevel=2)
|
warnings.warn("Customizer URL is not set, using default value: http://nemo.test", stacklevel=2)
|
||||||
self.customizer_url = "http://nemo.test"
|
self.customizer_url = "http://nemo.test"
|
||||||
|
|
||||||
async def _get_session(self) -> aiohttp.ClientSession:
|
|
||||||
if self.session is None or self.session.closed:
|
|
||||||
self.session = aiohttp.ClientSession(headers=self.headers, timeout=self.timeout)
|
|
||||||
return self.session
|
|
||||||
|
|
||||||
async def _make_request(
|
async def _make_request(
|
||||||
self,
|
self,
|
||||||
method: str,
|
method: str,
|
||||||
|
@ -99,13 +93,13 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
||||||
if json and "Content-Type" not in request_headers:
|
if json and "Content-Type" not in request_headers:
|
||||||
request_headers["Content-Type"] = "application/json"
|
request_headers["Content-Type"] = "application/json"
|
||||||
|
|
||||||
session = await self._get_session()
|
|
||||||
for _ in range(self.config.max_retries):
|
for _ in range(self.config.max_retries):
|
||||||
async with session.request(method, url, params=params, json=json, **kwargs) as response:
|
async with aiohttp.ClientSession(headers=request_headers) as session:
|
||||||
if response.status >= 400:
|
async with session.request(method, url, params=params, json=json, **kwargs) as response:
|
||||||
error_data = await response.json()
|
if response.status >= 400:
|
||||||
raise Exception(f"API request failed: {error_data}")
|
error_data = await response.json()
|
||||||
return await response.json()
|
raise Exception(f"API request failed: {error_data}")
|
||||||
|
return await response.json()
|
||||||
|
|
||||||
async def get_training_jobs(
|
async def get_training_jobs(
|
||||||
self,
|
self,
|
||||||
|
@ -176,6 +170,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
||||||
- metrics: Optional[Dict] - Additional training metrics
|
- metrics: Optional[Dict] - Additional training metrics
|
||||||
- status_logs: Optional[List] - Detailed logs of status changes
|
- status_logs: Optional[List] - Detailed logs of status changes
|
||||||
"""
|
"""
|
||||||
|
print("Using local Llama Stack Customizer API")
|
||||||
response = await self._make_request(
|
response = await self._make_request(
|
||||||
"GET",
|
"GET",
|
||||||
f"/v1/customization/jobs/{job_uuid}/status",
|
f"/v1/customization/jobs/{job_uuid}/status",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue