Simplified Telemetry API and tying it to logger (#57)

* Simplified Telemetry API and tying it to logger

* small update which adds a METRIC type

* move span events one level down into structured log events

---------

Co-authored-by: Ashwin Bharambe <ashwin@meta.com>
This commit is contained in:
Ashwin Bharambe 2024-09-11 14:25:37 -07:00 committed by GitHub
parent 1433aaf9f7
commit 191cd28831
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 524 additions and 162 deletions

View file

@ -38,6 +38,13 @@ from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from llama_toolchain.telemetry.tracing import (
end_trace,
setup_logger,
SpanStatus,
start_trace,
)
from .datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
from .distribution import api_endpoints, api_providers
from .dynamic import instantiate_provider
@ -88,6 +95,8 @@ async def passthrough(
downstream_url: str,
downstream_headers: Optional[Dict[str, str]] = None,
):
await start_trace(request.path, {"downstream_url": downstream_url})
headers = dict(request.headers)
headers.pop("host", None)
headers.update(downstream_headers or {})
@ -95,6 +104,7 @@ async def passthrough(
content = await request.body()
client = httpx.AsyncClient()
erred = False
try:
req = client.build_request(
method=request.method,
@ -120,17 +130,25 @@ async def passthrough(
)
except httpx.ReadTimeout:
erred = True
return Response(content="Downstream server timed out", status_code=504)
except httpx.NetworkError as e:
erred = True
return Response(content=f"Network error: {str(e)}", status_code=502)
except httpx.TooManyRedirects:
erred = True
return Response(content="Too many redirects", status_code=502)
except SSLError as e:
erred = True
return Response(content=f"SSL error: {str(e)}", status_code=502)
except httpx.HTTPStatusError as e:
erred = True
return Response(content=str(e), status_code=e.response.status_code)
except Exception as e:
erred = True
return Response(content=f"Unexpected error: {str(e)}", status_code=500)
finally:
await end_trace(SpanStatus.OK if not erred else SpanStatus.ERROR)
def handle_sigint(*args, **kwargs):
@ -159,7 +177,7 @@ def create_dynamic_passthrough(
def create_dynamic_typed_route(func: Any, method: str):
hints = get_type_hints(func)
response_model = hints["return"]
response_model = hints.get("return")
# NOTE: I think it is better to just add a method within each Api
# "Protocol" / adapter-impl to tell what sort of a response this request
@ -170,6 +188,8 @@ def create_dynamic_typed_route(func: Any, method: str):
if is_streaming:
async def endpoint(**kwargs):
await start_trace(func.__name__)
async def sse_generator(event_gen):
try:
async for item in event_gen:
@ -187,6 +207,8 @@ def create_dynamic_typed_route(func: Any, method: str):
},
}
)
finally:
await end_trace()
return StreamingResponse(
sse_generator(func(**kwargs)), media_type="text/event-stream"
@ -195,6 +217,7 @@ def create_dynamic_typed_route(func: Any, method: str):
else:
async def endpoint(**kwargs):
await start_trace(func.__name__)
try:
return (
await func(**kwargs)
@ -204,6 +227,8 @@ def create_dynamic_typed_route(func: Any, method: str):
except Exception as e:
traceback.print_exception(e)
raise translate_exception(e) from e
finally:
await end_trace()
sig = inspect.signature(func)
if method == "post":
@ -293,6 +318,8 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
provider_specs[api] = providers[provider_type]
impls = resolve_impls(provider_specs, config)
if Api.telemetry in impls:
setup_logger(impls[Api.telemetry])
for provider_spec in provider_specs.values():
api = provider_spec.api