mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-01 16:24:44 +00:00
more validation
This commit is contained in:
parent
ccb5445d2a
commit
acf9af841b
2 changed files with 37 additions and 4 deletions
|
@ -54,3 +54,20 @@ class TestModelRegistration:
|
||||||
model_id="custom-model",
|
model_id="custom-model",
|
||||||
metadata={"llama_model": CoreModelId.llama3_1_8b_instruct.value},
|
metadata={"llama_model": CoreModelId.llama3_1_8b_instruct.value},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
await models_impl.register_model(
|
||||||
|
model_id="custom-model-2",
|
||||||
|
metadata={"llama_model": CoreModelId.llama3_2_3b_instruct.value},
|
||||||
|
provider_model_id="custom-model",
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_register_with_invalid_llama_model(self, inference_stack):
|
||||||
|
_, models_impl = inference_stack
|
||||||
|
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
await models_impl.register_model(
|
||||||
|
model_id="custom-model-2",
|
||||||
|
metadata={"llama_model": "invalid-llama-model"},
|
||||||
|
)
|
||||||
|
|
|
@ -7,6 +7,7 @@
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from llama_models.datatypes import CoreModelId
|
||||||
from llama_models.sku_list import all_registered_models
|
from llama_models.sku_list import all_registered_models
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
||||||
|
@ -69,9 +70,24 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
||||||
f"Model '{model.provider_resource_id}' is not available and no llama_model was specified in metadata. "
|
f"Model '{model.provider_resource_id}' is not available and no llama_model was specified in metadata. "
|
||||||
"Please specify a llama_model in metadata or use a supported model identifier"
|
"Please specify a llama_model in metadata or use a supported model identifier"
|
||||||
)
|
)
|
||||||
# Register the mapping from provider model id to llama model for future lookups
|
existing_llama_model = self.get_llama_model(model.provider_resource_id)
|
||||||
self.provider_id_to_llama_model_map[model.provider_resource_id] = (
|
if existing_llama_model:
|
||||||
model.metadata["llama_model"]
|
if existing_llama_model != model.metadata["llama_model"]:
|
||||||
)
|
raise ValueError(
|
||||||
|
f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Validate that the llama model is a valid CoreModelId
|
||||||
|
try:
|
||||||
|
CoreModelId(model.metadata["llama_model"])
|
||||||
|
except ValueError as err:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid llama_model '{model.metadata['llama_model']}' specified in metadata. "
|
||||||
|
f"Must be one of: {', '.join(m.value for m in CoreModelId)}"
|
||||||
|
) from err
|
||||||
|
# Register the mapping from provider model id to llama model for future lookups
|
||||||
|
self.provider_id_to_llama_model_map[model.provider_resource_id] = (
|
||||||
|
model.metadata["llama_model"]
|
||||||
|
)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue