mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-23 08:29:43 +00:00
refactor lookup_model out so it can be used by vector dbs routing table
This commit is contained in:
parent
d3dee496ec
commit
50d16dc707
3 changed files with 26 additions and 22 deletions
|
|
@ -13,7 +13,7 @@ from llama_stack.distribution.datatypes import (
|
|||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .common import CommonRoutingTableImpl
|
||||
from .common import CommonRoutingTableImpl, lookup_model
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
|
@ -36,23 +36,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
return OpenAIListModelsResponse(data=openai_models)
|
||||
|
||||
async def get_model(self, model_id: str) -> Model:
|
||||
# first try to get the model by identifier
|
||||
# this works if model_id is an alias or is of the form provider_id/provider_model_id
|
||||
model = await self.get_object_by_identifier("model", model_id)
|
||||
if model is not None:
|
||||
return model
|
||||
|
||||
# if not found, this means model_id is an unscoped provider_model_id, we need
|
||||
# to iterate (given a lack of an efficient index on the KVStore)
|
||||
models = await self.get_all_with_type("model")
|
||||
matching_models = [m for m in models if m.provider_resource_id == model_id]
|
||||
if len(matching_models) == 0:
|
||||
raise ValueError(f"Model '{model_id}' not found")
|
||||
|
||||
if len(matching_models) > 1:
|
||||
raise ValueError(f"Multiple providers found for '{model_id}': {[m.provider_id for m in matching_models]}")
|
||||
|
||||
return matching_models[0]
|
||||
return await lookup_model(self, model_id)
|
||||
|
||||
async def register_model(
|
||||
self,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue