bug fixes to make this work, trace creation worked - spans dont yet

This commit is contained in:
Ashwin Bharambe 2024-09-19 08:56:03 -07:00
parent 84ebed9c9f
commit 6e5ca1350e
3 changed files with 25 additions and 10 deletions

View file

@ -6,19 +6,20 @@
import asyncio
import json
from typing import Any, AsyncGenerator
from typing import Any, AsyncGenerator, List, Optional
import fire
import httpx
from pydantic import BaseModel
from llama_models.llama3.api import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from termcolor import cprint
from llama_stack.distribution.datatypes import RemoteProviderConfig
from .event_logger import EventLogger
from llama_stack.apis.inference import * # noqa: F403
async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Inference:
return InferenceClient(config.url)

View file

@ -23,6 +23,21 @@ from llama_stack.apis.telemetry import * # noqa: F403
from .config import OpenTelemetryConfig
def string_to_trace_id(s: str) -> int:
# Convert the string to bytes and then to an integer
return int.from_bytes(s.encode(), byteorder="big", signed=False)
def string_to_span_id(s: str) -> int:
# Use only the first 8 bytes (64 bits) for span ID
return int.from_bytes(s.encode()[:8], byteorder="big", signed=False)
def is_tracing_enabled(tracer):
with tracer.start_as_current_span("check_tracing") as span:
return span.is_recording()
class OpenTelemetryAdapter(Telemetry):
def __init__(self, config: OpenTelemetryConfig):
self.config = config
@ -92,23 +107,24 @@ class OpenTelemetryAdapter(Telemetry):
context = trace.set_span_in_context(
trace.NonRecordingSpan(
trace.SpanContext(
trace_id=int(event.trace_id, 16),
span_id=int(event.span_id, 16),
trace_id=string_to_trace_id(event.trace_id),
span_id=string_to_span_id(event.span_id),
is_remote=True,
)
)
)
span = self.tracer.start_span(
name=event.payload.name,
context=context,
kind=trace.SpanKind.INTERNAL,
context=context,
attributes=event.attributes,
)
if event.payload.parent_span_id:
span.set_parent(
trace.SpanContext(
trace_id=int(event.trace_id, 16),
span_id=int(event.payload.parent_span_id, 16),
trace_id=string_to_trace_id(event.trace_id),
span_id=string_to_span_id(event.payload.parent_span_id),
is_remote=True,
)
)

View file

@ -223,13 +223,11 @@ class SpanContextManager:
def __call__(self, func: Callable):
@wraps(func)
def sync_wrapper(*args, **kwargs):
print("sync wrapper")
with self:
return func(*args, **kwargs)
@wraps(func)
async def async_wrapper(*args, **kwargs):
print("async wrapper")
async with self:
return await func(*args, **kwargs)