mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-05 18:22:41 +00:00
add register_model method
This commit is contained in:
parent
c169c164b3
commit
3d2b374ee7
2 changed files with 99 additions and 10 deletions
|
@ -33,11 +33,15 @@ from llama_stack.apis.inference import (
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
)
|
)
|
||||||
|
from llama_stack.apis.models import Model, ModelType
|
||||||
from llama_stack.models.llama.datatypes import (
|
from llama_stack.models.llama.datatypes import (
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
|
from llama_stack.providers.utils.inference import (
|
||||||
|
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR,
|
||||||
|
)
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
ModelRegistryHelper,
|
ModelRegistryHelper,
|
||||||
)
|
)
|
||||||
|
@ -114,10 +118,13 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
"meta/llama-3.2-90b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-90b-vision-instruct",
|
"meta/llama-3.2-90b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-90b-vision-instruct",
|
||||||
}
|
}
|
||||||
|
|
||||||
base_url = f"{self._config.url}/v1"
|
# add /v1 in case of hosted models
|
||||||
if _is_nvidia_hosted(self._config) and provider_model_id in special_model_urls:
|
base_url = self._config.url
|
||||||
|
if _is_nvidia_hosted(self._config):
|
||||||
|
if provider_model_id in special_model_urls:
|
||||||
base_url = special_model_urls[provider_model_id]
|
base_url = special_model_urls[provider_model_id]
|
||||||
|
else:
|
||||||
|
base_url = f"{self._config.url}/v1"
|
||||||
return _get_client_for_base_url(base_url)
|
return _get_client_for_base_url(base_url)
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
|
@ -265,3 +272,44 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
else:
|
else:
|
||||||
# we pass n=1 to get only one completion
|
# we pass n=1 to get only one completion
|
||||||
return convert_openai_chat_completion_choice(response.choices[0])
|
return convert_openai_chat_completion_choice(response.choices[0])
|
||||||
|
|
||||||
|
async def register_model(self, model: Model) -> Model:
|
||||||
|
"""
|
||||||
|
Allow non-llama model registration.
|
||||||
|
|
||||||
|
Non-llama model registration: API Catalogue models, post-training models, etc.
|
||||||
|
client = LlamaStackAsLibraryClient("nvidia")
|
||||||
|
client.models.register(
|
||||||
|
model_id="mistralai/mixtral-8x7b-instruct-v0.1",
|
||||||
|
model_type=ModelType.llm,
|
||||||
|
provider_id="nvidia",
|
||||||
|
provider_model_id="mistralai/mixtral-8x7b-instruct-v0.1"
|
||||||
|
)
|
||||||
|
|
||||||
|
NOTE: Only supports models endpoints compatible with AsyncOpenAI base_url format.
|
||||||
|
"""
|
||||||
|
if model.model_type == ModelType.embedding:
|
||||||
|
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model
|
||||||
|
provider_resource_id = model.provider_resource_id
|
||||||
|
else:
|
||||||
|
provider_resource_id = self.get_provider_model_id(model.provider_resource_id)
|
||||||
|
|
||||||
|
if provider_resource_id:
|
||||||
|
model.provider_resource_id = provider_resource_id
|
||||||
|
else:
|
||||||
|
llama_model = model.metadata.get("llama_model")
|
||||||
|
existing_llama_model = self.get_llama_model(model.provider_resource_id)
|
||||||
|
if existing_llama_model:
|
||||||
|
if existing_llama_model != llama_model:
|
||||||
|
raise ValueError(
|
||||||
|
f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# not llama model
|
||||||
|
if llama_model in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR:
|
||||||
|
self.provider_id_to_llama_model_map[model.provider_resource_id] = (
|
||||||
|
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.alias_to_provider_id_map[model.provider_model_id] = model.provider_model_id
|
||||||
|
return model
|
||||||
|
|
|
@ -10,13 +10,9 @@ import warnings
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from llama_stack_client.types.algorithm_config_param import LoraFinetuningConfig, QatFinetuningConfig
|
|
||||||
from llama_stack_client.types.post_training_supervised_fine_tune_params import (
|
|
||||||
TrainingConfig,
|
|
||||||
TrainingConfigDataConfig,
|
|
||||||
TrainingConfigOptimizerConfig,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
from llama_stack.apis.models import Model, ModelType
|
||||||
|
from llama_stack.providers.remote.inference.nvidia.nvidia import NVIDIAConfig, NVIDIAInferenceAdapter
|
||||||
from llama_stack.providers.remote.post_training.nvidia.post_training import (
|
from llama_stack.providers.remote.post_training.nvidia.post_training import (
|
||||||
ListNvidiaPostTrainingJobs,
|
ListNvidiaPostTrainingJobs,
|
||||||
NvidiaPostTrainingAdapter,
|
NvidiaPostTrainingAdapter,
|
||||||
|
@ -24,6 +20,12 @@ from llama_stack.providers.remote.post_training.nvidia.post_training import (
|
||||||
NvidiaPostTrainingJob,
|
NvidiaPostTrainingJob,
|
||||||
NvidiaPostTrainingJobStatusResponse,
|
NvidiaPostTrainingJobStatusResponse,
|
||||||
)
|
)
|
||||||
|
from llama_stack_client.types.algorithm_config_param import LoraFinetuningConfig, QatFinetuningConfig
|
||||||
|
from llama_stack_client.types.post_training_supervised_fine_tune_params import (
|
||||||
|
TrainingConfig,
|
||||||
|
TrainingConfigDataConfig,
|
||||||
|
TrainingConfigOptimizerConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestNvidiaPostTraining(unittest.TestCase):
|
class TestNvidiaPostTraining(unittest.TestCase):
|
||||||
|
@ -40,8 +42,22 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
||||||
)
|
)
|
||||||
self.mock_make_request = self.make_request_patcher.start()
|
self.mock_make_request = self.make_request_patcher.start()
|
||||||
|
|
||||||
|
# Mock the inference client
|
||||||
|
inference_config = NVIDIAConfig(base_url=os.environ["NVIDIA_BASE_URL"], api_key=None)
|
||||||
|
self.inference_adapter = NVIDIAInferenceAdapter(inference_config)
|
||||||
|
|
||||||
|
self.mock_client = unittest.mock.MagicMock()
|
||||||
|
self.mock_client.chat.completions.create = unittest.mock.AsyncMock()
|
||||||
|
self.inference_mock_make_request = self.mock_client.chat.completions.create
|
||||||
|
self.inference_make_request_patcher = patch(
|
||||||
|
"llama_stack.providers.remote.inference.nvidia.nvidia.NVIDIAInferenceAdapter._get_client",
|
||||||
|
return_value=self.mock_client,
|
||||||
|
)
|
||||||
|
self.inference_make_request_patcher.start()
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
self.make_request_patcher.stop()
|
self.make_request_patcher.stop()
|
||||||
|
self.inference_make_request_patcher.stop()
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def inject_fixtures(self, run_async):
|
def inject_fixtures(self, run_async):
|
||||||
|
@ -290,6 +306,31 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
||||||
expected_params={"job_id": job_id},
|
expected_params={"job_id": job_id},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_inference_register_model(self):
|
||||||
|
model_id = "default/job-1234"
|
||||||
|
model_type = ModelType.llm
|
||||||
|
model = Model(
|
||||||
|
identifier=model_id,
|
||||||
|
provider_id="nvidia",
|
||||||
|
provider_model_id=model_id,
|
||||||
|
provider_resource_id=model_id,
|
||||||
|
model_type=model_type,
|
||||||
|
)
|
||||||
|
result = self.run_async(self.inference_adapter.register_model(model))
|
||||||
|
assert result == model
|
||||||
|
assert len(self.inference_adapter.alias_to_provider_id_map) > 1
|
||||||
|
assert self.inference_adapter.get_provider_model_id(model.provider_model_id) == model_id
|
||||||
|
|
||||||
|
with patch.object(self.inference_adapter, "chat_completion") as mock_chat_completion:
|
||||||
|
self.run_async(
|
||||||
|
self.inference_adapter.chat_completion(
|
||||||
|
model_id=model_id,
|
||||||
|
messages=[{"role": "user", "content": "Hello, model"}],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_chat_completion.assert_called()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue