mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
Merge 6b6d8d70a5
into 40fdce79b3
This commit is contained in:
commit
112b935546
1 changed files with 6 additions and 12 deletions
|
@ -67,18 +67,12 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
||||||
self.timeout = aiohttp.ClientTimeout(total=config.timeout)
|
self.timeout = aiohttp.ClientTimeout(total=config.timeout)
|
||||||
# TODO: filter by available models based on /config endpoint
|
# TODO: filter by available models based on /config endpoint
|
||||||
ModelRegistryHelper.__init__(self, model_entries=_MODEL_ENTRIES)
|
ModelRegistryHelper.__init__(self, model_entries=_MODEL_ENTRIES)
|
||||||
self.session = None
|
|
||||||
|
|
||||||
self.customizer_url = config.customizer_url
|
self.customizer_url = config.customizer_url
|
||||||
if not self.customizer_url:
|
if not self.customizer_url:
|
||||||
warnings.warn("Customizer URL is not set, using default value: http://nemo.test", stacklevel=2)
|
warnings.warn("Customizer URL is not set, using default value: http://nemo.test", stacklevel=2)
|
||||||
self.customizer_url = "http://nemo.test"
|
self.customizer_url = "http://nemo.test"
|
||||||
|
|
||||||
async def _get_session(self) -> aiohttp.ClientSession:
|
|
||||||
if self.session is None or self.session.closed:
|
|
||||||
self.session = aiohttp.ClientSession(headers=self.headers, timeout=self.timeout)
|
|
||||||
return self.session
|
|
||||||
|
|
||||||
async def _make_request(
|
async def _make_request(
|
||||||
self,
|
self,
|
||||||
method: str,
|
method: str,
|
||||||
|
@ -99,8 +93,8 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
||||||
if json and "Content-Type" not in request_headers:
|
if json and "Content-Type" not in request_headers:
|
||||||
request_headers["Content-Type"] = "application/json"
|
request_headers["Content-Type"] = "application/json"
|
||||||
|
|
||||||
session = await self._get_session()
|
|
||||||
for _ in range(self.config.max_retries):
|
for _ in range(self.config.max_retries):
|
||||||
|
async with aiohttp.ClientSession(headers=request_headers) as session:
|
||||||
async with session.request(method, url, params=params, json=json, **kwargs) as response:
|
async with session.request(method, url, params=params, json=json, **kwargs) as response:
|
||||||
if response.status >= 400:
|
if response.status >= 400:
|
||||||
error_data = await response.json()
|
error_data = await response.json()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue