Merge branch 'api_updates_3' into fix_cli_api_updates_3

This commit is contained in:
Xi Yan 2024-09-22 22:05:22 -07:00 committed by GitHub
commit 28065ca53a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 345 additions and 80 deletions

View file

@ -6,19 +6,20 @@
import asyncio import asyncio
import json import json
from typing import Any, AsyncGenerator from typing import Any, AsyncGenerator, List, Optional
import fire import fire
import httpx import httpx
from pydantic import BaseModel from pydantic import BaseModel
from llama_models.llama3.api import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from termcolor import cprint from termcolor import cprint
from llama_stack.distribution.datatypes import RemoteProviderConfig from llama_stack.distribution.datatypes import RemoteProviderConfig
from .event_logger import EventLogger from .event_logger import EventLogger
from llama_stack.apis.inference import * # noqa: F403
async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Inference: async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Inference:
return InferenceClient(config.url) return InferenceClient(config.url)

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import OpenTelemetryConfig
async def get_adapter_impl(config: OpenTelemetryConfig, _deps):
from .opentelemetry import OpenTelemetryAdapter
impl = OpenTelemetryAdapter(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,12 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
class OpenTelemetryConfig(BaseModel):
jaeger_host: str = "localhost"
jaeger_port: int = 6831

View file

@ -0,0 +1,201 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime
from opentelemetry import metrics, trace
from opentelemetry.exporter.jaeger.thrift import JaegerExporter
from opentelemetry.sdk.metrics import MeterProvider
from opentelemetry.sdk.metrics.export import (
ConsoleMetricExporter,
PeriodicExportingMetricReader,
)
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.semconv.resource import ResourceAttributes
from llama_stack.apis.telemetry import * # noqa: F403
from .config import OpenTelemetryConfig
def string_to_trace_id(s: str) -> int:
# Convert the string to bytes and then to an integer
return int.from_bytes(s.encode(), byteorder="big", signed=False)
def string_to_span_id(s: str) -> int:
# Use only the first 8 bytes (64 bits) for span ID
return int.from_bytes(s.encode()[:8], byteorder="big", signed=False)
def is_tracing_enabled(tracer):
with tracer.start_as_current_span("check_tracing") as span:
return span.is_recording()
class OpenTelemetryAdapter(Telemetry):
def __init__(self, config: OpenTelemetryConfig):
self.config = config
self.resource = Resource.create(
{ResourceAttributes.SERVICE_NAME: "foobar-service"}
)
# Set up tracing with Jaeger exporter
jaeger_exporter = JaegerExporter(
agent_host_name=self.config.jaeger_host,
agent_port=self.config.jaeger_port,
)
trace_provider = TracerProvider(resource=self.resource)
trace_processor = BatchSpanProcessor(jaeger_exporter)
trace_provider.add_span_processor(trace_processor)
trace.set_tracer_provider(trace_provider)
self.tracer = trace.get_tracer(__name__)
# Set up metrics
metric_reader = PeriodicExportingMetricReader(ConsoleMetricExporter())
metric_provider = MeterProvider(
resource=self.resource, metric_readers=[metric_reader]
)
metrics.set_meter_provider(metric_provider)
self.meter = metrics.get_meter(__name__)
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
trace.get_tracer_provider().shutdown()
metrics.get_meter_provider().shutdown()
async def log_event(self, event: Event) -> None:
if isinstance(event, UnstructuredLogEvent):
self._log_unstructured(event)
elif isinstance(event, MetricEvent):
self._log_metric(event)
elif isinstance(event, StructuredLogEvent):
self._log_structured(event)
def _log_unstructured(self, event: UnstructuredLogEvent) -> None:
span = trace.get_current_span()
span.add_event(
name=event.message,
attributes={"severity": event.severity.value, **event.attributes},
timestamp=event.timestamp,
)
def _log_metric(self, event: MetricEvent) -> None:
if isinstance(event.value, int):
self.meter.create_counter(
name=event.metric,
unit=event.unit,
description=f"Counter for {event.metric}",
).add(event.value, attributes=event.attributes)
elif isinstance(event.value, float):
self.meter.create_gauge(
name=event.metric,
unit=event.unit,
description=f"Gauge for {event.metric}",
).set(event.value, attributes=event.attributes)
def _log_structured(self, event: StructuredLogEvent) -> None:
if isinstance(event.payload, SpanStartPayload):
context = trace.set_span_in_context(
trace.NonRecordingSpan(
trace.SpanContext(
trace_id=string_to_trace_id(event.trace_id),
span_id=string_to_span_id(event.span_id),
is_remote=True,
)
)
)
span = self.tracer.start_span(
name=event.payload.name,
kind=trace.SpanKind.INTERNAL,
context=context,
attributes=event.attributes,
)
if event.payload.parent_span_id:
span.set_parent(
trace.SpanContext(
trace_id=string_to_trace_id(event.trace_id),
span_id=string_to_span_id(event.payload.parent_span_id),
is_remote=True,
)
)
elif isinstance(event.payload, SpanEndPayload):
span = trace.get_current_span()
span.set_status(
trace.Status(
trace.StatusCode.OK
if event.payload.status == SpanStatus.OK
else trace.StatusCode.ERROR
)
)
span.end(end_time=event.timestamp)
async def get_trace(self, trace_id: str) -> Trace:
# we need to look up the root span id
raise NotImplementedError("not yet no")
# Usage example
async def main():
telemetry = OpenTelemetryTelemetry("my-service")
await telemetry.initialize()
# Log an unstructured event
await telemetry.log_event(
UnstructuredLogEvent(
trace_id="trace123",
span_id="span456",
timestamp=datetime.now(),
message="This is a log message",
severity=LogSeverity.INFO,
)
)
# Log a metric event
await telemetry.log_event(
MetricEvent(
trace_id="trace123",
span_id="span456",
timestamp=datetime.now(),
metric="my_metric",
value=42,
unit="count",
)
)
# Log a structured event (span start)
await telemetry.log_event(
StructuredLogEvent(
trace_id="trace123",
span_id="span789",
timestamp=datetime.now(),
payload=SpanStartPayload(name="my_operation"),
)
)
# Log a structured event (span end)
await telemetry.log_event(
StructuredLogEvent(
trace_id="trace123",
span_id="span789",
timestamp=datetime.now(),
payload=SpanEndPayload(status=SpanStatus.OK),
)
)
await telemetry.shutdown()
if __name__ == "__main__":
import asyncio
asyncio.run(main())

View file

@ -26,6 +26,7 @@ from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403
from llama_stack.providers.utils.kvstore import KVStore from llama_stack.providers.utils.kvstore import KVStore
from llama_stack.providers.utils.telemetry import tracing
from .persistence import AgentPersistence from .persistence import AgentPersistence
from .rag.context_retriever import generate_rag_query from .rag.context_retriever import generate_rag_query
@ -138,6 +139,7 @@ class ChatAgent(ShieldRunnerMixin):
async def create_session(self, name: str) -> str: async def create_session(self, name: str) -> str:
return await self.storage.create_session(name) return await self.storage.create_session(name)
@tracing.span("create_and_execute_turn")
async def create_and_execute_turn( async def create_and_execute_turn(
self, request: AgentTurnCreateRequest self, request: AgentTurnCreateRequest
) -> AsyncGenerator: ) -> AsyncGenerator:
@ -266,6 +268,7 @@ class ChatAgent(ShieldRunnerMixin):
yield final_response yield final_response
@tracing.span("run_shields")
async def run_multiple_shields_wrapper( async def run_multiple_shields_wrapper(
self, self,
turn_id: str, turn_id: str,
@ -348,9 +351,10 @@ class ChatAgent(ShieldRunnerMixin):
# TODO: find older context from the session and either replace it # TODO: find older context from the session and either replace it
# or append with a sliding window. this is really a very simplistic implementation # or append with a sliding window. this is really a very simplistic implementation
rag_context, bank_ids = await self._retrieve_context( with tracing.span("retrieve_rag_context"):
session_id, input_messages, attachments rag_context, bank_ids = await self._retrieve_context(
) session_id, input_messages, attachments
)
step_id = str(uuid.uuid4()) step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
@ -403,55 +407,57 @@ class ChatAgent(ShieldRunnerMixin):
tool_calls = [] tool_calls = []
content = "" content = ""
stop_reason = None stop_reason = None
async for chunk in self.inference_api.chat_completion(
self.agent_config.model,
input_messages,
tools=self._get_tools(),
tool_prompt_format=self.agent_config.tool_prompt_format,
stream=True,
sampling_params=sampling_params,
):
event = chunk.event
if event.event_type == ChatCompletionResponseEventType.start:
continue
elif event.event_type == ChatCompletionResponseEventType.complete:
stop_reason = StopReason.end_of_turn
continue
delta = event.delta with tracing.span("inference"):
if isinstance(delta, ToolCallDelta): async for chunk in self.inference_api.chat_completion(
if delta.parse_status == ToolCallParseStatus.success: self.agent_config.model,
tool_calls.append(delta.content) input_messages,
tools=self._get_tools(),
tool_prompt_format=self.agent_config.tool_prompt_format,
stream=True,
sampling_params=sampling_params,
):
event = chunk.event
if event.event_type == ChatCompletionResponseEventType.start:
continue
elif event.event_type == ChatCompletionResponseEventType.complete:
stop_reason = StopReason.end_of_turn
continue
if stream: delta = event.delta
yield AgentTurnResponseStreamChunk( if isinstance(delta, ToolCallDelta):
event=AgentTurnResponseEvent( if delta.parse_status == ToolCallParseStatus.success:
payload=AgentTurnResponseStepProgressPayload( tool_calls.append(delta.content)
step_type=StepType.inference.value,
step_id=step_id, if stream:
model_response_text_delta="", yield AgentTurnResponseStreamChunk(
tool_call_delta=delta, event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.inference.value,
step_id=step_id,
model_response_text_delta="",
tool_call_delta=delta,
)
) )
) )
)
elif isinstance(delta, str): elif isinstance(delta, str):
content += delta content += delta
if stream and event.stop_reason is None: if stream and event.stop_reason is None:
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent( event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload( payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.inference.value, step_type=StepType.inference.value,
step_id=step_id, step_id=step_id,
model_response_text_delta=event.delta, model_response_text_delta=event.delta,
)
) )
) )
) else:
else: raise ValueError(f"Unexpected delta type {type(delta)}")
raise ValueError(f"Unexpected delta type {type(delta)}")
if event.stop_reason is not None: if event.stop_reason is not None:
stop_reason = event.stop_reason stop_reason = event.stop_reason
stop_reason = stop_reason or StopReason.out_of_tokens stop_reason = stop_reason or StopReason.out_of_tokens
message = CompletionMessage( message = CompletionMessage(
@ -528,14 +534,15 @@ class ChatAgent(ShieldRunnerMixin):
) )
) )
result_messages = await execute_tool_call_maybe( with tracing.span("tool_execution"):
self.tools_dict, result_messages = await execute_tool_call_maybe(
[message], self.tools_dict,
) [message],
assert ( )
len(result_messages) == 1 assert (
), "Currently not supporting multiple messages" len(result_messages) == 1
result_message = result_messages[0] ), "Currently not supporting multiple messages"
result_message = result_messages[0]
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent( event=AgentTurnResponseEvent(
@ -669,7 +676,8 @@ class ChatAgent(ShieldRunnerMixin):
) )
for a in attachments for a in attachments
] ]
await self.memory_api.insert_documents(bank_id, documents) with tracing.span("insert_documents"):
await self.memory_api.insert_documents(bank_id, documents)
else: else:
session_info = await self.storage.get_session_info(session_id) session_info = await self.storage.get_session_info(session_id)
if session_info.memory_bank_id: if session_info.memory_bank_id:

