diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 75f0dddd1..e13c4960b 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -36,7 +36,6 @@ from llama_stack.apis.inference import ( ) from llama_stack.apis.safety import SafetyViolation from llama_stack.apis.tools import ToolDef -from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol from llama_stack.schema_utils import json_schema_type, register_schema, webmethod @@ -383,7 +382,6 @@ class AgentStepResponse(BaseModel): @runtime_checkable -@trace_protocol class Agents(Protocol): """Agents API for creating and interacting with agentic systems. @@ -395,7 +393,7 @@ class Agents(Protocol): - Agents can also use Memory to retrieve information from knowledge bases. See the RAG Tool and Vector IO APIs for more details. """ - @webmethod(route="/agents", method="POST") + @webmethod(route="/agents", method="POST", descriptive_name="create_agent") async def create_agent( self, agent_config: AgentConfig, @@ -407,7 +405,9 @@ class Agents(Protocol): """ ... - @webmethod(route="/agents/{agent_id}/session/{session_id}/turn", method="POST") + @webmethod( + route="/agents/{agent_id}/session/{session_id}/turn", method="POST", descriptive_name="create_agent_turn" + ) async def create_agent_turn( self, agent_id: str, @@ -439,6 +439,7 @@ class Agents(Protocol): @webmethod( route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume", method="POST", + descriptive_name="resume_agent_turn", ) async def resume_agent_turn( self, @@ -501,7 +502,7 @@ class Agents(Protocol): """ ... - @webmethod(route="/agents/{agent_id}/session", method="POST") + @webmethod(route="/agents/{agent_id}/session", method="POST", descriptive_name="create_agent_session") async def create_agent_session( self, agent_id: str, diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py index bf4f18f96..565f22ae0 100644 --- a/llama_stack/distribution/library_client.py +++ b/llama_stack/distribution/library_client.py @@ -9,7 +9,6 @@ import inspect import json import logging import os -import re from concurrent.futures import ThreadPoolExecutor from enum import Enum from pathlib import Path @@ -37,7 +36,10 @@ from llama_stack.distribution.request_headers import ( request_provider_data_context, ) from llama_stack.distribution.resolver import ProviderRegistry -from llama_stack.distribution.server.endpoints import get_all_api_endpoints +from llama_stack.distribution.server.endpoints import ( + find_matching_endpoint, + initialize_endpoint_impls, +) from llama_stack.distribution.stack import ( construct_stack, get_stack_run_config_from_template, @@ -232,31 +234,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): safe_config = redact_sensitive_fields(self.config.model_dump()) console.print(yaml.dump(safe_config, indent=2)) - endpoints = get_all_api_endpoints() - endpoint_impls = {} - - def _convert_path_to_regex(path: str) -> str: - # Convert {param} to named capture groups - # handle {param:path} as well which allows for forward slashes in the param value - pattern = re.sub( - r"{(\w+)(?::path)?}", - lambda m: f"(?P<{m.group(1)}>{'[^/]+' if not m.group(0).endswith(':path') else '.+'})", - path, - ) - - return f"^{pattern}$" - - for api, api_endpoints in endpoints.items(): - if api not in self.impls: - continue - for endpoint in api_endpoints: - impl = self.impls[api] - func = getattr(impl, endpoint.name) - if endpoint.method not in endpoint_impls: - endpoint_impls[endpoint.method] = {} - endpoint_impls[endpoint.method][_convert_path_to_regex(endpoint.route)] = (func, endpoint.route) - - self.endpoint_impls = endpoint_impls + self.endpoint_impls = initialize_endpoint_impls(self.impls) return True async def request( @@ -290,32 +268,6 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): ) return response - def _find_matching_endpoint(self, method: str, path: str) -> tuple[Any, dict, str]: - """Find the matching endpoint implementation for a given method and path. - - Args: - method: HTTP method (GET, POST, etc.) - path: URL path to match against - - Returns: - A tuple of (endpoint_function, path_params) - - Raises: - ValueError: If no matching endpoint is found - """ - impls = self.endpoint_impls.get(method) - if not impls: - raise ValueError(f"No endpoint found for {path}") - - for regex, (func, route) in impls.items(): - match = re.match(regex, path) - if match: - # Extract named groups from the regex match - path_params = match.groupdict() - return func, path_params, route - - raise ValueError(f"No endpoint found for {path}") - async def _call_non_streaming( self, *, @@ -326,7 +278,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): body = options.params or {} body |= options.json_data or {} - matched_func, path_params, route = self._find_matching_endpoint(options.method, path) + matched_func, path_params, route = find_matching_endpoint(options.method, path, self.endpoint_impls) body |= path_params body = self._convert_body(path, options.method, body) await start_trace(route, {"__location__": "library_client"}) @@ -371,7 +323,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): path = options.url body = options.params or {} body |= options.json_data or {} - func, path_params, route = self._find_matching_endpoint(options.method, path) + func, path_params, route = find_matching_endpoint(options.method, path, self.endpoint_impls) body |= path_params body = self._convert_body(path, options.method, body) @@ -422,7 +374,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): if not body: return {} - func, _, _ = self._find_matching_endpoint(method, path) + func, _, _ = find_matching_endpoint(method, path, self.endpoint_impls) sig = inspect.signature(func) # Strip NOT_GIVENs to use the defaults in signature diff --git a/llama_stack/distribution/server/endpoints.py b/llama_stack/distribution/server/endpoints.py index 812f59ffd..98f01c067 100644 --- a/llama_stack/distribution/server/endpoints.py +++ b/llama_stack/distribution/server/endpoints.py @@ -5,6 +5,7 @@ # the root directory of this source tree. import inspect +import re from typing import Dict, List from pydantic import BaseModel @@ -19,6 +20,7 @@ class ApiEndpoint(BaseModel): route: str method: str name: str + descriptive_name: str | None = None def toolgroup_protocol_map(): @@ -58,8 +60,69 @@ def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]: method = "delete" else: method = "post" - endpoints.append(ApiEndpoint(route=route, method=method, name=name)) + endpoints.append( + ApiEndpoint(route=route, method=method, name=name, descriptive_name=webmethod.descriptive_name) + ) apis[api] = endpoints return apis + + +def initialize_endpoint_impls(impls): + endpoints = get_all_api_endpoints() + endpoint_impls = {} + + def _convert_path_to_regex(path: str) -> str: + # Convert {param} to named capture groups + # handle {param:path} as well which allows for forward slashes in the param value + pattern = re.sub( + r"{(\w+)(?::path)?}", + lambda m: f"(?P<{m.group(1)}>{'[^/]+' if not m.group(0).endswith(':path') else '.+'})", + path, + ) + + return f"^{pattern}$" + + for api, api_endpoints in endpoints.items(): + if api not in impls: + continue + for endpoint in api_endpoints: + impl = impls[api] + func = getattr(impl, endpoint.name) + if endpoint.method not in endpoint_impls: + endpoint_impls[endpoint.method] = {} + endpoint_impls[endpoint.method][_convert_path_to_regex(endpoint.route)] = ( + func, + endpoint.descriptive_name or endpoint.route, + ) + + return endpoint_impls + + +def find_matching_endpoint(method, path, endpoint_impls): + """Find the matching endpoint implementation for a given method and path. + + Args: + method: HTTP method (GET, POST, etc.) + path: URL path to match against + endpoint_impls: A dictionary of endpoint implementations + + Returns: + A tuple of (endpoint_function, path_params, descriptive_name) + + Raises: + ValueError: If no matching endpoint is found + """ + impls = endpoint_impls.get(method.lower()) + if not impls: + raise ValueError(f"No endpoint found for {path}") + + for regex, (func, descriptive_name) in impls.items(): + match = re.match(regex, path) + if match: + # Extract named groups from the regex match + path_params = match.groupdict() + return func, path_params, descriptive_name + + raise ValueError(f"No endpoint found for {path}") diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 39de1e4df..b967b0269 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -32,6 +32,10 @@ from llama_stack.distribution.request_headers import ( request_provider_data_context, ) from llama_stack.distribution.resolver import InvalidProviderError +from llama_stack.distribution.server.endpoints import ( + find_matching_endpoint, + initialize_endpoint_impls, +) from llama_stack.distribution.stack import ( construct_stack, redact_sensitive_fields, @@ -222,20 +226,18 @@ def create_dynamic_typed_route(func: Any, method: str, route: str): class TracingMiddleware: - def __init__(self, app): + def __init__(self, app, impls): self.app = app + self.impls = impls async def __call__(self, scope, receive, send): if scope.get("type") == "lifespan": return await self.app(scope, receive, send) path = scope.get("path", "") - - # Try to match the path to a route template - route_template = self._match_path(path) - - # Use the matched template or original path - trace_path = route_template or path + if not hasattr(self, "endpoint_impls"): + self.endpoint_impls = initialize_endpoint_impls(self.impls) + _, _, trace_path = find_matching_endpoint(scope.get("method", "GET"), path, self.endpoint_impls) trace_context = await start_trace(trace_path, {"__location__": "server", "raw_path": path}) @@ -251,36 +253,6 @@ class TracingMiddleware: finally: await end_trace() - def _match_path(self, path): - """Match a path to a route template using simple segment matching.""" - path_segments = path.split("/") - - for route in self.app.app.routes: - if not hasattr(route, "path"): - continue - - route_path = route.path - route_segments = route_path.split("/") - - # Skip if number of segments doesn't match - if len(path_segments) != len(route_segments): - continue - - matches = True - for path_seg, route_seg in zip(path_segments, route_segments, strict=True): - # If route segment is a parameter (contains {...}), it matches anything - if route_seg.startswith("{") and route_seg.endswith("}"): - continue - # Otherwise, segments must match exactly - elif path_seg != route_seg: - matches = False - break - - if matches: - return route_path - - return None - class ClientVersionMiddleware: def __init__(self, app): @@ -399,7 +371,6 @@ def main(): logger.info(yaml.dump(safe_config, indent=2)) app = FastAPI(lifespan=lifespan) - app.add_middleware(TracingMiddleware) if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"): app.add_middleware(ClientVersionMiddleware) @@ -463,6 +434,7 @@ def main(): app.exception_handler(Exception)(global_exception_handler) app.__llama_stack_impls__ = impls + app.add_middleware(TracingMiddleware, impls=impls) import uvicorn diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 88b6e9697..fe1726b07 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -180,25 +180,29 @@ class ChatAgent(ShieldRunnerMixin): return messages async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator: - await self._initialize_tools(request.toolgroups) - async with tracing.span("create_and_execute_turn") as span: + span = tracing.get_current_span() + if span: span.set_attribute("session_id", request.session_id) span.set_attribute("agent_id", self.agent_id) span.set_attribute("request", request.model_dump_json()) turn_id = str(uuid.uuid4()) span.set_attribute("turn_id", turn_id) - async for chunk in self._run_turn(request, turn_id): - yield chunk + + await self._initialize_tools(request.toolgroups) + async for chunk in self._run_turn(request, turn_id): + yield chunk async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator: - await self._initialize_tools() - async with tracing.span("resume_turn") as span: + span = tracing.get_current_span() + if span: span.set_attribute("agent_id", self.agent_id) span.set_attribute("session_id", request.session_id) - span.set_attribute("turn_id", request.turn_id) span.set_attribute("request", request.model_dump_json()) - async for chunk in self._run_turn(request): - yield chunk + span.set_attribute("turn_id", request.turn_id) + + await self._initialize_tools() + async for chunk in self._run_turn(request): + yield chunk async def _run_turn( self, diff --git a/llama_stack/schema_utils.py b/llama_stack/schema_utils.py index ad92338e6..d84b1e95f 100644 --- a/llama_stack/schema_utils.py +++ b/llama_stack/schema_utils.py @@ -18,6 +18,8 @@ class WebMethod: response_examples: Optional[List[Any]] = None method: Optional[str] = None raw_bytes_request_body: Optional[bool] = False + # A descriptive name of the corresponding span created by tracing + descriptive_name: Optional[str] = None class HasWebMethod(Protocol): @@ -34,6 +36,7 @@ def webmethod( request_examples: Optional[List[Any]] = None, response_examples: Optional[List[Any]] = None, raw_bytes_request_body: Optional[bool] = False, + descriptive_name: Optional[str] = None, ) -> Callable[[T], T]: """ Decorator that supplies additional metadata to an endpoint operation function. @@ -52,6 +55,7 @@ def webmethod( request_examples=request_examples, response_examples=response_examples, raw_bytes_request_body=raw_bytes_request_body, + descriptive_name=descriptive_name, ) return cls