diff --git a/llama_stack/apis/eval/eval.py b/llama_stack/apis/eval/eval.py index e52d4dab6..b9348e963 100644 --- a/llama_stack/apis/eval/eval.py +++ b/llama_stack/apis/eval/eval.py @@ -90,6 +90,11 @@ class Eval(Protocol): task_config: EvalTaskConfig, ) -> EvaluateResponse: ... + @webmethod(route="/eval/create-annotation-dataset", method="POST") + async def create_annotation_dataset( + self, session_id: str, dataset_id: str + ) -> None: ... + @webmethod(route="/eval/job/status", method="GET") async def job_status(self, task_id: str, job_id: str) -> Optional[JobStatus]: ... diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 5b75a525b..88449b82e 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -322,6 +322,14 @@ class EvalRouter(Eval): task_config=task_config, ) + async def create_annotation_dataset(self, session_id: str, dataset_id: str) -> None: + return await self.routing_table.get_provider_impl( + task_id + ).create_annotation_dataset( + session_id=session_id, + dataset_id=dataset_id, + ) + async def job_status( self, task_id: str, diff --git a/llama_stack/providers/inline/eval/meta_reference/__init__.py b/llama_stack/providers/inline/eval/meta_reference/__init__.py index 56c115322..68fe58323 100644 --- a/llama_stack/providers/inline/eval/meta_reference/__init__.py +++ b/llama_stack/providers/inline/eval/meta_reference/__init__.py @@ -23,6 +23,7 @@ async def get_provider_impl( deps[Api.scoring], deps[Api.inference], deps[Api.agents], + deps[Api.telemetry], ) await impl.initialize() return impl diff --git a/llama_stack/providers/inline/eval/meta_reference/eval.py b/llama_stack/providers/inline/eval/meta_reference/eval.py index c6cacfcc3..d19fa8918 100644 --- a/llama_stack/providers/inline/eval/meta_reference/eval.py +++ b/llama_stack/providers/inline/eval/meta_reference/eval.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. from enum import Enum +from typing import Any, Dict, List, Optional from llama_models.llama3.api.datatypes import * # noqa: F403 from .....apis.common.job_types import Job @@ -15,6 +16,7 @@ from llama_stack.apis.datasets import Datasets from llama_stack.apis.eval_tasks import EvalTask from llama_stack.apis.inference import Inference from llama_stack.apis.scoring import Scoring +from llama_stack.apis.telemetry import QueryCondition, SpanWithChildren, Telemetry from llama_stack.providers.datatypes import EvalTasksProtocolPrivate from llama_stack.providers.utils.kvstore import kvstore_impl from tqdm import tqdm @@ -41,6 +43,7 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate): scoring_api: Scoring, inference_api: Inference, agents_api: Agents, + telemetry_api: Telemetry, ) -> None: self.config = config self.datasetio_api = datasetio_api @@ -48,6 +51,7 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate): self.scoring_api = scoring_api self.inference_api = inference_api self.agents_api = agents_api + self.telemetry_api = telemetry_api # TODO: assume sync job, will need jobs API for async scheduling self.jobs = {} @@ -268,3 +272,50 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate): raise ValueError(f"Job is not completed, Status: {status.value}") return self.jobs[job_id] + + async def create_annotation_dataset(self, session_id: str, dataset_id: str) -> None: + traces = await self.telemetry_api.query_traces( + attribute_filters=[ + QueryCondition(key="session_id", op="eq", value=session_id), + ] + ) + + annotation_rows = [] + + for trace in traces: + span_tree = await self.telemetry_api.get_span_tree( + span_id=trace.root_span_id, + attributes_to_return=[ + "input", + "output", + "name", + ], + ) + + def extract_spans(span: SpanWithChildren) -> List[Dict[str, Any]]: + rows = [] + if ( + span.attributes + and "input" in span.attributes + and "output" in span.attributes + ): + row = { + "input_query": span.attributes.get("input", ""), + "generated_answer": span.attributes.get("output", ""), + "trace_id": trace.root_span_id, + "span_id": span.span_id, + "step_name": span.name, + } + rows.append(row) + + for child in span.children: + rows.extend(extract_spans(child)) + + return rows + + annotation_rows.extend(extract_spans(span_tree)) + + if annotation_rows: + await self.datasetio_api.append_rows( + dataset_id=dataset_id, rows=annotation_rows + ) diff --git a/llama_stack/providers/registry/eval.py b/llama_stack/providers/registry/eval.py index 718c7eae5..81c55b2db 100644 --- a/llama_stack/providers/registry/eval.py +++ b/llama_stack/providers/registry/eval.py @@ -23,6 +23,7 @@ def available_providers() -> List[ProviderSpec]: Api.scoring, Api.inference, Api.agents, + Api.telemetry, ], ), ]