chore: add mypy inference fp8_impls

Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
Mustafa Elbehery 2025-07-09 00:22:45 +02:00
parent d880c2df0e
commit 1c08a1cae9
7 changed files with 38 additions and 25 deletions

View file

@ -54,6 +54,7 @@ class DistributionInspectImpl(Inspect):
provider_types=[], # These APIs don't have "real" providers - they're internal to the stack provider_types=[], # These APIs don't have "real" providers - they're internal to the stack
) )
for e in endpoints for e in endpoints
if e.methods
] ]
) )
else: else:
@ -63,10 +64,11 @@ class DistributionInspectImpl(Inspect):
[ [
RouteInfo( RouteInfo(
route=e.path, route=e.path,
method=next(iter([m for m in e.methods if m != "HEAD"])), method=next(iter([m for m in e.methods if m != "HEAD"])) if e.methods else "POST",
provider_types=[p.provider_type for p in providers], provider_types=[p.provider_type for p in providers],
) )
for e in endpoints for e in endpoints
if e.methods
] ]
) )

View file

@ -64,9 +64,7 @@ def get_all_api_routes() -> dict[Api, list[Route]]:
http_method = hdrs.METH_DELETE http_method = hdrs.METH_DELETE
else: else:
http_method = hdrs.METH_POST http_method = hdrs.METH_POST
routes.append( routes.append(Route(path=path, methods=[http_method], name=name, endpoint=None))
Route(path=path, methods=[http_method], name=name, endpoint=None)
) # setting endpoint to None since don't use a Router object
apis[api] = routes apis[api] = routes
@ -95,6 +93,8 @@ def initialize_route_impls(impls: dict[Api, Any]) -> RouteImpls:
impl = impls[api] impl = impls[api]
func = getattr(impl, route.name) func = getattr(impl, route.name)
# Get the first (and typically only) method from the set, filtering out HEAD # Get the first (and typically only) method from the set, filtering out HEAD
if route.methods is None:
continue # Skip if no methods are available
available_methods = [m for m in route.methods if m != "HEAD"] available_methods = [m for m in route.methods if m != "HEAD"]
if not available_methods: if not available_methods:
continue # Skip if only HEAD method is available continue # Skip if only HEAD method is available

View file

@ -42,14 +42,14 @@ def maybe_reshard_state_dict(
else: else:
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor) torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
ckpt_paths = np.array(sorted(ckpt_paths)) ckpt_paths_array = np.array(sorted(ckpt_paths))
new_mp_size, new_mp_rank = get_model_parallel_world_size(), get_model_parallel_rank() new_mp_size, new_mp_rank = get_model_parallel_world_size(), get_model_parallel_rank()
old_mp_size = len(ckpt_paths) old_mp_size = len(ckpt_paths_array)
old_mp_ranks = map_mp_rank(old_mp_size, new_mp_size, new_mp_rank) old_mp_ranks = map_mp_rank(old_mp_size, new_mp_size, new_mp_rank)
print(f"Loading checkpoint shards:\n{str(ckpt_paths[old_mp_ranks])}") # type: ignore print(f"Loading checkpoint shards:\n{str(ckpt_paths_array[old_mp_ranks])}")
paths = ckpt_paths[old_mp_ranks] # type: ignore paths = ckpt_paths_array[old_mp_ranks]
state_dicts = [torch.load(str(p), map_location=map_location, mmap=mmap) for p in paths] state_dicts = [torch.load(str(p), map_location=map_location, mmap=mmap) for p in paths]
if new_mp_size == old_mp_size: if new_mp_size == old_mp_size:

View file

@ -34,6 +34,9 @@ class ConsoleSpanProcessor(SpanProcessor):
if span.attributes and span.attributes.get("__autotraced__"): if span.attributes and span.attributes.get("__autotraced__"):
return return
if span.start_time is None:
return
timestamp = datetime.fromtimestamp(span.start_time / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3] timestamp = datetime.fromtimestamp(span.start_time / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3]
print( print(
@ -46,6 +49,9 @@ class ConsoleSpanProcessor(SpanProcessor):
if span.attributes and span.attributes.get("__autotraced__"): if span.attributes and span.attributes.get("__autotraced__"):
return return
if span.end_time is None:
return
timestamp = datetime.fromtimestamp(span.end_time / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3] timestamp = datetime.fromtimestamp(span.end_time / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3]
span_context = ( span_context = (
@ -59,8 +65,9 @@ class ConsoleSpanProcessor(SpanProcessor):
elif span.status.status_code != StatusCode.UNSET: elif span.status.status_code != StatusCode.UNSET:
span_context += f"{COLORS['reset']} [{span.status.status_code}]" span_context += f"{COLORS['reset']} [{span.status.status_code}]"
duration_ms = (span.end_time - span.start_time) / 1e6 if span.start_time is not None and span.end_time is not None:
span_context += f"{COLORS['reset']} ({duration_ms:.2f}ms)" duration_ms = (span.end_time - span.start_time) / 1e6
span_context += f"{COLORS['reset']} ({duration_ms:.2f}ms)"
print(span_context) print(span_context)
@ -76,10 +83,13 @@ class ConsoleSpanProcessor(SpanProcessor):
for event in span.events: for event in span.events:
event_time = datetime.fromtimestamp(event.timestamp / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3] event_time = datetime.fromtimestamp(event.timestamp / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3]
severity = event.attributes.get("severity", "info") severity = "info"
message = event.attributes.get("message", event.name) message = event.name
if isinstance(message, dict | list): if event.attributes:
message = json.dumps(message, indent=2) severity = event.attributes.get("severity", "info")
message = event.attributes.get("message", event.name)
if isinstance(message, dict | list):
message = json.dumps(message, indent=2)
severity_colors = { severity_colors = {
"error": f"{COLORS['bold']}{COLORS['red']}", "error": f"{COLORS['bold']}{COLORS['red']}",
@ -87,9 +97,10 @@ class ConsoleSpanProcessor(SpanProcessor):
"info": COLORS["white"], "info": COLORS["white"],
"debug": COLORS["dim"], "debug": COLORS["dim"],
} }
msg_color = severity_colors.get(severity, COLORS["white"]) msg_color = severity_colors.get(str(severity), COLORS["white"])
print(f" {event_time} {msg_color}[{severity.upper()}] {message}{COLORS['reset']}") severity_str = str(severity).upper() if severity else "INFO"
print(f" {event_time} {msg_color}[{severity_str}] {message}{COLORS['reset']}")
if event.attributes: if event.attributes:
for key, value in event.attributes.items(): for key, value in event.attributes.items():

View file

@ -10,8 +10,7 @@ import sqlite3
import threading import threading
from datetime import UTC, datetime from datetime import UTC, datetime
from opentelemetry.sdk.trace import SpanProcessor from opentelemetry.sdk.trace import ReadableSpan, SpanProcessor
from opentelemetry.trace import Span
from opentelemetry.trace.span import format_span_id, format_trace_id from opentelemetry.trace.span import format_span_id, format_trace_id
from llama_stack.providers.utils.telemetry.tracing import LOCAL_ROOT_SPAN_MARKER from llama_stack.providers.utils.telemetry.tracing import LOCAL_ROOT_SPAN_MARKER
@ -93,11 +92,11 @@ class SQLiteSpanProcessor(SpanProcessor):
conn.commit() conn.commit()
cursor.close() cursor.close()
def on_start(self, span: Span, parent_context=None): def on_start(self, span: ReadableSpan, parent_context=None):
"""Called when a span starts.""" """Called when a span starts."""
pass pass
def on_end(self, span: Span): def on_end(self, span: ReadableSpan):
"""Called when a span ends. Export the span data to SQLite.""" """Called when a span ends. Export the span data to SQLite."""
try: try:
conn = self._get_connection() conn = self._get_connection()

View file

@ -168,7 +168,7 @@ def _process_vllm_chat_completion_end_of_stream(
last_chunk_content: str | None, last_chunk_content: str | None,
current_event_type: ChatCompletionResponseEventType, current_event_type: ChatCompletionResponseEventType,
tool_call_bufs: dict[str, UnparseableToolCall] | None = None, tool_call_bufs: dict[str, UnparseableToolCall] | None = None,
) -> list[OpenAIChatCompletionChunk]: ) -> list[ChatCompletionResponseStreamChunk]:
chunks = [] chunks = []
if finish_reason is not None: if finish_reason is not None:
@ -247,9 +247,10 @@ async def _process_vllm_chat_completion_stream_response(
if choice.delta.tool_calls: if choice.delta.tool_calls:
for delta_tool_call in choice.delta.tool_calls: for delta_tool_call in choice.delta.tool_calls:
tool_call = convert_tool_call(delta_tool_call) tool_call = convert_tool_call(delta_tool_call)
if delta_tool_call.index not in tool_call_bufs: index_str = str(delta_tool_call.index)
tool_call_bufs[delta_tool_call.index] = UnparseableToolCall() if index_str not in tool_call_bufs:
tool_call_buf = tool_call_bufs[delta_tool_call.index] tool_call_bufs[index_str] = UnparseableToolCall()
tool_call_buf = tool_call_bufs[index_str]
tool_call_buf.tool_name += str(tool_call.tool_name) tool_call_buf.tool_name += str(tool_call.tool_name)
tool_call_buf.call_id += tool_call.call_id tool_call_buf.call_id += tool_call.call_id
tool_call_buf.arguments += ( tool_call_buf.arguments += (

View file

@ -251,7 +251,7 @@ exclude = [
"^llama_stack/models/llama/llama3/multimodal/model\\.py$", "^llama_stack/models/llama/llama3/multimodal/model\\.py$",
"^llama_stack/models/llama/llama4/", "^llama_stack/models/llama/llama4/",
"^llama_stack/providers/inline/inference/meta_reference/parallel_utils\\.py$", "^llama_stack/providers/inline/inference/meta_reference/parallel_utils\\.py$",
"^llama_stack/providers/inline/inference/meta_reference/quantization/fp8_impls\\.py$", "^llama_stack/providers/inline/inference/meta_reference/quantization/loader\\.py$",
"^llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers\\.py$", "^llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers\\.py$",
"^llama_stack/providers/inline/inference/vllm/", "^llama_stack/providers/inline/inference/vllm/",
"^llama_stack/providers/inline/post_training/common/validator\\.py$", "^llama_stack/providers/inline/post_training/common/validator\\.py$",