diff --git a/llama_stack/apis/models/client.py b/llama_stack/apis/models/client.py index 34541b96e..1a72d8043 100644 --- a/llama_stack/apis/models/client.py +++ b/llama_stack/apis/models/client.py @@ -40,7 +40,7 @@ class ModelsClient(Models): response = await client.post( f"{self.base_url}/models/register", json={ - "model": json.loads(model.json()), + "model": json.loads(model.model_dump_json()), }, headers={"Content-Type": "application/json"}, ) diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index b8ff0e785..8116e2b39 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -17,13 +17,11 @@ import warnings from contextlib import asynccontextmanager from pathlib import Path -from ssl import SSLError -from typing import Any, Dict, Optional +from typing import Any, Union -import httpx import yaml -from fastapi import Body, FastAPI, HTTPException, Request, Response +from fastapi import Body, FastAPI, HTTPException, Request from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse, StreamingResponse from pydantic import BaseModel, ValidationError @@ -35,7 +33,6 @@ from llama_stack.distribution.distribution import builtin_automatically_routed_a from llama_stack.providers.utils.telemetry.tracing import ( end_trace, setup_logger, - SpanStatus, start_trace, ) from llama_stack.distribution.datatypes import * # noqa: F403 @@ -118,67 +115,6 @@ def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidatio ) -async def passthrough( - request: Request, - downstream_url: str, - downstream_headers: Optional[Dict[str, str]] = None, -): - await start_trace(request.path, {"downstream_url": downstream_url}) - - headers = dict(request.headers) - headers.pop("host", None) - headers.update(downstream_headers or {}) - - content = await request.body() - - client = httpx.AsyncClient() - erred = False - try: - req = client.build_request( - method=request.method, - url=downstream_url, - headers=headers, - content=content, - params=request.query_params, - ) - response = await client.send(req, stream=True) - - async def stream_response(): - async for chunk in response.aiter_raw(chunk_size=64): - yield chunk - - await response.aclose() - await client.aclose() - - return StreamingResponse( - stream_response(), - status_code=response.status_code, - headers=dict(response.headers), - media_type=response.headers.get("content-type"), - ) - - except httpx.ReadTimeout: - erred = True - return Response(content="Downstream server timed out", status_code=504) - except httpx.NetworkError as e: - erred = True - return Response(content=f"Network error: {str(e)}", status_code=502) - except httpx.TooManyRedirects: - erred = True - return Response(content="Too many redirects", status_code=502) - except SSLError as e: - erred = True - return Response(content=f"SSL error: {str(e)}", status_code=502) - except httpx.HTTPStatusError as e: - erred = True - return Response(content=str(e), status_code=e.response.status_code) - except Exception as e: - erred = True - return Response(content=f"Unexpected error: {str(e)}", status_code=500) - finally: - await end_trace(SpanStatus.OK if not erred else SpanStatus.ERROR) - - def handle_sigint(app, *args, **kwargs): print("SIGINT or CTRL-C detected. Exiting gracefully...") @@ -217,7 +153,6 @@ async def maybe_await(value): async def sse_generator(event_gen): - await start_trace("sse_generator") try: event_gen = await event_gen async for item in event_gen: @@ -235,14 +170,10 @@ async def sse_generator(event_gen): }, } ) - finally: - await end_trace() def create_dynamic_typed_route(func: Any, method: str): async def endpoint(request: Request, **kwargs): - await start_trace(func.__name__) - set_request_provider_data(request.headers) is_streaming = is_streaming_request(func.__name__, request, **kwargs) @@ -257,8 +188,6 @@ def create_dynamic_typed_route(func: Any, method: str): except Exception as e: traceback.print_exception(e) raise translate_exception(e) from e - finally: - await end_trace() sig = inspect.signature(func) new_params = [ @@ -282,6 +211,19 @@ def create_dynamic_typed_route(func: Any, method: str): return endpoint +class TracingMiddleware: + def __init__(self, app): + self.app = app + + async def __call__(self, scope, receive, send): + path = scope["path"] + await start_trace(path, {"location": "server"}) + try: + return await self.app(scope, receive, send) + finally: + await end_trace() + + def main(): """Start the LlamaStack server.""" parser = argparse.ArgumentParser(description="Start the LlamaStack server.") @@ -338,6 +280,7 @@ def main(): print(yaml.dump(config.model_dump(), indent=2)) app = FastAPI(lifespan=lifespan) + app.add_middleware(TracingMiddleware) try: impls = asyncio.run(construct_stack(config)) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index e1713c0e3..8f800ad6f 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -113,7 +113,7 @@ class ChatAgent(ShieldRunnerMixin): # May be this should be a parameter of the agentic instance # that can define its behavior in a custom way for m in turn.input_messages: - msg = m.copy() + msg = m.model_copy() if isinstance(msg, UserMessage): msg.context = None messages.append(msg) diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index 13d9044fd..f33aadde3 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -52,7 +52,7 @@ class MetaReferenceAgentsImpl(Agents): await self.persistence_store.set( key=f"agent:{agent_id}", - value=agent_config.json(), + value=agent_config.model_dump_json(), ) return AgentCreateResponse( agent_id=agent_id, diff --git a/llama_stack/providers/inline/agents/meta_reference/persistence.py b/llama_stack/providers/inline/agents/meta_reference/persistence.py index d51e25a32..1c99e3d75 100644 --- a/llama_stack/providers/inline/agents/meta_reference/persistence.py +++ b/llama_stack/providers/inline/agents/meta_reference/persistence.py @@ -39,7 +39,7 @@ class AgentPersistence: ) await self.kvstore.set( key=f"session:{self.agent_id}:{session_id}", - value=session_info.json(), + value=session_info.model_dump_json(), ) return session_id @@ -60,13 +60,13 @@ class AgentPersistence: session_info.memory_bank_id = bank_id await self.kvstore.set( key=f"session:{self.agent_id}:{session_id}", - value=session_info.json(), + value=session_info.model_dump_json(), ) async def add_turn_to_session(self, session_id: str, turn: Turn): await self.kvstore.set( key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}", - value=turn.json(), + value=turn.model_dump_json(), ) async def get_session_turns(self, session_id: str) -> List[Turn]: diff --git a/llama_stack/providers/inline/eval/meta_reference/eval.py b/llama_stack/providers/inline/eval/meta_reference/eval.py index d1df869b4..c6cacfcc3 100644 --- a/llama_stack/providers/inline/eval/meta_reference/eval.py +++ b/llama_stack/providers/inline/eval/meta_reference/eval.py @@ -72,7 +72,7 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate): key = f"{EVAL_TASKS_PREFIX}{task_def.identifier}" await self.kvstore.set( key=key, - value=task_def.json(), + value=task_def.model_dump_json(), ) self.eval_tasks[task_def.identifier] = task_def diff --git a/llama_stack/providers/inline/memory/faiss/faiss.py b/llama_stack/providers/inline/memory/faiss/faiss.py index 95791bc69..dfefefeb8 100644 --- a/llama_stack/providers/inline/memory/faiss/faiss.py +++ b/llama_stack/providers/inline/memory/faiss/faiss.py @@ -80,7 +80,9 @@ class FaissIndex(EmbeddingIndex): np.savetxt(buffer, np_index) data = { "id_by_index": self.id_by_index, - "chunk_by_index": {k: v.json() for k, v in self.chunk_by_index.items()}, + "chunk_by_index": { + k: v.model_dump_json() for k, v in self.chunk_by_index.items() + }, "faiss_index": base64.b64encode(buffer.getvalue()).decode("utf-8"), } @@ -162,7 +164,7 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate): key = f"{MEMORY_BANKS_PREFIX}{memory_bank.identifier}" await self.kvstore.set( key=key, - value=memory_bank.json(), + value=memory_bank.model_dump_json(), ) # Store in cache diff --git a/llama_stack/providers/remote/memory/chroma/chroma.py b/llama_stack/providers/remote/memory/chroma/chroma.py index 20185aade..207f6b54d 100644 --- a/llama_stack/providers/remote/memory/chroma/chroma.py +++ b/llama_stack/providers/remote/memory/chroma/chroma.py @@ -107,7 +107,7 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate): collection = await self.client.get_or_create_collection( name=memory_bank.identifier, - metadata={"bank": memory_bank.json()}, + metadata={"bank": memory_bank.model_dump_json()}, ) bank_index = BankWithIndex( bank=memory_bank, index=ChromaIndex(self.client, collection) diff --git a/llama_stack/providers/remote/telemetry/opentelemetry/config.py b/llama_stack/providers/remote/telemetry/opentelemetry/config.py index 71a82aed9..5e9dff1a1 100644 --- a/llama_stack/providers/remote/telemetry/opentelemetry/config.py +++ b/llama_stack/providers/remote/telemetry/opentelemetry/config.py @@ -4,9 +4,24 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from pydantic import BaseModel +from typing import Any, Dict + +from pydantic import BaseModel, Field class OpenTelemetryConfig(BaseModel): - jaeger_host: str = "localhost" - jaeger_port: int = 6831 + otel_endpoint: str = Field( + default="http://localhost:4318/v1/traces", + description="The OpenTelemetry collector endpoint URL", + ) + service_name: str = Field( + default="llama-stack", + description="The service name to use for telemetry", + ) + + @classmethod + def sample_run_config(cls, **kwargs) -> Dict[str, Any]: + return { + "otel_endpoint": "${env.OTEL_ENDPOINT:http://localhost:4318/v1/traces}", + "service_name": "${env.OTEL_SERVICE_NAME:llama-stack}", + } diff --git a/llama_stack/providers/remote/telemetry/opentelemetry/opentelemetry.py b/llama_stack/providers/remote/telemetry/opentelemetry/opentelemetry.py index 03e8f7d53..c9830fd9d 100644 --- a/llama_stack/providers/remote/telemetry/opentelemetry/opentelemetry.py +++ b/llama_stack/providers/remote/telemetry/opentelemetry/opentelemetry.py @@ -4,24 +4,31 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from datetime import datetime +import threading from opentelemetry import metrics, trace -from opentelemetry.exporter.jaeger.thrift import JaegerExporter +from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter +from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter from opentelemetry.sdk.metrics import MeterProvider -from opentelemetry.sdk.metrics.export import ( - ConsoleMetricExporter, - PeriodicExportingMetricReader, -) +from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor from opentelemetry.semconv.resource import ResourceAttributes + from llama_stack.apis.telemetry import * # noqa: F403 from .config import OpenTelemetryConfig +_GLOBAL_STORAGE = { + "active_spans": {}, + "counters": {}, + "gauges": {}, + "up_down_counters": {}, +} +_global_lock = threading.Lock() + def string_to_trace_id(s: str) -> int: # Convert the string to bytes and then to an integer @@ -42,33 +49,37 @@ class OpenTelemetryAdapter(Telemetry): def __init__(self, config: OpenTelemetryConfig): self.config = config - self.resource = Resource.create( - {ResourceAttributes.SERVICE_NAME: "foobar-service"} + resource = Resource.create( + { + ResourceAttributes.SERVICE_NAME: self.config.service_name, + } ) - # Set up tracing with Jaeger exporter - jaeger_exporter = JaegerExporter( - agent_host_name=self.config.jaeger_host, - agent_port=self.config.jaeger_port, + provider = TracerProvider(resource=resource) + trace.set_tracer_provider(provider) + otlp_exporter = OTLPSpanExporter( + endpoint=self.config.otel_endpoint, ) - trace_provider = TracerProvider(resource=self.resource) - trace_processor = BatchSpanProcessor(jaeger_exporter) - trace_provider.add_span_processor(trace_processor) - trace.set_tracer_provider(trace_provider) - self.tracer = trace.get_tracer(__name__) - + span_processor = BatchSpanProcessor(otlp_exporter) + trace.get_tracer_provider().add_span_processor(span_processor) # Set up metrics - metric_reader = PeriodicExportingMetricReader(ConsoleMetricExporter()) + metric_reader = PeriodicExportingMetricReader( + OTLPMetricExporter( + endpoint=self.config.otel_endpoint, + ) + ) metric_provider = MeterProvider( - resource=self.resource, metric_readers=[metric_reader] + resource=resource, metric_readers=[metric_reader] ) metrics.set_meter_provider(metric_provider) self.meter = metrics.get_meter(__name__) + self._lock = _global_lock async def initialize(self) -> None: pass async def shutdown(self) -> None: + trace.get_tracer_provider().force_flush() trace.get_tracer_provider().shutdown() metrics.get_meter_provider().shutdown() @@ -81,121 +92,117 @@ class OpenTelemetryAdapter(Telemetry): self._log_structured(event) def _log_unstructured(self, event: UnstructuredLogEvent) -> None: - span = trace.get_current_span() - span.add_event( - name=event.message, - attributes={"severity": event.severity.value, **event.attributes}, - timestamp=event.timestamp, - ) + with self._lock: + # Use global storage instead of instance storage + span_id = string_to_span_id(event.span_id) + span = _GLOBAL_STORAGE["active_spans"].get(span_id) + + if span: + timestamp_ns = int(event.timestamp.timestamp() * 1e9) + span.add_event( + name=event.type, + attributes={ + "message": event.message, + "severity": event.severity.value, + **event.attributes, + }, + timestamp=timestamp_ns, + ) + else: + print( + f"Warning: No active span found for span_id {span_id}. Dropping event: {event}" + ) + + def _get_or_create_counter(self, name: str, unit: str) -> metrics.Counter: + if name not in _GLOBAL_STORAGE["counters"]: + _GLOBAL_STORAGE["counters"][name] = self.meter.create_counter( + name=name, + unit=unit, + description=f"Counter for {name}", + ) + return _GLOBAL_STORAGE["counters"][name] + + def _get_or_create_gauge(self, name: str, unit: str) -> metrics.ObservableGauge: + if name not in _GLOBAL_STORAGE["gauges"]: + _GLOBAL_STORAGE["gauges"][name] = self.meter.create_gauge( + name=name, + unit=unit, + description=f"Gauge for {name}", + ) + return _GLOBAL_STORAGE["gauges"][name] def _log_metric(self, event: MetricEvent) -> None: if isinstance(event.value, int): - self.meter.create_counter( - name=event.metric, - unit=event.unit, - description=f"Counter for {event.metric}", - ).add(event.value, attributes=event.attributes) + counter = self._get_or_create_counter(event.metric, event.unit) + counter.add(event.value, attributes=event.attributes) elif isinstance(event.value, float): - self.meter.create_gauge( - name=event.metric, - unit=event.unit, - description=f"Gauge for {event.metric}", - ).set(event.value, attributes=event.attributes) + up_down_counter = self._get_or_create_up_down_counter( + event.metric, event.unit + ) + up_down_counter.add(event.value, attributes=event.attributes) + + def _get_or_create_up_down_counter( + self, name: str, unit: str + ) -> metrics.UpDownCounter: + if name not in _GLOBAL_STORAGE["up_down_counters"]: + _GLOBAL_STORAGE["up_down_counters"][name] = ( + self.meter.create_up_down_counter( + name=name, + unit=unit, + description=f"UpDownCounter for {name}", + ) + ) + return _GLOBAL_STORAGE["up_down_counters"][name] def _log_structured(self, event: StructuredLogEvent) -> None: - if isinstance(event.payload, SpanStartPayload): - context = trace.set_span_in_context( - trace.NonRecordingSpan( - trace.SpanContext( - trace_id=string_to_trace_id(event.trace_id), - span_id=string_to_span_id(event.span_id), - is_remote=True, - ) - ) - ) - span = self.tracer.start_span( - name=event.payload.name, - kind=trace.SpanKind.INTERNAL, - context=context, - attributes=event.attributes, - ) + with self._lock: + span_id = string_to_span_id(event.span_id) + trace_id = string_to_trace_id(event.trace_id) + tracer = trace.get_tracer(__name__) - if event.payload.parent_span_id: - span.set_parent( - trace.SpanContext( - trace_id=string_to_trace_id(event.trace_id), - span_id=string_to_span_id(event.payload.parent_span_id), - is_remote=True, + if isinstance(event.payload, SpanStartPayload): + # Check if span already exists to prevent duplicates + if span_id in _GLOBAL_STORAGE["active_spans"]: + return + + parent_span = None + if event.payload.parent_span_id: + parent_span_id = string_to_span_id(event.payload.parent_span_id) + parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id) + + # Create a new trace context with the trace_id + context = trace.Context(trace_id=trace_id) + if parent_span: + context = trace.set_span_in_context(parent_span, context) + + span = tracer.start_span( + name=event.payload.name, + context=context, + attributes=event.attributes or {}, + start_time=int(event.timestamp.timestamp() * 1e9), + ) + _GLOBAL_STORAGE["active_spans"][span_id] = span + + # Set as current span using context manager + with trace.use_span(span, end_on_exit=False): + pass # Let the span continue beyond this block + + elif isinstance(event.payload, SpanEndPayload): + span = _GLOBAL_STORAGE["active_spans"].get(span_id) + if span: + if event.attributes: + span.set_attributes(event.attributes) + + status = ( + trace.Status(status_code=trace.StatusCode.OK) + if event.payload.status == SpanStatus.OK + else trace.Status(status_code=trace.StatusCode.ERROR) ) - ) - elif isinstance(event.payload, SpanEndPayload): - span = trace.get_current_span() - span.set_status( - trace.Status( - trace.StatusCode.OK - if event.payload.status == SpanStatus.OK - else trace.StatusCode.ERROR - ) - ) - span.end(end_time=event.timestamp) + span.set_status(status) + span.end(end_time=int(event.timestamp.timestamp() * 1e9)) + + # Remove from active spans + _GLOBAL_STORAGE["active_spans"].pop(span_id, None) async def get_trace(self, trace_id: str) -> Trace: - # we need to look up the root span id - raise NotImplementedError("not yet no") - - -# Usage example -async def main(): - telemetry = OpenTelemetryTelemetry("my-service") - await telemetry.initialize() - - # Log an unstructured event - await telemetry.log_event( - UnstructuredLogEvent( - trace_id="trace123", - span_id="span456", - timestamp=datetime.now(), - message="This is a log message", - severity=LogSeverity.INFO, - ) - ) - - # Log a metric event - await telemetry.log_event( - MetricEvent( - trace_id="trace123", - span_id="span456", - timestamp=datetime.now(), - metric="my_metric", - value=42, - unit="count", - ) - ) - - # Log a structured event (span start) - await telemetry.log_event( - StructuredLogEvent( - trace_id="trace123", - span_id="span789", - timestamp=datetime.now(), - payload=SpanStartPayload(name="my_operation"), - ) - ) - - # Log a structured event (span end) - await telemetry.log_event( - StructuredLogEvent( - trace_id="trace123", - span_id="span789", - timestamp=datetime.now(), - payload=SpanEndPayload(status=SpanStatus.OK), - ) - ) - - await telemetry.shutdown() - - -if __name__ == "__main__": - import asyncio - - asyncio.run(main()) + raise NotImplementedError("Trace retrieval not implemented yet") diff --git a/llama_stack/providers/utils/telemetry/tracing.py b/llama_stack/providers/utils/telemetry/tracing.py index 3383f7a7a..b53dc0df9 100644 --- a/llama_stack/providers/utils/telemetry/tracing.py +++ b/llama_stack/providers/utils/telemetry/tracing.py @@ -20,7 +20,7 @@ from llama_stack.apis.telemetry import * # noqa: F403 log = logging.getLogger(__name__) -def generate_short_uuid(len: int = 12): +def generate_short_uuid(len: int = 8): full_uuid = uuid.uuid4() uuid_bytes = full_uuid.bytes encoded = base64.urlsafe_b64encode(uuid_bytes) @@ -123,18 +123,19 @@ def setup_logger(api: Telemetry, level: int = logging.INFO): logger.addHandler(TelemetryHandler()) -async def start_trace(name: str, attributes: Dict[str, Any] = None): +async def start_trace(name: str, attributes: Dict[str, Any] = None) -> TraceContext: global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER if BACKGROUND_LOGGER is None: log.info("No Telemetry implementation set. Skipping trace initialization...") return - trace_id = generate_short_uuid() + trace_id = generate_short_uuid(16) context = TraceContext(BACKGROUND_LOGGER, trace_id) context.push_span(name, {"__root__": True, **(attributes or {})}) CURRENT_TRACE_CONTEXT = context + return context async def end_trace(status: SpanStatus = SpanStatus.OK):