mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-01 16:24:44 +00:00
working nested tracing
This commit is contained in:
parent
9cebac8a3c
commit
d2e6e59647
2 changed files with 19 additions and 74 deletions
|
@ -17,13 +17,11 @@ import warnings
|
|||
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from ssl import SSLError
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Union
|
||||
|
||||
import httpx
|
||||
import yaml
|
||||
|
||||
from fastapi import Body, FastAPI, HTTPException, Request, Response
|
||||
from fastapi import Body, FastAPI, HTTPException, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
@ -35,7 +33,6 @@ from llama_stack.distribution.distribution import builtin_automatically_routed_a
|
|||
from llama_stack.providers.utils.telemetry.tracing import (
|
||||
end_trace,
|
||||
setup_logger,
|
||||
SpanStatus,
|
||||
start_trace,
|
||||
)
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
@ -118,67 +115,6 @@ def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidatio
|
|||
)
|
||||
|
||||
|
||||
async def passthrough(
|
||||
request: Request,
|
||||
downstream_url: str,
|
||||
downstream_headers: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
await start_trace(request.path, {"downstream_url": downstream_url})
|
||||
|
||||
headers = dict(request.headers)
|
||||
headers.pop("host", None)
|
||||
headers.update(downstream_headers or {})
|
||||
|
||||
content = await request.body()
|
||||
|
||||
client = httpx.AsyncClient()
|
||||
erred = False
|
||||
try:
|
||||
req = client.build_request(
|
||||
method=request.method,
|
||||
url=downstream_url,
|
||||
headers=headers,
|
||||
content=content,
|
||||
params=request.query_params,
|
||||
)
|
||||
response = await client.send(req, stream=True)
|
||||
|
||||
async def stream_response():
|
||||
async for chunk in response.aiter_raw(chunk_size=64):
|
||||
yield chunk
|
||||
|
||||
await response.aclose()
|
||||
await client.aclose()
|
||||
|
||||
return StreamingResponse(
|
||||
stream_response(),
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers),
|
||||
media_type=response.headers.get("content-type"),
|
||||
)
|
||||
|
||||
except httpx.ReadTimeout:
|
||||
erred = True
|
||||
return Response(content="Downstream server timed out", status_code=504)
|
||||
except httpx.NetworkError as e:
|
||||
erred = True
|
||||
return Response(content=f"Network error: {str(e)}", status_code=502)
|
||||
except httpx.TooManyRedirects:
|
||||
erred = True
|
||||
return Response(content="Too many redirects", status_code=502)
|
||||
except SSLError as e:
|
||||
erred = True
|
||||
return Response(content=f"SSL error: {str(e)}", status_code=502)
|
||||
except httpx.HTTPStatusError as e:
|
||||
erred = True
|
||||
return Response(content=str(e), status_code=e.response.status_code)
|
||||
except Exception as e:
|
||||
erred = True
|
||||
return Response(content=f"Unexpected error: {str(e)}", status_code=500)
|
||||
finally:
|
||||
await end_trace(SpanStatus.OK if not erred else SpanStatus.ERROR)
|
||||
|
||||
|
||||
def handle_sigint(app, *args, **kwargs):
|
||||
print("SIGINT or CTRL-C detected. Exiting gracefully...")
|
||||
|
||||
|
@ -217,7 +153,6 @@ async def maybe_await(value):
|
|||
|
||||
|
||||
async def sse_generator(event_gen):
|
||||
await start_trace("sse_generator")
|
||||
try:
|
||||
event_gen = await event_gen
|
||||
async for item in event_gen:
|
||||
|
@ -235,14 +170,10 @@ async def sse_generator(event_gen):
|
|||
},
|
||||
}
|
||||
)
|
||||
finally:
|
||||
await end_trace()
|
||||
|
||||
|
||||
def create_dynamic_typed_route(func: Any, method: str):
|
||||
async def endpoint(request: Request, **kwargs):
|
||||
await start_trace(func.__name__)
|
||||
|
||||
set_request_provider_data(request.headers)
|
||||
|
||||
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
||||
|
@ -257,8 +188,6 @@ def create_dynamic_typed_route(func: Any, method: str):
|
|||
except Exception as e:
|
||||
traceback.print_exception(e)
|
||||
raise translate_exception(e) from e
|
||||
finally:
|
||||
await end_trace()
|
||||
|
||||
sig = inspect.signature(func)
|
||||
new_params = [
|
||||
|
@ -282,6 +211,20 @@ def create_dynamic_typed_route(func: Any, method: str):
|
|||
return endpoint
|
||||
|
||||
|
||||
# Add new middleware class
|
||||
class TracingMiddleware:
|
||||
def __init__(self, app):
|
||||
self.app = app
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
path = scope["path"]
|
||||
await start_trace(path)
|
||||
try:
|
||||
return await self.app(scope, receive, send)
|
||||
finally:
|
||||
await end_trace()
|
||||
|
||||
|
||||
def main():
|
||||
"""Start the LlamaStack server."""
|
||||
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
|
||||
|
@ -338,6 +281,7 @@ def main():
|
|||
print(yaml.dump(config.model_dump(), indent=2))
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app.add_middleware(TracingMiddleware)
|
||||
|
||||
try:
|
||||
impls = asyncio.run(construct_stack(config))
|
||||
|
|
|
@ -123,7 +123,7 @@ def setup_logger(api: Telemetry, level: int = logging.INFO):
|
|||
logger.addHandler(TelemetryHandler())
|
||||
|
||||
|
||||
async def start_trace(name: str, attributes: Dict[str, Any] = None):
|
||||
async def start_trace(name: str, attributes: Dict[str, Any] = None) -> TraceContext:
|
||||
global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER
|
||||
|
||||
if BACKGROUND_LOGGER is None:
|
||||
|
@ -135,6 +135,7 @@ async def start_trace(name: str, attributes: Dict[str, Any] = None):
|
|||
context.push_span(name, {"__root__": True, **(attributes or {})})
|
||||
|
||||
CURRENT_TRACE_CONTEXT = context
|
||||
return context
|
||||
|
||||
|
||||
async def end_trace(status: SpanStatus = SpanStatus.OK):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue