fix: allow lookup of models registered at runtime

adds tests for ModelRegistryHelper to stabilize behavior
This commit is contained in:
Matthew Farrellee 2025-04-16 15:51:08 -04:00
parent e4d001c4e4
commit 9982aa64f0
3 changed files with 165 additions and 1 deletions

View file

@ -59,6 +59,8 @@ def build_model_entry(provider_model_id: str, model_descriptor: str) -> Provider
class ModelRegistryHelper(ModelsProtocolPrivate):
def __init__(self, model_entries: List[ProviderModelEntry]):
self.supported_model_ids = {entry.provider_model_id for entry in model_entries}
self.alias_to_provider_id_map = {}
self.provider_id_to_llama_model_map = {}
for entry in model_entries:
@ -79,6 +81,16 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
return self.provider_id_to_llama_model_map.get(provider_model_id, None)
async def register_model(self, model: Model) -> Model:
if model.provider_resource_id not in self.supported_model_ids:
raise ValueError(
f"Model id '{model.provider_resource_id}' is not supported. Supported ids are: {', '.join(self.supported_model_ids)}"
)
if model.model_id in self.alias_to_provider_id_map:
# be idemopotent
if model.provider_resource_id != self.alias_to_provider_id_map[model.model_id]:
raise ValueError(
f"Model id '{model.model_id}' is already registered. Please use a different id or unregister it first."
)
if model.model_type == ModelType.embedding:
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model
provider_resource_id = model.provider_resource_id
@ -108,7 +120,12 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model]
)
self.alias_to_provider_id_map[model.model_id] = model.provider_resource_id
return model
async def unregister_model(self, model_id: str) -> None:
pass
if model_id not in self.alias_to_provider_id_map:
raise ValueError(f"Model id '{model_id}' is not registered.")
del self.alias_to_provider_id_map[model_id]