feat(telemetry): clean up spans (#1760)

This commit is contained in:
ehhuang 2025-03-21 20:05:11 -07:00 committed by GitHub
parent e4de9e59fd
commit 06788643b3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 105 additions and 109 deletions

View file

@ -36,7 +36,6 @@ from llama_stack.apis.inference import (
) )
from llama_stack.apis.safety import SafetyViolation from llama_stack.apis.safety import SafetyViolation
from llama_stack.apis.tools import ToolDef from llama_stack.apis.tools import ToolDef
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@ -383,7 +382,6 @@ class AgentStepResponse(BaseModel):
@runtime_checkable @runtime_checkable
@trace_protocol
class Agents(Protocol): class Agents(Protocol):
"""Agents API for creating and interacting with agentic systems. """Agents API for creating and interacting with agentic systems.
@ -395,7 +393,7 @@ class Agents(Protocol):
- Agents can also use Memory to retrieve information from knowledge bases. See the RAG Tool and Vector IO APIs for more details. - Agents can also use Memory to retrieve information from knowledge bases. See the RAG Tool and Vector IO APIs for more details.
""" """
@webmethod(route="/agents", method="POST") @webmethod(route="/agents", method="POST", descriptive_name="create_agent")
async def create_agent( async def create_agent(
self, self,
agent_config: AgentConfig, agent_config: AgentConfig,
@ -407,7 +405,9 @@ class Agents(Protocol):
""" """
... ...
@webmethod(route="/agents/{agent_id}/session/{session_id}/turn", method="POST") @webmethod(
route="/agents/{agent_id}/session/{session_id}/turn", method="POST", descriptive_name="create_agent_turn"
)
async def create_agent_turn( async def create_agent_turn(
self, self,
agent_id: str, agent_id: str,
@ -439,6 +439,7 @@ class Agents(Protocol):
@webmethod( @webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume", route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume",
method="POST", method="POST",
descriptive_name="resume_agent_turn",
) )
async def resume_agent_turn( async def resume_agent_turn(
self, self,
@ -501,7 +502,7 @@ class Agents(Protocol):
""" """
... ...
@webmethod(route="/agents/{agent_id}/session", method="POST") @webmethod(route="/agents/{agent_id}/session", method="POST", descriptive_name="create_agent_session")
async def create_agent_session( async def create_agent_session(
self, self,
agent_id: str, agent_id: str,

View file

@ -9,7 +9,6 @@ import inspect
import json import json
import logging import logging
import os import os
import re
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
@ -37,7 +36,10 @@ from llama_stack.distribution.request_headers import (
request_provider_data_context, request_provider_data_context,
) )
from llama_stack.distribution.resolver import ProviderRegistry from llama_stack.distribution.resolver import ProviderRegistry
from llama_stack.distribution.server.endpoints import get_all_api_endpoints from llama_stack.distribution.server.endpoints import (
find_matching_endpoint,
initialize_endpoint_impls,
)
from llama_stack.distribution.stack import ( from llama_stack.distribution.stack import (
construct_stack, construct_stack,
get_stack_run_config_from_template, get_stack_run_config_from_template,
@ -232,31 +234,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
safe_config = redact_sensitive_fields(self.config.model_dump()) safe_config = redact_sensitive_fields(self.config.model_dump())
console.print(yaml.dump(safe_config, indent=2)) console.print(yaml.dump(safe_config, indent=2))
endpoints = get_all_api_endpoints() self.endpoint_impls = initialize_endpoint_impls(self.impls)
endpoint_impls = {}
def _convert_path_to_regex(path: str) -> str:
# Convert {param} to named capture groups
# handle {param:path} as well which allows for forward slashes in the param value
pattern = re.sub(
r"{(\w+)(?::path)?}",
lambda m: f"(?P<{m.group(1)}>{'[^/]+' if not m.group(0).endswith(':path') else '.+'})",
path,
)
return f"^{pattern}$"
for api, api_endpoints in endpoints.items():
if api not in self.impls:
continue
for endpoint in api_endpoints:
impl = self.impls[api]
func = getattr(impl, endpoint.name)
if endpoint.method not in endpoint_impls:
endpoint_impls[endpoint.method] = {}
endpoint_impls[endpoint.method][_convert_path_to_regex(endpoint.route)] = (func, endpoint.route)
self.endpoint_impls = endpoint_impls
return True return True
async def request( async def request(
@ -290,32 +268,6 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
) )
return response return response
def _find_matching_endpoint(self, method: str, path: str) -> tuple[Any, dict, str]:
"""Find the matching endpoint implementation for a given method and path.
Args:
method: HTTP method (GET, POST, etc.)
path: URL path to match against
Returns:
A tuple of (endpoint_function, path_params)
Raises:
ValueError: If no matching endpoint is found
"""
impls = self.endpoint_impls.get(method)
if not impls:
raise ValueError(f"No endpoint found for {path}")
for regex, (func, route) in impls.items():
match = re.match(regex, path)
if match:
# Extract named groups from the regex match
path_params = match.groupdict()
return func, path_params, route
raise ValueError(f"No endpoint found for {path}")
async def _call_non_streaming( async def _call_non_streaming(
self, self,
*, *,
@ -326,7 +278,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
body = options.params or {} body = options.params or {}
body |= options.json_data or {} body |= options.json_data or {}
matched_func, path_params, route = self._find_matching_endpoint(options.method, path) matched_func, path_params, route = find_matching_endpoint(options.method, path, self.endpoint_impls)
body |= path_params body |= path_params
body = self._convert_body(path, options.method, body) body = self._convert_body(path, options.method, body)
await start_trace(route, {"__location__": "library_client"}) await start_trace(route, {"__location__": "library_client"})
@ -371,7 +323,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
path = options.url path = options.url
body = options.params or {} body = options.params or {}
body |= options.json_data or {} body |= options.json_data or {}
func, path_params, route = self._find_matching_endpoint(options.method, path) func, path_params, route = find_matching_endpoint(options.method, path, self.endpoint_impls)
body |= path_params body |= path_params
body = self._convert_body(path, options.method, body) body = self._convert_body(path, options.method, body)
@ -422,7 +374,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
if not body: if not body:
return {} return {}
func, _, _ = self._find_matching_endpoint(method, path) func, _, _ = find_matching_endpoint(method, path, self.endpoint_impls)
sig = inspect.signature(func) sig = inspect.signature(func)
# Strip NOT_GIVENs to use the defaults in signature # Strip NOT_GIVENs to use the defaults in signature

View file

@ -5,6 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
import inspect import inspect
import re
from typing import Dict, List from typing import Dict, List
from pydantic import BaseModel from pydantic import BaseModel
@ -19,6 +20,7 @@ class ApiEndpoint(BaseModel):
route: str route: str
method: str method: str
name: str name: str
descriptive_name: str | None = None
def toolgroup_protocol_map(): def toolgroup_protocol_map():
@ -58,8 +60,69 @@ def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
method = "delete" method = "delete"
else: else:
method = "post" method = "post"
endpoints.append(ApiEndpoint(route=route, method=method, name=name)) endpoints.append(
ApiEndpoint(route=route, method=method, name=name, descriptive_name=webmethod.descriptive_name)
)
apis[api] = endpoints apis[api] = endpoints
return apis return apis
def initialize_endpoint_impls(impls):
endpoints = get_all_api_endpoints()
endpoint_impls = {}
def _convert_path_to_regex(path: str) -> str:
# Convert {param} to named capture groups
# handle {param:path} as well which allows for forward slashes in the param value
pattern = re.sub(
r"{(\w+)(?::path)?}",
lambda m: f"(?P<{m.group(1)}>{'[^/]+' if not m.group(0).endswith(':path') else '.+'})",
path,
)
return f"^{pattern}$"
for api, api_endpoints in endpoints.items():
if api not in impls:
continue
for endpoint in api_endpoints:
impl = impls[api]
func = getattr(impl, endpoint.name)
if endpoint.method not in endpoint_impls:
endpoint_impls[endpoint.method] = {}
endpoint_impls[endpoint.method][_convert_path_to_regex(endpoint.route)] = (
func,
endpoint.descriptive_name or endpoint.route,
)
return endpoint_impls
def find_matching_endpoint(method, path, endpoint_impls):
"""Find the matching endpoint implementation for a given method and path.
Args:
method: HTTP method (GET, POST, etc.)
path: URL path to match against
endpoint_impls: A dictionary of endpoint implementations
Returns:
A tuple of (endpoint_function, path_params, descriptive_name)
Raises:
ValueError: If no matching endpoint is found
"""
impls = endpoint_impls.get(method.lower())
if not impls:
raise ValueError(f"No endpoint found for {path}")
for regex, (func, descriptive_name) in impls.items():
match = re.match(regex, path)
if match:
# Extract named groups from the regex match
path_params = match.groupdict()
return func, path_params, descriptive_name
raise ValueError(f"No endpoint found for {path}")

View file

@ -32,6 +32,10 @@ from llama_stack.distribution.request_headers import (
request_provider_data_context, request_provider_data_context,
) )
from llama_stack.distribution.resolver import InvalidProviderError from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.server.endpoints import (
find_matching_endpoint,
initialize_endpoint_impls,
)
from llama_stack.distribution.stack import ( from llama_stack.distribution.stack import (
construct_stack, construct_stack,
redact_sensitive_fields, redact_sensitive_fields,
@ -222,20 +226,18 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
class TracingMiddleware: class TracingMiddleware:
def __init__(self, app): def __init__(self, app, impls):
self.app = app self.app = app
self.impls = impls
async def __call__(self, scope, receive, send): async def __call__(self, scope, receive, send):
if scope.get("type") == "lifespan": if scope.get("type") == "lifespan":
return await self.app(scope, receive, send) return await self.app(scope, receive, send)
path = scope.get("path", "") path = scope.get("path", "")
if not hasattr(self, "endpoint_impls"):
# Try to match the path to a route template self.endpoint_impls = initialize_endpoint_impls(self.impls)
route_template = self._match_path(path) _, _, trace_path = find_matching_endpoint(scope.get("method", "GET"), path, self.endpoint_impls)
# Use the matched template or original path
trace_path = route_template or path
trace_context = await start_trace(trace_path, {"__location__": "server", "raw_path": path}) trace_context = await start_trace(trace_path, {"__location__": "server", "raw_path": path})
@ -251,36 +253,6 @@ class TracingMiddleware:
finally: finally:
await end_trace() await end_trace()
def _match_path(self, path):
"""Match a path to a route template using simple segment matching."""
path_segments = path.split("/")
for route in self.app.app.routes:
if not hasattr(route, "path"):
continue
route_path = route.path
route_segments = route_path.split("/")
# Skip if number of segments doesn't match
if len(path_segments) != len(route_segments):
continue
matches = True
for path_seg, route_seg in zip(path_segments, route_segments, strict=True):
# If route segment is a parameter (contains {...}), it matches anything
if route_seg.startswith("{") and route_seg.endswith("}"):
continue
# Otherwise, segments must match exactly
elif path_seg != route_seg:
matches = False
break
if matches:
return route_path
return None
class ClientVersionMiddleware: class ClientVersionMiddleware:
def __init__(self, app): def __init__(self, app):
@ -399,7 +371,6 @@ def main():
logger.info(yaml.dump(safe_config, indent=2)) logger.info(yaml.dump(safe_config, indent=2))
app = FastAPI(lifespan=lifespan) app = FastAPI(lifespan=lifespan)
app.add_middleware(TracingMiddleware)
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"): if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
app.add_middleware(ClientVersionMiddleware) app.add_middleware(ClientVersionMiddleware)
@ -463,6 +434,7 @@ def main():
app.exception_handler(Exception)(global_exception_handler) app.exception_handler(Exception)(global_exception_handler)
app.__llama_stack_impls__ = impls app.__llama_stack_impls__ = impls
app.add_middleware(TracingMiddleware, impls=impls)
import uvicorn import uvicorn

View file

@ -180,25 +180,29 @@ class ChatAgent(ShieldRunnerMixin):
return messages return messages
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator: async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
await self._initialize_tools(request.toolgroups) span = tracing.get_current_span()
async with tracing.span("create_and_execute_turn") as span: if span:
span.set_attribute("session_id", request.session_id) span.set_attribute("session_id", request.session_id)
span.set_attribute("agent_id", self.agent_id) span.set_attribute("agent_id", self.agent_id)
span.set_attribute("request", request.model_dump_json()) span.set_attribute("request", request.model_dump_json())
turn_id = str(uuid.uuid4()) turn_id = str(uuid.uuid4())
span.set_attribute("turn_id", turn_id) span.set_attribute("turn_id", turn_id)
async for chunk in self._run_turn(request, turn_id):
yield chunk await self._initialize_tools(request.toolgroups)
async for chunk in self._run_turn(request, turn_id):
yield chunk
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator: async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
await self._initialize_tools() span = tracing.get_current_span()
async with tracing.span("resume_turn") as span: if span:
span.set_attribute("agent_id", self.agent_id) span.set_attribute("agent_id", self.agent_id)
span.set_attribute("session_id", request.session_id) span.set_attribute("session_id", request.session_id)
span.set_attribute("turn_id", request.turn_id)
span.set_attribute("request", request.model_dump_json()) span.set_attribute("request", request.model_dump_json())
async for chunk in self._run_turn(request): span.set_attribute("turn_id", request.turn_id)
yield chunk
await self._initialize_tools()
async for chunk in self._run_turn(request):
yield chunk
async def _run_turn( async def _run_turn(
self, self,

View file

@ -18,6 +18,8 @@ class WebMethod:
response_examples: Optional[List[Any]] = None response_examples: Optional[List[Any]] = None
method: Optional[str] = None method: Optional[str] = None
raw_bytes_request_body: Optional[bool] = False raw_bytes_request_body: Optional[bool] = False
# A descriptive name of the corresponding span created by tracing
descriptive_name: Optional[str] = None
class HasWebMethod(Protocol): class HasWebMethod(Protocol):
@ -34,6 +36,7 @@ def webmethod(
request_examples: Optional[List[Any]] = None, request_examples: Optional[List[Any]] = None,
response_examples: Optional[List[Any]] = None, response_examples: Optional[List[Any]] = None,
raw_bytes_request_body: Optional[bool] = False, raw_bytes_request_body: Optional[bool] = False,
descriptive_name: Optional[str] = None,
) -> Callable[[T], T]: ) -> Callable[[T], T]:
""" """
Decorator that supplies additional metadata to an endpoint operation function. Decorator that supplies additional metadata to an endpoint operation function.
@ -52,6 +55,7 @@ def webmethod(
request_examples=request_examples, request_examples=request_examples,
response_examples=response_examples, response_examples=response_examples,
raw_bytes_request_body=raw_bytes_request_body, raw_bytes_request_body=raw_bytes_request_body,
descriptive_name=descriptive_name,
) )
return cls return cls