feat(registry): more flexible model lookup

This commit is contained in:
Ashwin Bharambe 2025-07-22 09:08:51 -07:00
parent b5a6ecc331
commit 352bf3ec56

View file

@ -36,11 +36,24 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
return OpenAIListModelsResponse(data=openai_models)
async def get_model(self, model_id: str) -> Model:
# first try to get the model by identifier
# this works if model_id is an alias or is of the form provider_id/provider_model_id
model = await self.get_object_by_identifier("model", model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found")
if model is not None:
return model
# if not found, this means model_id is an unscoped provider_model_id, we need
# to iterate (given a lack of an efficient index on the KVStore)
models = await self.get_all_with_type("model")
matching_models = [m for m in models if m.provider_resource_id == model_id]
if len(matching_models) == 0:
raise ValueError(f"Model '{model_id}' not found")
if len(matching_models) > 1:
raise ValueError(f"Multiple providers found for '{model_id}': {[m.provider_id for m in matching_models]}")
return matching_models[0]
async def register_model(
self,
model_id: str,
@ -49,24 +62,32 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
metadata: dict[str, Any] | None = None,
model_type: ModelType | None = None,
) -> Model:
if provider_model_id is None:
provider_model_id = model_id
if provider_id is None:
# If provider_id not specified, use the only provider if it supports this model
if len(self.impls_by_provider_id) == 1:
provider_id = list(self.impls_by_provider_id.keys())[0]
else:
raise ValueError(
f"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}"
f"Please specify a provider_id for model {model_id} since multiple providers are available: {self.impls_by_provider_id.keys()}"
)
if metadata is None:
metadata = {}
if model_type is None:
model_type = ModelType.llm
provider_model_id = provider_model_id or model_id
metadata = metadata or {}
model_type = model_type or ModelType.llm
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
raise ValueError("Embedding model must have an embedding dimension in its metadata")
# an identifier different than provider_model_id implies it is an alias, so that
# becomes the globally unique identifier. otherwise provider_model_ids can conflict,
# so as a general rule we must use the provider_id to disambiguate.
if model_id != provider_model_id:
identifier = model_id
else:
identifier = f"{provider_id}/{provider_model_id}"
model = ModelWithOwner(
identifier=model_id,
identifier=identifier,
provider_resource_id=provider_model_id,
provider_id=provider_id,
metadata=metadata,