add test cases

This commit is contained in:
Ubuntu 2025-03-08 11:43:13 +00:00
parent 0db524cc26
commit d8a5455fe0
5 changed files with 425 additions and 29 deletions

View file

@ -52,7 +52,7 @@ from .openai_utils import (
convert_openai_completion_choice,
convert_openai_completion_stream,
)
from .utils import _is_nvidia_hosted, check_health
from .utils import _is_nvidia_hosted
logger = logging.getLogger(__name__)
@ -99,7 +99,9 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
if content_has_media(content):
raise NotImplementedError("Media is not supported")
await check_health(self._config) # this raises errors
# ToDo: check health of NeMo endpoints and enable this
# removing this health check as NeMo customizer endpoint health check is returning 404
# await check_health(self._config) # this raises errors
request = convert_completion_request(
request=CompletionRequest(
@ -179,7 +181,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
if tool_prompt_format:
warnings.warn("tool_prompt_format is not supported by NVIDIA NIM, ignoring")
await check_health(self._config) # this raises errors
# await check_health(self._config) # this raises errors
request = await convert_chat_completion_request(
request=ChatCompletionRequest(

View file

@ -1,13 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, IAny, nc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
## ToDo: add supported models list, model validation logic

View file

@ -82,6 +82,9 @@ class NvidiaPostTrainingImpl:
for _ in range(self.config.max_retries):
async with aiohttp.ClientSession(headers=request_headers, timeout=self.timeout) as session:
async with session.request(method, url, params=params, json=json, **kwargs) as response:
if response.status >= 400:
error_data = await response.json()
raise Exception(f"API request failed: {error_data}")
return await response.json()
@webmethod(route="/post-training/jobs", method="GET")
@ -175,9 +178,9 @@ class NvidiaPostTrainingImpl:
Fine-tunes a model on a dataset.
Currently only supports Lora finetuning for standlone docker container.
Assumptions:
- model is a valid Nvidia model
- nemo microservice is running and endpoint is set in config.customizer_url
- dataset is registered separately in nemo datastore
- model checkpoint is downloaded from ngc and exists in the local directory
- model checkpoint is downloaded as per nemo customizer requirements
Parameters:
training_config: TrainingConfig - Configuration for training

View file

@ -0,0 +1,59 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, IAny, nc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from typing import Tuple
import httpx
from .config import NvidiaPostTrainingConfig
logger = logging.getLogger(__name__)
async def _get_health(url: str) -> Tuple[bool, bool]:
"""
Query {url}/v1/health/{live,ready} to check if the server is running and ready
Args:
url (str): URL of the server
Returns:
Tuple[bool, bool]: (is_live, is_ready)
"""
async with httpx.AsyncClient() as client:
live = await client.get(f"{url}/v1/health/live")
ready = await client.get(f"{url}/v1/health/ready")
return live.status_code == 200, ready.status_code == 200
async def check_health(config: NvidiaPostTrainingConfig) -> None:
"""
Check if the server is running and ready
Args:
url (str): URL of the server
Raises:
RuntimeError: If the server is not running or ready
"""
if not _is_nvidia_hosted(config):
logger.info("Checking NVIDIA NIM health...")
try:
is_live, is_ready = await _get_health(config.url)
if not is_live:
raise ConnectionError("NVIDIA NIM is not running")
if not is_ready:
raise ConnectionError("NVIDIA NIM is not ready")
# TODO(mf): should we wait for the server to be ready?
except httpx.ConnectError as e:
raise ConnectionError(f"Failed to connect to NVIDIA NIM: {e}") from e