View file

@ -27,4 +27,18 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.adapters.telemetry.sample.SampleConfig", config_class="llama_stack.providers.adapters.telemetry.sample.SampleConfig",
), ),
), ),
remote_provider_spec(
api=Api.telemetry,
adapter=AdapterSpec(
adapter_id="opentelemetry-jaeger",
pip_packages=[
"opentelemetry-api",
"opentelemetry-sdk",
"opentelemetry-exporter-jaeger",
"opentelemetry-semantic-conventions",
],
module="llama_stack.providers.adapters.telemetry.opentelemetry",
config_class="llama_stack.providers.adapters.telemetry.opentelemetry.OpenTelemetryConfig",
),
),
] ]

View file

@ -12,7 +12,7 @@ import threading
import uuid import uuid
from datetime import datetime from datetime import datetime
from functools import wraps from functools import wraps
from typing import Any, Dict, List from typing import Any, Callable, Dict, List
from llama_stack.apis.telemetry import * # noqa: F403 from llama_stack.apis.telemetry import * # noqa: F403
@ -196,33 +196,40 @@ class TelemetryHandler(logging.Handler):
pass pass
def span(name: str, attributes: Dict[str, Any] = None): class SpanContextManager:
def decorator(func): def __init__(self, name: str, attributes: Dict[str, Any] = None):
self.name = name
self.attributes = attributes
def __enter__(self):
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT
if context:
context.push_span(self.name, self.attributes)
return self
def __exit__(self, exc_type, exc_value, traceback):
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT
if context:
context.pop_span()
async def __aenter__(self):
return self.__enter__()
async def __aexit__(self, exc_type, exc_value, traceback):
self.__exit__(exc_type, exc_value, traceback)
def __call__(self, func: Callable):
@wraps(func) @wraps(func)
def sync_wrapper(*args, **kwargs): def sync_wrapper(*args, **kwargs):
try: with self:
global CURRENT_TRACE_CONTEXT return func(*args, **kwargs)
context = CURRENT_TRACE_CONTEXT
if context:
context.push_span(name, attributes)
result = func(*args, **kwargs)
finally:
context.pop_span()
return result
@wraps(func) @wraps(func)
async def async_wrapper(*args, **kwargs): async def async_wrapper(*args, **kwargs):
try: async with self:
global CURRENT_TRACE_CONTEXT return await func(*args, **kwargs)
context = CURRENT_TRACE_CONTEXT
if context:
context.push_span(name, attributes)
result = await func(*args, **kwargs)
finally:
context.pop_span()
return result
@wraps(func) @wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
@ -233,4 +240,6 @@ def span(name: str, attributes: Dict[str, Any] = None):
return wrapper return wrapper
return decorator
def span(name: str, attributes: Dict[str, Any] = None):
return SpanContextManager(name, attributes)