Fix opentelemetry adapter (#510)

# What does this PR do?

This PR fixes some of the issues with our telemetry setup to enable logs
to be delivered to opentelemetry and jaeger. Main fixes
1) Updates the open telemetry provider to use the latest oltp exports
instead of deprected ones.
2) Adds a tracing middleware, which injects traces into each HTTP
request that the server recieves and this is going to be the root trace.
Previously, we did this in the create_dynamic_route method, which is
actually not the actual exectuion flow, but more of a config and this
causes the traces to end prematurely. Through middleware, we plugin the
trace start and end at the right location.
3) We manage our own methods to create traces and spans and this does
not fit well with Opentelemetry SDK since it does not support provide a
way to take in traces and spans that are already created. it expects us
to use the SDK to create them. For now, I have a hacky approach of just
maintaining a map from our internal telemetry objects to the open
telemetry specfic ones. This is not the ideal solution. I will explore
other ways to get around this issue. for now, to have something that
works, i am going to keep this as is.

Addresses: #509
This commit is contained in:
Dinesh Yeduguru 2024-11-22 18:18:11 -08:00 committed by GitHub
parent beab798a1d
commit 501e7c9d64
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 185 additions and 217 deletions

View file

@ -40,7 +40,7 @@ class ModelsClient(Models):
response = await client.post( response = await client.post(
f"{self.base_url}/models/register", f"{self.base_url}/models/register",
json={ json={
"model": json.loads(model.json()), "model": json.loads(model.model_dump_json()),
}, },
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
) )

View file

