diff --git a/llama_stack/apis/inference/client.py b/llama_stack/apis/inference/client.py index f7de36a74..4df138841 100644 --- a/llama_stack/apis/inference/client.py +++ b/llama_stack/apis/inference/client.py @@ -6,19 +6,20 @@ import asyncio import json -from typing import Any, AsyncGenerator +from typing import Any, AsyncGenerator, List, Optional import fire import httpx from pydantic import BaseModel + +from llama_models.llama3.api import * # noqa: F403 +from llama_stack.apis.inference import * # noqa: F403 from termcolor import cprint from llama_stack.distribution.datatypes import RemoteProviderConfig from .event_logger import EventLogger -from llama_stack.apis.inference import * # noqa: F403 - async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Inference: return InferenceClient(config.url) diff --git a/llama_stack/providers/adapters/telemetry/__init__.py b/llama_stack/providers/adapters/telemetry/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/providers/adapters/telemetry/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/providers/adapters/telemetry/opentelemetry/__init__.py b/llama_stack/providers/adapters/telemetry/opentelemetry/__init__.py new file mode 100644 index 000000000..0842afe2d --- /dev/null +++ b/llama_stack/providers/adapters/telemetry/opentelemetry/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .config import OpenTelemetryConfig + + +async def get_adapter_impl(config: OpenTelemetryConfig, _deps): + from .opentelemetry import OpenTelemetryAdapter + + impl = OpenTelemetryAdapter(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/adapters/telemetry/opentelemetry/config.py b/llama_stack/providers/adapters/telemetry/opentelemetry/config.py new file mode 100644 index 000000000..71a82aed9 --- /dev/null +++ b/llama_stack/providers/adapters/telemetry/opentelemetry/config.py @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pydantic import BaseModel + + +class OpenTelemetryConfig(BaseModel): + jaeger_host: str = "localhost" + jaeger_port: int = 6831 diff --git a/llama_stack/providers/adapters/telemetry/opentelemetry/opentelemetry.py b/llama_stack/providers/adapters/telemetry/opentelemetry/opentelemetry.py new file mode 100644 index 000000000..03e8f7d53 --- /dev/null +++ b/llama_stack/providers/adapters/telemetry/opentelemetry/opentelemetry.py @@ -0,0 +1,201 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from datetime import datetime + +from opentelemetry import metrics, trace +from opentelemetry.exporter.jaeger.thrift import JaegerExporter +from opentelemetry.sdk.metrics import MeterProvider +from opentelemetry.sdk.metrics.export import ( + ConsoleMetricExporter, + PeriodicExportingMetricReader, +) +from opentelemetry.sdk.resources import Resource +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import BatchSpanProcessor +from opentelemetry.semconv.resource import ResourceAttributes + +from llama_stack.apis.telemetry import * # noqa: F403 + +from .config import OpenTelemetryConfig + + +def string_to_trace_id(s: str) -> int: + # Convert the string to bytes and then to an integer + return int.from_bytes(s.encode(), byteorder="big", signed=False) + + +def string_to_span_id(s: str) -> int: + # Use only the first 8 bytes (64 bits) for span ID + return int.from_bytes(s.encode()[:8], byteorder="big", signed=False) + + +def is_tracing_enabled(tracer): + with tracer.start_as_current_span("check_tracing") as span: + return span.is_recording() + + +class OpenTelemetryAdapter(Telemetry): + def __init__(self, config: OpenTelemetryConfig): + self.config = config + + self.resource = Resource.create( + {ResourceAttributes.SERVICE_NAME: "foobar-service"} + ) + + # Set up tracing with Jaeger exporter + jaeger_exporter = JaegerExporter( + agent_host_name=self.config.jaeger_host, + agent_port=self.config.jaeger_port, + ) + trace_provider = TracerProvider(resource=self.resource) + trace_processor = BatchSpanProcessor(jaeger_exporter) + trace_provider.add_span_processor(trace_processor) + trace.set_tracer_provider(trace_provider) + self.tracer = trace.get_tracer(__name__) + + # Set up metrics + metric_reader = PeriodicExportingMetricReader(ConsoleMetricExporter()) + metric_provider = MeterProvider( + resource=self.resource, metric_readers=[metric_reader] + ) + metrics.set_meter_provider(metric_provider) + self.meter = metrics.get_meter(__name__) + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + trace.get_tracer_provider().shutdown() + metrics.get_meter_provider().shutdown() + + async def log_event(self, event: Event) -> None: + if isinstance(event, UnstructuredLogEvent): + self._log_unstructured(event) + elif isinstance(event, MetricEvent): + self._log_metric(event) + elif isinstance(event, StructuredLogEvent): + self._log_structured(event) + + def _log_unstructured(self, event: UnstructuredLogEvent) -> None: + span = trace.get_current_span() + span.add_event( + name=event.message, + attributes={"severity": event.severity.value, **event.attributes}, + timestamp=event.timestamp, + ) + + def _log_metric(self, event: MetricEvent) -> None: + if isinstance(event.value, int): + self.meter.create_counter( + name=event.metric, + unit=event.unit, + description=f"Counter for {event.metric}", + ).add(event.value, attributes=event.attributes) + elif isinstance(event.value, float): + self.meter.create_gauge( + name=event.metric, + unit=event.unit, + description=f"Gauge for {event.metric}", + ).set(event.value, attributes=event.attributes) + + def _log_structured(self, event: StructuredLogEvent) -> None: + if isinstance(event.payload, SpanStartPayload): + context = trace.set_span_in_context( + trace.NonRecordingSpan( + trace.SpanContext( + trace_id=string_to_trace_id(event.trace_id), + span_id=string_to_span_id(event.span_id), + is_remote=True, + ) + ) + ) + span = self.tracer.start_span( + name=event.payload.name, + kind=trace.SpanKind.INTERNAL, + context=context, + attributes=event.attributes, + ) + + if event.payload.parent_span_id: + span.set_parent( + trace.SpanContext( + trace_id=string_to_trace_id(event.trace_id), + span_id=string_to_span_id(event.payload.parent_span_id), + is_remote=True, + ) + ) + elif isinstance(event.payload, SpanEndPayload): + span = trace.get_current_span() + span.set_status( + trace.Status( + trace.StatusCode.OK + if event.payload.status == SpanStatus.OK + else trace.StatusCode.ERROR + ) + ) + span.end(end_time=event.timestamp) + + async def get_trace(self, trace_id: str) -> Trace: + # we need to look up the root span id + raise NotImplementedError("not yet no") + + +# Usage example +async def main(): + telemetry = OpenTelemetryTelemetry("my-service") + await telemetry.initialize() + + # Log an unstructured event + await telemetry.log_event( + UnstructuredLogEvent( + trace_id="trace123", + span_id="span456", + timestamp=datetime.now(), + message="This is a log message", + severity=LogSeverity.INFO, + ) + ) + + # Log a metric event + await telemetry.log_event( + MetricEvent( + trace_id="trace123", + span_id="span456", + timestamp=datetime.now(), + metric="my_metric", + value=42, + unit="count", + ) + ) + + # Log a structured event (span start) + await telemetry.log_event( + StructuredLogEvent( + trace_id="trace123", + span_id="span789", + timestamp=datetime.now(), + payload=SpanStartPayload(name="my_operation"), + ) + ) + + # Log a structured event (span end) + await telemetry.log_event( + StructuredLogEvent( + trace_id="trace123", + span_id="span789", + timestamp=datetime.now(), + payload=SpanEndPayload(status=SpanStatus.OK), + ) + ) + + await telemetry.shutdown() + + +if __name__ == "__main__": + import asyncio + + asyncio.run(main()) diff --git a/llama_stack/providers/impls/meta_reference/agents/agent_instance.py b/llama_stack/providers/impls/meta_reference/agents/agent_instance.py index 6e9fc7c63..7d949603e 100644 --- a/llama_stack/providers/impls/meta_reference/agents/agent_instance.py +++ b/llama_stack/providers/impls/meta_reference/agents/agent_instance.py @@ -26,6 +26,7 @@ from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403 from llama_stack.providers.utils.kvstore import KVStore +from llama_stack.providers.utils.telemetry import tracing from .persistence import AgentPersistence from .rag.context_retriever import generate_rag_query @@ -138,6 +139,7 @@ class ChatAgent(ShieldRunnerMixin): async def create_session(self, name: str) -> str: return await self.storage.create_session(name) + @tracing.span("create_and_execute_turn") async def create_and_execute_turn( self, request: AgentTurnCreateRequest ) -> AsyncGenerator: @@ -266,6 +268,7 @@ class ChatAgent(ShieldRunnerMixin): yield final_response + @tracing.span("run_shields") async def run_multiple_shields_wrapper( self, turn_id: str, @@ -348,9 +351,10 @@ class ChatAgent(ShieldRunnerMixin): # TODO: find older context from the session and either replace it # or append with a sliding window. this is really a very simplistic implementation - rag_context, bank_ids = await self._retrieve_context( - session_id, input_messages, attachments - ) + with tracing.span("retrieve_rag_context"): + rag_context, bank_ids = await self._retrieve_context( + session_id, input_messages, attachments + ) step_id = str(uuid.uuid4()) yield AgentTurnResponseStreamChunk( @@ -403,55 +407,57 @@ class ChatAgent(ShieldRunnerMixin): tool_calls = [] content = "" stop_reason = None - async for chunk in self.inference_api.chat_completion( - self.agent_config.model, - input_messages, - tools=self._get_tools(), - tool_prompt_format=self.agent_config.tool_prompt_format, - stream=True, - sampling_params=sampling_params, - ): - event = chunk.event - if event.event_type == ChatCompletionResponseEventType.start: - continue - elif event.event_type == ChatCompletionResponseEventType.complete: - stop_reason = StopReason.end_of_turn - continue - delta = event.delta - if isinstance(delta, ToolCallDelta): - if delta.parse_status == ToolCallParseStatus.success: - tool_calls.append(delta.content) + with tracing.span("inference"): + async for chunk in self.inference_api.chat_completion( + self.agent_config.model, + input_messages, + tools=self._get_tools(), + tool_prompt_format=self.agent_config.tool_prompt_format, + stream=True, + sampling_params=sampling_params, + ): + event = chunk.event + if event.event_type == ChatCompletionResponseEventType.start: + continue + elif event.event_type == ChatCompletionResponseEventType.complete: + stop_reason = StopReason.end_of_turn + continue - if stream: - yield AgentTurnResponseStreamChunk( - event=AgentTurnResponseEvent( - payload=AgentTurnResponseStepProgressPayload( - step_type=StepType.inference.value, - step_id=step_id, - model_response_text_delta="", - tool_call_delta=delta, + delta = event.delta + if isinstance(delta, ToolCallDelta): + if delta.parse_status == ToolCallParseStatus.success: + tool_calls.append(delta.content) + + if stream: + yield AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseStepProgressPayload( + step_type=StepType.inference.value, + step_id=step_id, + model_response_text_delta="", + tool_call_delta=delta, + ) ) ) - ) - elif isinstance(delta, str): - content += delta - if stream and event.stop_reason is None: - yield AgentTurnResponseStreamChunk( - event=AgentTurnResponseEvent( - payload=AgentTurnResponseStepProgressPayload( - step_type=StepType.inference.value, - step_id=step_id, - model_response_text_delta=event.delta, + elif isinstance(delta, str): + content += delta + if stream and event.stop_reason is None: + yield AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseStepProgressPayload( + step_type=StepType.inference.value, + step_id=step_id, + model_response_text_delta=event.delta, + ) ) ) - ) - else: - raise ValueError(f"Unexpected delta type {type(delta)}") + else: + raise ValueError(f"Unexpected delta type {type(delta)}") - if event.stop_reason is not None: - stop_reason = event.stop_reason + if event.stop_reason is not None: + stop_reason = event.stop_reason stop_reason = stop_reason or StopReason.out_of_tokens message = CompletionMessage( @@ -528,14 +534,15 @@ class ChatAgent(ShieldRunnerMixin): ) ) - result_messages = await execute_tool_call_maybe( - self.tools_dict, - [message], - ) - assert ( - len(result_messages) == 1 - ), "Currently not supporting multiple messages" - result_message = result_messages[0] + with tracing.span("tool_execution"): + result_messages = await execute_tool_call_maybe( + self.tools_dict, + [message], + ) + assert ( + len(result_messages) == 1 + ), "Currently not supporting multiple messages" + result_message = result_messages[0] yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( @@ -669,7 +676,8 @@ class ChatAgent(ShieldRunnerMixin): ) for a in attachments ] - await self.memory_api.insert_documents(bank_id, documents) + with tracing.span("insert_documents"): + await self.memory_api.insert_documents(bank_id, documents) else: session_info = await self.storage.get_session_info(session_id) if session_info.memory_bank_id: diff --git a/llama_stack/providers/registry/telemetry.py b/llama_stack/providers/registry/telemetry.py index 363578749..02b71077e 100644 --- a/llama_stack/providers/registry/telemetry.py +++ b/llama_stack/providers/registry/telemetry.py @@ -27,4 +27,18 @@ def available_providers() -> List[ProviderSpec]: config_class="llama_stack.providers.adapters.telemetry.sample.SampleConfig", ), ), + remote_provider_spec( + api=Api.telemetry, + adapter=AdapterSpec( + adapter_id="opentelemetry-jaeger", + pip_packages=[ + "opentelemetry-api", + "opentelemetry-sdk", + "opentelemetry-exporter-jaeger", + "opentelemetry-semantic-conventions", + ], + module="llama_stack.providers.adapters.telemetry.opentelemetry", + config_class="llama_stack.providers.adapters.telemetry.opentelemetry.OpenTelemetryConfig", + ), + ), ] diff --git a/llama_stack/providers/utils/telemetry/tracing.py b/llama_stack/providers/utils/telemetry/tracing.py index 5284dfac0..9fffc0f99 100644 --- a/llama_stack/providers/utils/telemetry/tracing.py +++ b/llama_stack/providers/utils/telemetry/tracing.py @@ -12,7 +12,7 @@ import threading import uuid from datetime import datetime from functools import wraps -from typing import Any, Dict, List +from typing import Any, Callable, Dict, List from llama_stack.apis.telemetry import * # noqa: F403 @@ -196,33 +196,40 @@ class TelemetryHandler(logging.Handler): pass -def span(name: str, attributes: Dict[str, Any] = None): - def decorator(func): +class SpanContextManager: + def __init__(self, name: str, attributes: Dict[str, Any] = None): + self.name = name + self.attributes = attributes + + def __enter__(self): + global CURRENT_TRACE_CONTEXT + context = CURRENT_TRACE_CONTEXT + if context: + context.push_span(self.name, self.attributes) + return self + + def __exit__(self, exc_type, exc_value, traceback): + global CURRENT_TRACE_CONTEXT + context = CURRENT_TRACE_CONTEXT + if context: + context.pop_span() + + async def __aenter__(self): + return self.__enter__() + + async def __aexit__(self, exc_type, exc_value, traceback): + self.__exit__(exc_type, exc_value, traceback) + + def __call__(self, func: Callable): @wraps(func) def sync_wrapper(*args, **kwargs): - try: - global CURRENT_TRACE_CONTEXT - - context = CURRENT_TRACE_CONTEXT - if context: - context.push_span(name, attributes) - result = func(*args, **kwargs) - finally: - context.pop_span() - return result + with self: + return func(*args, **kwargs) @wraps(func) async def async_wrapper(*args, **kwargs): - try: - global CURRENT_TRACE_CONTEXT - - context = CURRENT_TRACE_CONTEXT - if context: - context.push_span(name, attributes) - result = await func(*args, **kwargs) - finally: - context.pop_span() - return result + async with self: + return await func(*args, **kwargs) @wraps(func) def wrapper(*args, **kwargs): @@ -233,4 +240,6 @@ def span(name: str, attributes: Dict[str, Any] = None): return wrapper - return decorator + +def span(name: str, attributes: Dict[str, Any] = None): + return SpanContextManager(name, attributes)