working nested tracing

This commit is contained in:
Dinesh Yeduguru 2024-11-22 14:59:35 -08:00
parent 9cebac8a3c
commit d2e6e59647
2 changed files with 19 additions and 74 deletions

View file

@ -17,13 +17,11 @@ import warnings
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from pathlib import Path from pathlib import Path
from ssl import SSLError from typing import Any, Union
from typing import Any, Dict, Optional
import httpx
import yaml import yaml
from fastapi import Body, FastAPI, HTTPException, Request, Response from fastapi import Body, FastAPI, HTTPException, Request
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, ValidationError from pydantic import BaseModel, ValidationError
@ -35,7 +33,6 @@ from llama_stack.distribution.distribution import builtin_automatically_routed_a
from llama_stack.providers.utils.telemetry.tracing import ( from llama_stack.providers.utils.telemetry.tracing import (
end_trace, end_trace,
setup_logger, setup_logger,
SpanStatus,
start_trace, start_trace,
) )
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403
@ -118,67 +115,6 @@ def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidatio
) )
async def passthrough(
request: Request,
downstream_url: str,
downstream_headers: Optional[Dict[str, str]] = None,
):
await start_trace(request.path, {"downstream_url": downstream_url})
headers = dict(request.headers)
headers.pop("host", None)
headers.update(downstream_headers or {})
content = await request.body()
client = httpx.AsyncClient()
erred = False
try:
req = client.build_request(
method=request.method,
url=downstream_url,
headers=headers,
content=content,
params=request.query_params,
)
response = await client.send(req, stream=True)
async def stream_response():
async for chunk in response.aiter_raw(chunk_size=64):
yield chunk
await response.aclose()
await client.aclose()
return StreamingResponse(
stream_response(),
status_code=response.status_code,
headers=dict(response.headers),
media_type=response.headers.get("content-type"),
)
except httpx.ReadTimeout:
erred = True
return Response(content="Downstream server timed out", status_code=504)
except httpx.NetworkError as e:
erred = True
return Response(content=f"Network error: {str(e)}", status_code=502)
except httpx.TooManyRedirects:
erred = True
return Response(content="Too many redirects", status_code=502)
except SSLError as e:
erred = True
return Response(content=f"SSL error: {str(e)}", status_code=502)
except httpx.HTTPStatusError as e:
erred = True
return Response(content=str(e), status_code=e.response.status_code)
except Exception as e:
erred = True
return Response(content=f"Unexpected error: {str(e)}", status_code=500)
finally:
await end_trace(SpanStatus.OK if not erred else SpanStatus.ERROR)
def handle_sigint(app, *args, **kwargs): def handle_sigint(app, *args, **kwargs):
print("SIGINT or CTRL-C detected. Exiting gracefully...") print("SIGINT or CTRL-C detected. Exiting gracefully...")
@ -217,7 +153,6 @@ async def maybe_await(value):
async def sse_generator(event_gen): async def sse_generator(event_gen):
await start_trace("sse_generator")
try: try:
event_gen = await event_gen event_gen = await event_gen
async for item in event_gen: async for item in event_gen:
@ -235,14 +170,10 @@ async def sse_generator(event_gen):
}, },
} }
) )
finally:
await end_trace()
def create_dynamic_typed_route(func: Any, method: str): def create_dynamic_typed_route(func: Any, method: str):
async def endpoint(request: Request, **kwargs): async def endpoint(request: Request, **kwargs):
await start_trace(func.__name__)
set_request_provider_data(request.headers) set_request_provider_data(request.headers)
is_streaming = is_streaming_request(func.__name__, request, **kwargs) is_streaming = is_streaming_request(func.__name__, request, **kwargs)
@ -257,8 +188,6 @@ def create_dynamic_typed_route(func: Any, method: str):
except Exception as e: except Exception as e:
traceback.print_exception(e) traceback.print_exception(e)
raise translate_exception(e) from e raise translate_exception(e) from e
finally:
await end_trace()
sig = inspect.signature(func) sig = inspect.signature(func)
new_params = [ new_params = [
@ -282,6 +211,20 @@ def create_dynamic_typed_route(func: Any, method: str):
return endpoint return endpoint
# Add new middleware class
class TracingMiddleware:
def __init__(self, app):
self.app = app
async def __call__(self, scope, receive, send):
path = scope["path"]
await start_trace(path)
try:
return await self.app(scope, receive, send)
finally:
await end_trace()
def main(): def main():
"""Start the LlamaStack server.""" """Start the LlamaStack server."""
parser = argparse.ArgumentParser(description="Start the LlamaStack server.") parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
@ -338,6 +281,7 @@ def main():
print(yaml.dump(config.model_dump(), indent=2)) print(yaml.dump(config.model_dump(), indent=2))
app = FastAPI(lifespan=lifespan) app = FastAPI(lifespan=lifespan)
app.add_middleware(TracingMiddleware)
try: try:
impls = asyncio.run(construct_stack(config)) impls = asyncio.run(construct_stack(config))

View file

@ -123,7 +123,7 @@ def setup_logger(api: Telemetry, level: int = logging.INFO):
logger.addHandler(TelemetryHandler()) logger.addHandler(TelemetryHandler())
async def start_trace(name: str, attributes: Dict[str, Any] = None): async def start_trace(name: str, attributes: Dict[str, Any] = None) -> TraceContext:
global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER
if BACKGROUND_LOGGER is None: if BACKGROUND_LOGGER is None:
@ -135,6 +135,7 @@ async def start_trace(name: str, attributes: Dict[str, Any] = None):
context.push_span(name, {"__root__": True, **(attributes or {})}) context.push_span(name, {"__root__": True, **(attributes or {})})
CURRENT_TRACE_CONTEXT = context CURRENT_TRACE_CONTEXT = context
return context
async def end_trace(status: SpanStatus = SpanStatus.OK): async def end_trace(status: SpanStatus = SpanStatus.OK):