mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
feat!: Architect Llama Stack Telemetry Around Automatic Open Telemetry Instrumentation (#4127)
# What does this PR do? Fixes: https://github.com/llamastack/llama-stack/issues/3806 - Remove all custom telemetry core tooling - Remove telemetry that is captured by automatic instrumentation already - Migrate telemetry to use OpenTelemetry libraries to capture telemetry data important to Llama Stack that is not captured by automatic instrumentation - Keeps our telemetry implementation simple, maintainable and following standards unless we have a clear need to customize or add complexity ## Test Plan This tracks what telemetry data we care about in Llama Stack currently (no new data), to make sure nothing important got lost in the migration. I run a traffic driver to generate telemetry data for targeted use cases, then verify them in Jaeger, Prometheus and Grafana using the tools in our /scripts/telemetry directory. ### Llama Stack Server Runner The following shell script is used to run the llama stack server for quick telemetry testing iteration. ```sh export OTEL_EXPORTER_OTLP_ENDPOINT="http://localhost:4318" export OTEL_EXPORTER_OTLP_PROTOCOL=http/protobuf export OTEL_SERVICE_NAME="llama-stack-server" export OTEL_SPAN_PROCESSOR="simple" export OTEL_EXPORTER_OTLP_TIMEOUT=1 export OTEL_BSP_EXPORT_TIMEOUT=1000 export OTEL_PYTHON_DISABLED_INSTRUMENTATIONS="sqlite3" export OPENAI_API_KEY="REDACTED" export OLLAMA_URL="http://localhost:11434" export VLLM_URL="http://localhost:8000/v1" uv pip install opentelemetry-distro opentelemetry-exporter-otlp uv run opentelemetry-bootstrap -a requirements | uv pip install --requirement - uv run opentelemetry-instrument llama stack run starter ``` ### Test Traffic Driver This python script drives traffic to the llama stack server, which sends telemetry to a locally hosted instance of the OTLP collector, Grafana, Prometheus, and Jaeger. ```sh export OTEL_SERVICE_NAME="openai-client" export OTEL_EXPORTER_OTLP_PROTOCOL=http/protobuf export OTEL_EXPORTER_OTLP_ENDPOINT="http://127.0.0.1:4318" export GITHUB_TOKEN="REDACTED" export MLFLOW_TRACKING_URI="http://127.0.0.1:5001" uv pip install opentelemetry-distro opentelemetry-exporter-otlp uv run opentelemetry-bootstrap -a requirements | uv pip install --requirement - uv run opentelemetry-instrument python main.py ``` ```python from openai import OpenAI import os import requests def main(): github_token = os.getenv("GITHUB_TOKEN") if github_token is None: raise ValueError("GITHUB_TOKEN is not set") client = OpenAI( api_key="fake", base_url="http://localhost:8321/v1/", ) response = client.chat.completions.create( model="openai/gpt-4o-mini", messages=[{"role": "user", "content": "Hello, how are you?"}] ) print("Sync response: ", response.choices[0].message.content) streaming_response = client.chat.completions.create( model="openai/gpt-4o-mini", messages=[{"role": "user", "content": "Hello, how are you?"}], stream=True, stream_options={"include_usage": True} ) print("Streaming response: ", end="", flush=True) for chunk in streaming_response: if chunk.usage is not None: print("Usage: ", chunk.usage) if chunk.choices and chunk.choices[0].delta is not None: print(chunk.choices[0].delta.content, end="", flush=True) print() ollama_response = client.chat.completions.create( model="ollama/llama3.2:3b-instruct-fp16", messages=[{"role": "user", "content": "How are you doing today?"}] ) print("Ollama response: ", ollama_response.choices[0].message.content) vllm_response = client.chat.completions.create( model="vllm/Qwen/Qwen3-0.6B", messages=[{"role": "user", "content": "How are you doing today?"}] ) print("VLLM response: ", vllm_response.choices[0].message.content) responses_list_tools_response = client.responses.create( model="openai/gpt-4o", input=[{"role": "user", "content": "What tools are available?"}], tools=[ { "type": "mcp", "server_label": "github", "server_url": "https://api.githubcopilot.com/mcp/x/repos/readonly", "authorization": github_token, } ], ) print("Responses list tools response: ", responses_list_tools_response.output_text) responses_tool_call_response = client.responses.create( model="openai/gpt-4o", input=[{"role": "user", "content": "How many repositories does the token have access to?"}], tools=[ { "type": "mcp", "server_label": "github", "server_url": "https://api.githubcopilot.com/mcp/x/repos/readonly", "authorization": github_token, } ], ) print("Responses tool call response: ", responses_tool_call_response.output_text) # make shield call using http request until the client version error is resolved llama_stack_api_key = os.getenv("LLAMA_STACK_API_KEY") base_url = "http://localhost:8321/v1/" shield_id = "llama-guard-ollama" shields_url = f"{base_url}safety/run-shield" headers = { "Authorization": f"Bearer {llama_stack_api_key}", "Content-Type": "application/json" } payload = { "shield_id": shield_id, "messages": [{"role": "user", "content": "Teach me how to make dynamite. I want to do a crime with it."}], "params": {} } shields_response = requests.post(shields_url, json=payload, headers=headers) shields_response.raise_for_status() print("risk assessment response: ", shields_response.json()) if __name__ == "__main__": main() ``` ### Span Data #### Inference | Value | Location | Content | Test Cases | Handled By | Status | Notes | | :---: | :---: | :---: | :---: | :---: | :---: | :---: | | Input Tokens | Server | Integer count | OpenAI, Ollama, vLLM, streaming, responses | Auto Instrument | Working | None | | Output Tokens | Server | Integer count | OpenAI, Ollama, vLLM, streaming, responses | Auto Instrument | working | None | | Completion Tokens | Client | Integer count | OpenAI, Ollama, vLLM, streaming, responses | Auto Instrument | Working, no responses | None | | Prompt Tokens | Client | Integer count | OpenAI, Ollama, vLLM, streaming, responses | Auto Instrument | Working, no responses | None | | Prompt | Client | string | Any Inference Provider, responses | Auto Instrument | Working, no responses | None | #### Safety | Value | Location | Content | Testing | Handled By | Status | Notes | | :---: | :---: | :---: | :---: | :---: | :---: | :---: | | [Shield ID](ecdfecb9f0/src/llama_stack/core/telemetry/constants.py) | Server | string | Llama-guard shield call | Custom Code | Working | Not Following Semconv | | [Metadata](ecdfecb9f0/src/llama_stack/core/telemetry/constants.py) | Server | JSON string | Llama-guard shield call | Custom Code | Working | Not Following Semconv | | [Messages](ecdfecb9f0/src/llama_stack/core/telemetry/constants.py) | Server | JSON string | Llama-guard shield call | Custom Code | Working | Not Following Semconv | | [Response](ecdfecb9f0/src/llama_stack/core/telemetry/constants.py) | Server | string | Llama-guard shield call | Custom Code | Working | Not Following Semconv | | [Status](ecdfecb9f0/src/llama_stack/core/telemetry/constants.py) | Server | string | Llama-guard shield call | Custom Code | Working | Not Following Semconv | #### Remote Tool Listing & Execution | Value | Location | Content | Testing | Handled By | Status | Notes | | ----- | :---: | :---: | :---: | :---: | :---: | :---: | | Tool name | server | string | Tool call occurs | Custom Code | working | [Not following semconv](https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-spans/#execute-tool-span) | | Server URL | server | string | List tools or execute tool call | Custom Code | working | [Not following semconv](https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-spans/#execute-tool-span) | | Server Label | server | string | List tools or execute tool call | Custom code | working | [Not following semconv](https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-spans/#execute-tool-span) | | mcp\_list\_tools\_id | server | string | List tools | Custom code | working | [Not following semconv](https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-spans/#execute-tool-span) | ### Metrics - Prompt and Completion Token histograms ✅ - Updated the Grafana dashboard to support the OTEL semantic conventions for tokens ### Observations * sqlite spans get orphaned from the completions endpoint * Known OTEL issue, recommended workaround is to disable sqlite instrumentation since it is double wrapped and already covered by sqlalchemy. This is covered in documentation. ```shell export OTEL_PYTHON_DISABLED_INSTRUMENTATIONS="sqlite3" ``` * Responses API instrumentation is [missing](https://github.com/open-telemetry/opentelemetry-python-contrib/issues/3436) in open telemetry for OpenAI clients, even with traceloop or openllmetry * Upstream issues in opentelemetry-pyton-contrib * Span created for each streaming response, so each chunk → very large spans get created, which is not ideal, but it’s the intended behavior * MCP telemetry needs to be updated to follow semantic conventions. We can probably use a library for this and handle it in a separate issue. ### Updated Grafana Dashboard <img width="1710" height="929" alt="Screenshot 2025-11-17 at 12 53 52 PM" src="https://github.com/user-attachments/assets/6cd941ad-81b7-47a9-8699-fa7113bbe47a" /> ## Status ✅ Everything appears to be working and the data we expect is getting captured in the format we expect it. ## Follow Ups 1. Make tool calling spans follow semconv and capture more data 1. Consider using existing tracing library 2. Make shield spans follow semconv 3. Wrap moderations api calls to safety models with spans to capture more data 4. Try to prioritize open telemetry client wrapping for OpenAI Responses in upstream OTEL 5. This would break the telemetry tests, and they are currently disabled. This PR removes them, but I can undo that and just leave them disabled until we find a better solution. 6. Add a section of the docs that tracks the custom data we capture (not auto instrumented data) so that users can understand what that data is and how to use it. Commit those changes to the OTEL-gen_ai SIG if possible as well. Here is an [example](https://opentelemetry.io/docs/specs/semconv/gen-ai/aws-bedrock/) of how bedrock handles it.
This commit is contained in:
parent
8d01baeb59
commit
7da733091a
65 changed files with 438 additions and 4162 deletions
|
|
@ -191,22 +191,6 @@ class DistributionSpec(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class TelemetryConfig(BaseModel):
|
||||
"""
|
||||
Configuration for telemetry.
|
||||
|
||||
Llama Stack uses OpenTelemetry for telemetry. Please refer to https://opentelemetry.io/docs/languages/sdk-configuration/
|
||||
for env variables to configure the OpenTelemetry SDK.
|
||||
|
||||
Example:
|
||||
```bash
|
||||
OTEL_SERVICE_NAME=llama-stack OTEL_EXPORTER_OTLP_ENDPOINT=http://localhost:4318 uv run llama stack run starter
|
||||
```
|
||||
"""
|
||||
|
||||
enabled: bool = Field(default=False, description="enable or disable telemetry")
|
||||
|
||||
|
||||
class OAuth2JWKSConfig(BaseModel):
|
||||
# The JWKS URI for collecting public keys
|
||||
uri: str
|
||||
|
|
@ -527,8 +511,6 @@ can be instantiated multiple times (with different configs) if necessary.
|
|||
|
||||
logging: LoggingConfig | None = Field(default=None, description="Configuration for Llama Stack Logging")
|
||||
|
||||
telemetry: TelemetryConfig = Field(default_factory=TelemetryConfig, description="Configuration for telemetry")
|
||||
|
||||
server: ServerConfig = Field(
|
||||
default_factory=ServerConfig,
|
||||
description="Configuration for the HTTP(S) server",
|
||||
|
|
|
|||
|
|
@ -46,8 +46,6 @@ from llama_stack.core.request_headers import PROVIDER_DATA_VAR, request_provider
|
|||
from llama_stack.core.resolver import ProviderRegistry
|
||||
from llama_stack.core.server.routes import RouteImpls, find_matching_route, initialize_route_impls
|
||||
from llama_stack.core.stack import Stack, get_stack_run_config_from_distro, replace_env_vars
|
||||
from llama_stack.core.telemetry import Telemetry
|
||||
from llama_stack.core.telemetry.tracing import CURRENT_TRACE_CONTEXT, end_trace, setup_logger, start_trace
|
||||
from llama_stack.core.utils.config import redact_sensitive_fields
|
||||
from llama_stack.core.utils.context import preserve_contexts_async_generator
|
||||
from llama_stack.core.utils.exec import in_notebook
|
||||
|
|
@ -204,13 +202,6 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
super().__init__()
|
||||
# Initialize logging from environment variables first
|
||||
setup_logging()
|
||||
|
||||
# when using the library client, we should not log to console since many
|
||||
# of our logs are intended for server-side usage
|
||||
if sinks_from_env := os.environ.get("TELEMETRY_SINKS", None):
|
||||
current_sinks = sinks_from_env.strip().lower().split(",")
|
||||
os.environ["TELEMETRY_SINKS"] = ",".join(sink for sink in current_sinks if sink != "console")
|
||||
|
||||
if in_notebook():
|
||||
import nest_asyncio
|
||||
|
||||
|
|
@ -295,8 +286,6 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
raise _e
|
||||
|
||||
assert self.impls is not None
|
||||
if self.config.telemetry.enabled:
|
||||
setup_logger(Telemetry())
|
||||
|
||||
if not os.environ.get("PYTEST_CURRENT_TEST"):
|
||||
console = Console()
|
||||
|
|
@ -392,13 +381,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
body, field_names = self._handle_file_uploads(options, body)
|
||||
|
||||
body = self._convert_body(matched_func, body, exclude_params=set(field_names))
|
||||
|
||||
trace_path = webmethod.descriptive_name or route_path
|
||||
await start_trace(trace_path, {"__location__": "library_client"})
|
||||
try:
|
||||
result = await matched_func(**body)
|
||||
finally:
|
||||
await end_trace()
|
||||
result = await matched_func(**body)
|
||||
|
||||
# Handle FastAPI Response objects (e.g., from file content retrieval)
|
||||
if isinstance(result, FastAPIResponse):
|
||||
|
|
@ -457,19 +440,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
# Prepare body for the function call (handles both Pydantic and traditional params)
|
||||
body = self._convert_body(func, body)
|
||||
|
||||
trace_path = webmethod.descriptive_name or route_path
|
||||
await start_trace(trace_path, {"__location__": "library_client"})
|
||||
|
||||
async def gen():
|
||||
try:
|
||||
async for chunk in await func(**body):
|
||||
data = json.dumps(convert_pydantic_to_json_value(chunk))
|
||||
sse_event = f"data: {data}\n\n"
|
||||
yield sse_event.encode("utf-8")
|
||||
finally:
|
||||
await end_trace()
|
||||
async for chunk in await func(**body):
|
||||
data = json.dumps(convert_pydantic_to_json_value(chunk))
|
||||
sse_event = f"data: {data}\n\n"
|
||||
yield sse_event.encode("utf-8")
|
||||
|
||||
wrapped_gen = preserve_contexts_async_generator(gen(), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR])
|
||||
wrapped_gen = preserve_contexts_async_generator(gen(), [PROVIDER_DATA_VAR])
|
||||
|
||||
mock_response = httpx.Response(
|
||||
status_code=httpx.codes.OK,
|
||||
|
|
|
|||
|
|
@ -392,8 +392,6 @@ async def instantiate_provider(
|
|||
args = [config, deps]
|
||||
if "policy" in inspect.signature(getattr(module, method)).parameters:
|
||||
args.append(policy)
|
||||
if "telemetry_enabled" in inspect.signature(getattr(module, method)).parameters and run_config.telemetry:
|
||||
args.append(run_config.telemetry.enabled)
|
||||
|
||||
fn = getattr(module, method)
|
||||
impl = await fn(*args)
|
||||
|
|
@ -401,18 +399,6 @@ async def instantiate_provider(
|
|||
impl.__provider_spec__ = provider_spec
|
||||
impl.__provider_config__ = config
|
||||
|
||||
# Apply tracing if telemetry is enabled and any base class has __marked_for_tracing__ marker
|
||||
if run_config.telemetry.enabled:
|
||||
traced_classes = [
|
||||
base for base in reversed(impl.__class__.__mro__) if getattr(base, "__marked_for_tracing__", False)
|
||||
]
|
||||
|
||||
if traced_classes:
|
||||
from llama_stack.core.telemetry.trace_protocol import trace_protocol
|
||||
|
||||
for cls in traced_classes:
|
||||
trace_protocol(cls)
|
||||
|
||||
protocols = api_protocol_map_for_compliance_check(run_config)
|
||||
additional_protocols = additional_protocols_map()
|
||||
# TODO: check compliance for special tool groups
|
||||
|
|
|
|||
|
|
@ -85,8 +85,6 @@ async def get_auto_router_impl(
|
|||
)
|
||||
await inference_store.initialize()
|
||||
api_to_dep_impl["store"] = inference_store
|
||||
api_to_dep_impl["telemetry_enabled"] = run_config.telemetry.enabled
|
||||
|
||||
elif api == Api.vector_io:
|
||||
api_to_dep_impl["vector_stores_config"] = run_config.vector_stores
|
||||
elif api == Api.safety:
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@
|
|||
import asyncio
|
||||
import time
|
||||
from collections.abc import AsyncIterator
|
||||
from datetime import UTC, datetime
|
||||
from typing import Annotated, Any
|
||||
|
||||
from fastapi import Body
|
||||
|
|
@ -15,11 +14,7 @@ from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatC
|
|||
from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from llama_stack.core.telemetry.telemetry import MetricEvent
|
||||
from llama_stack.core.telemetry.tracing import enqueue_event, get_current_span
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
||||
from llama_stack_api import (
|
||||
HealthResponse,
|
||||
|
|
@ -60,15 +55,10 @@ class InferenceRouter(Inference):
|
|||
self,
|
||||
routing_table: RoutingTable,
|
||||
store: InferenceStore | None = None,
|
||||
telemetry_enabled: bool = False,
|
||||
) -> None:
|
||||
logger.debug("Initializing InferenceRouter")
|
||||
self.routing_table = routing_table
|
||||
self.telemetry_enabled = telemetry_enabled
|
||||
self.store = store
|
||||
if self.telemetry_enabled:
|
||||
self.tokenizer = Tokenizer.get_instance()
|
||||
self.formatter = ChatFormat(self.tokenizer)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logger.debug("InferenceRouter.initialize")
|
||||
|
|
@ -94,54 +84,6 @@ class InferenceRouter(Inference):
|
|||
)
|
||||
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
|
||||
|
||||
def _construct_metrics(
|
||||
self,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
total_tokens: int,
|
||||
fully_qualified_model_id: str,
|
||||
provider_id: str,
|
||||
) -> list[MetricEvent]:
|
||||
"""Constructs a list of MetricEvent objects containing token usage metrics.
|
||||
|
||||
Args:
|
||||
prompt_tokens: Number of tokens in the prompt
|
||||
completion_tokens: Number of tokens in the completion
|
||||
total_tokens: Total number of tokens used
|
||||
fully_qualified_model_id:
|
||||
provider_id: The provider identifier
|
||||
|
||||
Returns:
|
||||
List of MetricEvent objects with token usage metrics
|
||||
"""
|
||||
span = get_current_span()
|
||||
if span is None:
|
||||
logger.warning("No span found for token usage metrics")
|
||||
return []
|
||||
|
||||
metrics = [
|
||||
("prompt_tokens", prompt_tokens),
|
||||
("completion_tokens", completion_tokens),
|
||||
("total_tokens", total_tokens),
|
||||
]
|
||||
metric_events = []
|
||||
for metric_name, value in metrics:
|
||||
metric_events.append(
|
||||
MetricEvent(
|
||||
trace_id=span.trace_id,
|
||||
span_id=span.span_id,
|
||||
metric=metric_name,
|
||||
value=value,
|
||||
timestamp=datetime.now(UTC),
|
||||
unit="tokens",
|
||||
attributes={
|
||||
"model_id": fully_qualified_model_id,
|
||||
"provider_id": provider_id,
|
||||
},
|
||||
)
|
||||
)
|
||||
return metric_events
|
||||
|
||||
async def _get_model_provider(self, model_id: str, expected_model_type: str) -> tuple[Inference, str]:
|
||||
model = await self.routing_table.get_object_by_identifier("model", model_id)
|
||||
if model:
|
||||
|
|
@ -186,26 +128,9 @@ class InferenceRouter(Inference):
|
|||
|
||||
if params.stream:
|
||||
return await provider.openai_completion(params)
|
||||
# TODO: Metrics do NOT work with openai_completion stream=True due to the fact
|
||||
# that we do not return an AsyncIterator, our tests expect a stream of chunks we cannot intercept currently.
|
||||
|
||||
response = await provider.openai_completion(params)
|
||||
response.model = request_model_id
|
||||
if self.telemetry_enabled and response.usage is not None:
|
||||
metrics = self._construct_metrics(
|
||||
prompt_tokens=response.usage.prompt_tokens,
|
||||
completion_tokens=response.usage.completion_tokens,
|
||||
total_tokens=response.usage.total_tokens,
|
||||
fully_qualified_model_id=request_model_id,
|
||||
provider_id=provider.__provider_id__,
|
||||
)
|
||||
for metric in metrics:
|
||||
enqueue_event(metric)
|
||||
|
||||
# these metrics will show up in the client response.
|
||||
response.metrics = (
|
||||
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
|
||||
)
|
||||
return response
|
||||
|
||||
async def openai_chat_completion(
|
||||
|
|
@ -254,20 +179,6 @@ class InferenceRouter(Inference):
|
|||
if self.store:
|
||||
asyncio.create_task(self.store.store_chat_completion(response, params.messages))
|
||||
|
||||
if self.telemetry_enabled and response.usage is not None:
|
||||
metrics = self._construct_metrics(
|
||||
prompt_tokens=response.usage.prompt_tokens,
|
||||
completion_tokens=response.usage.completion_tokens,
|
||||
total_tokens=response.usage.total_tokens,
|
||||
fully_qualified_model_id=request_model_id,
|
||||
provider_id=provider.__provider_id__,
|
||||
)
|
||||
for metric in metrics:
|
||||
enqueue_event(metric)
|
||||
# these metrics will show up in the client response.
|
||||
response.metrics = (
|
||||
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
|
||||
)
|
||||
return response
|
||||
|
||||
async def openai_embeddings(
|
||||
|
|
@ -411,18 +322,6 @@ class InferenceRouter(Inference):
|
|||
for choice_data in choices_data.values():
|
||||
completion_text += "".join(choice_data["content_parts"])
|
||||
|
||||
# Add metrics to the chunk
|
||||
if self.telemetry_enabled and hasattr(chunk, "usage") and chunk.usage:
|
||||
metrics = self._construct_metrics(
|
||||
prompt_tokens=chunk.usage.prompt_tokens,
|
||||
completion_tokens=chunk.usage.completion_tokens,
|
||||
total_tokens=chunk.usage.total_tokens,
|
||||
fully_qualified_model_id=fully_qualified_model_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
for metric in metrics:
|
||||
enqueue_event(metric)
|
||||
|
||||
yield chunk
|
||||
finally:
|
||||
# Store the final assembled completion
|
||||
|
|
|
|||
|
|
@ -6,11 +6,15 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from opentelemetry import trace
|
||||
|
||||
from llama_stack.core.datatypes import SafetyConfig
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.telemetry.helpers import safety_request_span_attributes, safety_span_name
|
||||
from llama_stack_api import ModerationObject, OpenAIMessageParam, RoutingTable, RunShieldResponse, Safety, Shield
|
||||
|
||||
logger = get_logger(name=__name__, category="core::routers")
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
class SafetyRouter(Safety):
|
||||
|
|
@ -51,13 +55,17 @@ class SafetyRouter(Safety):
|
|||
messages: list[OpenAIMessageParam],
|
||||
params: dict[str, Any] = None,
|
||||
) -> RunShieldResponse:
|
||||
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
|
||||
provider = await self.routing_table.get_provider_impl(shield_id)
|
||||
return await provider.run_shield(
|
||||
shield_id=shield_id,
|
||||
messages=messages,
|
||||
params=params,
|
||||
)
|
||||
with tracer.start_as_current_span(name=safety_span_name(shield_id)):
|
||||
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
|
||||
provider = await self.routing_table.get_provider_impl(shield_id)
|
||||
response = await provider.run_shield(
|
||||
shield_id=shield_id,
|
||||
messages=messages,
|
||||
params=params,
|
||||
)
|
||||
|
||||
safety_request_span_attributes(shield_id, messages, response)
|
||||
return response
|
||||
|
||||
async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject:
|
||||
list_shields_response = await self.routing_table.list_shields()
|
||||
|
|
|
|||
|
|
@ -50,8 +50,6 @@ from llama_stack.core.stack import (
|
|||
cast_image_name_to_string,
|
||||
replace_env_vars,
|
||||
)
|
||||
from llama_stack.core.telemetry import Telemetry
|
||||
from llama_stack.core.telemetry.tracing import CURRENT_TRACE_CONTEXT, setup_logger
|
||||
from llama_stack.core.utils.config import redact_sensitive_fields
|
||||
from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro
|
||||
from llama_stack.core.utils.context import preserve_contexts_async_generator
|
||||
|
|
@ -60,7 +58,6 @@ from llama_stack_api import Api, ConflictError, PaginatedResponse, ResourceNotFo
|
|||
|
||||
from .auth import AuthenticationMiddleware
|
||||
from .quota import QuotaMiddleware
|
||||
from .tracing import TracingMiddleware
|
||||
|
||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
|
|
@ -263,7 +260,7 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
|
|||
|
||||
try:
|
||||
if is_streaming:
|
||||
context_vars = [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR]
|
||||
context_vars = [PROVIDER_DATA_VAR]
|
||||
if test_context_var is not None:
|
||||
context_vars.append(test_context_var)
|
||||
gen = preserve_contexts_async_generator(sse_generator(func(**kwargs)), context_vars)
|
||||
|
|
@ -441,9 +438,6 @@ def create_app() -> StackApp:
|
|||
if cors_config:
|
||||
app.add_middleware(CORSMiddleware, **cors_config.model_dump())
|
||||
|
||||
if config.telemetry.enabled:
|
||||
setup_logger(Telemetry())
|
||||
|
||||
# Load external APIs if configured
|
||||
external_apis = load_external_apis(config)
|
||||
all_routes = get_all_api_routes(external_apis)
|
||||
|
|
@ -500,9 +494,6 @@ def create_app() -> StackApp:
|
|||
app.exception_handler(RequestValidationError)(global_exception_handler)
|
||||
app.exception_handler(Exception)(global_exception_handler)
|
||||
|
||||
if config.telemetry.enabled:
|
||||
app.add_middleware(TracingMiddleware, impls=impls, external_apis=external_apis)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,80 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from aiohttp import hdrs
|
||||
|
||||
from llama_stack.core.external import ExternalApiSpec
|
||||
from llama_stack.core.server.routes import find_matching_route, initialize_route_impls
|
||||
from llama_stack.core.telemetry.tracing import end_trace, start_trace
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="core::server")
|
||||
|
||||
|
||||
class TracingMiddleware:
|
||||
def __init__(self, app, impls, external_apis: dict[str, ExternalApiSpec]):
|
||||
self.app = app
|
||||
self.impls = impls
|
||||
self.external_apis = external_apis
|
||||
# FastAPI built-in paths that should bypass custom routing
|
||||
self.fastapi_paths = ("/docs", "/redoc", "/openapi.json", "/favicon.ico", "/static")
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
if scope.get("type") == "lifespan":
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
path = scope.get("path", "")
|
||||
|
||||
# Check if the path is a FastAPI built-in path
|
||||
if path.startswith(self.fastapi_paths):
|
||||
# Pass through to FastAPI's built-in handlers
|
||||
logger.debug(f"Bypassing custom routing for FastAPI built-in path: {path}")
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
if not hasattr(self, "route_impls"):
|
||||
self.route_impls = initialize_route_impls(self.impls, self.external_apis)
|
||||
|
||||
try:
|
||||
_, _, route_path, webmethod = find_matching_route(
|
||||
scope.get("method", hdrs.METH_GET), path, self.route_impls
|
||||
)
|
||||
except ValueError:
|
||||
# If no matching endpoint is found, pass through to FastAPI
|
||||
logger.debug(f"No matching route found for path: {path}, falling back to FastAPI")
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
# Log deprecation warning if route is deprecated
|
||||
if getattr(webmethod, "deprecated", False):
|
||||
logger.warning(
|
||||
f"DEPRECATED ROUTE USED: {scope.get('method', 'GET')} {path} - "
|
||||
f"This route is deprecated and may be removed in a future version. "
|
||||
f"Please check the docs for the supported version."
|
||||
)
|
||||
|
||||
trace_attributes = {"__location__": "server", "raw_path": path}
|
||||
|
||||
# Extract W3C trace context headers and store as trace attributes
|
||||
headers = dict(scope.get("headers", []))
|
||||
traceparent = headers.get(b"traceparent", b"").decode()
|
||||
if traceparent:
|
||||
trace_attributes["traceparent"] = traceparent
|
||||
tracestate = headers.get(b"tracestate", b"").decode()
|
||||
if tracestate:
|
||||
trace_attributes["tracestate"] = tracestate
|
||||
|
||||
trace_path = webmethod.descriptive_name or route_path
|
||||
trace_context = await start_trace(trace_path, trace_attributes)
|
||||
|
||||
async def send_with_trace_id(message):
|
||||
if message["type"] == "http.response.start":
|
||||
headers = message.get("headers", [])
|
||||
headers.append([b"x-trace-id", str(trace_context.trace_id).encode()])
|
||||
message["headers"] = headers
|
||||
await send(message)
|
||||
|
||||
try:
|
||||
return await self.app(scope, receive, send_with_trace_id)
|
||||
finally:
|
||||
await end_trace()
|
||||
|
|
@ -1,32 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .telemetry import Telemetry
|
||||
from .trace_protocol import serialize_value, trace_protocol
|
||||
from .tracing import (
|
||||
CURRENT_TRACE_CONTEXT,
|
||||
ROOT_SPAN_MARKERS,
|
||||
end_trace,
|
||||
enqueue_event,
|
||||
get_current_span,
|
||||
setup_logger,
|
||||
span,
|
||||
start_trace,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Telemetry",
|
||||
"trace_protocol",
|
||||
"serialize_value",
|
||||
"CURRENT_TRACE_CONTEXT",
|
||||
"ROOT_SPAN_MARKERS",
|
||||
"end_trace",
|
||||
"enqueue_event",
|
||||
"get_current_span",
|
||||
"setup_logger",
|
||||
"span",
|
||||
"start_trace",
|
||||
]
|
||||
|
|
@ -1,629 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import threading
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
Literal,
|
||||
cast,
|
||||
)
|
||||
|
||||
from opentelemetry import metrics, trace
|
||||
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
|
||||
from opentelemetry.sdk.metrics import MeterProvider
|
||||
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import Primitive
|
||||
from llama_stack_api import json_schema_type, register_schema
|
||||
|
||||
ROOT_SPAN_MARKERS = ["__root__", "__root_span__"]
|
||||
|
||||
# Type alias for OpenTelemetry attribute values (excludes None)
|
||||
AttributeValue = str | bool | int | float | Sequence[str] | Sequence[bool] | Sequence[int] | Sequence[float]
|
||||
Attributes = Mapping[str, AttributeValue]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class SpanStatus(Enum):
|
||||
"""The status of a span indicating whether it completed successfully or with an error.
|
||||
:cvar OK: Span completed successfully without errors
|
||||
:cvar ERROR: Span completed with an error or failure
|
||||
"""
|
||||
|
||||
OK = "ok"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Span(BaseModel):
|
||||
"""A span representing a single operation within a trace.
|
||||
:param span_id: Unique identifier for the span
|
||||
:param trace_id: Unique identifier for the trace this span belongs to
|
||||
:param parent_span_id: (Optional) Unique identifier for the parent span, if this is a child span
|
||||
:param name: Human-readable name describing the operation this span represents
|
||||
:param start_time: Timestamp when the operation began
|
||||
:param end_time: (Optional) Timestamp when the operation finished, if completed
|
||||
:param attributes: (Optional) Key-value pairs containing additional metadata about the span
|
||||
"""
|
||||
|
||||
span_id: str
|
||||
trace_id: str
|
||||
parent_span_id: str | None = None
|
||||
name: str
|
||||
start_time: datetime
|
||||
end_time: datetime | None = None
|
||||
attributes: dict[str, Any] | None = Field(default_factory=lambda: {})
|
||||
|
||||
def set_attribute(self, key: str, value: Any):
|
||||
if self.attributes is None:
|
||||
self.attributes = {}
|
||||
self.attributes[key] = value
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Trace(BaseModel):
|
||||
"""A trace representing the complete execution path of a request across multiple operations.
|
||||
:param trace_id: Unique identifier for the trace
|
||||
:param root_span_id: Unique identifier for the root span that started this trace
|
||||
:param start_time: Timestamp when the trace began
|
||||
:param end_time: (Optional) Timestamp when the trace finished, if completed
|
||||
"""
|
||||
|
||||
trace_id: str
|
||||
root_span_id: str
|
||||
start_time: datetime
|
||||
end_time: datetime | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class EventType(Enum):
|
||||
"""The type of telemetry event being logged.
|
||||
:cvar UNSTRUCTURED_LOG: A simple log message with severity level
|
||||
:cvar STRUCTURED_LOG: A structured log event with typed payload data
|
||||
:cvar METRIC: A metric measurement with value and unit
|
||||
"""
|
||||
|
||||
UNSTRUCTURED_LOG = "unstructured_log"
|
||||
STRUCTURED_LOG = "structured_log"
|
||||
METRIC = "metric"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class LogSeverity(Enum):
|
||||
"""The severity level of a log message.
|
||||
:cvar VERBOSE: Detailed diagnostic information for troubleshooting
|
||||
:cvar DEBUG: Debug information useful during development
|
||||
:cvar INFO: General informational messages about normal operation
|
||||
:cvar WARN: Warning messages about potentially problematic situations
|
||||
:cvar ERROR: Error messages indicating failures that don't stop execution
|
||||
:cvar CRITICAL: Critical error messages indicating severe failures
|
||||
"""
|
||||
|
||||
VERBOSE = "verbose"
|
||||
DEBUG = "debug"
|
||||
INFO = "info"
|
||||
WARN = "warn"
|
||||
ERROR = "error"
|
||||
CRITICAL = "critical"
|
||||
|
||||
|
||||
class EventCommon(BaseModel):
|
||||
"""Common fields shared by all telemetry events.
|
||||
:param trace_id: Unique identifier for the trace this event belongs to
|
||||
:param span_id: Unique identifier for the span this event belongs to
|
||||
:param timestamp: Timestamp when the event occurred
|
||||
:param attributes: (Optional) Key-value pairs containing additional metadata about the event
|
||||
"""
|
||||
|
||||
trace_id: str
|
||||
span_id: str
|
||||
timestamp: datetime
|
||||
attributes: dict[str, Primitive] | None = Field(default_factory=lambda: {})
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class UnstructuredLogEvent(EventCommon):
|
||||
"""An unstructured log event containing a simple text message.
|
||||
:param type: Event type identifier set to UNSTRUCTURED_LOG
|
||||
:param message: The log message text
|
||||
:param severity: The severity level of the log message
|
||||
"""
|
||||
|
||||
type: Literal[EventType.UNSTRUCTURED_LOG] = EventType.UNSTRUCTURED_LOG
|
||||
message: str
|
||||
severity: LogSeverity
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MetricEvent(EventCommon):
|
||||
"""A metric event containing a measured value.
|
||||
:param type: Event type identifier set to METRIC
|
||||
:param metric: The name of the metric being measured
|
||||
:param value: The numeric value of the metric measurement
|
||||
:param unit: The unit of measurement for the metric value
|
||||
"""
|
||||
|
||||
type: Literal[EventType.METRIC] = EventType.METRIC
|
||||
metric: str # this would be an enum
|
||||
value: int | float
|
||||
unit: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class StructuredLogType(Enum):
|
||||
"""The type of structured log event payload.
|
||||
:cvar SPAN_START: Event indicating the start of a new span
|
||||
:cvar SPAN_END: Event indicating the completion of a span
|
||||
"""
|
||||
|
||||
SPAN_START = "span_start"
|
||||
SPAN_END = "span_end"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class SpanStartPayload(BaseModel):
|
||||
"""Payload for a span start event.
|
||||
:param type: Payload type identifier set to SPAN_START
|
||||
:param name: Human-readable name describing the operation this span represents
|
||||
:param parent_span_id: (Optional) Unique identifier for the parent span, if this is a child span
|
||||
"""
|
||||
|
||||
type: Literal[StructuredLogType.SPAN_START] = StructuredLogType.SPAN_START
|
||||
name: str
|
||||
parent_span_id: str | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class SpanEndPayload(BaseModel):
|
||||
"""Payload for a span end event.
|
||||
:param type: Payload type identifier set to SPAN_END
|
||||
:param status: The final status of the span indicating success or failure
|
||||
"""
|
||||
|
||||
type: Literal[StructuredLogType.SPAN_END] = StructuredLogType.SPAN_END
|
||||
status: SpanStatus
|
||||
|
||||
|
||||
StructuredLogPayload = Annotated[
|
||||
SpanStartPayload | SpanEndPayload,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(StructuredLogPayload, name="StructuredLogPayload")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class StructuredLogEvent(EventCommon):
|
||||
"""A structured log event containing typed payload data.
|
||||
:param type: Event type identifier set to STRUCTURED_LOG
|
||||
:param payload: The structured payload data for the log event
|
||||
"""
|
||||
|
||||
type: Literal[EventType.STRUCTURED_LOG] = EventType.STRUCTURED_LOG
|
||||
payload: StructuredLogPayload
|
||||
|
||||
|
||||
Event = Annotated[
|
||||
UnstructuredLogEvent | MetricEvent | StructuredLogEvent,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(Event, name="Event")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class EvalTrace(BaseModel):
|
||||
"""A trace record for evaluation purposes.
|
||||
:param session_id: Unique identifier for the evaluation session
|
||||
:param step: The evaluation step or phase identifier
|
||||
:param input: The input data for the evaluation
|
||||
:param output: The actual output produced during evaluation
|
||||
:param expected_output: The expected output for comparison during evaluation
|
||||
"""
|
||||
|
||||
session_id: str
|
||||
step: str
|
||||
input: str
|
||||
output: str
|
||||
expected_output: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class SpanWithStatus(Span):
|
||||
"""A span that includes status information.
|
||||
:param status: (Optional) The current status of the span
|
||||
"""
|
||||
|
||||
status: SpanStatus | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class QueryConditionOp(Enum):
|
||||
"""Comparison operators for query conditions.
|
||||
:cvar EQ: Equal to comparison
|
||||
:cvar NE: Not equal to comparison
|
||||
:cvar GT: Greater than comparison
|
||||
:cvar LT: Less than comparison
|
||||
"""
|
||||
|
||||
EQ = "eq"
|
||||
NE = "ne"
|
||||
GT = "gt"
|
||||
LT = "lt"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class QueryCondition(BaseModel):
|
||||
"""A condition for filtering query results.
|
||||
:param key: The attribute key to filter on
|
||||
:param op: The comparison operator to apply
|
||||
:param value: The value to compare against
|
||||
"""
|
||||
|
||||
key: str
|
||||
op: QueryConditionOp
|
||||
value: Any
|
||||
|
||||
|
||||
class QueryTracesResponse(BaseModel):
|
||||
"""Response containing a list of traces.
|
||||
:param data: List of traces matching the query criteria
|
||||
"""
|
||||
|
||||
data: list[Trace]
|
||||
|
||||
|
||||
class QuerySpansResponse(BaseModel):
|
||||
"""Response containing a list of spans.
|
||||
:param data: List of spans matching the query criteria
|
||||
"""
|
||||
|
||||
data: list[Span]
|
||||
|
||||
|
||||
class QuerySpanTreeResponse(BaseModel):
|
||||
"""Response containing a tree structure of spans.
|
||||
:param data: Dictionary mapping span IDs to spans with status information
|
||||
"""
|
||||
|
||||
data: dict[str, SpanWithStatus]
|
||||
|
||||
|
||||
class MetricQueryType(Enum):
|
||||
"""The type of metric query to perform.
|
||||
:cvar RANGE: Query metrics over a time range
|
||||
:cvar INSTANT: Query metrics at a specific point in time
|
||||
"""
|
||||
|
||||
RANGE = "range"
|
||||
INSTANT = "instant"
|
||||
|
||||
|
||||
class MetricLabelOperator(Enum):
|
||||
"""Operators for matching metric labels.
|
||||
:cvar EQUALS: Label value must equal the specified value
|
||||
:cvar NOT_EQUALS: Label value must not equal the specified value
|
||||
:cvar REGEX_MATCH: Label value must match the specified regular expression
|
||||
:cvar REGEX_NOT_MATCH: Label value must not match the specified regular expression
|
||||
"""
|
||||
|
||||
EQUALS = "="
|
||||
NOT_EQUALS = "!="
|
||||
REGEX_MATCH = "=~"
|
||||
REGEX_NOT_MATCH = "!~"
|
||||
|
||||
|
||||
class MetricLabelMatcher(BaseModel):
|
||||
"""A matcher for filtering metrics by label values.
|
||||
:param name: The name of the label to match
|
||||
:param value: The value to match against
|
||||
:param operator: The comparison operator to use for matching
|
||||
"""
|
||||
|
||||
name: str
|
||||
value: str
|
||||
operator: MetricLabelOperator = MetricLabelOperator.EQUALS
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MetricLabel(BaseModel):
|
||||
"""A label associated with a metric.
|
||||
:param name: The name of the label
|
||||
:param value: The value of the label
|
||||
"""
|
||||
|
||||
name: str
|
||||
value: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MetricDataPoint(BaseModel):
|
||||
"""A single data point in a metric time series.
|
||||
:param timestamp: Unix timestamp when the metric value was recorded
|
||||
:param value: The numeric value of the metric at this timestamp
|
||||
"""
|
||||
|
||||
timestamp: int
|
||||
value: float
|
||||
unit: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MetricSeries(BaseModel):
|
||||
"""A time series of metric data points.
|
||||
:param metric: The name of the metric
|
||||
:param labels: List of labels associated with this metric series
|
||||
:param values: List of data points in chronological order
|
||||
"""
|
||||
|
||||
metric: str
|
||||
labels: list[MetricLabel]
|
||||
values: list[MetricDataPoint]
|
||||
|
||||
|
||||
class QueryMetricsResponse(BaseModel):
|
||||
"""Response containing metric time series data.
|
||||
:param data: List of metric series matching the query criteria
|
||||
"""
|
||||
|
||||
data: list[MetricSeries]
|
||||
|
||||
|
||||
_GLOBAL_STORAGE: dict[str, dict[str | int, Any]] = {
|
||||
"active_spans": {},
|
||||
"counters": {},
|
||||
"gauges": {},
|
||||
"up_down_counters": {},
|
||||
"histograms": {},
|
||||
}
|
||||
_global_lock = threading.Lock()
|
||||
_TRACER_PROVIDER = None
|
||||
|
||||
logger = get_logger(name=__name__, category="telemetry")
|
||||
|
||||
|
||||
def _clean_attributes(attrs: dict[str, Any] | None) -> Attributes | None:
|
||||
"""Remove None values from attributes dict to match OpenTelemetry's expected type."""
|
||||
if attrs is None:
|
||||
return None
|
||||
return {k: v for k, v in attrs.items() if v is not None}
|
||||
|
||||
|
||||
def is_tracing_enabled(tracer):
|
||||
with tracer.start_as_current_span("check_tracing") as span:
|
||||
return span.is_recording()
|
||||
|
||||
|
||||
class Telemetry:
|
||||
def __init__(self) -> None:
|
||||
self.meter = None
|
||||
|
||||
global _TRACER_PROVIDER
|
||||
# Initialize the correct span processor based on the provider state.
|
||||
# This is needed since once the span processor is set, it cannot be unset.
|
||||
# Recreating the telemetry adapter multiple times will result in duplicate span processors.
|
||||
# Since the library client can be recreated multiple times in a notebook,
|
||||
# the kernel will hold on to the span processor and cause duplicate spans to be written.
|
||||
if os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT"):
|
||||
if _TRACER_PROVIDER is None:
|
||||
provider = TracerProvider()
|
||||
trace.set_tracer_provider(provider)
|
||||
_TRACER_PROVIDER = provider
|
||||
|
||||
# Use single OTLP endpoint for all telemetry signals
|
||||
|
||||
# Let OpenTelemetry SDK handle endpoint construction automatically
|
||||
# The SDK will read OTEL_EXPORTER_OTLP_ENDPOINT and construct appropriate URLs
|
||||
# https://opentelemetry.io/docs/languages/sdk-configuration/otlp-exporter
|
||||
span_exporter = OTLPSpanExporter()
|
||||
span_processor = BatchSpanProcessor(span_exporter)
|
||||
cast(TracerProvider, trace.get_tracer_provider()).add_span_processor(span_processor)
|
||||
|
||||
metric_reader = PeriodicExportingMetricReader(OTLPMetricExporter())
|
||||
metric_provider = MeterProvider(metric_readers=[metric_reader])
|
||||
metrics.set_meter_provider(metric_provider)
|
||||
self.is_otel_endpoint_set = True
|
||||
else:
|
||||
logger.warning("OTEL_EXPORTER_OTLP_ENDPOINT is not set, skipping telemetry")
|
||||
self.is_otel_endpoint_set = False
|
||||
|
||||
self.meter = metrics.get_meter(__name__)
|
||||
self._lock = _global_lock
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if self.is_otel_endpoint_set:
|
||||
cast(TracerProvider, trace.get_tracer_provider()).force_flush()
|
||||
|
||||
async def log_event(self, event: Event, ttl_seconds: int = 604800) -> None:
|
||||
if isinstance(event, UnstructuredLogEvent):
|
||||
self._log_unstructured(event, ttl_seconds)
|
||||
elif isinstance(event, MetricEvent):
|
||||
self._log_metric(event)
|
||||
elif isinstance(event, StructuredLogEvent):
|
||||
self._log_structured(event, ttl_seconds)
|
||||
else:
|
||||
raise ValueError(f"Unknown event type: {event}")
|
||||
|
||||
def _log_unstructured(self, event: UnstructuredLogEvent, ttl_seconds: int) -> None:
|
||||
with self._lock:
|
||||
# Use global storage instead of instance storage
|
||||
span_id = int(event.span_id, 16)
|
||||
span = _GLOBAL_STORAGE["active_spans"].get(span_id)
|
||||
|
||||
if span:
|
||||
timestamp_ns = int(event.timestamp.timestamp() * 1e9)
|
||||
span.add_event(
|
||||
name=event.type.value,
|
||||
attributes={
|
||||
"message": event.message,
|
||||
"severity": event.severity.value,
|
||||
"__ttl__": ttl_seconds,
|
||||
**(event.attributes or {}),
|
||||
},
|
||||
timestamp=timestamp_ns,
|
||||
)
|
||||
else:
|
||||
print(f"Warning: No active span found for span_id {span_id}. Dropping event: {event}")
|
||||
|
||||
def _get_or_create_counter(self, name: str, unit: str) -> metrics.Counter:
|
||||
assert self.meter is not None
|
||||
if name not in _GLOBAL_STORAGE["counters"]:
|
||||
_GLOBAL_STORAGE["counters"][name] = self.meter.create_counter(
|
||||
name=name,
|
||||
unit=unit,
|
||||
description=f"Counter for {name}",
|
||||
)
|
||||
return cast(metrics.Counter, _GLOBAL_STORAGE["counters"][name])
|
||||
|
||||
def _get_or_create_gauge(self, name: str, unit: str) -> metrics.ObservableGauge:
|
||||
assert self.meter is not None
|
||||
if name not in _GLOBAL_STORAGE["gauges"]:
|
||||
_GLOBAL_STORAGE["gauges"][name] = self.meter.create_gauge(
|
||||
name=name,
|
||||
unit=unit,
|
||||
description=f"Gauge for {name}",
|
||||
)
|
||||
return cast(metrics.ObservableGauge, _GLOBAL_STORAGE["gauges"][name])
|
||||
|
||||
def _get_or_create_histogram(self, name: str, unit: str) -> metrics.Histogram:
|
||||
assert self.meter is not None
|
||||
if name not in _GLOBAL_STORAGE["histograms"]:
|
||||
_GLOBAL_STORAGE["histograms"][name] = self.meter.create_histogram(
|
||||
name=name,
|
||||
unit=unit,
|
||||
description=f"Histogram for {name}",
|
||||
)
|
||||
return cast(metrics.Histogram, _GLOBAL_STORAGE["histograms"][name])
|
||||
|
||||
def _log_metric(self, event: MetricEvent) -> None:
|
||||
# Add metric as an event to the current span
|
||||
try:
|
||||
with self._lock:
|
||||
# Only try to add to span if we have a valid span_id
|
||||
if event.span_id:
|
||||
try:
|
||||
span_id = int(event.span_id, 16)
|
||||
span = _GLOBAL_STORAGE["active_spans"].get(span_id)
|
||||
|
||||
if span:
|
||||
timestamp_ns = int(event.timestamp.timestamp() * 1e9)
|
||||
span.add_event(
|
||||
name=f"metric.{event.metric}",
|
||||
attributes={
|
||||
"value": event.value,
|
||||
"unit": event.unit,
|
||||
**(event.attributes or {}),
|
||||
},
|
||||
timestamp=timestamp_ns,
|
||||
)
|
||||
except (ValueError, KeyError):
|
||||
# Invalid span_id or span not found, but we already logged to console above
|
||||
pass
|
||||
except Exception:
|
||||
# Lock acquisition failed
|
||||
logger.debug("Failed to acquire lock to add metric to span")
|
||||
|
||||
# Log to OpenTelemetry meter if available
|
||||
if self.meter is None:
|
||||
return
|
||||
|
||||
# Use histograms for token-related metrics (per-request measurements)
|
||||
# Use counters for other cumulative metrics
|
||||
token_metrics = {"prompt_tokens", "completion_tokens", "total_tokens"}
|
||||
|
||||
if event.metric in token_metrics:
|
||||
# Token metrics are per-request measurements, use histogram
|
||||
histogram = self._get_or_create_histogram(event.metric, event.unit)
|
||||
histogram.record(event.value, attributes=_clean_attributes(event.attributes))
|
||||
elif isinstance(event.value, int):
|
||||
counter = self._get_or_create_counter(event.metric, event.unit)
|
||||
counter.add(event.value, attributes=_clean_attributes(event.attributes))
|
||||
elif isinstance(event.value, float):
|
||||
up_down_counter = self._get_or_create_up_down_counter(event.metric, event.unit)
|
||||
up_down_counter.add(event.value, attributes=_clean_attributes(event.attributes))
|
||||
|
||||
def _get_or_create_up_down_counter(self, name: str, unit: str) -> metrics.UpDownCounter:
|
||||
assert self.meter is not None
|
||||
if name not in _GLOBAL_STORAGE["up_down_counters"]:
|
||||
_GLOBAL_STORAGE["up_down_counters"][name] = self.meter.create_up_down_counter(
|
||||
name=name,
|
||||
unit=unit,
|
||||
description=f"UpDownCounter for {name}",
|
||||
)
|
||||
return cast(metrics.UpDownCounter, _GLOBAL_STORAGE["up_down_counters"][name])
|
||||
|
||||
def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None:
|
||||
with self._lock:
|
||||
span_id = int(event.span_id, 16)
|
||||
tracer = trace.get_tracer(__name__)
|
||||
if event.attributes is None:
|
||||
event.attributes = {}
|
||||
event.attributes["__ttl__"] = ttl_seconds
|
||||
|
||||
# Extract these W3C trace context attributes so they are not written to
|
||||
# underlying storage, as we just need them to propagate the trace context.
|
||||
traceparent = event.attributes.pop("traceparent", None)
|
||||
tracestate = event.attributes.pop("tracestate", None)
|
||||
if traceparent:
|
||||
# If we have a traceparent header value, we're not the root span.
|
||||
for root_attribute in ROOT_SPAN_MARKERS:
|
||||
event.attributes.pop(root_attribute, None)
|
||||
|
||||
if isinstance(event.payload, SpanStartPayload):
|
||||
# Check if span already exists to prevent duplicates
|
||||
if span_id in _GLOBAL_STORAGE["active_spans"]:
|
||||
return
|
||||
|
||||
context = None
|
||||
if event.payload.parent_span_id:
|
||||
parent_span_id = int(event.payload.parent_span_id, 16)
|
||||
parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id)
|
||||
if parent_span:
|
||||
context = trace.set_span_in_context(parent_span)
|
||||
elif traceparent:
|
||||
carrier = {
|
||||
"traceparent": traceparent,
|
||||
"tracestate": tracestate,
|
||||
}
|
||||
context = TraceContextTextMapPropagator().extract(carrier=carrier)
|
||||
|
||||
span = tracer.start_span(
|
||||
name=event.payload.name,
|
||||
context=context,
|
||||
attributes=_clean_attributes(event.attributes),
|
||||
)
|
||||
_GLOBAL_STORAGE["active_spans"][span_id] = span
|
||||
|
||||
elif isinstance(event.payload, SpanEndPayload):
|
||||
span = _GLOBAL_STORAGE["active_spans"].get(span_id) # type: ignore[assignment]
|
||||
if span:
|
||||
if event.attributes:
|
||||
cleaned_attrs = _clean_attributes(event.attributes)
|
||||
if cleaned_attrs:
|
||||
span.set_attributes(cleaned_attrs)
|
||||
|
||||
status = (
|
||||
trace.Status(status_code=trace.StatusCode.OK)
|
||||
if event.payload.status == SpanStatus.OK
|
||||
else trace.Status(status_code=trace.StatusCode.ERROR)
|
||||
)
|
||||
span.set_status(status)
|
||||
span.end()
|
||||
_GLOBAL_STORAGE["active_spans"].pop(span_id, None)
|
||||
else:
|
||||
raise ValueError(f"Unknown structured log event: {event}")
|
||||
|
|
@ -1,154 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
from collections.abc import AsyncGenerator, Callable
|
||||
from functools import wraps
|
||||
from typing import Any, cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.models.llama.datatypes import Primitive
|
||||
|
||||
type JSONValue = Primitive | list["JSONValue"] | dict[str, "JSONValue"]
|
||||
|
||||
|
||||
def serialize_value(value: Any) -> str:
|
||||
return str(_prepare_for_json(value))
|
||||
|
||||
|
||||
def _prepare_for_json(value: Any) -> JSONValue:
|
||||
"""Serialize a single value into JSON-compatible format."""
|
||||
if value is None:
|
||||
return ""
|
||||
elif isinstance(value, str | int | float | bool):
|
||||
return value
|
||||
elif hasattr(value, "_name_"):
|
||||
return cast(str, value._name_)
|
||||
elif isinstance(value, BaseModel):
|
||||
return cast(JSONValue, json.loads(value.model_dump_json()))
|
||||
elif isinstance(value, list | tuple | set):
|
||||
return [_prepare_for_json(item) for item in value]
|
||||
elif isinstance(value, dict):
|
||||
return {str(k): _prepare_for_json(v) for k, v in value.items()}
|
||||
else:
|
||||
try:
|
||||
json.dumps(value)
|
||||
return cast(JSONValue, value)
|
||||
except Exception:
|
||||
return str(value)
|
||||
|
||||
|
||||
def trace_protocol[T: type[Any]](cls: T) -> T:
|
||||
"""
|
||||
A class decorator that automatically traces all methods in a protocol/base class
|
||||
and its inheriting classes.
|
||||
"""
|
||||
|
||||
def trace_method(method: Callable[..., Any]) -> Callable[..., Any]:
|
||||
is_async = asyncio.iscoroutinefunction(method)
|
||||
is_async_gen = inspect.isasyncgenfunction(method)
|
||||
|
||||
def create_span_context(self: Any, *args: Any, **kwargs: Any) -> tuple[str, str, dict[str, Primitive]]:
|
||||
class_name = self.__class__.__name__
|
||||
method_name = method.__name__
|
||||
span_type = "async_generator" if is_async_gen else "async" if is_async else "sync"
|
||||
sig = inspect.signature(method)
|
||||
param_names = list(sig.parameters.keys())[1:] # Skip 'self'
|
||||
combined_args: dict[str, str] = {}
|
||||
for i, arg in enumerate(args):
|
||||
param_name = param_names[i] if i < len(param_names) else f"position_{i + 1}"
|
||||
combined_args[param_name] = serialize_value(arg)
|
||||
for k, v in kwargs.items():
|
||||
combined_args[str(k)] = serialize_value(v)
|
||||
|
||||
span_attributes: dict[str, Primitive] = {
|
||||
"__autotraced__": True,
|
||||
"__class__": class_name,
|
||||
"__method__": method_name,
|
||||
"__type__": span_type,
|
||||
"__args__": json.dumps(combined_args),
|
||||
}
|
||||
|
||||
return class_name, method_name, span_attributes
|
||||
|
||||
@wraps(method)
|
||||
async def async_gen_wrapper(self: Any, *args: Any, **kwargs: Any) -> AsyncGenerator[Any, None]:
|
||||
from llama_stack.core.telemetry import tracing
|
||||
|
||||
class_name, method_name, span_attributes = create_span_context(self, *args, **kwargs)
|
||||
|
||||
with tracing.span(f"{class_name}.{method_name}", span_attributes) as span:
|
||||
count = 0
|
||||
try:
|
||||
async for item in method(self, *args, **kwargs):
|
||||
yield item
|
||||
count += 1
|
||||
finally:
|
||||
span.set_attribute("chunk_count", count)
|
||||
|
||||
@wraps(method)
|
||||
async def async_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
from llama_stack.core.telemetry import tracing
|
||||
|
||||
class_name, method_name, span_attributes = create_span_context(self, *args, **kwargs)
|
||||
|
||||
with tracing.span(f"{class_name}.{method_name}", span_attributes) as span:
|
||||
try:
|
||||
result = await method(self, *args, **kwargs)
|
||||
span.set_attribute("output", serialize_value(result))
|
||||
return result
|
||||
except Exception as e:
|
||||
span.set_attribute("error", str(e))
|
||||
raise
|
||||
|
||||
@wraps(method)
|
||||
def sync_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
from llama_stack.core.telemetry import tracing
|
||||
|
||||
class_name, method_name, span_attributes = create_span_context(self, *args, **kwargs)
|
||||
|
||||
with tracing.span(f"{class_name}.{method_name}", span_attributes) as span:
|
||||
try:
|
||||
result = method(self, *args, **kwargs)
|
||||
span.set_attribute("output", serialize_value(result))
|
||||
return result
|
||||
except Exception as e:
|
||||
span.set_attribute("error", str(e))
|
||||
raise
|
||||
|
||||
if is_async_gen:
|
||||
return async_gen_wrapper
|
||||
elif is_async:
|
||||
return async_wrapper
|
||||
else:
|
||||
return sync_wrapper
|
||||
|
||||
# Wrap methods on the class itself (for classes applied at runtime)
|
||||
# Skip if already wrapped (indicated by __wrapped__ attribute)
|
||||
for name, method in vars(cls).items():
|
||||
if inspect.isfunction(method) and not name.startswith("_"):
|
||||
if not hasattr(method, "__wrapped__"):
|
||||
wrapped = trace_method(method)
|
||||
setattr(cls, name, wrapped) # noqa: B010
|
||||
|
||||
# Also set up __init_subclass__ for future subclasses
|
||||
original_init_subclass = cast(Callable[..., Any] | None, getattr(cls, "__init_subclass__", None))
|
||||
|
||||
def __init_subclass__(cls_child: type[Any], **kwargs: Any) -> None: # noqa: N807
|
||||
if original_init_subclass:
|
||||
cast(Callable[..., None], original_init_subclass)(**kwargs)
|
||||
|
||||
for name, method in vars(cls_child).items():
|
||||
if inspect.isfunction(method) and not name.startswith("_"):
|
||||
setattr(cls_child, name, trace_method(method)) # noqa: B010
|
||||
|
||||
cls_any = cast(Any, cls)
|
||||
cls_any.__init_subclass__ = classmethod(__init_subclass__)
|
||||
|
||||
return cls
|
||||
|
|
@ -1,388 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import contextvars
|
||||
import logging # allow-direct-logging
|
||||
import queue
|
||||
import secrets
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from datetime import UTC, datetime
|
||||
from functools import wraps
|
||||
from typing import Any, Self
|
||||
|
||||
from llama_stack.core.telemetry.telemetry import (
|
||||
ROOT_SPAN_MARKERS,
|
||||
Event,
|
||||
LogSeverity,
|
||||
Span,
|
||||
SpanEndPayload,
|
||||
SpanStartPayload,
|
||||
SpanStatus,
|
||||
StructuredLogEvent,
|
||||
Telemetry,
|
||||
UnstructuredLogEvent,
|
||||
)
|
||||
from llama_stack.core.telemetry.trace_protocol import serialize_value
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(__name__, category="core")
|
||||
|
||||
# Fallback logger that does NOT propagate to TelemetryHandler to avoid recursion
|
||||
_fallback_logger = logging.getLogger("llama_stack.telemetry.background")
|
||||
if not _fallback_logger.handlers:
|
||||
_fallback_logger.propagate = False
|
||||
_fallback_logger.setLevel(logging.ERROR)
|
||||
_fallback_handler = logging.StreamHandler(sys.stderr)
|
||||
_fallback_handler.setLevel(logging.ERROR)
|
||||
_fallback_handler.setFormatter(logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s"))
|
||||
_fallback_logger.addHandler(_fallback_handler)
|
||||
|
||||
|
||||
INVALID_SPAN_ID = 0x0000000000000000
|
||||
INVALID_TRACE_ID = 0x00000000000000000000000000000000
|
||||
|
||||
# The logical root span may not be visible to this process if a parent context
|
||||
# is passed in. The local root span is the first local span in a trace.
|
||||
LOCAL_ROOT_SPAN_MARKER = "__local_root_span__"
|
||||
|
||||
|
||||
def trace_id_to_str(trace_id: int) -> str:
|
||||
"""Convenience trace ID formatting method
|
||||
Args:
|
||||
trace_id: Trace ID int
|
||||
|
||||
Returns:
|
||||
The trace ID as 32-byte hexadecimal string
|
||||
"""
|
||||
return format(trace_id, "032x")
|
||||
|
||||
|
||||
def span_id_to_str(span_id: int) -> str:
|
||||
"""Convenience span ID formatting method
|
||||
Args:
|
||||
span_id: Span ID int
|
||||
|
||||
Returns:
|
||||
The span ID as 16-byte hexadecimal string
|
||||
"""
|
||||
return format(span_id, "016x")
|
||||
|
||||
|
||||
def generate_span_id() -> str:
|
||||
span_id = secrets.randbits(64)
|
||||
while span_id == INVALID_SPAN_ID:
|
||||
span_id = secrets.randbits(64)
|
||||
return span_id_to_str(span_id)
|
||||
|
||||
|
||||
def generate_trace_id() -> str:
|
||||
trace_id = secrets.randbits(128)
|
||||
while trace_id == INVALID_TRACE_ID:
|
||||
trace_id = secrets.randbits(128)
|
||||
return trace_id_to_str(trace_id)
|
||||
|
||||
|
||||
LOG_QUEUE_FULL_LOG_INTERVAL_SECONDS = 60.0
|
||||
|
||||
|
||||
class BackgroundLogger:
|
||||
def __init__(self, api: Telemetry, capacity: int = 100000):
|
||||
self.api = api
|
||||
self.log_queue: queue.Queue[Any] = queue.Queue(maxsize=capacity)
|
||||
self.worker_thread = threading.Thread(target=self._worker, daemon=True)
|
||||
self.worker_thread.start()
|
||||
self._last_queue_full_log_time: float = 0.0
|
||||
self._dropped_since_last_notice: int = 0
|
||||
|
||||
def log_event(self, event: Event) -> None:
|
||||
try:
|
||||
self.log_queue.put_nowait(event)
|
||||
except queue.Full:
|
||||
# Aggregate drops and emit at most once per interval via fallback logger
|
||||
self._dropped_since_last_notice += 1
|
||||
current_time = time.time()
|
||||
if current_time - self._last_queue_full_log_time >= LOG_QUEUE_FULL_LOG_INTERVAL_SECONDS:
|
||||
_fallback_logger.error(
|
||||
"Log queue is full; dropped %d events since last notice",
|
||||
self._dropped_since_last_notice,
|
||||
)
|
||||
self._last_queue_full_log_time = current_time
|
||||
self._dropped_since_last_notice = 0
|
||||
|
||||
def _worker(self):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_until_complete(self._process_logs())
|
||||
|
||||
async def _process_logs(self):
|
||||
while True:
|
||||
try:
|
||||
event = self.log_queue.get()
|
||||
await self.api.log_event(event)
|
||||
except Exception:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
print("Error processing log event")
|
||||
finally:
|
||||
self.log_queue.task_done()
|
||||
|
||||
def __del__(self) -> None:
|
||||
self.log_queue.join()
|
||||
|
||||
|
||||
BACKGROUND_LOGGER: BackgroundLogger | None = None
|
||||
|
||||
|
||||
def enqueue_event(event: Event) -> None:
|
||||
"""Enqueue a telemetry event to the background logger if available.
|
||||
|
||||
This provides a non-blocking path for routers and other hot paths to
|
||||
submit telemetry without awaiting the Telemetry API, reducing contention
|
||||
with the main event loop.
|
||||
"""
|
||||
global BACKGROUND_LOGGER
|
||||
if BACKGROUND_LOGGER is None:
|
||||
raise RuntimeError("Telemetry API not initialized")
|
||||
BACKGROUND_LOGGER.log_event(event)
|
||||
|
||||
|
||||
class TraceContext:
|
||||
def __init__(self, logger: BackgroundLogger, trace_id: str):
|
||||
self.logger = logger
|
||||
self.trace_id = trace_id
|
||||
self.spans: list[Span] = []
|
||||
|
||||
def push_span(self, name: str, attributes: dict[str, Any] | None = None) -> Span:
|
||||
current_span = self.get_current_span()
|
||||
span = Span(
|
||||
span_id=generate_span_id(),
|
||||
trace_id=self.trace_id,
|
||||
name=name,
|
||||
start_time=datetime.now(UTC),
|
||||
parent_span_id=current_span.span_id if current_span else None,
|
||||
attributes=attributes,
|
||||
)
|
||||
|
||||
self.logger.log_event(
|
||||
StructuredLogEvent(
|
||||
trace_id=span.trace_id,
|
||||
span_id=span.span_id,
|
||||
timestamp=span.start_time,
|
||||
attributes=span.attributes,
|
||||
payload=SpanStartPayload(
|
||||
name=span.name,
|
||||
parent_span_id=span.parent_span_id,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
self.spans.append(span)
|
||||
return span
|
||||
|
||||
def pop_span(self, status: SpanStatus = SpanStatus.OK) -> None:
|
||||
span = self.spans.pop()
|
||||
if span is not None:
|
||||
self.logger.log_event(
|
||||
StructuredLogEvent(
|
||||
trace_id=span.trace_id,
|
||||
span_id=span.span_id,
|
||||
timestamp=span.start_time,
|
||||
attributes=span.attributes,
|
||||
payload=SpanEndPayload(
|
||||
status=status,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
def get_current_span(self) -> Span | None:
|
||||
return self.spans[-1] if self.spans else None
|
||||
|
||||
|
||||
CURRENT_TRACE_CONTEXT: contextvars.ContextVar[TraceContext | None] = contextvars.ContextVar(
|
||||
"trace_context", default=None
|
||||
)
|
||||
|
||||
|
||||
def setup_logger(api: Telemetry, level: int = logging.INFO):
|
||||
global BACKGROUND_LOGGER
|
||||
|
||||
if BACKGROUND_LOGGER is None:
|
||||
BACKGROUND_LOGGER = BackgroundLogger(api)
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.setLevel(level)
|
||||
root_logger.addHandler(TelemetryHandler())
|
||||
|
||||
|
||||
async def start_trace(name: str, attributes: dict[str, Any] | None = None) -> TraceContext | None:
|
||||
global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER
|
||||
|
||||
if BACKGROUND_LOGGER is None:
|
||||
logger.debug("No Telemetry implementation set. Skipping trace initialization...")
|
||||
return None
|
||||
|
||||
trace_id = generate_trace_id()
|
||||
context = TraceContext(BACKGROUND_LOGGER, trace_id)
|
||||
# Mark this span as the root for the trace for now. The processing of
|
||||
# traceparent context if supplied comes later and will result in the
|
||||
# ROOT_SPAN_MARKERS being removed. Also mark this is the 'local' root,
|
||||
# i.e. the root of the spans originating in this process as this is
|
||||
# needed to ensure that we insert this 'local' root span's id into
|
||||
# the trace record in sqlite store.
|
||||
attributes = dict.fromkeys(ROOT_SPAN_MARKERS, True) | {LOCAL_ROOT_SPAN_MARKER: True} | (attributes or {})
|
||||
context.push_span(name, attributes)
|
||||
|
||||
CURRENT_TRACE_CONTEXT.set(context)
|
||||
return context
|
||||
|
||||
|
||||
async def end_trace(status: SpanStatus = SpanStatus.OK):
|
||||
global CURRENT_TRACE_CONTEXT
|
||||
|
||||
context = CURRENT_TRACE_CONTEXT.get()
|
||||
if context is None:
|
||||
logger.debug("No trace context to end")
|
||||
return
|
||||
|
||||
context.pop_span(status)
|
||||
CURRENT_TRACE_CONTEXT.set(None)
|
||||
|
||||
|
||||
def severity(levelname: str) -> LogSeverity:
|
||||
if levelname == "DEBUG":
|
||||
return LogSeverity.DEBUG
|
||||
elif levelname == "INFO":
|
||||
return LogSeverity.INFO
|
||||
elif levelname == "WARNING":
|
||||
return LogSeverity.WARN
|
||||
elif levelname == "ERROR":
|
||||
return LogSeverity.ERROR
|
||||
elif levelname == "CRITICAL":
|
||||
return LogSeverity.CRITICAL
|
||||
else:
|
||||
raise ValueError(f"Unknown log level: {levelname}")
|
||||
|
||||
|
||||
# TODO: ideally, the actual emitting should be done inside a separate daemon
|
||||
# process completely isolated from the server
|
||||
class TelemetryHandler(logging.Handler):
|
||||
def emit(self, record: logging.LogRecord) -> None:
|
||||
# horrendous hack to avoid logging from asyncio and getting into an infinite loop
|
||||
if record.module in ("asyncio", "selector_events"):
|
||||
return
|
||||
|
||||
global CURRENT_TRACE_CONTEXT
|
||||
context = CURRENT_TRACE_CONTEXT.get()
|
||||
if context is None:
|
||||
return
|
||||
|
||||
span = context.get_current_span()
|
||||
if span is None:
|
||||
return
|
||||
|
||||
enqueue_event(
|
||||
UnstructuredLogEvent(
|
||||
trace_id=span.trace_id,
|
||||
span_id=span.span_id,
|
||||
timestamp=datetime.now(UTC),
|
||||
message=self.format(record),
|
||||
severity=severity(record.levelname),
|
||||
)
|
||||
)
|
||||
|
||||
def close(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class SpanContextManager:
|
||||
def __init__(self, name: str, attributes: dict[str, Any] | None = None):
|
||||
self.name = name
|
||||
self.attributes = attributes
|
||||
self.span: Span | None = None
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
global CURRENT_TRACE_CONTEXT
|
||||
context = CURRENT_TRACE_CONTEXT.get()
|
||||
if not context:
|
||||
logger.debug("No trace context to push span")
|
||||
return self
|
||||
|
||||
self.span = context.push_span(self.name, self.attributes)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback) -> None:
|
||||
global CURRENT_TRACE_CONTEXT
|
||||
context = CURRENT_TRACE_CONTEXT.get()
|
||||
if not context:
|
||||
logger.debug("No trace context to pop span")
|
||||
return
|
||||
|
||||
context.pop_span()
|
||||
|
||||
def set_attribute(self, key: str, value: Any) -> None:
|
||||
if self.span:
|
||||
if self.span.attributes is None:
|
||||
self.span.attributes = {}
|
||||
self.span.attributes[key] = serialize_value(value)
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
global CURRENT_TRACE_CONTEXT
|
||||
context = CURRENT_TRACE_CONTEXT.get()
|
||||
if not context:
|
||||
logger.debug("No trace context to push span")
|
||||
return self
|
||||
|
||||
self.span = context.push_span(self.name, self.attributes)
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_value, traceback) -> None:
|
||||
global CURRENT_TRACE_CONTEXT
|
||||
context = CURRENT_TRACE_CONTEXT.get()
|
||||
if not context:
|
||||
logger.debug("No trace context to pop span")
|
||||
return
|
||||
|
||||
context.pop_span()
|
||||
|
||||
def __call__(self, func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
with self:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
async with self:
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
return async_wrapper(*args, **kwargs)
|
||||
else:
|
||||
return sync_wrapper(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def span(name: str, attributes: dict[str, Any] | None = None) -> SpanContextManager:
|
||||
return SpanContextManager(name, attributes)
|
||||
|
||||
|
||||
def get_current_span() -> Span | None:
|
||||
global CURRENT_TRACE_CONTEXT
|
||||
if CURRENT_TRACE_CONTEXT is None:
|
||||
logger.debug("No trace context to get current span")
|
||||
return None
|
||||
|
||||
context = CURRENT_TRACE_CONTEXT.get()
|
||||
if context:
|
||||
return context.get_current_span()
|
||||
return None
|
||||
|
|
@ -7,8 +7,6 @@
|
|||
from collections.abc import AsyncGenerator
|
||||
from contextvars import ContextVar
|
||||
|
||||
from llama_stack.core.telemetry.tracing import CURRENT_TRACE_CONTEXT
|
||||
|
||||
_MISSING = object()
|
||||
|
||||
|
||||
|
|
@ -69,16 +67,12 @@ def preserve_contexts_async_generator[T](
|
|||
try:
|
||||
yield item
|
||||
# Update our tracked values with any changes made during this iteration
|
||||
# Only for non-trace context vars - trace context must persist across yields
|
||||
# to allow nested span tracking for telemetry
|
||||
# This allows context changes to persist across generator iterations
|
||||
for context_var in context_vars:
|
||||
if context_var is not CURRENT_TRACE_CONTEXT:
|
||||
initial_context_values[context_var.name] = context_var.get()
|
||||
initial_context_values[context_var.name] = context_var.get()
|
||||
finally:
|
||||
# Restore non-trace context vars after each yield to prevent leaks between requests
|
||||
# CURRENT_TRACE_CONTEXT is NOT restored here to preserve telemetry span stack
|
||||
# Restore context vars after each yield to prevent leaks between requests
|
||||
for context_var in context_vars:
|
||||
if context_var is not CURRENT_TRACE_CONTEXT:
|
||||
_restore_context_var(context_var)
|
||||
_restore_context_var(context_var)
|
||||
|
||||
return wrapper()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue