add register_model method

This commit is contained in:
raspawar 2025-04-02 09:51:55 +00:00
parent c169c164b3
commit 3d2b374ee7
2 changed files with 99 additions and 10 deletions

View file

@ -33,11 +33,15 @@ from llama_stack.apis.inference import (
ToolChoice, ToolChoice,
ToolConfig, ToolConfig,
) )
from llama_stack.apis.models import Model, ModelType
from llama_stack.models.llama.datatypes import ( from llama_stack.models.llama.datatypes import (
SamplingParams, SamplingParams,
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.providers.utils.inference import (
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR,
)
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper, ModelRegistryHelper,
) )
@ -114,10 +118,13 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
"meta/llama-3.2-90b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-90b-vision-instruct", "meta/llama-3.2-90b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-90b-vision-instruct",
} }
base_url = f"{self._config.url}/v1" # add /v1 in case of hosted models
if _is_nvidia_hosted(self._config) and provider_model_id in special_model_urls: base_url = self._config.url
base_url = special_model_urls[provider_model_id] if _is_nvidia_hosted(self._config):
if provider_model_id in special_model_urls:
base_url = special_model_urls[provider_model_id]
else:
base_url = f"{self._config.url}/v1"
return _get_client_for_base_url(base_url) return _get_client_for_base_url(base_url)
async def completion( async def completion(
@ -265,3 +272,44 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
else: else:
# we pass n=1 to get only one completion # we pass n=1 to get only one completion
return convert_openai_chat_completion_choice(response.choices[0]) return convert_openai_chat_completion_choice(response.choices[0])
async def register_model(self, model: Model) -> Model:
"""
Allow non-llama model registration.
Non-llama model registration: API Catalogue models, post-training models, etc.
client = LlamaStackAsLibraryClient("nvidia")
client.models.register(
model_id="mistralai/mixtral-8x7b-instruct-v0.1",
model_type=ModelType.llm,
provider_id="nvidia",
provider_model_id="mistralai/mixtral-8x7b-instruct-v0.1"
)
NOTE: Only supports models endpoints compatible with AsyncOpenAI base_url format.
"""
if model.model_type == ModelType.embedding:
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model
provider_resource_id = model.provider_resource_id
else:
provider_resource_id = self.get_provider_model_id(model.provider_resource_id)
if provider_resource_id:
model.provider_resource_id = provider_resource_id
else:
llama_model = model.metadata.get("llama_model")
existing_llama_model = self.get_llama_model(model.provider_resource_id)
if existing_llama_model:
if existing_llama_model != llama_model:
raise ValueError(
f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'"
)
else:
# not llama model
if llama_model in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR:
self.provider_id_to_llama_model_map[model.provider_resource_id] = (
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model]
)
else:
self.alias_to_provider_id_map[model.provider_model_id] = model.provider_model_id
return model

View file

@ -10,13 +10,9 @@ import warnings
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from llama_stack_client.types.algorithm_config_param import LoraFinetuningConfig, QatFinetuningConfig
from llama_stack_client.types.post_training_supervised_fine_tune_params import (
TrainingConfig,
TrainingConfigDataConfig,
TrainingConfigOptimizerConfig,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.providers.remote.inference.nvidia.nvidia import NVIDIAConfig, NVIDIAInferenceAdapter
from llama_stack.providers.remote.post_training.nvidia.post_training import ( from llama_stack.providers.remote.post_training.nvidia.post_training import (
ListNvidiaPostTrainingJobs, ListNvidiaPostTrainingJobs,
NvidiaPostTrainingAdapter, NvidiaPostTrainingAdapter,
@ -24,6 +20,12 @@ from llama_stack.providers.remote.post_training.nvidia.post_training import (
NvidiaPostTrainingJob, NvidiaPostTrainingJob,
NvidiaPostTrainingJobStatusResponse, NvidiaPostTrainingJobStatusResponse,
) )
from llama_stack_client.types.algorithm_config_param import LoraFinetuningConfig, QatFinetuningConfig
from llama_stack_client.types.post_training_supervised_fine_tune_params import (
TrainingConfig,
TrainingConfigDataConfig,
TrainingConfigOptimizerConfig,
)
class TestNvidiaPostTraining(unittest.TestCase): class TestNvidiaPostTraining(unittest.TestCase):
@ -40,8 +42,22 @@ class TestNvidiaPostTraining(unittest.TestCase):
) )
self.mock_make_request = self.make_request_patcher.start() self.mock_make_request = self.make_request_patcher.start()
# Mock the inference client
inference_config = NVIDIAConfig(base_url=os.environ["NVIDIA_BASE_URL"], api_key=None)
self.inference_adapter = NVIDIAInferenceAdapter(inference_config)
self.mock_client = unittest.mock.MagicMock()
self.mock_client.chat.completions.create = unittest.mock.AsyncMock()
self.inference_mock_make_request = self.mock_client.chat.completions.create
self.inference_make_request_patcher = patch(
"llama_stack.providers.remote.inference.nvidia.nvidia.NVIDIAInferenceAdapter._get_client",
return_value=self.mock_client,
)
self.inference_make_request_patcher.start()
def tearDown(self): def tearDown(self):
self.make_request_patcher.stop() self.make_request_patcher.stop()
self.inference_make_request_patcher.stop()
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def inject_fixtures(self, run_async): def inject_fixtures(self, run_async):
@ -290,6 +306,31 @@ class TestNvidiaPostTraining(unittest.TestCase):
expected_params={"job_id": job_id}, expected_params={"job_id": job_id},
) )
def test_inference_register_model(self):
model_id = "default/job-1234"
model_type = ModelType.llm
model = Model(
identifier=model_id,
provider_id="nvidia",
provider_model_id=model_id,
provider_resource_id=model_id,
model_type=model_type,
)
result = self.run_async(self.inference_adapter.register_model(model))
assert result == model
assert len(self.inference_adapter.alias_to_provider_id_map) > 1
assert self.inference_adapter.get_provider_model_id(model.provider_model_id) == model_id
with patch.object(self.inference_adapter, "chat_completion") as mock_chat_completion:
self.run_async(
self.inference_adapter.chat_completion(
model_id=model_id,
messages=[{"role": "user", "content": "Hello, model"}],
)
)
mock_chat_completion.assert_called()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()