diff --git a/llama_stack/apis/eval/eval.py b/llama_stack/apis/eval/eval.py index b9348e963..e52d4dab6 100644 --- a/llama_stack/apis/eval/eval.py +++ b/llama_stack/apis/eval/eval.py @@ -90,11 +90,6 @@ class Eval(Protocol): task_config: EvalTaskConfig, ) -> EvaluateResponse: ... - @webmethod(route="/eval/create-annotation-dataset", method="POST") - async def create_annotation_dataset( - self, session_id: str, dataset_id: str - ) -> None: ... - @webmethod(route="/eval/job/status", method="GET") async def job_status(self, task_id: str, job_id: str) -> Optional[JobStatus]: ... diff --git a/llama_stack/apis/telemetry/telemetry.py b/llama_stack/apis/telemetry/telemetry.py index 2ff783c46..d631c90ce 100644 --- a/llama_stack/apis/telemetry/telemetry.py +++ b/llama_stack/apis/telemetry/telemetry.py @@ -21,6 +21,8 @@ from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field from typing_extensions import Annotated +from llama_stack.apis.datasetio import DatasetIO + # Add this constant near the top of the file, after the imports DEFAULT_TTL_DAYS = 7 @@ -165,6 +167,8 @@ class QueryCondition(BaseModel): @runtime_checkable class Telemetry(Protocol): + datasetio_api: DatasetIO + @webmethod(route="/telemetry/log-event") async def log_event( self, event: Event, ttl_seconds: int = DEFAULT_TTL_DAYS * 86400 @@ -186,3 +190,64 @@ class Telemetry(Protocol): attributes_to_return: Optional[List[str]] = None, max_depth: Optional[int] = None, ) -> SpanWithChildren: ... + + @webmethod(route="/telemetry/query-spans", method="POST") + async def query_spans( + self, + attribute_filters: List[QueryCondition], + attributes_to_return: List[str], + max_depth: Optional[int] = None, + ) -> List[Dict[str, Any]]: + traces = await self.query_traces(attribute_filters=attribute_filters) + + rows = [] + + for trace in traces: + span_tree = await self.get_span_tree( + span_id=trace.root_span_id, + attributes_to_return=attributes_to_return, + max_depth=max_depth, + ) + + def extract_spans(span: SpanWithChildren) -> List[Dict[str, Any]]: + rows = [] + if span.attributes and all( + attr in span.attributes and span.attributes[attr] is not None + for attr in attributes_to_return + ): + row = { + "trace_id": trace.root_span_id, + "span_id": span.span_id, + "step_name": span.name, + } + for attr in attributes_to_return: + row[attr] = str(span.attributes[attr]) + rows.append(row) + + for child in span.children: + rows.extend(extract_spans(child)) + + return rows + + rows.extend(extract_spans(span_tree)) + + return rows + + @webmethod(route="/telemetry/save-traces-to-dataset", method="POST") + async def save_traces_to_dataset( + self, + attribute_filters: List[QueryCondition], + attributes_to_save: List[str], + dataset_id: str, + max_depth: Optional[int] = None, + ) -> None: + annotation_rows = await self.query_spans( + attribute_filters=attribute_filters, + attributes_to_return=attributes_to_save, + max_depth=max_depth, + ) + + if annotation_rows: + await self.datasetio_api.append_rows( + dataset_id=dataset_id, rows=annotation_rows + ) diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 9b3812e9e..084bca24f 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -349,11 +349,14 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None: method_owner = next( (cls for cls in mro if name in cls.__dict__), None ) - if ( - method_owner is None - or method_owner.__name__ == protocol.__name__ - ): + proto_method = getattr(protocol, name) + if method_owner is None: missing_methods.append((name, "not_actually_implemented")) + elif method_owner.__name__ == protocol.__name__: + # Check if it's just a stub (...) or has real implementation + proto_source = inspect.getsource(proto_method) + if "..." in proto_source: + missing_methods.append((name, "not_actually_implemented")) if missing_methods: raise ValueError( diff --git a/llama_stack/providers/inline/eval/meta_reference/__init__.py b/llama_stack/providers/inline/eval/meta_reference/__init__.py index 68fe58323..56c115322 100644 --- a/llama_stack/providers/inline/eval/meta_reference/__init__.py +++ b/llama_stack/providers/inline/eval/meta_reference/__init__.py @@ -23,7 +23,6 @@ async def get_provider_impl( deps[Api.scoring], deps[Api.inference], deps[Api.agents], - deps[Api.telemetry], ) await impl.initialize() return impl diff --git a/llama_stack/providers/inline/eval/meta_reference/eval.py b/llama_stack/providers/inline/eval/meta_reference/eval.py index d19fa8918..810d1a971 100644 --- a/llama_stack/providers/inline/eval/meta_reference/eval.py +++ b/llama_stack/providers/inline/eval/meta_reference/eval.py @@ -16,7 +16,6 @@ from llama_stack.apis.datasets import Datasets from llama_stack.apis.eval_tasks import EvalTask from llama_stack.apis.inference import Inference from llama_stack.apis.scoring import Scoring -from llama_stack.apis.telemetry import QueryCondition, SpanWithChildren, Telemetry from llama_stack.providers.datatypes import EvalTasksProtocolPrivate from llama_stack.providers.utils.kvstore import kvstore_impl from tqdm import tqdm @@ -43,7 +42,6 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate): scoring_api: Scoring, inference_api: Inference, agents_api: Agents, - telemetry_api: Telemetry, ) -> None: self.config = config self.datasetio_api = datasetio_api @@ -51,7 +49,6 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate): self.scoring_api = scoring_api self.inference_api = inference_api self.agents_api = agents_api - self.telemetry_api = telemetry_api # TODO: assume sync job, will need jobs API for async scheduling self.jobs = {} @@ -272,50 +269,3 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate): raise ValueError(f"Job is not completed, Status: {status.value}") return self.jobs[job_id] - - async def create_annotation_dataset(self, session_id: str, dataset_id: str) -> None: - traces = await self.telemetry_api.query_traces( - attribute_filters=[ - QueryCondition(key="session_id", op="eq", value=session_id), - ] - ) - - annotation_rows = [] - - for trace in traces: - span_tree = await self.telemetry_api.get_span_tree( - span_id=trace.root_span_id, - attributes_to_return=[ - "input", - "output", - "name", - ], - ) - - def extract_spans(span: SpanWithChildren) -> List[Dict[str, Any]]: - rows = [] - if ( - span.attributes - and "input" in span.attributes - and "output" in span.attributes - ): - row = { - "input_query": span.attributes.get("input", ""), - "generated_answer": span.attributes.get("output", ""), - "trace_id": trace.root_span_id, - "span_id": span.span_id, - "step_name": span.name, - } - rows.append(row) - - for child in span.children: - rows.extend(extract_spans(child)) - - return rows - - annotation_rows.extend(extract_spans(span_tree)) - - if annotation_rows: - await self.datasetio_api.append_rows( - dataset_id=dataset_id, rows=annotation_rows - ) diff --git a/llama_stack/providers/inline/telemetry/meta_reference/__init__.py b/llama_stack/providers/inline/telemetry/meta_reference/__init__.py index 6213d5536..38871a7e4 100644 --- a/llama_stack/providers/inline/telemetry/meta_reference/__init__.py +++ b/llama_stack/providers/inline/telemetry/meta_reference/__init__.py @@ -13,6 +13,6 @@ __all__ = ["TelemetryConfig", "TelemetryAdapter", "TelemetrySink"] async def get_provider_impl(config: TelemetryConfig, deps: Dict[str, Any]): - impl = TelemetryAdapter(config) + impl = TelemetryAdapter(config, deps) await impl.initialize() return impl diff --git a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py index 6540a667f..091e36b11 100644 --- a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py +++ b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py @@ -5,7 +5,7 @@ # the root directory of this source tree. import threading -from typing import List, Optional +from typing import Any, Dict, List, Optional from opentelemetry import metrics, trace from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter @@ -28,6 +28,8 @@ from llama_stack.providers.utils.telemetry.sqlite_trace_store import SQLiteTrace from llama_stack.apis.telemetry import * # noqa: F403 +from llama_stack.distribution.datatypes import Api + from .config import TelemetryConfig, TelemetrySink _GLOBAL_STORAGE = { @@ -55,8 +57,9 @@ def is_tracing_enabled(tracer): class TelemetryAdapter(Telemetry): - def __init__(self, config: TelemetryConfig) -> None: + def __init__(self, config: TelemetryConfig, deps: Dict[str, Any]) -> None: self.config = config + self.datasetio_api = deps[Api.datasetio] resource = Resource.create( { diff --git a/llama_stack/providers/registry/eval.py b/llama_stack/providers/registry/eval.py index 81c55b2db..718c7eae5 100644 --- a/llama_stack/providers/registry/eval.py +++ b/llama_stack/providers/registry/eval.py @@ -23,7 +23,6 @@ def available_providers() -> List[ProviderSpec]: Api.scoring, Api.inference, Api.agents, - Api.telemetry, ], ), ] diff --git a/llama_stack/providers/registry/telemetry.py b/llama_stack/providers/registry/telemetry.py index a53ad5b94..d367bf894 100644 --- a/llama_stack/providers/registry/telemetry.py +++ b/llama_stack/providers/registry/telemetry.py @@ -18,6 +18,7 @@ def available_providers() -> List[ProviderSpec]: "opentelemetry-sdk", "opentelemetry-exporter-otlp-proto-http", ], + api_dependencies=[Api.datasetio], module="llama_stack.providers.inline.telemetry.meta_reference", config_class="llama_stack.providers.inline.telemetry.meta_reference.config.TelemetryConfig", ),