mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
feat: NVIDIA allow non-llama model registration (#1859)
# What does this PR do? Adds custom model registration functionality to NVIDIAInferenceAdapter which let's the inference happen on: - post-training model - non-llama models in API Catalogue(behind https://integrate.api.nvidia.com and endpoints compatible with AyncOpenAI) ## Example Usage: ```python from llama_stack.apis.models import Model, ModelType from llama_stack.distribution.library_client import LlamaStackAsLibraryClient client = LlamaStackAsLibraryClient("nvidia") _ = client.initialize() client.models.register( model_id=model_name, model_type=ModelType.llm, provider_id="nvidia" ) response = client.inference.chat_completion( model_id=model_name, messages=[{"role":"system","content":"You are a helpful assistant."},{"role":"user","content":"Write a limerick about the wonders of GPU computing."}], ) ``` ## Test Plan ```bash pytest tests/unit/providers/nvidia/test_supervised_fine_tuning.py ========================================================== test session starts =========================================================== platform linux -- Python 3.10.0, pytest-8.3.5, pluggy-1.5.0 rootdir: /home/ubuntu/llama-stack configfile: pyproject.toml plugins: anyio-4.9.0 collected 6 items tests/unit/providers/nvidia/test_supervised_fine_tuning.py ...... [100%] ============================================================ warnings summary ============================================================ ../miniconda/envs/nvidia-1/lib/python3.10/site-packages/pydantic/fields.py:1076 /home/ubuntu/miniconda/envs/nvidia-1/lib/python3.10/site-packages/pydantic/fields.py:1076: PydanticDeprecatedSince20: Using extra keyword arguments on `Field` is deprecated and will be removed. Use `json_schema_extra` instead. (Extra keys: 'contentEncoding'). Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.11/migration/ warn( -- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html ====================================================== 6 passed, 1 warning in 1.51s ====================================================== ``` [//]: # (## Documentation) Updated Readme.md cc: @dglogo, @sumitb, @mattf
This commit is contained in:
parent
cc77f79f55
commit
ace82836c1
8 changed files with 116 additions and 15 deletions
|
@ -17,6 +17,8 @@ from llama_stack_client.types.post_training_supervised_fine_tune_params import (
|
|||
TrainingConfigOptimizerConfig,
|
||||
)
|
||||
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.providers.remote.inference.nvidia.nvidia import NVIDIAConfig, NVIDIAInferenceAdapter
|
||||
from llama_stack.providers.remote.post_training.nvidia.post_training import (
|
||||
ListNvidiaPostTrainingJobs,
|
||||
NvidiaPostTrainingAdapter,
|
||||
|
@ -40,8 +42,22 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
|||
)
|
||||
self.mock_make_request = self.make_request_patcher.start()
|
||||
|
||||
# Mock the inference client
|
||||
inference_config = NVIDIAConfig(base_url=os.environ["NVIDIA_BASE_URL"], api_key=None)
|
||||
self.inference_adapter = NVIDIAInferenceAdapter(inference_config)
|
||||
|
||||
self.mock_client = unittest.mock.MagicMock()
|
||||
self.mock_client.chat.completions.create = unittest.mock.AsyncMock()
|
||||
self.inference_mock_make_request = self.mock_client.chat.completions.create
|
||||
self.inference_make_request_patcher = patch(
|
||||
"llama_stack.providers.remote.inference.nvidia.nvidia.NVIDIAInferenceAdapter._get_client",
|
||||
return_value=self.mock_client,
|
||||
)
|
||||
self.inference_make_request_patcher.start()
|
||||
|
||||
def tearDown(self):
|
||||
self.make_request_patcher.stop()
|
||||
self.inference_make_request_patcher.stop()
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def inject_fixtures(self, run_async):
|
||||
|
@ -303,6 +319,31 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
|||
expected_params={"job_id": job_id},
|
||||
)
|
||||
|
||||
def test_inference_register_model(self):
|
||||
model_id = "default/job-1234"
|
||||
model_type = ModelType.llm
|
||||
model = Model(
|
||||
identifier=model_id,
|
||||
provider_id="nvidia",
|
||||
provider_model_id=model_id,
|
||||
provider_resource_id=model_id,
|
||||
model_type=model_type,
|
||||
)
|
||||
result = self.run_async(self.inference_adapter.register_model(model))
|
||||
assert result == model
|
||||
assert len(self.inference_adapter.alias_to_provider_id_map) > 1
|
||||
assert self.inference_adapter.get_provider_model_id(model.provider_model_id) == model_id
|
||||
|
||||
with patch.object(self.inference_adapter, "chat_completion") as mock_chat_completion:
|
||||
self.run_async(
|
||||
self.inference_adapter.chat_completion(
|
||||
model_id=model_id,
|
||||
messages=[{"role": "user", "content": "Hello, model"}],
|
||||
)
|
||||
)
|
||||
|
||||
mock_chat_completion.assert_called()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue