working nested tracing

This commit is contained in:
Dinesh Yeduguru 2024-11-22 14:59:35 -08:00
parent 9cebac8a3c
commit d2e6e59647
2 changed files with 19 additions and 74 deletions

View file

@ -17,13 +17,11 @@ import warnings
from contextlib import asynccontextmanager
from pathlib import Path
from ssl import SSLError
from typing import Any, Dict, Optional
from typing import Any, Union
import httpx
import yaml
from fastapi import Body, FastAPI, HTTPException, Request, Response
from fastapi import Body, FastAPI, HTTPException, Request
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, ValidationError
@ -35,7 +33,6 @@ from llama_stack.distribution.distribution import builtin_automatically_routed_a
from llama_stack.providers.utils.telemetry.tracing import (
end_trace,
setup_logger,
SpanStatus,
start_trace,
)
from llama_stack.distribution.datatypes import * # noqa: F403
@ -118,67 +115,6 @@ def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidatio
)
async def passthrough(
request: Request,
downstream_url: str,
downstream_headers: Optional[Dict[str, str]] = None,
):
await start_trace(request.path, {"downstream_url": downstream_url})
headers = dict(request.headers)
headers.pop("host", None)
headers.update(downstream_headers or {})
content = await request.body()
client = httpx.AsyncClient()
erred = False
try:
req = client.build_request(
method=request.method,
url=downstream_url,
headers=headers,
content=content,
params=request.query_params,
)
response = await client.send(req, stream=True)
async def stream_response():
async for chunk in response.aiter_raw(chunk_size=64):
yield chunk
await response.aclose()
await client.aclose()
return StreamingResponse(
stream_response(),
status_code=response.status_code,
headers=dict(response.headers),
media_type=response.headers.get("content-type"),
)
except httpx.ReadTimeout:
erred = True
return Response(content="Downstream server timed out", status_code=504)
except httpx.NetworkError as e:
erred = True
return Response(content=f"Network error: {str(e)}", status_code=502)
except httpx.TooManyRedirects:
erred = True
return Response(content="Too many redirects", status_code=502)
except SSLError as e:
erred = True
return Response(content=f"SSL error: {str(e)}", status_code=502)
except httpx.HTTPStatusError as e:
erred = True
return Response(content=str(e), status_code=e.response.status_code)
except Exception as e:
erred = True
return Response(content=f"Unexpected error: {str(e)}", status_code=500)
finally:
await end_trace(SpanStatus.OK if not erred else SpanStatus.ERROR)
def handle_sigint(app, *args, **kwargs):
print("SIGINT or CTRL-C detected. Exiting gracefully...")
@ -217,7 +153,6 @@ async def maybe_await(value):
async def sse_generator(event_gen):
await start_trace("sse_generator")
try:
event_gen = await event_gen
async for item in event_gen:
@ -235,14 +170,10 @@ async def sse_generator(event_gen):
},
}
)
finally:
await end_trace()
def create_dynamic_typed_route(func: Any, method: str):
async def endpoint(request: Request, **kwargs):
await start_trace(func.__name__)
set_request_provider_data(request.headers)
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
@ -257,8 +188,6 @@ def create_dynamic_typed_route(func: Any, method: str):
except Exception as e:
traceback.print_exception(e)
raise translate_exception(e) from e
finally:
await end_trace()
sig = inspect.signature(func)
new_params = [
@ -282,6 +211,20 @@ def create_dynamic_typed_route(func: Any, method: str):
return endpoint
# Add new middleware class
class TracingMiddleware:
def __init__(self, app):
self.app = app
async def __call__(self, scope, receive, send):
path = scope["path"]
await start_trace(path)
try:
return await self.app(scope, receive, send)
finally:
await end_trace()
def main():
"""Start the LlamaStack server."""
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
@ -338,6 +281,7 @@ def main():
print(yaml.dump(config.model_dump(), indent=2))
app = FastAPI(lifespan=lifespan)
app.add_middleware(TracingMiddleware)
try:
impls = asyncio.run(construct_stack(config))

View file

@ -123,7 +123,7 @@ def setup_logger(api: Telemetry, level: int = logging.INFO):
logger.addHandler(TelemetryHandler())
async def start_trace(name: str, attributes: Dict[str, Any] = None):
async def start_trace(name: str, attributes: Dict[str, Any] = None) -> TraceContext:
global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER
if BACKGROUND_LOGGER is None:
@ -135,6 +135,7 @@ async def start_trace(name: str, attributes: Dict[str, Any] = None):
context.push_span(name, {"__root__": True, **(attributes or {})})
CURRENT_TRACE_CONTEXT = context
return context
async def end_trace(status: SpanStatus = SpanStatus.OK):