mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-15 06:00:48 +00:00
Merge 3f580503ef
into 81ecaf6221
This commit is contained in:
commit
71896ef764
9 changed files with 25 additions and 17 deletions
|
@ -171,7 +171,7 @@ class VectorIORouter(VectorIO):
|
||||||
logger.debug(f"VectorIORouter.openai_list_vector_stores: limit={limit}")
|
logger.debug(f"VectorIORouter.openai_list_vector_stores: limit={limit}")
|
||||||
# Route to default provider for now - could aggregate from all providers in the future
|
# Route to default provider for now - could aggregate from all providers in the future
|
||||||
# call retrieve on each vector dbs to get list of vector stores
|
# call retrieve on each vector dbs to get list of vector stores
|
||||||
vector_dbs = await self.routing_table.get_all_with_type("vector_db")
|
vector_dbs = await self.routing_table.get_all_with_type_filtered("vector_db")
|
||||||
all_stores = []
|
all_stores = []
|
||||||
for vector_db in vector_dbs:
|
for vector_db in vector_dbs:
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -19,7 +19,7 @@ logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
|
class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
|
||||||
async def list_benchmarks(self) -> ListBenchmarksResponse:
|
async def list_benchmarks(self) -> ListBenchmarksResponse:
|
||||||
return ListBenchmarksResponse(data=await self.get_all_with_type("benchmark"))
|
return ListBenchmarksResponse(data=await self.get_all_with_type_filtered("benchmark"))
|
||||||
|
|
||||||
async def get_benchmark(self, benchmark_id: str) -> Benchmark:
|
async def get_benchmark(self, benchmark_id: str) -> Benchmark:
|
||||||
benchmark = await self.get_object_by_identifier("benchmark", benchmark_id)
|
benchmark = await self.get_object_by_identifier("benchmark", benchmark_id)
|
||||||
|
|
|
@ -232,15 +232,21 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
|
|
||||||
async def get_all_with_type(self, type: str) -> list[RoutableObjectWithProvider]:
|
async def get_all_with_type(self, type: str) -> list[RoutableObjectWithProvider]:
|
||||||
objs = await self.dist_registry.get_all()
|
objs = await self.dist_registry.get_all()
|
||||||
filtered_objs = [obj for obj in objs if obj.type == type]
|
return [obj for obj in objs if obj.type == type]
|
||||||
|
|
||||||
|
async def get_all_with_type_filtered(self, type: str) -> list[RoutableObjectWithProvider]:
|
||||||
|
all_objs = await self.get_all_with_type(type=type)
|
||||||
|
|
||||||
# Apply attribute-based access control filtering
|
# Apply attribute-based access control filtering
|
||||||
if filtered_objs:
|
if all_objs:
|
||||||
filtered_objs = [
|
all_objs = [
|
||||||
obj for obj in filtered_objs if is_action_allowed(self.policy, "read", obj, get_authenticated_user())
|
obj
|
||||||
|
for obj in all_objs
|
||||||
|
if is_action_allowed(self.policy, "read", obj, get_authenticated_user())
|
||||||
|
and obj.provider_id in self.impls_by_provider_id
|
||||||
]
|
]
|
||||||
|
|
||||||
return filtered_objs
|
return all_objs
|
||||||
|
|
||||||
|
|
||||||
async def lookup_model(routing_table: CommonRoutingTableImpl, model_id: str) -> Model:
|
async def lookup_model(routing_table: CommonRoutingTableImpl, model_id: str) -> Model:
|
||||||
|
@ -257,7 +263,7 @@ async def lookup_model(routing_table: CommonRoutingTableImpl, model_id: str) ->
|
||||||
)
|
)
|
||||||
# if not found, this means model_id is an unscoped provider_model_id, we need
|
# if not found, this means model_id is an unscoped provider_model_id, we need
|
||||||
# to iterate (given a lack of an efficient index on the KVStore)
|
# to iterate (given a lack of an efficient index on the KVStore)
|
||||||
models = await routing_table.get_all_with_type("model")
|
models = await routing_table.get_all_with_type_filtered("model")
|
||||||
matching_models = [m for m in models if m.provider_resource_id == model_id]
|
matching_models = [m for m in models if m.provider_resource_id == model_id]
|
||||||
if len(matching_models) == 0:
|
if len(matching_models) == 0:
|
||||||
raise ModelNotFoundError(model_id)
|
raise ModelNotFoundError(model_id)
|
||||||
|
|
|
@ -31,7 +31,7 @@ logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
||||||
async def list_datasets(self) -> ListDatasetsResponse:
|
async def list_datasets(self) -> ListDatasetsResponse:
|
||||||
return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value))
|
return ListDatasetsResponse(data=await self.get_all_with_type_filtered(ResourceType.dataset.value))
|
||||||
|
|
||||||
async def get_dataset(self, dataset_id: str) -> Dataset:
|
async def get_dataset(self, dataset_id: str) -> Dataset:
|
||||||
dataset = await self.get_object_by_identifier("dataset", dataset_id)
|
dataset = await self.get_object_by_identifier("dataset", dataset_id)
|
||||||
|
|
|
@ -43,10 +43,10 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
await self.update_registered_models(provider_id, models)
|
await self.update_registered_models(provider_id, models)
|
||||||
|
|
||||||
async def list_models(self) -> ListModelsResponse:
|
async def list_models(self) -> ListModelsResponse:
|
||||||
return ListModelsResponse(data=await self.get_all_with_type("model"))
|
return ListModelsResponse(data=await self.get_all_with_type_filtered("model"))
|
||||||
|
|
||||||
async def openai_list_models(self) -> OpenAIListModelsResponse:
|
async def openai_list_models(self) -> OpenAIListModelsResponse:
|
||||||
models = await self.get_all_with_type("model")
|
models = await self.get_all_with_type_filtered("model")
|
||||||
openai_models = [
|
openai_models = [
|
||||||
OpenAIModel(
|
OpenAIModel(
|
||||||
id=model.identifier,
|
id=model.identifier,
|
||||||
|
@ -122,7 +122,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
provider_id: str,
|
provider_id: str,
|
||||||
models: list[Model],
|
models: list[Model],
|
||||||
) -> None:
|
) -> None:
|
||||||
existing_models = await self.get_all_with_type("model")
|
existing_models = await self.get_all_with_type_filtered("model")
|
||||||
|
|
||||||
# we may have an alias for the model registered by the user (or during initialization
|
# we may have an alias for the model registered by the user (or during initialization
|
||||||
# from run.yaml) that we need to keep track of
|
# from run.yaml) that we need to keep track of
|
||||||
|
|
|
@ -24,7 +24,9 @@ logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
||||||
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
|
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
|
||||||
return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value))
|
return ListScoringFunctionsResponse(
|
||||||
|
data=await self.get_all_with_type_filtered(ResourceType.scoring_function.value)
|
||||||
|
)
|
||||||
|
|
||||||
async def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn:
|
async def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn:
|
||||||
scoring_fn = await self.get_object_by_identifier("scoring_function", scoring_fn_id)
|
scoring_fn = await self.get_object_by_identifier("scoring_function", scoring_fn_id)
|
||||||
|
|
|
@ -20,7 +20,7 @@ logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||||
async def list_shields(self) -> ListShieldsResponse:
|
async def list_shields(self) -> ListShieldsResponse:
|
||||||
return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value))
|
return ListShieldsResponse(data=await self.get_all_with_type_filtered(ResourceType.shield.value))
|
||||||
|
|
||||||
async def get_shield(self, identifier: str) -> Shield:
|
async def get_shield(self, identifier: str) -> Shield:
|
||||||
shield = await self.get_object_by_identifier("shield", identifier)
|
shield = await self.get_object_by_identifier("shield", identifier)
|
||||||
|
|
|
@ -49,7 +49,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
toolgroup_id = group_id
|
toolgroup_id = group_id
|
||||||
toolgroups = [await self.get_tool_group(toolgroup_id)]
|
toolgroups = [await self.get_tool_group(toolgroup_id)]
|
||||||
else:
|
else:
|
||||||
toolgroups = await self.get_all_with_type("tool_group")
|
toolgroups = await self.get_all_with_type_filtered("tool_group")
|
||||||
|
|
||||||
all_tools = []
|
all_tools = []
|
||||||
for toolgroup in toolgroups:
|
for toolgroup in toolgroups:
|
||||||
|
@ -83,7 +83,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
self.tool_to_toolgroup[tool.identifier] = toolgroup.identifier
|
self.tool_to_toolgroup[tool.identifier] = toolgroup.identifier
|
||||||
|
|
||||||
async def list_tool_groups(self) -> ListToolGroupsResponse:
|
async def list_tool_groups(self) -> ListToolGroupsResponse:
|
||||||
return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group"))
|
return ListToolGroupsResponse(data=await self.get_all_with_type_filtered("tool_group"))
|
||||||
|
|
||||||
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
|
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
|
||||||
tool_group = await self.get_object_by_identifier("tool_group", toolgroup_id)
|
tool_group = await self.get_object_by_identifier("tool_group", toolgroup_id)
|
||||||
|
|
|
@ -35,7 +35,7 @@ logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
||||||
async def list_vector_dbs(self) -> ListVectorDBsResponse:
|
async def list_vector_dbs(self) -> ListVectorDBsResponse:
|
||||||
return ListVectorDBsResponse(data=await self.get_all_with_type("vector_db"))
|
return ListVectorDBsResponse(data=await self.get_all_with_type_filtered("vector_db"))
|
||||||
|
|
||||||
async def get_vector_db(self, vector_db_id: str) -> VectorDB:
|
async def get_vector_db(self, vector_db_id: str) -> VectorDB:
|
||||||
vector_db = await self.get_object_by_identifier("vector_db", vector_db_id)
|
vector_db = await self.get_object_by_identifier("vector_db", vector_db_id)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue