mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-08 14:54:35 +00:00
Simplified Telemetry API and tying it to logger (#57)
* Simplified Telemetry API and tying it to logger * small update which adds a METRIC type * move span events one level down into structured log events --------- Co-authored-by: Ashwin Bharambe <ashwin@meta.com>
This commit is contained in:
parent
1433aaf9f7
commit
191cd28831
15 changed files with 524 additions and 162 deletions
|
@ -19,6 +19,7 @@ class Api(Enum):
|
|||
safety = "safety"
|
||||
agentic_system = "agentic_system"
|
||||
memory = "memory"
|
||||
telemetry = "telemetry"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
@ -4,17 +4,15 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
from typing import Dict, List
|
||||
|
||||
from llama_toolchain.agentic_system.api import AgenticSystem
|
||||
from llama_toolchain.agentic_system.providers import available_agentic_system_providers
|
||||
from llama_toolchain.inference.api import Inference
|
||||
from llama_toolchain.inference.providers import available_inference_providers
|
||||
from llama_toolchain.memory.api import Memory
|
||||
from llama_toolchain.memory.providers import available_memory_providers
|
||||
from llama_toolchain.safety.api import Safety
|
||||
from llama_toolchain.safety.providers import available_safety_providers
|
||||
from llama_toolchain.telemetry.api import Telemetry
|
||||
|
||||
from .datatypes import (
|
||||
Api,
|
||||
|
@ -44,7 +42,7 @@ def distribution_dependencies(distribution: DistributionSpec) -> List[str]:
|
|||
|
||||
|
||||
def stack_apis() -> List[Api]:
|
||||
return [Api.inference, Api.safety, Api.agentic_system, Api.memory]
|
||||
return [v for v in Api]
|
||||
|
||||
|
||||
def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
||||
|
@ -55,6 +53,7 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
|||
Api.safety: Safety,
|
||||
Api.agentic_system: AgenticSystem,
|
||||
Api.memory: Memory,
|
||||
Api.telemetry: Telemetry,
|
||||
}
|
||||
|
||||
for api, protocol in protocols.items():
|
||||
|
@ -82,20 +81,13 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
|||
|
||||
|
||||
def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]:
|
||||
inference_providers_by_id = {
|
||||
a.provider_type: a for a in available_inference_providers()
|
||||
}
|
||||
safety_providers_by_id = {a.provider_type: a for a in available_safety_providers()}
|
||||
agentic_system_providers_by_id = {
|
||||
a.provider_type: a for a in available_agentic_system_providers()
|
||||
}
|
||||
ret = {}
|
||||
for api in stack_apis():
|
||||
name = api.name.lower()
|
||||
module = importlib.import_module(f"llama_toolchain.{name}.providers")
|
||||
ret[api] = {
|
||||
"remote": remote_provider_spec(api),
|
||||
**{a.provider_type: a for a in module.available_providers()},
|
||||
}
|
||||
|
||||
ret = {
|
||||
Api.inference: inference_providers_by_id,
|
||||
Api.safety: safety_providers_by_id,
|
||||
Api.agentic_system: agentic_system_providers_by_id,
|
||||
Api.memory: {a.provider_type: a for a in available_memory_providers()},
|
||||
}
|
||||
for k, v in ret.items():
|
||||
v["remote"] = remote_provider_spec(k)
|
||||
return ret
|
||||
|
|
|
@ -21,12 +21,16 @@ def available_distribution_specs() -> List[DistributionSpec]:
|
|||
Api.memory: "meta-reference-faiss",
|
||||
Api.safety: "meta-reference",
|
||||
Api.agentic_system: "meta-reference",
|
||||
Api.telemetry: "console",
|
||||
},
|
||||
),
|
||||
DistributionSpec(
|
||||
distribution_type="remote",
|
||||
description="Point to remote services for all llama stack APIs",
|
||||
providers={x: "remote" for x in Api},
|
||||
providers={
|
||||
**{x: "remote" for x in Api},
|
||||
Api.telemetry: "console",
|
||||
},
|
||||
),
|
||||
DistributionSpec(
|
||||
distribution_type="local-ollama",
|
||||
|
@ -36,6 +40,7 @@ def available_distribution_specs() -> List[DistributionSpec]:
|
|||
Api.safety: "meta-reference",
|
||||
Api.agentic_system: "meta-reference",
|
||||
Api.memory: "meta-reference-faiss",
|
||||
Api.telemetry: "console",
|
||||
},
|
||||
),
|
||||
DistributionSpec(
|
||||
|
@ -46,6 +51,7 @@ def available_distribution_specs() -> List[DistributionSpec]:
|
|||
Api.safety: "meta-reference",
|
||||
Api.agentic_system: "meta-reference",
|
||||
Api.memory: "meta-reference-faiss",
|
||||
Api.telemetry: "console",
|
||||
},
|
||||
),
|
||||
DistributionSpec(
|
||||
|
@ -56,6 +62,7 @@ def available_distribution_specs() -> List[DistributionSpec]:
|
|||
Api.safety: "meta-reference",
|
||||
Api.agentic_system: "meta-reference",
|
||||
Api.memory: "meta-reference-faiss",
|
||||
Api.telemetry: "console",
|
||||
},
|
||||
),
|
||||
]
|
||||
|
|
|
@ -38,6 +38,13 @@ from pydantic import BaseModel, ValidationError
|
|||
from termcolor import cprint
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_toolchain.telemetry.tracing import (
|
||||
end_trace,
|
||||
setup_logger,
|
||||
SpanStatus,
|
||||
start_trace,
|
||||
)
|
||||
|
||||
from .datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
|
||||
from .distribution import api_endpoints, api_providers
|
||||
from .dynamic import instantiate_provider
|
||||
|
@ -88,6 +95,8 @@ async def passthrough(
|
|||
downstream_url: str,
|
||||
downstream_headers: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
await start_trace(request.path, {"downstream_url": downstream_url})
|
||||
|
||||
headers = dict(request.headers)
|
||||
headers.pop("host", None)
|
||||
headers.update(downstream_headers or {})
|
||||
|
@ -95,6 +104,7 @@ async def passthrough(
|
|||
content = await request.body()
|
||||
|
||||
client = httpx.AsyncClient()
|
||||
erred = False
|
||||
try:
|
||||
req = client.build_request(
|
||||
method=request.method,
|
||||
|
@ -120,17 +130,25 @@ async def passthrough(
|
|||
)
|
||||
|
||||
except httpx.ReadTimeout:
|
||||
erred = True
|
||||
return Response(content="Downstream server timed out", status_code=504)
|
||||
except httpx.NetworkError as e:
|
||||
erred = True
|
||||
return Response(content=f"Network error: {str(e)}", status_code=502)
|
||||
except httpx.TooManyRedirects:
|
||||
erred = True
|
||||
return Response(content="Too many redirects", status_code=502)
|
||||
except SSLError as e:
|
||||
erred = True
|
||||
return Response(content=f"SSL error: {str(e)}", status_code=502)
|
||||
except httpx.HTTPStatusError as e:
|
||||
erred = True
|
||||
return Response(content=str(e), status_code=e.response.status_code)
|
||||
except Exception as e:
|
||||
erred = True
|
||||
return Response(content=f"Unexpected error: {str(e)}", status_code=500)
|
||||
finally:
|
||||
await end_trace(SpanStatus.OK if not erred else SpanStatus.ERROR)
|
||||
|
||||
|
||||
def handle_sigint(*args, **kwargs):
|
||||
|
@ -159,7 +177,7 @@ def create_dynamic_passthrough(
|
|||
|
||||
def create_dynamic_typed_route(func: Any, method: str):
|
||||
hints = get_type_hints(func)
|
||||
response_model = hints["return"]
|
||||
response_model = hints.get("return")
|
||||
|
||||
# NOTE: I think it is better to just add a method within each Api
|
||||
# "Protocol" / adapter-impl to tell what sort of a response this request
|
||||
|
@ -170,6 +188,8 @@ def create_dynamic_typed_route(func: Any, method: str):
|
|||
if is_streaming:
|
||||
|
||||
async def endpoint(**kwargs):
|
||||
await start_trace(func.__name__)
|
||||
|
||||
async def sse_generator(event_gen):
|
||||
try:
|
||||
async for item in event_gen:
|
||||
|
@ -187,6 +207,8 @@ def create_dynamic_typed_route(func: Any, method: str):
|
|||
},
|
||||
}
|
||||
)
|
||||
finally:
|
||||
await end_trace()
|
||||
|
||||
return StreamingResponse(
|
||||
sse_generator(func(**kwargs)), media_type="text/event-stream"
|
||||
|
@ -195,6 +217,7 @@ def create_dynamic_typed_route(func: Any, method: str):
|
|||
else:
|
||||
|
||||
async def endpoint(**kwargs):
|
||||
await start_trace(func.__name__)
|
||||
try:
|
||||
return (
|
||||
await func(**kwargs)
|
||||
|
@ -204,6 +227,8 @@ def create_dynamic_typed_route(func: Any, method: str):
|
|||
except Exception as e:
|
||||
traceback.print_exception(e)
|
||||
raise translate_exception(e) from e
|
||||
finally:
|
||||
await end_trace()
|
||||
|
||||
sig = inspect.signature(func)
|
||||
if method == "post":
|
||||
|
@ -293,6 +318,8 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
|||
provider_specs[api] = providers[provider_type]
|
||||
|
||||
impls = resolve_impls(provider_specs, config)
|
||||
if Api.telemetry in impls:
|
||||
setup_logger(impls[Api.telemetry])
|
||||
|
||||
for provider_spec in provider_specs.values():
|
||||
api = provider_spec.api
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue