mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-22 22:42:25 +00:00
refactor lookup_model out so it can be used by vector dbs routing table
This commit is contained in:
parent
d3dee496ec
commit
50d16dc707
3 changed files with 26 additions and 22 deletions
|
|
@ -6,6 +6,7 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_stack.apis.models import Model
|
||||||
from llama_stack.apis.resource import ResourceType
|
from llama_stack.apis.resource import ResourceType
|
||||||
from llama_stack.apis.scoring_functions import ScoringFn
|
from llama_stack.apis.scoring_functions import ScoringFn
|
||||||
from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed
|
from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed
|
||||||
|
|
@ -235,3 +236,23 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
]
|
]
|
||||||
|
|
||||||
return filtered_objs
|
return filtered_objs
|
||||||
|
|
||||||
|
|
||||||
|
async def lookup_model(routing_table: CommonRoutingTableImpl, model_id: str) -> Model:
|
||||||
|
# first try to get the model by identifier
|
||||||
|
# this works if model_id is an alias or is of the form provider_id/provider_model_id
|
||||||
|
model = await routing_table.get_object_by_identifier("model", model_id)
|
||||||
|
if model is not None:
|
||||||
|
return model
|
||||||
|
|
||||||
|
# if not found, this means model_id is an unscoped provider_model_id, we need
|
||||||
|
# to iterate (given a lack of an efficient index on the KVStore)
|
||||||
|
models = await routing_table.get_all_with_type("model")
|
||||||
|
matching_models = [m for m in models if m.provider_resource_id == model_id]
|
||||||
|
if len(matching_models) == 0:
|
||||||
|
raise ValueError(f"Model '{model_id}' not found")
|
||||||
|
|
||||||
|
if len(matching_models) > 1:
|
||||||
|
raise ValueError(f"Multiple providers found for '{model_id}': {[m.provider_id for m in matching_models]}")
|
||||||
|
|
||||||
|
return matching_models[0]
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ from llama_stack.distribution.datatypes import (
|
||||||
)
|
)
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .common import CommonRoutingTableImpl
|
from .common import CommonRoutingTableImpl, lookup_model
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
@ -36,23 +36,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
return OpenAIListModelsResponse(data=openai_models)
|
return OpenAIListModelsResponse(data=openai_models)
|
||||||
|
|
||||||
async def get_model(self, model_id: str) -> Model:
|
async def get_model(self, model_id: str) -> Model:
|
||||||
# first try to get the model by identifier
|
return await lookup_model(self, model_id)
|
||||||
# this works if model_id is an alias or is of the form provider_id/provider_model_id
|
|
||||||
model = await self.get_object_by_identifier("model", model_id)
|
|
||||||
if model is not None:
|
|
||||||
return model
|
|
||||||
|
|
||||||
# if not found, this means model_id is an unscoped provider_model_id, we need
|
|
||||||
# to iterate (given a lack of an efficient index on the KVStore)
|
|
||||||
models = await self.get_all_with_type("model")
|
|
||||||
matching_models = [m for m in models if m.provider_resource_id == model_id]
|
|
||||||
if len(matching_models) == 0:
|
|
||||||
raise ValueError(f"Model '{model_id}' not found")
|
|
||||||
|
|
||||||
if len(matching_models) > 1:
|
|
||||||
raise ValueError(f"Multiple providers found for '{model_id}': {[m.provider_id for m in matching_models]}")
|
|
||||||
|
|
||||||
return matching_models[0]
|
|
||||||
|
|
||||||
async def register_model(
|
async def register_model(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -27,7 +27,7 @@ from llama_stack.distribution.datatypes import (
|
||||||
)
|
)
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .common import CommonRoutingTableImpl
|
from .common import CommonRoutingTableImpl, lookup_model
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
@ -51,8 +51,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
||||||
provider_vector_db_id: str | None = None,
|
provider_vector_db_id: str | None = None,
|
||||||
vector_db_name: str | None = None,
|
vector_db_name: str | None = None,
|
||||||
) -> VectorDB:
|
) -> VectorDB:
|
||||||
if provider_vector_db_id is None:
|
provider_vector_db_id = provider_vector_db_id or vector_db_id
|
||||||
provider_vector_db_id = vector_db_id
|
|
||||||
if provider_id is None:
|
if provider_id is None:
|
||||||
if len(self.impls_by_provider_id) > 0:
|
if len(self.impls_by_provider_id) > 0:
|
||||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||||
|
|
@ -62,7 +61,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("No provider available. Please configure a vector_io provider.")
|
raise ValueError("No provider available. Please configure a vector_io provider.")
|
||||||
model = await self.get_object_by_identifier("model", embedding_model)
|
model = await lookup_model(self, embedding_model)
|
||||||
if model is None:
|
if model is None:
|
||||||
raise ValueError(f"Model {embedding_model} not found")
|
raise ValueError(f"Model {embedding_model} not found")
|
||||||
if model.model_type != ModelType.embedding:
|
if model.model_type != ModelType.embedding:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue