Update Telemetry API so OpenAPI generation can work (#640)

We cannot use recursive types because not only does our OpenAPI
generator not like them, even if it did, it is not easy for all client
languages to automatically construct proper APIs (especially considering
garbage collection) around them. For now, we can return a `Dict[str,
SpanWithStatus]` instead of `SpanWithChildren` and rely on the client to
reconstruct the tree.

Also fixed a super subtle issue with the OpenAPI generation process
(monkey-patching of json_schema_type wasn't working because of import
reordering.)
This commit is contained in:
Ashwin Bharambe 2024-12-16 13:00:14 -08:00 committed by GitHub
parent 78e2bfbe7a
commit 2e5bfcd42a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 349 additions and 473 deletions

View file

@ -150,8 +150,7 @@ class EvalTrace(BaseModel):
@json_schema_type
class SpanWithChildren(Span):
children: List["SpanWithChildren"] = Field(default_factory=list)
class SpanWithStatus(Span):
status: Optional[SpanStatus] = None
@ -192,7 +191,7 @@ class Telemetry(Protocol):
span_id: str,
attributes_to_return: Optional[List[str]] = None,
max_depth: Optional[int] = None,
) -> SpanWithChildren: ...
) -> Dict[str, SpanWithStatus]: ...
@webmethod(route="/telemetry/query-spans", method="POST")
async def query_spans(

View file

@ -243,7 +243,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
span_id: str,
attributes_to_return: Optional[List[str]] = None,
max_depth: Optional[int] = None,
) -> SpanWithChildren:
) -> Dict[str, SpanWithStatus]:
return await self.trace_store.get_span_tree(
span_id=span_id,
attributes_to_return=attributes_to_return,

View file

@ -7,7 +7,7 @@
from typing import List, Optional
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.telemetry import QueryCondition, Span, SpanWithChildren
from llama_stack.apis.telemetry import QueryCondition, Span
class TelemetryDatasetMixin:
@ -53,19 +53,18 @@ class TelemetryDatasetMixin:
spans = []
for trace in traces:
span_tree = await self.get_span_tree(
spans_by_id = await self.get_span_tree(
span_id=trace.root_span_id,
attributes_to_return=attributes_to_return,
max_depth=max_depth,
)
def extract_spans(span: SpanWithChildren) -> List[Span]:
result = []
for span in spans_by_id.values():
if span.attributes and all(
attr in span.attributes and span.attributes[attr] is not None
for attr in attributes_to_return
):
result.append(
spans.append(
Span(
trace_id=trace.root_span_id,
span_id=span.span_id,
@ -77,11 +76,4 @@ class TelemetryDatasetMixin:
)
)
for child in span.children:
result.extend(extract_spans(child))
return result
spans.extend(extract_spans(span_tree))
return spans

View file

@ -6,11 +6,11 @@
import json
from datetime import datetime
from typing import List, Optional, Protocol
from typing import Dict, List, Optional, Protocol
import aiosqlite
from llama_stack.apis.telemetry import QueryCondition, SpanWithChildren, Trace
from llama_stack.apis.telemetry import QueryCondition, SpanWithStatus, Trace
class TraceStore(Protocol):
@ -27,7 +27,7 @@ class TraceStore(Protocol):
span_id: str,
attributes_to_return: Optional[List[str]] = None,
max_depth: Optional[int] = None,
) -> SpanWithChildren: ...
) -> Dict[str, SpanWithStatus]: ...
class SQLiteTraceStore(TraceStore):
@ -114,7 +114,7 @@ class SQLiteTraceStore(TraceStore):
span_id: str,
attributes_to_return: Optional[List[str]] = None,
max_depth: Optional[int] = None,
) -> SpanWithChildren:
) -> Dict[str, SpanWithStatus]:
# Build the attributes selection
attributes_select = "s.attributes"
if attributes_to_return:
@ -143,6 +143,7 @@ class SQLiteTraceStore(TraceStore):
ORDER BY depth, start_time
"""
spans_by_id = {}
async with aiosqlite.connect(self.conn_string) as conn:
conn.row_factory = aiosqlite.Row
async with conn.execute(query, (span_id, max_depth, max_depth)) as cursor:
@ -151,12 +152,8 @@ class SQLiteTraceStore(TraceStore):
if not rows:
raise ValueError(f"Span {span_id} not found")
# Build span tree
spans_by_id = {}
root_span = None
for row in rows:
span = SpanWithChildren(
span = SpanWithStatus(
span_id=row["span_id"],
trace_id=row["trace_id"],
parent_span_id=row["parent_span_id"],
@ -165,14 +162,8 @@ class SQLiteTraceStore(TraceStore):
end_time=datetime.fromisoformat(row["end_time"]),
attributes=json.loads(row["filtered_attributes"]),
status=row["status"].lower(),
children=[],
)
spans_by_id[span.span_id] = span
if span.span_id == span_id:
root_span = span
elif span.parent_span_id in spans_by_id:
spans_by_id[span.parent_span_id].children.append(span)
return root_span
return spans_by_id

View file

@ -41,8 +41,6 @@ def trace_protocol(cls: Type[T]) -> Type[T]:
"""
def trace_method(method: Callable) -> Callable:
from llama_stack.providers.utils.telemetry import tracing
is_async = asyncio.iscoroutinefunction(method)
is_async_gen = inspect.isasyncgenfunction(method)
@ -77,6 +75,8 @@ def trace_protocol(cls: Type[T]) -> Type[T]:
async def async_gen_wrapper(
self: Any, *args: Any, **kwargs: Any
) -> AsyncGenerator:
from llama_stack.providers.utils.telemetry import tracing
class_name, method_name, span_attributes = create_span_context(
self, *args, **kwargs
)
@ -92,6 +92,8 @@ def trace_protocol(cls: Type[T]) -> Type[T]:
@wraps(method)
async def async_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
from llama_stack.providers.utils.telemetry import tracing
class_name, method_name, span_attributes = create_span_context(
self, *args, **kwargs
)
@ -107,6 +109,8 @@ def trace_protocol(cls: Type[T]) -> Type[T]:
@wraps(method)
def sync_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
from llama_stack.providers.utils.telemetry import tracing
class_name, method_name, span_attributes = create_span_context(
self, *args, **kwargs
)