mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
Telemetry API redesign (#525)
# What does this PR do? Change the Telemetry API to be able to support different use cases like returning traces for the UI and ability to export for Evals. Other changes: * Add a new trace_protocol decorator to decorate all our API methods so that any call to them will automatically get traced across all impls. * There is some issue with the decorator pattern of span creation when using async generators, where there are multiple yields with in the same context. I think its much more explicit by using the explicit context manager pattern using with. I moved the span creations in agent instance to be using with * Inject session id at the turn level, which should quickly give us all traces across turns for a given session Addresses #509 ## Test Plan ``` llama stack run /Users/dineshyv/.llama/distributions/llamastack-together/together-run.yaml PYTHONPATH=. python -m examples.agents.rag_with_memory_bank localhost 5000 curl -X POST 'http://localhost:5000/alpha/telemetry/query-traces' \ -H 'Content-Type: application/json' \ -d '{ "attribute_filters": [ { "key": "session_id", "op": "eq", "value": "dd667b87-ca4b-4d30-9265-5a0de318fc65" }], "limit": 100, "offset": 0, "order_by": ["start_time"] }' | jq . [ { "trace_id": "6902f54b83b4b48be18a6f422b13e16f", "root_span_id": "5f37b85543afc15a", "start_time": "2024-12-04T08:08:30.501587", "end_time": "2024-12-04T08:08:36.026463" }, { "trace_id": "92227dac84c0615ed741be393813fb5f", "root_span_id": "af7c5bb46665c2c8", "start_time": "2024-12-04T08:08:36.031170", "end_time": "2024-12-04T08:08:41.693301" }, { "trace_id": "7d578a6edac62f204ab479fba82f77b6", "root_span_id": "1d935e3362676896", "start_time": "2024-12-04T08:08:41.695204", "end_time": "2024-12-04T08:08:47.228016" }, { "trace_id": "dbd767d76991bc816f9f078907dc9ff2", "root_span_id": "f5a7ee76683b9602", "start_time": "2024-12-04T08:08:47.234578", "end_time": "2024-12-04T08:08:53.189412" } ] curl -X POST 'http://localhost:5000/alpha/telemetry/get-span-tree' \ -H 'Content-Type: application/json' \ -d '{ "span_id" : "6cceb4b48a156913", "max_depth": 2, "attributes_to_return": ["input"] }' | jq . % Total % Received % Xferd Average Speed Time Time Time Current Dload Upload Total Spent Left Speed 100 875 100 790 100 85 18462 1986 --:--:-- --:--:-- --:--:-- 20833 { "span_id": "6cceb4b48a156913", "trace_id": "dafa796f6aaf925f511c04cd7c67fdda", "parent_span_id": "892a66d726c7f990", "name": "retrieve_rag_context", "start_time": "2024-12-04T09:28:21.781995", "end_time": "2024-12-04T09:28:21.913352", "attributes": { "input": [ "{\"role\":\"system\",\"content\":\"You are a helpful assistant\"}", "{\"role\":\"user\",\"content\":\"What are the top 5 topics that were explained in the documentation? Only list succinct bullet points.\",\"context\":null}" ] }, "children": [ { "span_id": "1a2df181854064a8", "trace_id": "dafa796f6aaf925f511c04cd7c67fdda", "parent_span_id": "6cceb4b48a156913", "name": "MemoryRouter.query_documents", "start_time": "2024-12-04T09:28:21.787620", "end_time": "2024-12-04T09:28:21.906512", "attributes": { "input": null }, "children": [], "status": "ok" } ], "status": "ok" } ``` <img width="1677" alt="Screenshot 2024-12-04 at 9 42 56 AM" src="https://github.com/user-attachments/assets/4d3cea93-05ce-415a-93d9-4b1628631bf8">
This commit is contained in:
parent
16769256b7
commit
fcd6449519
34 changed files with 1551 additions and 245 deletions
|
@ -23,6 +23,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
|
|||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.distribution.tracing import trace_protocol
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.common.deployment_types import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
@ -418,6 +419,7 @@ class AgentStepResponse(BaseModel):
|
|||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class Agents(Protocol):
|
||||
@webmethod(route="/agents/create")
|
||||
async def create_agent(
|
||||
|
|
|
@ -37,3 +37,8 @@ class DatasetIO(Protocol):
|
|||
page_token: Optional[str] = None,
|
||||
filter_condition: Optional[str] = None,
|
||||
) -> PaginatedRowsResult: ...
|
||||
|
||||
@webmethod(route="/datasetio/append-rows", method="POST")
|
||||
async def append_rows(
|
||||
self, dataset_id: str, rows: List[Dict[str, Any]]
|
||||
) -> None: ...
|
||||
|
|
|
@ -21,6 +21,8 @@ from llama_models.schema_utils import json_schema_type, webmethod
|
|||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.distribution.tracing import trace_protocol
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.models import * # noqa: F403
|
||||
|
||||
|
@ -220,6 +222,7 @@ class ModelStore(Protocol):
|
|||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class Inference(Protocol):
|
||||
model_store: ModelStore
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@ from pydantic import BaseModel, Field
|
|||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.memory_banks import * # noqa: F403
|
||||
from llama_stack.distribution.tracing import trace_protocol
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -43,6 +44,7 @@ class MemoryBankStore(Protocol):
|
|||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class Memory(Protocol):
|
||||
memory_bank_store: MemoryBankStore
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
|
|||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
from llama_stack.distribution.tracing import trace_protocol
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -129,6 +130,7 @@ class MemoryBankInput(BaseModel):
|
|||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class MemoryBanks(Protocol):
|
||||
@webmethod(route="/memory-banks/list", method="GET")
|
||||
async def list_memory_banks(self) -> List[MemoryBank]: ...
|
||||
|
|
|
@ -10,6 +10,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
|
|||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
from llama_stack.distribution.tracing import trace_protocol
|
||||
|
||||
|
||||
class CommonModelFields(BaseModel):
|
||||
|
@ -43,6 +44,7 @@ class ModelInput(CommonModelFields):
|
|||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class Models(Protocol):
|
||||
@webmethod(route="/models/list", method="GET")
|
||||
async def list_models(self) -> List[Model]: ...
|
||||
|
|
|
@ -10,6 +10,8 @@ from typing import Any, Dict, List, Protocol, runtime_checkable
|
|||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.distribution.tracing import trace_protocol
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.shields import * # noqa: F403
|
||||
|
||||
|
@ -43,6 +45,7 @@ class ShieldStore(Protocol):
|
|||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class Safety(Protocol):
|
||||
shield_store: ShieldStore
|
||||
|
||||
|
|
|
@ -10,6 +10,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
|
|||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
from llama_stack.distribution.tracing import trace_protocol
|
||||
|
||||
|
||||
class CommonShieldFields(BaseModel):
|
||||
|
@ -38,6 +39,7 @@ class ShieldInput(CommonShieldFields):
|
|||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class Shields(Protocol):
|
||||
@webmethod(route="/shields/list", method="GET")
|
||||
async def list_shields(self) -> List[Shield]: ...
|
||||
|
|
|
@ -6,12 +6,24 @@
|
|||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Literal, Optional, Protocol, runtime_checkable, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Protocol,
|
||||
runtime_checkable,
|
||||
Union,
|
||||
)
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
# Add this constant near the top of the file, after the imports
|
||||
DEFAULT_TTL_DAYS = 7
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class SpanStatus(Enum):
|
||||
|
@ -29,6 +41,11 @@ class Span(BaseModel):
|
|||
end_time: Optional[datetime] = None
|
||||
attributes: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||
|
||||
def set_attribute(self, key: str, value: Any):
|
||||
if self.attributes is None:
|
||||
self.attributes = {}
|
||||
self.attributes[key] = value
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Trace(BaseModel):
|
||||
|
@ -123,10 +140,49 @@ Event = Annotated[
|
|||
]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class EvalTrace(BaseModel):
|
||||
session_id: str
|
||||
step: str
|
||||
input: str
|
||||
output: str
|
||||
expected_output: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class SpanWithChildren(Span):
|
||||
children: List["SpanWithChildren"] = Field(default_factory=list)
|
||||
status: Optional[SpanStatus] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class QueryCondition(BaseModel):
|
||||
key: str
|
||||
op: Literal["eq", "ne", "gt", "lt"]
|
||||
value: Any
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Telemetry(Protocol):
|
||||
@webmethod(route="/telemetry/log-event")
|
||||
async def log_event(self, event: Event) -> None: ...
|
||||
|
||||
@webmethod(route="/telemetry/get-trace", method="GET")
|
||||
async def get_trace(self, trace_id: str) -> Trace: ...
|
||||
@webmethod(route="/telemetry/log-event")
|
||||
async def log_event(
|
||||
self, event: Event, ttl_seconds: int = DEFAULT_TTL_DAYS * 86400
|
||||
) -> None: ...
|
||||
|
||||
@webmethod(route="/telemetry/query-traces", method="POST")
|
||||
async def query_traces(
|
||||
self,
|
||||
attribute_filters: Optional[List[QueryCondition]] = None,
|
||||
limit: Optional[int] = 100,
|
||||
offset: Optional[int] = 0,
|
||||
order_by: Optional[List[str]] = None,
|
||||
) -> List[Trace]: ...
|
||||
|
||||
@webmethod(route="/telemetry/get-span-tree", method="POST")
|
||||
async def get_span_tree(
|
||||
self,
|
||||
span_id: str,
|
||||
attributes_to_return: Optional[List[str]] = None,
|
||||
max_depth: Optional[int] = None,
|
||||
) -> SpanWithChildren: ...
|
||||
|
|
|
@ -222,6 +222,12 @@ class DatasetIORouter(DatasetIO):
|
|||
filter_condition=filter_condition,
|
||||
)
|
||||
|
||||
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
||||
return await self.routing_table.get_provider_impl(dataset_id).append_rows(
|
||||
dataset_id=dataset_id,
|
||||
rows=rows,
|
||||
)
|
||||
|
||||
|
||||
class ScoringRouter(Scoring):
|
||||
def __init__(
|
||||
|
|
|
@ -43,9 +43,9 @@ from llama_stack.distribution.stack import (
|
|||
replace_env_vars,
|
||||
validate_env_pair,
|
||||
)
|
||||
from llama_stack.providers.inline.meta_reference.telemetry.console import (
|
||||
ConsoleConfig,
|
||||
ConsoleTelemetryImpl,
|
||||
from llama_stack.providers.inline.telemetry.meta_reference import (
|
||||
TelemetryAdapter,
|
||||
TelemetryConfig,
|
||||
)
|
||||
|
||||
from .endpoints import get_all_api_endpoints
|
||||
|
@ -290,7 +290,7 @@ def main():
|
|||
if Api.telemetry in impls:
|
||||
setup_logger(impls[Api.telemetry])
|
||||
else:
|
||||
setup_logger(ConsoleTelemetryImpl(ConsoleConfig()))
|
||||
setup_logger(TelemetryAdapter(TelemetryConfig()))
|
||||
|
||||
all_endpoints = get_all_api_endpoints()
|
||||
|
||||
|
|
128
llama_stack/distribution/tracing.py
Normal file
128
llama_stack/distribution/tracing.py
Normal file
|
@ -0,0 +1,128 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
from functools import wraps
|
||||
from typing import Any, AsyncGenerator, Callable, Type, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.providers.utils.telemetry import tracing
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def serialize_value(value: Any) -> str:
|
||||
"""Helper function to serialize values to string representation."""
|
||||
try:
|
||||
if isinstance(value, BaseModel):
|
||||
return value.model_dump_json()
|
||||
elif isinstance(value, list) and value and isinstance(value[0], BaseModel):
|
||||
return json.dumps([item.model_dump_json() for item in value])
|
||||
elif hasattr(value, "to_dict"):
|
||||
return json.dumps(value.to_dict())
|
||||
elif isinstance(value, (dict, list, int, float, str, bool)):
|
||||
return json.dumps(value)
|
||||
else:
|
||||
return str(value)
|
||||
except Exception:
|
||||
return str(value)
|
||||
|
||||
|
||||
def trace_protocol(cls: Type[T]) -> Type[T]:
|
||||
"""
|
||||
A class decorator that automatically traces all methods in a protocol/base class
|
||||
and its inheriting classes.
|
||||
"""
|
||||
|
||||
def trace_method(method: Callable) -> Callable:
|
||||
is_async = asyncio.iscoroutinefunction(method)
|
||||
is_async_gen = inspect.isasyncgenfunction(method)
|
||||
|
||||
def create_span_context(self: Any, *args: Any, **kwargs: Any) -> tuple:
|
||||
class_name = self.__class__.__name__
|
||||
method_name = method.__name__
|
||||
|
||||
span_type = (
|
||||
"async_generator" if is_async_gen else "async" if is_async else "sync"
|
||||
)
|
||||
span_attributes = {
|
||||
"class": class_name,
|
||||
"method": method_name,
|
||||
"type": span_type,
|
||||
"args": serialize_value(args),
|
||||
}
|
||||
|
||||
return class_name, method_name, span_attributes
|
||||
|
||||
@wraps(method)
|
||||
async def async_gen_wrapper(
|
||||
self: Any, *args: Any, **kwargs: Any
|
||||
) -> AsyncGenerator:
|
||||
class_name, method_name, span_attributes = create_span_context(
|
||||
self, *args, **kwargs
|
||||
)
|
||||
|
||||
with tracing.span(f"{class_name}.{method_name}", span_attributes) as span:
|
||||
try:
|
||||
count = 0
|
||||
async for item in method(self, *args, **kwargs):
|
||||
yield item
|
||||
count += 1
|
||||
finally:
|
||||
span.set_attribute("chunk_count", count)
|
||||
|
||||
@wraps(method)
|
||||
async def async_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
class_name, method_name, span_attributes = create_span_context(
|
||||
self, *args, **kwargs
|
||||
)
|
||||
|
||||
with tracing.span(f"{class_name}.{method_name}", span_attributes) as span:
|
||||
try:
|
||||
result = await method(self, *args, **kwargs)
|
||||
span.set_attribute("output", serialize_value(result))
|
||||
return result
|
||||
except Exception as e:
|
||||
span.set_attribute("error", str(e))
|
||||
raise
|
||||
|
||||
@wraps(method)
|
||||
def sync_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
class_name, method_name, span_attributes = create_span_context(
|
||||
self, *args, **kwargs
|
||||
)
|
||||
|
||||
with tracing.span(f"{class_name}.{method_name}", span_attributes) as span:
|
||||
try:
|
||||
result = method(self, *args, **kwargs)
|
||||
span.set_attribute("output", serialize_value(result))
|
||||
return result
|
||||
except Exception as e:
|
||||
raise
|
||||
|
||||
if is_async_gen:
|
||||
return async_gen_wrapper
|
||||
elif is_async:
|
||||
return async_wrapper
|
||||
else:
|
||||
return sync_wrapper
|
||||
|
||||
original_init_subclass = getattr(cls, "__init_subclass__", None)
|
||||
|
||||
def __init_subclass__(cls_child, **kwargs): # noqa: N807
|
||||
if original_init_subclass:
|
||||
original_init_subclass(**kwargs)
|
||||
|
||||
for name, method in vars(cls_child).items():
|
||||
if inspect.isfunction(method) and not name.startswith("_"):
|
||||
setattr(cls_child, name, trace_method(method)) # noqa: B010
|
||||
|
||||
cls.__init_subclass__ = classmethod(__init_subclass__)
|
||||
|
||||
return cls
|
|
@ -144,87 +144,91 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
async def create_session(self, name: str) -> str:
|
||||
return await self.storage.create_session(name)
|
||||
|
||||
@tracing.span("create_and_execute_turn")
|
||||
async def create_and_execute_turn(
|
||||
self, request: AgentTurnCreateRequest
|
||||
) -> AsyncGenerator:
|
||||
assert request.stream is True, "Non-streaming not supported"
|
||||
with tracing.span("create_and_execute_turn") as span:
|
||||
span.set_attribute("session_id", request.session_id)
|
||||
span.set_attribute("agent_id", self.agent_id)
|
||||
span.set_attribute("request", request.model_dump_json())
|
||||
assert request.stream is True, "Non-streaming not supported"
|
||||
|
||||
session_info = await self.storage.get_session_info(request.session_id)
|
||||
if session_info is None:
|
||||
raise ValueError(f"Session {request.session_id} not found")
|
||||
session_info = await self.storage.get_session_info(request.session_id)
|
||||
if session_info is None:
|
||||
raise ValueError(f"Session {request.session_id} not found")
|
||||
|
||||
turns = await self.storage.get_session_turns(request.session_id)
|
||||
turns = await self.storage.get_session_turns(request.session_id)
|
||||
|
||||
messages = []
|
||||
if self.agent_config.instructions != "":
|
||||
messages.append(SystemMessage(content=self.agent_config.instructions))
|
||||
messages = []
|
||||
if self.agent_config.instructions != "":
|
||||
messages.append(SystemMessage(content=self.agent_config.instructions))
|
||||
|
||||
for i, turn in enumerate(turns):
|
||||
messages.extend(self.turn_to_messages(turn))
|
||||
for i, turn in enumerate(turns):
|
||||
messages.extend(self.turn_to_messages(turn))
|
||||
|
||||
messages.extend(request.messages)
|
||||
messages.extend(request.messages)
|
||||
|
||||
turn_id = str(uuid.uuid4())
|
||||
start_time = datetime.now()
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseTurnStartPayload(
|
||||
turn_id=turn_id,
|
||||
turn_id = str(uuid.uuid4())
|
||||
span.set_attribute("turn_id", turn_id)
|
||||
start_time = datetime.now()
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseTurnStartPayload(
|
||||
turn_id=turn_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
steps = []
|
||||
output_message = None
|
||||
async for chunk in self.run(
|
||||
session_id=request.session_id,
|
||||
turn_id=turn_id,
|
||||
input_messages=messages,
|
||||
attachments=request.attachments or [],
|
||||
sampling_params=self.agent_config.sampling_params,
|
||||
stream=request.stream,
|
||||
):
|
||||
if isinstance(chunk, CompletionMessage):
|
||||
log.info(
|
||||
f"{chunk.role.capitalize()}: {chunk.content}",
|
||||
)
|
||||
output_message = chunk
|
||||
continue
|
||||
|
||||
assert isinstance(
|
||||
chunk, AgentTurnResponseStreamChunk
|
||||
), f"Unexpected type {type(chunk)}"
|
||||
event = chunk.event
|
||||
if (
|
||||
event.payload.event_type
|
||||
== AgentTurnResponseEventType.step_complete.value
|
||||
steps = []
|
||||
output_message = None
|
||||
async for chunk in self.run(
|
||||
session_id=request.session_id,
|
||||
turn_id=turn_id,
|
||||
input_messages=messages,
|
||||
attachments=request.attachments or [],
|
||||
sampling_params=self.agent_config.sampling_params,
|
||||
stream=request.stream,
|
||||
):
|
||||
steps.append(event.payload.step_details)
|
||||
if isinstance(chunk, CompletionMessage):
|
||||
log.info(
|
||||
f"{chunk.role.capitalize()}: {chunk.content}",
|
||||
)
|
||||
output_message = chunk
|
||||
continue
|
||||
|
||||
yield chunk
|
||||
assert isinstance(
|
||||
chunk, AgentTurnResponseStreamChunk
|
||||
), f"Unexpected type {type(chunk)}"
|
||||
event = chunk.event
|
||||
if (
|
||||
event.payload.event_type
|
||||
== AgentTurnResponseEventType.step_complete.value
|
||||
):
|
||||
steps.append(event.payload.step_details)
|
||||
|
||||
assert output_message is not None
|
||||
yield chunk
|
||||
|
||||
turn = Turn(
|
||||
turn_id=turn_id,
|
||||
session_id=request.session_id,
|
||||
input_messages=request.messages,
|
||||
output_message=output_message,
|
||||
started_at=start_time,
|
||||
completed_at=datetime.now(),
|
||||
steps=steps,
|
||||
)
|
||||
await self.storage.add_turn_to_session(request.session_id, turn)
|
||||
assert output_message is not None
|
||||
|
||||
chunk = AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseTurnCompletePayload(
|
||||
turn=turn,
|
||||
turn = Turn(
|
||||
turn_id=turn_id,
|
||||
session_id=request.session_id,
|
||||
input_messages=request.messages,
|
||||
output_message=output_message,
|
||||
started_at=start_time,
|
||||
completed_at=datetime.now(),
|
||||
steps=steps,
|
||||
)
|
||||
await self.storage.add_turn_to_session(request.session_id, turn)
|
||||
|
||||
chunk = AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseTurnCompletePayload(
|
||||
turn=turn,
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
yield chunk
|
||||
yield chunk
|
||||
|
||||
async def run(
|
||||
self,
|
||||
|
@ -273,7 +277,6 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
yield final_response
|
||||
|
||||
@tracing.span("run_shields")
|
||||
async def run_multiple_shields_wrapper(
|
||||
self,
|
||||
turn_id: str,
|
||||
|
@ -281,23 +284,47 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
shields: List[str],
|
||||
touchpoint: str,
|
||||
) -> AsyncGenerator:
|
||||
if len(shields) == 0:
|
||||
return
|
||||
with tracing.span("run_shields") as span:
|
||||
span.set_attribute("turn_id", turn_id)
|
||||
span.set_attribute("input", [m.model_dump_json() for m in messages])
|
||||
if len(shields) == 0:
|
||||
span.set_attribute("output", "no shields")
|
||||
return
|
||||
|
||||
step_id = str(uuid.uuid4())
|
||||
try:
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepStartPayload(
|
||||
step_type=StepType.shield_call.value,
|
||||
step_id=step_id,
|
||||
metadata=dict(touchpoint=touchpoint),
|
||||
step_id = str(uuid.uuid4())
|
||||
try:
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepStartPayload(
|
||||
step_type=StepType.shield_call.value,
|
||||
step_id=step_id,
|
||||
metadata=dict(touchpoint=touchpoint),
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
await self.run_multiple_shields(messages, shields)
|
||||
await self.run_multiple_shields(messages, shields)
|
||||
|
||||
except SafetyException as e:
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.shield_call.value,
|
||||
step_details=ShieldCallStep(
|
||||
step_id=step_id,
|
||||
turn_id=turn_id,
|
||||
violation=e.violation,
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
span.set_attribute("output", e.violation.model_dump_json())
|
||||
|
||||
yield CompletionMessage(
|
||||
content=str(e),
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
)
|
||||
yield False
|
||||
|
||||
except SafetyException as e:
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
|
@ -305,30 +332,12 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
step_details=ShieldCallStep(
|
||||
step_id=step_id,
|
||||
turn_id=turn_id,
|
||||
violation=e.violation,
|
||||
violation=None,
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
yield CompletionMessage(
|
||||
content=str(e),
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
)
|
||||
yield False
|
||||
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.shield_call.value,
|
||||
step_details=ShieldCallStep(
|
||||
step_id=step_id,
|
||||
turn_id=turn_id,
|
||||
violation=None,
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
span.set_attribute("output", "no violations")
|
||||
|
||||
async def _run(
|
||||
self,
|
||||
|
@ -356,10 +365,15 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
# TODO: find older context from the session and either replace it
|
||||
# or append with a sliding window. this is really a very simplistic implementation
|
||||
with tracing.span("retrieve_rag_context"):
|
||||
with tracing.span("retrieve_rag_context") as span:
|
||||
rag_context, bank_ids = await self._retrieve_context(
|
||||
session_id, input_messages, attachments
|
||||
)
|
||||
span.set_attribute(
|
||||
"input", [m.model_dump_json() for m in input_messages]
|
||||
)
|
||||
span.set_attribute("output", rag_context)
|
||||
span.set_attribute("bank_ids", bank_ids)
|
||||
|
||||
step_id = str(uuid.uuid4())
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
|
@ -416,7 +430,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
content = ""
|
||||
stop_reason = None
|
||||
|
||||
with tracing.span("inference"):
|
||||
with tracing.span("inference") as span:
|
||||
async for chunk in await self.inference_api.chat_completion(
|
||||
self.agent_config.model,
|
||||
input_messages,
|
||||
|
@ -436,7 +450,6 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
if isinstance(delta, ToolCallDelta):
|
||||
if delta.parse_status == ToolCallParseStatus.success:
|
||||
tool_calls.append(delta.content)
|
||||
|
||||
if stream:
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
|
@ -466,6 +479,13 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
if event.stop_reason is not None:
|
||||
stop_reason = event.stop_reason
|
||||
span.set_attribute("stop_reason", stop_reason)
|
||||
span.set_attribute(
|
||||
"input", [m.model_dump_json() for m in input_messages]
|
||||
)
|
||||
span.set_attribute(
|
||||
"output", f"content: {content} tool_calls: {tool_calls}"
|
||||
)
|
||||
|
||||
stop_reason = stop_reason or StopReason.out_of_tokens
|
||||
|
||||
|
@ -549,7 +569,13 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
)
|
||||
|
||||
with tracing.span("tool_execution"):
|
||||
with tracing.span(
|
||||
"tool_execution",
|
||||
{
|
||||
"tool_name": tool_call.tool_name,
|
||||
"input": message.model_dump_json(),
|
||||
},
|
||||
) as span:
|
||||
result_messages = await execute_tool_call_maybe(
|
||||
self.tools_dict,
|
||||
[message],
|
||||
|
@ -558,6 +584,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
len(result_messages) == 1
|
||||
), "Currently not supporting multiple messages"
|
||||
result_message = result_messages[0]
|
||||
span.set_attribute("output", result_message.model_dump_json())
|
||||
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
|
|
|
@ -3,14 +3,17 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import pandas
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.apis.datasetio import * # noqa: F403
|
||||
import base64
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
||||
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
|
||||
|
@ -131,3 +134,41 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
|||
total_count=len(rows),
|
||||
next_page_token=str(end),
|
||||
)
|
||||
|
||||
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
||||
dataset_info = self.dataset_infos.get(dataset_id)
|
||||
if dataset_info is None:
|
||||
raise ValueError(f"Dataset with id {dataset_id} not found")
|
||||
|
||||
dataset_impl = dataset_info.dataset_impl
|
||||
dataset_impl.load()
|
||||
|
||||
new_rows_df = pandas.DataFrame(rows)
|
||||
new_rows_df = dataset_impl._validate_dataset_schema(new_rows_df)
|
||||
dataset_impl.df = pandas.concat(
|
||||
[dataset_impl.df, new_rows_df], ignore_index=True
|
||||
)
|
||||
|
||||
url = str(dataset_info.dataset_def.url)
|
||||
parsed_url = urlparse(url)
|
||||
|
||||
if parsed_url.scheme == "file" or not parsed_url.scheme:
|
||||
file_path = parsed_url.path
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
dataset_impl.df.to_csv(file_path, index=False)
|
||||
elif parsed_url.scheme == "data":
|
||||
# For data URLs, we need to update the base64-encoded content
|
||||
if not parsed_url.path.startswith("text/csv;base64,"):
|
||||
raise ValueError("Data URL must be a base64-encoded CSV")
|
||||
|
||||
csv_buffer = dataset_impl.df.to_csv(index=False)
|
||||
base64_content = base64.b64encode(csv_buffer.encode("utf-8")).decode(
|
||||
"utf-8"
|
||||
)
|
||||
dataset_info.dataset_def.url = URL(
|
||||
uri=f"data:text/csv;base64,{base64_content}"
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported URL scheme: {parsed_url.scheme}. Only file:// and data: URLs are supported for writing."
|
||||
)
|
||||
|
|
|
@ -1,15 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import ConsoleConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: ConsoleConfig, _deps):
|
||||
from .console import ConsoleTelemetryImpl
|
||||
|
||||
impl = ConsoleTelemetryImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
|
@ -1,21 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class LogFormat(Enum):
|
||||
TEXT = "text"
|
||||
JSON = "json"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ConsoleConfig(BaseModel):
|
||||
log_format: LogFormat = LogFormat.TEXT
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from .config import LogFormat
|
||||
|
||||
|
@ -49,8 +49,27 @@ class ConsoleTelemetryImpl(Telemetry):
|
|||
if formatted:
|
||||
print(formatted)
|
||||
|
||||
async def get_trace(self, trace_id: str) -> Trace:
|
||||
raise NotImplementedError()
|
||||
async def query_traces(
|
||||
self,
|
||||
attribute_conditions: Optional[List[QueryCondition]] = None,
|
||||
attribute_keys_to_return: Optional[List[str]] = None,
|
||||
limit: Optional[int] = 100,
|
||||
offset: Optional[int] = 0,
|
||||
order_by: Optional[List[str]] = None,
|
||||
) -> List[Trace]:
|
||||
raise NotImplementedError("Console telemetry does not support trace querying")
|
||||
|
||||
async def get_spans(
|
||||
self,
|
||||
span_id: str,
|
||||
attribute_conditions: Optional[List[QueryCondition]] = None,
|
||||
attribute_keys_to_return: Optional[List[str]] = None,
|
||||
max_depth: Optional[int] = None,
|
||||
limit: Optional[int] = 100,
|
||||
offset: Optional[int] = 0,
|
||||
order_by: Optional[List[str]] = None,
|
||||
) -> SpanWithChildren:
|
||||
raise NotImplementedError("Console telemetry does not support span querying")
|
||||
|
||||
|
||||
COLORS = {
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from .config import TelemetryConfig, TelemetrySink
|
||||
from .telemetry import TelemetryAdapter
|
||||
|
||||
__all__ = ["TelemetryConfig", "TelemetryAdapter", "TelemetrySink"]
|
||||
|
||||
|
||||
async def get_provider_impl(config: TelemetryConfig, deps: Dict[str, Any]):
|
||||
impl = TelemetryAdapter(config)
|
||||
await impl.initialize()
|
||||
return impl
|
|
@ -0,0 +1,45 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
|
||||
|
||||
|
||||
class TelemetrySink(str, Enum):
|
||||
JAEGER = "jaeger"
|
||||
SQLITE = "sqlite"
|
||||
CONSOLE = "console"
|
||||
|
||||
|
||||
class TelemetryConfig(BaseModel):
|
||||
otel_endpoint: str = Field(
|
||||
default="http://localhost:4318/v1/traces",
|
||||
description="The OpenTelemetry collector endpoint URL",
|
||||
)
|
||||
service_name: str = Field(
|
||||
default="llama-stack",
|
||||
description="The service name to use for telemetry",
|
||||
)
|
||||
sinks: List[TelemetrySink] = Field(
|
||||
default=[TelemetrySink.CONSOLE, TelemetrySink.SQLITE],
|
||||
description="List of telemetry sinks to enable (possible values: jaeger, sqlite, console)",
|
||||
)
|
||||
sqlite_db_path: str = Field(
|
||||
default=(RUNTIME_BASE_DIR / "trace_store.db").as_posix(),
|
||||
description="The path to the SQLite database to use for storing traces",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
||||
return {
|
||||
"service_name": "${env.OTEL_SERVICE_NAME:llama-stack}",
|
||||
"sinks": "${env.TELEMETRY_SINKS:['console', 'sqlite']}",
|
||||
"sqlite_db_path": "${env.SQLITE_DB_PATH:${runtime.base_dir}/trace_store.db}",
|
||||
}
|
|
@ -0,0 +1,95 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from opentelemetry.sdk.trace import ReadableSpan
|
||||
from opentelemetry.sdk.trace.export import SpanProcessor
|
||||
|
||||
# Colors for console output
|
||||
COLORS = {
|
||||
"reset": "\033[0m",
|
||||
"bold": "\033[1m",
|
||||
"dim": "\033[2m",
|
||||
"red": "\033[31m",
|
||||
"green": "\033[32m",
|
||||
"yellow": "\033[33m",
|
||||
"blue": "\033[34m",
|
||||
"magenta": "\033[35m",
|
||||
"cyan": "\033[36m",
|
||||
"white": "\033[37m",
|
||||
}
|
||||
|
||||
|
||||
class ConsoleSpanProcessor(SpanProcessor):
|
||||
"""A SpanProcessor that prints spans to the console with color formatting."""
|
||||
|
||||
def on_start(self, span: ReadableSpan, parent_context=None) -> None:
|
||||
"""Called when a span starts."""
|
||||
timestamp = datetime.utcfromtimestamp(span.start_time / 1e9).strftime(
|
||||
"%H:%M:%S.%f"
|
||||
)[:-3]
|
||||
|
||||
print(
|
||||
f"{COLORS['dim']}{timestamp}{COLORS['reset']} "
|
||||
f"{COLORS['magenta']}[START]{COLORS['reset']} "
|
||||
f"{COLORS['cyan']}{span.name}{COLORS['reset']}"
|
||||
)
|
||||
|
||||
def on_end(self, span: ReadableSpan) -> None:
|
||||
"""Called when a span ends."""
|
||||
timestamp = datetime.utcfromtimestamp(span.end_time / 1e9).strftime(
|
||||
"%H:%M:%S.%f"
|
||||
)[:-3]
|
||||
|
||||
# Build the span context string
|
||||
span_context = (
|
||||
f"{COLORS['dim']}{timestamp}{COLORS['reset']} "
|
||||
f"{COLORS['magenta']}[END]{COLORS['reset']} "
|
||||
f"{COLORS['cyan']}{span.name}{COLORS['reset']} "
|
||||
)
|
||||
|
||||
# Add status if not OK
|
||||
if span.status.status_code != 0: # UNSET or ERROR
|
||||
status_color = (
|
||||
COLORS["red"] if span.status.status_code == 2 else COLORS["yellow"]
|
||||
)
|
||||
span_context += (
|
||||
f" {status_color}[{span.status.status_code}]{COLORS['reset']}"
|
||||
)
|
||||
|
||||
# Add duration
|
||||
duration_ms = (span.end_time - span.start_time) / 1e6
|
||||
span_context += f" {COLORS['dim']}({duration_ms:.2f}ms){COLORS['reset']}"
|
||||
|
||||
# Print the main span line
|
||||
print(span_context)
|
||||
|
||||
# Print attributes indented
|
||||
if span.attributes:
|
||||
for key, value in span.attributes.items():
|
||||
print(f" {COLORS['dim']}{key}: {value}{COLORS['reset']}")
|
||||
|
||||
# Print events indented
|
||||
for event in span.events:
|
||||
event_time = datetime.utcfromtimestamp(event.timestamp / 1e9).strftime(
|
||||
"%H:%M:%S.%f"
|
||||
)[:-3]
|
||||
print(
|
||||
f" {COLORS['dim']}{event_time}{COLORS['reset']} "
|
||||
f"{COLORS['cyan']}[EVENT]{COLORS['reset']} {event.name}"
|
||||
)
|
||||
if event.attributes:
|
||||
for key, value in event.attributes.items():
|
||||
print(f" {COLORS['dim']}{key}: {value}{COLORS['reset']}")
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""Shutdown the processor."""
|
||||
pass
|
||||
|
||||
def force_flush(self, timeout_millis: float = None) -> bool:
|
||||
"""Force flush any pending spans."""
|
||||
return True
|
|
@ -0,0 +1,242 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import os
|
||||
import sqlite3
|
||||
import threading
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict
|
||||
|
||||
from opentelemetry.sdk.trace import SpanProcessor
|
||||
from opentelemetry.trace import Span
|
||||
|
||||
|
||||
class SQLiteSpanProcessor(SpanProcessor):
|
||||
def __init__(self, conn_string, ttl_days=30):
|
||||
"""Initialize the SQLite span processor with a connection string."""
|
||||
self.conn_string = conn_string
|
||||
self.ttl_days = ttl_days
|
||||
self.cleanup_task = None
|
||||
self._thread_local = threading.local()
|
||||
self._connections: Dict[int, sqlite3.Connection] = {}
|
||||
self._lock = threading.Lock()
|
||||
self.setup_database()
|
||||
|
||||
def _get_connection(self) -> sqlite3.Connection:
|
||||
"""Get a thread-specific database connection."""
|
||||
thread_id = threading.get_ident()
|
||||
with self._lock:
|
||||
if thread_id not in self._connections:
|
||||
conn = sqlite3.connect(self.conn_string)
|
||||
self._connections[thread_id] = conn
|
||||
return self._connections[thread_id]
|
||||
|
||||
def setup_database(self):
|
||||
"""Create the necessary tables if they don't exist."""
|
||||
# Create directory if it doesn't exist
|
||||
os.makedirs(os.path.dirname(self.conn_string), exist_ok=True)
|
||||
|
||||
conn = self._get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS traces (
|
||||
trace_id TEXT PRIMARY KEY,
|
||||
service_name TEXT,
|
||||
root_span_id TEXT,
|
||||
start_time TIMESTAMP,
|
||||
end_time TIMESTAMP,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS spans (
|
||||
span_id TEXT PRIMARY KEY,
|
||||
trace_id TEXT REFERENCES traces(trace_id),
|
||||
parent_span_id TEXT,
|
||||
name TEXT,
|
||||
start_time TIMESTAMP,
|
||||
end_time TIMESTAMP,
|
||||
attributes TEXT,
|
||||
status TEXT,
|
||||
kind TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS span_events (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
span_id TEXT REFERENCES spans(span_id),
|
||||
name TEXT,
|
||||
timestamp TIMESTAMP,
|
||||
attributes TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_traces_created_at
|
||||
ON traces(created_at)
|
||||
"""
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
cursor.close()
|
||||
|
||||
# Start periodic cleanup in a separate thread
|
||||
self.cleanup_task = threading.Thread(target=self._periodic_cleanup, daemon=True)
|
||||
self.cleanup_task.start()
|
||||
|
||||
def _cleanup_old_data(self):
|
||||
"""Delete records older than TTL."""
|
||||
try:
|
||||
conn = self._get_connection()
|
||||
cutoff_date = (datetime.now() - timedelta(days=self.ttl_days)).isoformat()
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Delete old span events
|
||||
cursor.execute(
|
||||
"""
|
||||
DELETE FROM span_events
|
||||
WHERE span_id IN (
|
||||
SELECT span_id FROM spans
|
||||
WHERE trace_id IN (
|
||||
SELECT trace_id FROM traces
|
||||
WHERE created_at < ?
|
||||
)
|
||||
)
|
||||
""",
|
||||
(cutoff_date,),
|
||||
)
|
||||
|
||||
# Delete old spans
|
||||
cursor.execute(
|
||||
"""
|
||||
DELETE FROM spans
|
||||
WHERE trace_id IN (
|
||||
SELECT trace_id FROM traces
|
||||
WHERE created_at < ?
|
||||
)
|
||||
""",
|
||||
(cutoff_date,),
|
||||
)
|
||||
|
||||
# Delete old traces
|
||||
cursor.execute("DELETE FROM traces WHERE created_at < ?", (cutoff_date,))
|
||||
|
||||
conn.commit()
|
||||
cursor.close()
|
||||
except Exception as e:
|
||||
print(f"Error during cleanup: {e}")
|
||||
|
||||
def _periodic_cleanup(self):
|
||||
"""Run cleanup periodically."""
|
||||
import time
|
||||
|
||||
while True:
|
||||
time.sleep(3600) # Sleep for 1 hour
|
||||
self._cleanup_old_data()
|
||||
|
||||
def on_start(self, span: Span, parent_context=None):
|
||||
"""Called when a span starts."""
|
||||
pass
|
||||
|
||||
def on_end(self, span: Span):
|
||||
"""Called when a span ends. Export the span data to SQLite."""
|
||||
try:
|
||||
conn = self._get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
trace_id = format(span.get_span_context().trace_id, "032x")
|
||||
span_id = format(span.get_span_context().span_id, "016x")
|
||||
service_name = span.resource.attributes.get("service.name", "unknown")
|
||||
|
||||
parent_span_id = None
|
||||
parent_context = span.parent
|
||||
if parent_context:
|
||||
parent_span_id = format(parent_context.span_id, "016x")
|
||||
|
||||
# Insert into traces
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO traces (
|
||||
trace_id, service_name, root_span_id, start_time, end_time
|
||||
) VALUES (?, ?, ?, ?, ?)
|
||||
ON CONFLICT(trace_id) DO UPDATE SET
|
||||
root_span_id = COALESCE(root_span_id, excluded.root_span_id),
|
||||
start_time = MIN(excluded.start_time, start_time),
|
||||
end_time = MAX(excluded.end_time, end_time)
|
||||
""",
|
||||
(
|
||||
trace_id,
|
||||
service_name,
|
||||
(span_id if not parent_span_id else None),
|
||||
datetime.fromtimestamp(span.start_time / 1e9).isoformat(),
|
||||
datetime.fromtimestamp(span.end_time / 1e9).isoformat(),
|
||||
),
|
||||
)
|
||||
|
||||
# Insert into spans
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO spans (
|
||||
span_id, trace_id, parent_span_id, name,
|
||||
start_time, end_time, attributes, status,
|
||||
kind
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
span_id,
|
||||
trace_id,
|
||||
parent_span_id,
|
||||
span.name,
|
||||
datetime.fromtimestamp(span.start_time / 1e9).isoformat(),
|
||||
datetime.fromtimestamp(span.end_time / 1e9).isoformat(),
|
||||
json.dumps(dict(span.attributes)),
|
||||
span.status.status_code.name,
|
||||
span.kind.name,
|
||||
),
|
||||
)
|
||||
|
||||
for event in span.events:
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO span_events (
|
||||
span_id, name, timestamp, attributes
|
||||
) VALUES (?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
span_id,
|
||||
event.name,
|
||||
datetime.fromtimestamp(event.timestamp / 1e9).isoformat(),
|
||||
json.dumps(dict(event.attributes)),
|
||||
),
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
cursor.close()
|
||||
except Exception as e:
|
||||
print(f"Error exporting span to SQLite: {e}")
|
||||
|
||||
def shutdown(self):
|
||||
"""Cleanup any resources."""
|
||||
with self._lock:
|
||||
for conn in self._connections.values():
|
||||
if conn:
|
||||
conn.close()
|
||||
self._connections.clear()
|
||||
|
||||
def force_flush(self, timeout_millis=30000):
|
||||
"""Force export of spans."""
|
||||
pass
|
|
@ -0,0 +1,247 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import threading
|
||||
from typing import List, Optional
|
||||
|
||||
from opentelemetry import metrics, trace
|
||||
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
|
||||
from opentelemetry.sdk.metrics import MeterProvider
|
||||
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||
from opentelemetry.semconv.resource import ResourceAttributes
|
||||
|
||||
from llama_stack.providers.inline.telemetry.meta_reference.console_span_processor import (
|
||||
ConsoleSpanProcessor,
|
||||
)
|
||||
|
||||
from llama_stack.providers.inline.telemetry.meta_reference.sqlite_span_processor import (
|
||||
SQLiteSpanProcessor,
|
||||
)
|
||||
from llama_stack.providers.utils.telemetry.sqlite_trace_store import SQLiteTraceStore
|
||||
|
||||
from llama_stack.apis.telemetry import * # noqa: F403
|
||||
|
||||
from .config import TelemetryConfig, TelemetrySink
|
||||
|
||||
_GLOBAL_STORAGE = {
|
||||
"active_spans": {},
|
||||
"counters": {},
|
||||
"gauges": {},
|
||||
"up_down_counters": {},
|
||||
}
|
||||
_global_lock = threading.Lock()
|
||||
|
||||
|
||||
def string_to_trace_id(s: str) -> int:
|
||||
# Convert the string to bytes and then to an integer
|
||||
return int.from_bytes(s.encode(), byteorder="big", signed=False)
|
||||
|
||||
|
||||
def string_to_span_id(s: str) -> int:
|
||||
# Use only the first 8 bytes (64 bits) for span ID
|
||||
return int.from_bytes(s.encode()[:8], byteorder="big", signed=False)
|
||||
|
||||
|
||||
def is_tracing_enabled(tracer):
|
||||
with tracer.start_as_current_span("check_tracing") as span:
|
||||
return span.is_recording()
|
||||
|
||||
|
||||
class TelemetryAdapter(Telemetry):
|
||||
def __init__(self, config: TelemetryConfig) -> None:
|
||||
self.config = config
|
||||
|
||||
resource = Resource.create(
|
||||
{
|
||||
ResourceAttributes.SERVICE_NAME: self.config.service_name,
|
||||
}
|
||||
)
|
||||
|
||||
provider = TracerProvider(resource=resource)
|
||||
trace.set_tracer_provider(provider)
|
||||
if TelemetrySink.JAEGER in self.config.sinks:
|
||||
otlp_exporter = OTLPSpanExporter(
|
||||
endpoint=self.config.otel_endpoint,
|
||||
)
|
||||
span_processor = BatchSpanProcessor(otlp_exporter)
|
||||
trace.get_tracer_provider().add_span_processor(span_processor)
|
||||
metric_reader = PeriodicExportingMetricReader(
|
||||
OTLPMetricExporter(
|
||||
endpoint=self.config.otel_endpoint,
|
||||
)
|
||||
)
|
||||
metric_provider = MeterProvider(
|
||||
resource=resource, metric_readers=[metric_reader]
|
||||
)
|
||||
metrics.set_meter_provider(metric_provider)
|
||||
self.meter = metrics.get_meter(__name__)
|
||||
if TelemetrySink.SQLITE in self.config.sinks:
|
||||
trace.get_tracer_provider().add_span_processor(
|
||||
SQLiteSpanProcessor(self.config.sqlite_db_path)
|
||||
)
|
||||
self.trace_store = SQLiteTraceStore(self.config.sqlite_db_path)
|
||||
if TelemetrySink.CONSOLE in self.config.sinks:
|
||||
trace.get_tracer_provider().add_span_processor(ConsoleSpanProcessor())
|
||||
self._lock = _global_lock
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
trace.get_tracer_provider().force_flush()
|
||||
trace.get_tracer_provider().shutdown()
|
||||
metrics.get_meter_provider().shutdown()
|
||||
|
||||
async def log_event(self, event: Event, ttl_seconds: int = 604800) -> None:
|
||||
if isinstance(event, UnstructuredLogEvent):
|
||||
self._log_unstructured(event, ttl_seconds)
|
||||
elif isinstance(event, MetricEvent):
|
||||
self._log_metric(event)
|
||||
elif isinstance(event, StructuredLogEvent):
|
||||
self._log_structured(event, ttl_seconds)
|
||||
else:
|
||||
raise ValueError(f"Unknown event type: {event}")
|
||||
|
||||
def _log_unstructured(self, event: UnstructuredLogEvent, ttl_seconds: int) -> None:
|
||||
with self._lock:
|
||||
# Use global storage instead of instance storage
|
||||
span_id = string_to_span_id(event.span_id)
|
||||
span = _GLOBAL_STORAGE["active_spans"].get(span_id)
|
||||
|
||||
if span:
|
||||
timestamp_ns = int(event.timestamp.timestamp() * 1e9)
|
||||
span.add_event(
|
||||
name=event.type,
|
||||
attributes={
|
||||
"message": event.message,
|
||||
"severity": event.severity.value,
|
||||
"__ttl__": ttl_seconds,
|
||||
**event.attributes,
|
||||
},
|
||||
timestamp=timestamp_ns,
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"Warning: No active span found for span_id {span_id}. Dropping event: {event}"
|
||||
)
|
||||
|
||||
def _get_or_create_counter(self, name: str, unit: str) -> metrics.Counter:
|
||||
if name not in _GLOBAL_STORAGE["counters"]:
|
||||
_GLOBAL_STORAGE["counters"][name] = self.meter.create_counter(
|
||||
name=name,
|
||||
unit=unit,
|
||||
description=f"Counter for {name}",
|
||||
)
|
||||
return _GLOBAL_STORAGE["counters"][name]
|
||||
|
||||
def _get_or_create_gauge(self, name: str, unit: str) -> metrics.ObservableGauge:
|
||||
if name not in _GLOBAL_STORAGE["gauges"]:
|
||||
_GLOBAL_STORAGE["gauges"][name] = self.meter.create_gauge(
|
||||
name=name,
|
||||
unit=unit,
|
||||
description=f"Gauge for {name}",
|
||||
)
|
||||
return _GLOBAL_STORAGE["gauges"][name]
|
||||
|
||||
def _log_metric(self, event: MetricEvent) -> None:
|
||||
if isinstance(event.value, int):
|
||||
counter = self._get_or_create_counter(event.metric, event.unit)
|
||||
counter.add(event.value, attributes=event.attributes)
|
||||
elif isinstance(event.value, float):
|
||||
up_down_counter = self._get_or_create_up_down_counter(
|
||||
event.metric, event.unit
|
||||
)
|
||||
up_down_counter.add(event.value, attributes=event.attributes)
|
||||
|
||||
def _get_or_create_up_down_counter(
|
||||
self, name: str, unit: str
|
||||
) -> metrics.UpDownCounter:
|
||||
if name not in _GLOBAL_STORAGE["up_down_counters"]:
|
||||
_GLOBAL_STORAGE["up_down_counters"][name] = (
|
||||
self.meter.create_up_down_counter(
|
||||
name=name,
|
||||
unit=unit,
|
||||
description=f"UpDownCounter for {name}",
|
||||
)
|
||||
)
|
||||
return _GLOBAL_STORAGE["up_down_counters"][name]
|
||||
|
||||
def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None:
|
||||
with self._lock:
|
||||
span_id = string_to_span_id(event.span_id)
|
||||
trace_id = string_to_trace_id(event.trace_id)
|
||||
tracer = trace.get_tracer(__name__)
|
||||
if event.attributes is None:
|
||||
event.attributes = {}
|
||||
event.attributes["__ttl__"] = ttl_seconds
|
||||
|
||||
if isinstance(event.payload, SpanStartPayload):
|
||||
# Check if span already exists to prevent duplicates
|
||||
if span_id in _GLOBAL_STORAGE["active_spans"]:
|
||||
return
|
||||
|
||||
parent_span = None
|
||||
if event.payload.parent_span_id:
|
||||
parent_span_id = string_to_span_id(event.payload.parent_span_id)
|
||||
parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id)
|
||||
|
||||
context = trace.Context(trace_id=trace_id)
|
||||
if parent_span:
|
||||
context = trace.set_span_in_context(parent_span, context)
|
||||
|
||||
span = tracer.start_span(
|
||||
name=event.payload.name,
|
||||
context=context,
|
||||
attributes=event.attributes or {},
|
||||
)
|
||||
_GLOBAL_STORAGE["active_spans"][span_id] = span
|
||||
|
||||
elif isinstance(event.payload, SpanEndPayload):
|
||||
span = _GLOBAL_STORAGE["active_spans"].get(span_id)
|
||||
if span:
|
||||
if event.attributes:
|
||||
span.set_attributes(event.attributes)
|
||||
|
||||
status = (
|
||||
trace.Status(status_code=trace.StatusCode.OK)
|
||||
if event.payload.status == SpanStatus.OK
|
||||
else trace.Status(status_code=trace.StatusCode.ERROR)
|
||||
)
|
||||
span.set_status(status)
|
||||
span.end()
|
||||
_GLOBAL_STORAGE["active_spans"].pop(span_id, None)
|
||||
else:
|
||||
raise ValueError(f"Unknown structured log event: {event}")
|
||||
|
||||
async def query_traces(
|
||||
self,
|
||||
attribute_filters: Optional[List[QueryCondition]] = None,
|
||||
limit: Optional[int] = 100,
|
||||
offset: Optional[int] = 0,
|
||||
order_by: Optional[List[str]] = None,
|
||||
) -> List[Trace]:
|
||||
return await self.trace_store.query_traces(
|
||||
attribute_filters=attribute_filters,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
order_by=order_by,
|
||||
)
|
||||
|
||||
async def get_span_tree(
|
||||
self,
|
||||
span_id: str,
|
||||
attributes_to_return: Optional[List[str]] = None,
|
||||
max_depth: Optional[int] = None,
|
||||
) -> SpanWithChildren:
|
||||
return await self.trace_store.get_materialized_span(
|
||||
span_id=span_id,
|
||||
attributes_to_return=attributes_to_return,
|
||||
max_depth=max_depth,
|
||||
)
|
|
@ -14,9 +14,12 @@ def available_providers() -> List[ProviderSpec]:
|
|||
InlineProviderSpec(
|
||||
api=Api.telemetry,
|
||||
provider_type="inline::meta-reference",
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.inline.meta_reference.telemetry",
|
||||
config_class="llama_stack.providers.inline.meta_reference.telemetry.ConsoleConfig",
|
||||
pip_packages=[
|
||||
"opentelemetry-sdk",
|
||||
"opentelemetry-exporter-otlp-proto-http",
|
||||
],
|
||||
module="llama_stack.providers.inline.telemetry.meta_reference",
|
||||
config_class="llama_stack.providers.inline.telemetry.meta_reference.config.TelemetryConfig",
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.telemetry,
|
||||
|
@ -27,18 +30,4 @@ def available_providers() -> List[ProviderSpec]:
|
|||
config_class="llama_stack.providers.remote.telemetry.sample.SampleConfig",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.telemetry,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="opentelemetry-jaeger",
|
||||
pip_packages=[
|
||||
"opentelemetry-api",
|
||||
"opentelemetry-sdk",
|
||||
"opentelemetry-exporter-jaeger",
|
||||
"opentelemetry-semantic-conventions",
|
||||
],
|
||||
module="llama_stack.providers.remote.telemetry.opentelemetry",
|
||||
config_class="llama_stack.providers.remote.telemetry.opentelemetry.OpenTelemetryConfig",
|
||||
),
|
||||
),
|
||||
]
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from llama_stack.apis.datasetio import * # noqa: F403
|
||||
|
||||
|
@ -100,3 +100,22 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
|||
total_count=len(rows),
|
||||
next_page_token=str(end),
|
||||
)
|
||||
|
||||
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
||||
dataset_def = self.dataset_infos[dataset_id]
|
||||
loaded_dataset = load_hf_dataset(dataset_def)
|
||||
|
||||
# Convert rows to HF Dataset format
|
||||
new_dataset = hf_datasets.Dataset.from_list(rows)
|
||||
|
||||
# Concatenate the new rows with existing dataset
|
||||
updated_dataset = hf_datasets.concatenate_datasets(
|
||||
[loaded_dataset, new_dataset]
|
||||
)
|
||||
|
||||
if dataset_def.metadata.get("path", None):
|
||||
updated_dataset.push_to_hub(dataset_def.metadata["path"])
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Uploading to URL-based datasets is not supported yet"
|
||||
)
|
||||
|
|
|
@ -1,15 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import OpenTelemetryConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: OpenTelemetryConfig, _deps):
|
||||
from .opentelemetry import OpenTelemetryAdapter
|
||||
|
||||
impl = OpenTelemetryAdapter(config)
|
||||
await impl.initialize()
|
||||
return impl
|
|
@ -1,27 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class OpenTelemetryConfig(BaseModel):
|
||||
otel_endpoint: str = Field(
|
||||
default="http://localhost:4318/v1/traces",
|
||||
description="The OpenTelemetry collector endpoint URL",
|
||||
)
|
||||
service_name: str = Field(
|
||||
default="llama-stack",
|
||||
description="The service name to use for telemetry",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
||||
return {
|
||||
"otel_endpoint": "${env.OTEL_ENDPOINT:http://localhost:4318/v1/traces}",
|
||||
"service_name": "${env.OTEL_SERVICE_NAME:llama-stack}",
|
||||
}
|
|
@ -5,6 +5,16 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import threading
|
||||
from typing import List, Optional
|
||||
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
from llama_stack.providers.remote.telemetry.opentelemetry.console_span_processor import (
|
||||
ConsoleSpanProcessor,
|
||||
)
|
||||
from llama_stack.providers.remote.telemetry.opentelemetry.sqlite_span_processor import (
|
||||
SQLiteSpanProcessor,
|
||||
)
|
||||
from llama_stack.providers.utils.telemetry.sqlite_trace_store import SQLiteTraceStore
|
||||
|
||||
from opentelemetry import metrics, trace
|
||||
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
|
||||
|
@ -19,7 +29,7 @@ from opentelemetry.semconv.resource import ResourceAttributes
|
|||
|
||||
from llama_stack.apis.telemetry import * # noqa: F403
|
||||
|
||||
from .config import OpenTelemetryConfig
|
||||
from .config import OpenTelemetryConfig, TelemetrySink
|
||||
|
||||
_GLOBAL_STORAGE = {
|
||||
"active_spans": {},
|
||||
|
@ -46,8 +56,9 @@ def is_tracing_enabled(tracer):
|
|||
|
||||
|
||||
class OpenTelemetryAdapter(Telemetry):
|
||||
def __init__(self, config: OpenTelemetryConfig):
|
||||
def __init__(self, config: OpenTelemetryConfig, deps) -> None:
|
||||
self.config = config
|
||||
self.datasetio = deps[Api.datasetio]
|
||||
|
||||
resource = Resource.create(
|
||||
{
|
||||
|
@ -57,22 +68,29 @@ class OpenTelemetryAdapter(Telemetry):
|
|||
|
||||
provider = TracerProvider(resource=resource)
|
||||
trace.set_tracer_provider(provider)
|
||||
otlp_exporter = OTLPSpanExporter(
|
||||
endpoint=self.config.otel_endpoint,
|
||||
)
|
||||
span_processor = BatchSpanProcessor(otlp_exporter)
|
||||
trace.get_tracer_provider().add_span_processor(span_processor)
|
||||
# Set up metrics
|
||||
metric_reader = PeriodicExportingMetricReader(
|
||||
OTLPMetricExporter(
|
||||
if TelemetrySink.JAEGER in self.config.sinks:
|
||||
otlp_exporter = OTLPSpanExporter(
|
||||
endpoint=self.config.otel_endpoint,
|
||||
)
|
||||
)
|
||||
metric_provider = MeterProvider(
|
||||
resource=resource, metric_readers=[metric_reader]
|
||||
)
|
||||
metrics.set_meter_provider(metric_provider)
|
||||
self.meter = metrics.get_meter(__name__)
|
||||
span_processor = BatchSpanProcessor(otlp_exporter)
|
||||
trace.get_tracer_provider().add_span_processor(span_processor)
|
||||
metric_reader = PeriodicExportingMetricReader(
|
||||
OTLPMetricExporter(
|
||||
endpoint=self.config.otel_endpoint,
|
||||
)
|
||||
)
|
||||
metric_provider = MeterProvider(
|
||||
resource=resource, metric_readers=[metric_reader]
|
||||
)
|
||||
metrics.set_meter_provider(metric_provider)
|
||||
self.meter = metrics.get_meter(__name__)
|
||||
if TelemetrySink.SQLITE in self.config.sinks:
|
||||
trace.get_tracer_provider().add_span_processor(
|
||||
SQLiteSpanProcessor(self.config.sqlite_db_path)
|
||||
)
|
||||
self.trace_store = SQLiteTraceStore(self.config.sqlite_db_path)
|
||||
if TelemetrySink.CONSOLE in self.config.sinks:
|
||||
trace.get_tracer_provider().add_span_processor(ConsoleSpanProcessor())
|
||||
self._lock = _global_lock
|
||||
|
||||
async def initialize(self) -> None:
|
||||
|
@ -83,15 +101,17 @@ class OpenTelemetryAdapter(Telemetry):
|
|||
trace.get_tracer_provider().shutdown()
|
||||
metrics.get_meter_provider().shutdown()
|
||||
|
||||
async def log_event(self, event: Event) -> None:
|
||||
async def log_event(self, event: Event, ttl_seconds: int = 604800) -> None:
|
||||
if isinstance(event, UnstructuredLogEvent):
|
||||
self._log_unstructured(event)
|
||||
self._log_unstructured(event, ttl_seconds)
|
||||
elif isinstance(event, MetricEvent):
|
||||
self._log_metric(event)
|
||||
elif isinstance(event, StructuredLogEvent):
|
||||
self._log_structured(event)
|
||||
self._log_structured(event, ttl_seconds)
|
||||
else:
|
||||
raise ValueError(f"Unknown event type: {event}")
|
||||
|
||||
def _log_unstructured(self, event: UnstructuredLogEvent) -> None:
|
||||
def _log_unstructured(self, event: UnstructuredLogEvent, ttl_seconds: int) -> None:
|
||||
with self._lock:
|
||||
# Use global storage instead of instance storage
|
||||
span_id = string_to_span_id(event.span_id)
|
||||
|
@ -104,6 +124,7 @@ class OpenTelemetryAdapter(Telemetry):
|
|||
attributes={
|
||||
"message": event.message,
|
||||
"severity": event.severity.value,
|
||||
"__ttl__": ttl_seconds,
|
||||
**event.attributes,
|
||||
},
|
||||
timestamp=timestamp_ns,
|
||||
|
@ -154,11 +175,14 @@ class OpenTelemetryAdapter(Telemetry):
|
|||
)
|
||||
return _GLOBAL_STORAGE["up_down_counters"][name]
|
||||
|
||||
def _log_structured(self, event: StructuredLogEvent) -> None:
|
||||
def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None:
|
||||
with self._lock:
|
||||
span_id = string_to_span_id(event.span_id)
|
||||
trace_id = string_to_trace_id(event.trace_id)
|
||||
tracer = trace.get_tracer(__name__)
|
||||
if event.attributes is None:
|
||||
event.attributes = {}
|
||||
event.attributes["__ttl__"] = ttl_seconds
|
||||
|
||||
if isinstance(event.payload, SpanStartPayload):
|
||||
# Check if span already exists to prevent duplicates
|
||||
|
@ -170,7 +194,6 @@ class OpenTelemetryAdapter(Telemetry):
|
|||
parent_span_id = string_to_span_id(event.payload.parent_span_id)
|
||||
parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id)
|
||||
|
||||
# Create a new trace context with the trace_id
|
||||
context = trace.Context(trace_id=trace_id)
|
||||
if parent_span:
|
||||
context = trace.set_span_in_context(parent_span, context)
|
||||
|
@ -179,14 +202,9 @@ class OpenTelemetryAdapter(Telemetry):
|
|||
name=event.payload.name,
|
||||
context=context,
|
||||
attributes=event.attributes or {},
|
||||
start_time=int(event.timestamp.timestamp() * 1e9),
|
||||
)
|
||||
_GLOBAL_STORAGE["active_spans"][span_id] = span
|
||||
|
||||
# Set as current span using context manager
|
||||
with trace.use_span(span, end_on_exit=False):
|
||||
pass # Let the span continue beyond this block
|
||||
|
||||
elif isinstance(event.payload, SpanEndPayload):
|
||||
span = _GLOBAL_STORAGE["active_spans"].get(span_id)
|
||||
if span:
|
||||
|
@ -199,10 +217,43 @@ class OpenTelemetryAdapter(Telemetry):
|
|||
else trace.Status(status_code=trace.StatusCode.ERROR)
|
||||
)
|
||||
span.set_status(status)
|
||||
span.end(end_time=int(event.timestamp.timestamp() * 1e9))
|
||||
|
||||
# Remove from active spans
|
||||
span.end()
|
||||
_GLOBAL_STORAGE["active_spans"].pop(span_id, None)
|
||||
else:
|
||||
raise ValueError(f"Unknown structured log event: {event}")
|
||||
|
||||
async def get_trace(self, trace_id: str) -> Trace:
|
||||
raise NotImplementedError("Trace retrieval not implemented yet")
|
||||
async def query_traces(
|
||||
self,
|
||||
attribute_conditions: Optional[List[QueryCondition]] = None,
|
||||
attribute_keys_to_return: Optional[List[str]] = None,
|
||||
limit: Optional[int] = 100,
|
||||
offset: Optional[int] = 0,
|
||||
order_by: Optional[List[str]] = None,
|
||||
) -> List[Trace]:
|
||||
return await self.trace_store.query_traces(
|
||||
attribute_conditions=attribute_conditions,
|
||||
attribute_keys_to_return=attribute_keys_to_return,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
order_by=order_by,
|
||||
)
|
||||
|
||||
async def get_spans(
|
||||
self,
|
||||
span_id: str,
|
||||
attribute_conditions: Optional[List[QueryCondition]] = None,
|
||||
attribute_keys_to_return: Optional[List[str]] = None,
|
||||
max_depth: Optional[int] = None,
|
||||
limit: Optional[int] = 100,
|
||||
offset: Optional[int] = 0,
|
||||
order_by: Optional[List[str]] = None,
|
||||
) -> SpanWithChildren:
|
||||
return await self.trace_store.get_spans(
|
||||
span_id=span_id,
|
||||
attribute_conditions=attribute_conditions,
|
||||
attribute_keys_to_return=attribute_keys_to_return,
|
||||
max_depth=max_depth,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
order_by=order_by,
|
||||
)
|
||||
|
|
177
llama_stack/providers/utils/telemetry/sqlite.py
Normal file
177
llama_stack/providers/utils/telemetry/sqlite.py
Normal file
|
@ -0,0 +1,177 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
import aiosqlite
|
||||
|
||||
from llama_stack.apis.telemetry import (
|
||||
QueryCondition,
|
||||
SpanWithChildren,
|
||||
Trace,
|
||||
TraceStore,
|
||||
)
|
||||
|
||||
|
||||
class SQLiteTraceStore(TraceStore):
|
||||
def __init__(self, conn_string: str):
|
||||
self.conn_string = conn_string
|
||||
|
||||
async def query_traces(
|
||||
self,
|
||||
attribute_filters: Optional[List[QueryCondition]] = None,
|
||||
attributes_to_return: Optional[List[str]] = None,
|
||||
limit: Optional[int] = 100,
|
||||
offset: Optional[int] = 0,
|
||||
order_by: Optional[List[str]] = None,
|
||||
) -> List[Trace]:
|
||||
print(attribute_filters, attributes_to_return, limit, offset, order_by)
|
||||
|
||||
def build_attribute_select() -> str:
|
||||
if not attributes_to_return:
|
||||
return ""
|
||||
return "".join(
|
||||
f", json_extract(s.attributes, '$.{key}') as attr_{key}"
|
||||
for key in attributes_to_return
|
||||
)
|
||||
|
||||
def build_where_clause() -> tuple[str, list]:
|
||||
if not attribute_filters:
|
||||
return "", []
|
||||
|
||||
conditions = [
|
||||
f"json_extract(s.attributes, '$.{condition.key}') {condition.op} ?"
|
||||
for condition in attribute_filters
|
||||
]
|
||||
params = [condition.value for condition in attribute_filters]
|
||||
where_clause = " WHERE " + " AND ".join(conditions)
|
||||
return where_clause, params
|
||||
|
||||
def build_order_clause() -> str:
|
||||
if not order_by:
|
||||
return ""
|
||||
|
||||
order_clauses = []
|
||||
for field in order_by:
|
||||
desc = field.startswith("-")
|
||||
clean_field = field[1:] if desc else field
|
||||
order_clauses.append(f"t.{clean_field} {'DESC' if desc else 'ASC'}")
|
||||
return " ORDER BY " + ", ".join(order_clauses)
|
||||
|
||||
# Build the main query
|
||||
base_query = """
|
||||
WITH matching_traces AS (
|
||||
SELECT DISTINCT t.trace_id
|
||||
FROM traces t
|
||||
JOIN spans s ON t.trace_id = s.trace_id
|
||||
{where_clause}
|
||||
),
|
||||
filtered_traces AS (
|
||||
SELECT t.trace_id, t.root_span_id, t.start_time, t.end_time
|
||||
{attribute_select}
|
||||
FROM matching_traces mt
|
||||
JOIN traces t ON mt.trace_id = t.trace_id
|
||||
LEFT JOIN spans s ON t.trace_id = s.trace_id
|
||||
{order_clause}
|
||||
)
|
||||
SELECT DISTINCT trace_id, root_span_id, start_time, end_time
|
||||
FROM filtered_traces
|
||||
LIMIT {limit} OFFSET {offset}
|
||||
"""
|
||||
|
||||
where_clause, params = build_where_clause()
|
||||
query = base_query.format(
|
||||
attribute_select=build_attribute_select(),
|
||||
where_clause=where_clause,
|
||||
order_clause=build_order_clause(),
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
# Execute query and return results
|
||||
async with aiosqlite.connect(self.conn_string) as conn:
|
||||
conn.row_factory = aiosqlite.Row
|
||||
async with conn.execute(query, params) as cursor:
|
||||
rows = await cursor.fetchall()
|
||||
return [
|
||||
Trace(
|
||||
trace_id=row["trace_id"],
|
||||
root_span_id=row["root_span_id"],
|
||||
start_time=datetime.fromisoformat(row["start_time"]),
|
||||
end_time=datetime.fromisoformat(row["end_time"]),
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
async def get_materialized_span(
|
||||
self,
|
||||
span_id: str,
|
||||
attributes_to_return: Optional[List[str]] = None,
|
||||
max_depth: Optional[int] = None,
|
||||
) -> SpanWithChildren:
|
||||
# Build the attributes selection
|
||||
attributes_select = "s.attributes"
|
||||
if attributes_to_return:
|
||||
json_object = ", ".join(
|
||||
f"'{key}', json_extract(s.attributes, '$.{key}')"
|
||||
for key in attributes_to_return
|
||||
)
|
||||
attributes_select = f"json_object({json_object})"
|
||||
|
||||
# SQLite CTE query with filtered attributes
|
||||
query = f"""
|
||||
WITH RECURSIVE span_tree AS (
|
||||
SELECT s.*, 1 as depth, {attributes_select} as filtered_attributes
|
||||
FROM spans s
|
||||
WHERE s.span_id = ?
|
||||
|
||||
UNION ALL
|
||||
|
||||
SELECT s.*, st.depth + 1, {attributes_select} as filtered_attributes
|
||||
FROM spans s
|
||||
JOIN span_tree st ON s.parent_span_id = st.span_id
|
||||
WHERE (? IS NULL OR st.depth < ?)
|
||||
)
|
||||
SELECT *
|
||||
FROM span_tree
|
||||
ORDER BY depth, start_time
|
||||
"""
|
||||
|
||||
async with aiosqlite.connect(self.conn_string) as conn:
|
||||
conn.row_factory = aiosqlite.Row
|
||||
async with conn.execute(query, (span_id, max_depth, max_depth)) as cursor:
|
||||
rows = await cursor.fetchall()
|
||||
|
||||
if not rows:
|
||||
raise ValueError(f"Span {span_id} not found")
|
||||
|
||||
# Build span tree
|
||||
spans_by_id = {}
|
||||
root_span = None
|
||||
|
||||
for row in rows:
|
||||
span = SpanWithChildren(
|
||||
span_id=row["span_id"],
|
||||
trace_id=row["trace_id"],
|
||||
parent_span_id=row["parent_span_id"],
|
||||
name=row["name"],
|
||||
start_time=datetime.fromisoformat(row["start_time"]),
|
||||
end_time=datetime.fromisoformat(row["end_time"]),
|
||||
attributes=json.loads(row["filtered_attributes"]),
|
||||
status=row["status"].lower(),
|
||||
children=[],
|
||||
)
|
||||
|
||||
spans_by_id[span.span_id] = span
|
||||
|
||||
if span.span_id == span_id:
|
||||
root_span = span
|
||||
elif span.parent_span_id in spans_by_id:
|
||||
spans_by_id[span.parent_span_id].children.append(span)
|
||||
|
||||
return root_span
|
180
llama_stack/providers/utils/telemetry/sqlite_trace_store.py
Normal file
180
llama_stack/providers/utils/telemetry/sqlite_trace_store.py
Normal file
|
@ -0,0 +1,180 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Protocol
|
||||
|
||||
import aiosqlite
|
||||
|
||||
from llama_stack.apis.telemetry import QueryCondition, SpanWithChildren, Trace
|
||||
|
||||
|
||||
class TraceStore(Protocol):
|
||||
|
||||
async def query_traces(
|
||||
self,
|
||||
attribute_filters: Optional[List[QueryCondition]] = None,
|
||||
limit: Optional[int] = 100,
|
||||
offset: Optional[int] = 0,
|
||||
order_by: Optional[List[str]] = None,
|
||||
) -> List[Trace]: ...
|
||||
|
||||
async def get_materialized_span(
|
||||
self,
|
||||
span_id: str,
|
||||
attributes_to_return: Optional[List[str]] = None,
|
||||
max_depth: Optional[int] = None,
|
||||
) -> SpanWithChildren: ...
|
||||
|
||||
|
||||
class SQLiteTraceStore(TraceStore):
|
||||
def __init__(self, conn_string: str):
|
||||
self.conn_string = conn_string
|
||||
|
||||
async def query_traces(
|
||||
self,
|
||||
attribute_filters: Optional[List[QueryCondition]] = None,
|
||||
limit: Optional[int] = 100,
|
||||
offset: Optional[int] = 0,
|
||||
order_by: Optional[List[str]] = None,
|
||||
) -> List[Trace]:
|
||||
|
||||
def build_where_clause() -> tuple[str, list]:
|
||||
if not attribute_filters:
|
||||
return "", []
|
||||
|
||||
ops_map = {"eq": "=", "ne": "!=", "gt": ">", "lt": "<"}
|
||||
|
||||
conditions = [
|
||||
f"json_extract(s.attributes, '$.{condition.key}') {ops_map[condition.op]} ?"
|
||||
for condition in attribute_filters
|
||||
]
|
||||
params = [condition.value for condition in attribute_filters]
|
||||
where_clause = " WHERE " + " AND ".join(conditions)
|
||||
return where_clause, params
|
||||
|
||||
def build_order_clause() -> str:
|
||||
if not order_by:
|
||||
return ""
|
||||
|
||||
order_clauses = []
|
||||
for field in order_by:
|
||||
desc = field.startswith("-")
|
||||
clean_field = field[1:] if desc else field
|
||||
order_clauses.append(f"t.{clean_field} {'DESC' if desc else 'ASC'}")
|
||||
return " ORDER BY " + ", ".join(order_clauses)
|
||||
|
||||
# Build the main query
|
||||
base_query = """
|
||||
WITH matching_traces AS (
|
||||
SELECT DISTINCT t.trace_id
|
||||
FROM traces t
|
||||
JOIN spans s ON t.trace_id = s.trace_id
|
||||
{where_clause}
|
||||
),
|
||||
filtered_traces AS (
|
||||
SELECT t.trace_id, t.root_span_id, t.start_time, t.end_time
|
||||
FROM matching_traces mt
|
||||
JOIN traces t ON mt.trace_id = t.trace_id
|
||||
LEFT JOIN spans s ON t.trace_id = s.trace_id
|
||||
{order_clause}
|
||||
)
|
||||
SELECT DISTINCT trace_id, root_span_id, start_time, end_time
|
||||
FROM filtered_traces
|
||||
LIMIT {limit} OFFSET {offset}
|
||||
"""
|
||||
|
||||
where_clause, params = build_where_clause()
|
||||
query = base_query.format(
|
||||
where_clause=where_clause,
|
||||
order_clause=build_order_clause(),
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
# Execute query and return results
|
||||
async with aiosqlite.connect(self.conn_string) as conn:
|
||||
conn.row_factory = aiosqlite.Row
|
||||
async with conn.execute(query, params) as cursor:
|
||||
rows = await cursor.fetchall()
|
||||
return [
|
||||
Trace(
|
||||
trace_id=row["trace_id"],
|
||||
root_span_id=row["root_span_id"],
|
||||
start_time=datetime.fromisoformat(row["start_time"]),
|
||||
end_time=datetime.fromisoformat(row["end_time"]),
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
async def get_materialized_span(
|
||||
self,
|
||||
span_id: str,
|
||||
attributes_to_return: Optional[List[str]] = None,
|
||||
max_depth: Optional[int] = None,
|
||||
) -> SpanWithChildren:
|
||||
# Build the attributes selection
|
||||
attributes_select = "s.attributes"
|
||||
if attributes_to_return:
|
||||
json_object = ", ".join(
|
||||
f"'{key}', json_extract(s.attributes, '$.{key}')"
|
||||
for key in attributes_to_return
|
||||
)
|
||||
attributes_select = f"json_object({json_object})"
|
||||
|
||||
# SQLite CTE query with filtered attributes
|
||||
query = f"""
|
||||
WITH RECURSIVE span_tree AS (
|
||||
SELECT s.*, 1 as depth, {attributes_select} as filtered_attributes
|
||||
FROM spans s
|
||||
WHERE s.span_id = ?
|
||||
|
||||
UNION ALL
|
||||
|
||||
SELECT s.*, st.depth + 1, {attributes_select} as filtered_attributes
|
||||
FROM spans s
|
||||
JOIN span_tree st ON s.parent_span_id = st.span_id
|
||||
WHERE (? IS NULL OR st.depth < ?)
|
||||
)
|
||||
SELECT *
|
||||
FROM span_tree
|
||||
ORDER BY depth, start_time
|
||||
"""
|
||||
|
||||
async with aiosqlite.connect(self.conn_string) as conn:
|
||||
conn.row_factory = aiosqlite.Row
|
||||
async with conn.execute(query, (span_id, max_depth, max_depth)) as cursor:
|
||||
rows = await cursor.fetchall()
|
||||
|
||||
if not rows:
|
||||
raise ValueError(f"Span {span_id} not found")
|
||||
|
||||
# Build span tree
|
||||
spans_by_id = {}
|
||||
root_span = None
|
||||
|
||||
for row in rows:
|
||||
span = SpanWithChildren(
|
||||
span_id=row["span_id"],
|
||||
trace_id=row["trace_id"],
|
||||
parent_span_id=row["parent_span_id"],
|
||||
name=row["name"],
|
||||
start_time=datetime.fromisoformat(row["start_time"]),
|
||||
end_time=datetime.fromisoformat(row["end_time"]),
|
||||
attributes=json.loads(row["filtered_attributes"]),
|
||||
status=row["status"].lower(),
|
||||
children=[],
|
||||
)
|
||||
|
||||
spans_by_id[span.span_id] = span
|
||||
|
||||
if span.span_id == span_id:
|
||||
root_span = span
|
||||
elif span.parent_span_id in spans_by_id:
|
||||
spans_by_id[span.parent_span_id].children.append(span)
|
||||
|
||||
return root_span
|
|
@ -69,7 +69,7 @@ class TraceContext:
|
|||
self.logger = logger
|
||||
self.trace_id = trace_id
|
||||
|
||||
def push_span(self, name: str, attributes: Dict[str, Any] = None):
|
||||
def push_span(self, name: str, attributes: Dict[str, Any] = None) -> Span:
|
||||
current_span = self.get_current_span()
|
||||
span = Span(
|
||||
span_id=generate_short_uuid(),
|
||||
|
@ -94,6 +94,7 @@ class TraceContext:
|
|||
)
|
||||
|
||||
self.spans.append(span)
|
||||
return span
|
||||
|
||||
def pop_span(self, status: SpanStatus = SpanStatus.OK):
|
||||
span = self.spans.pop()
|
||||
|
@ -203,12 +204,13 @@ class SpanContextManager:
|
|||
def __init__(self, name: str, attributes: Dict[str, Any] = None):
|
||||
self.name = name
|
||||
self.attributes = attributes
|
||||
self.span = None
|
||||
|
||||
def __enter__(self):
|
||||
global CURRENT_TRACE_CONTEXT
|
||||
context = CURRENT_TRACE_CONTEXT
|
||||
if context:
|
||||
context.push_span(self.name, self.attributes)
|
||||
self.span = context.push_span(self.name, self.attributes)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
|
@ -217,11 +219,24 @@ class SpanContextManager:
|
|||
if context:
|
||||
context.pop_span()
|
||||
|
||||
def set_attribute(self, key: str, value: Any):
|
||||
if self.span:
|
||||
if self.span.attributes is None:
|
||||
self.span.attributes = {}
|
||||
self.span.attributes[key] = value
|
||||
|
||||
async def __aenter__(self):
|
||||
return self.__enter__()
|
||||
global CURRENT_TRACE_CONTEXT
|
||||
context = CURRENT_TRACE_CONTEXT
|
||||
if context:
|
||||
self.span = context.push_span(self.name, self.attributes)
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||
self.__exit__(exc_type, exc_value, traceback)
|
||||
global CURRENT_TRACE_CONTEXT
|
||||
context = CURRENT_TRACE_CONTEXT
|
||||
if context:
|
||||
context.pop_span()
|
||||
|
||||
def __call__(self, func: Callable):
|
||||
@wraps(func)
|
||||
|
@ -246,3 +261,11 @@ class SpanContextManager:
|
|||
|
||||
def span(name: str, attributes: Dict[str, Any] = None):
|
||||
return SpanContextManager(name, attributes)
|
||||
|
||||
|
||||
def get_current_span() -> Optional[Span]:
|
||||
global CURRENT_TRACE_CONTEXT
|
||||
context = CURRENT_TRACE_CONTEXT
|
||||
if context:
|
||||
return context.get_current_span()
|
||||
return None
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue