From 9cebac8a3c27d453087b05204c3ab0b2a69576b4 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Fri, 22 Nov 2024 09:22:48 -0800 Subject: [PATCH] working telemetry v0 --- llama_stack/apis/models/client.py | 2 +- .../agents/meta_reference/agent_instance.py | 2 +- .../inline/agents/meta_reference/agents.py | 2 +- .../agents/meta_reference/persistence.py | 6 +- .../inline/eval/meta_reference/eval.py | 2 +- .../providers/inline/memory/faiss/faiss.py | 6 +- .../providers/remote/memory/chroma/chroma.py | 2 +- .../remote/telemetry/opentelemetry/config.py | 4 +- .../telemetry/opentelemetry/opentelemetry.py | 267 ++++++++++++------ .../providers/utils/telemetry/tracing.py | 4 +- 10 files changed, 196 insertions(+), 101 deletions(-) diff --git a/llama_stack/apis/models/client.py b/llama_stack/apis/models/client.py index 34541b96e..1a72d8043 100644 --- a/llama_stack/apis/models/client.py +++ b/llama_stack/apis/models/client.py @@ -40,7 +40,7 @@ class ModelsClient(Models): response = await client.post( f"{self.base_url}/models/register", json={ - "model": json.loads(model.json()), + "model": json.loads(model.model_dump_json()), }, headers={"Content-Type": "application/json"}, ) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 6d7fb95c1..ddb357600 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -113,7 +113,7 @@ class ChatAgent(ShieldRunnerMixin): # May be this should be a parameter of the agentic instance # that can define its behavior in a custom way for m in turn.input_messages: - msg = m.copy() + msg = m.model_copy() if isinstance(msg, UserMessage): msg.context = None messages.append(msg) diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index 13d9044fd..f33aadde3 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -52,7 +52,7 @@ class MetaReferenceAgentsImpl(Agents): await self.persistence_store.set( key=f"agent:{agent_id}", - value=agent_config.json(), + value=agent_config.model_dump_json(), ) return AgentCreateResponse( agent_id=agent_id, diff --git a/llama_stack/providers/inline/agents/meta_reference/persistence.py b/llama_stack/providers/inline/agents/meta_reference/persistence.py index d51e25a32..1c99e3d75 100644 --- a/llama_stack/providers/inline/agents/meta_reference/persistence.py +++ b/llama_stack/providers/inline/agents/meta_reference/persistence.py @@ -39,7 +39,7 @@ class AgentPersistence: ) await self.kvstore.set( key=f"session:{self.agent_id}:{session_id}", - value=session_info.json(), + value=session_info.model_dump_json(), ) return session_id @@ -60,13 +60,13 @@ class AgentPersistence: session_info.memory_bank_id = bank_id await self.kvstore.set( key=f"session:{self.agent_id}:{session_id}", - value=session_info.json(), + value=session_info.model_dump_json(), ) async def add_turn_to_session(self, session_id: str, turn: Turn): await self.kvstore.set( key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}", - value=turn.json(), + value=turn.model_dump_json(), ) async def get_session_turns(self, session_id: str) -> List[Turn]: diff --git a/llama_stack/providers/inline/eval/meta_reference/eval.py b/llama_stack/providers/inline/eval/meta_reference/eval.py index d1df869b4..c6cacfcc3 100644 --- a/llama_stack/providers/inline/eval/meta_reference/eval.py +++ b/llama_stack/providers/inline/eval/meta_reference/eval.py @@ -72,7 +72,7 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate): key = f"{EVAL_TASKS_PREFIX}{task_def.identifier}" await self.kvstore.set( key=key, - value=task_def.json(), + value=task_def.model_dump_json(), ) self.eval_tasks[task_def.identifier] = task_def diff --git a/llama_stack/providers/inline/memory/faiss/faiss.py b/llama_stack/providers/inline/memory/faiss/faiss.py index 95791bc69..dfefefeb8 100644 --- a/llama_stack/providers/inline/memory/faiss/faiss.py +++ b/llama_stack/providers/inline/memory/faiss/faiss.py @@ -80,7 +80,9 @@ class FaissIndex(EmbeddingIndex): np.savetxt(buffer, np_index) data = { "id_by_index": self.id_by_index, - "chunk_by_index": {k: v.json() for k, v in self.chunk_by_index.items()}, + "chunk_by_index": { + k: v.model_dump_json() for k, v in self.chunk_by_index.items() + }, "faiss_index": base64.b64encode(buffer.getvalue()).decode("utf-8"), } @@ -162,7 +164,7 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate): key = f"{MEMORY_BANKS_PREFIX}{memory_bank.identifier}" await self.kvstore.set( key=key, - value=memory_bank.json(), + value=memory_bank.model_dump_json(), ) # Store in cache diff --git a/llama_stack/providers/remote/memory/chroma/chroma.py b/llama_stack/providers/remote/memory/chroma/chroma.py index 20185aade..207f6b54d 100644 --- a/llama_stack/providers/remote/memory/chroma/chroma.py +++ b/llama_stack/providers/remote/memory/chroma/chroma.py @@ -107,7 +107,7 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate): collection = await self.client.get_or_create_collection( name=memory_bank.identifier, - metadata={"bank": memory_bank.json()}, + metadata={"bank": memory_bank.model_dump_json()}, ) bank_index = BankWithIndex( bank=memory_bank, index=ChromaIndex(self.client, collection) diff --git a/llama_stack/providers/remote/telemetry/opentelemetry/config.py b/llama_stack/providers/remote/telemetry/opentelemetry/config.py index 71a82aed9..e3e7b8d44 100644 --- a/llama_stack/providers/remote/telemetry/opentelemetry/config.py +++ b/llama_stack/providers/remote/telemetry/opentelemetry/config.py @@ -8,5 +8,5 @@ from pydantic import BaseModel class OpenTelemetryConfig(BaseModel): - jaeger_host: str = "localhost" - jaeger_port: int = 6831 + otel_endpoint: str = "http://localhost:4318/v1/traces" + service_name: str = "llama-stack" diff --git a/llama_stack/providers/remote/telemetry/opentelemetry/opentelemetry.py b/llama_stack/providers/remote/telemetry/opentelemetry/opentelemetry.py index 03e8f7d53..b520f078d 100644 --- a/llama_stack/providers/remote/telemetry/opentelemetry/opentelemetry.py +++ b/llama_stack/providers/remote/telemetry/opentelemetry/opentelemetry.py @@ -4,20 +4,21 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import threading from datetime import datetime from opentelemetry import metrics, trace -from opentelemetry.exporter.jaeger.thrift import JaegerExporter +from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter +from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter from opentelemetry.sdk.metrics import MeterProvider -from opentelemetry.sdk.metrics.export import ( - ConsoleMetricExporter, - PeriodicExportingMetricReader, -) +from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor from opentelemetry.semconv.resource import ResourceAttributes +from llama_stack.providers.utils.telemetry.tracing import generate_short_uuid + from llama_stack.apis.telemetry import * # noqa: F403 from .config import OpenTelemetryConfig @@ -42,33 +43,42 @@ class OpenTelemetryAdapter(Telemetry): def __init__(self, config: OpenTelemetryConfig): self.config = config - self.resource = Resource.create( - {ResourceAttributes.SERVICE_NAME: "foobar-service"} + resource = Resource.create( + { + ResourceAttributes.SERVICE_NAME: self.config.service_name, + } ) - # Set up tracing with Jaeger exporter - jaeger_exporter = JaegerExporter( - agent_host_name=self.config.jaeger_host, - agent_port=self.config.jaeger_port, + provider = TracerProvider(resource=resource) + trace.set_tracer_provider(provider) + otlp_exporter = OTLPSpanExporter( + endpoint=self.config.otel_endpoint, ) - trace_provider = TracerProvider(resource=self.resource) - trace_processor = BatchSpanProcessor(jaeger_exporter) - trace_provider.add_span_processor(trace_processor) - trace.set_tracer_provider(trace_provider) - self.tracer = trace.get_tracer(__name__) - + span_processor = BatchSpanProcessor(otlp_exporter) + trace.get_tracer_provider().add_span_processor(span_processor) # Set up metrics - metric_reader = PeriodicExportingMetricReader(ConsoleMetricExporter()) + metric_reader = PeriodicExportingMetricReader( + OTLPMetricExporter( + endpoint=self.config.otel_endpoint, + ) + ) metric_provider = MeterProvider( - resource=self.resource, metric_readers=[metric_reader] + resource=resource, metric_readers=[metric_reader] ) metrics.set_meter_provider(metric_provider) self.meter = metrics.get_meter(__name__) + # Initialize metric storage + self._counters = {} + self._gauges = {} + self._up_down_counters = {} + self._active_spans = {} + self._lock = threading.Lock() async def initialize(self) -> None: pass async def shutdown(self) -> None: + trace.get_tracer_provider().force_flush() trace.get_tracer_provider().shutdown() metrics.get_meter_provider().shutdown() @@ -81,112 +91,195 @@ class OpenTelemetryAdapter(Telemetry): self._log_structured(event) def _log_unstructured(self, event: UnstructuredLogEvent) -> None: - span = trace.get_current_span() - span.add_event( - name=event.message, - attributes={"severity": event.severity.value, **event.attributes}, - timestamp=event.timestamp, - ) + with self._lock: + # Check if there's an existing span in the cache + span_id = string_to_span_id(event.span_id) + span = self._active_spans.get(span_id) + + if span: + # Use existing span + timestamp_ns = int(event.timestamp.timestamp() * 1e9) + span.add_event( + name=event.message, + attributes={"severity": event.severity.value, **event.attributes}, + timestamp=timestamp_ns, + ) + else: + print( + f"Warning: No active span found for span_id {span_id}. Dropping event: {event}" + ) + + def _get_or_create_counter(self, name: str, unit: str) -> metrics.Counter: + if name not in self._counters: + self._counters[name] = self.meter.create_counter( + name=name, + unit=unit, + description=f"Counter for {name}", + ) + return self._counters[name] + + def _get_or_create_gauge(self, name: str, unit: str) -> metrics.ObservableGauge: + if name not in self._gauges: + self._gauges[name] = self.meter.create_gauge( + name=name, + unit=unit, + description=f"Gauge for {name}", + ) + return self._gauges[name] def _log_metric(self, event: MetricEvent) -> None: if isinstance(event.value, int): - self.meter.create_counter( - name=event.metric, - unit=event.unit, - description=f"Counter for {event.metric}", - ).add(event.value, attributes=event.attributes) + counter = self._get_or_create_counter(event.metric, event.unit) + counter.add(event.value, attributes=event.attributes) elif isinstance(event.value, float): - self.meter.create_gauge( - name=event.metric, - unit=event.unit, - description=f"Gauge for {event.metric}", - ).set(event.value, attributes=event.attributes) + up_down_counter = self._get_or_create_up_down_counter( + event.metric, event.unit + ) + up_down_counter.add(event.value, attributes=event.attributes) + + def _get_or_create_up_down_counter( + self, name: str, unit: str + ) -> metrics.UpDownCounter: + if name not in self._up_down_counters: + self._up_down_counters[name] = self.meter.create_up_down_counter( + name=name, + unit=unit, + description=f"UpDownCounter for {name}", + ) + return self._up_down_counters[name] def _log_structured(self, event: StructuredLogEvent) -> None: - if isinstance(event.payload, SpanStartPayload): - context = trace.set_span_in_context( - trace.NonRecordingSpan( - trace.SpanContext( - trace_id=string_to_trace_id(event.trace_id), - span_id=string_to_span_id(event.span_id), - is_remote=True, - ) - ) - ) - span = self.tracer.start_span( - name=event.payload.name, - kind=trace.SpanKind.INTERNAL, - context=context, - attributes=event.attributes, + with self._lock: + trace_id = string_to_trace_id(event.trace_id) + span_id = string_to_span_id(event.span_id) + + tracer = trace.get_tracer(__name__) + span_context = trace.SpanContext( + trace_id=trace_id, + span_id=span_id, + is_remote=True, + trace_flags=trace.TraceFlags(trace.TraceFlags.SAMPLED), + trace_state=trace.TraceState(), ) - if event.payload.parent_span_id: - span.set_parent( - trace.SpanContext( - trace_id=string_to_trace_id(event.trace_id), - span_id=string_to_span_id(event.payload.parent_span_id), - is_remote=True, + if isinstance(event.payload, SpanStartPayload): + # Get parent span if it exists + parent_span = None + for active_span in self._active_spans.values(): + if active_span.is_recording(): + parent_span = active_span + break + + # Create the context properly + context = trace.Context() + if parent_span: + context = trace.set_span_in_context(parent_span) + + span = tracer.start_span( + name=event.payload.name, + context=context, + attributes=event.attributes or {}, + start_time=int(event.timestamp.timestamp() * 1e9), + ) + self._active_spans[span_id] = span + + # Set the span as current + _ = trace.set_span_in_context(span) + trace.use_span(span, end_on_exit=False) + + elif isinstance(event.payload, SpanEndPayload): + # Retrieve and end the existing span + span = self._active_spans.get(span_id) + if span: + if event.attributes: + span.set_attributes(event.attributes) + + status = ( + trace.Status(status_code=trace.StatusCode.OK) + if event.payload.status == SpanStatus.OK + else trace.Status(status_code=trace.StatusCode.ERROR) ) - ) - elif isinstance(event.payload, SpanEndPayload): - span = trace.get_current_span() - span.set_status( - trace.Status( - trace.StatusCode.OK - if event.payload.status == SpanStatus.OK - else trace.StatusCode.ERROR - ) - ) - span.end(end_time=event.timestamp) + span.set_status(status) + span.end(end_time=int(event.timestamp.timestamp() * 1e9)) + + # Remove from active spans + del self._active_spans[span_id] async def get_trace(self, trace_id: str) -> Trace: - # we need to look up the root span id - raise NotImplementedError("not yet no") + raise NotImplementedError("Trace retrieval not implemented yet") # Usage example async def main(): - telemetry = OpenTelemetryTelemetry("my-service") + telemetry = OpenTelemetryAdapter(OpenTelemetryConfig()) await telemetry.initialize() + # # Log a metric event + # await telemetry.log_event( + # MetricEvent( + # trace_id="trace123", + # span_id="span456", + # timestamp=datetime.now(), + # metric="my_metric", + # value=42, + # unit="count", + # ) + # ) + + # Log a structured event (span start) + trace_id = generate_short_uuid(16) + span_id = generate_short_uuid(8) + span_id_2 = generate_short_uuid(8) + await telemetry.log_event( + StructuredLogEvent( + trace_id=trace_id, + span_id=span_id, + timestamp=datetime.now(), + payload=SpanStartPayload(name="my_operation"), + ) + ) + # Log an unstructured event await telemetry.log_event( UnstructuredLogEvent( - trace_id="trace123", - span_id="span456", + trace_id=trace_id, + span_id=span_id, timestamp=datetime.now(), message="This is a log message", severity=LogSeverity.INFO, ) ) - # Log a metric event await telemetry.log_event( - MetricEvent( - trace_id="trace123", - span_id="span456", + StructuredLogEvent( + trace_id=trace_id, + span_id=span_id_2, timestamp=datetime.now(), - metric="my_metric", - value=42, - unit="count", + payload=SpanStartPayload(name="my_operation_2"), + ) + ) + await telemetry.log_event( + UnstructuredLogEvent( + trace_id=trace_id, + span_id=span_id_2, + timestamp=datetime.now(), + message="This is a log message 2", + severity=LogSeverity.INFO, ) ) - # Log a structured event (span start) await telemetry.log_event( StructuredLogEvent( - trace_id="trace123", - span_id="span789", + trace_id=trace_id, + span_id=span_id_2, timestamp=datetime.now(), - payload=SpanStartPayload(name="my_operation"), + payload=SpanEndPayload(status=SpanStatus.OK), ) ) - - # Log a structured event (span end) await telemetry.log_event( StructuredLogEvent( - trace_id="trace123", - span_id="span789", + trace_id=trace_id, + span_id=span_id, timestamp=datetime.now(), payload=SpanEndPayload(status=SpanStatus.OK), ) diff --git a/llama_stack/providers/utils/telemetry/tracing.py b/llama_stack/providers/utils/telemetry/tracing.py index 3383f7a7a..689c79508 100644 --- a/llama_stack/providers/utils/telemetry/tracing.py +++ b/llama_stack/providers/utils/telemetry/tracing.py @@ -20,7 +20,7 @@ from llama_stack.apis.telemetry import * # noqa: F403 log = logging.getLogger(__name__) -def generate_short_uuid(len: int = 12): +def generate_short_uuid(len: int = 8): full_uuid = uuid.uuid4() uuid_bytes = full_uuid.bytes encoded = base64.urlsafe_b64encode(uuid_bytes) @@ -130,7 +130,7 @@ async def start_trace(name: str, attributes: Dict[str, Any] = None): log.info("No Telemetry implementation set. Skipping trace initialization...") return - trace_id = generate_short_uuid() + trace_id = generate_short_uuid(16) context = TraceContext(BACKGROUND_LOGGER, trace_id) context.push_span(name, {"__root__": True, **(attributes or {})})