mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-28 18:32:00 +00:00
add register_model method
This commit is contained in:
parent
c169c164b3
commit
3d2b374ee7
2 changed files with 99 additions and 10 deletions
|
|
@ -33,11 +33,15 @@ from llama_stack.apis.inference import (
|
|||
ToolChoice,
|
||||
ToolConfig,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
SamplingParams,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.providers.utils.inference import (
|
||||
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
)
|
||||
|
|
@ -114,10 +118,13 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
"meta/llama-3.2-90b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-90b-vision-instruct",
|
||||
}
|
||||
|
||||
base_url = f"{self._config.url}/v1"
|
||||
if _is_nvidia_hosted(self._config) and provider_model_id in special_model_urls:
|
||||
base_url = special_model_urls[provider_model_id]
|
||||
|
||||
# add /v1 in case of hosted models
|
||||
base_url = self._config.url
|
||||
if _is_nvidia_hosted(self._config):
|
||||
if provider_model_id in special_model_urls:
|
||||
base_url = special_model_urls[provider_model_id]
|
||||
else:
|
||||
base_url = f"{self._config.url}/v1"
|
||||
return _get_client_for_base_url(base_url)
|
||||
|
||||
async def completion(
|
||||
|
|
@ -265,3 +272,44 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
else:
|
||||
# we pass n=1 to get only one completion
|
||||
return convert_openai_chat_completion_choice(response.choices[0])
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
"""
|
||||
Allow non-llama model registration.
|
||||
|
||||
Non-llama model registration: API Catalogue models, post-training models, etc.
|
||||
client = LlamaStackAsLibraryClient("nvidia")
|
||||
client.models.register(
|
||||
model_id="mistralai/mixtral-8x7b-instruct-v0.1",
|
||||
model_type=ModelType.llm,
|
||||
provider_id="nvidia",
|
||||
provider_model_id="mistralai/mixtral-8x7b-instruct-v0.1"
|
||||
)
|
||||
|
||||
NOTE: Only supports models endpoints compatible with AsyncOpenAI base_url format.
|
||||
"""
|
||||
if model.model_type == ModelType.embedding:
|
||||
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model
|
||||
provider_resource_id = model.provider_resource_id
|
||||
else:
|
||||
provider_resource_id = self.get_provider_model_id(model.provider_resource_id)
|
||||
|
||||
if provider_resource_id:
|
||||
model.provider_resource_id = provider_resource_id
|
||||
else:
|
||||
llama_model = model.metadata.get("llama_model")
|
||||
existing_llama_model = self.get_llama_model(model.provider_resource_id)
|
||||
if existing_llama_model:
|
||||
if existing_llama_model != llama_model:
|
||||
raise ValueError(
|
||||
f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'"
|
||||
)
|
||||
else:
|
||||
# not llama model
|
||||
if llama_model in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR:
|
||||
self.provider_id_to_llama_model_map[model.provider_resource_id] = (
|
||||
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model]
|
||||
)
|
||||
else:
|
||||
self.alias_to_provider_id_map[model.provider_model_id] = model.provider_model_id
|
||||
return model
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue