align with other remote adapters, rename config base_url -> url

This commit is contained in:
Matthew Farrellee 2024-11-19 17:36:08 -05:00
parent 2980a18920
commit 4ccf4ef641
3 changed files with 8 additions and 8 deletions

View file

@ -17,7 +17,7 @@ class NVIDIAConfig(BaseModel):
Configuration for the NVIDIA NIM inference endpoint. Configuration for the NVIDIA NIM inference endpoint.
Attributes: Attributes:
base_url (str): A base url for accessing the NVIDIA NIM, e.g. http://localhost:8000 url (str): A base url for accessing the NVIDIA NIM, e.g. http://localhost:8000
api_key (str): The access key for the hosted NIM endpoints api_key (str): The access key for the hosted NIM endpoints
There are two ways to access NVIDIA NIMs - There are two ways to access NVIDIA NIMs -
@ -30,11 +30,11 @@ class NVIDIAConfig(BaseModel):
By default the configuration will attempt to read the NVIDIA_API_KEY environment By default the configuration will attempt to read the NVIDIA_API_KEY environment
variable to set the api_key. Please do not put your API key in code. variable to set the api_key. Please do not put your API key in code.
If you are using a self-hosted NVIDIA NIM, you can set the base_url to the If you are using a self-hosted NVIDIA NIM, you can set the url to the
URL of your running NVIDIA NIM and do not need to set the api_key. URL of your running NVIDIA NIM and do not need to set the api_key.
""" """
base_url: str = Field( url: str = Field(
default="https://integrate.api.nvidia.com", default="https://integrate.api.nvidia.com",
description="A base url for accessing the NVIDIA NIM", description="A base url for accessing the NVIDIA NIM",
) )
@ -49,4 +49,4 @@ class NVIDIAConfig(BaseModel):
@property @property
def is_hosted(self) -> bool: def is_hosted(self) -> bool:
return "integrate.api.nvidia.com" in self.base_url return "integrate.api.nvidia.com" in self.url

View file

@ -89,7 +89,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
# TODO(mf): filter by available models # TODO(mf): filter by available models
ModelRegistryHelper.__init__(self, model_aliases=_MODEL_ALIASES) ModelRegistryHelper.__init__(self, model_aliases=_MODEL_ALIASES)
print(f"Initializing NVIDIAInferenceAdapter({config.base_url})...") print(f"Initializing NVIDIAInferenceAdapter({config.url})...")
if config.is_hosted: if config.is_hosted:
if not config.api_key: if not config.api_key:
@ -110,7 +110,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
self._config = config self._config = config
# make sure the client lives longer than any async calls # make sure the client lives longer than any async calls
self._client = AsyncOpenAI( self._client = AsyncOpenAI(
base_url=f"{self._config.base_url}/v1", base_url=f"{self._config.url}/v1",
api_key=self._config.api_key or "NO KEY", api_key=self._config.api_key or "NO KEY",
timeout=self._config.timeout, timeout=self._config.timeout,
) )
@ -172,7 +172,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
response = await self._client.chat.completions.create(**request) response = await self._client.chat.completions.create(**request)
except APIConnectionError as e: except APIConnectionError as e:
raise ConnectionError( raise ConnectionError(
f"Failed to connect to NVIDIA NIM at {self._config.base_url}: {e}" f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}"
) from e ) from e
if stream: if stream:

View file

@ -40,7 +40,7 @@ async def check_health(config: NVIDIAConfig) -> None:
if not config.is_hosted: if not config.is_hosted:
print("Checking NVIDIA NIM health...") print("Checking NVIDIA NIM health...")
try: try:
is_live, is_ready = await _get_health(config.base_url) is_live, is_ready = await _get_health(config.url)
if not is_live: if not is_live:
raise ConnectionError("NVIDIA NIM is not running") raise ConnectionError("NVIDIA NIM is not running")
if not is_ready: if not is_ready: