mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
feat: NVIDIA allow non-llama model registration (#1859)
# What does this PR do? Adds custom model registration functionality to NVIDIAInferenceAdapter which let's the inference happen on: - post-training model - non-llama models in API Catalogue(behind https://integrate.api.nvidia.com and endpoints compatible with AyncOpenAI) ## Example Usage: ```python from llama_stack.apis.models import Model, ModelType from llama_stack.distribution.library_client import LlamaStackAsLibraryClient client = LlamaStackAsLibraryClient("nvidia") _ = client.initialize() client.models.register( model_id=model_name, model_type=ModelType.llm, provider_id="nvidia" ) response = client.inference.chat_completion( model_id=model_name, messages=[{"role":"system","content":"You are a helpful assistant."},{"role":"user","content":"Write a limerick about the wonders of GPU computing."}], ) ``` ## Test Plan ```bash pytest tests/unit/providers/nvidia/test_supervised_fine_tuning.py ========================================================== test session starts =========================================================== platform linux -- Python 3.10.0, pytest-8.3.5, pluggy-1.5.0 rootdir: /home/ubuntu/llama-stack configfile: pyproject.toml plugins: anyio-4.9.0 collected 6 items tests/unit/providers/nvidia/test_supervised_fine_tuning.py ...... [100%] ============================================================ warnings summary ============================================================ ../miniconda/envs/nvidia-1/lib/python3.10/site-packages/pydantic/fields.py:1076 /home/ubuntu/miniconda/envs/nvidia-1/lib/python3.10/site-packages/pydantic/fields.py:1076: PydanticDeprecatedSince20: Using extra keyword arguments on `Field` is deprecated and will be removed. Use `json_schema_extra` instead. (Extra keys: 'contentEncoding'). Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.11/migration/ warn( -- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html ====================================================== 6 passed, 1 warning in 1.51s ====================================================== ``` [//]: # (## Documentation) Updated Readme.md cc: @dglogo, @sumitb, @mattf
This commit is contained in:
parent
cc77f79f55
commit
ace82836c1
8 changed files with 116 additions and 15 deletions
|
@ -22,9 +22,8 @@ The `llamastack/distribution-nvidia` distribution consists of the following prov
|
|||
The following environment variables can be configured:
|
||||
|
||||
- `NVIDIA_API_KEY`: NVIDIA API Key (default: ``)
|
||||
- `NVIDIA_USER_ID`: NVIDIA User ID (default: `llama-stack-user`)
|
||||
- `NVIDIA_APPEND_API_VERSION`: Whether to append the API version to the base_url (default: `True`)
|
||||
- `NVIDIA_DATASET_NAMESPACE`: NVIDIA Dataset Namespace (default: `default`)
|
||||
- `NVIDIA_ACCESS_POLICIES`: NVIDIA Access Policies (default: `{}`)
|
||||
- `NVIDIA_PROJECT_ID`: NVIDIA Project ID (default: `test-project`)
|
||||
- `NVIDIA_CUSTOMIZER_URL`: NVIDIA Customizer URL (default: `https://customizer.api.nvidia.com`)
|
||||
- `NVIDIA_OUTPUT_MODEL_DIR`: NVIDIA Output Model Directory (default: `test-example-model@v1`)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue