working telemetry v0

This commit is contained in:
Dinesh Yeduguru 2024-11-22 09:22:48 -08:00
parent d790be28b3
commit 9cebac8a3c
10 changed files with 196 additions and 101 deletions

View file

@ -40,7 +40,7 @@ class ModelsClient(Models):
response = await client.post( response = await client.post(
f"{self.base_url}/models/register", f"{self.base_url}/models/register",
json={ json={
"model": json.loads(model.json()), "model": json.loads(model.model_dump_json()),
}, },
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
) )

View file

@ -113,7 +113,7 @@ class ChatAgent(ShieldRunnerMixin):
# May be this should be a parameter of the agentic instance # May be this should be a parameter of the agentic instance
# that can define its behavior in a custom way # that can define its behavior in a custom way
for m in turn.input_messages: for m in turn.input_messages:
msg = m.copy() msg = m.model_copy()
if isinstance(msg, UserMessage): if isinstance(msg, UserMessage):
msg.context = None msg.context = None
messages.append(msg) messages.append(msg)

View file

@ -52,7 +52,7 @@ class MetaReferenceAgentsImpl(Agents):
await self.persistence_store.set( await self.persistence_store.set(
key=f"agent:{agent_id}", key=f"agent:{agent_id}",
value=agent_config.json(), value=agent_config.model_dump_json(),
) )
return AgentCreateResponse( return AgentCreateResponse(
agent_id=agent_id, agent_id=agent_id,

View file

@ -39,7 +39,7 @@ class AgentPersistence:
) )
await self.kvstore.set( await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}", key=f"session:{self.agent_id}:{session_id}",
value=session_info.json(), value=session_info.model_dump_json(),
) )
return session_id return session_id
@ -60,13 +60,13 @@ class AgentPersistence:
session_info.memory_bank_id = bank_id session_info.memory_bank_id = bank_id
await self.kvstore.set( await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}", key=f"session:{self.agent_id}:{session_id}",
value=session_info.json(), value=session_info.model_dump_json(),
) )
async def add_turn_to_session(self, session_id: str, turn: Turn): async def add_turn_to_session(self, session_id: str, turn: Turn):
await self.kvstore.set( await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}", key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}",
value=turn.json(), value=turn.model_dump_json(),
) )
async def get_session_turns(self, session_id: str) -> List[Turn]: async def get_session_turns(self, session_id: str) -> List[Turn]:

View file

@ -72,7 +72,7 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
key = f"{EVAL_TASKS_PREFIX}{task_def.identifier}" key = f"{EVAL_TASKS_PREFIX}{task_def.identifier}"
await self.kvstore.set( await self.kvstore.set(
key=key, key=key,
value=task_def.json(), value=task_def.model_dump_json(),
) )
self.eval_tasks[task_def.identifier] = task_def self.eval_tasks[task_def.identifier] = task_def

View file

@ -80,7 +80,9 @@ class FaissIndex(EmbeddingIndex):
np.savetxt(buffer, np_index) np.savetxt(buffer, np_index)
data = { data = {
"id_by_index": self.id_by_index, "id_by_index": self.id_by_index,
"chunk_by_index": {k: v.json() for k, v in self.chunk_by_index.items()}, "chunk_by_index": {
k: v.model_dump_json() for k, v in self.chunk_by_index.items()
},
"faiss_index": base64.b64encode(buffer.getvalue()).decode("utf-8"), "faiss_index": base64.b64encode(buffer.getvalue()).decode("utf-8"),
} }
@ -162,7 +164,7 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
key = f"{MEMORY_BANKS_PREFIX}{memory_bank.identifier}" key = f"{MEMORY_BANKS_PREFIX}{memory_bank.identifier}"
await self.kvstore.set( await self.kvstore.set(
key=key, key=key,
value=memory_bank.json(), value=memory_bank.model_dump_json(),
) )
# Store in cache # Store in cache

View file

