From 2dfbb9744d5db23c7d12da8dc55b5674363e17ab Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Wed, 27 Nov 2024 09:24:23 -0800 Subject: [PATCH] explicit type for trace --- llama_stack/apis/telemetry/telemetry.py | 13 ++++++++--- .../meta_reference/telemetry/console.py | 6 ++--- .../telemetry/opentelemetry/opentelemetry.py | 22 +++++++++---------- 3 files changed, 23 insertions(+), 18 deletions(-) diff --git a/llama_stack/apis/telemetry/telemetry.py b/llama_stack/apis/telemetry/telemetry.py index c9d939d2d..b0550f848 100644 --- a/llama_stack/apis/telemetry/telemetry.py +++ b/llama_stack/apis/telemetry/telemetry.py @@ -137,6 +137,13 @@ Event = Annotated[ ] +@json_schema_type +class EvalTrace(BaseModel): + step: str + input: str + output: str + + @runtime_checkable class Telemetry(Protocol): @webmethod(route="/telemetry/log-event") @@ -145,11 +152,11 @@ class Telemetry(Protocol): @webmethod(route="/telemetry/get-trace", method="POST") async def get_trace(self, trace_id: str) -> Trace: ... - @webmethod(route="/telemetry/get-traces-for-eval", method="POST") - async def get_traces_for_eval( + @webmethod(route="/telemetry/get-traces-for-agent-eval", method="POST") + async def get_traces_for_agent_eval( self, session_ids: List[str], lookback: str = "1h", limit: int = 100, dataset_id: Optional[str] = None, - ) -> List[Trace]: ... + ) -> List[EvalTrace]: ... diff --git a/llama_stack/providers/inline/meta_reference/telemetry/console.py b/llama_stack/providers/inline/meta_reference/telemetry/console.py index 264b82b69..e6dbdda64 100644 --- a/llama_stack/providers/inline/meta_reference/telemetry/console.py +++ b/llama_stack/providers/inline/meta_reference/telemetry/console.py @@ -54,9 +54,9 @@ class ConsoleTelemetryImpl(Telemetry): async def get_trace(self, trace_id: str) -> Trace: raise NotImplementedError() - async def get_traces_for_session( - self, session_id: str, lookback: str = "1h", limit: int = 100 - ) -> List[Trace]: + async def get_traces_for_agent_eval( + self, session_ids: List[str], lookback: str = "1h", limit: int = 100 + ) -> List[EvalTrace]: raise NotImplementedError() diff --git a/llama_stack/providers/remote/telemetry/opentelemetry/opentelemetry.py b/llama_stack/providers/remote/telemetry/opentelemetry/opentelemetry.py index c4e4afabe..89d669ab1 100644 --- a/llama_stack/providers/remote/telemetry/opentelemetry/opentelemetry.py +++ b/llama_stack/providers/remote/telemetry/opentelemetry/opentelemetry.py @@ -199,13 +199,13 @@ class OpenTelemetryAdapter(Telemetry): span.end() _GLOBAL_STORAGE["active_spans"].pop(span_id, None) - async def get_traces_for_eval( + async def get_traces_for_agent_eval( self, session_ids: List[str], lookback: str = "1h", limit: int = 100, dataset_id: Optional[str] = None, - ) -> List[Dict[str, Any]]: + ) -> List[EvalTrace]: traces = [] # Fetch traces for each session ID individually @@ -311,7 +311,7 @@ class OpenTelemetryAdapter(Telemetry): except Exception as e: raise Exception(f"Error querying Jaeger trace structure: {str(e)}") from e - async def get_trace_for_eval(self, trace_id: str) -> List[Dict[str, Any]]: + async def get_trace_for_eval(self, trace_id: str) -> List[EvalTrace]: """ Get simplified trace information focusing on first-level children of create_and_execute_turn operations. Returns a list of spans with name, input, and output information, sorted by start time. @@ -320,21 +320,19 @@ class OpenTelemetryAdapter(Telemetry): if not trace_data: return [] - def find_execute_turn_children( - spans: List[Dict[str, Any]] - ) -> List[Dict[str, Any]]: - results = [] + def find_execute_turn_children(spans: List[Dict[str, Any]]) -> List[EvalTrace]: + results: List[EvalTrace] = [] for span in spans: if span["name"] == "create_and_execute_turn": # Extract and format children spans children = sorted(span["children"], key=lambda x: x["start_time"]) for child in children: results.append( - { - "name": child["name"], - "input": child["tags"].get("input", ""), - "output": child["tags"].get("output", ""), - } + EvalTrace( + step=child["name"], + input=child["tags"].get("input", ""), + output=child["tags"].get("output", ""), + ) ) # Recursively search in children results.extend(find_execute_turn_children(span["children"]))