mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-01 16:24:44 +00:00
Merge branch 'main' into cprint
This commit is contained in:
commit
d876aa1eb4
22 changed files with 1221 additions and 1041 deletions
|
@ -22,7 +22,11 @@ from docutils import nodes
|
||||||
# Read version from pyproject.toml
|
# Read version from pyproject.toml
|
||||||
with Path(__file__).parent.parent.parent.joinpath("pyproject.toml").open("rb") as f:
|
with Path(__file__).parent.parent.parent.joinpath("pyproject.toml").open("rb") as f:
|
||||||
pypi_url = "https://pypi.org/pypi/llama-stack/json"
|
pypi_url = "https://pypi.org/pypi/llama-stack/json"
|
||||||
version_tag = json.loads(requests.get(pypi_url).text)["info"]["version"]
|
headers = {
|
||||||
|
'User-Agent': 'pip/23.0.1 (python 3.11)', # Mimic pip's user agent
|
||||||
|
'Accept': 'application/json'
|
||||||
|
}
|
||||||
|
version_tag = json.loads(requests.get(pypi_url, headers=headers).text)["info"]["version"]
|
||||||
print(f"{version_tag=}")
|
print(f"{version_tag=}")
|
||||||
|
|
||||||
# generate the full link including text and url here
|
# generate the full link including text and url here
|
||||||
|
|
|
@ -12,16 +12,6 @@ from llama_stack.distribution.store import DistributionRegistry
|
||||||
from llama_stack.providers.datatypes import Api, RoutingTable
|
from llama_stack.providers.datatypes import Api, RoutingTable
|
||||||
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
||||||
|
|
||||||
from .routing_tables import (
|
|
||||||
BenchmarksRoutingTable,
|
|
||||||
DatasetsRoutingTable,
|
|
||||||
ModelsRoutingTable,
|
|
||||||
ScoringFunctionsRoutingTable,
|
|
||||||
ShieldsRoutingTable,
|
|
||||||
ToolGroupsRoutingTable,
|
|
||||||
VectorDBsRoutingTable,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def get_routing_table_impl(
|
async def get_routing_table_impl(
|
||||||
api: Api,
|
api: Api,
|
||||||
|
@ -29,6 +19,14 @@ async def get_routing_table_impl(
|
||||||
_deps,
|
_deps,
|
||||||
dist_registry: DistributionRegistry,
|
dist_registry: DistributionRegistry,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
|
from ..routing_tables.benchmarks import BenchmarksRoutingTable
|
||||||
|
from ..routing_tables.datasets import DatasetsRoutingTable
|
||||||
|
from ..routing_tables.models import ModelsRoutingTable
|
||||||
|
from ..routing_tables.scoring_functions import ScoringFunctionsRoutingTable
|
||||||
|
from ..routing_tables.shields import ShieldsRoutingTable
|
||||||
|
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
|
||||||
|
from ..routing_tables.vector_dbs import VectorDBsRoutingTable
|
||||||
|
|
||||||
api_to_tables = {
|
api_to_tables = {
|
||||||
"vector_dbs": VectorDBsRoutingTable,
|
"vector_dbs": VectorDBsRoutingTable,
|
||||||
"models": ModelsRoutingTable,
|
"models": ModelsRoutingTable,
|
||||||
|
@ -50,15 +48,12 @@ async def get_routing_table_impl(
|
||||||
async def get_auto_router_impl(
|
async def get_auto_router_impl(
|
||||||
api: Api, routing_table: RoutingTable, deps: dict[str, Any], run_config: StackRunConfig
|
api: Api, routing_table: RoutingTable, deps: dict[str, Any], run_config: StackRunConfig
|
||||||
) -> Any:
|
) -> Any:
|
||||||
from .routers import (
|
from .datasets import DatasetIORouter
|
||||||
DatasetIORouter,
|
from .eval_scoring import EvalRouter, ScoringRouter
|
||||||
EvalRouter,
|
from .inference import InferenceRouter
|
||||||
InferenceRouter,
|
from .safety import SafetyRouter
|
||||||
SafetyRouter,
|
from .tool_runtime import ToolRuntimeRouter
|
||||||
ScoringRouter,
|
from .vector_io import VectorIORouter
|
||||||
ToolRuntimeRouter,
|
|
||||||
VectorIORouter,
|
|
||||||
)
|
|
||||||
|
|
||||||
api_to_routers = {
|
api_to_routers = {
|
||||||
"vector_io": VectorIORouter,
|
"vector_io": VectorIORouter,
|
||||||
|
|
71
llama_stack/distribution/routers/datasets.py
Normal file
71
llama_stack/distribution/routers/datasets.py
Normal file
|
@ -0,0 +1,71 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_stack.apis.common.responses import PaginatedResponse
|
||||||
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
|
from llama_stack.apis.datasets import DatasetPurpose, DataSource
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
from llama_stack.providers.datatypes import RoutingTable
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetIORouter(DatasetIO):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
routing_table: RoutingTable,
|
||||||
|
) -> None:
|
||||||
|
logger.debug("Initializing DatasetIORouter")
|
||||||
|
self.routing_table = routing_table
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
logger.debug("DatasetIORouter.initialize")
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
logger.debug("DatasetIORouter.shutdown")
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def register_dataset(
|
||||||
|
self,
|
||||||
|
purpose: DatasetPurpose,
|
||||||
|
source: DataSource,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
dataset_id: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
logger.debug(
|
||||||
|
f"DatasetIORouter.register_dataset: {purpose=} {source=} {metadata=} {dataset_id=}",
|
||||||
|
)
|
||||||
|
await self.routing_table.register_dataset(
|
||||||
|
purpose=purpose,
|
||||||
|
source=source,
|
||||||
|
metadata=metadata,
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def iterrows(
|
||||||
|
self,
|
||||||
|
dataset_id: str,
|
||||||
|
start_index: int | None = None,
|
||||||
|
limit: int | None = None,
|
||||||
|
) -> PaginatedResponse:
|
||||||
|
logger.debug(
|
||||||
|
f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}",
|
||||||
|
)
|
||||||
|
return await self.routing_table.get_provider_impl(dataset_id).iterrows(
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
start_index=start_index,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
|
||||||
|
logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
|
||||||
|
return await self.routing_table.get_provider_impl(dataset_id).append_rows(
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
rows=rows,
|
||||||
|
)
|
148
llama_stack/distribution/routers/eval_scoring.py
Normal file
148
llama_stack/distribution/routers/eval_scoring.py
Normal file
|
@ -0,0 +1,148 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse, Job
|
||||||
|
from llama_stack.apis.scoring import (
|
||||||
|
ScoreBatchResponse,
|
||||||
|
ScoreResponse,
|
||||||
|
Scoring,
|
||||||
|
ScoringFnParams,
|
||||||
|
)
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
from llama_stack.providers.datatypes import RoutingTable
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
|
class ScoringRouter(Scoring):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
routing_table: RoutingTable,
|
||||||
|
) -> None:
|
||||||
|
logger.debug("Initializing ScoringRouter")
|
||||||
|
self.routing_table = routing_table
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
logger.debug("ScoringRouter.initialize")
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
logger.debug("ScoringRouter.shutdown")
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def score_batch(
|
||||||
|
self,
|
||||||
|
dataset_id: str,
|
||||||
|
scoring_functions: dict[str, ScoringFnParams | None] = None,
|
||||||
|
save_results_dataset: bool = False,
|
||||||
|
) -> ScoreBatchResponse:
|
||||||
|
logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
|
||||||
|
res = {}
|
||||||
|
for fn_identifier in scoring_functions.keys():
|
||||||
|
score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch(
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
||||||
|
)
|
||||||
|
res.update(score_response.results)
|
||||||
|
|
||||||
|
if save_results_dataset:
|
||||||
|
raise NotImplementedError("Save results dataset not implemented yet")
|
||||||
|
|
||||||
|
return ScoreBatchResponse(
|
||||||
|
results=res,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def score(
|
||||||
|
self,
|
||||||
|
input_rows: list[dict[str, Any]],
|
||||||
|
scoring_functions: dict[str, ScoringFnParams | None] = None,
|
||||||
|
) -> ScoreResponse:
|
||||||
|
logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions")
|
||||||
|
res = {}
|
||||||
|
# look up and map each scoring function to its provider impl
|
||||||
|
for fn_identifier in scoring_functions.keys():
|
||||||
|
score_response = await self.routing_table.get_provider_impl(fn_identifier).score(
|
||||||
|
input_rows=input_rows,
|
||||||
|
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
||||||
|
)
|
||||||
|
res.update(score_response.results)
|
||||||
|
|
||||||
|
return ScoreResponse(results=res)
|
||||||
|
|
||||||
|
|
||||||
|
class EvalRouter(Eval):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
routing_table: RoutingTable,
|
||||||
|
) -> None:
|
||||||
|
logger.debug("Initializing EvalRouter")
|
||||||
|
self.routing_table = routing_table
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
logger.debug("EvalRouter.initialize")
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
logger.debug("EvalRouter.shutdown")
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def run_eval(
|
||||||
|
self,
|
||||||
|
benchmark_id: str,
|
||||||
|
benchmark_config: BenchmarkConfig,
|
||||||
|
) -> Job:
|
||||||
|
logger.debug(f"EvalRouter.run_eval: {benchmark_id}")
|
||||||
|
return await self.routing_table.get_provider_impl(benchmark_id).run_eval(
|
||||||
|
benchmark_id=benchmark_id,
|
||||||
|
benchmark_config=benchmark_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def evaluate_rows(
|
||||||
|
self,
|
||||||
|
benchmark_id: str,
|
||||||
|
input_rows: list[dict[str, Any]],
|
||||||
|
scoring_functions: list[str],
|
||||||
|
benchmark_config: BenchmarkConfig,
|
||||||
|
) -> EvaluateResponse:
|
||||||
|
logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
|
||||||
|
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
|
||||||
|
benchmark_id=benchmark_id,
|
||||||
|
input_rows=input_rows,
|
||||||
|
scoring_functions=scoring_functions,
|
||||||
|
benchmark_config=benchmark_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def job_status(
|
||||||
|
self,
|
||||||
|
benchmark_id: str,
|
||||||
|
job_id: str,
|
||||||
|
) -> Job:
|
||||||
|
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
|
||||||
|
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)
|
||||||
|
|
||||||
|
async def job_cancel(
|
||||||
|
self,
|
||||||
|
benchmark_id: str,
|
||||||
|
job_id: str,
|
||||||
|
) -> None:
|
||||||
|
logger.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
|
||||||
|
await self.routing_table.get_provider_impl(benchmark_id).job_cancel(
|
||||||
|
benchmark_id,
|
||||||
|
job_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def job_result(
|
||||||
|
self,
|
||||||
|
benchmark_id: str,
|
||||||
|
job_id: str,
|
||||||
|
) -> EvaluateResponse:
|
||||||
|
logger.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}")
|
||||||
|
return await self.routing_table.get_provider_impl(benchmark_id).job_result(
|
||||||
|
benchmark_id,
|
||||||
|
job_id,
|
||||||
|
)
|
|
@ -14,14 +14,9 @@ from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToo
|
||||||
from pydantic import Field, TypeAdapter
|
from pydantic import Field, TypeAdapter
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
URL,
|
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
InterleavedContentItem,
|
InterleavedContentItem,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.common.responses import PaginatedResponse
|
|
||||||
from llama_stack.apis.datasetio import DatasetIO
|
|
||||||
from llama_stack.apis.datasets import DatasetPurpose, DataSource
|
|
||||||
from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse, Job
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
BatchChatCompletionResponse,
|
BatchChatCompletionResponse,
|
||||||
BatchCompletionResponse,
|
BatchCompletionResponse,
|
||||||
|
@ -54,24 +49,7 @@ from llama_stack.apis.inference.inference import (
|
||||||
OpenAIResponseFormatParam,
|
OpenAIResponseFormatParam,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import Model, ModelType
|
from llama_stack.apis.models import Model, ModelType
|
||||||
from llama_stack.apis.safety import RunShieldResponse, Safety
|
|
||||||
from llama_stack.apis.scoring import (
|
|
||||||
ScoreBatchResponse,
|
|
||||||
ScoreResponse,
|
|
||||||
Scoring,
|
|
||||||
ScoringFnParams,
|
|
||||||
)
|
|
||||||
from llama_stack.apis.shields import Shield
|
|
||||||
from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry
|
from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry
|
||||||
from llama_stack.apis.tools import (
|
|
||||||
ListToolDefsResponse,
|
|
||||||
RAGDocument,
|
|
||||||
RAGQueryConfig,
|
|
||||||
RAGQueryResult,
|
|
||||||
RAGToolRuntime,
|
|
||||||
ToolRuntime,
|
|
||||||
)
|
|
||||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
||||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||||
|
@ -83,62 +61,6 @@ from llama_stack.providers.utils.telemetry.tracing import get_current_span
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
class VectorIORouter(VectorIO):
|
|
||||||
"""Routes to an provider based on the vector db identifier"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
routing_table: RoutingTable,
|
|
||||||
) -> None:
|
|
||||||
logger.debug("Initializing VectorIORouter")
|
|
||||||
self.routing_table = routing_table
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
|
||||||
logger.debug("VectorIORouter.initialize")
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
|
||||||
logger.debug("VectorIORouter.shutdown")
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def register_vector_db(
|
|
||||||
self,
|
|
||||||
vector_db_id: str,
|
|
||||||
embedding_model: str,
|
|
||||||
embedding_dimension: int | None = 384,
|
|
||||||
provider_id: str | None = None,
|
|
||||||
provider_vector_db_id: str | None = None,
|
|
||||||
) -> None:
|
|
||||||
logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
|
|
||||||
await self.routing_table.register_vector_db(
|
|
||||||
vector_db_id,
|
|
||||||
embedding_model,
|
|
||||||
embedding_dimension,
|
|
||||||
provider_id,
|
|
||||||
provider_vector_db_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def insert_chunks(
|
|
||||||
self,
|
|
||||||
vector_db_id: str,
|
|
||||||
chunks: list[Chunk],
|
|
||||||
ttl_seconds: int | None = None,
|
|
||||||
) -> None:
|
|
||||||
logger.debug(
|
|
||||||
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
|
|
||||||
)
|
|
||||||
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds)
|
|
||||||
|
|
||||||
async def query_chunks(
|
|
||||||
self,
|
|
||||||
vector_db_id: str,
|
|
||||||
query: InterleavedContent,
|
|
||||||
params: dict[str, Any] | None = None,
|
|
||||||
) -> QueryChunksResponse:
|
|
||||||
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
|
|
||||||
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
|
|
||||||
|
|
||||||
|
|
||||||
class InferenceRouter(Inference):
|
class InferenceRouter(Inference):
|
||||||
"""Routes to an provider based on the model"""
|
"""Routes to an provider based on the model"""
|
||||||
|
|
||||||
|
@ -671,295 +593,3 @@ class InferenceRouter(Inference):
|
||||||
status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"
|
status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"
|
||||||
)
|
)
|
||||||
return health_statuses
|
return health_statuses
|
||||||
|
|
||||||
|
|
||||||
class SafetyRouter(Safety):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
routing_table: RoutingTable,
|
|
||||||
) -> None:
|
|
||||||
logger.debug("Initializing SafetyRouter")
|
|
||||||
self.routing_table = routing_table
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
|
||||||
logger.debug("SafetyRouter.initialize")
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
|
||||||
logger.debug("SafetyRouter.shutdown")
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def register_shield(
|
|
||||||
self,
|
|
||||||
shield_id: str,
|
|
||||||
provider_shield_id: str | None = None,
|
|
||||||
provider_id: str | None = None,
|
|
||||||
params: dict[str, Any] | None = None,
|
|
||||||
) -> Shield:
|
|
||||||
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
|
|
||||||
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
|
|
||||||
|
|
||||||
async def run_shield(
|
|
||||||
self,
|
|
||||||
shield_id: str,
|
|
||||||
messages: list[Message],
|
|
||||||
params: dict[str, Any] = None,
|
|
||||||
) -> RunShieldResponse:
|
|
||||||
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
|
|
||||||
return await self.routing_table.get_provider_impl(shield_id).run_shield(
|
|
||||||
shield_id=shield_id,
|
|
||||||
messages=messages,
|
|
||||||
params=params,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetIORouter(DatasetIO):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
routing_table: RoutingTable,
|
|
||||||
) -> None:
|
|
||||||
logger.debug("Initializing DatasetIORouter")
|
|
||||||
self.routing_table = routing_table
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
|
||||||
logger.debug("DatasetIORouter.initialize")
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
|
||||||
logger.debug("DatasetIORouter.shutdown")
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def register_dataset(
|
|
||||||
self,
|
|
||||||
purpose: DatasetPurpose,
|
|
||||||
source: DataSource,
|
|
||||||
metadata: dict[str, Any] | None = None,
|
|
||||||
dataset_id: str | None = None,
|
|
||||||
) -> None:
|
|
||||||
logger.debug(
|
|
||||||
f"DatasetIORouter.register_dataset: {purpose=} {source=} {metadata=} {dataset_id=}",
|
|
||||||
)
|
|
||||||
await self.routing_table.register_dataset(
|
|
||||||
purpose=purpose,
|
|
||||||
source=source,
|
|
||||||
metadata=metadata,
|
|
||||||
dataset_id=dataset_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def iterrows(
|
|
||||||
self,
|
|
||||||
dataset_id: str,
|
|
||||||
start_index: int | None = None,
|
|
||||||
limit: int | None = None,
|
|
||||||
) -> PaginatedResponse:
|
|
||||||
logger.debug(
|
|
||||||
f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}",
|
|
||||||
)
|
|
||||||
return await self.routing_table.get_provider_impl(dataset_id).iterrows(
|
|
||||||
dataset_id=dataset_id,
|
|
||||||
start_index=start_index,
|
|
||||||
limit=limit,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
|
|
||||||
logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
|
|
||||||
return await self.routing_table.get_provider_impl(dataset_id).append_rows(
|
|
||||||
dataset_id=dataset_id,
|
|
||||||
rows=rows,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ScoringRouter(Scoring):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
routing_table: RoutingTable,
|
|
||||||
) -> None:
|
|
||||||
logger.debug("Initializing ScoringRouter")
|
|
||||||
self.routing_table = routing_table
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
|
||||||
logger.debug("ScoringRouter.initialize")
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
|
||||||
logger.debug("ScoringRouter.shutdown")
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def score_batch(
|
|
||||||
self,
|
|
||||||
dataset_id: str,
|
|
||||||
scoring_functions: dict[str, ScoringFnParams | None] = None,
|
|
||||||
save_results_dataset: bool = False,
|
|
||||||
) -> ScoreBatchResponse:
|
|
||||||
logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
|
|
||||||
res = {}
|
|
||||||
for fn_identifier in scoring_functions.keys():
|
|
||||||
score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch(
|
|
||||||
dataset_id=dataset_id,
|
|
||||||
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
|
||||||
)
|
|
||||||
res.update(score_response.results)
|
|
||||||
|
|
||||||
if save_results_dataset:
|
|
||||||
raise NotImplementedError("Save results dataset not implemented yet")
|
|
||||||
|
|
||||||
return ScoreBatchResponse(
|
|
||||||
results=res,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def score(
|
|
||||||
self,
|
|
||||||
input_rows: list[dict[str, Any]],
|
|
||||||
scoring_functions: dict[str, ScoringFnParams | None] = None,
|
|
||||||
) -> ScoreResponse:
|
|
||||||
logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions")
|
|
||||||
res = {}
|
|
||||||
# look up and map each scoring function to its provider impl
|
|
||||||
for fn_identifier in scoring_functions.keys():
|
|
||||||
score_response = await self.routing_table.get_provider_impl(fn_identifier).score(
|
|
||||||
input_rows=input_rows,
|
|
||||||
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
|
||||||
)
|
|
||||||
res.update(score_response.results)
|
|
||||||
|
|
||||||
return ScoreResponse(results=res)
|
|
||||||
|
|
||||||
|
|
||||||
class EvalRouter(Eval):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
routing_table: RoutingTable,
|
|
||||||
) -> None:
|
|
||||||
logger.debug("Initializing EvalRouter")
|
|
||||||
self.routing_table = routing_table
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
|
||||||
logger.debug("EvalRouter.initialize")
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
|
||||||
logger.debug("EvalRouter.shutdown")
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def run_eval(
|
|
||||||
self,
|
|
||||||
benchmark_id: str,
|
|
||||||
benchmark_config: BenchmarkConfig,
|
|
||||||
) -> Job:
|
|
||||||
logger.debug(f"EvalRouter.run_eval: {benchmark_id}")
|
|
||||||
return await self.routing_table.get_provider_impl(benchmark_id).run_eval(
|
|
||||||
benchmark_id=benchmark_id,
|
|
||||||
benchmark_config=benchmark_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def evaluate_rows(
|
|
||||||
self,
|
|
||||||
benchmark_id: str,
|
|
||||||
input_rows: list[dict[str, Any]],
|
|
||||||
scoring_functions: list[str],
|
|
||||||
benchmark_config: BenchmarkConfig,
|
|
||||||
) -> EvaluateResponse:
|
|
||||||
logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
|
|
||||||
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
|
|
||||||
benchmark_id=benchmark_id,
|
|
||||||
input_rows=input_rows,
|
|
||||||
scoring_functions=scoring_functions,
|
|
||||||
benchmark_config=benchmark_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def job_status(
|
|
||||||
self,
|
|
||||||
benchmark_id: str,
|
|
||||||
job_id: str,
|
|
||||||
) -> Job:
|
|
||||||
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
|
|
||||||
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)
|
|
||||||
|
|
||||||
async def job_cancel(
|
|
||||||
self,
|
|
||||||
benchmark_id: str,
|
|
||||||
job_id: str,
|
|
||||||
) -> None:
|
|
||||||
logger.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
|
|
||||||
await self.routing_table.get_provider_impl(benchmark_id).job_cancel(
|
|
||||||
benchmark_id,
|
|
||||||
job_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def job_result(
|
|
||||||
self,
|
|
||||||
benchmark_id: str,
|
|
||||||
job_id: str,
|
|
||||||
) -> EvaluateResponse:
|
|
||||||
logger.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}")
|
|
||||||
return await self.routing_table.get_provider_impl(benchmark_id).job_result(
|
|
||||||
benchmark_id,
|
|
||||||
job_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ToolRuntimeRouter(ToolRuntime):
|
|
||||||
class RagToolImpl(RAGToolRuntime):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
routing_table: RoutingTable,
|
|
||||||
) -> None:
|
|
||||||
logger.debug("Initializing ToolRuntimeRouter.RagToolImpl")
|
|
||||||
self.routing_table = routing_table
|
|
||||||
|
|
||||||
async def query(
|
|
||||||
self,
|
|
||||||
content: InterleavedContent,
|
|
||||||
vector_db_ids: list[str],
|
|
||||||
query_config: RAGQueryConfig | None = None,
|
|
||||||
) -> RAGQueryResult:
|
|
||||||
logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
|
|
||||||
return await self.routing_table.get_provider_impl("knowledge_search").query(
|
|
||||||
content, vector_db_ids, query_config
|
|
||||||
)
|
|
||||||
|
|
||||||
async def insert(
|
|
||||||
self,
|
|
||||||
documents: list[RAGDocument],
|
|
||||||
vector_db_id: str,
|
|
||||||
chunk_size_in_tokens: int = 512,
|
|
||||||
) -> None:
|
|
||||||
logger.debug(
|
|
||||||
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
|
|
||||||
)
|
|
||||||
return await self.routing_table.get_provider_impl("insert_into_memory").insert(
|
|
||||||
documents, vector_db_id, chunk_size_in_tokens
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
routing_table: RoutingTable,
|
|
||||||
) -> None:
|
|
||||||
logger.debug("Initializing ToolRuntimeRouter")
|
|
||||||
self.routing_table = routing_table
|
|
||||||
|
|
||||||
# HACK ALERT this should be in sync with "get_all_api_endpoints()"
|
|
||||||
self.rag_tool = self.RagToolImpl(routing_table)
|
|
||||||
for method in ("query", "insert"):
|
|
||||||
setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method))
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
|
||||||
logger.debug("ToolRuntimeRouter.initialize")
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
|
||||||
logger.debug("ToolRuntimeRouter.shutdown")
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> Any:
|
|
||||||
logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}")
|
|
||||||
return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
|
|
||||||
tool_name=tool_name,
|
|
||||||
kwargs=kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def list_runtime_tools(
|
|
||||||
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
|
||||||
) -> ListToolDefsResponse:
|
|
||||||
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
|
|
||||||
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)
|
|
|
@ -1,634 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from pydantic import TypeAdapter
|
|
||||||
|
|
||||||
from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse
|
|
||||||
from llama_stack.apis.common.content_types import URL
|
|
||||||
from llama_stack.apis.common.type_system import ParamType
|
|
||||||
from llama_stack.apis.datasets import (
|
|
||||||
Dataset,
|
|
||||||
DatasetPurpose,
|
|
||||||
Datasets,
|
|
||||||
DatasetType,
|
|
||||||
DataSource,
|
|
||||||
ListDatasetsResponse,
|
|
||||||
RowsDataSource,
|
|
||||||
URIDataSource,
|
|
||||||
)
|
|
||||||
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel
|
|
||||||
from llama_stack.apis.resource import ResourceType
|
|
||||||
from llama_stack.apis.scoring_functions import (
|
|
||||||
ListScoringFunctionsResponse,
|
|
||||||
ScoringFn,
|
|
||||||
ScoringFnParams,
|
|
||||||
ScoringFunctions,
|
|
||||||
)
|
|
||||||
from llama_stack.apis.shields import ListShieldsResponse, Shield, Shields
|
|
||||||
from llama_stack.apis.tools import (
|
|
||||||
ListToolGroupsResponse,
|
|
||||||
ListToolsResponse,
|
|
||||||
Tool,
|
|
||||||
ToolGroup,
|
|
||||||
ToolGroups,
|
|
||||||
ToolHost,
|
|
||||||
)
|
|
||||||
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
|
|
||||||
from llama_stack.distribution.access_control import check_access
|
|
||||||
from llama_stack.distribution.datatypes import (
|
|
||||||
AccessAttributes,
|
|
||||||
BenchmarkWithACL,
|
|
||||||
DatasetWithACL,
|
|
||||||
ModelWithACL,
|
|
||||||
RoutableObject,
|
|
||||||
RoutableObjectWithProvider,
|
|
||||||
RoutedProtocol,
|
|
||||||
ScoringFnWithACL,
|
|
||||||
ShieldWithACL,
|
|
||||||
ToolGroupWithACL,
|
|
||||||
ToolWithACL,
|
|
||||||
VectorDBWithACL,
|
|
||||||
)
|
|
||||||
from llama_stack.distribution.request_headers import get_auth_attributes
|
|
||||||
from llama_stack.distribution.store import DistributionRegistry
|
|
||||||
from llama_stack.providers.datatypes import Api, RoutingTable
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def get_impl_api(p: Any) -> Api:
|
|
||||||
return p.__provider_spec__.api
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: this should return the registered object for all APIs
|
|
||||||
async def register_object_with_provider(obj: RoutableObject, p: Any) -> RoutableObject:
|
|
||||||
api = get_impl_api(p)
|
|
||||||
|
|
||||||
assert obj.provider_id != "remote", "Remote provider should not be registered"
|
|
||||||
|
|
||||||
if api == Api.inference:
|
|
||||||
return await p.register_model(obj)
|
|
||||||
elif api == Api.safety:
|
|
||||||
return await p.register_shield(obj)
|
|
||||||
elif api == Api.vector_io:
|
|
||||||
return await p.register_vector_db(obj)
|
|
||||||
elif api == Api.datasetio:
|
|
||||||
return await p.register_dataset(obj)
|
|
||||||
elif api == Api.scoring:
|
|
||||||
return await p.register_scoring_function(obj)
|
|
||||||
elif api == Api.eval:
|
|
||||||
return await p.register_benchmark(obj)
|
|
||||||
elif api == Api.tool_runtime:
|
|
||||||
return await p.register_tool(obj)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown API {api} for registering object with provider")
|
|
||||||
|
|
||||||
|
|
||||||
async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
|
|
||||||
api = get_impl_api(p)
|
|
||||||
if api == Api.vector_io:
|
|
||||||
return await p.unregister_vector_db(obj.identifier)
|
|
||||||
elif api == Api.inference:
|
|
||||||
return await p.unregister_model(obj.identifier)
|
|
||||||
elif api == Api.datasetio:
|
|
||||||
return await p.unregister_dataset(obj.identifier)
|
|
||||||
elif api == Api.tool_runtime:
|
|
||||||
return await p.unregister_tool(obj.identifier)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unregister not supported for {api}")
|
|
||||||
|
|
||||||
|
|
||||||
Registry = dict[str, list[RoutableObjectWithProvider]]
|
|
||||||
|
|
||||||
|
|
||||||
class CommonRoutingTableImpl(RoutingTable):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
impls_by_provider_id: dict[str, RoutedProtocol],
|
|
||||||
dist_registry: DistributionRegistry,
|
|
||||||
) -> None:
|
|
||||||
self.impls_by_provider_id = impls_by_provider_id
|
|
||||||
self.dist_registry = dist_registry
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
|
||||||
async def add_objects(objs: list[RoutableObjectWithProvider], provider_id: str, cls) -> None:
|
|
||||||
for obj in objs:
|
|
||||||
if cls is None:
|
|
||||||
obj.provider_id = provider_id
|
|
||||||
else:
|
|
||||||
# Create a copy of the model data and explicitly set provider_id
|
|
||||||
model_data = obj.model_dump()
|
|
||||||
model_data["provider_id"] = provider_id
|
|
||||||
obj = cls(**model_data)
|
|
||||||
await self.dist_registry.register(obj)
|
|
||||||
|
|
||||||
# Register all objects from providers
|
|
||||||
for pid, p in self.impls_by_provider_id.items():
|
|
||||||
api = get_impl_api(p)
|
|
||||||
if api == Api.inference:
|
|
||||||
p.model_store = self
|
|
||||||
elif api == Api.safety:
|
|
||||||
p.shield_store = self
|
|
||||||
elif api == Api.vector_io:
|
|
||||||
p.vector_db_store = self
|
|
||||||
elif api == Api.datasetio:
|
|
||||||
p.dataset_store = self
|
|
||||||
elif api == Api.scoring:
|
|
||||||
p.scoring_function_store = self
|
|
||||||
scoring_functions = await p.list_scoring_functions()
|
|
||||||
await add_objects(scoring_functions, pid, ScoringFn)
|
|
||||||
elif api == Api.eval:
|
|
||||||
p.benchmark_store = self
|
|
||||||
elif api == Api.tool_runtime:
|
|
||||||
p.tool_store = self
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
|
||||||
for p in self.impls_by_provider_id.values():
|
|
||||||
await p.shutdown()
|
|
||||||
|
|
||||||
def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
|
|
||||||
def apiname_object():
|
|
||||||
if isinstance(self, ModelsRoutingTable):
|
|
||||||
return ("Inference", "model")
|
|
||||||
elif isinstance(self, ShieldsRoutingTable):
|
|
||||||
return ("Safety", "shield")
|
|
||||||
elif isinstance(self, VectorDBsRoutingTable):
|
|
||||||
return ("VectorIO", "vector_db")
|
|
||||||
elif isinstance(self, DatasetsRoutingTable):
|
|
||||||
return ("DatasetIO", "dataset")
|
|
||||||
elif isinstance(self, ScoringFunctionsRoutingTable):
|
|
||||||
return ("Scoring", "scoring_function")
|
|
||||||
elif isinstance(self, BenchmarksRoutingTable):
|
|
||||||
return ("Eval", "benchmark")
|
|
||||||
elif isinstance(self, ToolGroupsRoutingTable):
|
|
||||||
return ("Tools", "tool")
|
|
||||||
else:
|
|
||||||
raise ValueError("Unknown routing table type")
|
|
||||||
|
|
||||||
apiname, objtype = apiname_object()
|
|
||||||
|
|
||||||
# Get objects from disk registry
|
|
||||||
obj = self.dist_registry.get_cached(objtype, routing_key)
|
|
||||||
if not obj:
|
|
||||||
provider_ids = list(self.impls_by_provider_id.keys())
|
|
||||||
if len(provider_ids) > 1:
|
|
||||||
provider_ids_str = f"any of the providers: {', '.join(provider_ids)}"
|
|
||||||
else:
|
|
||||||
provider_ids_str = f"provider: `{provider_ids[0]}`"
|
|
||||||
raise ValueError(
|
|
||||||
f"{objtype.capitalize()} `{routing_key}` not served by {provider_ids_str}. Make sure there is an {apiname} provider serving this {objtype}."
|
|
||||||
)
|
|
||||||
|
|
||||||
if not provider_id or provider_id == obj.provider_id:
|
|
||||||
return self.impls_by_provider_id[obj.provider_id]
|
|
||||||
|
|
||||||
raise ValueError(f"Provider not found for `{routing_key}`")
|
|
||||||
|
|
||||||
async def get_object_by_identifier(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
|
|
||||||
# Get from disk registry
|
|
||||||
obj = await self.dist_registry.get(type, identifier)
|
|
||||||
if not obj:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Check if user has permission to access this object
|
|
||||||
if not check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes()):
|
|
||||||
logger.debug(f"Access denied to {type} '{identifier}' based on attribute mismatch")
|
|
||||||
return None
|
|
||||||
|
|
||||||
return obj
|
|
||||||
|
|
||||||
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
|
|
||||||
await self.dist_registry.delete(obj.type, obj.identifier)
|
|
||||||
await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id])
|
|
||||||
|
|
||||||
async def register_object(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider:
|
|
||||||
# if provider_id is not specified, pick an arbitrary one from existing entries
|
|
||||||
if not obj.provider_id and len(self.impls_by_provider_id) > 0:
|
|
||||||
obj.provider_id = list(self.impls_by_provider_id.keys())[0]
|
|
||||||
|
|
||||||
if obj.provider_id not in self.impls_by_provider_id:
|
|
||||||
raise ValueError(f"Provider `{obj.provider_id}` not found")
|
|
||||||
|
|
||||||
p = self.impls_by_provider_id[obj.provider_id]
|
|
||||||
|
|
||||||
# If object supports access control but no attributes set, use creator's attributes
|
|
||||||
if not obj.access_attributes:
|
|
||||||
creator_attributes = get_auth_attributes()
|
|
||||||
if creator_attributes:
|
|
||||||
obj.access_attributes = AccessAttributes(**creator_attributes)
|
|
||||||
logger.info(f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity")
|
|
||||||
|
|
||||||
registered_obj = await register_object_with_provider(obj, p)
|
|
||||||
# TODO: This needs to be fixed for all APIs once they return the registered object
|
|
||||||
if obj.type == ResourceType.model.value:
|
|
||||||
await self.dist_registry.register(registered_obj)
|
|
||||||
return registered_obj
|
|
||||||
|
|
||||||
else:
|
|
||||||
await self.dist_registry.register(obj)
|
|
||||||
return obj
|
|
||||||
|
|
||||||
async def get_all_with_type(self, type: str) -> list[RoutableObjectWithProvider]:
|
|
||||||
objs = await self.dist_registry.get_all()
|
|
||||||
filtered_objs = [obj for obj in objs if obj.type == type]
|
|
||||||
|
|
||||||
# Apply attribute-based access control filtering
|
|
||||||
if filtered_objs:
|
|
||||||
filtered_objs = [
|
|
||||||
obj
|
|
||||||
for obj in filtered_objs
|
|
||||||
if check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes())
|
|
||||||
]
|
|
||||||
|
|
||||||
return filtered_objs
|
|
||||||
|
|
||||||
|
|
||||||
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|
||||||
async def list_models(self) -> ListModelsResponse:
|
|
||||||
return ListModelsResponse(data=await self.get_all_with_type("model"))
|
|
||||||
|
|
||||||
async def openai_list_models(self) -> OpenAIListModelsResponse:
|
|
||||||
models = await self.get_all_with_type("model")
|
|
||||||
openai_models = [
|
|
||||||
OpenAIModel(
|
|
||||||
id=model.identifier,
|
|
||||||
object="model",
|
|
||||||
created=int(time.time()),
|
|
||||||
owned_by="llama_stack",
|
|
||||||
)
|
|
||||||
for model in models
|
|
||||||
]
|
|
||||||
return OpenAIListModelsResponse(data=openai_models)
|
|
||||||
|
|
||||||
async def get_model(self, model_id: str) -> Model:
|
|
||||||
model = await self.get_object_by_identifier("model", model_id)
|
|
||||||
if model is None:
|
|
||||||
raise ValueError(f"Model '{model_id}' not found")
|
|
||||||
return model
|
|
||||||
|
|
||||||
async def register_model(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
provider_model_id: str | None = None,
|
|
||||||
provider_id: str | None = None,
|
|
||||||
metadata: dict[str, Any] | None = None,
|
|
||||||
model_type: ModelType | None = None,
|
|
||||||
) -> Model:
|
|
||||||
if provider_model_id is None:
|
|
||||||
provider_model_id = model_id
|
|
||||||
if provider_id is None:
|
|
||||||
# If provider_id not specified, use the only provider if it supports this model
|
|
||||||
if len(self.impls_by_provider_id) == 1:
|
|
||||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}"
|
|
||||||
)
|
|
||||||
if metadata is None:
|
|
||||||
metadata = {}
|
|
||||||
if model_type is None:
|
|
||||||
model_type = ModelType.llm
|
|
||||||
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
|
|
||||||
raise ValueError("Embedding model must have an embedding dimension in its metadata")
|
|
||||||
model = ModelWithACL(
|
|
||||||
identifier=model_id,
|
|
||||||
provider_resource_id=provider_model_id,
|
|
||||||
provider_id=provider_id,
|
|
||||||
metadata=metadata,
|
|
||||||
model_type=model_type,
|
|
||||||
)
|
|
||||||
registered_model = await self.register_object(model)
|
|
||||||
return registered_model
|
|
||||||
|
|
||||||
async def unregister_model(self, model_id: str) -> None:
|
|
||||||
existing_model = await self.get_model(model_id)
|
|
||||||
if existing_model is None:
|
|
||||||
raise ValueError(f"Model {model_id} not found")
|
|
||||||
await self.unregister_object(existing_model)
|
|
||||||
|
|
||||||
|
|
||||||
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
|
||||||
async def list_shields(self) -> ListShieldsResponse:
|
|
||||||
return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value))
|
|
||||||
|
|
||||||
async def get_shield(self, identifier: str) -> Shield:
|
|
||||||
shield = await self.get_object_by_identifier("shield", identifier)
|
|
||||||
if shield is None:
|
|
||||||
raise ValueError(f"Shield '{identifier}' not found")
|
|
||||||
return shield
|
|
||||||
|
|
||||||
async def register_shield(
|
|
||||||
self,
|
|
||||||
shield_id: str,
|
|
||||||
provider_shield_id: str | None = None,
|
|
||||||
provider_id: str | None = None,
|
|
||||||
params: dict[str, Any] | None = None,
|
|
||||||
) -> Shield:
|
|
||||||
if provider_shield_id is None:
|
|
||||||
provider_shield_id = shield_id
|
|
||||||
if provider_id is None:
|
|
||||||
# If provider_id not specified, use the only provider if it supports this shield type
|
|
||||||
if len(self.impls_by_provider_id) == 1:
|
|
||||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"No provider specified and multiple providers available. Please specify a provider_id."
|
|
||||||
)
|
|
||||||
if params is None:
|
|
||||||
params = {}
|
|
||||||
shield = ShieldWithACL(
|
|
||||||
identifier=shield_id,
|
|
||||||
provider_resource_id=provider_shield_id,
|
|
||||||
provider_id=provider_id,
|
|
||||||
params=params,
|
|
||||||
)
|
|
||||||
await self.register_object(shield)
|
|
||||||
return shield
|
|
||||||
|
|
||||||
|
|
||||||
class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|
||||||
async def list_vector_dbs(self) -> ListVectorDBsResponse:
|
|
||||||
return ListVectorDBsResponse(data=await self.get_all_with_type("vector_db"))
|
|
||||||
|
|
||||||
async def get_vector_db(self, vector_db_id: str) -> VectorDB:
|
|
||||||
vector_db = await self.get_object_by_identifier("vector_db", vector_db_id)
|
|
||||||
if vector_db is None:
|
|
||||||
raise ValueError(f"Vector DB '{vector_db_id}' not found")
|
|
||||||
return vector_db
|
|
||||||
|
|
||||||
async def register_vector_db(
|
|
||||||
self,
|
|
||||||
vector_db_id: str,
|
|
||||||
embedding_model: str,
|
|
||||||
embedding_dimension: int | None = 384,
|
|
||||||
provider_id: str | None = None,
|
|
||||||
provider_vector_db_id: str | None = None,
|
|
||||||
) -> VectorDB:
|
|
||||||
if provider_vector_db_id is None:
|
|
||||||
provider_vector_db_id = vector_db_id
|
|
||||||
if provider_id is None:
|
|
||||||
if len(self.impls_by_provider_id) > 0:
|
|
||||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
|
||||||
if len(self.impls_by_provider_id) > 1:
|
|
||||||
logger.warning(
|
|
||||||
f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError("No provider available. Please configure a vector_io provider.")
|
|
||||||
model = await self.get_object_by_identifier("model", embedding_model)
|
|
||||||
if model is None:
|
|
||||||
raise ValueError(f"Model {embedding_model} not found")
|
|
||||||
if model.model_type != ModelType.embedding:
|
|
||||||
raise ValueError(f"Model {embedding_model} is not an embedding model")
|
|
||||||
if "embedding_dimension" not in model.metadata:
|
|
||||||
raise ValueError(f"Model {embedding_model} does not have an embedding dimension")
|
|
||||||
vector_db_data = {
|
|
||||||
"identifier": vector_db_id,
|
|
||||||
"type": ResourceType.vector_db.value,
|
|
||||||
"provider_id": provider_id,
|
|
||||||
"provider_resource_id": provider_vector_db_id,
|
|
||||||
"embedding_model": embedding_model,
|
|
||||||
"embedding_dimension": model.metadata["embedding_dimension"],
|
|
||||||
}
|
|
||||||
vector_db = TypeAdapter(VectorDBWithACL).validate_python(vector_db_data)
|
|
||||||
await self.register_object(vector_db)
|
|
||||||
return vector_db
|
|
||||||
|
|
||||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
|
||||||
existing_vector_db = await self.get_vector_db(vector_db_id)
|
|
||||||
if existing_vector_db is None:
|
|
||||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
|
||||||
await self.unregister_object(existing_vector_db)
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
|
||||||
async def list_datasets(self) -> ListDatasetsResponse:
|
|
||||||
return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value))
|
|
||||||
|
|
||||||
async def get_dataset(self, dataset_id: str) -> Dataset:
|
|
||||||
dataset = await self.get_object_by_identifier("dataset", dataset_id)
|
|
||||||
if dataset is None:
|
|
||||||
raise ValueError(f"Dataset '{dataset_id}' not found")
|
|
||||||
return dataset
|
|
||||||
|
|
||||||
async def register_dataset(
|
|
||||||
self,
|
|
||||||
purpose: DatasetPurpose,
|
|
||||||
source: DataSource,
|
|
||||||
metadata: dict[str, Any] | None = None,
|
|
||||||
dataset_id: str | None = None,
|
|
||||||
) -> Dataset:
|
|
||||||
if isinstance(source, dict):
|
|
||||||
if source["type"] == "uri":
|
|
||||||
source = URIDataSource.parse_obj(source)
|
|
||||||
elif source["type"] == "rows":
|
|
||||||
source = RowsDataSource.parse_obj(source)
|
|
||||||
|
|
||||||
if not dataset_id:
|
|
||||||
dataset_id = f"dataset-{str(uuid.uuid4())}"
|
|
||||||
|
|
||||||
provider_dataset_id = dataset_id
|
|
||||||
|
|
||||||
# infer provider from source
|
|
||||||
if metadata:
|
|
||||||
if metadata.get("provider_id"):
|
|
||||||
provider_id = metadata.get("provider_id") # pass through from nvidia datasetio
|
|
||||||
elif source.type == DatasetType.rows.value:
|
|
||||||
provider_id = "localfs"
|
|
||||||
elif source.type == DatasetType.uri.value:
|
|
||||||
# infer provider from uri
|
|
||||||
if source.uri.startswith("huggingface"):
|
|
||||||
provider_id = "huggingface"
|
|
||||||
else:
|
|
||||||
provider_id = "localfs"
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown data source type: {source.type}")
|
|
||||||
|
|
||||||
if metadata is None:
|
|
||||||
metadata = {}
|
|
||||||
|
|
||||||
dataset = DatasetWithACL(
|
|
||||||
identifier=dataset_id,
|
|
||||||
provider_resource_id=provider_dataset_id,
|
|
||||||
provider_id=provider_id,
|
|
||||||
purpose=purpose,
|
|
||||||
source=source,
|
|
||||||
metadata=metadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
await self.register_object(dataset)
|
|
||||||
return dataset
|
|
||||||
|
|
||||||
async def unregister_dataset(self, dataset_id: str) -> None:
|
|
||||||
dataset = await self.get_dataset(dataset_id)
|
|
||||||
if dataset is None:
|
|
||||||
raise ValueError(f"Dataset {dataset_id} not found")
|
|
||||||
await self.unregister_object(dataset)
|
|
||||||
|
|
||||||
|
|
||||||
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
|
||||||
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
|
|
||||||
return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value))
|
|
||||||
|
|
||||||
async def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn:
|
|
||||||
scoring_fn = await self.get_object_by_identifier("scoring_function", scoring_fn_id)
|
|
||||||
if scoring_fn is None:
|
|
||||||
raise ValueError(f"Scoring function '{scoring_fn_id}' not found")
|
|
||||||
return scoring_fn
|
|
||||||
|
|
||||||
async def register_scoring_function(
|
|
||||||
self,
|
|
||||||
scoring_fn_id: str,
|
|
||||||
description: str,
|
|
||||||
return_type: ParamType,
|
|
||||||
provider_scoring_fn_id: str | None = None,
|
|
||||||
provider_id: str | None = None,
|
|
||||||
params: ScoringFnParams | None = None,
|
|
||||||
) -> None:
|
|
||||||
if provider_scoring_fn_id is None:
|
|
||||||
provider_scoring_fn_id = scoring_fn_id
|
|
||||||
if provider_id is None:
|
|
||||||
if len(self.impls_by_provider_id) == 1:
|
|
||||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"No provider specified and multiple providers available. Please specify a provider_id."
|
|
||||||
)
|
|
||||||
scoring_fn = ScoringFnWithACL(
|
|
||||||
identifier=scoring_fn_id,
|
|
||||||
description=description,
|
|
||||||
return_type=return_type,
|
|
||||||
provider_resource_id=provider_scoring_fn_id,
|
|
||||||
provider_id=provider_id,
|
|
||||||
params=params,
|
|
||||||
)
|
|
||||||
scoring_fn.provider_id = provider_id
|
|
||||||
await self.register_object(scoring_fn)
|
|
||||||
|
|
||||||
|
|
||||||
class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
|
|
||||||
async def list_benchmarks(self) -> ListBenchmarksResponse:
|
|
||||||
return ListBenchmarksResponse(data=await self.get_all_with_type("benchmark"))
|
|
||||||
|
|
||||||
async def get_benchmark(self, benchmark_id: str) -> Benchmark:
|
|
||||||
benchmark = await self.get_object_by_identifier("benchmark", benchmark_id)
|
|
||||||
if benchmark is None:
|
|
||||||
raise ValueError(f"Benchmark '{benchmark_id}' not found")
|
|
||||||
return benchmark
|
|
||||||
|
|
||||||
async def register_benchmark(
|
|
||||||
self,
|
|
||||||
benchmark_id: str,
|
|
||||||
dataset_id: str,
|
|
||||||
scoring_functions: list[str],
|
|
||||||
metadata: dict[str, Any] | None = None,
|
|
||||||
provider_benchmark_id: str | None = None,
|
|
||||||
provider_id: str | None = None,
|
|
||||||
) -> None:
|
|
||||||
if metadata is None:
|
|
||||||
metadata = {}
|
|
||||||
if provider_id is None:
|
|
||||||
if len(self.impls_by_provider_id) == 1:
|
|
||||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"No provider specified and multiple providers available. Please specify a provider_id."
|
|
||||||
)
|
|
||||||
if provider_benchmark_id is None:
|
|
||||||
provider_benchmark_id = benchmark_id
|
|
||||||
benchmark = BenchmarkWithACL(
|
|
||||||
identifier=benchmark_id,
|
|
||||||
dataset_id=dataset_id,
|
|
||||||
scoring_functions=scoring_functions,
|
|
||||||
metadata=metadata,
|
|
||||||
provider_id=provider_id,
|
|
||||||
provider_resource_id=provider_benchmark_id,
|
|
||||||
)
|
|
||||||
await self.register_object(benchmark)
|
|
||||||
|
|
||||||
|
|
||||||
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|
||||||
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
|
|
||||||
tools = await self.get_all_with_type("tool")
|
|
||||||
if toolgroup_id:
|
|
||||||
tools = [tool for tool in tools if tool.toolgroup_id == toolgroup_id]
|
|
||||||
return ListToolsResponse(data=tools)
|
|
||||||
|
|
||||||
async def list_tool_groups(self) -> ListToolGroupsResponse:
|
|
||||||
return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group"))
|
|
||||||
|
|
||||||
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
|
|
||||||
tool_group = await self.get_object_by_identifier("tool_group", toolgroup_id)
|
|
||||||
if tool_group is None:
|
|
||||||
raise ValueError(f"Tool group '{toolgroup_id}' not found")
|
|
||||||
return tool_group
|
|
||||||
|
|
||||||
async def get_tool(self, tool_name: str) -> Tool:
|
|
||||||
return await self.get_object_by_identifier("tool", tool_name)
|
|
||||||
|
|
||||||
async def register_tool_group(
|
|
||||||
self,
|
|
||||||
toolgroup_id: str,
|
|
||||||
provider_id: str,
|
|
||||||
mcp_endpoint: URL | None = None,
|
|
||||||
args: dict[str, Any] | None = None,
|
|
||||||
) -> None:
|
|
||||||
tools = []
|
|
||||||
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint)
|
|
||||||
tool_host = ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution
|
|
||||||
|
|
||||||
for tool_def in tool_defs.data:
|
|
||||||
tools.append(
|
|
||||||
ToolWithACL(
|
|
||||||
identifier=tool_def.name,
|
|
||||||
toolgroup_id=toolgroup_id,
|
|
||||||
description=tool_def.description or "",
|
|
||||||
parameters=tool_def.parameters or [],
|
|
||||||
provider_id=provider_id,
|
|
||||||
provider_resource_id=tool_def.name,
|
|
||||||
metadata=tool_def.metadata,
|
|
||||||
tool_host=tool_host,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
for tool in tools:
|
|
||||||
existing_tool = await self.get_tool(tool.identifier)
|
|
||||||
# Compare existing and new object if one exists
|
|
||||||
if existing_tool:
|
|
||||||
existing_dict = existing_tool.model_dump()
|
|
||||||
new_dict = tool.model_dump()
|
|
||||||
|
|
||||||
if existing_dict != new_dict:
|
|
||||||
raise ValueError(
|
|
||||||
f"Object {tool.identifier} already exists in registry. Please use a different identifier."
|
|
||||||
)
|
|
||||||
await self.register_object(tool)
|
|
||||||
|
|
||||||
await self.dist_registry.register(
|
|
||||||
ToolGroupWithACL(
|
|
||||||
identifier=toolgroup_id,
|
|
||||||
provider_id=provider_id,
|
|
||||||
provider_resource_id=toolgroup_id,
|
|
||||||
mcp_endpoint=mcp_endpoint,
|
|
||||||
args=args,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
|
||||||
tool_group = await self.get_tool_group(toolgroup_id)
|
|
||||||
if tool_group is None:
|
|
||||||
raise ValueError(f"Tool group {toolgroup_id} not found")
|
|
||||||
tools = await self.list_tools(toolgroup_id)
|
|
||||||
for tool in getattr(tools, "data", []):
|
|
||||||
await self.unregister_object(tool)
|
|
||||||
await self.unregister_object(tool_group)
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
|
||||||
pass
|
|
57
llama_stack/distribution/routers/safety.py
Normal file
57
llama_stack/distribution/routers/safety.py
Normal file
|
@ -0,0 +1,57 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import (
|
||||||
|
Message,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.safety import RunShieldResponse, Safety
|
||||||
|
from llama_stack.apis.shields import Shield
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
from llama_stack.providers.datatypes import RoutingTable
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
|
class SafetyRouter(Safety):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
routing_table: RoutingTable,
|
||||||
|
) -> None:
|
||||||
|
logger.debug("Initializing SafetyRouter")
|
||||||
|
self.routing_table = routing_table
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
logger.debug("SafetyRouter.initialize")
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
logger.debug("SafetyRouter.shutdown")
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def register_shield(
|
||||||
|
self,
|
||||||
|
shield_id: str,
|
||||||
|
provider_shield_id: str | None = None,
|
||||||
|
provider_id: str | None = None,
|
||||||
|
params: dict[str, Any] | None = None,
|
||||||
|
) -> Shield:
|
||||||
|
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
|
||||||
|
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
|
||||||
|
|
||||||
|
async def run_shield(
|
||||||
|
self,
|
||||||
|
shield_id: str,
|
||||||
|
messages: list[Message],
|
||||||
|
params: dict[str, Any] = None,
|
||||||
|
) -> RunShieldResponse:
|
||||||
|
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
|
||||||
|
return await self.routing_table.get_provider_impl(shield_id).run_shield(
|
||||||
|
shield_id=shield_id,
|
||||||
|
messages=messages,
|
||||||
|
params=params,
|
||||||
|
)
|
91
llama_stack/distribution/routers/tool_runtime.py
Normal file
91
llama_stack/distribution/routers/tool_runtime.py
Normal file
|
@ -0,0 +1,91 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_stack.apis.common.content_types import (
|
||||||
|
URL,
|
||||||
|
InterleavedContent,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.tools import (
|
||||||
|
ListToolDefsResponse,
|
||||||
|
RAGDocument,
|
||||||
|
RAGQueryConfig,
|
||||||
|
RAGQueryResult,
|
||||||
|
RAGToolRuntime,
|
||||||
|
ToolRuntime,
|
||||||
|
)
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
from llama_stack.providers.datatypes import RoutingTable
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
|
class ToolRuntimeRouter(ToolRuntime):
|
||||||
|
class RagToolImpl(RAGToolRuntime):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
routing_table: RoutingTable,
|
||||||
|
) -> None:
|
||||||
|
logger.debug("Initializing ToolRuntimeRouter.RagToolImpl")
|
||||||
|
self.routing_table = routing_table
|
||||||
|
|
||||||
|
async def query(
|
||||||
|
self,
|
||||||
|
content: InterleavedContent,
|
||||||
|
vector_db_ids: list[str],
|
||||||
|
query_config: RAGQueryConfig | None = None,
|
||||||
|
) -> RAGQueryResult:
|
||||||
|
logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
|
||||||
|
return await self.routing_table.get_provider_impl("knowledge_search").query(
|
||||||
|
content, vector_db_ids, query_config
|
||||||
|
)
|
||||||
|
|
||||||
|
async def insert(
|
||||||
|
self,
|
||||||
|
documents: list[RAGDocument],
|
||||||
|
vector_db_id: str,
|
||||||
|
chunk_size_in_tokens: int = 512,
|
||||||
|
) -> None:
|
||||||
|
logger.debug(
|
||||||
|
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
|
||||||
|
)
|
||||||
|
return await self.routing_table.get_provider_impl("insert_into_memory").insert(
|
||||||
|
documents, vector_db_id, chunk_size_in_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
routing_table: RoutingTable,
|
||||||
|
) -> None:
|
||||||
|
logger.debug("Initializing ToolRuntimeRouter")
|
||||||
|
self.routing_table = routing_table
|
||||||
|
|
||||||
|
# HACK ALERT this should be in sync with "get_all_api_endpoints()"
|
||||||
|
self.rag_tool = self.RagToolImpl(routing_table)
|
||||||
|
for method in ("query", "insert"):
|
||||||
|
setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method))
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
logger.debug("ToolRuntimeRouter.initialize")
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
logger.debug("ToolRuntimeRouter.shutdown")
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> Any:
|
||||||
|
logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}")
|
||||||
|
return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
|
||||||
|
tool_name=tool_name,
|
||||||
|
kwargs=kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def list_runtime_tools(
|
||||||
|
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
||||||
|
) -> ListToolDefsResponse:
|
||||||
|
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
|
||||||
|
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)
|
72
llama_stack/distribution/routers/vector_io.py
Normal file
72
llama_stack/distribution/routers/vector_io.py
Normal file
|
@ -0,0 +1,72 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_stack.apis.common.content_types import (
|
||||||
|
InterleavedContent,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
from llama_stack.providers.datatypes import RoutingTable
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
|
class VectorIORouter(VectorIO):
|
||||||
|
"""Routes to an provider based on the vector db identifier"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
routing_table: RoutingTable,
|
||||||
|
) -> None:
|
||||||
|
logger.debug("Initializing VectorIORouter")
|
||||||
|
self.routing_table = routing_table
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
logger.debug("VectorIORouter.initialize")
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
logger.debug("VectorIORouter.shutdown")
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def register_vector_db(
|
||||||
|
self,
|
||||||
|
vector_db_id: str,
|
||||||
|
embedding_model: str,
|
||||||
|
embedding_dimension: int | None = 384,
|
||||||
|
provider_id: str | None = None,
|
||||||
|
provider_vector_db_id: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
|
||||||
|
await self.routing_table.register_vector_db(
|
||||||
|
vector_db_id,
|
||||||
|
embedding_model,
|
||||||
|
embedding_dimension,
|
||||||
|
provider_id,
|
||||||
|
provider_vector_db_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def insert_chunks(
|
||||||
|
self,
|
||||||
|
vector_db_id: str,
|
||||||
|
chunks: list[Chunk],
|
||||||
|
ttl_seconds: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
logger.debug(
|
||||||
|
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
|
||||||
|
)
|
||||||
|
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds)
|
||||||
|
|
||||||
|
async def query_chunks(
|
||||||
|
self,
|
||||||
|
vector_db_id: str,
|
||||||
|
query: InterleavedContent,
|
||||||
|
params: dict[str, Any] | None = None,
|
||||||
|
) -> QueryChunksResponse:
|
||||||
|
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
|
||||||
|
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
|
5
llama_stack/distribution/routing_tables/__init__.py
Normal file
5
llama_stack/distribution/routing_tables/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
58
llama_stack/distribution/routing_tables/benchmarks.py
Normal file
58
llama_stack/distribution/routing_tables/benchmarks.py
Normal file
|
@ -0,0 +1,58 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse
|
||||||
|
from llama_stack.distribution.datatypes import (
|
||||||
|
BenchmarkWithACL,
|
||||||
|
)
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
from .common import CommonRoutingTableImpl
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
|
class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
|
||||||
|
async def list_benchmarks(self) -> ListBenchmarksResponse:
|
||||||
|
return ListBenchmarksResponse(data=await self.get_all_with_type("benchmark"))
|
||||||
|
|
||||||
|
async def get_benchmark(self, benchmark_id: str) -> Benchmark:
|
||||||
|
benchmark = await self.get_object_by_identifier("benchmark", benchmark_id)
|
||||||
|
if benchmark is None:
|
||||||
|
raise ValueError(f"Benchmark '{benchmark_id}' not found")
|
||||||
|
return benchmark
|
||||||
|
|
||||||
|
async def register_benchmark(
|
||||||
|
self,
|
||||||
|
benchmark_id: str,
|
||||||
|
dataset_id: str,
|
||||||
|
scoring_functions: list[str],
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
provider_benchmark_id: str | None = None,
|
||||||
|
provider_id: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
if metadata is None:
|
||||||
|
metadata = {}
|
||||||
|
if provider_id is None:
|
||||||
|
if len(self.impls_by_provider_id) == 1:
|
||||||
|
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"No provider specified and multiple providers available. Please specify a provider_id."
|
||||||
|
)
|
||||||
|
if provider_benchmark_id is None:
|
||||||
|
provider_benchmark_id = benchmark_id
|
||||||
|
benchmark = BenchmarkWithACL(
|
||||||
|
identifier=benchmark_id,
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
scoring_functions=scoring_functions,
|
||||||
|
metadata=metadata,
|
||||||
|
provider_id=provider_id,
|
||||||
|
provider_resource_id=provider_benchmark_id,
|
||||||
|
)
|
||||||
|
await self.register_object(benchmark)
|
218
llama_stack/distribution/routing_tables/common.py
Normal file
218
llama_stack/distribution/routing_tables/common.py
Normal file
|
@ -0,0 +1,218 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_stack.apis.resource import ResourceType
|
||||||
|
from llama_stack.apis.scoring_functions import ScoringFn
|
||||||
|
from llama_stack.distribution.access_control import check_access
|
||||||
|
from llama_stack.distribution.datatypes import (
|
||||||
|
AccessAttributes,
|
||||||
|
RoutableObject,
|
||||||
|
RoutableObjectWithProvider,
|
||||||
|
RoutedProtocol,
|
||||||
|
)
|
||||||
|
from llama_stack.distribution.request_headers import get_auth_attributes
|
||||||
|
from llama_stack.distribution.store import DistributionRegistry
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
from llama_stack.providers.datatypes import Api, RoutingTable
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
|
def get_impl_api(p: Any) -> Api:
|
||||||
|
return p.__provider_spec__.api
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: this should return the registered object for all APIs
|
||||||
|
async def register_object_with_provider(obj: RoutableObject, p: Any) -> RoutableObject:
|
||||||
|
api = get_impl_api(p)
|
||||||
|
|
||||||
|
assert obj.provider_id != "remote", "Remote provider should not be registered"
|
||||||
|
|
||||||
|
if api == Api.inference:
|
||||||
|
return await p.register_model(obj)
|
||||||
|
elif api == Api.safety:
|
||||||
|
return await p.register_shield(obj)
|
||||||
|
elif api == Api.vector_io:
|
||||||
|
return await p.register_vector_db(obj)
|
||||||
|
elif api == Api.datasetio:
|
||||||
|
return await p.register_dataset(obj)
|
||||||
|
elif api == Api.scoring:
|
||||||
|
return await p.register_scoring_function(obj)
|
||||||
|
elif api == Api.eval:
|
||||||
|
return await p.register_benchmark(obj)
|
||||||
|
elif api == Api.tool_runtime:
|
||||||
|
return await p.register_tool(obj)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown API {api} for registering object with provider")
|
||||||
|
|
||||||
|
|
||||||
|
async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
|
||||||
|
api = get_impl_api(p)
|
||||||
|
if api == Api.vector_io:
|
||||||
|
return await p.unregister_vector_db(obj.identifier)
|
||||||
|
elif api == Api.inference:
|
||||||
|
return await p.unregister_model(obj.identifier)
|
||||||
|
elif api == Api.datasetio:
|
||||||
|
return await p.unregister_dataset(obj.identifier)
|
||||||
|
elif api == Api.tool_runtime:
|
||||||
|
return await p.unregister_tool(obj.identifier)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unregister not supported for {api}")
|
||||||
|
|
||||||
|
|
||||||
|
Registry = dict[str, list[RoutableObjectWithProvider]]
|
||||||
|
|
||||||
|
|
||||||
|
class CommonRoutingTableImpl(RoutingTable):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
impls_by_provider_id: dict[str, RoutedProtocol],
|
||||||
|
dist_registry: DistributionRegistry,
|
||||||
|
) -> None:
|
||||||
|
self.impls_by_provider_id = impls_by_provider_id
|
||||||
|
self.dist_registry = dist_registry
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
async def add_objects(objs: list[RoutableObjectWithProvider], provider_id: str, cls) -> None:
|
||||||
|
for obj in objs:
|
||||||
|
if cls is None:
|
||||||
|
obj.provider_id = provider_id
|
||||||
|
else:
|
||||||
|
# Create a copy of the model data and explicitly set provider_id
|
||||||
|
model_data = obj.model_dump()
|
||||||
|
model_data["provider_id"] = provider_id
|
||||||
|
obj = cls(**model_data)
|
||||||
|
await self.dist_registry.register(obj)
|
||||||
|
|
||||||
|
# Register all objects from providers
|
||||||
|
for pid, p in self.impls_by_provider_id.items():
|
||||||
|
api = get_impl_api(p)
|
||||||
|
if api == Api.inference:
|
||||||
|
p.model_store = self
|
||||||
|
elif api == Api.safety:
|
||||||
|
p.shield_store = self
|
||||||
|
elif api == Api.vector_io:
|
||||||
|
p.vector_db_store = self
|
||||||
|
elif api == Api.datasetio:
|
||||||
|
p.dataset_store = self
|
||||||
|
elif api == Api.scoring:
|
||||||
|
p.scoring_function_store = self
|
||||||
|
scoring_functions = await p.list_scoring_functions()
|
||||||
|
await add_objects(scoring_functions, pid, ScoringFn)
|
||||||
|
elif api == Api.eval:
|
||||||
|
p.benchmark_store = self
|
||||||
|
elif api == Api.tool_runtime:
|
||||||
|
p.tool_store = self
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
for p in self.impls_by_provider_id.values():
|
||||||
|
await p.shutdown()
|
||||||
|
|
||||||
|
def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
|
||||||
|
from .benchmarks import BenchmarksRoutingTable
|
||||||
|
from .datasets import DatasetsRoutingTable
|
||||||
|
from .models import ModelsRoutingTable
|
||||||
|
from .scoring_functions import ScoringFunctionsRoutingTable
|
||||||
|
from .shields import ShieldsRoutingTable
|
||||||
|
from .toolgroups import ToolGroupsRoutingTable
|
||||||
|
from .vector_dbs import VectorDBsRoutingTable
|
||||||
|
|
||||||
|
def apiname_object():
|
||||||
|
if isinstance(self, ModelsRoutingTable):
|
||||||
|
return ("Inference", "model")
|
||||||
|
elif isinstance(self, ShieldsRoutingTable):
|
||||||
|
return ("Safety", "shield")
|
||||||
|
elif isinstance(self, VectorDBsRoutingTable):
|
||||||
|
return ("VectorIO", "vector_db")
|
||||||
|
elif isinstance(self, DatasetsRoutingTable):
|
||||||
|
return ("DatasetIO", "dataset")
|
||||||
|
elif isinstance(self, ScoringFunctionsRoutingTable):
|
||||||
|
return ("Scoring", "scoring_function")
|
||||||
|
elif isinstance(self, BenchmarksRoutingTable):
|
||||||
|
return ("Eval", "benchmark")
|
||||||
|
elif isinstance(self, ToolGroupsRoutingTable):
|
||||||
|
return ("Tools", "tool")
|
||||||
|
else:
|
||||||
|
raise ValueError("Unknown routing table type")
|
||||||
|
|
||||||
|
apiname, objtype = apiname_object()
|
||||||
|
|
||||||
|
# Get objects from disk registry
|
||||||
|
obj = self.dist_registry.get_cached(objtype, routing_key)
|
||||||
|
if not obj:
|
||||||
|
provider_ids = list(self.impls_by_provider_id.keys())
|
||||||
|
if len(provider_ids) > 1:
|
||||||
|
provider_ids_str = f"any of the providers: {', '.join(provider_ids)}"
|
||||||
|
else:
|
||||||
|
provider_ids_str = f"provider: `{provider_ids[0]}`"
|
||||||
|
raise ValueError(
|
||||||
|
f"{objtype.capitalize()} `{routing_key}` not served by {provider_ids_str}. Make sure there is an {apiname} provider serving this {objtype}."
|
||||||
|
)
|
||||||
|
|
||||||
|
if not provider_id or provider_id == obj.provider_id:
|
||||||
|
return self.impls_by_provider_id[obj.provider_id]
|
||||||
|
|
||||||
|
raise ValueError(f"Provider not found for `{routing_key}`")
|
||||||
|
|
||||||
|
async def get_object_by_identifier(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
|
||||||
|
# Get from disk registry
|
||||||
|
obj = await self.dist_registry.get(type, identifier)
|
||||||
|
if not obj:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Check if user has permission to access this object
|
||||||
|
if not check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes()):
|
||||||
|
logger.debug(f"Access denied to {type} '{identifier}' based on attribute mismatch")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return obj
|
||||||
|
|
||||||
|
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
|
||||||
|
await self.dist_registry.delete(obj.type, obj.identifier)
|
||||||
|
await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id])
|
||||||
|
|
||||||
|
async def register_object(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider:
|
||||||
|
# if provider_id is not specified, pick an arbitrary one from existing entries
|
||||||
|
if not obj.provider_id and len(self.impls_by_provider_id) > 0:
|
||||||
|
obj.provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||||
|
|
||||||
|
if obj.provider_id not in self.impls_by_provider_id:
|
||||||
|
raise ValueError(f"Provider `{obj.provider_id}` not found")
|
||||||
|
|
||||||
|
p = self.impls_by_provider_id[obj.provider_id]
|
||||||
|
|
||||||
|
# If object supports access control but no attributes set, use creator's attributes
|
||||||
|
if not obj.access_attributes:
|
||||||
|
creator_attributes = get_auth_attributes()
|
||||||
|
if creator_attributes:
|
||||||
|
obj.access_attributes = AccessAttributes(**creator_attributes)
|
||||||
|
logger.info(f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity")
|
||||||
|
|
||||||
|
registered_obj = await register_object_with_provider(obj, p)
|
||||||
|
# TODO: This needs to be fixed for all APIs once they return the registered object
|
||||||
|
if obj.type == ResourceType.model.value:
|
||||||
|
await self.dist_registry.register(registered_obj)
|
||||||
|
return registered_obj
|
||||||
|
|
||||||
|
else:
|
||||||
|
await self.dist_registry.register(obj)
|
||||||
|
return obj
|
||||||
|
|
||||||
|
async def get_all_with_type(self, type: str) -> list[RoutableObjectWithProvider]:
|
||||||
|
objs = await self.dist_registry.get_all()
|
||||||
|
filtered_objs = [obj for obj in objs if obj.type == type]
|
||||||
|
|
||||||
|
# Apply attribute-based access control filtering
|
||||||
|
if filtered_objs:
|
||||||
|
filtered_objs = [
|
||||||
|
obj
|
||||||
|
for obj in filtered_objs
|
||||||
|
if check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes())
|
||||||
|
]
|
||||||
|
|
||||||
|
return filtered_objs
|
93
llama_stack/distribution/routing_tables/datasets.py
Normal file
93
llama_stack/distribution/routing_tables/datasets.py
Normal file
|
@ -0,0 +1,93 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_stack.apis.datasets import (
|
||||||
|
Dataset,
|
||||||
|
DatasetPurpose,
|
||||||
|
Datasets,
|
||||||
|
DatasetType,
|
||||||
|
DataSource,
|
||||||
|
ListDatasetsResponse,
|
||||||
|
RowsDataSource,
|
||||||
|
URIDataSource,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.resource import ResourceType
|
||||||
|
from llama_stack.distribution.datatypes import (
|
||||||
|
DatasetWithACL,
|
||||||
|
)
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
from .common import CommonRoutingTableImpl
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
||||||
|
async def list_datasets(self) -> ListDatasetsResponse:
|
||||||
|
return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value))
|
||||||
|
|
||||||
|
async def get_dataset(self, dataset_id: str) -> Dataset:
|
||||||
|
dataset = await self.get_object_by_identifier("dataset", dataset_id)
|
||||||
|
if dataset is None:
|
||||||
|
raise ValueError(f"Dataset '{dataset_id}' not found")
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
async def register_dataset(
|
||||||
|
self,
|
||||||
|
purpose: DatasetPurpose,
|
||||||
|
source: DataSource,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
dataset_id: str | None = None,
|
||||||
|
) -> Dataset:
|
||||||
|
if isinstance(source, dict):
|
||||||
|
if source["type"] == "uri":
|
||||||
|
source = URIDataSource.parse_obj(source)
|
||||||
|
elif source["type"] == "rows":
|
||||||
|
source = RowsDataSource.parse_obj(source)
|
||||||
|
|
||||||
|
if not dataset_id:
|
||||||
|
dataset_id = f"dataset-{str(uuid.uuid4())}"
|
||||||
|
|
||||||
|
provider_dataset_id = dataset_id
|
||||||
|
|
||||||
|
# infer provider from source
|
||||||
|
if metadata:
|
||||||
|
if metadata.get("provider_id"):
|
||||||
|
provider_id = metadata.get("provider_id") # pass through from nvidia datasetio
|
||||||
|
elif source.type == DatasetType.rows.value:
|
||||||
|
provider_id = "localfs"
|
||||||
|
elif source.type == DatasetType.uri.value:
|
||||||
|
# infer provider from uri
|
||||||
|
if source.uri.startswith("huggingface"):
|
||||||
|
provider_id = "huggingface"
|
||||||
|
else:
|
||||||
|
provider_id = "localfs"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown data source type: {source.type}")
|
||||||
|
|
||||||
|
if metadata is None:
|
||||||
|
metadata = {}
|
||||||
|
|
||||||
|
dataset = DatasetWithACL(
|
||||||
|
identifier=dataset_id,
|
||||||
|
provider_resource_id=provider_dataset_id,
|
||||||
|
provider_id=provider_id,
|
||||||
|
purpose=purpose,
|
||||||
|
source=source,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.register_object(dataset)
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
async def unregister_dataset(self, dataset_id: str) -> None:
|
||||||
|
dataset = await self.get_dataset(dataset_id)
|
||||||
|
if dataset is None:
|
||||||
|
raise ValueError(f"Dataset {dataset_id} not found")
|
||||||
|
await self.unregister_object(dataset)
|
82
llama_stack/distribution/routing_tables/models.py
Normal file
82
llama_stack/distribution/routing_tables/models.py
Normal file
|
@ -0,0 +1,82 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import time
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel
|
||||||
|
from llama_stack.distribution.datatypes import (
|
||||||
|
ModelWithACL,
|
||||||
|
)
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
from .common import CommonRoutingTableImpl
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
|
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
|
async def list_models(self) -> ListModelsResponse:
|
||||||
|
return ListModelsResponse(data=await self.get_all_with_type("model"))
|
||||||
|
|
||||||
|
async def openai_list_models(self) -> OpenAIListModelsResponse:
|
||||||
|
models = await self.get_all_with_type("model")
|
||||||
|
openai_models = [
|
||||||
|
OpenAIModel(
|
||||||
|
id=model.identifier,
|
||||||
|
object="model",
|
||||||
|
created=int(time.time()),
|
||||||
|
owned_by="llama_stack",
|
||||||
|
)
|
||||||
|
for model in models
|
||||||
|
]
|
||||||
|
return OpenAIListModelsResponse(data=openai_models)
|
||||||
|
|
||||||
|
async def get_model(self, model_id: str) -> Model:
|
||||||
|
model = await self.get_object_by_identifier("model", model_id)
|
||||||
|
if model is None:
|
||||||
|
raise ValueError(f"Model '{model_id}' not found")
|
||||||
|
return model
|
||||||
|
|
||||||
|
async def register_model(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
provider_model_id: str | None = None,
|
||||||
|
provider_id: str | None = None,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
model_type: ModelType | None = None,
|
||||||
|
) -> Model:
|
||||||
|
if provider_model_id is None:
|
||||||
|
provider_model_id = model_id
|
||||||
|
if provider_id is None:
|
||||||
|
# If provider_id not specified, use the only provider if it supports this model
|
||||||
|
if len(self.impls_by_provider_id) == 1:
|
||||||
|
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}"
|
||||||
|
)
|
||||||
|
if metadata is None:
|
||||||
|
metadata = {}
|
||||||
|
if model_type is None:
|
||||||
|
model_type = ModelType.llm
|
||||||
|
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
|
||||||
|
raise ValueError("Embedding model must have an embedding dimension in its metadata")
|
||||||
|
model = ModelWithACL(
|
||||||
|
identifier=model_id,
|
||||||
|
provider_resource_id=provider_model_id,
|
||||||
|
provider_id=provider_id,
|
||||||
|
metadata=metadata,
|
||||||
|
model_type=model_type,
|
||||||
|
)
|
||||||
|
registered_model = await self.register_object(model)
|
||||||
|
return registered_model
|
||||||
|
|
||||||
|
async def unregister_model(self, model_id: str) -> None:
|
||||||
|
existing_model = await self.get_model(model_id)
|
||||||
|
if existing_model is None:
|
||||||
|
raise ValueError(f"Model {model_id} not found")
|
||||||
|
await self.unregister_object(existing_model)
|
62
llama_stack/distribution/routing_tables/scoring_functions.py
Normal file
62
llama_stack/distribution/routing_tables/scoring_functions.py
Normal file
|
@ -0,0 +1,62 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_stack.apis.common.type_system import ParamType
|
||||||
|
from llama_stack.apis.resource import ResourceType
|
||||||
|
from llama_stack.apis.scoring_functions import (
|
||||||
|
ListScoringFunctionsResponse,
|
||||||
|
ScoringFn,
|
||||||
|
ScoringFnParams,
|
||||||
|
ScoringFunctions,
|
||||||
|
)
|
||||||
|
from llama_stack.distribution.datatypes import (
|
||||||
|
ScoringFnWithACL,
|
||||||
|
)
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
from .common import CommonRoutingTableImpl
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
|
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
||||||
|
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
|
||||||
|
return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value))
|
||||||
|
|
||||||
|
async def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn:
|
||||||
|
scoring_fn = await self.get_object_by_identifier("scoring_function", scoring_fn_id)
|
||||||
|
if scoring_fn is None:
|
||||||
|
raise ValueError(f"Scoring function '{scoring_fn_id}' not found")
|
||||||
|
return scoring_fn
|
||||||
|
|
||||||
|
async def register_scoring_function(
|
||||||
|
self,
|
||||||
|
scoring_fn_id: str,
|
||||||
|
description: str,
|
||||||
|
return_type: ParamType,
|
||||||
|
provider_scoring_fn_id: str | None = None,
|
||||||
|
provider_id: str | None = None,
|
||||||
|
params: ScoringFnParams | None = None,
|
||||||
|
) -> None:
|
||||||
|
if provider_scoring_fn_id is None:
|
||||||
|
provider_scoring_fn_id = scoring_fn_id
|
||||||
|
if provider_id is None:
|
||||||
|
if len(self.impls_by_provider_id) == 1:
|
||||||
|
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"No provider specified and multiple providers available. Please specify a provider_id."
|
||||||
|
)
|
||||||
|
scoring_fn = ScoringFnWithACL(
|
||||||
|
identifier=scoring_fn_id,
|
||||||
|
description=description,
|
||||||
|
return_type=return_type,
|
||||||
|
provider_resource_id=provider_scoring_fn_id,
|
||||||
|
provider_id=provider_id,
|
||||||
|
params=params,
|
||||||
|
)
|
||||||
|
scoring_fn.provider_id = provider_id
|
||||||
|
await self.register_object(scoring_fn)
|
57
llama_stack/distribution/routing_tables/shields.py
Normal file
57
llama_stack/distribution/routing_tables/shields.py
Normal file
|
@ -0,0 +1,57 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_stack.apis.resource import ResourceType
|
||||||
|
from llama_stack.apis.shields import ListShieldsResponse, Shield, Shields
|
||||||
|
from llama_stack.distribution.datatypes import (
|
||||||
|
ShieldWithACL,
|
||||||
|
)
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
from .common import CommonRoutingTableImpl
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
|
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||||
|
async def list_shields(self) -> ListShieldsResponse:
|
||||||
|
return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value))
|
||||||
|
|
||||||
|
async def get_shield(self, identifier: str) -> Shield:
|
||||||
|
shield = await self.get_object_by_identifier("shield", identifier)
|
||||||
|
if shield is None:
|
||||||
|
raise ValueError(f"Shield '{identifier}' not found")
|
||||||
|
return shield
|
||||||
|
|
||||||
|
async def register_shield(
|
||||||
|
self,
|
||||||
|
shield_id: str,
|
||||||
|
provider_shield_id: str | None = None,
|
||||||
|
provider_id: str | None = None,
|
||||||
|
params: dict[str, Any] | None = None,
|
||||||
|
) -> Shield:
|
||||||
|
if provider_shield_id is None:
|
||||||
|
provider_shield_id = shield_id
|
||||||
|
if provider_id is None:
|
||||||
|
# If provider_id not specified, use the only provider if it supports this shield type
|
||||||
|
if len(self.impls_by_provider_id) == 1:
|
||||||
|
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"No provider specified and multiple providers available. Please specify a provider_id."
|
||||||
|
)
|
||||||
|
if params is None:
|
||||||
|
params = {}
|
||||||
|
shield = ShieldWithACL(
|
||||||
|
identifier=shield_id,
|
||||||
|
provider_resource_id=provider_shield_id,
|
||||||
|
provider_id=provider_id,
|
||||||
|
params=params,
|
||||||
|
)
|
||||||
|
await self.register_object(shield)
|
||||||
|
return shield
|
98
llama_stack/distribution/routing_tables/toolgroups.py
Normal file
98
llama_stack/distribution/routing_tables/toolgroups.py
Normal file
|
@ -0,0 +1,98 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_stack.apis.common.content_types import URL
|
||||||
|
from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups, ToolHost
|
||||||
|
from llama_stack.distribution.datatypes import (
|
||||||
|
ToolGroupWithACL,
|
||||||
|
ToolWithACL,
|
||||||
|
)
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
from .common import CommonRoutingTableImpl
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
|
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
|
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
|
||||||
|
tools = await self.get_all_with_type("tool")
|
||||||
|
if toolgroup_id:
|
||||||
|
tools = [tool for tool in tools if tool.toolgroup_id == toolgroup_id]
|
||||||
|
return ListToolsResponse(data=tools)
|
||||||
|
|
||||||
|
async def list_tool_groups(self) -> ListToolGroupsResponse:
|
||||||
|
return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group"))
|
||||||
|
|
||||||
|
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
|
||||||
|
tool_group = await self.get_object_by_identifier("tool_group", toolgroup_id)
|
||||||
|
if tool_group is None:
|
||||||
|
raise ValueError(f"Tool group '{toolgroup_id}' not found")
|
||||||
|
return tool_group
|
||||||
|
|
||||||
|
async def get_tool(self, tool_name: str) -> Tool:
|
||||||
|
return await self.get_object_by_identifier("tool", tool_name)
|
||||||
|
|
||||||
|
async def register_tool_group(
|
||||||
|
self,
|
||||||
|
toolgroup_id: str,
|
||||||
|
provider_id: str,
|
||||||
|
mcp_endpoint: URL | None = None,
|
||||||
|
args: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
tools = []
|
||||||
|
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint)
|
||||||
|
tool_host = ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution
|
||||||
|
|
||||||
|
for tool_def in tool_defs.data:
|
||||||
|
tools.append(
|
||||||
|
ToolWithACL(
|
||||||
|
identifier=tool_def.name,
|
||||||
|
toolgroup_id=toolgroup_id,
|
||||||
|
description=tool_def.description or "",
|
||||||
|
parameters=tool_def.parameters or [],
|
||||||
|
provider_id=provider_id,
|
||||||
|
provider_resource_id=tool_def.name,
|
||||||
|
metadata=tool_def.metadata,
|
||||||
|
tool_host=tool_host,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for tool in tools:
|
||||||
|
existing_tool = await self.get_tool(tool.identifier)
|
||||||
|
# Compare existing and new object if one exists
|
||||||
|
if existing_tool:
|
||||||
|
existing_dict = existing_tool.model_dump()
|
||||||
|
new_dict = tool.model_dump()
|
||||||
|
|
||||||
|
if existing_dict != new_dict:
|
||||||
|
raise ValueError(
|
||||||
|
f"Object {tool.identifier} already exists in registry. Please use a different identifier."
|
||||||
|
)
|
||||||
|
await self.register_object(tool)
|
||||||
|
|
||||||
|
await self.dist_registry.register(
|
||||||
|
ToolGroupWithACL(
|
||||||
|
identifier=toolgroup_id,
|
||||||
|
provider_id=provider_id,
|
||||||
|
provider_resource_id=toolgroup_id,
|
||||||
|
mcp_endpoint=mcp_endpoint,
|
||||||
|
args=args,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
|
||||||
|
tool_group = await self.get_tool_group(toolgroup_id)
|
||||||
|
if tool_group is None:
|
||||||
|
raise ValueError(f"Tool group {toolgroup_id} not found")
|
||||||
|
tools = await self.list_tools(toolgroup_id)
|
||||||
|
for tool in getattr(tools, "data", []):
|
||||||
|
await self.unregister_object(tool)
|
||||||
|
await self.unregister_object(tool_group)
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
pass
|
74
llama_stack/distribution/routing_tables/vector_dbs.py
Normal file
74
llama_stack/distribution/routing_tables/vector_dbs.py
Normal file
|
@ -0,0 +1,74 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from pydantic import TypeAdapter
|
||||||
|
|
||||||
|
from llama_stack.apis.models import ModelType
|
||||||
|
from llama_stack.apis.resource import ResourceType
|
||||||
|
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
|
||||||
|
from llama_stack.distribution.datatypes import (
|
||||||
|
VectorDBWithACL,
|
||||||
|
)
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
from .common import CommonRoutingTableImpl
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
|
class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
||||||
|
async def list_vector_dbs(self) -> ListVectorDBsResponse:
|
||||||
|
return ListVectorDBsResponse(data=await self.get_all_with_type("vector_db"))
|
||||||
|
|
||||||
|
async def get_vector_db(self, vector_db_id: str) -> VectorDB:
|
||||||
|
vector_db = await self.get_object_by_identifier("vector_db", vector_db_id)
|
||||||
|
if vector_db is None:
|
||||||
|
raise ValueError(f"Vector DB '{vector_db_id}' not found")
|
||||||
|
return vector_db
|
||||||
|
|
||||||
|
async def register_vector_db(
|
||||||
|
self,
|
||||||
|
vector_db_id: str,
|
||||||
|
embedding_model: str,
|
||||||
|
embedding_dimension: int | None = 384,
|
||||||
|
provider_id: str | None = None,
|
||||||
|
provider_vector_db_id: str | None = None,
|
||||||
|
) -> VectorDB:
|
||||||
|
if provider_vector_db_id is None:
|
||||||
|
provider_vector_db_id = vector_db_id
|
||||||
|
if provider_id is None:
|
||||||
|
if len(self.impls_by_provider_id) > 0:
|
||||||
|
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||||
|
if len(self.impls_by_provider_id) > 1:
|
||||||
|
logger.warning(
|
||||||
|
f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("No provider available. Please configure a vector_io provider.")
|
||||||
|
model = await self.get_object_by_identifier("model", embedding_model)
|
||||||
|
if model is None:
|
||||||
|
raise ValueError(f"Model {embedding_model} not found")
|
||||||
|
if model.model_type != ModelType.embedding:
|
||||||
|
raise ValueError(f"Model {embedding_model} is not an embedding model")
|
||||||
|
if "embedding_dimension" not in model.metadata:
|
||||||
|
raise ValueError(f"Model {embedding_model} does not have an embedding dimension")
|
||||||
|
vector_db_data = {
|
||||||
|
"identifier": vector_db_id,
|
||||||
|
"type": ResourceType.vector_db.value,
|
||||||
|
"provider_id": provider_id,
|
||||||
|
"provider_resource_id": provider_vector_db_id,
|
||||||
|
"embedding_model": embedding_model,
|
||||||
|
"embedding_dimension": model.metadata["embedding_dimension"],
|
||||||
|
}
|
||||||
|
vector_db = TypeAdapter(VectorDBWithACL).validate_python(vector_db_data)
|
||||||
|
await self.register_object(vector_db)
|
||||||
|
return vector_db
|
||||||
|
|
||||||
|
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||||
|
existing_vector_db = await self.get_vector_db(vector_db_id)
|
||||||
|
if existing_vector_db is None:
|
||||||
|
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||||
|
await self.unregister_object(existing_vector_db)
|
|
@ -219,6 +219,7 @@ exclude = [
|
||||||
"^llama_stack/distribution/client\\.py$",
|
"^llama_stack/distribution/client\\.py$",
|
||||||
"^llama_stack/distribution/request_headers\\.py$",
|
"^llama_stack/distribution/request_headers\\.py$",
|
||||||
"^llama_stack/distribution/routers/",
|
"^llama_stack/distribution/routers/",
|
||||||
|
"^llama_stack/distribution/routing_tables/",
|
||||||
"^llama_stack/distribution/server/endpoints\\.py$",
|
"^llama_stack/distribution/server/endpoints\\.py$",
|
||||||
"^llama_stack/distribution/server/server\\.py$",
|
"^llama_stack/distribution/server/server\\.py$",
|
||||||
"^llama_stack/distribution/stack\\.py$",
|
"^llama_stack/distribution/stack\\.py$",
|
||||||
|
|
|
@ -17,15 +17,13 @@ from llama_stack.apis.models.models import Model, ModelType
|
||||||
from llama_stack.apis.shields.shields import Shield
|
from llama_stack.apis.shields.shields import Shield
|
||||||
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolParameter
|
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolParameter
|
||||||
from llama_stack.apis.vector_dbs.vector_dbs import VectorDB
|
from llama_stack.apis.vector_dbs.vector_dbs import VectorDB
|
||||||
from llama_stack.distribution.routers.routing_tables import (
|
from llama_stack.distribution.routing_tables.benchmarks import BenchmarksRoutingTable
|
||||||
BenchmarksRoutingTable,
|
from llama_stack.distribution.routing_tables.datasets import DatasetsRoutingTable
|
||||||
DatasetsRoutingTable,
|
from llama_stack.distribution.routing_tables.models import ModelsRoutingTable
|
||||||
ModelsRoutingTable,
|
from llama_stack.distribution.routing_tables.scoring_functions import ScoringFunctionsRoutingTable
|
||||||
ScoringFunctionsRoutingTable,
|
from llama_stack.distribution.routing_tables.shields import ShieldsRoutingTable
|
||||||
ShieldsRoutingTable,
|
from llama_stack.distribution.routing_tables.toolgroups import ToolGroupsRoutingTable
|
||||||
ToolGroupsRoutingTable,
|
from llama_stack.distribution.routing_tables.vector_dbs import VectorDBsRoutingTable
|
||||||
VectorDBsRoutingTable,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Impl:
|
class Impl:
|
||||||
|
|
|
@ -11,7 +11,7 @@ import pytest
|
||||||
from llama_stack.apis.datatypes import Api
|
from llama_stack.apis.datatypes import Api
|
||||||
from llama_stack.apis.models import ModelType
|
from llama_stack.apis.models import ModelType
|
||||||
from llama_stack.distribution.datatypes import AccessAttributes, ModelWithACL
|
from llama_stack.distribution.datatypes import AccessAttributes, ModelWithACL
|
||||||
from llama_stack.distribution.routers.routing_tables import ModelsRoutingTable
|
from llama_stack.distribution.routing_tables.models import ModelsRoutingTable
|
||||||
|
|
||||||
|
|
||||||
class AsyncMock(MagicMock):
|
class AsyncMock(MagicMock):
|
||||||
|
@ -37,7 +37,7 @@ async def test_setup(cached_disk_dist_registry):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes")
|
@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes")
|
||||||
async def test_access_control_with_cache(mock_get_auth_attributes, test_setup):
|
async def test_access_control_with_cache(mock_get_auth_attributes, test_setup):
|
||||||
registry, routing_table = test_setup
|
registry, routing_table = test_setup
|
||||||
model_public = ModelWithACL(
|
model_public = ModelWithACL(
|
||||||
|
@ -102,7 +102,7 @@ async def test_access_control_with_cache(mock_get_auth_attributes, test_setup):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes")
|
@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes")
|
||||||
async def test_access_control_and_updates(mock_get_auth_attributes, test_setup):
|
async def test_access_control_and_updates(mock_get_auth_attributes, test_setup):
|
||||||
registry, routing_table = test_setup
|
registry, routing_table = test_setup
|
||||||
model_public = ModelWithACL(
|
model_public = ModelWithACL(
|
||||||
|
@ -132,7 +132,7 @@ async def test_access_control_and_updates(mock_get_auth_attributes, test_setup):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes")
|
@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes")
|
||||||
async def test_access_control_empty_attributes(mock_get_auth_attributes, test_setup):
|
async def test_access_control_empty_attributes(mock_get_auth_attributes, test_setup):
|
||||||
registry, routing_table = test_setup
|
registry, routing_table = test_setup
|
||||||
model = ModelWithACL(
|
model = ModelWithACL(
|
||||||
|
@ -154,7 +154,7 @@ async def test_access_control_empty_attributes(mock_get_auth_attributes, test_se
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes")
|
@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes")
|
||||||
async def test_no_user_attributes(mock_get_auth_attributes, test_setup):
|
async def test_no_user_attributes(mock_get_auth_attributes, test_setup):
|
||||||
registry, routing_table = test_setup
|
registry, routing_table = test_setup
|
||||||
model_public = ModelWithACL(
|
model_public = ModelWithACL(
|
||||||
|
@ -185,7 +185,7 @@ async def test_no_user_attributes(mock_get_auth_attributes, test_setup):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes")
|
@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes")
|
||||||
async def test_automatic_access_attributes(mock_get_auth_attributes, test_setup):
|
async def test_automatic_access_attributes(mock_get_auth_attributes, test_setup):
|
||||||
"""Test that newly created resources inherit access attributes from their creator."""
|
"""Test that newly created resources inherit access attributes from their creator."""
|
||||||
registry, routing_table = test_setup
|
registry, routing_table = test_setup
|
||||||
|
|
|
@ -19,8 +19,8 @@ from llama_stack.distribution.datatypes import (
|
||||||
StackRunConfig,
|
StackRunConfig,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.resolver import resolve_impls
|
from llama_stack.distribution.resolver import resolve_impls
|
||||||
from llama_stack.distribution.routers.routers import InferenceRouter
|
from llama_stack.distribution.routers.inference import InferenceRouter
|
||||||
from llama_stack.distribution.routers.routing_tables import ModelsRoutingTable
|
from llama_stack.distribution.routing_tables.models import ModelsRoutingTable
|
||||||
from llama_stack.providers.datatypes import InlineProviderSpec, ProviderSpec
|
from llama_stack.providers.datatypes import InlineProviderSpec, ProviderSpec
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue