mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
Merge remote-tracking branch 'upstream/main' into rm-build
This commit is contained in:
commit
39ad54696c
81 changed files with 585 additions and 4236 deletions
|
|
@ -191,22 +191,6 @@ class DistributionSpec(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class TelemetryConfig(BaseModel):
|
||||
"""
|
||||
Configuration for telemetry.
|
||||
|
||||
Llama Stack uses OpenTelemetry for telemetry. Please refer to https://opentelemetry.io/docs/languages/sdk-configuration/
|
||||
for env variables to configure the OpenTelemetry SDK.
|
||||
|
||||
Example:
|
||||
```bash
|
||||
OTEL_SERVICE_NAME=llama-stack OTEL_EXPORTER_OTLP_ENDPOINT=http://localhost:4318 uv run llama stack run starter
|
||||
```
|
||||
"""
|
||||
|
||||
enabled: bool = Field(default=False, description="enable or disable telemetry")
|
||||
|
||||
|
||||
class OAuth2JWKSConfig(BaseModel):
|
||||
# The JWKS URI for collecting public keys
|
||||
uri: str
|
||||
|
|
@ -528,8 +512,6 @@ can be instantiated multiple times (with different configs) if necessary.
|
|||
|
||||
logging: LoggingConfig | None = Field(default=None, description="Configuration for Llama Stack Logging")
|
||||
|
||||
telemetry: TelemetryConfig = Field(default_factory=TelemetryConfig, description="Configuration for telemetry")
|
||||
|
||||
server: ServerConfig = Field(
|
||||
default_factory=ServerConfig,
|
||||
description="Configuration for the HTTP(S) server",
|
||||
|
|
|
|||
|
|
@ -45,8 +45,6 @@ from llama_stack.core.request_headers import PROVIDER_DATA_VAR, request_provider
|
|||
from llama_stack.core.resolver import ProviderRegistry
|
||||
from llama_stack.core.server.routes import RouteImpls, find_matching_route, initialize_route_impls
|
||||
from llama_stack.core.stack import Stack, get_stack_run_config_from_distro, replace_env_vars
|
||||
from llama_stack.core.telemetry import Telemetry
|
||||
from llama_stack.core.telemetry.tracing import CURRENT_TRACE_CONTEXT, end_trace, setup_logger, start_trace
|
||||
from llama_stack.core.utils.config import redact_sensitive_fields
|
||||
from llama_stack.core.utils.context import preserve_contexts_async_generator
|
||||
from llama_stack.core.utils.exec import in_notebook
|
||||
|
|
@ -203,13 +201,6 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
super().__init__()
|
||||
# Initialize logging from environment variables first
|
||||
setup_logging()
|
||||
|
||||
# when using the library client, we should not log to console since many
|
||||
# of our logs are intended for server-side usage
|
||||
if sinks_from_env := os.environ.get("TELEMETRY_SINKS", None):
|
||||
current_sinks = sinks_from_env.strip().lower().split(",")
|
||||
os.environ["TELEMETRY_SINKS"] = ",".join(sink for sink in current_sinks if sink != "console")
|
||||
|
||||
if in_notebook():
|
||||
import nest_asyncio
|
||||
|
||||
|
|
@ -281,8 +272,6 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
raise _e
|
||||
|
||||
assert self.impls is not None
|
||||
if self.config.telemetry.enabled:
|
||||
setup_logger(Telemetry())
|
||||
|
||||
if not os.environ.get("PYTEST_CURRENT_TEST"):
|
||||
console = Console()
|
||||
|
|
@ -378,13 +367,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
body, field_names = self._handle_file_uploads(options, body)
|
||||
|
||||
body = self._convert_body(matched_func, body, exclude_params=set(field_names))
|
||||
|
||||
trace_path = webmethod.descriptive_name or route_path
|
||||
await start_trace(trace_path, {"__location__": "library_client"})
|
||||
try:
|
||||
result = await matched_func(**body)
|
||||
finally:
|
||||
await end_trace()
|
||||
result = await matched_func(**body)
|
||||
|
||||
# Handle FastAPI Response objects (e.g., from file content retrieval)
|
||||
if isinstance(result, FastAPIResponse):
|
||||
|
|
@ -443,19 +426,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
# Prepare body for the function call (handles both Pydantic and traditional params)
|
||||
body = self._convert_body(func, body)
|
||||
|
||||
trace_path = webmethod.descriptive_name or route_path
|
||||
await start_trace(trace_path, {"__location__": "library_client"})
|
||||
|
||||
async def gen():
|
||||
try:
|
||||
async for chunk in await func(**body):
|
||||
data = json.dumps(convert_pydantic_to_json_value(chunk))
|
||||
sse_event = f"data: {data}\n\n"
|
||||
yield sse_event.encode("utf-8")
|
||||
finally:
|
||||
await end_trace()
|
||||
async for chunk in await func(**body):
|
||||
data = json.dumps(convert_pydantic_to_json_value(chunk))
|
||||
sse_event = f"data: {data}\n\n"
|
||||
yield sse_event.encode("utf-8")
|
||||
|
||||
wrapped_gen = preserve_contexts_async_generator(gen(), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR])
|
||||
wrapped_gen = preserve_contexts_async_generator(gen(), [PROVIDER_DATA_VAR])
|
||||
|
||||
mock_response = httpx.Response(
|
||||
status_code=httpx.codes.OK,
|
||||
|
|
|
|||
|
|
@ -392,8 +392,6 @@ async def instantiate_provider(
|
|||
args = [config, deps]
|
||||
if "policy" in inspect.signature(getattr(module, method)).parameters:
|
||||
args.append(policy)
|
||||
if "telemetry_enabled" in inspect.signature(getattr(module, method)).parameters and run_config.telemetry:
|
||||
args.append(run_config.telemetry.enabled)
|
||||
|
||||
fn = getattr(module, method)
|
||||
impl = await fn(*args)
|
||||
|
|
@ -401,18 +399,6 @@ async def instantiate_provider(
|
|||
impl.__provider_spec__ = provider_spec
|
||||
impl.__provider_config__ = config
|
||||
|
||||
# Apply tracing if telemetry is enabled and any base class has __marked_for_tracing__ marker
|
||||
if run_config.telemetry.enabled:
|
||||
traced_classes = [
|
||||
base for base in reversed(impl.__class__.__mro__) if getattr(base, "__marked_for_tracing__", False)
|
||||
]
|
||||
|
||||
if traced_classes:
|
||||
from llama_stack.core.telemetry.trace_protocol import trace_protocol
|
||||
|
||||
for cls in traced_classes:
|
||||
trace_protocol(cls)
|
||||
|
||||
protocols = api_protocol_map_for_compliance_check(run_config)
|
||||
additional_protocols = additional_protocols_map()
|
||||
# TODO: check compliance for special tool groups
|
||||
|
|
|
|||
|
|
@ -85,8 +85,6 @@ async def get_auto_router_impl(
|
|||
)
|
||||
await inference_store.initialize()
|
||||
api_to_dep_impl["store"] = inference_store
|
||||
api_to_dep_impl["telemetry_enabled"] = run_config.telemetry.enabled
|
||||
|
||||
elif api == Api.vector_io:
|
||||
api_to_dep_impl["vector_stores_config"] = run_config.vector_stores
|
||||
elif api == Api.safety:
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@
|
|||
import asyncio
|
||||
import time
|
||||
from collections.abc import AsyncIterator
|
||||
from datetime import UTC, datetime
|
||||
from typing import Annotated, Any
|
||||
|
||||
from fastapi import Body
|
||||
|
|
@ -15,11 +14,7 @@ from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatC
|
|||
from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from llama_stack.core.telemetry.telemetry import MetricEvent
|
||||
from llama_stack.core.telemetry.tracing import enqueue_event, get_current_span
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
||||
from llama_stack_api import (
|
||||
HealthResponse,
|
||||
|
|
@ -60,15 +55,10 @@ class InferenceRouter(Inference):
|
|||
self,
|
||||
routing_table: RoutingTable,
|
||||
store: InferenceStore | None = None,
|
||||
telemetry_enabled: bool = False,
|
||||
) -> None:
|
||||
logger.debug("Initializing InferenceRouter")
|
||||
self.routing_table = routing_table
|
||||
self.telemetry_enabled = telemetry_enabled
|
||||
self.store = store
|
||||
if self.telemetry_enabled:
|
||||
self.tokenizer = Tokenizer.get_instance()
|
||||
self.formatter = ChatFormat(self.tokenizer)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logger.debug("InferenceRouter.initialize")
|
||||
|
|
@ -94,54 +84,6 @@ class InferenceRouter(Inference):
|
|||
)
|
||||
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
|
||||
|
||||
def _construct_metrics(
|
||||
self,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
total_tokens: int,
|
||||
fully_qualified_model_id: str,
|
||||
provider_id: str,
|
||||
) -> list[MetricEvent]:
|
||||
"""Constructs a list of MetricEvent objects containing token usage metrics.
|
||||
|
||||
Args:
|
||||
prompt_tokens: Number of tokens in the prompt
|
||||
completion_tokens: Number of tokens in the completion
|
||||
total_tokens: Total number of tokens used
|
||||
fully_qualified_model_id:
|
||||
provider_id: The provider identifier
|
||||
|
||||
Returns:
|
||||
List of MetricEvent objects with token usage metrics
|
||||
"""
|
||||
span = get_current_span()
|
||||
if span is None:
|
||||
logger.warning("No span found for token usage metrics")
|
||||
return []
|
||||
|
||||
metrics = [
|
||||
("prompt_tokens", prompt_tokens),
|
||||
("completion_tokens", completion_tokens),
|
||||
("total_tokens", total_tokens),
|
||||
]
|
||||
metric_events = []
|
||||
for metric_name, value in metrics:
|
||||
metric_events.append(
|
||||
MetricEvent(
|
||||
trace_id=span.trace_id,
|
||||
span_id=span.span_id,
|
||||
metric=metric_name,
|
||||
value=value,
|
||||
timestamp=datetime.now(UTC),
|
||||
unit="tokens",
|
||||
attributes={
|
||||
"model_id": fully_qualified_model_id,
|
||||
"provider_id": provider_id,
|
||||
},
|
||||
)
|
||||
)
|
||||
return metric_events
|
||||
|
||||
async def _get_model_provider(self, model_id: str, expected_model_type: str) -> tuple[Inference, str]:
|
||||
model = await self.routing_table.get_object_by_identifier("model", model_id)
|
||||
if model:
|
||||
|
|
@ -186,26 +128,9 @@ class InferenceRouter(Inference):
|
|||
|
||||
if params.stream:
|
||||
return await provider.openai_completion(params)
|
||||
# TODO: Metrics do NOT work with openai_completion stream=True due to the fact
|
||||
# that we do not return an AsyncIterator, our tests expect a stream of chunks we cannot intercept currently.
|
||||
|
||||
response = await provider.openai_completion(params)
|
||||
response.model = request_model_id
|
||||
if self.telemetry_enabled and response.usage is not None:
|
||||
metrics = self._construct_metrics(
|
||||
prompt_tokens=response.usage.prompt_tokens,
|
||||
completion_tokens=response.usage.completion_tokens,
|
||||
total_tokens=response.usage.total_tokens,
|
||||
fully_qualified_model_id=request_model_id,
|
||||
provider_id=provider.__provider_id__,
|
||||
)
|
||||
for metric in metrics:
|
||||
enqueue_event(metric)
|
||||
|
||||
# these metrics will show up in the client response.
|
||||
response.metrics = (
|
||||
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
|
||||
)
|
||||
return response
|
||||
|
||||
async def openai_chat_completion(
|
||||
|
|
@ -254,20 +179,6 @@ class InferenceRouter(Inference):
|
|||
if self.store:
|
||||
asyncio.create_task(self.store.store_chat_completion(response, params.messages))
|
||||
|
||||
if self.telemetry_enabled and response.usage is not None:
|
||||
metrics = self._construct_metrics(
|
||||
prompt_tokens=response.usage.prompt_tokens,
|
||||
completion_tokens=response.usage.completion_tokens,
|
||||
total_tokens=response.usage.total_tokens,
|
||||
fully_qualified_model_id=request_model_id,
|
||||
provider_id=provider.__provider_id__,
|
||||
)
|
||||
for metric in metrics:
|
||||
enqueue_event(metric)
|
||||
# these metrics will show up in the client response.
|
||||
response.metrics = (
|
||||
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
|
||||
)
|
||||
return response
|
||||
|
||||
async def openai_embeddings(
|
||||
|
|
@ -411,18 +322,6 @@ class InferenceRouter(Inference):
|
|||
for choice_data in choices_data.values():
|
||||
completion_text += "".join(choice_data["content_parts"])
|
||||
|
||||
# Add metrics to the chunk
|
||||
if self.telemetry_enabled and hasattr(chunk, "usage") and chunk.usage:
|
||||
metrics = self._construct_metrics(
|
||||
prompt_tokens=chunk.usage.prompt_tokens,
|
||||
completion_tokens=chunk.usage.completion_tokens,
|
||||
total_tokens=chunk.usage.total_tokens,
|
||||
fully_qualified_model_id=fully_qualified_model_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
for metric in metrics:
|
||||
enqueue_event(metric)
|
||||
|
||||
yield chunk
|
||||
finally:
|
||||
# Store the final assembled completion
|
||||
|
|
|
|||
|
|
@ -6,11 +6,15 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from opentelemetry import trace
|
||||
|
||||
from llama_stack.core.datatypes import SafetyConfig
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.telemetry.helpers import safety_request_span_attributes, safety_span_name
|
||||
from llama_stack_api import ModerationObject, OpenAIMessageParam, RoutingTable, RunShieldResponse, Safety, Shield
|
||||
|
||||
logger = get_logger(name=__name__, category="core::routers")
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
class SafetyRouter(Safety):
|
||||
|
|
@ -51,13 +55,17 @@ class SafetyRouter(Safety):
|
|||
messages: list[OpenAIMessageParam],
|
||||
params: dict[str, Any] = None,
|
||||
) -> RunShieldResponse:
|
||||
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
|
||||
provider = await self.routing_table.get_provider_impl(shield_id)
|
||||
return await provider.run_shield(
|
||||
shield_id=shield_id,
|
||||
messages=messages,
|
||||
params=params,
|
||||
)
|
||||
with tracer.start_as_current_span(name=safety_span_name(shield_id)):
|
||||
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
|
||||
provider = await self.routing_table.get_provider_impl(shield_id)
|
||||
response = await provider.run_shield(
|
||||
shield_id=shield_id,
|
||||
messages=messages,
|
||||
params=params,
|
||||
)
|
||||
|
||||
safety_request_span_attributes(shield_id, messages, response)
|
||||
return response
|
||||
|
||||
async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject:
|
||||
list_shields_response = await self.routing_table.list_shields()
|
||||
|
|
|
|||
|
|
@ -50,8 +50,6 @@ from llama_stack.core.stack import (
|
|||
cast_image_name_to_string,
|
||||
replace_env_vars,
|
||||
)
|
||||
from llama_stack.core.telemetry import Telemetry
|
||||
from llama_stack.core.telemetry.tracing import CURRENT_TRACE_CONTEXT, setup_logger
|
||||
from llama_stack.core.utils.config import redact_sensitive_fields
|
||||
from llama_stack.core.utils.config_resolution import resolve_config_or_distro
|
||||
from llama_stack.core.utils.context import preserve_contexts_async_generator
|
||||
|
|
@ -60,7 +58,6 @@ from llama_stack_api import Api, ConflictError, PaginatedResponse, ResourceNotFo
|
|||
|
||||
from .auth import AuthenticationMiddleware
|
||||
from .quota import QuotaMiddleware
|
||||
from .tracing import TracingMiddleware
|
||||
|
||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
|
|
@ -263,7 +260,7 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
|
|||
|
||||
try:
|
||||
if is_streaming:
|
||||
context_vars = [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR]
|
||||
context_vars = [PROVIDER_DATA_VAR]
|
||||
if test_context_var is not None:
|
||||
context_vars.append(test_context_var)
|
||||
gen = preserve_contexts_async_generator(sse_generator(func(**kwargs)), context_vars)
|
||||
|
|
@ -441,9 +438,6 @@ def create_app() -> StackApp:
|
|||
if cors_config:
|
||||
app.add_middleware(CORSMiddleware, **cors_config.model_dump())
|
||||
|
||||
if config.telemetry.enabled:
|
||||
setup_logger(Telemetry())
|
||||
|
||||
# Load external APIs if configured
|
||||
external_apis = load_external_apis(config)
|
||||
all_routes = get_all_api_routes(external_apis)
|
||||
|
|
@ -500,9 +494,6 @@ def create_app() -> StackApp:
|
|||
app.exception_handler(RequestValidationError)(global_exception_handler)
|
||||
app.exception_handler(Exception)(global_exception_handler)
|
||||
|
||||
if config.telemetry.enabled:
|
||||
app.add_middleware(TracingMiddleware, impls=impls, external_apis=external_apis)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,80 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from aiohttp import hdrs
|
||||
|
||||
from llama_stack.core.external import ExternalApiSpec
|
||||
from llama_stack.core.server.routes import find_matching_route, initialize_route_impls
|
||||
from llama_stack.core.telemetry.tracing import end_trace, start_trace
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="core::server")
|
||||
|
||||
|
||||
class TracingMiddleware:
|
||||
def __init__(self, app, impls, external_apis: dict[str, ExternalApiSpec]):
|
||||
self.app = app
|
||||
self.impls = impls
|
||||
self.external_apis = external_apis
|
||||
# FastAPI built-in paths that should bypass custom routing
|
||||
self.fastapi_paths = ("/docs", "/redoc", "/openapi.json", "/favicon.ico", "/static")
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
if scope.get("type") == "lifespan":
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
path = scope.get("path", "")
|
||||
|
||||
# Check if the path is a FastAPI built-in path
|
||||
if path.startswith(self.fastapi_paths):
|
||||
# Pass through to FastAPI's built-in handlers
|
||||
logger.debug(f"Bypassing custom routing for FastAPI built-in path: {path}")
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
if not hasattr(self, "route_impls"):
|
||||
self.route_impls = initialize_route_impls(self.impls, self.external_apis)
|
||||
|
||||
try:
|
||||
_, _, route_path, webmethod = find_matching_route(
|
||||
scope.get("method", hdrs.METH_GET), path, self.route_impls
|
||||
)
|
||||
except ValueError:
|
||||
# If no matching endpoint is found, pass through to FastAPI
|
||||
logger.debug(f"No matching route found for path: {path}, falling back to FastAPI")
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
# Log deprecation warning if route is deprecated
|
||||
if getattr(webmethod, "deprecated", False):
|
||||
logger.warning(
|
||||
f"DEPRECATED ROUTE USED: {scope.get('method', 'GET')} {path} - "
|
||||
f"This route is deprecated and may be removed in a future version. "
|
||||
f"Please check the docs for the supported version."
|
||||
)
|
||||
|
||||
trace_attributes = {"__location__": "server", "raw_path": path}
|
||||
|
||||
# Extract W3C trace context headers and store as trace attributes
|
||||
headers = dict(scope.get("headers", []))
|
||||
traceparent = headers.get(b"traceparent", b"").decode()
|
||||
if traceparent:
|
||||
trace_attributes["traceparent"] = traceparent
|
||||
tracestate = headers.get(b"tracestate", b"").decode()
|
||||
if tracestate:
|
||||
trace_attributes["tracestate"] = tracestate
|
||||
|
||||
trace_path = webmethod.descriptive_name or route_path
|
||||
trace_context = await start_trace(trace_path, trace_attributes)
|
||||
|
||||
async def send_with_trace_id(message):
|
||||
if message["type"] == "http.response.start":
|
||||
headers = message.get("headers", [])
|
||||
headers.append([b"x-trace-id", str(trace_context.trace_id).encode()])
|
||||
message["headers"] = headers
|
||||
await send(message)
|
||||
|
||||
try:
|
||||
return await self.app(scope, receive, send_with_trace_id)
|
||||
finally:
|
||||
await end_trace()
|
||||
|
|
@ -1,32 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .telemetry import Telemetry
|
||||
from .trace_protocol import serialize_value, trace_protocol
|
||||
from .tracing import (
|
||||
CURRENT_TRACE_CONTEXT,
|
||||
ROOT_SPAN_MARKERS,
|
||||
end_trace,
|
||||
enqueue_event,
|
||||
get_current_span,
|
||||
setup_logger,
|
||||
span,
|
||||
start_trace,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Telemetry",
|
||||
"trace_protocol",
|
||||
"serialize_value",
|
||||
"CURRENT_TRACE_CONTEXT",
|
||||
"ROOT_SPAN_MARKERS",
|
||||
"end_trace",
|
||||
"enqueue_event",
|
||||
"get_current_span",
|
||||
"setup_logger",
|
||||
"span",
|
||||
"start_trace",
|
||||
]
|
||||
|
|
@ -1,629 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import threading
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
Literal,
|
||||
cast,
|
||||
)
|
||||
|
||||
from opentelemetry import metrics, trace
|
||||
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
|
||||
from opentelemetry.sdk.metrics import MeterProvider
|
||||
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import Primitive
|
||||
from llama_stack_api import json_schema_type, register_schema
|
||||
|
||||
ROOT_SPAN_MARKERS = ["__root__", "__root_span__"]
|
||||
|
||||
# Type alias for OpenTelemetry attribute values (excludes None)
|
||||
AttributeValue = str | bool | int | float | Sequence[str] | Sequence[bool] | Sequence[int] | Sequence[float]
|
||||
Attributes = Mapping[str, AttributeValue]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class SpanStatus(Enum):
|
||||
"""The status of a span indicating whether it completed successfully or with an error.
|
||||
:cvar OK: Span completed successfully without errors
|
||||
:cvar ERROR: Span completed with an error or failure
|
||||
"""
|
||||
|
||||
OK = "ok"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Span(BaseModel):
|
||||
"""A span representing a single operation within a trace.
|
||||
:param span_id: Unique identifier for the span
|
||||
:param trace_id: Unique identifier for the trace this span belongs to
|
||||
:param parent_span_id: (Optional) Unique identifier for the parent span, if this is a child span
|
||||
:param name: Human-readable name describing the operation this span represents
|
||||
:param start_time: Timestamp when the operation began
|
||||
:param end_time: (Optional) Timestamp when the operation finished, if completed
|
||||
:param attributes: (Optional) Key-value pairs containing additional metadata about the span
|
||||
"""
|
||||
|
||||
span_id: str
|
||||
trace_id: str
|
||||
parent_span_id: str | None = None
|
||||
name: str
|
||||
start_time: datetime
|
||||
end_time: datetime | None = None
|
||||
attributes: dict[str, Any] | None = Field(default_factory=lambda: {})
|
||||
|
||||
def set_attribute(self, key: str, value: Any):
|
||||
if self.attributes is None:
|
||||
self.attributes = {}
|
||||
self.attributes[key] = value
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Trace(BaseModel):
|
||||
"""A trace representing the complete execution path of a request across multiple operations.
|
||||
:param trace_id: Unique identifier for the trace
|
||||
:param root_span_id: Unique identifier for the root span that started this trace
|
||||
:param start_time: Timestamp when the trace began
|
||||
:param end_time: (Optional) Timestamp when the trace finished, if completed
|
||||
"""
|
||||
|
||||
trace_id: str
|
||||
root_span_id: str
|
||||
start_time: datetime
|
||||
end_time: datetime | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class EventType(Enum):
|
||||
"""The type of telemetry event being logged.
|
||||
:cvar UNSTRUCTURED_LOG: A simple log message with severity level
|
||||
:cvar STRUCTURED_LOG: A structured log event with typed payload data
|
||||
:cvar METRIC: A metric measurement with value and unit
|
||||
"""
|
||||
|
||||
UNSTRUCTURED_LOG = "unstructured_log"
|
||||
STRUCTURED_LOG = "structured_log"
|
||||
METRIC = "metric"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class LogSeverity(Enum):
|
||||
"""The severity level of a log message.
|
||||
:cvar VERBOSE: Detailed diagnostic information for troubleshooting
|
||||
:cvar DEBUG: Debug information useful during development
|
||||
:cvar INFO: General informational messages about normal operation
|
||||
:cvar WARN: Warning messages about potentially problematic situations
|
||||
:cvar ERROR: Error messages indicating failures that don't stop execution
|
||||
:cvar CRITICAL: Critical error messages indicating severe failures
|
||||
"""
|
||||
|
||||
VERBOSE = "verbose"
|
||||
DEBUG = "debug"
|
||||
INFO = "info"
|
||||
WARN = "warn"
|
||||
ERROR = "error"
|
||||
CRITICAL = "critical"
|
||||
|
||||
|
||||
class EventCommon(BaseModel):
|
||||
"""Common fields shared by all telemetry events.
|
||||
:param trace_id: Unique identifier for the trace this event belongs to
|
||||
:param span_id: Unique identifier for the span this event belongs to
|
||||
:param timestamp: Timestamp when the event occurred
|
||||
:param attributes: (Optional) Key-value pairs containing additional metadata about the event
|
||||
"""
|
||||
|
||||
trace_id: str
|
||||
span_id: str
|
||||
timestamp: datetime
|
||||
attributes: dict[str, Primitive] | None = Field(default_factory=lambda: {})
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class UnstructuredLogEvent(EventCommon):
|
||||
"""An unstructured log event containing a simple text message.
|
||||
:param type: Event type identifier set to UNSTRUCTURED_LOG
|
||||
:param message: The log message text
|
||||
:param severity: The severity level of the log message
|
||||
"""
|
||||
|
||||
type: Literal[EventType.UNSTRUCTURED_LOG] = EventType.UNSTRUCTURED_LOG
|
||||
message: str
|
||||
severity: LogSeverity
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MetricEvent(EventCommon):
|
||||
"""A metric event containing a measured value.
|
||||
:param type: Event type identifier set to METRIC
|
||||
:param metric: The name of the metric being measured
|
||||
:param value: The numeric value of the metric measurement
|
||||
:param unit: The unit of measurement for the metric value
|
||||
"""
|
||||
|
||||
type: Literal[EventType.METRIC] = EventType.METRIC
|
||||
metric: str # this would be an enum
|
||||
value: int | float
|
||||
unit: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class StructuredLogType(Enum):
|
||||
"""The type of structured log event payload.
|
||||
:cvar SPAN_START: Event indicating the start of a new span
|
||||
:cvar SPAN_END: Event indicating the completion of a span
|
||||
"""
|
||||
|
||||
SPAN_START = "span_start"
|
||||
SPAN_END = "span_end"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class SpanStartPayload(BaseModel):
|
||||
"""Payload for a span start event.
|
||||
:param type: Payload type identifier set to SPAN_START
|
||||
:param name: Human-readable name describing the operation this span represents
|
||||
:param parent_span_id: (Optional) Unique identifier for the parent span, if this is a child span
|
||||
"""
|
||||
|
||||
type: Literal[StructuredLogType.SPAN_START] = StructuredLogType.SPAN_START
|
||||
name: str
|
||||
parent_span_id: str | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class SpanEndPayload(BaseModel):
|
||||
"""Payload for a span end event.
|
||||
:param type: Payload type identifier set to SPAN_END
|
||||
:param status: The final status of the span indicating success or failure
|
||||
"""
|
||||
|
||||
type: Literal[StructuredLogType.SPAN_END] = StructuredLogType.SPAN_END
|
||||
status: SpanStatus
|
||||
|
||||
|
||||
StructuredLogPayload = Annotated[
|
||||
SpanStartPayload | SpanEndPayload,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(StructuredLogPayload, name="StructuredLogPayload")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class StructuredLogEvent(EventCommon):
|
||||
"""A structured log event containing typed payload data.
|
||||
:param type: Event type identifier set to STRUCTURED_LOG
|
||||
:param payload: The structured payload data for the log event
|
||||
"""
|
||||
|
||||
type: Literal[EventType.STRUCTURED_LOG] = EventType.STRUCTURED_LOG
|
||||
payload: StructuredLogPayload
|
||||
|
||||
|
||||
Event = Annotated[
|
||||
UnstructuredLogEvent | MetricEvent | StructuredLogEvent,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(Event, name="Event")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class EvalTrace(BaseModel):
|
||||
"""A trace record for evaluation purposes.
|
||||
:param session_id: Unique identifier for the evaluation session
|
||||
:param step: The evaluation step or phase identifier
|
||||
:param input: The input data for the evaluation
|
||||
:param output: The actual output produced during evaluation
|
||||
:param expected_output: The expected output for comparison during evaluation
|
||||
"""
|
||||
|
||||
session_id: str
|
||||
step: str
|
||||
input: str
|
||||
output: str
|
||||
expected_output: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class SpanWithStatus(Span):
|
||||
"""A span that includes status information.
|
||||
:param status: (Optional) The current status of the span
|
||||
"""
|
||||
|
||||
status: SpanStatus | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class QueryConditionOp(Enum):
|
||||
"""Comparison operators for query conditions.
|
||||
:cvar EQ: Equal to comparison
|
||||
:cvar NE: Not equal to comparison
|
||||
:cvar GT: Greater than comparison
|
||||
:cvar LT: Less than comparison
|
||||
"""
|
||||
|
||||
EQ = "eq"
|
||||
NE = "ne"
|
||||
GT = "gt"
|
||||
LT = "lt"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class QueryCondition(BaseModel):
|
||||
"""A condition for filtering query results.
|
||||
:param key: The attribute key to filter on
|
||||
:param op: The comparison operator to apply
|
||||
:param value: The value to compare against
|
||||
"""
|
||||
|
||||
key: str
|
||||
op: QueryConditionOp
|
||||
value: Any
|
||||
|
||||
|
||||
class QueryTracesResponse(BaseModel):
|
||||
"""Response containing a list of traces.
|
||||
:param data: List of traces matching the query criteria
|
||||
"""
|
||||
|
||||
data: list[Trace]
|
||||
|
||||
|
||||
class QuerySpansResponse(BaseModel):
|
||||
"""Response containing a list of spans.
|
||||
:param data: List of spans matching the query criteria
|
||||
"""
|
||||
|
||||
data: list[Span]
|
||||
|
||||
|
||||
class QuerySpanTreeResponse(BaseModel):
|
||||
"""Response containing a tree structure of spans.
|
||||
:param data: Dictionary mapping span IDs to spans with status information
|
||||
"""
|
||||
|
||||
data: dict[str, SpanWithStatus]
|
||||
|
||||
|
||||
class MetricQueryType(Enum):
|
||||
"""The type of metric query to perform.
|
||||
:cvar RANGE: Query metrics over a time range
|
||||
:cvar INSTANT: Query metrics at a specific point in time
|
||||
"""
|
||||
|
||||
RANGE = "range"
|
||||
INSTANT = "instant"
|
||||
|
||||
|
||||
class MetricLabelOperator(Enum):
|
||||
"""Operators for matching metric labels.
|
||||
:cvar EQUALS: Label value must equal the specified value
|
||||
:cvar NOT_EQUALS: Label value must not equal the specified value
|
||||
:cvar REGEX_MATCH: Label value must match the specified regular expression
|
||||
:cvar REGEX_NOT_MATCH: Label value must not match the specified regular expression
|
||||
"""
|
||||
|
||||
EQUALS = "="
|
||||
NOT_EQUALS = "!="
|
||||
REGEX_MATCH = "=~"
|
||||
REGEX_NOT_MATCH = "!~"
|
||||
|
||||
|
||||
class MetricLabelMatcher(BaseModel):
|
||||
"""A matcher for filtering metrics by label values.
|
||||
:param name: The name of the label to match
|
||||
:param value: The value to match against
|
||||
:param operator: The comparison operator to use for matching
|
||||
"""
|
||||
|
||||
name: str
|
||||
value: str
|
||||
operator: MetricLabelOperator = MetricLabelOperator.EQUALS
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MetricLabel(BaseModel):
|
||||
"""A label associated with a metric.
|
||||
:param name: The name of the label
|
||||
:param value: The value of the label
|
||||
"""
|
||||
|
||||
name: str
|
||||
value: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MetricDataPoint(BaseModel):
|
||||
"""A single data point in a metric time series.
|
||||
:param timestamp: Unix timestamp when the metric value was recorded
|
||||
:param value: The numeric value of the metric at this timestamp
|
||||
"""
|
||||
|
||||
timestamp: int
|
||||
value: float
|
||||
unit: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MetricSeries(BaseModel):
|
||||
"""A time series of metric data points.
|
||||
:param metric: The name of the metric
|
||||
:param labels: List of labels associated with this metric series
|
||||
:param values: List of data points in chronological order
|
||||
"""
|
||||
|
||||
metric: str
|
||||
labels: list[MetricLabel]
|
||||
values: list[MetricDataPoint]
|
||||
|
||||
|
||||
class QueryMetricsResponse(BaseModel):
|
||||
"""Response containing metric time series data.
|
||||
:param data: List of metric series matching the query criteria
|
||||
"""
|
||||
|
||||
data: list[MetricSeries]
|
||||
|
||||
|
||||
_GLOBAL_STORAGE: dict[str, dict[str | int, Any]] = {
|
||||
"active_spans": {},
|
||||
"counters": {},
|
||||
"gauges": {},
|
||||
"up_down_counters": {},
|
||||
"histograms": {},
|
||||
}
|
||||
_global_lock = threading.Lock()
|
||||
_TRACER_PROVIDER = None
|
||||
|
||||
logger = get_logger(name=__name__, category="telemetry")
|
||||
|
||||
|
||||
def _clean_attributes(attrs: dict[str, Any] | None) -> Attributes | None:
|
||||
"""Remove None values from attributes dict to match OpenTelemetry's expected type."""
|
||||
if attrs is None:
|
||||
return None
|
||||
return {k: v for k, v in attrs.items() if v is not None}
|
||||
|
||||
|
||||
def is_tracing_enabled(tracer):
|
||||
with tracer.start_as_current_span("check_tracing") as span:
|
||||
return span.is_recording()
|
||||
|
||||
|
||||
class Telemetry:
|
||||
def __init__(self) -> None:
|
||||
self.meter = None
|
||||
|
||||
global _TRACER_PROVIDER
|
||||
# Initialize the correct span processor based on the provider state.
|
||||
# This is needed since once the span processor is set, it cannot be unset.
|
||||
# Recreating the telemetry adapter multiple times will result in duplicate span processors.
|
||||
# Since the library client can be recreated multiple times in a notebook,
|
||||
# the kernel will hold on to the span processor and cause duplicate spans to be written.
|
||||
if os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT"):
|
||||
if _TRACER_PROVIDER is None:
|
||||
provider = TracerProvider()
|
||||
trace.set_tracer_provider(provider)
|
||||
_TRACER_PROVIDER = provider
|
||||
|
||||
# Use single OTLP endpoint for all telemetry signals
|
||||
|
||||
# Let OpenTelemetry SDK handle endpoint construction automatically
|
||||
# The SDK will read OTEL_EXPORTER_OTLP_ENDPOINT and construct appropriate URLs
|
||||
# https://opentelemetry.io/docs/languages/sdk-configuration/otlp-exporter
|
||||
span_exporter = OTLPSpanExporter()
|
||||
span_processor = BatchSpanProcessor(span_exporter)
|
||||
cast(TracerProvider, trace.get_tracer_provider()).add_span_processor(span_processor)
|
||||
|
||||
metric_reader = PeriodicExportingMetricReader(OTLPMetricExporter())
|
||||
metric_provider = MeterProvider(metric_readers=[metric_reader])
|
||||
metrics.set_meter_provider(metric_provider)
|
||||
self.is_otel_endpoint_set = True
|
||||
else:
|
||||
logger.warning("OTEL_EXPORTER_OTLP_ENDPOINT is not set, skipping telemetry")
|
||||
self.is_otel_endpoint_set = False
|
||||
|
||||
self.meter = metrics.get_meter(__name__)
|
||||
self._lock = _global_lock
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if self.is_otel_endpoint_set:
|
||||
cast(TracerProvider, trace.get_tracer_provider()).force_flush()
|
||||
|
||||
async def log_event(self, event: Event, ttl_seconds: int = 604800) -> None:
|
||||
if isinstance(event, UnstructuredLogEvent):
|
||||
self._log_unstructured(event, ttl_seconds)
|
||||
elif isinstance(event, MetricEvent):
|
||||
self._log_metric(event)
|
||||
elif isinstance(event, StructuredLogEvent):
|
||||
self._log_structured(event, ttl_seconds)
|
||||
else:
|
||||
raise ValueError(f"Unknown event type: {event}")
|
||||
|
||||
def _log_unstructured(self, event: UnstructuredLogEvent, ttl_seconds: int) -> None:
|
||||
with self._lock:
|
||||
# Use global storage instead of instance storage
|
||||
span_id = int(event.span_id, 16)
|
||||
span = _GLOBAL_STORAGE["active_spans"].get(span_id)
|
||||
|
||||
if span:
|
||||
timestamp_ns = int(event.timestamp.timestamp() * 1e9)
|
||||
span.add_event(
|
||||
name=event.type.value,
|
||||
attributes={
|
||||
"message": event.message,
|
||||
"severity": event.severity.value,
|
||||
"__ttl__": ttl_seconds,
|
||||
**(event.attributes or {}),
|
||||
},
|
||||
timestamp=timestamp_ns,
|
||||
)
|
||||
else:
|
||||
print(f"Warning: No active span found for span_id {span_id}. Dropping event: {event}")
|
||||
|
||||
def _get_or_create_counter(self, name: str, unit: str) -> metrics.Counter:
|
||||
assert self.meter is not None
|
||||
if name not in _GLOBAL_STORAGE["counters"]:
|
||||
_GLOBAL_STORAGE["counters"][name] = self.meter.create_counter(
|
||||
name=name,
|
||||
unit=unit,
|
||||
description=f"Counter for {name}",
|
||||
)
|
||||
return cast(metrics.Counter, _GLOBAL_STORAGE["counters"][name])
|
||||
|
||||
def _get_or_create_gauge(self, name: str, unit: str) -> metrics.ObservableGauge:
|
||||
assert self.meter is not None
|
||||
if name not in _GLOBAL_STORAGE["gauges"]:
|
||||
_GLOBAL_STORAGE["gauges"][name] = self.meter.create_gauge(
|
||||
name=name,
|
||||
unit=unit,
|
||||
description=f"Gauge for {name}",
|
||||
)
|
||||
return cast(metrics.ObservableGauge, _GLOBAL_STORAGE["gauges"][name])
|
||||
|
||||
def _get_or_create_histogram(self, name: str, unit: str) -> metrics.Histogram:
|
||||
assert self.meter is not None
|
||||
if name not in _GLOBAL_STORAGE["histograms"]:
|
||||
_GLOBAL_STORAGE["histograms"][name] = self.meter.create_histogram(
|
||||
name=name,
|
||||
unit=unit,
|
||||
description=f"Histogram for {name}",
|
||||
)
|
||||
return cast(metrics.Histogram, _GLOBAL_STORAGE["histograms"][name])
|
||||
|
||||
def _log_metric(self, event: MetricEvent) -> None:
|
||||
# Add metric as an event to the current span
|
||||
try:
|
||||
with self._lock:
|
||||
# Only try to add to span if we have a valid span_id
|
||||
if event.span_id:
|
||||
try:
|
||||
span_id = int(event.span_id, 16)
|
||||
span = _GLOBAL_STORAGE["active_spans"].get(span_id)
|
||||
|
||||
if span:
|
||||
timestamp_ns = int(event.timestamp.timestamp() * 1e9)
|
||||
span.add_event(
|
||||
name=f"metric.{event.metric}",
|
||||
attributes={
|
||||
"value": event.value,
|
||||
"unit": event.unit,
|
||||
**(event.attributes or {}),
|
||||
},
|
||||
timestamp=timestamp_ns,
|
||||
)
|
||||
except (ValueError, KeyError):
|
||||
# Invalid span_id or span not found, but we already logged to console above
|
||||
pass
|
||||
except Exception:
|
||||
# Lock acquisition failed
|
||||
logger.debug("Failed to acquire lock to add metric to span")
|
||||
|
||||
# Log to OpenTelemetry meter if available
|
||||
if self.meter is None:
|
||||
return
|
||||
|
||||
# Use histograms for token-related metrics (per-request measurements)
|
||||
# Use counters for other cumulative metrics
|
||||
token_metrics = {"prompt_tokens", "completion_tokens", "total_tokens"}
|
||||
|
||||
if event.metric in token_metrics:
|
||||
# Token metrics are per-request measurements, use histogram
|
||||
histogram = self._get_or_create_histogram(event.metric, event.unit)
|
||||
histogram.record(event.value, attributes=_clean_attributes(event.attributes))
|
||||
elif isinstance(event.value, int):
|
||||
counter = self._get_or_create_counter(event.metric, event.unit)
|
||||
counter.add(event.value, attributes=_clean_attributes(event.attributes))
|
||||
elif isinstance(event.value, float):
|
||||
up_down_counter = self._get_or_create_up_down_counter(event.metric, event.unit)
|
||||
up_down_counter.add(event.value, attributes=_clean_attributes(event.attributes))
|
||||
|
||||
def _get_or_create_up_down_counter(self, name: str, unit: str) -> metrics.UpDownCounter:
|
||||
assert self.meter is not None
|
||||
if name not in _GLOBAL_STORAGE["up_down_counters"]:
|
||||
_GLOBAL_STORAGE["up_down_counters"][name] = self.meter.create_up_down_counter(
|
||||
name=name,
|
||||
unit=unit,
|
||||
description=f"UpDownCounter for {name}",
|
||||
)
|
||||
return cast(metrics.UpDownCounter, _GLOBAL_STORAGE["up_down_counters"][name])
|
||||
|
||||
def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None:
|
||||
with self._lock:
|
||||
span_id = int(event.span_id, 16)
|
||||
tracer = trace.get_tracer(__name__)
|
||||
if event.attributes is None:
|
||||
event.attributes = {}
|
||||
event.attributes["__ttl__"] = ttl_seconds
|
||||
|
||||
# Extract these W3C trace context attributes so they are not written to
|
||||
# underlying storage, as we just need them to propagate the trace context.
|
||||
traceparent = event.attributes.pop("traceparent", None)
|
||||
tracestate = event.attributes.pop("tracestate", None)
|
||||
if traceparent:
|
||||
# If we have a traceparent header value, we're not the root span.
|
||||
for root_attribute in ROOT_SPAN_MARKERS:
|
||||
event.attributes.pop(root_attribute, None)
|
||||
|
||||
if isinstance(event.payload, SpanStartPayload):
|
||||
# Check if span already exists to prevent duplicates
|
||||
if span_id in _GLOBAL_STORAGE["active_spans"]:
|
||||
return
|
||||
|
||||
context = None
|
||||
if event.payload.parent_span_id:
|
||||
parent_span_id = int(event.payload.parent_span_id, 16)
|
||||
parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id)
|
||||
if parent_span:
|
||||
context = trace.set_span_in_context(parent_span)
|
||||
elif traceparent:
|
||||
carrier = {
|
||||
"traceparent": traceparent,
|
||||
"tracestate": tracestate,
|
||||
}
|
||||
context = TraceContextTextMapPropagator().extract(carrier=carrier)
|
||||
|
||||
span = tracer.start_span(
|
||||
name=event.payload.name,
|
||||
context=context,
|
||||
attributes=_clean_attributes(event.attributes),
|
||||
)
|
||||
_GLOBAL_STORAGE["active_spans"][span_id] = span
|
||||
|
||||
elif isinstance(event.payload, SpanEndPayload):
|
||||
span = _GLOBAL_STORAGE["active_spans"].get(span_id) # type: ignore[assignment]
|
||||
if span:
|
||||
if event.attributes:
|
||||
cleaned_attrs = _clean_attributes(event.attributes)
|
||||
if cleaned_attrs:
|
||||
span.set_attributes(cleaned_attrs)
|
||||
|
||||
status = (
|
||||
trace.Status(status_code=trace.StatusCode.OK)
|
||||
if event.payload.status == SpanStatus.OK
|
||||
else trace.Status(status_code=trace.StatusCode.ERROR)
|
||||
)
|
||||
span.set_status(status)
|
||||
span.end()
|
||||
_GLOBAL_STORAGE["active_spans"].pop(span_id, None)
|
||||
else:
|
||||
raise ValueError(f"Unknown structured log event: {event}")
|
||||
|
|
@ -1,154 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
from collections.abc import AsyncGenerator, Callable
|
||||
from functools import wraps
|
||||
from typing import Any, cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.models.llama.datatypes import Primitive
|
||||
|
||||
type JSONValue = Primitive | list["JSONValue"] | dict[str, "JSONValue"]
|
||||
|
||||
|
||||
def serialize_value(value: Any) -> str:
|
||||
return str(_prepare_for_json(value))
|
||||
|
||||
|
||||
def _prepare_for_json(value: Any) -> JSONValue:
|
||||
"""Serialize a single value into JSON-compatible format."""
|
||||
if value is None:
|
||||
return ""
|
||||
elif isinstance(value, str | int | float | bool):
|
||||
return value
|
||||
elif hasattr(value, "_name_"):
|
||||
return cast(str, value._name_)
|
||||
elif isinstance(value, BaseModel):
|
||||
return cast(JSONValue, json.loads(value.model_dump_json()))
|
||||
elif isinstance(value, list | tuple | set):
|
||||
return [_prepare_for_json(item) for item in value]
|
||||
elif isinstance(value, dict):
|
||||
return {str(k): _prepare_for_json(v) for k, v in value.items()}
|
||||
else:
|
||||
try:
|
||||
json.dumps(value)
|
||||
return cast(JSONValue, value)
|
||||
except Exception:
|
||||
return str(value)
|
||||
|
||||
|
||||
def trace_protocol[T: type[Any]](cls: T) -> T:
|
||||
"""
|
||||
A class decorator that automatically traces all methods in a protocol/base class
|
||||
and its inheriting classes.
|
||||
"""
|
||||
|
||||
def trace_method(method: Callable[..., Any]) -> Callable[..., Any]:
|
||||
is_async = asyncio.iscoroutinefunction(method)
|
||||
is_async_gen = inspect.isasyncgenfunction(method)
|
||||
|
||||
def create_span_context(self: Any, *args: Any, **kwargs: Any) -> tuple[str, str, dict[str, Primitive]]:
|
||||
class_name = self.__class__.__name__
|
||||
method_name = method.__name__
|
||||
span_type = "async_generator" if is_async_gen else "async" if is_async else "sync"
|
||||
sig = inspect.signature(method)
|
||||
param_names = list(sig.parameters.keys())[1:] # Skip 'self'
|
||||
combined_args: dict[str, str] = {}
|
||||
for i, arg in enumerate(args):
|
||||
param_name = param_names[i] if i < len(param_names) else f"position_{i + 1}"
|
||||
combined_args[param_name] = serialize_value(arg)
|
||||
for k, v in kwargs.items():
|
||||
combined_args[str(k)] = serialize_value(v)
|
||||
|
||||
span_attributes: dict[str, Primitive] = {
|
||||
"__autotraced__": True,
|
||||
"__class__": class_name,
|
||||
"__method__": method_name,
|
||||
"__type__": span_type,
|
||||
"__args__": json.dumps(combined_args),
|
||||
}
|
||||
|
||||
return class_name, method_name, span_attributes
|
||||
|
||||
@wraps(method)
|
||||
async def async_gen_wrapper(self: Any, *args: Any, **kwargs: Any) -> AsyncGenerator[Any, None]:
|
||||
from llama_stack.core.telemetry import tracing
|
||||
|
||||
class_name, method_name, span_attributes = create_span_context(self, *args, **kwargs)
|
||||
|
||||
with tracing.span(f"{class_name}.{method_name}", span_attributes) as span:
|
||||
count = 0
|
||||
try:
|
||||
async for item in method(self, *args, **kwargs):
|
||||
yield item
|
||||
count += 1
|
||||
finally:
|
||||
span.set_attribute("chunk_count", count)
|
||||
|
||||
@wraps(method)
|
||||
async def async_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
from llama_stack.core.telemetry import tracing
|
||||
|
||||
class_name, method_name, span_attributes = create_span_context(self, *args, **kwargs)
|
||||
|
||||
with tracing.span(f"{class_name}.{method_name}", span_attributes) as span:
|
||||
try:
|
||||
result = await method(self, *args, **kwargs)
|
||||
span.set_attribute("output", serialize_value(result))
|
||||
return result
|
||||
except Exception as e:
|
||||
span.set_attribute("error", str(e))
|
||||
raise
|
||||
|
||||
@wraps(method)
|
||||
def sync_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
from llama_stack.core.telemetry import tracing
|
||||
|
||||
class_name, method_name, span_attributes = create_span_context(self, *args, **kwargs)
|
||||
|
||||
with tracing.span(f"{class_name}.{method_name}", span_attributes) as span:
|
||||
try:
|
||||
result = method(self, *args, **kwargs)
|
||||
span.set_attribute("output", serialize_value(result))
|
||||
return result
|
||||
except Exception as e:
|
||||
span.set_attribute("error", str(e))
|
||||
raise
|
||||
|
||||
if is_async_gen:
|
||||
return async_gen_wrapper
|
||||
elif is_async:
|
||||
return async_wrapper
|
||||
else:
|
||||
return sync_wrapper
|
||||
|
||||
# Wrap methods on the class itself (for classes applied at runtime)
|
||||
# Skip if already wrapped (indicated by __wrapped__ attribute)
|
||||
for name, method in vars(cls).items():
|
||||
if inspect.isfunction(method) and not name.startswith("_"):
|
||||
if not hasattr(method, "__wrapped__"):
|
||||
wrapped = trace_method(method)
|
||||
setattr(cls, name, wrapped) # noqa: B010
|
||||
|
||||
# Also set up __init_subclass__ for future subclasses
|
||||
original_init_subclass = cast(Callable[..., Any] | None, getattr(cls, "__init_subclass__", None))
|
||||
|
||||
def __init_subclass__(cls_child: type[Any], **kwargs: Any) -> None: # noqa: N807
|
||||
if original_init_subclass:
|
||||
cast(Callable[..., None], original_init_subclass)(**kwargs)
|
||||
|
||||
for name, method in vars(cls_child).items():
|
||||
if inspect.isfunction(method) and not name.startswith("_"):
|
||||
setattr(cls_child, name, trace_method(method)) # noqa: B010
|
||||
|
||||
cls_any = cast(Any, cls)
|
||||
cls_any.__init_subclass__ = classmethod(__init_subclass__)
|
||||
|
||||
return cls
|
||||
|
|
@ -1,388 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import contextvars
|
||||
import logging # allow-direct-logging
|
||||
import queue
|
||||
import secrets
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from datetime import UTC, datetime
|
||||
from functools import wraps
|
||||
from typing import Any, Self
|
||||
|
||||
from llama_stack.core.telemetry.telemetry import (
|
||||
ROOT_SPAN_MARKERS,
|
||||
Event,
|
||||
LogSeverity,
|
||||
Span,
|
||||
SpanEndPayload,
|
||||
SpanStartPayload,
|
||||
SpanStatus,
|
||||
StructuredLogEvent,
|
||||
Telemetry,
|
||||
UnstructuredLogEvent,
|
||||
)
|
||||
from llama_stack.core.telemetry.trace_protocol import serialize_value
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(__name__, category="core")
|
||||
|
||||
# Fallback logger that does NOT propagate to TelemetryHandler to avoid recursion
|
||||
_fallback_logger = logging.getLogger("llama_stack.telemetry.background")
|
||||
if not _fallback_logger.handlers:
|
||||
_fallback_logger.propagate = False
|
||||
_fallback_logger.setLevel(logging.ERROR)
|
||||
_fallback_handler = logging.StreamHandler(sys.stderr)
|
||||
_fallback_handler.setLevel(logging.ERROR)
|
||||
_fallback_handler.setFormatter(logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s"))
|
||||
_fallback_logger.addHandler(_fallback_handler)
|
||||
|
||||
|
||||
INVALID_SPAN_ID = 0x0000000000000000
|
||||
INVALID_TRACE_ID = 0x00000000000000000000000000000000
|
||||
|
||||
# The logical root span may not be visible to this process if a parent context
|
||||
# is passed in. The local root span is the first local span in a trace.
|
||||
LOCAL_ROOT_SPAN_MARKER = "__local_root_span__"
|
||||
|
||||
|
||||
def trace_id_to_str(trace_id: int) -> str:
|
||||
"""Convenience trace ID formatting method
|
||||
Args:
|
||||
trace_id: Trace ID int
|
||||
|
||||
Returns:
|
||||
The trace ID as 32-byte hexadecimal string
|
||||
"""
|
||||
return format(trace_id, "032x")
|
||||
|
||||
|
||||
def span_id_to_str(span_id: int) -> str:
|
||||
"""Convenience span ID formatting method
|
||||
Args:
|
||||
span_id: Span ID int
|
||||
|
||||
Returns:
|
||||
The span ID as 16-byte hexadecimal string
|
||||
"""
|
||||
return format(span_id, "016x")
|
||||
|
||||
|
||||
def generate_span_id() -> str:
|
||||
span_id = secrets.randbits(64)
|
||||
while span_id == INVALID_SPAN_ID:
|
||||
span_id = secrets.randbits(64)
|
||||
return span_id_to_str(span_id)
|
||||
|
||||
|
||||
def generate_trace_id() -> str:
|
||||
trace_id = secrets.randbits(128)
|
||||
while trace_id == INVALID_TRACE_ID:
|
||||
trace_id = secrets.randbits(128)
|
||||
return trace_id_to_str(trace_id)
|
||||
|
||||
|
||||
LOG_QUEUE_FULL_LOG_INTERVAL_SECONDS = 60.0
|
||||
|
||||
|
||||
class BackgroundLogger:
|
||||
def __init__(self, api: Telemetry, capacity: int = 100000):
|
||||
self.api = api
|
||||
self.log_queue: queue.Queue[Any] = queue.Queue(maxsize=capacity)
|
||||
self.worker_thread = threading.Thread(target=self._worker, daemon=True)
|
||||
self.worker_thread.start()
|
||||
self._last_queue_full_log_time: float = 0.0
|
||||
self._dropped_since_last_notice: int = 0
|
||||
|
||||
def log_event(self, event: Event) -> None:
|
||||
try:
|
||||
self.log_queue.put_nowait(event)
|
||||
except queue.Full:
|
||||
# Aggregate drops and emit at most once per interval via fallback logger
|
||||
self._dropped_since_last_notice += 1
|
||||
current_time = time.time()
|
||||
if current_time - self._last_queue_full_log_time >= LOG_QUEUE_FULL_LOG_INTERVAL_SECONDS:
|
||||
_fallback_logger.error(
|
||||
"Log queue is full; dropped %d events since last notice",
|
||||
self._dropped_since_last_notice,
|
||||
)
|
||||
self._last_queue_full_log_time = current_time
|
||||
self._dropped_since_last_notice = 0
|
||||
|
||||
def _worker(self):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_until_complete(self._process_logs())
|
||||
|
||||
async def _process_logs(self):
|
||||
while True:
|
||||
try:
|
||||
event = self.log_queue.get()
|
||||
await self.api.log_event(event)
|
||||
except Exception:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
print("Error processing log event")
|
||||
finally:
|
||||
self.log_queue.task_done()
|
||||
|
||||
def __del__(self) -> None:
|
||||
self.log_queue.join()
|
||||
|
||||
|
||||
BACKGROUND_LOGGER: BackgroundLogger | None = None
|
||||
|
||||
|
||||
def enqueue_event(event: Event) -> None:
|
||||
"""Enqueue a telemetry event to the background logger if available.
|
||||
|
||||
This provides a non-blocking path for routers and other hot paths to
|
||||
submit telemetry without awaiting the Telemetry API, reducing contention
|
||||
with the main event loop.
|
||||
"""
|
||||
global BACKGROUND_LOGGER
|
||||
if BACKGROUND_LOGGER is None:
|
||||
raise RuntimeError("Telemetry API not initialized")
|
||||
BACKGROUND_LOGGER.log_event(event)
|
||||
|
||||
|
||||
class TraceContext:
|
||||
def __init__(self, logger: BackgroundLogger, trace_id: str):
|
||||
self.logger = logger
|
||||
self.trace_id = trace_id
|
||||
self.spans: list[Span] = []
|
||||
|
||||
def push_span(self, name: str, attributes: dict[str, Any] | None = None) -> Span:
|
||||
current_span = self.get_current_span()
|
||||
span = Span(
|
||||
span_id=generate_span_id(),
|
||||
trace_id=self.trace_id,
|
||||
name=name,
|
||||
start_time=datetime.now(UTC),
|
||||
parent_span_id=current_span.span_id if current_span else None,
|
||||
attributes=attributes,
|
||||
)
|
||||
|
||||
self.logger.log_event(
|
||||
StructuredLogEvent(
|
||||
trace_id=span.trace_id,
|
||||
span_id=span.span_id,
|
||||
timestamp=span.start_time,
|
||||
attributes=span.attributes,
|
||||
payload=SpanStartPayload(
|
||||
name=span.name,
|
||||
parent_span_id=span.parent_span_id,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
self.spans.append(span)
|
||||
return span
|
||||
|
||||
def pop_span(self, status: SpanStatus = SpanStatus.OK) -> None:
|
||||
span = self.spans.pop()
|
||||
if span is not None:
|
||||
self.logger.log_event(
|
||||
StructuredLogEvent(
|
||||
trace_id=span.trace_id,
|
||||
span_id=span.span_id,
|
||||
timestamp=span.start_time,
|
||||
attributes=span.attributes,
|
||||
payload=SpanEndPayload(
|
||||
status=status,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
def get_current_span(self) -> Span | None:
|
||||
return self.spans[-1] if self.spans else None
|
||||
|
||||
|
||||
CURRENT_TRACE_CONTEXT: contextvars.ContextVar[TraceContext | None] = contextvars.ContextVar(
|
||||
"trace_context", default=None
|
||||
)
|
||||
|
||||
|
||||
def setup_logger(api: Telemetry, level: int = logging.INFO):
|
||||
global BACKGROUND_LOGGER
|
||||
|
||||
if BACKGROUND_LOGGER is None:
|
||||
BACKGROUND_LOGGER = BackgroundLogger(api)
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.setLevel(level)
|
||||
root_logger.addHandler(TelemetryHandler())
|
||||
|
||||
|
||||
async def start_trace(name: str, attributes: dict[str, Any] | None = None) -> TraceContext | None:
|
||||
global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER
|
||||
|
||||
if BACKGROUND_LOGGER is None:
|
||||
logger.debug("No Telemetry implementation set. Skipping trace initialization...")
|
||||
return None
|
||||
|
||||
trace_id = generate_trace_id()
|
||||
context = TraceContext(BACKGROUND_LOGGER, trace_id)
|
||||
# Mark this span as the root for the trace for now. The processing of
|
||||
# traceparent context if supplied comes later and will result in the
|
||||
# ROOT_SPAN_MARKERS being removed. Also mark this is the 'local' root,
|
||||
# i.e. the root of the spans originating in this process as this is
|
||||
# needed to ensure that we insert this 'local' root span's id into
|
||||
# the trace record in sqlite store.
|
||||
attributes = dict.fromkeys(ROOT_SPAN_MARKERS, True) | {LOCAL_ROOT_SPAN_MARKER: True} | (attributes or {})
|
||||
context.push_span(name, attributes)
|
||||
|
||||
CURRENT_TRACE_CONTEXT.set(context)
|
||||
return context
|
||||
|
||||
|
||||
async def end_trace(status: SpanStatus = SpanStatus.OK):
|
||||
global CURRENT_TRACE_CONTEXT
|
||||
|
||||
context = CURRENT_TRACE_CONTEXT.get()
|
||||
if context is None:
|
||||
logger.debug("No trace context to end")
|
||||
return
|
||||
|
||||
context.pop_span(status)
|
||||
CURRENT_TRACE_CONTEXT.set(None)
|
||||
|
||||
|
||||
def severity(levelname: str) -> LogSeverity:
|
||||
if levelname == "DEBUG":
|
||||
return LogSeverity.DEBUG
|
||||
elif levelname == "INFO":
|
||||
return LogSeverity.INFO
|
||||
elif levelname == "WARNING":
|
||||
return LogSeverity.WARN
|
||||
elif levelname == "ERROR":
|
||||
return LogSeverity.ERROR
|
||||
elif levelname == "CRITICAL":
|
||||
return LogSeverity.CRITICAL
|
||||
else:
|
||||
raise ValueError(f"Unknown log level: {levelname}")
|
||||
|
||||
|
||||
# TODO: ideally, the actual emitting should be done inside a separate daemon
|
||||
# process completely isolated from the server
|
||||
class TelemetryHandler(logging.Handler):
|
||||
def emit(self, record: logging.LogRecord) -> None:
|
||||
# horrendous hack to avoid logging from asyncio and getting into an infinite loop
|
||||
if record.module in ("asyncio", "selector_events"):
|
||||
return
|
||||
|
||||
global CURRENT_TRACE_CONTEXT
|
||||
context = CURRENT_TRACE_CONTEXT.get()
|
||||
if context is None:
|
||||
return
|
||||
|
||||
span = context.get_current_span()
|
||||
if span is None:
|
||||
return
|
||||
|
||||
enqueue_event(
|
||||
UnstructuredLogEvent(
|
||||
trace_id=span.trace_id,
|
||||
span_id=span.span_id,
|
||||
timestamp=datetime.now(UTC),
|
||||
message=self.format(record),
|
||||
severity=severity(record.levelname),
|
||||
)
|
||||
)
|
||||
|
||||
def close(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class SpanContextManager:
|
||||
def __init__(self, name: str, attributes: dict[str, Any] | None = None):
|
||||
self.name = name
|
||||
self.attributes = attributes
|
||||
self.span: Span | None = None
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
global CURRENT_TRACE_CONTEXT
|
||||
context = CURRENT_TRACE_CONTEXT.get()
|
||||
if not context:
|
||||
logger.debug("No trace context to push span")
|
||||
return self
|
||||
|
||||
self.span = context.push_span(self.name, self.attributes)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback) -> None:
|
||||
global CURRENT_TRACE_CONTEXT
|
||||
context = CURRENT_TRACE_CONTEXT.get()
|
||||
if not context:
|
||||
logger.debug("No trace context to pop span")
|
||||
return
|
||||
|
||||
context.pop_span()
|
||||
|
||||
def set_attribute(self, key: str, value: Any) -> None:
|
||||
if self.span:
|
||||
if self.span.attributes is None:
|
||||
self.span.attributes = {}
|
||||
self.span.attributes[key] = serialize_value(value)
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
global CURRENT_TRACE_CONTEXT
|
||||
context = CURRENT_TRACE_CONTEXT.get()
|
||||
if not context:
|
||||
logger.debug("No trace context to push span")
|
||||
return self
|
||||
|
||||
self.span = context.push_span(self.name, self.attributes)
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_value, traceback) -> None:
|
||||
global CURRENT_TRACE_CONTEXT
|
||||
context = CURRENT_TRACE_CONTEXT.get()
|
||||
if not context:
|
||||
logger.debug("No trace context to pop span")
|
||||
return
|
||||
|
||||
context.pop_span()
|
||||
|
||||
def __call__(self, func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
with self:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
async with self:
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
return async_wrapper(*args, **kwargs)
|
||||
else:
|
||||
return sync_wrapper(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def span(name: str, attributes: dict[str, Any] | None = None) -> SpanContextManager:
|
||||
return SpanContextManager(name, attributes)
|
||||
|
||||
|
||||
def get_current_span() -> Span | None:
|
||||
global CURRENT_TRACE_CONTEXT
|
||||
if CURRENT_TRACE_CONTEXT is None:
|
||||
logger.debug("No trace context to get current span")
|
||||
return None
|
||||
|
||||
context = CURRENT_TRACE_CONTEXT.get()
|
||||
if context:
|
||||
return context.get_current_span()
|
||||
return None
|
||||
|
|
@ -7,8 +7,6 @@
|
|||
from collections.abc import AsyncGenerator
|
||||
from contextvars import ContextVar
|
||||
|
||||
from llama_stack.core.telemetry.tracing import CURRENT_TRACE_CONTEXT
|
||||
|
||||
_MISSING = object()
|
||||
|
||||
|
||||
|
|
@ -69,16 +67,12 @@ def preserve_contexts_async_generator[T](
|
|||
try:
|
||||
yield item
|
||||
# Update our tracked values with any changes made during this iteration
|
||||
# Only for non-trace context vars - trace context must persist across yields
|
||||
# to allow nested span tracking for telemetry
|
||||
# This allows context changes to persist across generator iterations
|
||||
for context_var in context_vars:
|
||||
if context_var is not CURRENT_TRACE_CONTEXT:
|
||||
initial_context_values[context_var.name] = context_var.get()
|
||||
initial_context_values[context_var.name] = context_var.get()
|
||||
finally:
|
||||
# Restore non-trace context vars after each yield to prevent leaks between requests
|
||||
# CURRENT_TRACE_CONTEXT is NOT restored here to preserve telemetry span stack
|
||||
# Restore context vars after each yield to prevent leaks between requests
|
||||
for context_var in context_vars:
|
||||
if context_var is not CURRENT_TRACE_CONTEXT:
|
||||
_restore_context_var(context_var)
|
||||
_restore_context_var(context_var)
|
||||
|
||||
return wrapper()
|
||||
|
|
|
|||
|
|
@ -272,8 +272,6 @@ registered_resources:
|
|||
provider_id: rag-runtime
|
||||
server:
|
||||
port: 8321
|
||||
telemetry:
|
||||
enabled: true
|
||||
vector_stores:
|
||||
default_provider_id: faiss
|
||||
default_embedding_model:
|
||||
|
|
|
|||
|
|
@ -281,8 +281,6 @@ registered_resources:
|
|||
provider_id: rag-runtime
|
||||
server:
|
||||
port: 8321
|
||||
telemetry:
|
||||
enabled: true
|
||||
vector_stores:
|
||||
default_provider_id: faiss
|
||||
default_embedding_model:
|
||||
|
|
|
|||
|
|
@ -131,5 +131,3 @@ registered_resources:
|
|||
provider_id: rag-runtime
|
||||
server:
|
||||
port: 8321
|
||||
telemetry:
|
||||
enabled: true
|
||||
|
|
|
|||
|
|
@ -140,5 +140,3 @@ registered_resources:
|
|||
provider_id: rag-runtime
|
||||
server:
|
||||
port: 8321
|
||||
telemetry:
|
||||
enabled: true
|
||||
|
|
|
|||
|
|
@ -138,5 +138,3 @@ registered_resources:
|
|||
provider_id: rag-runtime
|
||||
server:
|
||||
port: 8321
|
||||
telemetry:
|
||||
enabled: true
|
||||
|
|
|
|||
|
|
@ -153,5 +153,3 @@ registered_resources:
|
|||
provider_id: rag-runtime
|
||||
server:
|
||||
port: 8321
|
||||
telemetry:
|
||||
enabled: true
|
||||
|
|
|
|||
|
|
@ -114,5 +114,3 @@ registered_resources:
|
|||
provider_id: rag-runtime
|
||||
server:
|
||||
port: 8321
|
||||
telemetry:
|
||||
enabled: true
|
||||
|
|
|
|||
|
|
@ -135,5 +135,3 @@ registered_resources:
|
|||
provider_id: rag-runtime
|
||||
server:
|
||||
port: 8321
|
||||
telemetry:
|
||||
enabled: true
|
||||
|
|
|
|||
|
|
@ -132,5 +132,3 @@ registered_resources:
|
|||
provider_id: tavily-search
|
||||
server:
|
||||
port: 8321
|
||||
telemetry:
|
||||
enabled: true
|
||||
|
|
|
|||
|
|
@ -251,5 +251,3 @@ registered_resources:
|
|||
provider_id: rag-runtime
|
||||
server:
|
||||
port: 8321
|
||||
telemetry:
|
||||
enabled: true
|
||||
|
|
|
|||
|
|
@ -114,5 +114,3 @@ registered_resources:
|
|||
provider_id: rag-runtime
|
||||
server:
|
||||
port: 8321
|
||||
telemetry:
|
||||
enabled: true
|
||||
|
|
|
|||
|
|
@ -275,8 +275,6 @@ registered_resources:
|
|||
provider_id: rag-runtime
|
||||
server:
|
||||
port: 8321
|
||||
telemetry:
|
||||
enabled: true
|
||||
vector_stores:
|
||||
default_provider_id: faiss
|
||||
default_embedding_model:
|
||||
|
|
|
|||
|
|
@ -284,8 +284,6 @@ registered_resources:
|
|||
provider_id: rag-runtime
|
||||
server:
|
||||
port: 8321
|
||||
telemetry:
|
||||
enabled: true
|
||||
vector_stores:
|
||||
default_provider_id: faiss
|
||||
default_embedding_model:
|
||||
|
|
|
|||
|
|
@ -272,8 +272,6 @@ registered_resources:
|
|||
provider_id: rag-runtime
|
||||
server:
|
||||
port: 8321
|
||||
telemetry:
|
||||
enabled: true
|
||||
vector_stores:
|
||||
default_provider_id: faiss
|
||||
default_embedding_model:
|
||||
|
|
|
|||
|
|
@ -281,8 +281,6 @@ registered_resources:
|
|||
provider_id: rag-runtime
|
||||
server:
|
||||
port: 8321
|
||||
telemetry:
|
||||
enabled: true
|
||||
vector_stores:
|
||||
default_provider_id: faiss
|
||||
default_embedding_model:
|
||||
|
|
|
|||
|
|
@ -22,7 +22,6 @@ from llama_stack.core.datatypes import (
|
|||
Provider,
|
||||
SafetyConfig,
|
||||
ShieldInput,
|
||||
TelemetryConfig,
|
||||
ToolGroupInput,
|
||||
VectorStoresConfig,
|
||||
)
|
||||
|
|
@ -184,7 +183,6 @@ class RunConfigSettings(BaseModel):
|
|||
default_benchmarks: list[BenchmarkInput] | None = None
|
||||
vector_stores_config: VectorStoresConfig | None = None
|
||||
safety_config: SafetyConfig | None = None
|
||||
telemetry: TelemetryConfig = Field(default_factory=lambda: TelemetryConfig(enabled=True))
|
||||
storage_backends: dict[str, Any] | None = None
|
||||
storage_stores: dict[str, Any] | None = None
|
||||
|
||||
|
|
@ -284,7 +282,6 @@ class RunConfigSettings(BaseModel):
|
|||
"server": {
|
||||
"port": 8321,
|
||||
},
|
||||
"telemetry": self.telemetry.model_dump(exclude_none=True) if self.telemetry else None,
|
||||
}
|
||||
|
||||
if self.vector_stores_config:
|
||||
|
|
|
|||
|
|
@ -132,5 +132,3 @@ registered_resources:
|
|||
provider_id: rag-runtime
|
||||
server:
|
||||
port: 8321
|
||||
telemetry:
|
||||
enabled: true
|
||||
|
|
|
|||
|
|
@ -37,7 +37,6 @@ CATEGORIES = [
|
|||
"eval",
|
||||
"tools",
|
||||
"client",
|
||||
"telemetry",
|
||||
"openai",
|
||||
"openai_responses",
|
||||
"openai_conversations",
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@ async def get_provider_impl(
|
|||
config: MetaReferenceAgentsImplConfig,
|
||||
deps: dict[Api, Any],
|
||||
policy: list[AccessRule],
|
||||
telemetry_enabled: bool = False,
|
||||
):
|
||||
from .agents import MetaReferenceAgentsImpl
|
||||
|
||||
|
|
@ -29,7 +28,6 @@ async def get_provider_impl(
|
|||
deps[Api.conversations],
|
||||
deps[Api.prompts],
|
||||
deps[Api.files],
|
||||
telemetry_enabled,
|
||||
policy,
|
||||
)
|
||||
await impl.initialize()
|
||||
|
|
|
|||
|
|
@ -50,7 +50,6 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
prompts_api: Prompts,
|
||||
files_api: Files,
|
||||
policy: list[AccessRule],
|
||||
telemetry_enabled: bool = False,
|
||||
):
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
|
|
@ -59,7 +58,6 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
self.tool_runtime_api = tool_runtime_api
|
||||
self.tool_groups_api = tool_groups_api
|
||||
self.conversations_api = conversations_api
|
||||
self.telemetry_enabled = telemetry_enabled
|
||||
self.prompts_api = prompts_api
|
||||
self.files_api = files_api
|
||||
self.in_memory_store = InmemoryKVStoreImpl()
|
||||
|
|
@ -111,6 +109,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
max_infer_iters: int | None = 10,
|
||||
guardrails: list[ResponseGuardrail] | None = None,
|
||||
max_tool_calls: int | None = None,
|
||||
metadata: dict[str, str] | None = None,
|
||||
) -> OpenAIResponseObject:
|
||||
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
|
||||
result = await self.openai_responses_impl.create_openai_response(
|
||||
|
|
@ -130,6 +129,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
guardrails,
|
||||
parallel_tool_calls,
|
||||
max_tool_calls,
|
||||
metadata,
|
||||
)
|
||||
return result # type: ignore[no-any-return]
|
||||
|
||||
|
|
|
|||
|
|
@ -336,6 +336,7 @@ class OpenAIResponsesImpl:
|
|||
guardrails: list[str | ResponseGuardrailSpec] | None = None,
|
||||
parallel_tool_calls: bool | None = None,
|
||||
max_tool_calls: int | None = None,
|
||||
metadata: dict[str, str] | None = None,
|
||||
):
|
||||
stream = bool(stream)
|
||||
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text
|
||||
|
|
@ -390,6 +391,7 @@ class OpenAIResponsesImpl:
|
|||
guardrail_ids=guardrail_ids,
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
max_tool_calls=max_tool_calls,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
if stream:
|
||||
|
|
@ -442,6 +444,7 @@ class OpenAIResponsesImpl:
|
|||
guardrail_ids: list[str] | None = None,
|
||||
parallel_tool_calls: bool | None = True,
|
||||
max_tool_calls: int | None = None,
|
||||
metadata: dict[str, str] | None = None,
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
# These should never be None when called from create_openai_response (which sets defaults)
|
||||
# but we assert here to help mypy understand the types
|
||||
|
|
@ -490,6 +493,7 @@ class OpenAIResponsesImpl:
|
|||
guardrail_ids=guardrail_ids,
|
||||
instructions=instructions,
|
||||
max_tool_calls=max_tool_calls,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
# Stream the response
|
||||
|
|
|
|||
|
|
@ -8,7 +8,8 @@ import uuid
|
|||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.core.telemetry import tracing
|
||||
from opentelemetry import trace
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||
from llama_stack_api import (
|
||||
|
|
@ -79,6 +80,7 @@ from .utils import (
|
|||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="agents::meta_reference")
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
def convert_tooldef_to_chat_tool(tool_def):
|
||||
|
|
@ -118,6 +120,7 @@ class StreamingResponseOrchestrator:
|
|||
prompt: OpenAIResponsePrompt | None = None,
|
||||
parallel_tool_calls: bool | None = None,
|
||||
max_tool_calls: int | None = None,
|
||||
metadata: dict[str, str] | None = None,
|
||||
):
|
||||
self.inference_api = inference_api
|
||||
self.ctx = ctx
|
||||
|
|
@ -135,6 +138,7 @@ class StreamingResponseOrchestrator:
|
|||
self.parallel_tool_calls = parallel_tool_calls
|
||||
# Max number of total calls to built-in tools that can be processed in a response
|
||||
self.max_tool_calls = max_tool_calls
|
||||
self.metadata = metadata
|
||||
self.sequence_number = 0
|
||||
# Store MCP tool mapping that gets built during tool processing
|
||||
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = (
|
||||
|
|
@ -162,6 +166,7 @@ class StreamingResponseOrchestrator:
|
|||
model=self.ctx.model,
|
||||
status="completed",
|
||||
output=[OpenAIResponseMessage(role="assistant", content=[refusal_content], type="message")],
|
||||
metadata=self.metadata,
|
||||
)
|
||||
|
||||
return OpenAIResponseObjectStreamResponseCompleted(response=refusal_response)
|
||||
|
|
@ -197,6 +202,7 @@ class StreamingResponseOrchestrator:
|
|||
prompt=self.prompt,
|
||||
parallel_tool_calls=self.parallel_tool_calls,
|
||||
max_tool_calls=self.max_tool_calls,
|
||||
metadata=self.metadata,
|
||||
)
|
||||
|
||||
async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
|
|
@ -1106,8 +1112,10 @@ class StreamingResponseOrchestrator:
|
|||
"server_url": mcp_tool.server_url,
|
||||
"mcp_list_tools_id": list_id,
|
||||
}
|
||||
# List MCP tools with authorization from tool config
|
||||
async with tracing.span("list_mcp_tools", attributes):
|
||||
|
||||
# TODO: follow semantic conventions for Open Telemetry tool spans
|
||||
# https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-spans/#execute-tool-span
|
||||
with tracer.start_as_current_span("list_mcp_tools", attributes=attributes):
|
||||
tool_defs = await list_mcp_tools(
|
||||
endpoint=mcp_tool.server_url,
|
||||
headers=mcp_tool.headers,
|
||||
|
|
@ -1183,9 +1191,9 @@ class StreamingResponseOrchestrator:
|
|||
if mcp_server.require_approval == "never":
|
||||
return False
|
||||
if isinstance(mcp_server, ApprovalFilter):
|
||||
if tool_name in mcp_server.always:
|
||||
if mcp_server.always and tool_name in mcp_server.always:
|
||||
return True
|
||||
if tool_name in mcp_server.never:
|
||||
if mcp_server.never and tool_name in mcp_server.never:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,8 @@ import json
|
|||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.core.telemetry import tracing
|
||||
from opentelemetry import trace
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack_api import (
|
||||
ImageContentItem,
|
||||
|
|
@ -42,6 +43,7 @@ from llama_stack_api import (
|
|||
from .types import ChatCompletionContext, ToolExecutionResult
|
||||
|
||||
logger = get_logger(name=__name__, category="agents::meta_reference")
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
class ToolExecutor:
|
||||
|
|
@ -296,8 +298,9 @@ class ToolExecutor:
|
|||
"server_url": mcp_tool.server_url,
|
||||
"tool_name": function_name,
|
||||
}
|
||||
# Invoke MCP tool with authorization from tool config
|
||||
async with tracing.span("invoke_mcp_tool", attributes):
|
||||
# TODO: follow semantic conventions for Open Telemetry tool spans
|
||||
# https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-spans/#execute-tool-span
|
||||
with tracer.start_as_current_span("invoke_mcp_tool", attributes=attributes):
|
||||
result = await invoke_mcp_tool(
|
||||
endpoint=mcp_tool.server_url,
|
||||
tool_name=function_name,
|
||||
|
|
@ -318,7 +321,7 @@ class ToolExecutor:
|
|||
# Use vector_stores.search API instead of knowledge_search tool
|
||||
# to support filters and ranking_options
|
||||
query = tool_kwargs.get("query", "")
|
||||
async with tracing.span("knowledge_search", {}):
|
||||
with tracer.start_as_current_span("knowledge_search"):
|
||||
result = await self._execute_knowledge_search_via_vector_store(
|
||||
query=query,
|
||||
response_file_search_tool=response_file_search_tool,
|
||||
|
|
@ -327,7 +330,9 @@ class ToolExecutor:
|
|||
attributes = {
|
||||
"tool_name": function_name,
|
||||
}
|
||||
async with tracing.span("invoke_tool", attributes):
|
||||
# TODO: follow semantic conventions for Open Telemetry tool spans
|
||||
# https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-spans/#execute-tool-span
|
||||
with tracer.start_as_current_span("invoke_tool", attributes=attributes):
|
||||
result = await self.tool_runtime_api.invoke_tool(
|
||||
tool_name=function_name,
|
||||
kwargs=tool_kwargs,
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@
|
|||
|
||||
import asyncio
|
||||
|
||||
from llama_stack.core.telemetry import tracing
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack_api import OpenAIMessageParam, Safety, SafetyViolation, ViolationLevel
|
||||
|
||||
|
|
@ -31,15 +30,12 @@ class ShieldRunnerMixin:
|
|||
self.output_shields = output_shields
|
||||
|
||||
async def run_multiple_shields(self, messages: list[OpenAIMessageParam], identifiers: list[str]) -> None:
|
||||
async def run_shield_with_span(identifier: str):
|
||||
async with tracing.span(f"run_shield_{identifier}"):
|
||||
return await self.safety_api.run_shield(
|
||||
shield_id=identifier,
|
||||
messages=messages,
|
||||
params={},
|
||||
)
|
||||
|
||||
responses = await asyncio.gather(*[run_shield_with_span(identifier) for identifier in identifiers])
|
||||
responses = await asyncio.gather(
|
||||
*[
|
||||
self.safety_api.run_shield(shield_id=identifier, messages=messages, params={})
|
||||
for identifier in identifiers
|
||||
]
|
||||
)
|
||||
for identifier, response in zip(identifiers, responses, strict=False):
|
||||
if not response.violation:
|
||||
continue
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ from collections.abc import AsyncIterator, Iterable
|
|||
|
||||
from openai import AuthenticationError
|
||||
|
||||
from llama_stack.core.telemetry.tracing import get_current_span
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
from llama_stack_api import (
|
||||
|
|
@ -84,7 +83,7 @@ class BedrockInferenceAdapter(OpenAIMixin):
|
|||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
"""Override to enable streaming usage metrics and handle authentication errors."""
|
||||
# Enable streaming usage metrics when telemetry is active
|
||||
if params.stream and get_current_span() is not None:
|
||||
if params.stream:
|
||||
if params.stream_options is None:
|
||||
params.stream_options = {"include_usage": True}
|
||||
elif "include_usage" not in params.stream_options:
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ from typing import Any
|
|||
import litellm
|
||||
import requests
|
||||
|
||||
from llama_stack.core.telemetry.tracing import get_current_span
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
|
|
@ -59,7 +58,7 @@ class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
|
||||
# Add usage tracking for streaming when telemetry is active
|
||||
stream_options = params.stream_options
|
||||
if params.stream and get_current_span() is not None:
|
||||
if params.stream:
|
||||
if stream_options is None:
|
||||
stream_options = {"include_usage": True}
|
||||
elif "include_usage" not in stream_options:
|
||||
|
|
|
|||
|
|
@ -217,10 +217,9 @@ class LiteLLMOpenAIMixin(
|
|||
params: OpenAIChatCompletionRequestWithExtraBody,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
# Add usage tracking for streaming when telemetry is active
|
||||
from llama_stack.core.telemetry.tracing import get_current_span
|
||||
|
||||
stream_options = params.stream_options
|
||||
if params.stream and get_current_span() is not None:
|
||||
if params.stream:
|
||||
if stream_options is None:
|
||||
stream_options = {"include_usage": True}
|
||||
elif "include_usage" not in stream_options:
|
||||
|
|
|
|||
|
|
@ -89,6 +89,7 @@ async def client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerat
|
|||
# sse_client and streamablehttp_client have different signatures, but both
|
||||
# are called the same way here, so we cast to Any to avoid type errors
|
||||
client = cast(Any, sse_client)
|
||||
|
||||
async with client(endpoint, headers=headers) as client_streams:
|
||||
async with ClientSession(read_stream=client_streams[0], write_stream=client_streams[1]) as session:
|
||||
await session.initialize()
|
||||
|
|
|
|||
5
src/llama_stack/telemetry/__init__.py
Normal file
5
src/llama_stack/telemetry/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
27
src/llama_stack/telemetry/constants.py
Normal file
27
src/llama_stack/telemetry/constants.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
This file contains constants used for naming data captured for telemetry.
|
||||
|
||||
This is used to ensure that the data captured for telemetry is consistent and can be used to
|
||||
identify and correlate data. If custom telemetry data is added to llama stack, please add
|
||||
constants for it here.
|
||||
"""
|
||||
|
||||
llama_stack_prefix = "llama_stack"
|
||||
|
||||
# Safety Attributes
|
||||
RUN_SHIELD_OPERATION_NAME = "run_shield"
|
||||
|
||||
SAFETY_REQUEST_PREFIX = f"{llama_stack_prefix}.safety.request"
|
||||
SAFETY_REQUEST_SHIELD_ID_ATTRIBUTE = f"{SAFETY_REQUEST_PREFIX}.shield_id"
|
||||
SAFETY_REQUEST_MESSAGES_ATTRIBUTE = f"{SAFETY_REQUEST_PREFIX}.messages"
|
||||
|
||||
SAFETY_RESPONSE_PREFIX = f"{llama_stack_prefix}.safety.response"
|
||||
SAFETY_RESPONSE_METADATA_ATTRIBUTE = f"{SAFETY_RESPONSE_PREFIX}.metadata"
|
||||
SAFETY_RESPONSE_VIOLATION_LEVEL_ATTRIBUTE = f"{SAFETY_RESPONSE_PREFIX}.violation.level"
|
||||
SAFETY_RESPONSE_USER_MESSAGE_ATTRIBUTE = f"{SAFETY_RESPONSE_PREFIX}.violation.user_message"
|
||||
43
src/llama_stack/telemetry/helpers.py
Normal file
43
src/llama_stack/telemetry/helpers.py
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
|
||||
from opentelemetry import trace
|
||||
|
||||
from llama_stack_api import OpenAIMessageParam, RunShieldResponse
|
||||
|
||||
from .constants import (
|
||||
RUN_SHIELD_OPERATION_NAME,
|
||||
SAFETY_REQUEST_MESSAGES_ATTRIBUTE,
|
||||
SAFETY_REQUEST_SHIELD_ID_ATTRIBUTE,
|
||||
SAFETY_RESPONSE_METADATA_ATTRIBUTE,
|
||||
SAFETY_RESPONSE_USER_MESSAGE_ATTRIBUTE,
|
||||
SAFETY_RESPONSE_VIOLATION_LEVEL_ATTRIBUTE,
|
||||
)
|
||||
|
||||
|
||||
def safety_span_name(shield_id: str) -> str:
|
||||
return f"{RUN_SHIELD_OPERATION_NAME} {shield_id}"
|
||||
|
||||
|
||||
# TODO: Consider using Wrapt to automatically instrument code
|
||||
# This is the industry standard way to package automatically instrumentation in python.
|
||||
def safety_request_span_attributes(
|
||||
shield_id: str, messages: list[OpenAIMessageParam], response: RunShieldResponse
|
||||
) -> None:
|
||||
span = trace.get_current_span()
|
||||
span.set_attribute(SAFETY_REQUEST_SHIELD_ID_ATTRIBUTE, shield_id)
|
||||
messages_json = json.dumps([msg.model_dump() for msg in messages])
|
||||
span.set_attribute(SAFETY_REQUEST_MESSAGES_ATTRIBUTE, messages_json)
|
||||
|
||||
if response.violation:
|
||||
if response.violation.metadata:
|
||||
metadata_json = json.dumps(response.violation.metadata)
|
||||
span.set_attribute(SAFETY_RESPONSE_METADATA_ATTRIBUTE, metadata_json)
|
||||
if response.violation.user_message:
|
||||
span.set_attribute(SAFETY_RESPONSE_USER_MESSAGE_ATTRIBUTE, response.violation.user_message)
|
||||
span.set_attribute(SAFETY_RESPONSE_VIOLATION_LEVEL_ATTRIBUTE, response.violation.violation_level.value)
|
||||
|
|
@ -89,6 +89,7 @@ class Agents(Protocol):
|
|||
),
|
||||
] = None,
|
||||
max_tool_calls: int | None = None,
|
||||
metadata: dict[str, str] | None = None,
|
||||
) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]:
|
||||
"""Create a model response.
|
||||
|
||||
|
|
@ -100,6 +101,7 @@ class Agents(Protocol):
|
|||
:param include: (Optional) Additional fields to include in the response.
|
||||
:param guardrails: (Optional) List of guardrails to apply during response generation. Can be guardrail IDs (strings) or guardrail specifications.
|
||||
:param max_tool_calls: (Optional) Max number of total calls to built-in tools that can be processed in a response.
|
||||
:param metadata: (Optional) Dictionary of metadata key-value pairs to attach to the response.
|
||||
:returns: An OpenAIResponseObject.
|
||||
"""
|
||||
...
|
||||
|
|
|
|||
|
|
@ -1,22 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
def telemetry_traceable(cls):
|
||||
"""
|
||||
Mark a protocol for automatic tracing when telemetry is enabled.
|
||||
|
||||
This is a metadata-only decorator with no dependencies on core.
|
||||
Actual tracing is applied by core routers at runtime if telemetry is enabled.
|
||||
|
||||
Usage:
|
||||
@runtime_checkable
|
||||
@telemetry_traceable
|
||||
class MyProtocol(Protocol):
|
||||
...
|
||||
"""
|
||||
cls.__marked_for_tracing__ = True
|
||||
return cls
|
||||
|
|
@ -9,7 +9,6 @@ from typing import Annotated, Literal, Protocol, runtime_checkable
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack_api.common.tracing import telemetry_traceable
|
||||
from llama_stack_api.openai_responses import (
|
||||
OpenAIResponseInputFunctionToolCallOutput,
|
||||
OpenAIResponseMCPApprovalRequest,
|
||||
|
|
@ -157,7 +156,6 @@ class ConversationItemDeletedResource(BaseModel):
|
|||
|
||||
|
||||
@runtime_checkable
|
||||
@telemetry_traceable
|
||||
class Conversations(Protocol):
|
||||
"""Conversations
|
||||
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ from fastapi import File, Form, Response, UploadFile
|
|||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack_api.common.responses import Order
|
||||
from llama_stack_api.common.tracing import telemetry_traceable
|
||||
from llama_stack_api.schema_utils import json_schema_type, webmethod
|
||||
from llama_stack_api.version import LLAMA_STACK_API_V1
|
||||
|
||||
|
|
@ -102,7 +101,6 @@ class OpenAIFileDeleteResponse(BaseModel):
|
|||
|
||||
|
||||
@runtime_checkable
|
||||
@telemetry_traceable
|
||||
class Files(Protocol):
|
||||
"""Files
|
||||
|
||||
|
|
|
|||
|
|
@ -22,7 +22,6 @@ from llama_stack_api.common.content_types import InterleavedContent
|
|||
from llama_stack_api.common.responses import (
|
||||
Order,
|
||||
)
|
||||
from llama_stack_api.common.tracing import telemetry_traceable
|
||||
from llama_stack_api.models import Model
|
||||
from llama_stack_api.schema_utils import json_schema_type, register_schema, webmethod
|
||||
from llama_stack_api.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
|
||||
|
|
@ -989,7 +988,6 @@ class OpenAIEmbeddingsRequestWithExtraBody(BaseModel, extra="allow"):
|
|||
|
||||
|
||||
@runtime_checkable
|
||||
@telemetry_traceable
|
||||
class InferenceProvider(Protocol):
|
||||
"""
|
||||
This protocol defines the interface that should be implemented by all inference providers.
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ from typing import Any, Literal, Protocol, runtime_checkable
|
|||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from llama_stack_api.common.tracing import telemetry_traceable
|
||||
from llama_stack_api.resource import Resource, ResourceType
|
||||
from llama_stack_api.schema_utils import json_schema_type, webmethod
|
||||
from llama_stack_api.version import LLAMA_STACK_API_V1
|
||||
|
|
@ -106,7 +105,6 @@ class OpenAIListModelsResponse(BaseModel):
|
|||
|
||||
|
||||
@runtime_checkable
|
||||
@telemetry_traceable
|
||||
class Models(Protocol):
|
||||
async def list_models(self) -> ListModelsResponse:
|
||||
"""List all models.
|
||||
|
|
|
|||
|
|
@ -597,6 +597,7 @@ class OpenAIResponseObject(BaseModel):
|
|||
:param usage: (Optional) Token usage information for the response
|
||||
:param instructions: (Optional) System message inserted into the model's context
|
||||
:param max_tool_calls: (Optional) Max number of total calls to built-in tools that can be processed in a response
|
||||
:param metadata: (Optional) Dictionary of metadata key-value pairs
|
||||
"""
|
||||
|
||||
created_at: int
|
||||
|
|
@ -619,6 +620,7 @@ class OpenAIResponseObject(BaseModel):
|
|||
usage: OpenAIResponseUsage | None = None
|
||||
instructions: str | None = None
|
||||
max_tool_calls: int | None = None
|
||||
metadata: dict[str, str] | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ from typing import Protocol, runtime_checkable
|
|||
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
from llama_stack_api.common.tracing import telemetry_traceable
|
||||
from llama_stack_api.schema_utils import json_schema_type, webmethod
|
||||
from llama_stack_api.version import LLAMA_STACK_API_V1
|
||||
|
||||
|
|
@ -93,7 +92,6 @@ class ListPromptsResponse(BaseModel):
|
|||
|
||||
|
||||
@runtime_checkable
|
||||
@telemetry_traceable
|
||||
class Prompts(Protocol):
|
||||
"""Prompts
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ from typing import Any, Protocol, runtime_checkable
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack_api.common.tracing import telemetry_traceable
|
||||
from llama_stack_api.inference import OpenAIMessageParam
|
||||
from llama_stack_api.schema_utils import json_schema_type, webmethod
|
||||
from llama_stack_api.shields import Shield
|
||||
|
|
@ -94,7 +93,6 @@ class ShieldStore(Protocol):
|
|||
|
||||
|
||||
@runtime_checkable
|
||||
@telemetry_traceable
|
||||
class Safety(Protocol):
|
||||
"""Safety
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ from typing import Any, Literal, Protocol, runtime_checkable
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack_api.common.tracing import telemetry_traceable
|
||||
from llama_stack_api.resource import Resource, ResourceType
|
||||
from llama_stack_api.schema_utils import json_schema_type, webmethod
|
||||
from llama_stack_api.version import LLAMA_STACK_API_V1
|
||||
|
|
@ -49,7 +48,6 @@ class ListShieldsResponse(BaseModel):
|
|||
|
||||
|
||||
@runtime_checkable
|
||||
@telemetry_traceable
|
||||
class Shields(Protocol):
|
||||
@webmethod(route="/shields", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_shields(self) -> ListShieldsResponse:
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ from pydantic import BaseModel
|
|||
from typing_extensions import runtime_checkable
|
||||
|
||||
from llama_stack_api.common.content_types import URL, InterleavedContent
|
||||
from llama_stack_api.common.tracing import telemetry_traceable
|
||||
from llama_stack_api.resource import Resource, ResourceType
|
||||
from llama_stack_api.schema_utils import json_schema_type, webmethod
|
||||
from llama_stack_api.version import LLAMA_STACK_API_V1
|
||||
|
|
@ -109,7 +108,6 @@ class ListToolDefsResponse(BaseModel):
|
|||
|
||||
|
||||
@runtime_checkable
|
||||
@telemetry_traceable
|
||||
class ToolGroups(Protocol):
|
||||
@webmethod(route="/toolgroups", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
async def register_tool_group(
|
||||
|
|
@ -191,7 +189,6 @@ class SpecialToolGroup(Enum):
|
|||
|
||||
|
||||
@runtime_checkable
|
||||
@telemetry_traceable
|
||||
class ToolRuntime(Protocol):
|
||||
tool_store: ToolStore | None = None
|
||||
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ from typing import Annotated, Any, Literal, Protocol, runtime_checkable
|
|||
from fastapi import Body, Query
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from llama_stack_api.common.tracing import telemetry_traceable
|
||||
from llama_stack_api.inference import InterleavedContent
|
||||
from llama_stack_api.schema_utils import json_schema_type, register_schema, webmethod
|
||||
from llama_stack_api.vector_stores import VectorStore
|
||||
|
|
@ -572,7 +571,6 @@ class VectorStoreTable(Protocol):
|
|||
|
||||
|
||||
@runtime_checkable
|
||||
@telemetry_traceable
|
||||
class VectorIO(Protocol):
|
||||
vector_store_table: VectorStoreTable | None = None
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue