This commit is contained in:
Charlie Doern 2025-08-14 21:25:11 -07:00 committed by GitHub
commit 71896ef764
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 25 additions and 17 deletions

View file

@ -171,7 +171,7 @@ class VectorIORouter(VectorIO):
logger.debug(f"VectorIORouter.openai_list_vector_stores: limit={limit}") logger.debug(f"VectorIORouter.openai_list_vector_stores: limit={limit}")
# Route to default provider for now - could aggregate from all providers in the future # Route to default provider for now - could aggregate from all providers in the future
# call retrieve on each vector dbs to get list of vector stores # call retrieve on each vector dbs to get list of vector stores
vector_dbs = await self.routing_table.get_all_with_type("vector_db") vector_dbs = await self.routing_table.get_all_with_type_filtered("vector_db")
all_stores = [] all_stores = []
for vector_db in vector_dbs: for vector_db in vector_dbs:
try: try:

View file

@ -19,7 +19,7 @@ logger = get_logger(name=__name__, category="core")
class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks): class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
async def list_benchmarks(self) -> ListBenchmarksResponse: async def list_benchmarks(self) -> ListBenchmarksResponse:
return ListBenchmarksResponse(data=await self.get_all_with_type("benchmark")) return ListBenchmarksResponse(data=await self.get_all_with_type_filtered("benchmark"))
async def get_benchmark(self, benchmark_id: str) -> Benchmark: async def get_benchmark(self, benchmark_id: str) -> Benchmark:
benchmark = await self.get_object_by_identifier("benchmark", benchmark_id) benchmark = await self.get_object_by_identifier("benchmark", benchmark_id)

View file

@ -232,15 +232,21 @@ class CommonRoutingTableImpl(RoutingTable):
async def get_all_with_type(self, type: str) -> list[RoutableObjectWithProvider]: async def get_all_with_type(self, type: str) -> list[RoutableObjectWithProvider]:
objs = await self.dist_registry.get_all() objs = await self.dist_registry.get_all()
filtered_objs = [obj for obj in objs if obj.type == type] return [obj for obj in objs if obj.type == type]
async def get_all_with_type_filtered(self, type: str) -> list[RoutableObjectWithProvider]:
all_objs = await self.get_all_with_type(type=type)
# Apply attribute-based access control filtering # Apply attribute-based access control filtering
if filtered_objs: if all_objs:
filtered_objs = [ all_objs = [
obj for obj in filtered_objs if is_action_allowed(self.policy, "read", obj, get_authenticated_user()) obj
for obj in all_objs
if is_action_allowed(self.policy, "read", obj, get_authenticated_user())
and obj.provider_id in self.impls_by_provider_id
] ]
return filtered_objs return all_objs
async def lookup_model(routing_table: CommonRoutingTableImpl, model_id: str) -> Model: async def lookup_model(routing_table: CommonRoutingTableImpl, model_id: str) -> Model:
@ -257,7 +263,7 @@ async def lookup_model(routing_table: CommonRoutingTableImpl, model_id: str) ->
) )
# if not found, this means model_id is an unscoped provider_model_id, we need # if not found, this means model_id is an unscoped provider_model_id, we need
# to iterate (given a lack of an efficient index on the KVStore) # to iterate (given a lack of an efficient index on the KVStore)
models = await routing_table.get_all_with_type("model") models = await routing_table.get_all_with_type_filtered("model")
matching_models = [m for m in models if m.provider_resource_id == model_id] matching_models = [m for m in models if m.provider_resource_id == model_id]
if len(matching_models) == 0: if len(matching_models) == 0:
raise ModelNotFoundError(model_id) raise ModelNotFoundError(model_id)

View file

@ -31,7 +31,7 @@ logger = get_logger(name=__name__, category="core")
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
async def list_datasets(self) -> ListDatasetsResponse: async def list_datasets(self) -> ListDatasetsResponse:
return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value)) return ListDatasetsResponse(data=await self.get_all_with_type_filtered(ResourceType.dataset.value))
async def get_dataset(self, dataset_id: str) -> Dataset: async def get_dataset(self, dataset_id: str) -> Dataset:
dataset = await self.get_object_by_identifier("dataset", dataset_id) dataset = await self.get_object_by_identifier("dataset", dataset_id)

View file

@ -43,10 +43,10 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
await self.update_registered_models(provider_id, models) await self.update_registered_models(provider_id, models)
async def list_models(self) -> ListModelsResponse: async def list_models(self) -> ListModelsResponse:
return ListModelsResponse(data=await self.get_all_with_type("model")) return ListModelsResponse(data=await self.get_all_with_type_filtered("model"))
async def openai_list_models(self) -> OpenAIListModelsResponse: async def openai_list_models(self) -> OpenAIListModelsResponse:
models = await self.get_all_with_type("model") models = await self.get_all_with_type_filtered("model")
openai_models = [ openai_models = [
OpenAIModel( OpenAIModel(
id=model.identifier, id=model.identifier,
@ -122,7 +122,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
provider_id: str, provider_id: str,
models: list[Model], models: list[Model],
) -> None: ) -> None:
existing_models = await self.get_all_with_type("model") existing_models = await self.get_all_with_type_filtered("model")
# we may have an alias for the model registered by the user (or during initialization # we may have an alias for the model registered by the user (or during initialization
# from run.yaml) that we need to keep track of # from run.yaml) that we need to keep track of

View file

@ -24,7 +24,9 @@ logger = get_logger(name=__name__, category="core")
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions): class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
async def list_scoring_functions(self) -> ListScoringFunctionsResponse: async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value)) return ListScoringFunctionsResponse(
data=await self.get_all_with_type_filtered(ResourceType.scoring_function.value)
)
async def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn: async def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn:
scoring_fn = await self.get_object_by_identifier("scoring_function", scoring_fn_id) scoring_fn = await self.get_object_by_identifier("scoring_function", scoring_fn_id)

View file

@ -20,7 +20,7 @@ logger = get_logger(name=__name__, category="core")
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
async def list_shields(self) -> ListShieldsResponse: async def list_shields(self) -> ListShieldsResponse:
return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value)) return ListShieldsResponse(data=await self.get_all_with_type_filtered(ResourceType.shield.value))
async def get_shield(self, identifier: str) -> Shield: async def get_shield(self, identifier: str) -> Shield:
shield = await self.get_object_by_identifier("shield", identifier) shield = await self.get_object_by_identifier("shield", identifier)

View file

@ -49,7 +49,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
toolgroup_id = group_id toolgroup_id = group_id
toolgroups = [await self.get_tool_group(toolgroup_id)] toolgroups = [await self.get_tool_group(toolgroup_id)]
else: else:
toolgroups = await self.get_all_with_type("tool_group") toolgroups = await self.get_all_with_type_filtered("tool_group")
all_tools = [] all_tools = []
for toolgroup in toolgroups: for toolgroup in toolgroups:
@ -83,7 +83,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
self.tool_to_toolgroup[tool.identifier] = toolgroup.identifier self.tool_to_toolgroup[tool.identifier] = toolgroup.identifier
async def list_tool_groups(self) -> ListToolGroupsResponse: async def list_tool_groups(self) -> ListToolGroupsResponse:
return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group")) return ListToolGroupsResponse(data=await self.get_all_with_type_filtered("tool_group"))
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup: async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
tool_group = await self.get_object_by_identifier("tool_group", toolgroup_id) tool_group = await self.get_object_by_identifier("tool_group", toolgroup_id)

View file

@ -35,7 +35,7 @@ logger = get_logger(name=__name__, category="core")
class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
async def list_vector_dbs(self) -> ListVectorDBsResponse: async def list_vector_dbs(self) -> ListVectorDBsResponse:
return ListVectorDBsResponse(data=await self.get_all_with_type("vector_db")) return ListVectorDBsResponse(data=await self.get_all_with_type_filtered("vector_db"))
async def get_vector_db(self, vector_db_id: str) -> VectorDB: async def get_vector_db(self, vector_db_id: str) -> VectorDB:
vector_db = await self.get_object_by_identifier("vector_db", vector_db_id) vector_db = await self.get_object_by_identifier("vector_db", vector_db_id)