@ -17,13 +17,11 @@ import warnings
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from pathlib import Path from pathlib import Path
from ssl import SSLError from typing import Any, Union
from typing import Any, Dict, Optional
import httpx
import yaml import yaml
from fastapi import Body, FastAPI, HTTPException, Request, Response from fastapi import Body, FastAPI, HTTPException, Request
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, ValidationError from pydantic import BaseModel, ValidationError
@ -35,7 +33,6 @@ from llama_stack.distribution.distribution import builtin_automatically_routed_a
from llama_stack.providers.utils.telemetry.tracing import ( from llama_stack.providers.utils.telemetry.tracing import (
end_trace, end_trace,
setup_logger, setup_logger,
SpanStatus,
start_trace, start_trace,
) )
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403
@ -118,67 +115,6 @@ def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidatio
) )
async def passthrough(
request: Request,
downstream_url: str,
downstream_headers: Optional[Dict[str, str]] = None,
):
await start_trace(request.path, {"downstream_url": downstream_url})
headers = dict(request.headers)
headers.pop("host", None)
headers.update(downstream_headers or {})
content = await request.body()
client = httpx.AsyncClient()
erred = False
try:
req = client.build_request(
method=request.method,
url=downstream_url,
headers=headers,
content=content,
params=request.query_params,
)
response = await client.send(req, stream=True)
async def stream_response():
async for chunk in response.aiter_raw(chunk_size=64):
yield chunk
await response.aclose()
await client.aclose()
return StreamingResponse(
stream_response(),
status_code=response.status_code,
headers=dict(response.headers),
media_type=response.headers.get("content-type"),
)
except httpx.ReadTimeout:
erred = True
return Response(content="Downstream server timed out", status_code=504)
except httpx.NetworkError as e:
erred = True
return Response(content=f"Network error: {str(e)}", status_code=502)
except httpx.TooManyRedirects:
erred = True
return Response(content="Too many redirects", status_code=502)
except SSLError as e:
erred = True
return Response(content=f"SSL error: {str(e)}", status_code=502)
except httpx.HTTPStatusError as e:
erred = True
return Response(content=str(e), status_code=e.response.status_code)
except Exception as e:
erred = True
return Response(content=f"Unexpected error: {str(e)}", status_code=500)
finally:
await end_trace(SpanStatus.OK if not erred else SpanStatus.ERROR)
def handle_sigint(app, *args, **kwargs): def handle_sigint(app, *args, **kwargs):
print("SIGINT or CTRL-C detected. Exiting gracefully...") print("SIGINT or CTRL-C detected. Exiting gracefully...")
@ -217,7 +153,6 @@ async def maybe_await(value):
async def sse_generator(event_gen): async def sse_generator(event_gen):
await start_trace("sse_generator")
try: try:
event_gen = await event_gen event_gen = await event_gen
async for item in event_gen: async for item in event_gen:
@ -235,14 +170,10 @@ async def sse_generator(event_gen):
}, },
} }
) )
finally:
await end_trace()
def create_dynamic_typed_route(func: Any, method: str): def create_dynamic_typed_route(func: Any, method: str):
async def endpoint(request: Request, **kwargs): async def endpoint(request: Request, **kwargs):
await start_trace(func.__name__)
set_request_provider_data(request.headers) set_request_provider_data(request.headers)
is_streaming = is_streaming_request(func.__name__, request, **kwargs) is_streaming = is_streaming_request(func.__name__, request, **kwargs)
@ -257,8 +188,6 @@ def create_dynamic_typed_route(func: Any, method: str):
except Exception as e: except Exception as e:
traceback.print_exception(e) traceback.print_exception(e)
raise translate_exception(e) from e raise translate_exception(e) from e
finally:
await end_trace()
sig = inspect.signature(func) sig = inspect.signature(func)
new_params = [ new_params = [
@ -282,6 +211,19 @@ def create_dynamic_typed_route(func: Any, method: str):
return endpoint return endpoint
class TracingMiddleware:
def __init__(self, app):
self.app = app
async def __call__(self, scope, receive, send):
path = scope["path"]
await start_trace(path, {"location": "server"})
try:
return await self.app(scope, receive, send)
finally:
await end_trace()
def main(): def main():
"""Start the LlamaStack server.""" """Start the LlamaStack server."""
parser = argparse.ArgumentParser(description="Start the LlamaStack server.") parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
@ -338,6 +280,7 @@ def main():
print(yaml.dump(config.model_dump(), indent=2)) print(yaml.dump(config.model_dump(), indent=2))
app = FastAPI(lifespan=lifespan) app = FastAPI(lifespan=lifespan)
app.add_middleware(TracingMiddleware)
try: try:
impls = asyncio.run(construct_stack(config)) impls = asyncio.run(construct_stack(config))

View file

@ -113,7 +113,7 @@ class ChatAgent(ShieldRunnerMixin):
# May be this should be a parameter of the agentic instance # May be this should be a parameter of the agentic instance
# that can define its behavior in a custom way # that can define its behavior in a custom way
for m in turn.input_messages: for m in turn.input_messages:
msg = m.copy() msg = m.model_copy()
if isinstance(msg, UserMessage): if isinstance(msg, UserMessage):
msg.context = None msg.context = None
messages.append(msg) messages.append(msg)

View file

@ -52,7 +52,7 @@ class MetaReferenceAgentsImpl(Agents):
await self.persistence_store.set( await self.persistence_store.set(
key=f"agent:{agent_id}", key=f"agent:{agent_id}",
value=agent_config.json(), value=agent_config.model_dump_json(),
) )
return AgentCreateResponse( return AgentCreateResponse(
agent_id=agent_id, agent_id=agent_id,

View file

@ -39,7 +39,7 @@ class AgentPersistence:
) )
await self.kvstore.set( await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}", key=f"session:{self.agent_id}:{session_id}",
value=session_info.json(), value=session_info.model_dump_json(),
) )
return session_id return session_id
@ -60,13 +60,13 @@ class AgentPersistence:
session_info.memory_bank_id = bank_id session_info.memory_bank_id = bank_id
await self.kvstore.set( await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}", key=f"session:{self.agent_id}:{session_id}",
value=session_info.json(), value=session_info.model_dump_json(),
) )
async def add_turn_to_session(self, session_id: str, turn: Turn): async def add_turn_to_session(self, session_id: str, turn: Turn):
await self.kvstore.set( await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}", key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}",
value=turn.json(), value=turn.model_dump_json(),
) )
async def get_session_turns(self, session_id: str) -> List[Turn]: async def get_session_turns(self, session_id: str) -> List[Turn]:

View file

@ -72,7 +72,7 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
key = f"{EVAL_TASKS_PREFIX}{task_def.identifier}" key = f"{EVAL_TASKS_PREFIX}{task_def.identifier}"
await self.kvstore.set( await self.kvstore.set(
key=key, key=key,
value=task_def.json(), value=task_def.model_dump_json(),
) )
self.eval_tasks[task_def.identifier] = task_def self.eval_tasks[task_def.identifier] = task_def

View file

@ -80,7 +80,9 @@ class FaissIndex(EmbeddingIndex):
np.savetxt(buffer, np_index) np.savetxt(buffer, np_index)
data = { data = {
"id_by_index": self.id_by_index, "id_by_index": self.id_by_index,
"chunk_by_index": {k: v.json() for k, v in self.chunk_by_index.items()}, "chunk_by_index": {
k: v.model_dump_json() for k, v in self.chunk_by_index.items()
},
"faiss_index": base64.b64encode(buffer.getvalue()).decode("utf-8"), "faiss_index": base64.b64encode(buffer.getvalue()).decode("utf-8"),
} }
@ -162,7 +164,7 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
key = f"{MEMORY_BANKS_PREFIX}{memory_bank.identifier}" key = f"{MEMORY_BANKS_PREFIX}{memory_bank.identifier}"
await self.kvstore.set( await self.kvstore.set(
key=key, key=key,
value=memory_bank.json(), value=memory_bank.model_dump_json(),
) )
# Store in cache # Store in cache

View file

@ -107,7 +107,7 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
collection = await self.client.get_or_create_collection( collection = await self.client.get_or_create_collection(
name=memory_bank.identifier, name=memory_bank.identifier,
metadata={"bank": memory_bank.json()}, metadata={"bank": memory_bank.model_dump_json()},
) )
bank_index = BankWithIndex( bank_index = BankWithIndex(
bank=memory_bank, index=ChromaIndex(self.client, collection) bank=memory_bank, index=ChromaIndex(self.client, collection)

View file

@ -4,9 +4,24 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from pydantic import BaseModel from typing import Any, Dict
from pydantic import BaseModel, Field
class OpenTelemetryConfig(BaseModel): class OpenTelemetryConfig(BaseModel):
jaeger_host: str = "localhost" otel_endpoint: str = Field(
jaeger_port: int = 6831 default="http://localhost:4318/v1/traces",
description="The OpenTelemetry collector endpoint URL",
)
service_name: str = Field(
default="llama-stack",
description="The service name to use for telemetry",
)
@classmethod
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
return {
"otel_endpoint": "${env.OTEL_ENDPOINT:http://localhost:4318/v1/traces}",
"service_name": "${env.OTEL_SERVICE_NAME:llama-stack}",
}

View file

@ -4,24 +4,31 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from datetime import datetime import threading
from opentelemetry import metrics, trace from opentelemetry import metrics, trace
from opentelemetry.exporter.jaeger.thrift import JaegerExporter from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.metrics import MeterProvider from opentelemetry.sdk.metrics import MeterProvider
from opentelemetry.sdk.metrics.export import ( from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
ConsoleMetricExporter,
PeriodicExportingMetricReader,
)
from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.semconv.resource import ResourceAttributes from opentelemetry.semconv.resource import ResourceAttributes
from llama_stack.apis.telemetry import * # noqa: F403 from llama_stack.apis.telemetry import * # noqa: F403
from .config import OpenTelemetryConfig from .config import OpenTelemetryConfig
_GLOBAL_STORAGE = {
"active_spans": {},
"counters": {},
"gauges": {},
"up_down_counters": {},
}
_global_lock = threading.Lock()
def string_to_trace_id(s: str) -> int: def string_to_trace_id(s: str) -> int:
# Convert the string to bytes and then to an integer # Convert the string to bytes and then to an integer
@ -42,33 +49,37 @@ class OpenTelemetryAdapter(Telemetry):
def __init__(self, config: OpenTelemetryConfig): def __init__(self, config: OpenTelemetryConfig):
self.config = config self.config = config
self.resource = Resource.create( resource = Resource.create(
{ResourceAttributes.SERVICE_NAME: "foobar-service"} {
ResourceAttributes.SERVICE_NAME: self.config.service_name,
}
) )
# Set up tracing with Jaeger exporter provider = TracerProvider(resource=resource)
jaeger_exporter = JaegerExporter( trace.set_tracer_provider(provider)
agent_host_name=self.config.jaeger_host, otlp_exporter = OTLPSpanExporter(
agent_port=self.config.jaeger_port, endpoint=self.config.otel_endpoint,
) )
trace_provider = TracerProvider(resource=self.resource) span_processor = BatchSpanProcessor(otlp_exporter)
trace_processor = BatchSpanProcessor(jaeger_exporter) trace.get_tracer_provider().add_span_processor(span_processor)
trace_provider.add_span_processor(trace_processor)
trace.set_tracer_provider(trace_provider)
self.tracer = trace.get_tracer(__name__)
# Set up metrics # Set up metrics
metric_reader = PeriodicExportingMetricReader(ConsoleMetricExporter()) metric_reader = PeriodicExportingMetricReader(
OTLPMetricExporter(
endpoint=self.config.otel_endpoint,
)
)
metric_provider = MeterProvider( metric_provider = MeterProvider(
resource=self.resource, metric_readers=[metric_reader] resource=resource, metric_readers=[metric_reader]
) )
metrics.set_meter_provider(metric_provider) metrics.set_meter_provider(metric_provider)
self.meter = metrics.get_meter(__name__) self.meter = metrics.get_meter(__name__)
self._lock = _global_lock
async def initialize(self) -> None: async def initialize(self) -> None:
pass pass
async def shutdown(self) -> None: async def shutdown(self) -> None:
trace.get_tracer_provider().force_flush()
trace.get_tracer_provider().shutdown() trace.get_tracer_provider().shutdown()
metrics.get_meter_provider().shutdown() metrics.get_meter_provider().shutdown()
@ -81,121 +92,117 @@ class OpenTelemetryAdapter(Telemetry):
self._log_structured(event) self._log_structured(event)
def _log_unstructured(self, event: UnstructuredLogEvent) -> None: def _log_unstructured(self, event: UnstructuredLogEvent) -> None:
span = trace.get_current_span() with self._lock:
span.add_event( # Use global storage instead of instance storage
name=event.message, span_id = string_to_span_id(event.span_id)
attributes={"severity": event.severity.value, **event.attributes}, span = _GLOBAL_STORAGE["active_spans"].get(span_id)
timestamp=event.timestamp,
) if span:
timestamp_ns = int(event.timestamp.timestamp() * 1e9)
span.add_event(
name=event.type,
attributes={
"message": event.message,
"severity": event.severity.value,
**event.attributes,
},
timestamp=timestamp_ns,
)
else:
print(
f"Warning: No active span found for span_id {span_id}. Dropping event: {event}"
)
def _get_or_create_counter(self, name: str, unit: str) -> metrics.Counter:
if name not in _GLOBAL_STORAGE["counters"]:
_GLOBAL_STORAGE["counters"][name] = self.meter.create_counter(
name=name,
unit=unit,
description=f"Counter for {name}",
)
return _GLOBAL_STORAGE["counters"][name]
def _get_or_create_gauge(self, name: str, unit: str) -> metrics.ObservableGauge:
if name not in _GLOBAL_STORAGE["gauges"]:
_GLOBAL_STORAGE["gauges"][name] = self.meter.create_gauge(
name=name,
unit=unit,
description=f"Gauge for {name}",
)
return _GLOBAL_STORAGE["gauges"][name]
def _log_metric(self, event: MetricEvent) -> None: def _log_metric(self, event: MetricEvent) -> None:
if isinstance(event.value, int): if isinstance(event.value, int):
self.meter.create_counter( counter = self._get_or_create_counter(event.metric, event.unit)
name=event.metric, counter.add(event.value, attributes=event.attributes)
unit=event.unit,
description=f"Counter for {event.metric}",
).add(event.value, attributes=event.attributes)
elif isinstance(event.value, float): elif isinstance(event.value, float):
self.meter.create_gauge( up_down_counter = self._get_or_create_up_down_counter(
name=event.metric, event.metric, event.unit
unit=event.unit, )
description=f"Gauge for {event.metric}", up_down_counter.add(event.value, attributes=event.attributes)
).set(event.value, attributes=event.attributes)
def _get_or_create_up_down_counter(
self, name: str, unit: str
) -> metrics.UpDownCounter:
if name not in _GLOBAL_STORAGE["up_down_counters"]:
_GLOBAL_STORAGE["up_down_counters"][name] = (
self.meter.create_up_down_counter(
name=name,
unit=unit,
description=f"UpDownCounter for {name}",
)
)
return _GLOBAL_STORAGE["up_down_counters"][name]
def _log_structured(self, event: StructuredLogEvent) -> None: def _log_structured(self, event: StructuredLogEvent) -> None:
if isinstance(event.payload, SpanStartPayload): with self._lock:
context = trace.set_span_in_context( span_id = string_to_span_id(event.span_id)
trace.NonRecordingSpan( trace_id = string_to_trace_id(event.trace_id)
trace.SpanContext( tracer = trace.get_tracer(__name__)
trace_id=string_to_trace_id(event.trace_id),
span_id=string_to_span_id(event.span_id),
is_remote=True,
)
)
)
span = self.tracer.start_span(
name=event.payload.name,
kind=trace.SpanKind.INTERNAL,
context=context,
attributes=event.attributes,
)
if event.payload.parent_span_id: if isinstance(event.payload, SpanStartPayload):
span.set_parent( # Check if span already exists to prevent duplicates
trace.SpanContext( if span_id in _GLOBAL_STORAGE["active_spans"]:
trace_id=string_to_trace_id(event.trace_id), return
span_id=string_to_span_id(event.payload.parent_span_id),
is_remote=True, parent_span = None
if event.payload.parent_span_id:
parent_span_id = string_to_span_id(event.payload.parent_span_id)
parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id)
# Create a new trace context with the trace_id
context = trace.Context(trace_id=trace_id)
if parent_span:
context = trace.set_span_in_context(parent_span, context)
span = tracer.start_span(
name=event.payload.name,
context=context,
attributes=event.attributes or {},
start_time=int(event.timestamp.timestamp() * 1e9),
)
_GLOBAL_STORAGE["active_spans"][span_id] = span
# Set as current span using context manager
with trace.use_span(span, end_on_exit=False):
pass # Let the span continue beyond this block
elif isinstance(event.payload, SpanEndPayload):
span = _GLOBAL_STORAGE["active_spans"].get(span_id)
if span:
if event.attributes:
span.set_attributes(event.attributes)
status = (
trace.Status(status_code=trace.StatusCode.OK)
if event.payload.status == SpanStatus.OK
else trace.Status(status_code=trace.StatusCode.ERROR)
) )
) span.set_status(status)
elif isinstance(event.payload, SpanEndPayload): span.end(end_time=int(event.timestamp.timestamp() * 1e9))
span = trace.get_current_span()
span.set_status( # Remove from active spans
trace.Status( _GLOBAL_STORAGE["active_spans"].pop(span_id, None)
trace.StatusCode.OK
if event.payload.status == SpanStatus.OK
else trace.StatusCode.ERROR
)
)
span.end(end_time=event.timestamp)
async def get_trace(self, trace_id: str) -> Trace: async def get_trace(self, trace_id: str) -> Trace:
# we need to look up the root span id raise NotImplementedError("Trace retrieval not implemented yet")
raise NotImplementedError("not yet no")
# Usage example
async def main():
telemetry = OpenTelemetryTelemetry("my-service")
await telemetry.initialize()
# Log an unstructured event
await telemetry.log_event(
UnstructuredLogEvent(
trace_id="trace123",
span_id="span456",
timestamp=datetime.now(),
message="This is a log message",
severity=LogSeverity.INFO,
)
)
# Log a metric event
await telemetry.log_event(
MetricEvent(
trace_id="trace123",
span_id="span456",
timestamp=datetime.now(),
metric="my_metric",
value=42,
unit="count",
)
)
# Log a structured event (span start)
await telemetry.log_event(
StructuredLogEvent(
trace_id="trace123",
span_id="span789",
timestamp=datetime.now(),
payload=SpanStartPayload(name="my_operation"),
)
)
# Log a structured event (span end)
await telemetry.log_event(
StructuredLogEvent(
trace_id="trace123",
span_id="span789",
timestamp=datetime.now(),
payload=SpanEndPayload(status=SpanStatus.OK),
)
)
await telemetry.shutdown()
if __name__ == "__main__":
import asyncio
asyncio.run(main())

View file

@ -20,7 +20,7 @@ from llama_stack.apis.telemetry import * # noqa: F403
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
def generate_short_uuid(len: int = 12): def generate_short_uuid(len: int = 8):
full_uuid = uuid.uuid4() full_uuid = uuid.uuid4()
uuid_bytes = full_uuid.bytes uuid_bytes = full_uuid.bytes
encoded = base64.urlsafe_b64encode(uuid_bytes) encoded = base64.urlsafe_b64encode(uuid_bytes)
@ -123,18 +123,19 @@ def setup_logger(api: Telemetry, level: int = logging.INFO):
logger.addHandler(TelemetryHandler()) logger.addHandler(TelemetryHandler())
async def start_trace(name: str, attributes: Dict[str, Any] = None): async def start_trace(name: str, attributes: Dict[str, Any] = None) -> TraceContext:
global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER
if BACKGROUND_LOGGER is None: if BACKGROUND_LOGGER is None:
log.info("No Telemetry implementation set. Skipping trace initialization...") log.info("No Telemetry implementation set. Skipping trace initialization...")
return return
trace_id = generate_short_uuid() trace_id = generate_short_uuid(16)
context = TraceContext(BACKGROUND_LOGGER, trace_id) context = TraceContext(BACKGROUND_LOGGER, trace_id)
context.push_span(name, {"__root__": True, **(attributes or {})}) context.push_span(name, {"__root__": True, **(attributes or {})})
CURRENT_TRACE_CONTEXT = context CURRENT_TRACE_CONTEXT = context
return context
async def end_trace(status: SpanStatus = SpanStatus.OK): async def end_trace(status: SpanStatus = SpanStatus.OK):