mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
feat(telemetry): clean up spans (#1760)
This commit is contained in:
parent
e4de9e59fd
commit
06788643b3
6 changed files with 105 additions and 109 deletions
|
@ -36,7 +36,6 @@ from llama_stack.apis.inference import (
|
|||
)
|
||||
from llama_stack.apis.safety import SafetyViolation
|
||||
from llama_stack.apis.tools import ToolDef
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
|
||||
|
@ -383,7 +382,6 @@ class AgentStepResponse(BaseModel):
|
|||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class Agents(Protocol):
|
||||
"""Agents API for creating and interacting with agentic systems.
|
||||
|
||||
|
@ -395,7 +393,7 @@ class Agents(Protocol):
|
|||
- Agents can also use Memory to retrieve information from knowledge bases. See the RAG Tool and Vector IO APIs for more details.
|
||||
"""
|
||||
|
||||
@webmethod(route="/agents", method="POST")
|
||||
@webmethod(route="/agents", method="POST", descriptive_name="create_agent")
|
||||
async def create_agent(
|
||||
self,
|
||||
agent_config: AgentConfig,
|
||||
|
@ -407,7 +405,9 @@ class Agents(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents/{agent_id}/session/{session_id}/turn", method="POST")
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}/turn", method="POST", descriptive_name="create_agent_turn"
|
||||
)
|
||||
async def create_agent_turn(
|
||||
self,
|
||||
agent_id: str,
|
||||
|
@ -439,6 +439,7 @@ class Agents(Protocol):
|
|||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume",
|
||||
method="POST",
|
||||
descriptive_name="resume_agent_turn",
|
||||
)
|
||||
async def resume_agent_turn(
|
||||
self,
|
||||
|
@ -501,7 +502,7 @@ class Agents(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents/{agent_id}/session", method="POST")
|
||||
@webmethod(route="/agents/{agent_id}/session", method="POST", descriptive_name="create_agent_session")
|
||||
async def create_agent_session(
|
||||
self,
|
||||
agent_id: str,
|
||||
|
|
|
@ -9,7 +9,6 @@ import inspect
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
|
@ -37,7 +36,10 @@ from llama_stack.distribution.request_headers import (
|
|||
request_provider_data_context,
|
||||
)
|
||||
from llama_stack.distribution.resolver import ProviderRegistry
|
||||
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
|
||||
from llama_stack.distribution.server.endpoints import (
|
||||
find_matching_endpoint,
|
||||
initialize_endpoint_impls,
|
||||
)
|
||||
from llama_stack.distribution.stack import (
|
||||
construct_stack,
|
||||
get_stack_run_config_from_template,
|
||||
|
@ -232,31 +234,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
safe_config = redact_sensitive_fields(self.config.model_dump())
|
||||
console.print(yaml.dump(safe_config, indent=2))
|
||||
|
||||
endpoints = get_all_api_endpoints()
|
||||
endpoint_impls = {}
|
||||
|
||||
def _convert_path_to_regex(path: str) -> str:
|
||||
# Convert {param} to named capture groups
|
||||
# handle {param:path} as well which allows for forward slashes in the param value
|
||||
pattern = re.sub(
|
||||
r"{(\w+)(?::path)?}",
|
||||
lambda m: f"(?P<{m.group(1)}>{'[^/]+' if not m.group(0).endswith(':path') else '.+'})",
|
||||
path,
|
||||
)
|
||||
|
||||
return f"^{pattern}$"
|
||||
|
||||
for api, api_endpoints in endpoints.items():
|
||||
if api not in self.impls:
|
||||
continue
|
||||
for endpoint in api_endpoints:
|
||||
impl = self.impls[api]
|
||||
func = getattr(impl, endpoint.name)
|
||||
if endpoint.method not in endpoint_impls:
|
||||
endpoint_impls[endpoint.method] = {}
|
||||
endpoint_impls[endpoint.method][_convert_path_to_regex(endpoint.route)] = (func, endpoint.route)
|
||||
|
||||
self.endpoint_impls = endpoint_impls
|
||||
self.endpoint_impls = initialize_endpoint_impls(self.impls)
|
||||
return True
|
||||
|
||||
async def request(
|
||||
|
@ -290,32 +268,6 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
)
|
||||
return response
|
||||
|
||||
def _find_matching_endpoint(self, method: str, path: str) -> tuple[Any, dict, str]:
|
||||
"""Find the matching endpoint implementation for a given method and path.
|
||||
|
||||
Args:
|
||||
method: HTTP method (GET, POST, etc.)
|
||||
path: URL path to match against
|
||||
|
||||
Returns:
|
||||
A tuple of (endpoint_function, path_params)
|
||||
|
||||
Raises:
|
||||
ValueError: If no matching endpoint is found
|
||||
"""
|
||||
impls = self.endpoint_impls.get(method)
|
||||
if not impls:
|
||||
raise ValueError(f"No endpoint found for {path}")
|
||||
|
||||
for regex, (func, route) in impls.items():
|
||||
match = re.match(regex, path)
|
||||
if match:
|
||||
# Extract named groups from the regex match
|
||||
path_params = match.groupdict()
|
||||
return func, path_params, route
|
||||
|
||||
raise ValueError(f"No endpoint found for {path}")
|
||||
|
||||
async def _call_non_streaming(
|
||||
self,
|
||||
*,
|
||||
|
@ -326,7 +278,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
body = options.params or {}
|
||||
body |= options.json_data or {}
|
||||
|
||||
matched_func, path_params, route = self._find_matching_endpoint(options.method, path)
|
||||
matched_func, path_params, route = find_matching_endpoint(options.method, path, self.endpoint_impls)
|
||||
body |= path_params
|
||||
body = self._convert_body(path, options.method, body)
|
||||
await start_trace(route, {"__location__": "library_client"})
|
||||
|
@ -371,7 +323,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
path = options.url
|
||||
body = options.params or {}
|
||||
body |= options.json_data or {}
|
||||
func, path_params, route = self._find_matching_endpoint(options.method, path)
|
||||
func, path_params, route = find_matching_endpoint(options.method, path, self.endpoint_impls)
|
||||
body |= path_params
|
||||
|
||||
body = self._convert_body(path, options.method, body)
|
||||
|
@ -422,7 +374,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
if not body:
|
||||
return {}
|
||||
|
||||
func, _, _ = self._find_matching_endpoint(method, path)
|
||||
func, _, _ = find_matching_endpoint(method, path, self.endpoint_impls)
|
||||
sig = inspect.signature(func)
|
||||
|
||||
# Strip NOT_GIVENs to use the defaults in signature
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import inspect
|
||||
import re
|
||||
from typing import Dict, List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
@ -19,6 +20,7 @@ class ApiEndpoint(BaseModel):
|
|||
route: str
|
||||
method: str
|
||||
name: str
|
||||
descriptive_name: str | None = None
|
||||
|
||||
|
||||
def toolgroup_protocol_map():
|
||||
|
@ -58,8 +60,69 @@ def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
|||
method = "delete"
|
||||
else:
|
||||
method = "post"
|
||||
endpoints.append(ApiEndpoint(route=route, method=method, name=name))
|
||||
endpoints.append(
|
||||
ApiEndpoint(route=route, method=method, name=name, descriptive_name=webmethod.descriptive_name)
|
||||
)
|
||||
|
||||
apis[api] = endpoints
|
||||
|
||||
return apis
|
||||
|
||||
|
||||
def initialize_endpoint_impls(impls):
|
||||
endpoints = get_all_api_endpoints()
|
||||
endpoint_impls = {}
|
||||
|
||||
def _convert_path_to_regex(path: str) -> str:
|
||||
# Convert {param} to named capture groups
|
||||
# handle {param:path} as well which allows for forward slashes in the param value
|
||||
pattern = re.sub(
|
||||
r"{(\w+)(?::path)?}",
|
||||
lambda m: f"(?P<{m.group(1)}>{'[^/]+' if not m.group(0).endswith(':path') else '.+'})",
|
||||
path,
|
||||
)
|
||||
|
||||
return f"^{pattern}$"
|
||||
|
||||
for api, api_endpoints in endpoints.items():
|
||||
if api not in impls:
|
||||
continue
|
||||
for endpoint in api_endpoints:
|
||||
impl = impls[api]
|
||||
func = getattr(impl, endpoint.name)
|
||||
if endpoint.method not in endpoint_impls:
|
||||
endpoint_impls[endpoint.method] = {}
|
||||
endpoint_impls[endpoint.method][_convert_path_to_regex(endpoint.route)] = (
|
||||
func,
|
||||
endpoint.descriptive_name or endpoint.route,
|
||||
)
|
||||
|
||||
return endpoint_impls
|
||||
|
||||
|
||||
def find_matching_endpoint(method, path, endpoint_impls):
|
||||
"""Find the matching endpoint implementation for a given method and path.
|
||||
|
||||
Args:
|
||||
method: HTTP method (GET, POST, etc.)
|
||||
path: URL path to match against
|
||||
endpoint_impls: A dictionary of endpoint implementations
|
||||
|
||||
Returns:
|
||||
A tuple of (endpoint_function, path_params, descriptive_name)
|
||||
|
||||
Raises:
|
||||
ValueError: If no matching endpoint is found
|
||||
"""
|
||||
impls = endpoint_impls.get(method.lower())
|
||||
if not impls:
|
||||
raise ValueError(f"No endpoint found for {path}")
|
||||
|
||||
for regex, (func, descriptive_name) in impls.items():
|
||||
match = re.match(regex, path)
|
||||
if match:
|
||||
# Extract named groups from the regex match
|
||||
path_params = match.groupdict()
|
||||
return func, path_params, descriptive_name
|
||||
|
||||
raise ValueError(f"No endpoint found for {path}")
|
||||
|
|
|
@ -32,6 +32,10 @@ from llama_stack.distribution.request_headers import (
|
|||
request_provider_data_context,
|
||||
)
|
||||
from llama_stack.distribution.resolver import InvalidProviderError
|
||||
from llama_stack.distribution.server.endpoints import (
|
||||
find_matching_endpoint,
|
||||
initialize_endpoint_impls,
|
||||
)
|
||||
from llama_stack.distribution.stack import (
|
||||
construct_stack,
|
||||
redact_sensitive_fields,
|
||||
|
@ -222,20 +226,18 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
|
|||
|
||||
|
||||
class TracingMiddleware:
|
||||
def __init__(self, app):
|
||||
def __init__(self, app, impls):
|
||||
self.app = app
|
||||
self.impls = impls
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
if scope.get("type") == "lifespan":
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
path = scope.get("path", "")
|
||||
|
||||
# Try to match the path to a route template
|
||||
route_template = self._match_path(path)
|
||||
|
||||
# Use the matched template or original path
|
||||
trace_path = route_template or path
|
||||
if not hasattr(self, "endpoint_impls"):
|
||||
self.endpoint_impls = initialize_endpoint_impls(self.impls)
|
||||
_, _, trace_path = find_matching_endpoint(scope.get("method", "GET"), path, self.endpoint_impls)
|
||||
|
||||
trace_context = await start_trace(trace_path, {"__location__": "server", "raw_path": path})
|
||||
|
||||
|
@ -251,36 +253,6 @@ class TracingMiddleware:
|
|||
finally:
|
||||
await end_trace()
|
||||
|
||||
def _match_path(self, path):
|
||||
"""Match a path to a route template using simple segment matching."""
|
||||
path_segments = path.split("/")
|
||||
|
||||
for route in self.app.app.routes:
|
||||
if not hasattr(route, "path"):
|
||||
continue
|
||||
|
||||
route_path = route.path
|
||||
route_segments = route_path.split("/")
|
||||
|
||||
# Skip if number of segments doesn't match
|
||||
if len(path_segments) != len(route_segments):
|
||||
continue
|
||||
|
||||
matches = True
|
||||
for path_seg, route_seg in zip(path_segments, route_segments, strict=True):
|
||||
# If route segment is a parameter (contains {...}), it matches anything
|
||||
if route_seg.startswith("{") and route_seg.endswith("}"):
|
||||
continue
|
||||
# Otherwise, segments must match exactly
|
||||
elif path_seg != route_seg:
|
||||
matches = False
|
||||
break
|
||||
|
||||
if matches:
|
||||
return route_path
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class ClientVersionMiddleware:
|
||||
def __init__(self, app):
|
||||
|
@ -399,7 +371,6 @@ def main():
|
|||
logger.info(yaml.dump(safe_config, indent=2))
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app.add_middleware(TracingMiddleware)
|
||||
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
|
||||
app.add_middleware(ClientVersionMiddleware)
|
||||
|
||||
|
@ -463,6 +434,7 @@ def main():
|
|||
app.exception_handler(Exception)(global_exception_handler)
|
||||
|
||||
app.__llama_stack_impls__ = impls
|
||||
app.add_middleware(TracingMiddleware, impls=impls)
|
||||
|
||||
import uvicorn
|
||||
|
||||
|
|
|
@ -180,25 +180,29 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
return messages
|
||||
|
||||
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
|
||||
await self._initialize_tools(request.toolgroups)
|
||||
async with tracing.span("create_and_execute_turn") as span:
|
||||
span = tracing.get_current_span()
|
||||
if span:
|
||||
span.set_attribute("session_id", request.session_id)
|
||||
span.set_attribute("agent_id", self.agent_id)
|
||||
span.set_attribute("request", request.model_dump_json())
|
||||
turn_id = str(uuid.uuid4())
|
||||
span.set_attribute("turn_id", turn_id)
|
||||
async for chunk in self._run_turn(request, turn_id):
|
||||
yield chunk
|
||||
|
||||
await self._initialize_tools(request.toolgroups)
|
||||
async for chunk in self._run_turn(request, turn_id):
|
||||
yield chunk
|
||||
|
||||
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
|
||||
await self._initialize_tools()
|
||||
async with tracing.span("resume_turn") as span:
|
||||
span = tracing.get_current_span()
|
||||
if span:
|
||||
span.set_attribute("agent_id", self.agent_id)
|
||||
span.set_attribute("session_id", request.session_id)
|
||||
span.set_attribute("turn_id", request.turn_id)
|
||||
span.set_attribute("request", request.model_dump_json())
|
||||
async for chunk in self._run_turn(request):
|
||||
yield chunk
|
||||
span.set_attribute("turn_id", request.turn_id)
|
||||
|
||||
await self._initialize_tools()
|
||||
async for chunk in self._run_turn(request):
|
||||
yield chunk
|
||||
|
||||
async def _run_turn(
|
||||
self,
|
||||
|
|
|
@ -18,6 +18,8 @@ class WebMethod:
|
|||
response_examples: Optional[List[Any]] = None
|
||||
method: Optional[str] = None
|
||||
raw_bytes_request_body: Optional[bool] = False
|
||||
# A descriptive name of the corresponding span created by tracing
|
||||
descriptive_name: Optional[str] = None
|
||||
|
||||
|
||||
class HasWebMethod(Protocol):
|
||||
|
@ -34,6 +36,7 @@ def webmethod(
|
|||
request_examples: Optional[List[Any]] = None,
|
||||
response_examples: Optional[List[Any]] = None,
|
||||
raw_bytes_request_body: Optional[bool] = False,
|
||||
descriptive_name: Optional[str] = None,
|
||||
) -> Callable[[T], T]:
|
||||
"""
|
||||
Decorator that supplies additional metadata to an endpoint operation function.
|
||||
|
@ -52,6 +55,7 @@ def webmethod(
|
|||
request_examples=request_examples,
|
||||
response_examples=response_examples,
|
||||
raw_bytes_request_body=raw_bytes_request_body,
|
||||
descriptive_name=descriptive_name,
|
||||
)
|
||||
return cls
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue