forked from phoenix-oss/llama-stack-mirror
feat: NVIDIA allow non-llama model registration (#1859)
# What does this PR do? Adds custom model registration functionality to NVIDIAInferenceAdapter which let's the inference happen on: - post-training model - non-llama models in API Catalogue(behind https://integrate.api.nvidia.com and endpoints compatible with AyncOpenAI) ## Example Usage: ```python from llama_stack.apis.models import Model, ModelType from llama_stack.distribution.library_client import LlamaStackAsLibraryClient client = LlamaStackAsLibraryClient("nvidia") _ = client.initialize() client.models.register( model_id=model_name, model_type=ModelType.llm, provider_id="nvidia" ) response = client.inference.chat_completion( model_id=model_name, messages=[{"role":"system","content":"You are a helpful assistant."},{"role":"user","content":"Write a limerick about the wonders of GPU computing."}], ) ``` ## Test Plan ```bash pytest tests/unit/providers/nvidia/test_supervised_fine_tuning.py ========================================================== test session starts =========================================================== platform linux -- Python 3.10.0, pytest-8.3.5, pluggy-1.5.0 rootdir: /home/ubuntu/llama-stack configfile: pyproject.toml plugins: anyio-4.9.0 collected 6 items tests/unit/providers/nvidia/test_supervised_fine_tuning.py ...... [100%] ============================================================ warnings summary ============================================================ ../miniconda/envs/nvidia-1/lib/python3.10/site-packages/pydantic/fields.py:1076 /home/ubuntu/miniconda/envs/nvidia-1/lib/python3.10/site-packages/pydantic/fields.py:1076: PydanticDeprecatedSince20: Using extra keyword arguments on `Field` is deprecated and will be removed. Use `json_schema_extra` instead. (Extra keys: 'contentEncoding'). Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.11/migration/ warn( -- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html ====================================================== 6 passed, 1 warning in 1.51s ====================================================== ``` [//]: # (## Documentation) Updated Readme.md cc: @dglogo, @sumitb, @mattf
This commit is contained in:
parent
cc77f79f55
commit
ace82836c1
8 changed files with 116 additions and 15 deletions
|
@ -22,9 +22,8 @@ The `llamastack/distribution-nvidia` distribution consists of the following prov
|
||||||
The following environment variables can be configured:
|
The following environment variables can be configured:
|
||||||
|
|
||||||
- `NVIDIA_API_KEY`: NVIDIA API Key (default: ``)
|
- `NVIDIA_API_KEY`: NVIDIA API Key (default: ``)
|
||||||
- `NVIDIA_USER_ID`: NVIDIA User ID (default: `llama-stack-user`)
|
- `NVIDIA_APPEND_API_VERSION`: Whether to append the API version to the base_url (default: `True`)
|
||||||
- `NVIDIA_DATASET_NAMESPACE`: NVIDIA Dataset Namespace (default: `default`)
|
- `NVIDIA_DATASET_NAMESPACE`: NVIDIA Dataset Namespace (default: `default`)
|
||||||
- `NVIDIA_ACCESS_POLICIES`: NVIDIA Access Policies (default: `{}`)
|
|
||||||
- `NVIDIA_PROJECT_ID`: NVIDIA Project ID (default: `test-project`)
|
- `NVIDIA_PROJECT_ID`: NVIDIA Project ID (default: `test-project`)
|
||||||
- `NVIDIA_CUSTOMIZER_URL`: NVIDIA Customizer URL (default: `https://customizer.api.nvidia.com`)
|
- `NVIDIA_CUSTOMIZER_URL`: NVIDIA Customizer URL (default: `https://customizer.api.nvidia.com`)
|
||||||
- `NVIDIA_OUTPUT_MODEL_DIR`: NVIDIA Output Model Directory (default: `test-example-model@v1`)
|
- `NVIDIA_OUTPUT_MODEL_DIR`: NVIDIA Output Model Directory (default: `test-example-model@v1`)
|
||||||
|
|
|
@ -47,10 +47,15 @@ class NVIDIAConfig(BaseModel):
|
||||||
default=60,
|
default=60,
|
||||||
description="Timeout for the HTTP requests",
|
description="Timeout for the HTTP requests",
|
||||||
)
|
)
|
||||||
|
append_api_version: bool = Field(
|
||||||
|
default_factory=lambda: os.getenv("NVIDIA_APPEND_API_VERSION", "True").lower() != "false",
|
||||||
|
description="When set to false, the API version will not be appended to the base_url. By default, it is true.",
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"url": "${env.NVIDIA_BASE_URL:https://integrate.api.nvidia.com}",
|
"url": "${env.NVIDIA_BASE_URL:https://integrate.api.nvidia.com}",
|
||||||
"api_key": "${env.NVIDIA_API_KEY:}",
|
"api_key": "${env.NVIDIA_API_KEY:}",
|
||||||
|
"append_api_version": "${env.NVIDIA_APPEND_API_VERSION:True}",
|
||||||
}
|
}
|
||||||
|
|
|
@ -33,7 +33,6 @@ from llama_stack.apis.inference import (
|
||||||
TextTruncation,
|
TextTruncation,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
ToolDefinition,
|
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference.inference import (
|
from llama_stack.apis.inference.inference import (
|
||||||
OpenAIChatCompletion,
|
OpenAIChatCompletion,
|
||||||
|
@ -42,7 +41,11 @@ from llama_stack.apis.inference.inference import (
|
||||||
OpenAIMessageParam,
|
OpenAIMessageParam,
|
||||||
OpenAIResponseFormatParam,
|
OpenAIResponseFormatParam,
|
||||||
)
|
)
|
||||||
from llama_stack.models.llama.datatypes import ToolPromptFormat
|
from llama_stack.apis.models import Model, ModelType
|
||||||
|
from llama_stack.models.llama.datatypes import ToolDefinition, ToolPromptFormat
|
||||||
|
from llama_stack.providers.utils.inference import (
|
||||||
|
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR,
|
||||||
|
)
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
ModelRegistryHelper,
|
ModelRegistryHelper,
|
||||||
)
|
)
|
||||||
|
@ -120,10 +123,10 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
"meta/llama-3.2-90b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-90b-vision-instruct",
|
"meta/llama-3.2-90b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-90b-vision-instruct",
|
||||||
}
|
}
|
||||||
|
|
||||||
base_url = f"{self._config.url}/v1"
|
base_url = f"{self._config.url}/v1" if self._config.append_api_version else self._config.url
|
||||||
|
|
||||||
if _is_nvidia_hosted(self._config) and provider_model_id in special_model_urls:
|
if _is_nvidia_hosted(self._config) and provider_model_id in special_model_urls:
|
||||||
base_url = special_model_urls[provider_model_id]
|
base_url = special_model_urls[provider_model_id]
|
||||||
|
|
||||||
return _get_client_for_base_url(base_url)
|
return _get_client_for_base_url(base_url)
|
||||||
|
|
||||||
async def _get_provider_model_id(self, model_id: str) -> str:
|
async def _get_provider_model_id(self, model_id: str) -> str:
|
||||||
|
@ -387,3 +390,44 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
return await self._get_client(provider_model_id).chat.completions.create(**params)
|
return await self._get_client(provider_model_id).chat.completions.create(**params)
|
||||||
except APIConnectionError as e:
|
except APIConnectionError as e:
|
||||||
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
||||||
|
|
||||||
|
async def register_model(self, model: Model) -> Model:
|
||||||
|
"""
|
||||||
|
Allow non-llama model registration.
|
||||||
|
|
||||||
|
Non-llama model registration: API Catalogue models, post-training models, etc.
|
||||||
|
client = LlamaStackAsLibraryClient("nvidia")
|
||||||
|
client.models.register(
|
||||||
|
model_id="mistralai/mixtral-8x7b-instruct-v0.1",
|
||||||
|
model_type=ModelType.llm,
|
||||||
|
provider_id="nvidia",
|
||||||
|
provider_model_id="mistralai/mixtral-8x7b-instruct-v0.1"
|
||||||
|
)
|
||||||
|
|
||||||
|
NOTE: Only supports models endpoints compatible with AsyncOpenAI base_url format.
|
||||||
|
"""
|
||||||
|
if model.model_type == ModelType.embedding:
|
||||||
|
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model
|
||||||
|
provider_resource_id = model.provider_resource_id
|
||||||
|
else:
|
||||||
|
provider_resource_id = self.get_provider_model_id(model.provider_resource_id)
|
||||||
|
|
||||||
|
if provider_resource_id:
|
||||||
|
model.provider_resource_id = provider_resource_id
|
||||||
|
else:
|
||||||
|
llama_model = model.metadata.get("llama_model")
|
||||||
|
existing_llama_model = self.get_llama_model(model.provider_resource_id)
|
||||||
|
if existing_llama_model:
|
||||||
|
if existing_llama_model != llama_model:
|
||||||
|
raise ValueError(
|
||||||
|
f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# not llama model
|
||||||
|
if llama_model in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR:
|
||||||
|
self.provider_id_to_llama_model_map[model.provider_resource_id] = (
|
||||||
|
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.alias_to_provider_id_map[model.provider_model_id] = model.provider_model_id
|
||||||
|
return model
|
||||||
|
|
|
@ -36,7 +36,6 @@ import os
|
||||||
|
|
||||||
os.environ["NVIDIA_API_KEY"] = "your-api-key"
|
os.environ["NVIDIA_API_KEY"] = "your-api-key"
|
||||||
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test"
|
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test"
|
||||||
os.environ["NVIDIA_USER_ID"] = "llama-stack-user"
|
|
||||||
os.environ["NVIDIA_DATASET_NAMESPACE"] = "default"
|
os.environ["NVIDIA_DATASET_NAMESPACE"] = "default"
|
||||||
os.environ["NVIDIA_PROJECT_ID"] = "test-project"
|
os.environ["NVIDIA_PROJECT_ID"] = "test-project"
|
||||||
os.environ["NVIDIA_OUTPUT_MODEL_DIR"] = "test-example-model@v1"
|
os.environ["NVIDIA_OUTPUT_MODEL_DIR"] = "test-example-model@v1"
|
||||||
|
@ -125,6 +124,21 @@ client.post_training.job.cancel(job_uuid="your-job-id")
|
||||||
|
|
||||||
### Inference with the fine-tuned model
|
### Inference with the fine-tuned model
|
||||||
|
|
||||||
|
#### 1. Register the model
|
||||||
|
|
||||||
|
```python
|
||||||
|
from llama_stack.apis.models import Model, ModelType
|
||||||
|
|
||||||
|
client.models.register(
|
||||||
|
model_id="test-example-model@v1",
|
||||||
|
provider_id="nvidia",
|
||||||
|
provider_model_id="test-example-model@v1",
|
||||||
|
model_type=ModelType.llm,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 2. Inference with the fine-tuned model
|
||||||
|
|
||||||
```python
|
```python
|
||||||
response = client.inference.completion(
|
response = client.inference.completion(
|
||||||
content="Complete the sentence using one word: Roses are red, violets are ",
|
content="Complete the sentence using one word: Roses are red, violets are ",
|
||||||
|
|
|
@ -98,19 +98,15 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
"",
|
"",
|
||||||
"NVIDIA API Key",
|
"NVIDIA API Key",
|
||||||
),
|
),
|
||||||
## Nemo Customizer related variables
|
"NVIDIA_APPEND_API_VERSION": (
|
||||||
"NVIDIA_USER_ID": (
|
"True",
|
||||||
"llama-stack-user",
|
"Whether to append the API version to the base_url",
|
||||||
"NVIDIA User ID",
|
|
||||||
),
|
),
|
||||||
|
## Nemo Customizer related variables
|
||||||
"NVIDIA_DATASET_NAMESPACE": (
|
"NVIDIA_DATASET_NAMESPACE": (
|
||||||
"default",
|
"default",
|
||||||
"NVIDIA Dataset Namespace",
|
"NVIDIA Dataset Namespace",
|
||||||
),
|
),
|
||||||
"NVIDIA_ACCESS_POLICIES": (
|
|
||||||
"{}",
|
|
||||||
"NVIDIA Access Policies",
|
|
||||||
),
|
|
||||||
"NVIDIA_PROJECT_ID": (
|
"NVIDIA_PROJECT_ID": (
|
||||||
"test-project",
|
"test-project",
|
||||||
"NVIDIA Project ID",
|
"NVIDIA Project ID",
|
||||||
|
|
|
@ -18,6 +18,7 @@ providers:
|
||||||
config:
|
config:
|
||||||
url: ${env.NVIDIA_BASE_URL:https://integrate.api.nvidia.com}
|
url: ${env.NVIDIA_BASE_URL:https://integrate.api.nvidia.com}
|
||||||
api_key: ${env.NVIDIA_API_KEY:}
|
api_key: ${env.NVIDIA_API_KEY:}
|
||||||
|
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:True}
|
||||||
- provider_id: nvidia
|
- provider_id: nvidia
|
||||||
provider_type: remote::nvidia
|
provider_type: remote::nvidia
|
||||||
config:
|
config:
|
||||||
|
|
|
@ -18,6 +18,7 @@ providers:
|
||||||
config:
|
config:
|
||||||
url: ${env.NVIDIA_BASE_URL:https://integrate.api.nvidia.com}
|
url: ${env.NVIDIA_BASE_URL:https://integrate.api.nvidia.com}
|
||||||
api_key: ${env.NVIDIA_API_KEY:}
|
api_key: ${env.NVIDIA_API_KEY:}
|
||||||
|
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:True}
|
||||||
vector_io:
|
vector_io:
|
||||||
- provider_id: faiss
|
- provider_id: faiss
|
||||||
provider_type: inline::faiss
|
provider_type: inline::faiss
|
||||||
|
|
|
@ -17,6 +17,8 @@ from llama_stack_client.types.post_training_supervised_fine_tune_params import (
|
||||||
TrainingConfigOptimizerConfig,
|
TrainingConfigOptimizerConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from llama_stack.apis.models import Model, ModelType
|
||||||
|
from llama_stack.providers.remote.inference.nvidia.nvidia import NVIDIAConfig, NVIDIAInferenceAdapter
|
||||||
from llama_stack.providers.remote.post_training.nvidia.post_training import (
|
from llama_stack.providers.remote.post_training.nvidia.post_training import (
|
||||||
ListNvidiaPostTrainingJobs,
|
ListNvidiaPostTrainingJobs,
|
||||||
NvidiaPostTrainingAdapter,
|
NvidiaPostTrainingAdapter,
|
||||||
|
@ -40,8 +42,22 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
||||||
)
|
)
|
||||||
self.mock_make_request = self.make_request_patcher.start()
|
self.mock_make_request = self.make_request_patcher.start()
|
||||||
|
|
||||||
|
# Mock the inference client
|
||||||
|
inference_config = NVIDIAConfig(base_url=os.environ["NVIDIA_BASE_URL"], api_key=None)
|
||||||
|
self.inference_adapter = NVIDIAInferenceAdapter(inference_config)
|
||||||
|
|
||||||
|
self.mock_client = unittest.mock.MagicMock()
|
||||||
|
self.mock_client.chat.completions.create = unittest.mock.AsyncMock()
|
||||||
|
self.inference_mock_make_request = self.mock_client.chat.completions.create
|
||||||
|
self.inference_make_request_patcher = patch(
|
||||||
|
"llama_stack.providers.remote.inference.nvidia.nvidia.NVIDIAInferenceAdapter._get_client",
|
||||||
|
return_value=self.mock_client,
|
||||||
|
)
|
||||||
|
self.inference_make_request_patcher.start()
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
self.make_request_patcher.stop()
|
self.make_request_patcher.stop()
|
||||||
|
self.inference_make_request_patcher.stop()
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def inject_fixtures(self, run_async):
|
def inject_fixtures(self, run_async):
|
||||||
|
@ -303,6 +319,31 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
||||||
expected_params={"job_id": job_id},
|
expected_params={"job_id": job_id},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_inference_register_model(self):
|
||||||
|
model_id = "default/job-1234"
|
||||||
|
model_type = ModelType.llm
|
||||||
|
model = Model(
|
||||||
|
identifier=model_id,
|
||||||
|
provider_id="nvidia",
|
||||||
|
provider_model_id=model_id,
|
||||||
|
provider_resource_id=model_id,
|
||||||
|
model_type=model_type,
|
||||||
|
)
|
||||||
|
result = self.run_async(self.inference_adapter.register_model(model))
|
||||||
|
assert result == model
|
||||||
|
assert len(self.inference_adapter.alias_to_provider_id_map) > 1
|
||||||
|
assert self.inference_adapter.get_provider_model_id(model.provider_model_id) == model_id
|
||||||
|
|
||||||
|
with patch.object(self.inference_adapter, "chat_completion") as mock_chat_completion:
|
||||||
|
self.run_async(
|
||||||
|
self.inference_adapter.chat_completion(
|
||||||
|
model_id=model_id,
|
||||||
|
messages=[{"role": "user", "content": "Hello, model"}],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_chat_completion.assert_called()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue