chore: add mypy inference fp8_impls

Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
Mustafa Elbehery 2025-07-09 00:22:45 +02:00
parent d880c2df0e
commit 1c08a1cae9
7 changed files with 38 additions and 25 deletions

View file

@ -54,6 +54,7 @@ class DistributionInspectImpl(Inspect):
provider_types=[], # These APIs don't have "real" providers - they're internal to the stack
)
for e in endpoints
if e.methods
]
)
else:
@ -63,10 +64,11 @@ class DistributionInspectImpl(Inspect):
[
RouteInfo(
route=e.path,
method=next(iter([m for m in e.methods if m != "HEAD"])),
method=next(iter([m for m in e.methods if m != "HEAD"])) if e.methods else "POST",
provider_types=[p.provider_type for p in providers],
)
for e in endpoints
if e.methods
]
)

View file

@ -64,9 +64,7 @@ def get_all_api_routes() -> dict[Api, list[Route]]:
http_method = hdrs.METH_DELETE
else:
http_method = hdrs.METH_POST
routes.append(
Route(path=path, methods=[http_method], name=name, endpoint=None)
) # setting endpoint to None since don't use a Router object
routes.append(Route(path=path, methods=[http_method], name=name, endpoint=None))
apis[api] = routes
@ -95,6 +93,8 @@ def initialize_route_impls(impls: dict[Api, Any]) -> RouteImpls:
impl = impls[api]
func = getattr(impl, route.name)
# Get the first (and typically only) method from the set, filtering out HEAD
if route.methods is None:
continue # Skip if no methods are available
available_methods = [m for m in route.methods if m != "HEAD"]
if not available_methods:
continue # Skip if only HEAD method is available

View file

@ -42,14 +42,14 @@ def maybe_reshard_state_dict(
else:
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
ckpt_paths = np.array(sorted(ckpt_paths))
ckpt_paths_array = np.array(sorted(ckpt_paths))
new_mp_size, new_mp_rank = get_model_parallel_world_size(), get_model_parallel_rank()
old_mp_size = len(ckpt_paths)
old_mp_size = len(ckpt_paths_array)
old_mp_ranks = map_mp_rank(old_mp_size, new_mp_size, new_mp_rank)
print(f"Loading checkpoint shards:\n{str(ckpt_paths[old_mp_ranks])}") # type: ignore
paths = ckpt_paths[old_mp_ranks] # type: ignore
print(f"Loading checkpoint shards:\n{str(ckpt_paths_array[old_mp_ranks])}")
paths = ckpt_paths_array[old_mp_ranks]
state_dicts = [torch.load(str(p), map_location=map_location, mmap=mmap) for p in paths]
if new_mp_size == old_mp_size:

View file

@ -34,6 +34,9 @@ class ConsoleSpanProcessor(SpanProcessor):
if span.attributes and span.attributes.get("__autotraced__"):
return
if span.start_time is None:
return
timestamp = datetime.fromtimestamp(span.start_time / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3]
print(
@ -46,6 +49,9 @@ class ConsoleSpanProcessor(SpanProcessor):
if span.attributes and span.attributes.get("__autotraced__"):
return
if span.end_time is None:
return
timestamp = datetime.fromtimestamp(span.end_time / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3]
span_context = (
@ -59,8 +65,9 @@ class ConsoleSpanProcessor(SpanProcessor):
elif span.status.status_code != StatusCode.UNSET:
span_context += f"{COLORS['reset']} [{span.status.status_code}]"
duration_ms = (span.end_time - span.start_time) / 1e6
span_context += f"{COLORS['reset']} ({duration_ms:.2f}ms)"
if span.start_time is not None and span.end_time is not None:
duration_ms = (span.end_time - span.start_time) / 1e6
span_context += f"{COLORS['reset']} ({duration_ms:.2f}ms)"
print(span_context)
@ -76,10 +83,13 @@ class ConsoleSpanProcessor(SpanProcessor):
for event in span.events:
event_time = datetime.fromtimestamp(event.timestamp / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3]
severity = event.attributes.get("severity", "info")
message = event.attributes.get("message", event.name)
if isinstance(message, dict | list):
message = json.dumps(message, indent=2)
severity = "info"
message = event.name
if event.attributes:
severity = event.attributes.get("severity", "info")
message = event.attributes.get("message", event.name)
if isinstance(message, dict | list):
message = json.dumps(message, indent=2)
severity_colors = {
"error": f"{COLORS['bold']}{COLORS['red']}",
@ -87,9 +97,10 @@ class ConsoleSpanProcessor(SpanProcessor):
"info": COLORS["white"],
"debug": COLORS["dim"],
}
msg_color = severity_colors.get(severity, COLORS["white"])
msg_color = severity_colors.get(str(severity), COLORS["white"])
print(f" {event_time} {msg_color}[{severity.upper()}] {message}{COLORS['reset']}")
severity_str = str(severity).upper() if severity else "INFO"
print(f" {event_time} {msg_color}[{severity_str}] {message}{COLORS['reset']}")
if event.attributes:
for key, value in event.attributes.items():

View file

@ -10,8 +10,7 @@ import sqlite3
import threading
from datetime import UTC, datetime
from opentelemetry.sdk.trace import SpanProcessor
from opentelemetry.trace import Span
from opentelemetry.sdk.trace import ReadableSpan, SpanProcessor
from opentelemetry.trace.span import format_span_id, format_trace_id
from llama_stack.providers.utils.telemetry.tracing import LOCAL_ROOT_SPAN_MARKER
@ -93,11 +92,11 @@ class SQLiteSpanProcessor(SpanProcessor):
conn.commit()
cursor.close()
def on_start(self, span: Span, parent_context=None):
def on_start(self, span: ReadableSpan, parent_context=None):
"""Called when a span starts."""
pass
def on_end(self, span: Span):
def on_end(self, span: ReadableSpan):
"""Called when a span ends. Export the span data to SQLite."""
try:
conn = self._get_connection()

View file

@ -168,7 +168,7 @@ def _process_vllm_chat_completion_end_of_stream(
last_chunk_content: str | None,
current_event_type: ChatCompletionResponseEventType,
tool_call_bufs: dict[str, UnparseableToolCall] | None = None,
) -> list[OpenAIChatCompletionChunk]:
) -> list[ChatCompletionResponseStreamChunk]:
chunks = []
if finish_reason is not None:
@ -247,9 +247,10 @@ async def _process_vllm_chat_completion_stream_response(
if choice.delta.tool_calls:
for delta_tool_call in choice.delta.tool_calls:
tool_call = convert_tool_call(delta_tool_call)
if delta_tool_call.index not in tool_call_bufs:
tool_call_bufs[delta_tool_call.index] = UnparseableToolCall()
tool_call_buf = tool_call_bufs[delta_tool_call.index]
index_str = str(delta_tool_call.index)
if index_str not in tool_call_bufs:
tool_call_bufs[index_str] = UnparseableToolCall()
tool_call_buf = tool_call_bufs[index_str]
tool_call_buf.tool_name += str(tool_call.tool_name)
tool_call_buf.call_id += tool_call.call_id
tool_call_buf.arguments += (

View file

@ -251,7 +251,7 @@ exclude = [
"^llama_stack/models/llama/llama3/multimodal/model\\.py$",
"^llama_stack/models/llama/llama4/",
"^llama_stack/providers/inline/inference/meta_reference/parallel_utils\\.py$",
"^llama_stack/providers/inline/inference/meta_reference/quantization/fp8_impls\\.py$",
"^llama_stack/providers/inline/inference/meta_reference/quantization/loader\\.py$",
"^llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers\\.py$",
"^llama_stack/providers/inline/inference/vllm/",
"^llama_stack/providers/inline/post_training/common/validator\\.py$",