bug fixes to make this work, trace creation worked - spans dont yet

This commit is contained in:
Ashwin Bharambe 2024-09-19 08:56:03 -07:00
parent 84ebed9c9f
commit 6e5ca1350e
3 changed files with 25 additions and 10 deletions

View file

@ -6,19 +6,20 @@
import asyncio import asyncio
import json import json
from typing import Any, AsyncGenerator from typing import Any, AsyncGenerator, List, Optional
import fire import fire
import httpx import httpx
from pydantic import BaseModel from pydantic import BaseModel
from llama_models.llama3.api import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from termcolor import cprint from termcolor import cprint
from llama_stack.distribution.datatypes import RemoteProviderConfig from llama_stack.distribution.datatypes import RemoteProviderConfig
from .event_logger import EventLogger from .event_logger import EventLogger
from llama_stack.apis.inference import * # noqa: F403
async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Inference: async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Inference:
return InferenceClient(config.url) return InferenceClient(config.url)

View file

@ -23,6 +23,21 @@ from llama_stack.apis.telemetry import * # noqa: F403
from .config import OpenTelemetryConfig from .config import OpenTelemetryConfig
def string_to_trace_id(s: str) -> int:
# Convert the string to bytes and then to an integer
return int.from_bytes(s.encode(), byteorder="big", signed=False)
def string_to_span_id(s: str) -> int:
# Use only the first 8 bytes (64 bits) for span ID
return int.from_bytes(s.encode()[:8], byteorder="big", signed=False)
def is_tracing_enabled(tracer):
with tracer.start_as_current_span("check_tracing") as span:
return span.is_recording()
class OpenTelemetryAdapter(Telemetry): class OpenTelemetryAdapter(Telemetry):
def __init__(self, config: OpenTelemetryConfig): def __init__(self, config: OpenTelemetryConfig):
self.config = config self.config = config
@ -92,23 +107,24 @@ class OpenTelemetryAdapter(Telemetry):
context = trace.set_span_in_context( context = trace.set_span_in_context(
trace.NonRecordingSpan( trace.NonRecordingSpan(
trace.SpanContext( trace.SpanContext(
trace_id=int(event.trace_id, 16), trace_id=string_to_trace_id(event.trace_id),
span_id=int(event.span_id, 16), span_id=string_to_span_id(event.span_id),
is_remote=True, is_remote=True,
) )
) )
) )
span = self.tracer.start_span( span = self.tracer.start_span(
name=event.payload.name, name=event.payload.name,
context=context,
kind=trace.SpanKind.INTERNAL, kind=trace.SpanKind.INTERNAL,
context=context,
attributes=event.attributes, attributes=event.attributes,
) )
if event.payload.parent_span_id: if event.payload.parent_span_id:
span.set_parent( span.set_parent(
trace.SpanContext( trace.SpanContext(
trace_id=int(event.trace_id, 16), trace_id=string_to_trace_id(event.trace_id),
span_id=int(event.payload.parent_span_id, 16), span_id=string_to_span_id(event.payload.parent_span_id),
is_remote=True, is_remote=True,
) )
) )

View file

@ -223,13 +223,11 @@ class SpanContextManager:
def __call__(self, func: Callable): def __call__(self, func: Callable):
@wraps(func) @wraps(func)
def sync_wrapper(*args, **kwargs): def sync_wrapper(*args, **kwargs):
print("sync wrapper")
with self: with self:
return func(*args, **kwargs) return func(*args, **kwargs)
@wraps(func) @wraps(func)
async def async_wrapper(*args, **kwargs): async def async_wrapper(*args, **kwargs):
print("async wrapper")
async with self: async with self:
return await func(*args, **kwargs) return await func(*args, **kwargs)