mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-01 16:24:44 +00:00
add endpoint to export traces and standerdize the span creation
This commit is contained in:
parent
54bc5f2d55
commit
c6b4bf8ada
4 changed files with 88 additions and 5 deletions
|
@ -6,7 +6,16 @@
|
|||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Literal, Optional, Protocol, runtime_checkable, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Protocol,
|
||||
runtime_checkable,
|
||||
Union,
|
||||
)
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel, Field
|
||||
|
@ -135,3 +144,8 @@ class Telemetry(Protocol):
|
|||
|
||||
@webmethod(route="/telemetry/get-trace", method="GET")
|
||||
async def get_trace(self, trace_id: str) -> Trace: ...
|
||||
|
||||
@webmethod(route="/telemetry/get-traces-for-session", method="POST")
|
||||
async def get_traces_for_session(
|
||||
self, session_id: str, lookback: str = "1h", limit: int = 100
|
||||
) -> List[Trace]: ...
|
||||
|
|
|
@ -285,8 +285,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
) -> AsyncGenerator:
|
||||
with tracing.span("run_shields") as span:
|
||||
span.set_attribute("turn_id", turn_id)
|
||||
span.set_attribute("messages", [m.model_dump_json() for m in messages])
|
||||
span.set_attribute("input", [m.model_dump_json() for m in messages])
|
||||
if len(shields) == 0:
|
||||
span.set_attribute("output", "no shields")
|
||||
return
|
||||
|
||||
step_id = str(uuid.uuid4())
|
||||
|
@ -315,6 +316,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
)
|
||||
)
|
||||
span.set_attribute("output", e.violation.model_dump_json())
|
||||
|
||||
yield CompletionMessage(
|
||||
content=str(e),
|
||||
|
@ -334,6 +336,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
)
|
||||
)
|
||||
span.set_attribute("output", "no violations")
|
||||
|
||||
async def _run(
|
||||
self,
|
||||
|
@ -365,7 +368,10 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
rag_context, bank_ids = await self._retrieve_context(
|
||||
session_id, input_messages, attachments
|
||||
)
|
||||
span.set_attribute("rag_context", rag_context)
|
||||
span.set_attribute(
|
||||
"input", [m.model_dump_json() for m in input_messages]
|
||||
)
|
||||
span.set_attribute("output", rag_context)
|
||||
span.set_attribute("bank_ids", bank_ids)
|
||||
|
||||
step_id = str(uuid.uuid4())
|
||||
|
@ -473,9 +479,11 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
if event.stop_reason is not None:
|
||||
stop_reason = event.stop_reason
|
||||
span.set_attribute("stop_reason", stop_reason)
|
||||
span.set_attribute("content", content)
|
||||
span.set_attribute(
|
||||
"tool_calls", [tc.model_dump_json() for tc in tool_calls]
|
||||
"input", [m.model_dump_json() for m in input_messages]
|
||||
)
|
||||
span.set_attribute(
|
||||
"output", f"content: {content} tool_calls: {tool_calls}"
|
||||
)
|
||||
|
||||
stop_reason = stop_reason or StopReason.out_of_tokens
|
||||
|
|
|
@ -18,6 +18,10 @@ class OpenTelemetryConfig(BaseModel):
|
|||
default="llama-stack",
|
||||
description="The service name to use for telemetry",
|
||||
)
|
||||
export_endpoint: str = Field(
|
||||
default="http://localhost:16686/api/traces",
|
||||
description="The Jaeger query endpoint URL",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
||||
|
|
|
@ -5,6 +5,9 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import threading
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import aiohttp
|
||||
|
||||
from opentelemetry import metrics, trace
|
||||
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
|
||||
|
@ -206,3 +209,57 @@ class OpenTelemetryAdapter(Telemetry):
|
|||
|
||||
async def get_trace(self, trace_id: str) -> Trace:
|
||||
raise NotImplementedError("Trace retrieval not implemented yet")
|
||||
|
||||
async def get_traces_for_session(
|
||||
self, session_id: str, lookback: str = "1h", limit: int = 100
|
||||
) -> List[Dict[str, Any]]:
|
||||
params = {
|
||||
"tags": f'{{"session_id":"{session_id}"}}',
|
||||
"lookback": lookback,
|
||||
"limit": limit,
|
||||
"service": self.config.service_name,
|
||||
}
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
self.config.export_endpoint, params=params
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
raise Exception(
|
||||
f"Failed to query Jaeger: {response.status} {await response.text()}"
|
||||
)
|
||||
|
||||
traces_data = await response.json()
|
||||
processed_traces = []
|
||||
|
||||
for trace_data in traces_data.get("data", []):
|
||||
trace_steps = []
|
||||
for span in trace_data.get("spans", []):
|
||||
step_info = {
|
||||
"step": span.get("operationName"),
|
||||
"start_time": span.get("startTime"),
|
||||
"duration": span.get("duration"),
|
||||
}
|
||||
|
||||
tags = span.get("tags", [])
|
||||
for tag in tags:
|
||||
if tag.get("key") == "input":
|
||||
step_info["input"] = tag.get("value")
|
||||
elif tag.get("key") == "output":
|
||||
step_info["output"] = tag.get("value")
|
||||
# we only want to return steps which have input and output
|
||||
if step_info.get("input") and step_info.get("output"):
|
||||
trace_steps.append(step_info)
|
||||
|
||||
processed_traces.append(
|
||||
{
|
||||
"trace_id": trace_data.get("traceID"),
|
||||
"steps": trace_steps,
|
||||
}
|
||||
)
|
||||
|
||||
return processed_traces
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"Error querying Jaeger traces: {str(e)}") from e
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue