mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 00:34:44 +00:00
use postgres to store traces and query
This commit is contained in:
parent
4dd08e5595
commit
7a2c1126bb
5 changed files with 235 additions and 9 deletions
|
@ -4,14 +4,12 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from llama_stack.providers.utils.telemetry.jaeger import JaegerTraceStore
|
|
||||||
from .config import OpenTelemetryConfig
|
from .config import OpenTelemetryConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(config: OpenTelemetryConfig, deps):
|
async def get_adapter_impl(config: OpenTelemetryConfig, deps):
|
||||||
from .opentelemetry import OpenTelemetryAdapter
|
from .opentelemetry import OpenTelemetryAdapter
|
||||||
|
|
||||||
trace_store = JaegerTraceStore(config.jaeger_query_endpoint, config.service_name)
|
impl = OpenTelemetryAdapter(config, deps)
|
||||||
impl = OpenTelemetryAdapter(config, trace_store, deps)
|
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -18,10 +18,18 @@ class OpenTelemetryConfig(BaseModel):
|
||||||
default="llama-stack",
|
default="llama-stack",
|
||||||
description="The service name to use for telemetry",
|
description="The service name to use for telemetry",
|
||||||
)
|
)
|
||||||
|
trace_store: str = Field(
|
||||||
|
default="postgres",
|
||||||
|
description="The trace store to use for telemetry",
|
||||||
|
)
|
||||||
jaeger_query_endpoint: str = Field(
|
jaeger_query_endpoint: str = Field(
|
||||||
default="http://localhost:16686/api/traces",
|
default="http://localhost:16686/api/traces",
|
||||||
description="The Jaeger query endpoint URL",
|
description="The Jaeger query endpoint URL",
|
||||||
)
|
)
|
||||||
|
postgres_conn_string: str = Field(
|
||||||
|
default="host=localhost dbname=llama_stack user=llama_stack password=llama_stack port=5432",
|
||||||
|
description="The PostgreSQL connection string to use for storing traces",
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
||||||
|
|
|
@ -18,6 +18,11 @@ from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||||
from opentelemetry.semconv.resource import ResourceAttributes
|
from opentelemetry.semconv.resource import ResourceAttributes
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api
|
from llama_stack.distribution.datatypes import Api
|
||||||
|
from llama_stack.providers.remote.telemetry.opentelemetry.postgres_processor import (
|
||||||
|
PostgresSpanProcessor,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.utils.telemetry.jaeger import JaegerTraceStore
|
||||||
|
from llama_stack.providers.utils.telemetry.postgres import PostgresTraceStore
|
||||||
|
|
||||||
|
|
||||||
from llama_stack.apis.telemetry import * # noqa: F403
|
from llama_stack.apis.telemetry import * # noqa: F403
|
||||||
|
@ -49,12 +54,18 @@ def is_tracing_enabled(tracer):
|
||||||
|
|
||||||
|
|
||||||
class OpenTelemetryAdapter(Telemetry):
|
class OpenTelemetryAdapter(Telemetry):
|
||||||
def __init__(
|
def __init__(self, config: OpenTelemetryConfig, deps) -> None:
|
||||||
self, config: OpenTelemetryConfig, trace_store: TraceStore, deps
|
|
||||||
) -> None:
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.datasetio = deps[Api.datasetio]
|
self.datasetio = deps[Api.datasetio]
|
||||||
self.trace_store = trace_store
|
|
||||||
|
if config.trace_store == "jaeger":
|
||||||
|
self.trace_store = JaegerTraceStore(
|
||||||
|
config.jaeger_query_endpoint, config.service_name
|
||||||
|
)
|
||||||
|
elif config.trace_store == "postgres":
|
||||||
|
self.trace_store = PostgresTraceStore(config.postgres_conn_string)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid trace store: {config.trace_store}")
|
||||||
|
|
||||||
resource = Resource.create(
|
resource = Resource.create(
|
||||||
{
|
{
|
||||||
|
@ -69,6 +80,9 @@ class OpenTelemetryAdapter(Telemetry):
|
||||||
)
|
)
|
||||||
span_processor = BatchSpanProcessor(otlp_exporter)
|
span_processor = BatchSpanProcessor(otlp_exporter)
|
||||||
trace.get_tracer_provider().add_span_processor(span_processor)
|
trace.get_tracer_provider().add_span_processor(span_processor)
|
||||||
|
trace.get_tracer_provider().add_span_processor(
|
||||||
|
PostgresSpanProcessor(self.config.postgres_conn_string)
|
||||||
|
)
|
||||||
# Set up metrics
|
# Set up metrics
|
||||||
metric_reader = PeriodicExportingMetricReader(
|
metric_reader = PeriodicExportingMetricReader(
|
||||||
OTLPMetricExporter(
|
OTLPMetricExporter(
|
||||||
|
@ -252,8 +266,8 @@ class OpenTelemetryAdapter(Telemetry):
|
||||||
results.append(
|
results.append(
|
||||||
EvalTrace(
|
EvalTrace(
|
||||||
step=child.span.name,
|
step=child.span.name,
|
||||||
input=child.span.attributes.get("input", ""),
|
input=str(child.span.attributes.get("input", "")),
|
||||||
output=child.span.attributes.get("output", ""),
|
output=str(child.span.attributes.get("output", "")),
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
expected_output="",
|
expected_output="",
|
||||||
)
|
)
|
||||||
|
|
|
@ -0,0 +1,92 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import psycopg2
|
||||||
|
from opentelemetry.sdk.trace import SpanProcessor
|
||||||
|
from opentelemetry.trace import Span
|
||||||
|
|
||||||
|
|
||||||
|
class PostgresSpanProcessor(SpanProcessor):
|
||||||
|
def __init__(self, conn_string):
|
||||||
|
"""Initialize the PostgreSQL span processor with a connection string."""
|
||||||
|
self.conn_string = conn_string
|
||||||
|
self.conn = None
|
||||||
|
self.setup_database()
|
||||||
|
|
||||||
|
def setup_database(self):
|
||||||
|
"""Create the necessary table if it doesn't exist."""
|
||||||
|
with psycopg2.connect(self.conn_string) as conn:
|
||||||
|
with conn.cursor() as cur:
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS traces (
|
||||||
|
trace_id TEXT,
|
||||||
|
span_id TEXT,
|
||||||
|
parent_span_id TEXT,
|
||||||
|
name TEXT,
|
||||||
|
start_time TIMESTAMP,
|
||||||
|
end_time TIMESTAMP,
|
||||||
|
attributes JSONB,
|
||||||
|
status TEXT,
|
||||||
|
kind TEXT,
|
||||||
|
service_name TEXT,
|
||||||
|
session_id TEXT
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
def on_start(self, span: Span, parent_context=None):
|
||||||
|
"""Called when a span starts."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_end(self, span: Span):
|
||||||
|
"""Called when a span ends. Export the span data to PostgreSQL."""
|
||||||
|
try:
|
||||||
|
with psycopg2.connect(self.conn_string) as conn:
|
||||||
|
with conn.cursor() as cur:
|
||||||
|
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO traces (
|
||||||
|
trace_id, span_id, parent_span_id, name,
|
||||||
|
start_time, end_time, attributes, status,
|
||||||
|
kind, service_name, session_id
|
||||||
|
) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
format(span.get_span_context().trace_id, "032x"),
|
||||||
|
format(span.get_span_context().span_id, "016x"),
|
||||||
|
(
|
||||||
|
format(span.parent.span_id, "016x")
|
||||||
|
if span.parent
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
span.name,
|
||||||
|
datetime.fromtimestamp(span.start_time / 1e9),
|
||||||
|
datetime.fromtimestamp(span.end_time / 1e9),
|
||||||
|
json.dumps(dict(span.attributes)),
|
||||||
|
span.status.status_code.name,
|
||||||
|
span.kind.name,
|
||||||
|
span.resource.attributes.get("service.name", "unknown"),
|
||||||
|
span.attributes.get("session_id", None),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error exporting span to PostgreSQL: {e}")
|
||||||
|
|
||||||
|
def shutdown(self):
|
||||||
|
"""Cleanup any resources."""
|
||||||
|
if self.conn:
|
||||||
|
self.conn.close()
|
||||||
|
|
||||||
|
def force_flush(self, timeout_millis=30000):
|
||||||
|
"""Force export of spans."""
|
||||||
|
pass
|
114
llama_stack/providers/utils/telemetry/postgres.py
Normal file
114
llama_stack/providers/utils/telemetry/postgres.py
Normal file
|
@ -0,0 +1,114 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import psycopg2
|
||||||
|
|
||||||
|
from llama_stack.apis.telemetry import Span, SpanNode, Trace, TraceStore, TraceTree
|
||||||
|
|
||||||
|
|
||||||
|
class PostgresTraceStore(TraceStore):
|
||||||
|
def __init__(self, conn_string: str):
|
||||||
|
self.conn_string = conn_string
|
||||||
|
|
||||||
|
async def get_trace(self, trace_id: str) -> Optional[TraceTree]:
|
||||||
|
try:
|
||||||
|
with psycopg2.connect(self.conn_string) as conn:
|
||||||
|
with conn.cursor() as cur:
|
||||||
|
# Fetch all spans for the trace
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
SELECT trace_id, span_id, parent_span_id, name,
|
||||||
|
start_time, end_time, attributes
|
||||||
|
FROM traces
|
||||||
|
WHERE trace_id = %s
|
||||||
|
""",
|
||||||
|
(trace_id,),
|
||||||
|
)
|
||||||
|
spans_data = cur.fetchall()
|
||||||
|
|
||||||
|
if not spans_data:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# First pass: Build span map
|
||||||
|
span_map = {}
|
||||||
|
for span_data in spans_data:
|
||||||
|
# Ensure attributes is a string before parsing
|
||||||
|
attributes = span_data[6]
|
||||||
|
if isinstance(attributes, dict):
|
||||||
|
attributes = json.dumps(attributes)
|
||||||
|
|
||||||
|
span = Span(
|
||||||
|
span_id=span_data[1],
|
||||||
|
trace_id=span_data[0],
|
||||||
|
name=span_data[3],
|
||||||
|
start_time=span_data[4],
|
||||||
|
end_time=span_data[5],
|
||||||
|
parent_span_id=span_data[2],
|
||||||
|
attributes=json.loads(
|
||||||
|
attributes
|
||||||
|
), # Now safely parse the JSON string
|
||||||
|
)
|
||||||
|
span_map[span.span_id] = SpanNode(span=span)
|
||||||
|
|
||||||
|
# Second pass: Build parent-child relationships
|
||||||
|
root_node = None
|
||||||
|
for span_node in span_map.values():
|
||||||
|
parent_id = span_node.span.parent_span_id
|
||||||
|
if parent_id and parent_id in span_map:
|
||||||
|
span_map[parent_id].children.append(span_node)
|
||||||
|
elif not parent_id:
|
||||||
|
root_node = span_node
|
||||||
|
|
||||||
|
trace = Trace(
|
||||||
|
trace_id=trace_id,
|
||||||
|
root_span_id=root_node.span.span_id if root_node else "",
|
||||||
|
start_time=(
|
||||||
|
root_node.span.start_time if root_node else datetime.now()
|
||||||
|
),
|
||||||
|
end_time=root_node.span.end_time if root_node else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
return TraceTree(trace=trace, root=root_node)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(
|
||||||
|
f"Error querying PostgreSQL trace structure: {str(e)}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
async def get_traces_for_sessions(self, session_ids: List[str]) -> List[Trace]:
|
||||||
|
traces = []
|
||||||
|
try:
|
||||||
|
with psycopg2.connect(self.conn_string) as conn:
|
||||||
|
with conn.cursor() as cur:
|
||||||
|
# Query traces for all session IDs
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
SELECT DISTINCT trace_id, MIN(start_time) as start_time
|
||||||
|
FROM traces
|
||||||
|
WHERE attributes->>'session_id' = ANY(%s)
|
||||||
|
GROUP BY trace_id
|
||||||
|
""",
|
||||||
|
(session_ids,),
|
||||||
|
)
|
||||||
|
traces_data = cur.fetchall()
|
||||||
|
|
||||||
|
for trace_data in traces_data:
|
||||||
|
traces.append(
|
||||||
|
Trace(
|
||||||
|
trace_id=trace_data[0],
|
||||||
|
root_span_id="",
|
||||||
|
start_time=trace_data[1],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"Error querying PostgreSQL traces: {str(e)}") from e
|
||||||
|
|
||||||
|
return traces
|
Loading…
Add table
Add a link
Reference in a new issue