@ -107,7 +107,7 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
collection = await self.client.get_or_create_collection( collection = await self.client.get_or_create_collection(
name=memory_bank.identifier, name=memory_bank.identifier,
metadata={"bank": memory_bank.json()}, metadata={"bank": memory_bank.model_dump_json()},
) )
bank_index = BankWithIndex( bank_index = BankWithIndex(
bank=memory_bank, index=ChromaIndex(self.client, collection) bank=memory_bank, index=ChromaIndex(self.client, collection)

View file

@ -8,5 +8,5 @@ from pydantic import BaseModel
class OpenTelemetryConfig(BaseModel): class OpenTelemetryConfig(BaseModel):
jaeger_host: str = "localhost" otel_endpoint: str = "http://localhost:4318/v1/traces"
jaeger_port: int = 6831 service_name: str = "llama-stack"

View file

@ -4,20 +4,21 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import threading
from datetime import datetime from datetime import datetime
from opentelemetry import metrics, trace from opentelemetry import metrics, trace
from opentelemetry.exporter.jaeger.thrift import JaegerExporter from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.metrics import MeterProvider from opentelemetry.sdk.metrics import MeterProvider
from opentelemetry.sdk.metrics.export import ( from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
ConsoleMetricExporter,
PeriodicExportingMetricReader,
)
from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.semconv.resource import ResourceAttributes from opentelemetry.semconv.resource import ResourceAttributes
from llama_stack.providers.utils.telemetry.tracing import generate_short_uuid
from llama_stack.apis.telemetry import * # noqa: F403 from llama_stack.apis.telemetry import * # noqa: F403
from .config import OpenTelemetryConfig from .config import OpenTelemetryConfig
@ -42,33 +43,42 @@ class OpenTelemetryAdapter(Telemetry):
def __init__(self, config: OpenTelemetryConfig): def __init__(self, config: OpenTelemetryConfig):
self.config = config self.config = config
self.resource = Resource.create( resource = Resource.create(
{ResourceAttributes.SERVICE_NAME: "foobar-service"} {
ResourceAttributes.SERVICE_NAME: self.config.service_name,
}
) )
# Set up tracing with Jaeger exporter provider = TracerProvider(resource=resource)
jaeger_exporter = JaegerExporter( trace.set_tracer_provider(provider)
agent_host_name=self.config.jaeger_host, otlp_exporter = OTLPSpanExporter(
agent_port=self.config.jaeger_port, endpoint=self.config.otel_endpoint,
) )
trace_provider = TracerProvider(resource=self.resource) span_processor = BatchSpanProcessor(otlp_exporter)
trace_processor = BatchSpanProcessor(jaeger_exporter) trace.get_tracer_provider().add_span_processor(span_processor)
trace_provider.add_span_processor(trace_processor)
trace.set_tracer_provider(trace_provider)
self.tracer = trace.get_tracer(__name__)
# Set up metrics # Set up metrics
metric_reader = PeriodicExportingMetricReader(ConsoleMetricExporter()) metric_reader = PeriodicExportingMetricReader(
OTLPMetricExporter(
endpoint=self.config.otel_endpoint,
)
)
metric_provider = MeterProvider( metric_provider = MeterProvider(
resource=self.resource, metric_readers=[metric_reader] resource=resource, metric_readers=[metric_reader]
) )
metrics.set_meter_provider(metric_provider) metrics.set_meter_provider(metric_provider)
self.meter = metrics.get_meter(__name__) self.meter = metrics.get_meter(__name__)
# Initialize metric storage
self._counters = {}
self._gauges = {}
self._up_down_counters = {}
self._active_spans = {}
self._lock = threading.Lock()
async def initialize(self) -> None: async def initialize(self) -> None:
pass pass
async def shutdown(self) -> None: async def shutdown(self) -> None:
trace.get_tracer_provider().force_flush()
trace.get_tracer_provider().shutdown() trace.get_tracer_provider().shutdown()
metrics.get_meter_provider().shutdown() metrics.get_meter_provider().shutdown()
@ -81,112 +91,195 @@ class OpenTelemetryAdapter(Telemetry):
self._log_structured(event) self._log_structured(event)
def _log_unstructured(self, event: UnstructuredLogEvent) -> None: def _log_unstructured(self, event: UnstructuredLogEvent) -> None:
span = trace.get_current_span() with self._lock:
span.add_event( # Check if there's an existing span in the cache
name=event.message, span_id = string_to_span_id(event.span_id)
attributes={"severity": event.severity.value, **event.attributes}, span = self._active_spans.get(span_id)
timestamp=event.timestamp,
) if span:
# Use existing span
timestamp_ns = int(event.timestamp.timestamp() * 1e9)
span.add_event(
name=event.message,
attributes={"severity": event.severity.value, **event.attributes},
timestamp=timestamp_ns,
)
else:
print(
f"Warning: No active span found for span_id {span_id}. Dropping event: {event}"
)
def _get_or_create_counter(self, name: str, unit: str) -> metrics.Counter:
if name not in self._counters:
self._counters[name] = self.meter.create_counter(
name=name,
unit=unit,
description=f"Counter for {name}",
)
return self._counters[name]
def _get_or_create_gauge(self, name: str, unit: str) -> metrics.ObservableGauge:
if name not in self._gauges:
self._gauges[name] = self.meter.create_gauge(
name=name,
unit=unit,
description=f"Gauge for {name}",
)
return self._gauges[name]
def _log_metric(self, event: MetricEvent) -> None: def _log_metric(self, event: MetricEvent) -> None:
if isinstance(event.value, int): if isinstance(event.value, int):
self.meter.create_counter( counter = self._get_or_create_counter(event.metric, event.unit)
name=event.metric, counter.add(event.value, attributes=event.attributes)
unit=event.unit,
description=f"Counter for {event.metric}",
).add(event.value, attributes=event.attributes)
elif isinstance(event.value, float): elif isinstance(event.value, float):
self.meter.create_gauge( up_down_counter = self._get_or_create_up_down_counter(
name=event.metric, event.metric, event.unit
unit=event.unit, )
description=f"Gauge for {event.metric}", up_down_counter.add(event.value, attributes=event.attributes)
).set(event.value, attributes=event.attributes)
def _get_or_create_up_down_counter(
self, name: str, unit: str
) -> metrics.UpDownCounter:
if name not in self._up_down_counters:
self._up_down_counters[name] = self.meter.create_up_down_counter(
name=name,
unit=unit,
description=f"UpDownCounter for {name}",
)
return self._up_down_counters[name]
def _log_structured(self, event: StructuredLogEvent) -> None: def _log_structured(self, event: StructuredLogEvent) -> None:
if isinstance(event.payload, SpanStartPayload): with self._lock:
context = trace.set_span_in_context( trace_id = string_to_trace_id(event.trace_id)
trace.NonRecordingSpan( span_id = string_to_span_id(event.span_id)
trace.SpanContext(
trace_id=string_to_trace_id(event.trace_id), tracer = trace.get_tracer(__name__)
span_id=string_to_span_id(event.span_id), span_context = trace.SpanContext(
is_remote=True, trace_id=trace_id,
) span_id=span_id,
) is_remote=True,
) trace_flags=trace.TraceFlags(trace.TraceFlags.SAMPLED),
span = self.tracer.start_span( trace_state=trace.TraceState(),
name=event.payload.name,
kind=trace.SpanKind.INTERNAL,
context=context,
attributes=event.attributes,
) )
if event.payload.parent_span_id: if isinstance(event.payload, SpanStartPayload):
span.set_parent( # Get parent span if it exists
trace.SpanContext( parent_span = None
trace_id=string_to_trace_id(event.trace_id), for active_span in self._active_spans.values():
span_id=string_to_span_id(event.payload.parent_span_id), if active_span.is_recording():
is_remote=True, parent_span = active_span
break
# Create the context properly
context = trace.Context()
if parent_span:
context = trace.set_span_in_context(parent_span)
span = tracer.start_span(
name=event.payload.name,
context=context,
attributes=event.attributes or {},
start_time=int(event.timestamp.timestamp() * 1e9),
)
self._active_spans[span_id] = span
# Set the span as current
_ = trace.set_span_in_context(span)
trace.use_span(span, end_on_exit=False)
elif isinstance(event.payload, SpanEndPayload):
# Retrieve and end the existing span
span = self._active_spans.get(span_id)
if span:
if event.attributes:
span.set_attributes(event.attributes)
status = (
trace.Status(status_code=trace.StatusCode.OK)
if event.payload.status == SpanStatus.OK
else trace.Status(status_code=trace.StatusCode.ERROR)
) )
) span.set_status(status)
elif isinstance(event.payload, SpanEndPayload): span.end(end_time=int(event.timestamp.timestamp() * 1e9))
span = trace.get_current_span()
span.set_status( # Remove from active spans
trace.Status( del self._active_spans[span_id]
trace.StatusCode.OK
if event.payload.status == SpanStatus.OK
else trace.StatusCode.ERROR
)
)
span.end(end_time=event.timestamp)
async def get_trace(self, trace_id: str) -> Trace: async def get_trace(self, trace_id: str) -> Trace:
# we need to look up the root span id raise NotImplementedError("Trace retrieval not implemented yet")
raise NotImplementedError("not yet no")
# Usage example # Usage example
async def main(): async def main():
telemetry = OpenTelemetryTelemetry("my-service") telemetry = OpenTelemetryAdapter(OpenTelemetryConfig())
await telemetry.initialize() await telemetry.initialize()
# # Log a metric event
# await telemetry.log_event(
# MetricEvent(
# trace_id="trace123",
# span_id="span456",
# timestamp=datetime.now(),
# metric="my_metric",
# value=42,
# unit="count",
# )
# )
# Log a structured event (span start)
trace_id = generate_short_uuid(16)
span_id = generate_short_uuid(8)
span_id_2 = generate_short_uuid(8)
await telemetry.log_event(
StructuredLogEvent(
trace_id=trace_id,
span_id=span_id,
timestamp=datetime.now(),
payload=SpanStartPayload(name="my_operation"),
)
)
# Log an unstructured event # Log an unstructured event
await telemetry.log_event( await telemetry.log_event(
UnstructuredLogEvent( UnstructuredLogEvent(
trace_id="trace123", trace_id=trace_id,
span_id="span456", span_id=span_id,
timestamp=datetime.now(), timestamp=datetime.now(),
message="This is a log message", message="This is a log message",
severity=LogSeverity.INFO, severity=LogSeverity.INFO,
) )
) )
# Log a metric event
await telemetry.log_event( await telemetry.log_event(
MetricEvent( StructuredLogEvent(
trace_id="trace123", trace_id=trace_id,
span_id="span456", span_id=span_id_2,
timestamp=datetime.now(), timestamp=datetime.now(),
metric="my_metric", payload=SpanStartPayload(name="my_operation_2"),
value=42, )
unit="count", )
await telemetry.log_event(
UnstructuredLogEvent(
trace_id=trace_id,
span_id=span_id_2,
timestamp=datetime.now(),
message="This is a log message 2",
severity=LogSeverity.INFO,
) )
) )
# Log a structured event (span start)
await telemetry.log_event( await telemetry.log_event(
StructuredLogEvent( StructuredLogEvent(
trace_id="trace123", trace_id=trace_id,
span_id="span789", span_id=span_id_2,
timestamp=datetime.now(), timestamp=datetime.now(),
payload=SpanStartPayload(name="my_operation"), payload=SpanEndPayload(status=SpanStatus.OK),
) )
) )
# Log a structured event (span end)
await telemetry.log_event( await telemetry.log_event(
StructuredLogEvent( StructuredLogEvent(
trace_id="trace123", trace_id=trace_id,
span_id="span789", span_id=span_id,
timestamp=datetime.now(), timestamp=datetime.now(),
payload=SpanEndPayload(status=SpanStatus.OK), payload=SpanEndPayload(status=SpanStatus.OK),
) )

View file

@ -20,7 +20,7 @@ from llama_stack.apis.telemetry import * # noqa: F403
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
def generate_short_uuid(len: int = 12): def generate_short_uuid(len: int = 8):
full_uuid = uuid.uuid4() full_uuid = uuid.uuid4()
uuid_bytes = full_uuid.bytes uuid_bytes = full_uuid.bytes
encoded = base64.urlsafe_b64encode(uuid_bytes) encoded = base64.urlsafe_b64encode(uuid_bytes)
@ -130,7 +130,7 @@ async def start_trace(name: str, attributes: Dict[str, Any] = None):
log.info("No Telemetry implementation set. Skipping trace initialization...") log.info("No Telemetry implementation set. Skipping trace initialization...")
return return
trace_id = generate_short_uuid() trace_id = generate_short_uuid(16)
context = TraceContext(BACKGROUND_LOGGER, trace_id) context = TraceContext(BACKGROUND_LOGGER, trace_id)
context.push_span(name, {"__root__": True, **(attributes or {})}) context.push_span(name, {"__root__": True, **(attributes or {})})