mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-07 11:08:20 +00:00
update routing table registration types
This commit is contained in:
parent
e33a04d1f2
commit
2692d73d7d
2 changed files with 24 additions and 16 deletions
|
@ -142,14 +142,14 @@ class ToolGroupWithACL(ToolGroup, ResourceWithACL):
|
||||||
|
|
||||||
|
|
||||||
RoutableObject = Union[
|
RoutableObject = Union[
|
||||||
ModelWithACL,
|
Model,
|
||||||
ShieldWithACL,
|
Shield,
|
||||||
VectorDBWithACL,
|
VectorDB,
|
||||||
DatasetWithACL,
|
Dataset,
|
||||||
ScoringFnWithACL,
|
ScoringFn,
|
||||||
BenchmarkWithACL,
|
Benchmark,
|
||||||
ToolWithACL,
|
Tool,
|
||||||
ToolGroupWithACL,
|
ToolGroup,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -42,9 +42,17 @@ from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorD
|
||||||
from llama_stack.distribution.access_control import check_access
|
from llama_stack.distribution.access_control import check_access
|
||||||
from llama_stack.distribution.datatypes import (
|
from llama_stack.distribution.datatypes import (
|
||||||
AccessAttributes,
|
AccessAttributes,
|
||||||
|
BenchmarkWithACL,
|
||||||
|
DatasetWithACL,
|
||||||
|
ModelWithACL,
|
||||||
RoutableObject,
|
RoutableObject,
|
||||||
RoutableObjectWithProvider,
|
RoutableObjectWithProvider,
|
||||||
RoutedProtocol,
|
RoutedProtocol,
|
||||||
|
ScoringFnWithACL,
|
||||||
|
ShieldWithACL,
|
||||||
|
ToolGroupWithACL,
|
||||||
|
ToolWithACL,
|
||||||
|
VectorDBWithACL,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.request_headers import get_auth_attributes
|
from llama_stack.distribution.request_headers import get_auth_attributes
|
||||||
from llama_stack.distribution.store import DistributionRegistry
|
from llama_stack.distribution.store import DistributionRegistry
|
||||||
|
@ -270,7 +278,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
model_type = ModelType.llm
|
model_type = ModelType.llm
|
||||||
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
|
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
|
||||||
raise ValueError("Embedding model must have an embedding dimension in its metadata")
|
raise ValueError("Embedding model must have an embedding dimension in its metadata")
|
||||||
model = Model(
|
model = ModelWithACL(
|
||||||
identifier=model_id,
|
identifier=model_id,
|
||||||
provider_resource_id=provider_model_id,
|
provider_resource_id=provider_model_id,
|
||||||
provider_id=provider_id,
|
provider_id=provider_id,
|
||||||
|
@ -316,7 +324,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||||
)
|
)
|
||||||
if params is None:
|
if params is None:
|
||||||
params = {}
|
params = {}
|
||||||
shield = Shield(
|
shield = ShieldWithACL(
|
||||||
identifier=shield_id,
|
identifier=shield_id,
|
||||||
provider_resource_id=provider_shield_id,
|
provider_resource_id=provider_shield_id,
|
||||||
provider_id=provider_id,
|
provider_id=provider_id,
|
||||||
|
@ -370,7 +378,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
||||||
"embedding_model": embedding_model,
|
"embedding_model": embedding_model,
|
||||||
"embedding_dimension": model.metadata["embedding_dimension"],
|
"embedding_dimension": model.metadata["embedding_dimension"],
|
||||||
}
|
}
|
||||||
vector_db = TypeAdapter(VectorDB).validate_python(vector_db_data)
|
vector_db = TypeAdapter(VectorDBWithACL).validate_python(vector_db_data)
|
||||||
await self.register_object(vector_db)
|
await self.register_object(vector_db)
|
||||||
return vector_db
|
return vector_db
|
||||||
|
|
||||||
|
@ -418,7 +426,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
||||||
if metadata is None:
|
if metadata is None:
|
||||||
metadata = {}
|
metadata = {}
|
||||||
|
|
||||||
dataset = Dataset(
|
dataset = DatasetWithACL(
|
||||||
identifier=dataset_id,
|
identifier=dataset_id,
|
||||||
provider_resource_id=provider_dataset_id,
|
provider_resource_id=provider_dataset_id,
|
||||||
provider_id=provider_id,
|
provider_id=provider_id,
|
||||||
|
@ -465,7 +473,7 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"No provider specified and multiple providers available. Please specify a provider_id."
|
"No provider specified and multiple providers available. Please specify a provider_id."
|
||||||
)
|
)
|
||||||
scoring_fn = ScoringFn(
|
scoring_fn = ScoringFnWithACL(
|
||||||
identifier=scoring_fn_id,
|
identifier=scoring_fn_id,
|
||||||
description=description,
|
description=description,
|
||||||
return_type=return_type,
|
return_type=return_type,
|
||||||
|
@ -507,7 +515,7 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
|
||||||
)
|
)
|
||||||
if provider_benchmark_id is None:
|
if provider_benchmark_id is None:
|
||||||
provider_benchmark_id = benchmark_id
|
provider_benchmark_id = benchmark_id
|
||||||
benchmark = Benchmark(
|
benchmark = BenchmarkWithACL(
|
||||||
identifier=benchmark_id,
|
identifier=benchmark_id,
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
scoring_functions=scoring_functions,
|
scoring_functions=scoring_functions,
|
||||||
|
@ -550,7 +558,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
|
|
||||||
for tool_def in tool_defs:
|
for tool_def in tool_defs:
|
||||||
tools.append(
|
tools.append(
|
||||||
Tool(
|
ToolWithACL(
|
||||||
identifier=tool_def.name,
|
identifier=tool_def.name,
|
||||||
toolgroup_id=toolgroup_id,
|
toolgroup_id=toolgroup_id,
|
||||||
description=tool_def.description or "",
|
description=tool_def.description or "",
|
||||||
|
@ -575,7 +583,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
await self.register_object(tool)
|
await self.register_object(tool)
|
||||||
|
|
||||||
await self.dist_registry.register(
|
await self.dist_registry.register(
|
||||||
ToolGroup(
|
ToolGroupWithACL(
|
||||||
identifier=toolgroup_id,
|
identifier=toolgroup_id,
|
||||||
provider_id=provider_id,
|
provider_id=provider_id,
|
||||||
provider_resource_id=toolgroup_id,
|
provider_resource_id=toolgroup_id,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue