forked from phoenix-oss/llama-stack-mirror
merge
This commit is contained in:
commit
a54d757ade
197 changed files with 9392 additions and 3089 deletions
|
@ -19,6 +19,8 @@ from llama_stack.apis.datasets import (
|
|||
DatasetType,
|
||||
DataSource,
|
||||
ListDatasetsResponse,
|
||||
RowsDataSource,
|
||||
URIDataSource,
|
||||
)
|
||||
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
|
@ -32,11 +34,22 @@ from llama_stack.apis.tools import (
|
|||
ToolHost,
|
||||
)
|
||||
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
|
||||
from llama_stack.distribution.access_control import check_access
|
||||
from llama_stack.distribution.datatypes import (
|
||||
AccessAttributes,
|
||||
BenchmarkWithACL,
|
||||
DatasetWithACL,
|
||||
ModelWithACL,
|
||||
RoutableObject,
|
||||
RoutableObjectWithProvider,
|
||||
RoutedProtocol,
|
||||
ScoringFnWithACL,
|
||||
ShieldWithACL,
|
||||
ToolGroupWithACL,
|
||||
ToolWithACL,
|
||||
VectorDBWithACL,
|
||||
)
|
||||
from llama_stack.distribution.request_headers import get_auth_attributes
|
||||
from llama_stack.distribution.store import DistributionRegistry
|
||||
from llama_stack.providers.datatypes import Api, RoutingTable
|
||||
|
||||
|
@ -165,6 +178,11 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
if not obj:
|
||||
return None
|
||||
|
||||
# Check if user has permission to access this object
|
||||
if not check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes()):
|
||||
logger.debug(f"Access denied to {type} '{identifier}' based on attribute mismatch")
|
||||
return None
|
||||
|
||||
return obj
|
||||
|
||||
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
|
||||
|
@ -181,6 +199,13 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
|
||||
p = self.impls_by_provider_id[obj.provider_id]
|
||||
|
||||
# If object supports access control but no attributes set, use creator's attributes
|
||||
if not obj.access_attributes:
|
||||
creator_attributes = get_auth_attributes()
|
||||
if creator_attributes:
|
||||
obj.access_attributes = AccessAttributes(**creator_attributes)
|
||||
logger.info(f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity")
|
||||
|
||||
registered_obj = await register_object_with_provider(obj, p)
|
||||
# TODO: This needs to be fixed for all APIs once they return the registered object
|
||||
if obj.type == ResourceType.model.value:
|
||||
|
@ -193,7 +218,17 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
|
||||
async def get_all_with_type(self, type: str) -> List[RoutableObjectWithProvider]:
|
||||
objs = await self.dist_registry.get_all()
|
||||
return [obj for obj in objs if obj.type == type]
|
||||
filtered_objs = [obj for obj in objs if obj.type == type]
|
||||
|
||||
# Apply attribute-based access control filtering
|
||||
if filtered_objs:
|
||||
filtered_objs = [
|
||||
obj
|
||||
for obj in filtered_objs
|
||||
if check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes())
|
||||
]
|
||||
|
||||
return filtered_objs
|
||||
|
||||
|
||||
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||
|
@ -230,7 +265,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
model_type = ModelType.llm
|
||||
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
|
||||
raise ValueError("Embedding model must have an embedding dimension in its metadata")
|
||||
model = Model(
|
||||
model = ModelWithACL(
|
||||
identifier=model_id,
|
||||
provider_resource_id=provider_model_id,
|
||||
provider_id=provider_id,
|
||||
|
@ -276,7 +311,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
|||
)
|
||||
if params is None:
|
||||
params = {}
|
||||
shield = Shield(
|
||||
shield = ShieldWithACL(
|
||||
identifier=shield_id,
|
||||
provider_resource_id=provider_shield_id,
|
||||
provider_id=provider_id,
|
||||
|
@ -330,7 +365,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|||
"embedding_model": embedding_model,
|
||||
"embedding_dimension": model.metadata["embedding_dimension"],
|
||||
}
|
||||
vector_db = TypeAdapter(VectorDB).validate_python(vector_db_data)
|
||||
vector_db = TypeAdapter(VectorDBWithACL).validate_python(vector_db_data)
|
||||
await self.register_object(vector_db)
|
||||
return vector_db
|
||||
|
||||
|
@ -358,6 +393,12 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
|||
metadata: Optional[Dict[str, Any]] = None,
|
||||
dataset_id: Optional[str] = None,
|
||||
) -> Dataset:
|
||||
if isinstance(source, dict):
|
||||
if source["type"] == "uri":
|
||||
source = URIDataSource.parse_obj(source)
|
||||
elif source["type"] == "rows":
|
||||
source = RowsDataSource.parse_obj(source)
|
||||
|
||||
if not dataset_id:
|
||||
dataset_id = f"dataset-{str(uuid.uuid4())}"
|
||||
|
||||
|
@ -378,7 +419,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
|||
if metadata is None:
|
||||
metadata = {}
|
||||
|
||||
dataset = Dataset(
|
||||
dataset = DatasetWithACL(
|
||||
identifier=dataset_id,
|
||||
provider_resource_id=provider_dataset_id,
|
||||
provider_id=provider_id,
|
||||
|
@ -429,7 +470,7 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
|
|||
raise ValueError("No evaluation providers available. Please configure an evaluation provider.")
|
||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
|
||||
benchmark = Benchmark(
|
||||
benchmark = BenchmarkWithACL(
|
||||
identifier=benchmark_id,
|
||||
dataset_id=dataset_id,
|
||||
grader_ids=grader_ids,
|
||||
|
@ -473,7 +514,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
|
||||
for tool_def in tool_defs:
|
||||
tools.append(
|
||||
Tool(
|
||||
ToolWithACL(
|
||||
identifier=tool_def.name,
|
||||
toolgroup_id=toolgroup_id,
|
||||
description=tool_def.description or "",
|
||||
|
@ -498,7 +539,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
await self.register_object(tool)
|
||||
|
||||
await self.dist_registry.register(
|
||||
ToolGroup(
|
||||
ToolGroupWithACL(
|
||||
identifier=toolgroup_id,
|
||||
provider_id=provider_id,
|
||||
provider_resource_id=toolgroup_id,
|
||||
|
@ -511,7 +552,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
tool_group = await self.get_tool_group(toolgroup_id)
|
||||
if tool_group is None:
|
||||
raise ValueError(f"Tool group {toolgroup_id} not found")
|
||||
tools = await self.list_tools(toolgroup_id).data
|
||||
tools = (await self.list_tools(toolgroup_id)).data
|
||||
for tool in tools:
|
||||
await self.unregister_object(tool)
|
||||
await self.unregister_object(tool_group)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue