Merge branch 'api_updates_3' into fix_cli_api_updates_3

This commit is contained in:
Xi Yan 2024-09-22 22:05:22 -07:00 committed by GitHub
commit 28065ca53a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 345 additions and 80 deletions

View file

@ -6,19 +6,20 @@
import asyncio import asyncio
import json import json
from typing import Any, AsyncGenerator from typing import Any, AsyncGenerator, List, Optional
import fire import fire
import httpx import httpx
from pydantic import BaseModel from pydantic import BaseModel
from llama_models.llama3.api import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from termcolor import cprint from termcolor import cprint
from llama_stack.distribution.datatypes import RemoteProviderConfig from llama_stack.distribution.datatypes import RemoteProviderConfig
from .event_logger import EventLogger from .event_logger import EventLogger
from llama_stack.apis.inference import * # noqa: F403
async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Inference: async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Inference:
return InferenceClient(config.url) return InferenceClient(config.url)

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import OpenTelemetryConfig
async def get_adapter_impl(config: OpenTelemetryConfig, _deps):
from .opentelemetry import OpenTelemetryAdapter
impl = OpenTelemetryAdapter(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,12 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
class OpenTelemetryConfig(BaseModel):
jaeger_host: str = "localhost"
jaeger_port: int = 6831

View file

@ -0,0 +1,201 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime
from opentelemetry import metrics, trace
from opentelemetry.exporter.jaeger.thrift import JaegerExporter
from opentelemetry.sdk.metrics import MeterProvider
from opentelemetry.sdk.metrics.export import (
ConsoleMetricExporter,
PeriodicExportingMetricReader,
)
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.semconv.resource import ResourceAttributes
from llama_stack.apis.telemetry import * # noqa: F403
from .config import OpenTelemetryConfig
def string_to_trace_id(s: str) -> int:
# Convert the string to bytes and then to an integer
return int.from_bytes(s.encode(), byteorder="big", signed=False)
def string_to_span_id(s: str) -> int:
# Use only the first 8 bytes (64 bits) for span ID
return int.from_bytes(s.encode()[:8], byteorder="big", signed=False)
def is_tracing_enabled(tracer):
with tracer.start_as_current_span("check_tracing") as span:
return span.is_recording()
class OpenTelemetryAdapter(Telemetry):
def __init__(self, config: OpenTelemetryConfig):
self.config = config
self.resource = Resource.create(
{ResourceAttributes.SERVICE_NAME: "foobar-service"}
)
# Set up tracing with Jaeger exporter
jaeger_exporter = JaegerExporter(
agent_host_name=self.config.jaeger_host,
agent_port=self.config.jaeger_port,
)
trace_provider = TracerProvider(resource=self.resource)
trace_processor = BatchSpanProcessor(jaeger_exporter)
trace_provider.add_span_processor(trace_processor)
trace.set_tracer_provider(trace_provider)
self.tracer = trace.get_tracer(__name__)
# Set up metrics
metric_reader = PeriodicExportingMetricReader(ConsoleMetricExporter())
metric_provider = MeterProvider(
resource=self.resource, metric_readers=[metric_reader]
)
metrics.set_meter_provider(metric_provider)
self.meter = metrics.get_meter(__name__)
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
trace.get_tracer_provider().shutdown()
metrics.get_meter_provider().shutdown()
async def log_event(self, event: Event) -> None:
if isinstance(event, UnstructuredLogEvent):
self._log_unstructured(event)
elif isinstance(event, MetricEvent):
self._log_metric(event)
elif isinstance(event, StructuredLogEvent):
self._log_structured(event)
def _log_unstructured(self, event: UnstructuredLogEvent) -> None:
span = trace.get_current_span()
span.add_event(
name=event.message,
attributes={"severity": event.severity.value, **event.attributes},
timestamp=event.timestamp,
)
def _log_metric(self, event: MetricEvent) -> None:
if isinstance(event.value, int):
self.meter.create_counter(
name=event.metric,
unit=event.unit,
description=f"Counter for {event.metric}",
).add(event.value, attributes=event.attributes)
elif isinstance(event.value, float):
self.meter.create_gauge(
name=event.metric,
unit=event.unit,
description=f"Gauge for {event.metric}",
).set(event.value, attributes=event.attributes)
def _log_structured(self, event: StructuredLogEvent) -> None:
if isinstance(event.payload, SpanStartPayload):
context = trace.set_span_in_context(
trace.NonRecordingSpan(
trace.SpanContext(
trace_id=string_to_trace_id(event.trace_id),
span_id=string_to_span_id(event.span_id),
is_remote=True,
)
)
)
span = self.tracer.start_span(
name=event.payload.name,
kind=trace.SpanKind.INTERNAL,
context=context,
attributes=event.attributes,
)
if event.payload.parent_span_id:
span.set_parent(
trace.SpanContext(
trace_id=string_to_trace_id(event.trace_id),
span_id=string_to_span_id(event.payload.parent_span_id),
is_remote=True,
)
)
elif isinstance(event.payload, SpanEndPayload):
span = trace.get_current_span()
span.set_status(
trace.Status(
trace.StatusCode.OK
if event.payload.status == SpanStatus.OK
else trace.StatusCode.ERROR
)
)
span.end(end_time=event.timestamp)
async def get_trace(self, trace_id: str) -> Trace:
# we need to look up the root span id
raise NotImplementedError("not yet no")
# Usage example
async def main():
telemetry = OpenTelemetryTelemetry("my-service")
await telemetry.initialize()
# Log an unstructured event
await telemetry.log_event(
UnstructuredLogEvent(
trace_id="trace123",
span_id="span456",
timestamp=datetime.now(),
message="This is a log message",
severity=LogSeverity.INFO,
)
)
# Log a metric event
await telemetry.log_event(
MetricEvent(
trace_id="trace123",
span_id="span456",
timestamp=datetime.now(),
metric="my_metric",
value=42,
unit="count",
)
)
# Log a structured event (span start)
await telemetry.log_event(
StructuredLogEvent(
trace_id="trace123",
span_id="span789",
timestamp=datetime.now(),
payload=SpanStartPayload(name="my_operation"),
)
)
# Log a structured event (span end)
await telemetry.log_event(
StructuredLogEvent(
trace_id="trace123",
span_id="span789",
timestamp=datetime.now(),
payload=SpanEndPayload(status=SpanStatus.OK),
)
)
await telemetry.shutdown()
if __name__ == "__main__":
import asyncio
asyncio.run(main())

View file

@ -26,6 +26,7 @@ from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403
from llama_stack.providers.utils.kvstore import KVStore from llama_stack.providers.utils.kvstore import KVStore
from llama_stack.providers.utils.telemetry import tracing
from .persistence import AgentPersistence from .persistence import AgentPersistence
from .rag.context_retriever import generate_rag_query from .rag.context_retriever import generate_rag_query
@ -138,6 +139,7 @@ class ChatAgent(ShieldRunnerMixin):
async def create_session(self, name: str) -> str: async def create_session(self, name: str) -> str:
return await self.storage.create_session(name) return await self.storage.create_session(name)
@tracing.span("create_and_execute_turn")
async def create_and_execute_turn( async def create_and_execute_turn(
self, request: AgentTurnCreateRequest self, request: AgentTurnCreateRequest
) -> AsyncGenerator: ) -> AsyncGenerator:
@ -266,6 +268,7 @@ class ChatAgent(ShieldRunnerMixin):
yield final_response yield final_response
@tracing.span("run_shields")
async def run_multiple_shields_wrapper( async def run_multiple_shields_wrapper(
self, self,
turn_id: str, turn_id: str,
@ -348,6 +351,7 @@ class ChatAgent(ShieldRunnerMixin):
# TODO: find older context from the session and either replace it # TODO: find older context from the session and either replace it
# or append with a sliding window. this is really a very simplistic implementation # or append with a sliding window. this is really a very simplistic implementation
with tracing.span("retrieve_rag_context"):
rag_context, bank_ids = await self._retrieve_context( rag_context, bank_ids = await self._retrieve_context(
session_id, input_messages, attachments session_id, input_messages, attachments
) )
@ -403,6 +407,8 @@ class ChatAgent(ShieldRunnerMixin):
tool_calls = [] tool_calls = []
content = "" content = ""
stop_reason = None stop_reason = None
with tracing.span("inference"):
async for chunk in self.inference_api.chat_completion( async for chunk in self.inference_api.chat_completion(
self.agent_config.model, self.agent_config.model,
input_messages, input_messages,
@ -528,6 +534,7 @@ class ChatAgent(ShieldRunnerMixin):
) )
) )
with tracing.span("tool_execution"):
result_messages = await execute_tool_call_maybe( result_messages = await execute_tool_call_maybe(
self.tools_dict, self.tools_dict,
[message], [message],
@ -669,6 +676,7 @@ class ChatAgent(ShieldRunnerMixin):
) )
for a in attachments for a in attachments
] ]
with tracing.span("insert_documents"):
await self.memory_api.insert_documents(bank_id, documents) await self.memory_api.insert_documents(bank_id, documents)
else: else:
session_info = await self.storage.get_session_info(session_id) session_info = await self.storage.get_session_info(session_id)

View file

@ -27,4 +27,18 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.adapters.telemetry.sample.SampleConfig", config_class="llama_stack.providers.adapters.telemetry.sample.SampleConfig",
), ),
), ),
remote_provider_spec(
api=Api.telemetry,
adapter=AdapterSpec(
adapter_id="opentelemetry-jaeger",
pip_packages=[
"opentelemetry-api",
"opentelemetry-sdk",
"opentelemetry-exporter-jaeger",
"opentelemetry-semantic-conventions",
],
module="llama_stack.providers.adapters.telemetry.opentelemetry",
config_class="llama_stack.providers.adapters.telemetry.opentelemetry.OpenTelemetryConfig",
),
),
] ]

View file

@ -12,7 +12,7 @@ import threading
import uuid import uuid
from datetime import datetime from datetime import datetime
from functools import wraps from functools import wraps
from typing import Any, Dict, List from typing import Any, Callable, Dict, List
from llama_stack.apis.telemetry import * # noqa: F403 from llama_stack.apis.telemetry import * # noqa: F403
@ -196,33 +196,40 @@ class TelemetryHandler(logging.Handler):
pass pass
def span(name: str, attributes: Dict[str, Any] = None): class SpanContextManager:
def decorator(func): def __init__(self, name: str, attributes: Dict[str, Any] = None):
@wraps(func) self.name = name
def sync_wrapper(*args, **kwargs): self.attributes = attributes
try:
global CURRENT_TRACE_CONTEXT
def __enter__(self):
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT
if context:
context.push_span(self.name, self.attributes)
return self
def __exit__(self, exc_type, exc_value, traceback):
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT context = CURRENT_TRACE_CONTEXT
if context: if context:
context.push_span(name, attributes)
result = func(*args, **kwargs)
finally:
context.pop_span() context.pop_span()
return result
async def __aenter__(self):
return self.__enter__()
async def __aexit__(self, exc_type, exc_value, traceback):
self.__exit__(exc_type, exc_value, traceback)
def __call__(self, func: Callable):
@wraps(func)
def sync_wrapper(*args, **kwargs):
with self:
return func(*args, **kwargs)
@wraps(func) @wraps(func)
async def async_wrapper(*args, **kwargs): async def async_wrapper(*args, **kwargs):
try: async with self:
global CURRENT_TRACE_CONTEXT return await func(*args, **kwargs)
context = CURRENT_TRACE_CONTEXT
if context:
context.push_span(name, attributes)
result = await func(*args, **kwargs)
finally:
context.pop_span()
return result
@wraps(func) @wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
@ -233,4 +240,6 @@ def span(name: str, attributes: Dict[str, Any] = None):
return wrapper return wrapper
return decorator
def span(name: str, attributes: Dict[str, Any] = None):
return SpanContextManager(name, attributes)