move the save to dataset to telemetry

This commit is contained in:
Dinesh Yeduguru 2024-12-05 13:36:46 -08:00
parent 4c78432bc8
commit f5d427c178
9 changed files with 79 additions and 64 deletions

View file

@ -90,11 +90,6 @@ class Eval(Protocol):
task_config: EvalTaskConfig, task_config: EvalTaskConfig,
) -> EvaluateResponse: ... ) -> EvaluateResponse: ...
@webmethod(route="/eval/create-annotation-dataset", method="POST")
async def create_annotation_dataset(
self, session_id: str, dataset_id: str
) -> None: ...
@webmethod(route="/eval/job/status", method="GET") @webmethod(route="/eval/job/status", method="GET")
async def job_status(self, task_id: str, job_id: str) -> Optional[JobStatus]: ... async def job_status(self, task_id: str, job_id: str) -> Optional[JobStatus]: ...

View file

@ -21,6 +21,8 @@ from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Annotated from typing_extensions import Annotated
from llama_stack.apis.datasetio import DatasetIO
# Add this constant near the top of the file, after the imports # Add this constant near the top of the file, after the imports
DEFAULT_TTL_DAYS = 7 DEFAULT_TTL_DAYS = 7
@ -165,6 +167,8 @@ class QueryCondition(BaseModel):
@runtime_checkable @runtime_checkable
class Telemetry(Protocol): class Telemetry(Protocol):
datasetio_api: DatasetIO
@webmethod(route="/telemetry/log-event") @webmethod(route="/telemetry/log-event")
async def log_event( async def log_event(
self, event: Event, ttl_seconds: int = DEFAULT_TTL_DAYS * 86400 self, event: Event, ttl_seconds: int = DEFAULT_TTL_DAYS * 86400
@ -186,3 +190,64 @@ class Telemetry(Protocol):
attributes_to_return: Optional[List[str]] = None, attributes_to_return: Optional[List[str]] = None,
max_depth: Optional[int] = None, max_depth: Optional[int] = None,
) -> SpanWithChildren: ... ) -> SpanWithChildren: ...
@webmethod(route="/telemetry/query-spans", method="POST")
async def query_spans(
self,
attribute_filters: List[QueryCondition],
attributes_to_return: List[str],
max_depth: Optional[int] = None,
) -> List[Dict[str, Any]]:
traces = await self.query_traces(attribute_filters=attribute_filters)
rows = []
for trace in traces:
span_tree = await self.get_span_tree(
span_id=trace.root_span_id,
attributes_to_return=attributes_to_return,
max_depth=max_depth,
)
def extract_spans(span: SpanWithChildren) -> List[Dict[str, Any]]:
rows = []
if span.attributes and all(
attr in span.attributes and span.attributes[attr] is not None
for attr in attributes_to_return
):
row = {
"trace_id": trace.root_span_id,
"span_id": span.span_id,
"step_name": span.name,
}
for attr in attributes_to_return:
row[attr] = str(span.attributes[attr])
rows.append(row)
for child in span.children:
rows.extend(extract_spans(child))
return rows
rows.extend(extract_spans(span_tree))
return rows
@webmethod(route="/telemetry/save-traces-to-dataset", method="POST")
async def save_traces_to_dataset(
self,
attribute_filters: List[QueryCondition],
attributes_to_save: List[str],
dataset_id: str,
max_depth: Optional[int] = None,
) -> None:
annotation_rows = await self.query_spans(
attribute_filters=attribute_filters,
attributes_to_return=attributes_to_save,
max_depth=max_depth,
)
if annotation_rows:
await self.datasetio_api.append_rows(
dataset_id=dataset_id, rows=annotation_rows
)

View file

@ -349,11 +349,14 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
method_owner = next( method_owner = next(
(cls for cls in mro if name in cls.__dict__), None (cls for cls in mro if name in cls.__dict__), None
) )
if ( proto_method = getattr(protocol, name)
method_owner is None if method_owner is None:
or method_owner.__name__ == protocol.__name__
):
missing_methods.append((name, "not_actually_implemented")) missing_methods.append((name, "not_actually_implemented"))
elif method_owner.__name__ == protocol.__name__:
# Check if it's just a stub (...) or has real implementation
proto_source = inspect.getsource(proto_method)
if "..." in proto_source:
missing_methods.append((name, "not_actually_implemented"))
if missing_methods: if missing_methods:
raise ValueError( raise ValueError(

View file

@ -23,7 +23,6 @@ async def get_provider_impl(
deps[Api.scoring], deps[Api.scoring],
deps[Api.inference], deps[Api.inference],
deps[Api.agents], deps[Api.agents],
deps[Api.telemetry],
) )
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -16,7 +16,6 @@ from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval_tasks import EvalTask from llama_stack.apis.eval_tasks import EvalTask
from llama_stack.apis.inference import Inference from llama_stack.apis.inference import Inference
from llama_stack.apis.scoring import Scoring from llama_stack.apis.scoring import Scoring
from llama_stack.apis.telemetry import QueryCondition, SpanWithChildren, Telemetry
from llama_stack.providers.datatypes import EvalTasksProtocolPrivate from llama_stack.providers.datatypes import EvalTasksProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore import kvstore_impl
from tqdm import tqdm from tqdm import tqdm
@ -43,7 +42,6 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
scoring_api: Scoring, scoring_api: Scoring,
inference_api: Inference, inference_api: Inference,
agents_api: Agents, agents_api: Agents,
telemetry_api: Telemetry,
) -> None: ) -> None:
self.config = config self.config = config
self.datasetio_api = datasetio_api self.datasetio_api = datasetio_api
@ -51,7 +49,6 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
self.scoring_api = scoring_api self.scoring_api = scoring_api
self.inference_api = inference_api self.inference_api = inference_api
self.agents_api = agents_api self.agents_api = agents_api
self.telemetry_api = telemetry_api
# TODO: assume sync job, will need jobs API for async scheduling # TODO: assume sync job, will need jobs API for async scheduling
self.jobs = {} self.jobs = {}
@ -272,50 +269,3 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
raise ValueError(f"Job is not completed, Status: {status.value}") raise ValueError(f"Job is not completed, Status: {status.value}")
return self.jobs[job_id] return self.jobs[job_id]
async def create_annotation_dataset(self, session_id: str, dataset_id: str) -> None:
traces = await self.telemetry_api.query_traces(
attribute_filters=[
QueryCondition(key="session_id", op="eq", value=session_id),
]
)
annotation_rows = []
for trace in traces:
span_tree = await self.telemetry_api.get_span_tree(
span_id=trace.root_span_id,
attributes_to_return=[
"input",
"output",
"name",
],
)
def extract_spans(span: SpanWithChildren) -> List[Dict[str, Any]]:
rows = []
if (
span.attributes
and "input" in span.attributes
and "output" in span.attributes
):
row = {
"input_query": span.attributes.get("input", ""),
"generated_answer": span.attributes.get("output", ""),
"trace_id": trace.root_span_id,
"span_id": span.span_id,
"step_name": span.name,
}
rows.append(row)
for child in span.children:
rows.extend(extract_spans(child))
return rows
annotation_rows.extend(extract_spans(span_tree))
if annotation_rows:
await self.datasetio_api.append_rows(
dataset_id=dataset_id, rows=annotation_rows
)

View file

@ -13,6 +13,6 @@ __all__ = ["TelemetryConfig", "TelemetryAdapter", "TelemetrySink"]
async def get_provider_impl(config: TelemetryConfig, deps: Dict[str, Any]): async def get_provider_impl(config: TelemetryConfig, deps: Dict[str, Any]):
impl = TelemetryAdapter(config) impl = TelemetryAdapter(config, deps)
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -5,7 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
import threading import threading
from typing import List, Optional from typing import Any, Dict, List, Optional
from opentelemetry import metrics, trace from opentelemetry import metrics, trace
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
@ -28,6 +28,8 @@ from llama_stack.providers.utils.telemetry.sqlite_trace_store import SQLiteTrace
from llama_stack.apis.telemetry import * # noqa: F403 from llama_stack.apis.telemetry import * # noqa: F403
from llama_stack.distribution.datatypes import Api
from .config import TelemetryConfig, TelemetrySink from .config import TelemetryConfig, TelemetrySink
_GLOBAL_STORAGE = { _GLOBAL_STORAGE = {
@ -55,8 +57,9 @@ def is_tracing_enabled(tracer):
class TelemetryAdapter(Telemetry): class TelemetryAdapter(Telemetry):
def __init__(self, config: TelemetryConfig) -> None: def __init__(self, config: TelemetryConfig, deps: Dict[str, Any]) -> None:
self.config = config self.config = config
self.datasetio_api = deps[Api.datasetio]
resource = Resource.create( resource = Resource.create(
{ {

View file

@ -23,7 +23,6 @@ def available_providers() -> List[ProviderSpec]:
Api.scoring, Api.scoring,
Api.inference, Api.inference,
Api.agents, Api.agents,
Api.telemetry,
], ],
), ),
] ]

View file

@ -18,6 +18,7 @@ def available_providers() -> List[ProviderSpec]:
"opentelemetry-sdk", "opentelemetry-sdk",
"opentelemetry-exporter-otlp-proto-http", "opentelemetry-exporter-otlp-proto-http",
], ],
api_dependencies=[Api.datasetio],
module="llama_stack.providers.inline.telemetry.meta_reference", module="llama_stack.providers.inline.telemetry.meta_reference",
config_class="llama_stack.providers.inline.telemetry.meta_reference.config.TelemetryConfig", config_class="llama_stack.providers.inline.telemetry.meta_reference.config.TelemetryConfig",
), ),