From d6eac9f2193fd15a55a8aa9bf68fa5a0ca5c7c64 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Sat, 24 May 2025 22:31:17 -0700 Subject: [PATCH] =?UTF-8?q?[=F0=9D=98=80=F0=9D=97=BD=F0=9D=97=BF]=20change?= =?UTF-8?q?s=20to=20main=20this=20commit=20is=20based=20on?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Created using spr 1.3.7-beta.1 [skip ci] --- llama_stack/distribution/routers/__init__.py | 4 +- llama_stack/distribution/routers/datasets.py | 71 ++++++++++++ llama_stack/distribution/routers/routers.py | 101 ------------------ llama_stack/distribution/routers/safety.py | 57 ++++++++++ .../telemetry/meta_reference/telemetry.py | 2 +- 5 files changed, 131 insertions(+), 104 deletions(-) create mode 100644 llama_stack/distribution/routers/datasets.py create mode 100644 llama_stack/distribution/routers/safety.py diff --git a/llama_stack/distribution/routers/__init__.py b/llama_stack/distribution/routers/__init__.py index 84560b355..2befa4c16 100644 --- a/llama_stack/distribution/routers/__init__.py +++ b/llama_stack/distribution/routers/__init__.py @@ -50,15 +50,15 @@ async def get_routing_table_impl( async def get_auto_router_impl( api: Api, routing_table: RoutingTable, deps: dict[str, Any], run_config: StackRunConfig ) -> Any: + from .datasets import DatasetIORouter from .routers import ( - DatasetIORouter, EvalRouter, InferenceRouter, - SafetyRouter, ScoringRouter, ToolRuntimeRouter, VectorIORouter, ) + from .safety import SafetyRouter api_to_routers = { "vector_io": VectorIORouter, diff --git a/llama_stack/distribution/routers/datasets.py b/llama_stack/distribution/routers/datasets.py new file mode 100644 index 000000000..6f28756c9 --- /dev/null +++ b/llama_stack/distribution/routers/datasets.py @@ -0,0 +1,71 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from llama_stack.apis.common.responses import PaginatedResponse +from llama_stack.apis.datasetio import DatasetIO +from llama_stack.apis.datasets import DatasetPurpose, DataSource +from llama_stack.log import get_logger +from llama_stack.providers.datatypes import RoutingTable + +logger = get_logger(name=__name__, category="core") + + +class DatasetIORouter(DatasetIO): + def __init__( + self, + routing_table: RoutingTable, + ) -> None: + logger.debug("Initializing DatasetIORouter") + self.routing_table = routing_table + + async def initialize(self) -> None: + logger.debug("DatasetIORouter.initialize") + pass + + async def shutdown(self) -> None: + logger.debug("DatasetIORouter.shutdown") + pass + + async def register_dataset( + self, + purpose: DatasetPurpose, + source: DataSource, + metadata: dict[str, Any] | None = None, + dataset_id: str | None = None, + ) -> None: + logger.debug( + f"DatasetIORouter.register_dataset: {purpose=} {source=} {metadata=} {dataset_id=}", + ) + await self.routing_table.register_dataset( + purpose=purpose, + source=source, + metadata=metadata, + dataset_id=dataset_id, + ) + + async def iterrows( + self, + dataset_id: str, + start_index: int | None = None, + limit: int | None = None, + ) -> PaginatedResponse: + logger.debug( + f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}", + ) + return await self.routing_table.get_provider_impl(dataset_id).iterrows( + dataset_id=dataset_id, + start_index=start_index, + limit=limit, + ) + + async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None: + logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows") + return await self.routing_table.get_provider_impl(dataset_id).append_rows( + dataset_id=dataset_id, + rows=rows, + ) diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 0515b19f8..d0cb5ee7e 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -18,9 +18,6 @@ from llama_stack.apis.common.content_types import ( InterleavedContent, InterleavedContentItem, ) -from llama_stack.apis.common.responses import PaginatedResponse -from llama_stack.apis.datasetio import DatasetIO -from llama_stack.apis.datasets import DatasetPurpose, DataSource from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse, Job from llama_stack.apis.inference import ( BatchChatCompletionResponse, @@ -54,14 +51,12 @@ from llama_stack.apis.inference.inference import ( OpenAIResponseFormatParam, ) from llama_stack.apis.models import Model, ModelType -from llama_stack.apis.safety import RunShieldResponse, Safety from llama_stack.apis.scoring import ( ScoreBatchResponse, ScoreResponse, Scoring, ScoringFnParams, ) -from llama_stack.apis.shields import Shield from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry from llama_stack.apis.tools import ( ListToolDefsResponse, @@ -673,102 +668,6 @@ class InferenceRouter(Inference): return health_statuses -class SafetyRouter(Safety): - def __init__( - self, - routing_table: RoutingTable, - ) -> None: - logger.debug("Initializing SafetyRouter") - self.routing_table = routing_table - - async def initialize(self) -> None: - logger.debug("SafetyRouter.initialize") - pass - - async def shutdown(self) -> None: - logger.debug("SafetyRouter.shutdown") - pass - - async def register_shield( - self, - shield_id: str, - provider_shield_id: str | None = None, - provider_id: str | None = None, - params: dict[str, Any] | None = None, - ) -> Shield: - logger.debug(f"SafetyRouter.register_shield: {shield_id}") - return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params) - - async def run_shield( - self, - shield_id: str, - messages: list[Message], - params: dict[str, Any] = None, - ) -> RunShieldResponse: - logger.debug(f"SafetyRouter.run_shield: {shield_id}") - return await self.routing_table.get_provider_impl(shield_id).run_shield( - shield_id=shield_id, - messages=messages, - params=params, - ) - - -class DatasetIORouter(DatasetIO): - def __init__( - self, - routing_table: RoutingTable, - ) -> None: - logger.debug("Initializing DatasetIORouter") - self.routing_table = routing_table - - async def initialize(self) -> None: - logger.debug("DatasetIORouter.initialize") - pass - - async def shutdown(self) -> None: - logger.debug("DatasetIORouter.shutdown") - pass - - async def register_dataset( - self, - purpose: DatasetPurpose, - source: DataSource, - metadata: dict[str, Any] | None = None, - dataset_id: str | None = None, - ) -> None: - logger.debug( - f"DatasetIORouter.register_dataset: {purpose=} {source=} {metadata=} {dataset_id=}", - ) - await self.routing_table.register_dataset( - purpose=purpose, - source=source, - metadata=metadata, - dataset_id=dataset_id, - ) - - async def iterrows( - self, - dataset_id: str, - start_index: int | None = None, - limit: int | None = None, - ) -> PaginatedResponse: - logger.debug( - f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}", - ) - return await self.routing_table.get_provider_impl(dataset_id).iterrows( - dataset_id=dataset_id, - start_index=start_index, - limit=limit, - ) - - async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None: - logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows") - return await self.routing_table.get_provider_impl(dataset_id).append_rows( - dataset_id=dataset_id, - rows=rows, - ) - - class ScoringRouter(Scoring): def __init__( self, diff --git a/llama_stack/distribution/routers/safety.py b/llama_stack/distribution/routers/safety.py new file mode 100644 index 000000000..9761d2db0 --- /dev/null +++ b/llama_stack/distribution/routers/safety.py @@ -0,0 +1,57 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from llama_stack.apis.inference import ( + Message, +) +from llama_stack.apis.safety import RunShieldResponse, Safety +from llama_stack.apis.shields import Shield +from llama_stack.log import get_logger +from llama_stack.providers.datatypes import RoutingTable + +logger = get_logger(name=__name__, category="core") + + +class SafetyRouter(Safety): + def __init__( + self, + routing_table: RoutingTable, + ) -> None: + logger.debug("Initializing SafetyRouter") + self.routing_table = routing_table + + async def initialize(self) -> None: + logger.debug("SafetyRouter.initialize") + pass + + async def shutdown(self) -> None: + logger.debug("SafetyRouter.shutdown") + pass + + async def register_shield( + self, + shield_id: str, + provider_shield_id: str | None = None, + provider_id: str | None = None, + params: dict[str, Any] | None = None, + ) -> Shield: + logger.debug(f"SafetyRouter.register_shield: {shield_id}") + return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params) + + async def run_shield( + self, + shield_id: str, + messages: list[Message], + params: dict[str, Any] = None, + ) -> RunShieldResponse: + logger.debug(f"SafetyRouter.run_shield: {shield_id}") + return await self.routing_table.get_provider_impl(shield_id).run_shield( + shield_id=shield_id, + messages=messages, + params=params, + ) diff --git a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py index 1bc979894..0f6cf8619 100644 --- a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py +++ b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py @@ -148,7 +148,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry): if span: timestamp_ns = int(event.timestamp.timestamp() * 1e9) span.add_event( - name=event.type, + name=event.type.value, attributes={ "message": event.message, "severity": event.severity.value,