chore: add mypy inference fp8_impls

Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
Mustafa Elbehery 2025-07-09 00:22:45 +02:00
parent d880c2df0e
commit 1c08a1cae9
7 changed files with 38 additions and 25 deletions

View file

@ -34,6 +34,9 @@ class ConsoleSpanProcessor(SpanProcessor):
if span.attributes and span.attributes.get("__autotraced__"):
return
if span.start_time is None:
return
timestamp = datetime.fromtimestamp(span.start_time / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3]
print(
@ -46,6 +49,9 @@ class ConsoleSpanProcessor(SpanProcessor):
if span.attributes and span.attributes.get("__autotraced__"):
return
if span.end_time is None:
return
timestamp = datetime.fromtimestamp(span.end_time / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3]
span_context = (
@ -59,8 +65,9 @@ class ConsoleSpanProcessor(SpanProcessor):
elif span.status.status_code != StatusCode.UNSET:
span_context += f"{COLORS['reset']} [{span.status.status_code}]"
duration_ms = (span.end_time - span.start_time) / 1e6
span_context += f"{COLORS['reset']} ({duration_ms:.2f}ms)"
if span.start_time is not None and span.end_time is not None:
duration_ms = (span.end_time - span.start_time) / 1e6
span_context += f"{COLORS['reset']} ({duration_ms:.2f}ms)"
print(span_context)
@ -76,10 +83,13 @@ class ConsoleSpanProcessor(SpanProcessor):
for event in span.events:
event_time = datetime.fromtimestamp(event.timestamp / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3]
severity = event.attributes.get("severity", "info")
message = event.attributes.get("message", event.name)
if isinstance(message, dict | list):
message = json.dumps(message, indent=2)
severity = "info"
message = event.name
if event.attributes:
severity = event.attributes.get("severity", "info")
message = event.attributes.get("message", event.name)
if isinstance(message, dict | list):
message = json.dumps(message, indent=2)
severity_colors = {
"error": f"{COLORS['bold']}{COLORS['red']}",
@ -87,9 +97,10 @@ class ConsoleSpanProcessor(SpanProcessor):
"info": COLORS["white"],
"debug": COLORS["dim"],
}
msg_color = severity_colors.get(severity, COLORS["white"])
msg_color = severity_colors.get(str(severity), COLORS["white"])
print(f" {event_time} {msg_color}[{severity.upper()}] {message}{COLORS['reset']}")
severity_str = str(severity).upper() if severity else "INFO"
print(f" {event_time} {msg_color}[{severity_str}] {message}{COLORS['reset']}")
if event.attributes:
for key, value in event.attributes.items():

View file

@ -10,8 +10,7 @@ import sqlite3
import threading
from datetime import UTC, datetime
from opentelemetry.sdk.trace import SpanProcessor
from opentelemetry.trace import Span
from opentelemetry.sdk.trace import ReadableSpan, SpanProcessor
from opentelemetry.trace.span import format_span_id, format_trace_id
from llama_stack.providers.utils.telemetry.tracing import LOCAL_ROOT_SPAN_MARKER
@ -93,11 +92,11 @@ class SQLiteSpanProcessor(SpanProcessor):
conn.commit()
cursor.close()
def on_start(self, span: Span, parent_context=None):
def on_start(self, span: ReadableSpan, parent_context=None):
"""Called when a span starts."""
pass
def on_end(self, span: Span):
def on_end(self, span: ReadableSpan):
"""Called when a span ends. Export the span data to SQLite."""
try:
conn = self._get_connection()

View file

@ -168,7 +168,7 @@ def _process_vllm_chat_completion_end_of_stream(
last_chunk_content: str | None,
current_event_type: ChatCompletionResponseEventType,
tool_call_bufs: dict[str, UnparseableToolCall] | None = None,
) -> list[OpenAIChatCompletionChunk]:
) -> list[ChatCompletionResponseStreamChunk]:
chunks = []
if finish_reason is not None:
@ -247,9 +247,10 @@ async def _process_vllm_chat_completion_stream_response(
if choice.delta.tool_calls:
for delta_tool_call in choice.delta.tool_calls:
tool_call = convert_tool_call(delta_tool_call)
if delta_tool_call.index not in tool_call_bufs:
tool_call_bufs[delta_tool_call.index] = UnparseableToolCall()
tool_call_buf = tool_call_bufs[delta_tool_call.index]
index_str = str(delta_tool_call.index)
if index_str not in tool_call_bufs:
tool_call_bufs[index_str] = UnparseableToolCall()
tool_call_buf = tool_call_bufs[index_str]
tool_call_buf.tool_name += str(tool_call.tool_name)
tool_call_buf.call_id += tool_call.call_id
tool_call_buf.arguments += (