mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 00:34:44 +00:00
add endpoint to export traces and standerdize the span creation
This commit is contained in:
parent
54bc5f2d55
commit
c6b4bf8ada
4 changed files with 88 additions and 5 deletions
|
@ -6,7 +6,16 @@
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, Literal, Optional, Protocol, runtime_checkable, Union
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Protocol,
|
||||||
|
runtime_checkable,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
@ -135,3 +144,8 @@ class Telemetry(Protocol):
|
||||||
|
|
||||||
@webmethod(route="/telemetry/get-trace", method="GET")
|
@webmethod(route="/telemetry/get-trace", method="GET")
|
||||||
async def get_trace(self, trace_id: str) -> Trace: ...
|
async def get_trace(self, trace_id: str) -> Trace: ...
|
||||||
|
|
||||||
|
@webmethod(route="/telemetry/get-traces-for-session", method="POST")
|
||||||
|
async def get_traces_for_session(
|
||||||
|
self, session_id: str, lookback: str = "1h", limit: int = 100
|
||||||
|
) -> List[Trace]: ...
|
||||||
|
|
|
@ -285,8 +285,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
with tracing.span("run_shields") as span:
|
with tracing.span("run_shields") as span:
|
||||||
span.set_attribute("turn_id", turn_id)
|
span.set_attribute("turn_id", turn_id)
|
||||||
span.set_attribute("messages", [m.model_dump_json() for m in messages])
|
span.set_attribute("input", [m.model_dump_json() for m in messages])
|
||||||
if len(shields) == 0:
|
if len(shields) == 0:
|
||||||
|
span.set_attribute("output", "no shields")
|
||||||
return
|
return
|
||||||
|
|
||||||
step_id = str(uuid.uuid4())
|
step_id = str(uuid.uuid4())
|
||||||
|
@ -315,6 +316,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
span.set_attribute("output", e.violation.model_dump_json())
|
||||||
|
|
||||||
yield CompletionMessage(
|
yield CompletionMessage(
|
||||||
content=str(e),
|
content=str(e),
|
||||||
|
@ -334,6 +336,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
span.set_attribute("output", "no violations")
|
||||||
|
|
||||||
async def _run(
|
async def _run(
|
||||||
self,
|
self,
|
||||||
|
@ -365,7 +368,10 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
rag_context, bank_ids = await self._retrieve_context(
|
rag_context, bank_ids = await self._retrieve_context(
|
||||||
session_id, input_messages, attachments
|
session_id, input_messages, attachments
|
||||||
)
|
)
|
||||||
span.set_attribute("rag_context", rag_context)
|
span.set_attribute(
|
||||||
|
"input", [m.model_dump_json() for m in input_messages]
|
||||||
|
)
|
||||||
|
span.set_attribute("output", rag_context)
|
||||||
span.set_attribute("bank_ids", bank_ids)
|
span.set_attribute("bank_ids", bank_ids)
|
||||||
|
|
||||||
step_id = str(uuid.uuid4())
|
step_id = str(uuid.uuid4())
|
||||||
|
@ -473,9 +479,11 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
if event.stop_reason is not None:
|
if event.stop_reason is not None:
|
||||||
stop_reason = event.stop_reason
|
stop_reason = event.stop_reason
|
||||||
span.set_attribute("stop_reason", stop_reason)
|
span.set_attribute("stop_reason", stop_reason)
|
||||||
span.set_attribute("content", content)
|
|
||||||
span.set_attribute(
|
span.set_attribute(
|
||||||
"tool_calls", [tc.model_dump_json() for tc in tool_calls]
|
"input", [m.model_dump_json() for m in input_messages]
|
||||||
|
)
|
||||||
|
span.set_attribute(
|
||||||
|
"output", f"content: {content} tool_calls: {tool_calls}"
|
||||||
)
|
)
|
||||||
|
|
||||||
stop_reason = stop_reason or StopReason.out_of_tokens
|
stop_reason = stop_reason or StopReason.out_of_tokens
|
||||||
|
|
|
@ -18,6 +18,10 @@ class OpenTelemetryConfig(BaseModel):
|
||||||
default="llama-stack",
|
default="llama-stack",
|
||||||
description="The service name to use for telemetry",
|
description="The service name to use for telemetry",
|
||||||
)
|
)
|
||||||
|
export_endpoint: str = Field(
|
||||||
|
default="http://localhost:16686/api/traces",
|
||||||
|
description="The Jaeger query endpoint URL",
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
||||||
|
|
|
@ -5,6 +5,9 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import threading
|
import threading
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
from opentelemetry import metrics, trace
|
from opentelemetry import metrics, trace
|
||||||
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
|
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
|
||||||
|
@ -206,3 +209,57 @@ class OpenTelemetryAdapter(Telemetry):
|
||||||
|
|
||||||
async def get_trace(self, trace_id: str) -> Trace:
|
async def get_trace(self, trace_id: str) -> Trace:
|
||||||
raise NotImplementedError("Trace retrieval not implemented yet")
|
raise NotImplementedError("Trace retrieval not implemented yet")
|
||||||
|
|
||||||
|
async def get_traces_for_session(
|
||||||
|
self, session_id: str, lookback: str = "1h", limit: int = 100
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
params = {
|
||||||
|
"tags": f'{{"session_id":"{session_id}"}}',
|
||||||
|
"lookback": lookback,
|
||||||
|
"limit": limit,
|
||||||
|
"service": self.config.service_name,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.get(
|
||||||
|
self.config.export_endpoint, params=params
|
||||||
|
) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
raise Exception(
|
||||||
|
f"Failed to query Jaeger: {response.status} {await response.text()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
traces_data = await response.json()
|
||||||
|
processed_traces = []
|
||||||
|
|
||||||
|
for trace_data in traces_data.get("data", []):
|
||||||
|
trace_steps = []
|
||||||
|
for span in trace_data.get("spans", []):
|
||||||
|
step_info = {
|
||||||
|
"step": span.get("operationName"),
|
||||||
|
"start_time": span.get("startTime"),
|
||||||
|
"duration": span.get("duration"),
|
||||||
|
}
|
||||||
|
|
||||||
|
tags = span.get("tags", [])
|
||||||
|
for tag in tags:
|
||||||
|
if tag.get("key") == "input":
|
||||||
|
step_info["input"] = tag.get("value")
|
||||||
|
elif tag.get("key") == "output":
|
||||||
|
step_info["output"] = tag.get("value")
|
||||||
|
# we only want to return steps which have input and output
|
||||||
|
if step_info.get("input") and step_info.get("output"):
|
||||||
|
trace_steps.append(step_info)
|
||||||
|
|
||||||
|
processed_traces.append(
|
||||||
|
{
|
||||||
|
"trace_id": trace_data.get("traceID"),
|
||||||
|
"steps": trace_steps,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return processed_traces
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"Error querying Jaeger traces: {str(e)}") from e
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue