forked from phoenix-oss/llama-stack-mirror
Fix opentelemetry adapter (#510)
# What does this PR do? This PR fixes some of the issues with our telemetry setup to enable logs to be delivered to opentelemetry and jaeger. Main fixes 1) Updates the open telemetry provider to use the latest oltp exports instead of deprected ones. 2) Adds a tracing middleware, which injects traces into each HTTP request that the server recieves and this is going to be the root trace. Previously, we did this in the create_dynamic_route method, which is actually not the actual exectuion flow, but more of a config and this causes the traces to end prematurely. Through middleware, we plugin the trace start and end at the right location. 3) We manage our own methods to create traces and spans and this does not fit well with Opentelemetry SDK since it does not support provide a way to take in traces and spans that are already created. it expects us to use the SDK to create them. For now, I have a hacky approach of just maintaining a map from our internal telemetry objects to the open telemetry specfic ones. This is not the ideal solution. I will explore other ways to get around this issue. for now, to have something that works, i am going to keep this as is. Addresses: #509
This commit is contained in:
parent
beab798a1d
commit
501e7c9d64
11 changed files with 185 additions and 217 deletions
|
@ -40,7 +40,7 @@ class ModelsClient(Models):
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
f"{self.base_url}/models/register",
|
f"{self.base_url}/models/register",
|
||||||
json={
|
json={
|
||||||
"model": json.loads(model.json()),
|
"model": json.loads(model.model_dump_json()),
|
||||||
},
|
},
|
||||||
headers={"Content-Type": "application/json"},
|
headers={"Content-Type": "application/json"},
|
||||||
)
|
)
|
||||||
|
|
|
@ -17,13 +17,11 @@ import warnings
|
||||||
|
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from ssl import SSLError
|
from typing import Any, Union
|
||||||
from typing import Any, Dict, Optional
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from fastapi import Body, FastAPI, HTTPException, Request, Response
|
from fastapi import Body, FastAPI, HTTPException, Request
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
from pydantic import BaseModel, ValidationError
|
from pydantic import BaseModel, ValidationError
|
||||||
|
@ -35,7 +33,6 @@ from llama_stack.distribution.distribution import builtin_automatically_routed_a
|
||||||
from llama_stack.providers.utils.telemetry.tracing import (
|
from llama_stack.providers.utils.telemetry.tracing import (
|
||||||
end_trace,
|
end_trace,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
SpanStatus,
|
|
||||||
start_trace,
|
start_trace,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
|
@ -118,67 +115,6 @@ def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidatio
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def passthrough(
|
|
||||||
request: Request,
|
|
||||||
downstream_url: str,
|
|
||||||
downstream_headers: Optional[Dict[str, str]] = None,
|
|
||||||
):
|
|
||||||
await start_trace(request.path, {"downstream_url": downstream_url})
|
|
||||||
|
|
||||||
headers = dict(request.headers)
|
|
||||||
headers.pop("host", None)
|
|
||||||
headers.update(downstream_headers or {})
|
|
||||||
|
|
||||||
content = await request.body()
|
|
||||||
|
|
||||||
client = httpx.AsyncClient()
|
|
||||||
erred = False
|
|
||||||
try:
|
|
||||||
req = client.build_request(
|
|
||||||
method=request.method,
|
|
||||||
url=downstream_url,
|
|
||||||
headers=headers,
|
|
||||||
content=content,
|
|
||||||
params=request.query_params,
|
|
||||||
)
|
|
||||||
response = await client.send(req, stream=True)
|
|
||||||
|
|
||||||
async def stream_response():
|
|
||||||
async for chunk in response.aiter_raw(chunk_size=64):
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
await response.aclose()
|
|
||||||
await client.aclose()
|
|
||||||
|
|
||||||
return StreamingResponse(
|
|
||||||
stream_response(),
|
|
||||||
status_code=response.status_code,
|
|
||||||
headers=dict(response.headers),
|
|
||||||
media_type=response.headers.get("content-type"),
|
|
||||||
)
|
|
||||||
|
|
||||||
except httpx.ReadTimeout:
|
|
||||||
erred = True
|
|
||||||
return Response(content="Downstream server timed out", status_code=504)
|
|
||||||
except httpx.NetworkError as e:
|
|
||||||
erred = True
|
|
||||||
return Response(content=f"Network error: {str(e)}", status_code=502)
|
|
||||||
except httpx.TooManyRedirects:
|
|
||||||
erred = True
|
|
||||||
return Response(content="Too many redirects", status_code=502)
|
|
||||||
except SSLError as e:
|
|
||||||
erred = True
|
|
||||||
return Response(content=f"SSL error: {str(e)}", status_code=502)
|
|
||||||
except httpx.HTTPStatusError as e:
|
|
||||||
erred = True
|
|
||||||
return Response(content=str(e), status_code=e.response.status_code)
|
|
||||||
except Exception as e:
|
|
||||||
erred = True
|
|
||||||
return Response(content=f"Unexpected error: {str(e)}", status_code=500)
|
|
||||||
finally:
|
|
||||||
await end_trace(SpanStatus.OK if not erred else SpanStatus.ERROR)
|
|
||||||
|
|
||||||
|
|
||||||
def handle_sigint(app, *args, **kwargs):
|
def handle_sigint(app, *args, **kwargs):
|
||||||
print("SIGINT or CTRL-C detected. Exiting gracefully...")
|
print("SIGINT or CTRL-C detected. Exiting gracefully...")
|
||||||
|
|
||||||
|
@ -217,7 +153,6 @@ async def maybe_await(value):
|
||||||
|
|
||||||
|
|
||||||
async def sse_generator(event_gen):
|
async def sse_generator(event_gen):
|
||||||
await start_trace("sse_generator")
|
|
||||||
try:
|
try:
|
||||||
event_gen = await event_gen
|
event_gen = await event_gen
|
||||||
async for item in event_gen:
|
async for item in event_gen:
|
||||||
|
@ -235,14 +170,10 @@ async def sse_generator(event_gen):
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
finally:
|
|
||||||
await end_trace()
|
|
||||||
|
|
||||||
|
|
||||||
def create_dynamic_typed_route(func: Any, method: str):
|
def create_dynamic_typed_route(func: Any, method: str):
|
||||||
async def endpoint(request: Request, **kwargs):
|
async def endpoint(request: Request, **kwargs):
|
||||||
await start_trace(func.__name__)
|
|
||||||
|
|
||||||
set_request_provider_data(request.headers)
|
set_request_provider_data(request.headers)
|
||||||
|
|
||||||
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
||||||
|
@ -257,8 +188,6 @@ def create_dynamic_typed_route(func: Any, method: str):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exception(e)
|
traceback.print_exception(e)
|
||||||
raise translate_exception(e) from e
|
raise translate_exception(e) from e
|
||||||
finally:
|
|
||||||
await end_trace()
|
|
||||||
|
|
||||||
sig = inspect.signature(func)
|
sig = inspect.signature(func)
|
||||||
new_params = [
|
new_params = [
|
||||||
|
@ -282,6 +211,19 @@ def create_dynamic_typed_route(func: Any, method: str):
|
||||||
return endpoint
|
return endpoint
|
||||||
|
|
||||||
|
|
||||||
|
class TracingMiddleware:
|
||||||
|
def __init__(self, app):
|
||||||
|
self.app = app
|
||||||
|
|
||||||
|
async def __call__(self, scope, receive, send):
|
||||||
|
path = scope["path"]
|
||||||
|
await start_trace(path, {"location": "server"})
|
||||||
|
try:
|
||||||
|
return await self.app(scope, receive, send)
|
||||||
|
finally:
|
||||||
|
await end_trace()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""Start the LlamaStack server."""
|
"""Start the LlamaStack server."""
|
||||||
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
|
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
|
||||||
|
@ -338,6 +280,7 @@ def main():
|
||||||
print(yaml.dump(config.model_dump(), indent=2))
|
print(yaml.dump(config.model_dump(), indent=2))
|
||||||
|
|
||||||
app = FastAPI(lifespan=lifespan)
|
app = FastAPI(lifespan=lifespan)
|
||||||
|
app.add_middleware(TracingMiddleware)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
impls = asyncio.run(construct_stack(config))
|
impls = asyncio.run(construct_stack(config))
|
||||||
|
|
|
@ -113,7 +113,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
# May be this should be a parameter of the agentic instance
|
# May be this should be a parameter of the agentic instance
|
||||||
# that can define its behavior in a custom way
|
# that can define its behavior in a custom way
|
||||||
for m in turn.input_messages:
|
for m in turn.input_messages:
|
||||||
msg = m.copy()
|
msg = m.model_copy()
|
||||||
if isinstance(msg, UserMessage):
|
if isinstance(msg, UserMessage):
|
||||||
msg.context = None
|
msg.context = None
|
||||||
messages.append(msg)
|
messages.append(msg)
|
||||||
|
|
|
@ -52,7 +52,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
|
|
||||||
await self.persistence_store.set(
|
await self.persistence_store.set(
|
||||||
key=f"agent:{agent_id}",
|
key=f"agent:{agent_id}",
|
||||||
value=agent_config.json(),
|
value=agent_config.model_dump_json(),
|
||||||
)
|
)
|
||||||
return AgentCreateResponse(
|
return AgentCreateResponse(
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
|
|
|
@ -39,7 +39,7 @@ class AgentPersistence:
|
||||||
)
|
)
|
||||||
await self.kvstore.set(
|
await self.kvstore.set(
|
||||||
key=f"session:{self.agent_id}:{session_id}",
|
key=f"session:{self.agent_id}:{session_id}",
|
||||||
value=session_info.json(),
|
value=session_info.model_dump_json(),
|
||||||
)
|
)
|
||||||
return session_id
|
return session_id
|
||||||
|
|
||||||
|
@ -60,13 +60,13 @@ class AgentPersistence:
|
||||||
session_info.memory_bank_id = bank_id
|
session_info.memory_bank_id = bank_id
|
||||||
await self.kvstore.set(
|
await self.kvstore.set(
|
||||||
key=f"session:{self.agent_id}:{session_id}",
|
key=f"session:{self.agent_id}:{session_id}",
|
||||||
value=session_info.json(),
|
value=session_info.model_dump_json(),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def add_turn_to_session(self, session_id: str, turn: Turn):
|
async def add_turn_to_session(self, session_id: str, turn: Turn):
|
||||||
await self.kvstore.set(
|
await self.kvstore.set(
|
||||||
key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}",
|
key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}",
|
||||||
value=turn.json(),
|
value=turn.model_dump_json(),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_session_turns(self, session_id: str) -> List[Turn]:
|
async def get_session_turns(self, session_id: str) -> List[Turn]:
|
||||||
|
|
|
@ -72,7 +72,7 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
|
||||||
key = f"{EVAL_TASKS_PREFIX}{task_def.identifier}"
|
key = f"{EVAL_TASKS_PREFIX}{task_def.identifier}"
|
||||||
await self.kvstore.set(
|
await self.kvstore.set(
|
||||||
key=key,
|
key=key,
|
||||||
value=task_def.json(),
|
value=task_def.model_dump_json(),
|
||||||
)
|
)
|
||||||
self.eval_tasks[task_def.identifier] = task_def
|
self.eval_tasks[task_def.identifier] = task_def
|
||||||
|
|
||||||
|
|
|
@ -80,7 +80,9 @@ class FaissIndex(EmbeddingIndex):
|
||||||
np.savetxt(buffer, np_index)
|
np.savetxt(buffer, np_index)
|
||||||
data = {
|
data = {
|
||||||
"id_by_index": self.id_by_index,
|
"id_by_index": self.id_by_index,
|
||||||
"chunk_by_index": {k: v.json() for k, v in self.chunk_by_index.items()},
|
"chunk_by_index": {
|
||||||
|
k: v.model_dump_json() for k, v in self.chunk_by_index.items()
|
||||||
|
},
|
||||||
"faiss_index": base64.b64encode(buffer.getvalue()).decode("utf-8"),
|
"faiss_index": base64.b64encode(buffer.getvalue()).decode("utf-8"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -162,7 +164,7 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
|
||||||
key = f"{MEMORY_BANKS_PREFIX}{memory_bank.identifier}"
|
key = f"{MEMORY_BANKS_PREFIX}{memory_bank.identifier}"
|
||||||
await self.kvstore.set(
|
await self.kvstore.set(
|
||||||
key=key,
|
key=key,
|
||||||
value=memory_bank.json(),
|
value=memory_bank.model_dump_json(),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Store in cache
|
# Store in cache
|
||||||
|
|
|
@ -107,7 +107,7 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
|
|
||||||
collection = await self.client.get_or_create_collection(
|
collection = await self.client.get_or_create_collection(
|
||||||
name=memory_bank.identifier,
|
name=memory_bank.identifier,
|
||||||
metadata={"bank": memory_bank.json()},
|
metadata={"bank": memory_bank.model_dump_json()},
|
||||||
)
|
)
|
||||||
bank_index = BankWithIndex(
|
bank_index = BankWithIndex(
|
||||||
bank=memory_bank, index=ChromaIndex(self.client, collection)
|
bank=memory_bank, index=ChromaIndex(self.client, collection)
|
||||||
|
|
|
@ -4,9 +4,24 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
class OpenTelemetryConfig(BaseModel):
|
class OpenTelemetryConfig(BaseModel):
|
||||||
jaeger_host: str = "localhost"
|
otel_endpoint: str = Field(
|
||||||
jaeger_port: int = 6831
|
default="http://localhost:4318/v1/traces",
|
||||||
|
description="The OpenTelemetry collector endpoint URL",
|
||||||
|
)
|
||||||
|
service_name: str = Field(
|
||||||
|
default="llama-stack",
|
||||||
|
description="The service name to use for telemetry",
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"otel_endpoint": "${env.OTEL_ENDPOINT:http://localhost:4318/v1/traces}",
|
||||||
|
"service_name": "${env.OTEL_SERVICE_NAME:llama-stack}",
|
||||||
|
}
|
||||||
|
|
|
@ -4,24 +4,31 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from datetime import datetime
|
import threading
|
||||||
|
|
||||||
from opentelemetry import metrics, trace
|
from opentelemetry import metrics, trace
|
||||||
from opentelemetry.exporter.jaeger.thrift import JaegerExporter
|
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
|
||||||
|
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
|
||||||
from opentelemetry.sdk.metrics import MeterProvider
|
from opentelemetry.sdk.metrics import MeterProvider
|
||||||
from opentelemetry.sdk.metrics.export import (
|
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
|
||||||
ConsoleMetricExporter,
|
|
||||||
PeriodicExportingMetricReader,
|
|
||||||
)
|
|
||||||
from opentelemetry.sdk.resources import Resource
|
from opentelemetry.sdk.resources import Resource
|
||||||
from opentelemetry.sdk.trace import TracerProvider
|
from opentelemetry.sdk.trace import TracerProvider
|
||||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||||
from opentelemetry.semconv.resource import ResourceAttributes
|
from opentelemetry.semconv.resource import ResourceAttributes
|
||||||
|
|
||||||
|
|
||||||
from llama_stack.apis.telemetry import * # noqa: F403
|
from llama_stack.apis.telemetry import * # noqa: F403
|
||||||
|
|
||||||
from .config import OpenTelemetryConfig
|
from .config import OpenTelemetryConfig
|
||||||
|
|
||||||
|
_GLOBAL_STORAGE = {
|
||||||
|
"active_spans": {},
|
||||||
|
"counters": {},
|
||||||
|
"gauges": {},
|
||||||
|
"up_down_counters": {},
|
||||||
|
}
|
||||||
|
_global_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
def string_to_trace_id(s: str) -> int:
|
def string_to_trace_id(s: str) -> int:
|
||||||
# Convert the string to bytes and then to an integer
|
# Convert the string to bytes and then to an integer
|
||||||
|
@ -42,33 +49,37 @@ class OpenTelemetryAdapter(Telemetry):
|
||||||
def __init__(self, config: OpenTelemetryConfig):
|
def __init__(self, config: OpenTelemetryConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
self.resource = Resource.create(
|
resource = Resource.create(
|
||||||
{ResourceAttributes.SERVICE_NAME: "foobar-service"}
|
{
|
||||||
|
ResourceAttributes.SERVICE_NAME: self.config.service_name,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set up tracing with Jaeger exporter
|
provider = TracerProvider(resource=resource)
|
||||||
jaeger_exporter = JaegerExporter(
|
trace.set_tracer_provider(provider)
|
||||||
agent_host_name=self.config.jaeger_host,
|
otlp_exporter = OTLPSpanExporter(
|
||||||
agent_port=self.config.jaeger_port,
|
endpoint=self.config.otel_endpoint,
|
||||||
)
|
)
|
||||||
trace_provider = TracerProvider(resource=self.resource)
|
span_processor = BatchSpanProcessor(otlp_exporter)
|
||||||
trace_processor = BatchSpanProcessor(jaeger_exporter)
|
trace.get_tracer_provider().add_span_processor(span_processor)
|
||||||
trace_provider.add_span_processor(trace_processor)
|
|
||||||
trace.set_tracer_provider(trace_provider)
|
|
||||||
self.tracer = trace.get_tracer(__name__)
|
|
||||||
|
|
||||||
# Set up metrics
|
# Set up metrics
|
||||||
metric_reader = PeriodicExportingMetricReader(ConsoleMetricExporter())
|
metric_reader = PeriodicExportingMetricReader(
|
||||||
|
OTLPMetricExporter(
|
||||||
|
endpoint=self.config.otel_endpoint,
|
||||||
|
)
|
||||||
|
)
|
||||||
metric_provider = MeterProvider(
|
metric_provider = MeterProvider(
|
||||||
resource=self.resource, metric_readers=[metric_reader]
|
resource=resource, metric_readers=[metric_reader]
|
||||||
)
|
)
|
||||||
metrics.set_meter_provider(metric_provider)
|
metrics.set_meter_provider(metric_provider)
|
||||||
self.meter = metrics.get_meter(__name__)
|
self.meter = metrics.get_meter(__name__)
|
||||||
|
self._lock = _global_lock
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
|
trace.get_tracer_provider().force_flush()
|
||||||
trace.get_tracer_provider().shutdown()
|
trace.get_tracer_provider().shutdown()
|
||||||
metrics.get_meter_provider().shutdown()
|
metrics.get_meter_provider().shutdown()
|
||||||
|
|
||||||
|
@ -81,121 +92,117 @@ class OpenTelemetryAdapter(Telemetry):
|
||||||
self._log_structured(event)
|
self._log_structured(event)
|
||||||
|
|
||||||
def _log_unstructured(self, event: UnstructuredLogEvent) -> None:
|
def _log_unstructured(self, event: UnstructuredLogEvent) -> None:
|
||||||
span = trace.get_current_span()
|
with self._lock:
|
||||||
span.add_event(
|
# Use global storage instead of instance storage
|
||||||
name=event.message,
|
span_id = string_to_span_id(event.span_id)
|
||||||
attributes={"severity": event.severity.value, **event.attributes},
|
span = _GLOBAL_STORAGE["active_spans"].get(span_id)
|
||||||
timestamp=event.timestamp,
|
|
||||||
)
|
if span:
|
||||||
|
timestamp_ns = int(event.timestamp.timestamp() * 1e9)
|
||||||
|
span.add_event(
|
||||||
|
name=event.type,
|
||||||
|
attributes={
|
||||||
|
"message": event.message,
|
||||||
|
"severity": event.severity.value,
|
||||||
|
**event.attributes,
|
||||||
|
},
|
||||||
|
timestamp=timestamp_ns,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f"Warning: No active span found for span_id {span_id}. Dropping event: {event}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_or_create_counter(self, name: str, unit: str) -> metrics.Counter:
|
||||||
|
if name not in _GLOBAL_STORAGE["counters"]:
|
||||||
|
_GLOBAL_STORAGE["counters"][name] = self.meter.create_counter(
|
||||||
|
name=name,
|
||||||
|
unit=unit,
|
||||||
|
description=f"Counter for {name}",
|
||||||
|
)
|
||||||
|
return _GLOBAL_STORAGE["counters"][name]
|
||||||
|
|
||||||
|
def _get_or_create_gauge(self, name: str, unit: str) -> metrics.ObservableGauge:
|
||||||
|
if name not in _GLOBAL_STORAGE["gauges"]:
|
||||||
|
_GLOBAL_STORAGE["gauges"][name] = self.meter.create_gauge(
|
||||||
|
name=name,
|
||||||
|
unit=unit,
|
||||||
|
description=f"Gauge for {name}",
|
||||||
|
)
|
||||||
|
return _GLOBAL_STORAGE["gauges"][name]
|
||||||
|
|
||||||
def _log_metric(self, event: MetricEvent) -> None:
|
def _log_metric(self, event: MetricEvent) -> None:
|
||||||
if isinstance(event.value, int):
|
if isinstance(event.value, int):
|
||||||
self.meter.create_counter(
|
counter = self._get_or_create_counter(event.metric, event.unit)
|
||||||
name=event.metric,
|
counter.add(event.value, attributes=event.attributes)
|
||||||
unit=event.unit,
|
|
||||||
description=f"Counter for {event.metric}",
|
|
||||||
).add(event.value, attributes=event.attributes)
|
|
||||||
elif isinstance(event.value, float):
|
elif isinstance(event.value, float):
|
||||||
self.meter.create_gauge(
|
up_down_counter = self._get_or_create_up_down_counter(
|
||||||
name=event.metric,
|
event.metric, event.unit
|
||||||
unit=event.unit,
|
)
|
||||||
description=f"Gauge for {event.metric}",
|
up_down_counter.add(event.value, attributes=event.attributes)
|
||||||
).set(event.value, attributes=event.attributes)
|
|
||||||
|
def _get_or_create_up_down_counter(
|
||||||
|
self, name: str, unit: str
|
||||||
|
) -> metrics.UpDownCounter:
|
||||||
|
if name not in _GLOBAL_STORAGE["up_down_counters"]:
|
||||||
|
_GLOBAL_STORAGE["up_down_counters"][name] = (
|
||||||
|
self.meter.create_up_down_counter(
|
||||||
|
name=name,
|
||||||
|
unit=unit,
|
||||||
|
description=f"UpDownCounter for {name}",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return _GLOBAL_STORAGE["up_down_counters"][name]
|
||||||
|
|
||||||
def _log_structured(self, event: StructuredLogEvent) -> None:
|
def _log_structured(self, event: StructuredLogEvent) -> None:
|
||||||
if isinstance(event.payload, SpanStartPayload):
|
with self._lock:
|
||||||
context = trace.set_span_in_context(
|
span_id = string_to_span_id(event.span_id)
|
||||||
trace.NonRecordingSpan(
|
trace_id = string_to_trace_id(event.trace_id)
|
||||||
trace.SpanContext(
|
tracer = trace.get_tracer(__name__)
|
||||||
trace_id=string_to_trace_id(event.trace_id),
|
|
||||||
span_id=string_to_span_id(event.span_id),
|
|
||||||
is_remote=True,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
span = self.tracer.start_span(
|
|
||||||
name=event.payload.name,
|
|
||||||
kind=trace.SpanKind.INTERNAL,
|
|
||||||
context=context,
|
|
||||||
attributes=event.attributes,
|
|
||||||
)
|
|
||||||
|
|
||||||
if event.payload.parent_span_id:
|
if isinstance(event.payload, SpanStartPayload):
|
||||||
span.set_parent(
|
# Check if span already exists to prevent duplicates
|
||||||
trace.SpanContext(
|
if span_id in _GLOBAL_STORAGE["active_spans"]:
|
||||||
trace_id=string_to_trace_id(event.trace_id),
|
return
|
||||||
span_id=string_to_span_id(event.payload.parent_span_id),
|
|
||||||
is_remote=True,
|
parent_span = None
|
||||||
|
if event.payload.parent_span_id:
|
||||||
|
parent_span_id = string_to_span_id(event.payload.parent_span_id)
|
||||||
|
parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id)
|
||||||
|
|
||||||
|
# Create a new trace context with the trace_id
|
||||||
|
context = trace.Context(trace_id=trace_id)
|
||||||
|
if parent_span:
|
||||||
|
context = trace.set_span_in_context(parent_span, context)
|
||||||
|
|
||||||
|
span = tracer.start_span(
|
||||||
|
name=event.payload.name,
|
||||||
|
context=context,
|
||||||
|
attributes=event.attributes or {},
|
||||||
|
start_time=int(event.timestamp.timestamp() * 1e9),
|
||||||
|
)
|
||||||
|
_GLOBAL_STORAGE["active_spans"][span_id] = span
|
||||||
|
|
||||||
|
# Set as current span using context manager
|
||||||
|
with trace.use_span(span, end_on_exit=False):
|
||||||
|
pass # Let the span continue beyond this block
|
||||||
|
|
||||||
|
elif isinstance(event.payload, SpanEndPayload):
|
||||||
|
span = _GLOBAL_STORAGE["active_spans"].get(span_id)
|
||||||
|
if span:
|
||||||
|
if event.attributes:
|
||||||
|
span.set_attributes(event.attributes)
|
||||||
|
|
||||||
|
status = (
|
||||||
|
trace.Status(status_code=trace.StatusCode.OK)
|
||||||
|
if event.payload.status == SpanStatus.OK
|
||||||
|
else trace.Status(status_code=trace.StatusCode.ERROR)
|
||||||
)
|
)
|
||||||
)
|
span.set_status(status)
|
||||||
elif isinstance(event.payload, SpanEndPayload):
|
span.end(end_time=int(event.timestamp.timestamp() * 1e9))
|
||||||
span = trace.get_current_span()
|
|
||||||
span.set_status(
|
# Remove from active spans
|
||||||
trace.Status(
|
_GLOBAL_STORAGE["active_spans"].pop(span_id, None)
|
||||||
trace.StatusCode.OK
|
|
||||||
if event.payload.status == SpanStatus.OK
|
|
||||||
else trace.StatusCode.ERROR
|
|
||||||
)
|
|
||||||
)
|
|
||||||
span.end(end_time=event.timestamp)
|
|
||||||
|
|
||||||
async def get_trace(self, trace_id: str) -> Trace:
|
async def get_trace(self, trace_id: str) -> Trace:
|
||||||
# we need to look up the root span id
|
raise NotImplementedError("Trace retrieval not implemented yet")
|
||||||
raise NotImplementedError("not yet no")
|
|
||||||
|
|
||||||
|
|
||||||
# Usage example
|
|
||||||
async def main():
|
|
||||||
telemetry = OpenTelemetryTelemetry("my-service")
|
|
||||||
await telemetry.initialize()
|
|
||||||
|
|
||||||
# Log an unstructured event
|
|
||||||
await telemetry.log_event(
|
|
||||||
UnstructuredLogEvent(
|
|
||||||
trace_id="trace123",
|
|
||||||
span_id="span456",
|
|
||||||
timestamp=datetime.now(),
|
|
||||||
message="This is a log message",
|
|
||||||
severity=LogSeverity.INFO,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Log a metric event
|
|
||||||
await telemetry.log_event(
|
|
||||||
MetricEvent(
|
|
||||||
trace_id="trace123",
|
|
||||||
span_id="span456",
|
|
||||||
timestamp=datetime.now(),
|
|
||||||
metric="my_metric",
|
|
||||||
value=42,
|
|
||||||
unit="count",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Log a structured event (span start)
|
|
||||||
await telemetry.log_event(
|
|
||||||
StructuredLogEvent(
|
|
||||||
trace_id="trace123",
|
|
||||||
span_id="span789",
|
|
||||||
timestamp=datetime.now(),
|
|
||||||
payload=SpanStartPayload(name="my_operation"),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Log a structured event (span end)
|
|
||||||
await telemetry.log_event(
|
|
||||||
StructuredLogEvent(
|
|
||||||
trace_id="trace123",
|
|
||||||
span_id="span789",
|
|
||||||
timestamp=datetime.now(),
|
|
||||||
payload=SpanEndPayload(status=SpanStatus.OK),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
await telemetry.shutdown()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
asyncio.run(main())
|
|
||||||
|
|
|
@ -20,7 +20,7 @@ from llama_stack.apis.telemetry import * # noqa: F403
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def generate_short_uuid(len: int = 12):
|
def generate_short_uuid(len: int = 8):
|
||||||
full_uuid = uuid.uuid4()
|
full_uuid = uuid.uuid4()
|
||||||
uuid_bytes = full_uuid.bytes
|
uuid_bytes = full_uuid.bytes
|
||||||
encoded = base64.urlsafe_b64encode(uuid_bytes)
|
encoded = base64.urlsafe_b64encode(uuid_bytes)
|
||||||
|
@ -123,18 +123,19 @@ def setup_logger(api: Telemetry, level: int = logging.INFO):
|
||||||
logger.addHandler(TelemetryHandler())
|
logger.addHandler(TelemetryHandler())
|
||||||
|
|
||||||
|
|
||||||
async def start_trace(name: str, attributes: Dict[str, Any] = None):
|
async def start_trace(name: str, attributes: Dict[str, Any] = None) -> TraceContext:
|
||||||
global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER
|
global CURRENT_TRACE_CONTEXT, BACKGROUND_LOGGER
|
||||||
|
|
||||||
if BACKGROUND_LOGGER is None:
|
if BACKGROUND_LOGGER is None:
|
||||||
log.info("No Telemetry implementation set. Skipping trace initialization...")
|
log.info("No Telemetry implementation set. Skipping trace initialization...")
|
||||||
return
|
return
|
||||||
|
|
||||||
trace_id = generate_short_uuid()
|
trace_id = generate_short_uuid(16)
|
||||||
context = TraceContext(BACKGROUND_LOGGER, trace_id)
|
context = TraceContext(BACKGROUND_LOGGER, trace_id)
|
||||||
context.push_span(name, {"__root__": True, **(attributes or {})})
|
context.push_span(name, {"__root__": True, **(attributes or {})})
|
||||||
|
|
||||||
CURRENT_TRACE_CONTEXT = context
|
CURRENT_TRACE_CONTEXT = context
|
||||||
|
return context
|
||||||
|
|
||||||
|
|
||||||
async def end_trace(status: SpanStatus = SpanStatus.OK):
|
async def end_trace(status: SpanStatus = SpanStatus.OK):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue