mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-31 16:01:46 +00:00
Allow models to be registered as long as llama model is provided
This commit is contained in:
parent
f1b9578f8d
commit
ccb5445d2a
2 changed files with 30 additions and 21 deletions
|
@ -8,6 +8,7 @@ import pytest
|
|||
|
||||
from llama_models.datatypes import CoreModelId
|
||||
|
||||
|
||||
# How to run this test:
|
||||
#
|
||||
# pytest -v -s llama_stack/providers/tests/inference/test_model_registration.py
|
||||
|
@ -17,8 +18,17 @@ from llama_models.datatypes import CoreModelId
|
|||
|
||||
class TestModelRegistration:
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_unsupported_model(self, inference_stack):
|
||||
_, models_impl = inference_stack
|
||||
async def test_register_unsupported_model(self, inference_stack, inference_model):
|
||||
inference_impl, models_impl = inference_stack
|
||||
|
||||
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
||||
if provider.__provider_spec__.provider_type not in (
|
||||
"meta-reference",
|
||||
"remote::ollama",
|
||||
"remote::vllm",
|
||||
"remote::tgi",
|
||||
):
|
||||
pytest.skip("70B instruct is too big only for local inference providers")
|
||||
|
||||
# Try to register a model that's too large for local inference
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
|
@ -37,21 +47,10 @@ class TestModelRegistration:
|
|||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_model(self, inference_stack):
|
||||
async def test_register_with_llama_model(self, inference_stack):
|
||||
_, models_impl = inference_stack
|
||||
|
||||
# Register a model to update
|
||||
model_id = CoreModelId.llama3_1_8b_instruct.value
|
||||
old_model = await models_impl.register_model(model_id=model_id)
|
||||
|
||||
# Update the model
|
||||
new_model_id = CoreModelId.llama3_2_3b_instruct.value
|
||||
updated_model = await models_impl.update_model(
|
||||
model_id=model_id, provider_model_id=new_model_id
|
||||
_ = await models_impl.register_model(
|
||||
model_id="custom-model",
|
||||
metadata={"llama_model": CoreModelId.llama3_1_8b_instruct.value},
|
||||
)
|
||||
|
||||
# Retrieve the updated model to verify changes
|
||||
assert updated_model.provider_resource_id != old_model.provider_resource_id
|
||||
|
||||
# Cleanup
|
||||
await models_impl.unregister_model(model_id=model_id)
|
||||
|
|
|
@ -51,7 +51,7 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
|||
if identifier in self.alias_to_provider_id_map:
|
||||
return self.alias_to_provider_id_map[identifier]
|
||||
else:
|
||||
raise ValueError(f"Unknown model: `{identifier}`")
|
||||
return None
|
||||
|
||||
def get_llama_model(self, provider_model_id: str) -> str:
|
||||
if provider_model_id in self.provider_id_to_llama_model_map:
|
||||
|
@ -60,8 +60,18 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
|||
return None
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
model.provider_resource_id = self.get_provider_model_id(
|
||||
model.provider_resource_id
|
||||
)
|
||||
provider_resource_id = self.get_provider_model_id(model.provider_resource_id)
|
||||
if provider_resource_id:
|
||||
model.provider_resource_id = provider_resource_id
|
||||
else:
|
||||
if model.metadata.get("llama_model") is None:
|
||||
raise ValueError(
|
||||
f"Model '{model.provider_resource_id}' is not available and no llama_model was specified in metadata. "
|
||||
"Please specify a llama_model in metadata or use a supported model identifier"
|
||||
)
|
||||
# Register the mapping from provider model id to llama model for future lookups
|
||||
self.provider_id_to_llama_model_map[model.provider_resource_id] = (
|
||||
model.metadata["llama_model"]
|
||||
)
|
||||
|
||||
return model
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue