explicit type for trace

This commit is contained in:
Dinesh Yeduguru 2024-11-27 09:24:23 -08:00
parent dfe152cb97
commit 2dfbb9744d
3 changed files with 23 additions and 18 deletions

View file

@ -137,6 +137,13 @@ Event = Annotated[
] ]
@json_schema_type
class EvalTrace(BaseModel):
step: str
input: str
output: str
@runtime_checkable @runtime_checkable
class Telemetry(Protocol): class Telemetry(Protocol):
@webmethod(route="/telemetry/log-event") @webmethod(route="/telemetry/log-event")
@ -145,11 +152,11 @@ class Telemetry(Protocol):
@webmethod(route="/telemetry/get-trace", method="POST") @webmethod(route="/telemetry/get-trace", method="POST")
async def get_trace(self, trace_id: str) -> Trace: ... async def get_trace(self, trace_id: str) -> Trace: ...
@webmethod(route="/telemetry/get-traces-for-eval", method="POST") @webmethod(route="/telemetry/get-traces-for-agent-eval", method="POST")
async def get_traces_for_eval( async def get_traces_for_agent_eval(
self, self,
session_ids: List[str], session_ids: List[str],
lookback: str = "1h", lookback: str = "1h",
limit: int = 100, limit: int = 100,
dataset_id: Optional[str] = None, dataset_id: Optional[str] = None,
) -> List[Trace]: ... ) -> List[EvalTrace]: ...

View file

@ -54,9 +54,9 @@ class ConsoleTelemetryImpl(Telemetry):
async def get_trace(self, trace_id: str) -> Trace: async def get_trace(self, trace_id: str) -> Trace:
raise NotImplementedError() raise NotImplementedError()
async def get_traces_for_session( async def get_traces_for_agent_eval(
self, session_id: str, lookback: str = "1h", limit: int = 100 self, session_ids: List[str], lookback: str = "1h", limit: int = 100
) -> List[Trace]: ) -> List[EvalTrace]:
raise NotImplementedError() raise NotImplementedError()

View file

@ -199,13 +199,13 @@ class OpenTelemetryAdapter(Telemetry):
span.end() span.end()
_GLOBAL_STORAGE["active_spans"].pop(span_id, None) _GLOBAL_STORAGE["active_spans"].pop(span_id, None)
async def get_traces_for_eval( async def get_traces_for_agent_eval(
self, self,
session_ids: List[str], session_ids: List[str],
lookback: str = "1h", lookback: str = "1h",
limit: int = 100, limit: int = 100,
dataset_id: Optional[str] = None, dataset_id: Optional[str] = None,
) -> List[Dict[str, Any]]: ) -> List[EvalTrace]:
traces = [] traces = []
# Fetch traces for each session ID individually # Fetch traces for each session ID individually
@ -311,7 +311,7 @@ class OpenTelemetryAdapter(Telemetry):
except Exception as e: except Exception as e:
raise Exception(f"Error querying Jaeger trace structure: {str(e)}") from e raise Exception(f"Error querying Jaeger trace structure: {str(e)}") from e
async def get_trace_for_eval(self, trace_id: str) -> List[Dict[str, Any]]: async def get_trace_for_eval(self, trace_id: str) -> List[EvalTrace]:
""" """
Get simplified trace information focusing on first-level children of create_and_execute_turn operations. Get simplified trace information focusing on first-level children of create_and_execute_turn operations.
Returns a list of spans with name, input, and output information, sorted by start time. Returns a list of spans with name, input, and output information, sorted by start time.
@ -320,21 +320,19 @@ class OpenTelemetryAdapter(Telemetry):
if not trace_data: if not trace_data:
return [] return []
def find_execute_turn_children( def find_execute_turn_children(spans: List[Dict[str, Any]]) -> List[EvalTrace]:
spans: List[Dict[str, Any]] results: List[EvalTrace] = []
) -> List[Dict[str, Any]]:
results = []
for span in spans: for span in spans:
if span["name"] == "create_and_execute_turn": if span["name"] == "create_and_execute_turn":
# Extract and format children spans # Extract and format children spans
children = sorted(span["children"], key=lambda x: x["start_time"]) children = sorted(span["children"], key=lambda x: x["start_time"])
for child in children: for child in children:
results.append( results.append(
{ EvalTrace(
"name": child["name"], step=child["name"],
"input": child["tags"].get("input", ""), input=child["tags"].get("input", ""),
"output": child["tags"].get("output", ""), output=child["tags"].get("output", ""),
} )
) )
# Recursively search in children # Recursively search in children
results.extend(find_execute_turn_children(span["children"])) results.extend(find_execute_turn_children(span["children"]))