diff --git a/llama_stack/apis/inference/client.py b/llama_stack/apis/inference/client.py index f7de36a74..4df138841 100644 --- a/llama_stack/apis/inference/client.py +++ b/llama_stack/apis/inference/client.py @@ -6,19 +6,20 @@ import asyncio import json -from typing import Any, AsyncGenerator +from typing import Any, AsyncGenerator, List, Optional import fire import httpx from pydantic import BaseModel + +from llama_models.llama3.api import * # noqa: F403 +from llama_stack.apis.inference import * # noqa: F403 from termcolor import cprint from llama_stack.distribution.datatypes import RemoteProviderConfig from .event_logger import EventLogger -from llama_stack.apis.inference import * # noqa: F403 - async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Inference: return InferenceClient(config.url) diff --git a/llama_stack/providers/adapters/telemetry/opentelemetry/opentelemetry.py b/llama_stack/providers/adapters/telemetry/opentelemetry/opentelemetry.py index 44e49346e..03e8f7d53 100644 --- a/llama_stack/providers/adapters/telemetry/opentelemetry/opentelemetry.py +++ b/llama_stack/providers/adapters/telemetry/opentelemetry/opentelemetry.py @@ -23,6 +23,21 @@ from llama_stack.apis.telemetry import * # noqa: F403 from .config import OpenTelemetryConfig +def string_to_trace_id(s: str) -> int: + # Convert the string to bytes and then to an integer + return int.from_bytes(s.encode(), byteorder="big", signed=False) + + +def string_to_span_id(s: str) -> int: + # Use only the first 8 bytes (64 bits) for span ID + return int.from_bytes(s.encode()[:8], byteorder="big", signed=False) + + +def is_tracing_enabled(tracer): + with tracer.start_as_current_span("check_tracing") as span: + return span.is_recording() + + class OpenTelemetryAdapter(Telemetry): def __init__(self, config: OpenTelemetryConfig): self.config = config @@ -92,23 +107,24 @@ class OpenTelemetryAdapter(Telemetry): context = trace.set_span_in_context( trace.NonRecordingSpan( trace.SpanContext( - trace_id=int(event.trace_id, 16), - span_id=int(event.span_id, 16), + trace_id=string_to_trace_id(event.trace_id), + span_id=string_to_span_id(event.span_id), is_remote=True, ) ) ) span = self.tracer.start_span( name=event.payload.name, - context=context, kind=trace.SpanKind.INTERNAL, + context=context, attributes=event.attributes, ) + if event.payload.parent_span_id: span.set_parent( trace.SpanContext( - trace_id=int(event.trace_id, 16), - span_id=int(event.payload.parent_span_id, 16), + trace_id=string_to_trace_id(event.trace_id), + span_id=string_to_span_id(event.payload.parent_span_id), is_remote=True, ) ) diff --git a/llama_stack/providers/utils/telemetry/tracing.py b/llama_stack/providers/utils/telemetry/tracing.py index 45868b408..9fffc0f99 100644 --- a/llama_stack/providers/utils/telemetry/tracing.py +++ b/llama_stack/providers/utils/telemetry/tracing.py @@ -223,13 +223,11 @@ class SpanContextManager: def __call__(self, func: Callable): @wraps(func) def sync_wrapper(*args, **kwargs): - print("sync wrapper") with self: return func(*args, **kwargs) @wraps(func) async def async_wrapper(*args, **kwargs): - print("async wrapper") async with self: return await func(*args, **kwargs)