From 501e7c9d646873c341411b63429743f99b3afded Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Fri, 22 Nov 2024 18:18:11 -0800 Subject: [PATCH] Fix opentelemetry adapter (#510) # What does this PR do? This PR fixes some of the issues with our telemetry setup to enable logs to be delivered to opentelemetry and jaeger. Main fixes 1) Updates the open telemetry provider to use the latest oltp exports instead of deprected ones. 2) Adds a tracing middleware, which injects traces into each HTTP request that the server recieves and this is going to be the root trace. Previously, we did this in the create_dynamic_route method, which is actually not the actual exectuion flow, but more of a config and this causes the traces to end prematurely. Through middleware, we plugin the trace start and end at the right location. 3) We manage our own methods to create traces and spans and this does not fit well with Opentelemetry SDK since it does not support provide a way to take in traces and spans that are already created. it expects us to use the SDK to create them. For now, I have a hacky approach of just maintaining a map from our internal telemetry objects to the open telemetry specfic ones. This is not the ideal solution. I will explore other ways to get around this issue. for now, to have something that works, i am going to keep this as is. Addresses: #509 --- llama_stack/apis/models/client.py | 2 +- llama_stack/distribution/server/server.py | 89 ++---- .../agents/meta_reference/agent_instance.py | 2 +- .../inline/agents/meta_reference/agents.py | 2 +- .../agents/meta_reference/persistence.py | 6 +- .../inline/eval/meta_reference/eval.py | 2 +- .../providers/inline/memory/faiss/faiss.py | 6 +- .../providers/remote/memory/chroma/chroma.py | 2 +- .../remote/telemetry/opentelemetry/config.py | 21 +- .../telemetry/opentelemetry/opentelemetry.py | 263 +++++++++--------- .../providers/utils/telemetry/tracing.py | 7 +- 11 files changed, 185 insertions(+), 217 deletions(-) diff --git a/llama_stack/apis/models/client.py b/llama_stack/apis/models/client.py index 34541b96e..1a72d8043 100644 --- a/llama_stack/apis/models/client.py +++ b/llama_stack/apis/models/client.py @@ -40,7 +40,7 @@ class ModelsClient(Models): response = await client.post( f"{self.base_url}/models/register", json={ - "model": json.loads(model.json()), + "model": json.loads(model.model_dump_json()), }, headers={"Content-Type": "application/json"}, ) diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index b8ff0e785..8116e2b39 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -17,13 +17,11 @@ import warnings from contextlib import asynccontextmanager from pathlib import Path -from ssl import SSLError -from typing import Any, Dict, Optional +from typing import Any, Union -import httpx import yaml -from fastapi import Body, FastAPI, HTTPException, Request, Response +from fastapi import Body, FastAPI, HTTPException, Request from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse, StreamingResponse from pydantic import BaseModel, ValidationError @@ -35,7 +33,6 @@ from llama_stack.distribution.distribution import builtin_automatically_routed_a from llama_stack.providers.utils.telemetry.tracing import ( end_trace, setup_logger, - SpanStatus, start_trace, ) from llama_stack.distribution.datatypes import * # noqa: F403 @@ -118,67 +115,6 @@ def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidatio ) -async def passthrough( - request: Request, - downstream_url: str, - downstream_headers: Optional[Dict[str, str]] = None, -): - await start_trace(request.path, {"downstream_url": downstream_url}) - - headers = dict(request.headers) - headers.pop("host", None) - headers.update(downstream_headers or {}) - - content = await request.body() - - client = httpx.AsyncClient() - erred = False - try: - req = client.build_request( - method=request.method, - url=downstream_url, - headers=headers, - content=content, - params=request.query_params, - ) - response = await client.send(req, stream=True) - - async def stream_response(): - async for chunk in response.aiter_raw(chunk_size=64): - yield chunk - - await response.aclose() - await client.aclose() - - return StreamingResponse( - stream_response(), - status_code=response.status_code, - headers=dict(response.headers), - media_type=response.headers.get("content-type"), - ) - - except httpx.ReadTimeout: - erred = True - return Response(content="Downstream server timed out", status_code=504) - except httpx.NetworkError as e: - erred = True - return Response(content=f"Network error: {str(e)}", status_code=502) - except httpx.TooManyRedirects: - erred = True - return Response(content="Too many redirects", status_code=502) - except SSLError as e: - erred = True - return Response(content=f"SSL error: {str(e)}", status_code=502) - except httpx.HTTPStatusError as e: - erred = True - return Response(content=str(e), status_code=e.response.status_code) - except Exception as e: - erred = True - return Response(content=f"Unexpected error: {str(e)}", status_code=500) - finally: - await end_trace(SpanStatus.OK if not erred else SpanStatus.ERROR) - - def handle_sigint(app, *args, **kwargs): print("SIGINT or CTRL-C detected. Exiting gracefully...") @@ -217,7 +153,6 @@ async def maybe_await(value): async def sse_generator(event_gen): - await start_trace("sse_generator") try: event_gen = await event_gen async for item in event_gen: @@ -235,14 +170,10 @@ async def sse_generator(event_gen): }, } ) - finally: - await end_trace() def create_dynamic_typed_route(func: Any, method: str): async def endpoint(request: Request, **kwargs): - await start_trace(func.__name__) - set_request_provider_data(request.headers) is_streaming = is_streaming_request(func.__name__, request, **kwargs) @@ -257,8 +188,6 @@ def create_dynamic_typed_route(func: Any, method: str): except Exception as e: traceback.print_exception(e) raise translate_exception(e) from e - finally: - await end_trace() sig = inspect.signature(func) new_params = [ @@ -282,6 +211,19 @@ def create_dynamic_typed_route(func: Any, method: str): return endpoint +class TracingMiddleware: + def __init__(self, app): + self.app = app + + async def __call__(self, scope, receive, send): + path = scope["path"] + await start_trace(path, {"location": "server"}) + try: + return await self.app(scope, receive, send) + finally: + await end_trace() + + def main(): """Start the LlamaStack server.""" parser = argparse.ArgumentParser(description="Start the LlamaStack server.") @@ -338,6 +280,7 @@ def main(): print(yaml.dump(config.model_dump(), indent=2)) app = FastAPI(lifespan=lifespan) + app.add_middleware(TracingMiddleware) try: impls = asyncio.run(construct_stack(config)) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index e1713c0e3..8f800ad6f 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -113,7 +113,7 @@ class ChatAgent(ShieldRunnerMixin): # May be this should be a parameter of the agentic instance # that can define its behavior in a custom way for m in turn.input_messages: - msg = m.copy() + msg = m.model_copy() if isinstance(msg, UserMessage): msg.context = None messages.append(msg) diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index 13d9044fd..f33aadde3 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -52,7 +52,7 @@ class MetaReferenceAgentsImpl(Agents): await self.persistence_store.set( key=f"agent:{agent_id}", - value=agent_config.json(), + value=agent_config.model_dump_json(), ) return AgentCreateResponse( agent_id=agent_id, diff --git a/llama_stack/providers/inline/agents/meta_reference/persistence.py b/llama_stack/providers/inline/agents/meta_reference/persistence.py index d51e25a32..1c99e3d75 100644 --- a/llama_stack/providers/inline/agents/meta_reference/persistence.py +++ b/llama_stack/providers/inline/agents/meta_reference/persistence.py @@ -39,7 +39,7 @@ class AgentPersistence: ) await self.kvstore.set( key=f"session:{self.agent_id}:{session_id}", - value=session_info.json(), + value=session_info.model_dump_json(), ) return session_id @@ -60,13 +60,13 @@ class AgentPersistence: session_info.memory_bank_id = bank_id await self.kvstore.set( key=f"session:{self.agent_id}:{session_id}", - value=session_info.json(), + value=session_info.model_dump_json(), ) async def add_turn_to_session(self, session_id: str, turn: Turn): await self.kvstore.set( key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}", - value=turn.json(), + value=turn.model_dump_json(), ) async def get_session_turns(self, session_id: str) -> List[Turn]: diff --git a/llama_stack/providers/inline/eval/meta_reference/eval.py b/llama_stack/providers/inline/eval/meta_reference/eval.py index d1df869b4..c6cacfcc3 100644 --- a/llama_stack/providers/inline/eval/meta_reference/eval.py +++ b/llama_stack/providers/inline/eval/meta_reference/eval.py @@ -72,7 +72,7 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate): key = f"{EVAL_TASKS_PREFIX}{task_def.identifier}" await self.kvstore.set( key=key, - value=task_def.json(), + value=task_def.model_dump_json(), ) self.eval_tasks[task_def.identifier] = task_def diff --git a/llama_stack/providers/inline/memory/faiss/faiss.py b/llama_stack/providers/inline/memory/faiss/faiss.py index 95791bc69..dfefefeb8 100644 --- a/llama_stack/providers/inline/memory/faiss/faiss.py +++ b/llama_stack/providers/inline/memory/faiss/faiss.py @@ -80,7 +80,9 @@ class FaissIndex(EmbeddingIndex): np.savetxt(buffer, np_index) data = { "id_by_index": self.id_by_index, - "chunk_by_index": {k: v.json() for k, v in self.chunk_by_index.items()}, + "chunk_by_index": { + k: v.model_dump_json() for k, v in self.chunk_by_index.items() + }, "faiss_index": base64.b64encode(buffer.getvalue()).decode("utf-8"), } @@ -162,7 +164,7 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate): key = f"{MEMORY_BANKS_PREFIX}{memory_bank.identifier}" await self.kvstore.set( key=key, - value=memory_bank.json(), + value=memory_bank.model_dump_json(), ) # Store in cache diff --git a/llama_stack/providers/remote/memory/chroma/chroma.py b/llama_stack/providers/remote/memory/chroma/chroma.py index 20185aade..207f6b54d 100644 --- a/llama_stack/providers/remote/memory/chroma/chroma.py +++ b/llama_stack/providers/remote/memory/chroma/chroma.py @@ -107,7 +107,7 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate): collection = await self.client.get_or_create_collection( name=memory_bank.identifier, - metadata={"bank": memory_bank.json()}, + metadata={"bank": memory_bank.model_dump_json()}, ) bank_index = BankWithIndex( bank=memory_bank, index=ChromaIndex(self.client, collection) diff --git a/llama_stack/providers/remote/telemetry/opentelemetry/config.py b/llama_stack/providers/remote/telemetry/opentelemetry/config.py index 71a82aed9..5e9dff1a1 100644 --- a/llama_stack/providers/remote/telemetry/opentelemetry/config.py +++ b/llama_stack/providers/remote/telemetry/opentelemetry/config.py @@ -4,9 +4,24 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from pydantic import BaseModel +from typing import Any, Dict + +from pydantic import BaseModel, Field class OpenTelemetryConfig(BaseModel): - jaeger_host: str = "localhost" - jaeger_port: int = 6831 + otel_endpoint: str = Field( + default="http://localhost:4318/v1/traces", + description="The OpenTelemetry collector endpoint URL", + ) + service_name: str = Field( + default="llama-stack", + description="The service name to use for telemetry", + ) + + @classmethod + def sample_run_config(cls, **kwargs) -> Dict[str, Any]: + return { + "otel_endpoint": "${env.OTEL_ENDPOINT:http://localhost:4318/v1/traces}", + "service_name": "${env.OTEL_SERVICE_NAME:llama-stack}", + } diff --git a/llama_stack/providers/remote/telemetry/opentelemetry/opentelemetry.py b/llama_stack/providers/remote/telemetry/opentelemetry/opentelemetry.py index 03e8f7d53..c9830fd9d 100644 --- a/llama_stack/providers/remote/telemetry/opentelemetry/opentelemetry.py +++ b/llama_stack/providers/remote/telemetry/opentelemetry/opentelemetry.py @@ -4,24 +4,31 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from datetime import datetime +import threading from opentelemetry import metrics, trace -from opentelemetry.exporter.jaeger.thrift import JaegerExporter +from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter +from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter from opentelemetry.sdk.metrics import MeterProvider -from opentelemetry.sdk.metrics.export import ( - ConsoleMetricExporter, - PeriodicExportingMetricReader, -) +from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor from opentelemetry.semconv.resource import ResourceAttributes + from llama_stack.apis.telemetry import * # noqa: F403 from .config import OpenTelemetryConfig +_GLOBAL_STORAGE = { + "active_spans": {}, + "counters": {}, + "gauges": {}, + "up_down_counters": {}, +} +_global_lock = threading.Lock() + def string_to_trace_id(s: str) -> int: # Convert the string to bytes and then to an integer @@ -42,33 +49,37 @@ class OpenTelemetryAdapter(Telemetry): def __init__(self, config: OpenTelemetryConfig): self.config = config - self.resource = Resource.create( - {ResourceAttributes.SERVICE_NAME: "foobar-service"} + resource = Resource.create( + { + ResourceAttributes.SERVICE_NAME: self.config.service_name, + } ) - # Set up tracing with Jaeger exporter - jaeger_exporter = JaegerExporter( - agent_host_name=self.config.jaeger_host, - agent_port=self.config.jaeger_port, + provider = TracerProvider(resource=resource) + trace.set_tracer_provider(provider) + otlp_exporter = OTLPSpanExporter( + endpoint=self.config.otel_endpoint, ) - trace_provider = TracerProvider(resource=self.resource) - trace_processor = BatchSpanProcessor(jaeger_exporter) - trace_provider.add_span_processor(trace_processor) - trace.set_tracer_provider(trace_provider) - self.tracer = trace.get_tracer(__name__) - + span_processor = BatchSpanProcessor(otlp_exporter) + trace.get_tracer_provider().add_span_processor(span_processor) # Set up metrics - metric_reader = PeriodicExportingMetricReader(ConsoleMetricExporter()) + metric_reader = PeriodicExportingMetricReader( + OTLPMetricExporter( + endpoint=self.config.otel_endpoint, + ) + ) metric_provider = MeterProvider( - resource=self.resource, metric_readers=[metric_reader] + resource=resource, metric_readers=[metric_reader] ) metrics.set_meter_provider(metric_provider) self.meter = metrics.get_meter(__name__) + self._lock = _global_lock async def initialize(self) -> None: pass async def shutdown(self) -> None: + trace.get_tracer_provider().force_flush() trace.get_tracer_provider().shutdown() metrics.get_meter_provider().shutdown() @@ -81,121 +92,117 @@ class OpenTelemetryAdapter(Telemetry): self._log_structured(event) def _log_unstructured(self, event: UnstructuredLogEvent) -> None: - span = trace.get_current_span() - span.add_event( - name=event.message, - attributes={"severity": event.severity.value, **event.attributes}, - timestamp=event.timestamp, - ) + with self._lock: + # Use global storage instead of instance storage + span_id = string_to_span_id(event.span_id) + span = _GLOBAL_STORAGE["active_spans"].get(span_id) + + if span: + timestamp_ns = int(event.timestamp.timestamp() * 1e9) + span.add_event( + name=event.type, + attributes={ + "message": event.message, + "severity": event.severity.value, + **event.attributes, + }, + timestamp=timestamp_ns, + ) + else: + print( + f"Warning: No active span found for span_id {span_id}. Dropping event: {event}" + ) + + def _get_or_create_counter(self, name: str, unit: str) -> metrics.Counter: + if name not in _GLOBAL_STORAGE["counters"]: + _GLOBAL_STORAGE["counters"][name] = self.meter.create_counter( + name=name, + unit=unit, + description=f"Counter for {name}", + ) + return _GLOBAL_STORAGE["counters"][name] + + def _get_or_create_gauge(self, name: str, unit: str) -> metrics.ObservableGauge: + if name not in _GLOBAL_STORAGE["gauges"]: + _GLOBAL_STORAGE["gauges"][name] = self.meter.create_gauge( + name=name, + unit=unit, + description=f"Gauge for {name}", + ) + return _GLOBAL_STORAGE["gauges"][name] def _log_metric(self, event: MetricEvent) -> None: if isinstance(event.value, int): - self.meter.create_counter( - name=event.metric, - unit=event.unit, - description=f"Counter for {event.metric}", - ).add(event.value, attributes=event.attributes) + counter = self._get_or_create_counter(event.metric, event.unit) + counter.add(event.value, attributes=event.attributes) elif isinstance(event.value, float): - self.meter.create_gauge( - name=event.metric, - unit=event.unit, - description=f"Gauge for {event.metric}", - ).set(event.value, attributes=event.attributes) + up_down_counter = self._get_or_create_up_down_counter( + event.metric, event.unit + ) + up_down_counter.add(event.value, attributes=event.attributes) + + def _get_or_create_up_down_counter( + self, name: str, unit: str + ) -> metrics.UpDownCounter: + if name not in _GLOBAL_STORAGE["up_down_counters"]: + _GLOBAL_STORAGE["up_down_counters"][name] = ( + self.meter.create_up_down_counter( + name=name, + unit=unit, + description=f"UpDownCounter for {name}", + ) + ) + return _GLOBAL_STORAGE["up_down_counters"][name] def _log_structured(self, event: StructuredLogEvent) -> None: - if isinstance(event.payload, SpanStartPayload): - context = trace.set_span_in_context( - trace.NonRecordingSpan( - trace.SpanContext( - trace_id=string_to_trace_id(event.trace_id), - span_id=string_to_span_id(event.span_id), - is_remote=True, - ) - ) - ) - span = self.tracer.start_span( - name=event.payload.name, - kind=trace.SpanKind.INTERNAL, - context=context, - attributes=event.attributes, - ) + with self._lock: + span_id = string_to_span_id(event.span_id) + trace_id = string_to_trace_id(event.trace_id) + tracer = trace.get_tracer(__name__) - if event.payload.parent_span_id: - span.set_parent( - trace.SpanContext( - trace_id=string_to_trace_id(event.trace_id), - span_id=string_to_span_id(event.payload.parent_span_id), - is_remote=True, + if isinstance(event.payload, SpanStartPayload): + # Check if span already exists to prevent duplicates + if span_id in _GLOBAL_STORAGE["active_spans"]: + return + + parent_span = None + if event.payload.parent_span_id: + parent_span_id = string_to_span_id(event.payload.parent_span_id) + parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id) + + # Create a new trace context with the trace_id + context = trace.Context(trace_id=trace_id) + if parent_span: + context = trace.set_span_in_context(parent_span, context) + + span = tracer.start_span( + name=event.payload.name, + context=context, + attributes=event.attributes or {}, + start_time=int(event.timestamp.timestamp() * 1e9), + ) + _GLOBAL_STORAGE["active_spans"][span_id] = span + + # Set as current span using context manager + with trace.use_span(span, end_on_exit=False): + pass # Let the span continue beyond this block + + elif isinstance(event.payload, SpanEndPayload): + span = _GLOBAL_STORAGE["active_spans"].get(span_id) + if span: + if event.attributes: + span.set_attributes(event.attributes) + + status = ( + trace.Status(status_code=trace.StatusCode.OK) + if event.payload.status == SpanStatus.OK + else trace.Status(status_code=trace.StatusCode.ERROR) ) - ) - elif isinstance(event.payload, SpanEndPayload): - span = trace.get_current_span() - span.set_status( - trace.Status( - trace.StatusCode.OK - if event.payload.status == SpanStatus.OK - else trace.StatusCode.ERROR - ) - ) - span.end(end_time=event.timestamp) + span.set_status(status) + span.end(end_time=int(event.timestamp.timestamp() * 1e9)) + + # Remove from active spans + _GLOBAL_STORAGE["active_spans"].pop(span_id, None) async def get_trace(self, trace_id: str) -> Trace: - # we need to look up the root span id - raise NotImplementedError("not yet no") - - -# Usage example -async def main(): - telemetry = OpenTelemetryTelemetry("my-service") - await telemetry.initialize() - - # Log an unstructured event - await telemetry.log_event( - UnstructuredLogEvent( - trace_id="trace123", - span_id="span456", - timestamp=datetime.now(), - message="This is a log message", - severity=LogSeverity.INFO, - ) - ) - - # Log a metric event - await telemetry.log_event( - MetricEvent( - trace_id="trace123", - span_id="span456", - timestamp=datetime.now(), - metric="my_metric", - value=42, - unit="count", - ) - ) - - # Log a structured event (span start) - await telemetry.log_event( - StructuredLogEvent( - trace_id="trace123", - span_id="span789", - timestamp=datetime.now(), - payload=SpanStartPayload(name="my_operation"), - ) - ) - - # Log a structured event (span end) - await telemetry.log_event( - StructuredLogEvent( - trace_id="trace123", - span_id="span789", - timestamp=datetime.now(), - payload=SpanEndPayload(status=SpanStatus.OK), - ) - ) - - await telemetry.shutdown() - - -if __name__ == "__main__": - import asyncio - - asyncio.run(main()) + raise NotImplementedError("Trace retrieval not implemented yet") diff --git a/llama_stack/providers/utils/telemetry/tracing.py b/llama_stack/providers/utils/telemetry/tracing.py index 3383f7a7a..b53dc0df9 100644 --- a/llama_stack/providers/utils/telemetry/tracing.py +++ b/llama_stack/providers/utils/telemetry/tracing.py @@ -20,7 +20,7 @@ from llama_stack.apis.telemetry import * # noqa: F403 log = logging.getLogger(__name__) -def generate_short_uuid(len: int = 12): +def generate_short_uuid(len: int = 8): full_uuid = uuid.uuid4() uuid_bytes = full_uuid.bytes encoded = base64.urlsafe_b64encode(uuid_bytes) @@ -123,18 +123,19 @@ def setup_logger(api: Telemetry, level: int = logging.INFO): logger.addHandler(TelemetryHandler()) -async def start_trace(name: str, attributes: Dict[str, Any] = None): +async def start_trace(name: str, attributes: Dict[str, Any] = None) -> TraceContext: global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER if BACKGROUND_LOGGER is None: log.info("No Telemetry implementation set. Skipping trace initialization...") return - trace_id = generate_short_uuid() + trace_id = generate_short_uuid(16) context = TraceContext(BACKGROUND_LOGGER, trace_id) context.push_span(name, {"__root__": True, **(attributes or {})}) CURRENT_TRACE_CONTEXT = context + return context async def end_trace(status: SpanStatus = SpanStatus.OK):