llama-stack/llama_stack/providers/utils/telemetry/dataset_mixin.py
Dinesh Yeduguru 05f6b44da7
Fix telemetry (#787)
# What does this PR do?

PR fixes couple of issues with telemetry:
1) The REST refactor changed the method from get_span_tree to
query_span_tree, which is causing the server side to return empty spans
2) Library client has introduced a new event loop, which required
changing the location of where start and end trace are called


## Test Plan

LLAMA_STACK_CONFIG="/Users/dineshyv/.llama/distributions/llamastack-fireworks/fireworks-run.yaml"
pytest -v tests/client-sdk/agents/test_agents.py -k
"test_builtin_tool_web_search"


And querying for spans from the agent run using the library client.
2025-01-16 10:36:13 -08:00

79 lines
2.6 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Optional
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.telemetry import QueryCondition, Span
class TelemetryDatasetMixin:
"""Mixin class that provides dataset-related functionality for telemetry providers."""
datasetio_api: DatasetIO
async def save_spans_to_dataset(
self,
attribute_filters: List[QueryCondition],
attributes_to_save: List[str],
dataset_id: str,
max_depth: Optional[int] = None,
) -> None:
spans = await self.query_spans(
attribute_filters=attribute_filters,
attributes_to_return=attributes_to_save,
max_depth=max_depth,
)
rows = [
{
"trace_id": span.trace_id,
"span_id": span.span_id,
"parent_span_id": span.parent_span_id,
"name": span.name,
"start_time": span.start_time,
"end_time": span.end_time,
**{attr: span.attributes.get(attr) for attr in attributes_to_save},
}
for span in spans
]
await self.datasetio_api.append_rows(dataset_id=dataset_id, rows=rows)
async def query_spans(
self,
attribute_filters: List[QueryCondition],
attributes_to_return: List[str],
max_depth: Optional[int] = None,
) -> List[Span]:
traces = await self.query_traces(attribute_filters=attribute_filters)
spans = []
for trace in traces:
spans_by_id = await self.query_span_tree(
span_id=trace.root_span_id,
attributes_to_return=attributes_to_return,
max_depth=max_depth,
)
for span in spans_by_id.values():
if span.attributes and all(
attr in span.attributes and span.attributes[attr] is not None
for attr in attributes_to_return
):
spans.append(
Span(
trace_id=trace.root_span_id,
span_id=span.span_id,
parent_span_id=span.parent_span_id,
name=span.name,
start_time=span.start_time,
end_time=span.end_time,
attributes=span.attributes,
)
)
return spans