mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
chore: add mypy inference fp8_impls
Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
parent
d880c2df0e
commit
1c08a1cae9
7 changed files with 38 additions and 25 deletions
|
@ -54,6 +54,7 @@ class DistributionInspectImpl(Inspect):
|
|||
provider_types=[], # These APIs don't have "real" providers - they're internal to the stack
|
||||
)
|
||||
for e in endpoints
|
||||
if e.methods
|
||||
]
|
||||
)
|
||||
else:
|
||||
|
@ -63,10 +64,11 @@ class DistributionInspectImpl(Inspect):
|
|||
[
|
||||
RouteInfo(
|
||||
route=e.path,
|
||||
method=next(iter([m for m in e.methods if m != "HEAD"])),
|
||||
method=next(iter([m for m in e.methods if m != "HEAD"])) if e.methods else "POST",
|
||||
provider_types=[p.provider_type for p in providers],
|
||||
)
|
||||
for e in endpoints
|
||||
if e.methods
|
||||
]
|
||||
)
|
||||
|
||||
|
|
|
@ -64,9 +64,7 @@ def get_all_api_routes() -> dict[Api, list[Route]]:
|
|||
http_method = hdrs.METH_DELETE
|
||||
else:
|
||||
http_method = hdrs.METH_POST
|
||||
routes.append(
|
||||
Route(path=path, methods=[http_method], name=name, endpoint=None)
|
||||
) # setting endpoint to None since don't use a Router object
|
||||
routes.append(Route(path=path, methods=[http_method], name=name, endpoint=None))
|
||||
|
||||
apis[api] = routes
|
||||
|
||||
|
@ -95,6 +93,8 @@ def initialize_route_impls(impls: dict[Api, Any]) -> RouteImpls:
|
|||
impl = impls[api]
|
||||
func = getattr(impl, route.name)
|
||||
# Get the first (and typically only) method from the set, filtering out HEAD
|
||||
if route.methods is None:
|
||||
continue # Skip if no methods are available
|
||||
available_methods = [m for m in route.methods if m != "HEAD"]
|
||||
if not available_methods:
|
||||
continue # Skip if only HEAD method is available
|
||||
|
|
|
@ -42,14 +42,14 @@ def maybe_reshard_state_dict(
|
|||
else:
|
||||
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
|
||||
|
||||
ckpt_paths = np.array(sorted(ckpt_paths))
|
||||
ckpt_paths_array = np.array(sorted(ckpt_paths))
|
||||
|
||||
new_mp_size, new_mp_rank = get_model_parallel_world_size(), get_model_parallel_rank()
|
||||
old_mp_size = len(ckpt_paths)
|
||||
old_mp_size = len(ckpt_paths_array)
|
||||
old_mp_ranks = map_mp_rank(old_mp_size, new_mp_size, new_mp_rank)
|
||||
|
||||
print(f"Loading checkpoint shards:\n{str(ckpt_paths[old_mp_ranks])}") # type: ignore
|
||||
paths = ckpt_paths[old_mp_ranks] # type: ignore
|
||||
print(f"Loading checkpoint shards:\n{str(ckpt_paths_array[old_mp_ranks])}")
|
||||
paths = ckpt_paths_array[old_mp_ranks]
|
||||
state_dicts = [torch.load(str(p), map_location=map_location, mmap=mmap) for p in paths]
|
||||
|
||||
if new_mp_size == old_mp_size:
|
||||
|
|
|
@ -34,6 +34,9 @@ class ConsoleSpanProcessor(SpanProcessor):
|
|||
if span.attributes and span.attributes.get("__autotraced__"):
|
||||
return
|
||||
|
||||
if span.start_time is None:
|
||||
return
|
||||
|
||||
timestamp = datetime.fromtimestamp(span.start_time / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3]
|
||||
|
||||
print(
|
||||
|
@ -46,6 +49,9 @@ class ConsoleSpanProcessor(SpanProcessor):
|
|||
if span.attributes and span.attributes.get("__autotraced__"):
|
||||
return
|
||||
|
||||
if span.end_time is None:
|
||||
return
|
||||
|
||||
timestamp = datetime.fromtimestamp(span.end_time / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3]
|
||||
|
||||
span_context = (
|
||||
|
@ -59,8 +65,9 @@ class ConsoleSpanProcessor(SpanProcessor):
|
|||
elif span.status.status_code != StatusCode.UNSET:
|
||||
span_context += f"{COLORS['reset']} [{span.status.status_code}]"
|
||||
|
||||
duration_ms = (span.end_time - span.start_time) / 1e6
|
||||
span_context += f"{COLORS['reset']} ({duration_ms:.2f}ms)"
|
||||
if span.start_time is not None and span.end_time is not None:
|
||||
duration_ms = (span.end_time - span.start_time) / 1e6
|
||||
span_context += f"{COLORS['reset']} ({duration_ms:.2f}ms)"
|
||||
|
||||
print(span_context)
|
||||
|
||||
|
@ -76,10 +83,13 @@ class ConsoleSpanProcessor(SpanProcessor):
|
|||
for event in span.events:
|
||||
event_time = datetime.fromtimestamp(event.timestamp / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3]
|
||||
|
||||
severity = event.attributes.get("severity", "info")
|
||||
message = event.attributes.get("message", event.name)
|
||||
if isinstance(message, dict | list):
|
||||
message = json.dumps(message, indent=2)
|
||||
severity = "info"
|
||||
message = event.name
|
||||
if event.attributes:
|
||||
severity = event.attributes.get("severity", "info")
|
||||
message = event.attributes.get("message", event.name)
|
||||
if isinstance(message, dict | list):
|
||||
message = json.dumps(message, indent=2)
|
||||
|
||||
severity_colors = {
|
||||
"error": f"{COLORS['bold']}{COLORS['red']}",
|
||||
|
@ -87,9 +97,10 @@ class ConsoleSpanProcessor(SpanProcessor):
|
|||
"info": COLORS["white"],
|
||||
"debug": COLORS["dim"],
|
||||
}
|
||||
msg_color = severity_colors.get(severity, COLORS["white"])
|
||||
msg_color = severity_colors.get(str(severity), COLORS["white"])
|
||||
|
||||
print(f" {event_time} {msg_color}[{severity.upper()}] {message}{COLORS['reset']}")
|
||||
severity_str = str(severity).upper() if severity else "INFO"
|
||||
print(f" {event_time} {msg_color}[{severity_str}] {message}{COLORS['reset']}")
|
||||
|
||||
if event.attributes:
|
||||
for key, value in event.attributes.items():
|
||||
|
|
|
@ -10,8 +10,7 @@ import sqlite3
|
|||
import threading
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from opentelemetry.sdk.trace import SpanProcessor
|
||||
from opentelemetry.trace import Span
|
||||
from opentelemetry.sdk.trace import ReadableSpan, SpanProcessor
|
||||
from opentelemetry.trace.span import format_span_id, format_trace_id
|
||||
|
||||
from llama_stack.providers.utils.telemetry.tracing import LOCAL_ROOT_SPAN_MARKER
|
||||
|
@ -93,11 +92,11 @@ class SQLiteSpanProcessor(SpanProcessor):
|
|||
conn.commit()
|
||||
cursor.close()
|
||||
|
||||
def on_start(self, span: Span, parent_context=None):
|
||||
def on_start(self, span: ReadableSpan, parent_context=None):
|
||||
"""Called when a span starts."""
|
||||
pass
|
||||
|
||||
def on_end(self, span: Span):
|
||||
def on_end(self, span: ReadableSpan):
|
||||
"""Called when a span ends. Export the span data to SQLite."""
|
||||
try:
|
||||
conn = self._get_connection()
|
||||
|
|
|
@ -168,7 +168,7 @@ def _process_vllm_chat_completion_end_of_stream(
|
|||
last_chunk_content: str | None,
|
||||
current_event_type: ChatCompletionResponseEventType,
|
||||
tool_call_bufs: dict[str, UnparseableToolCall] | None = None,
|
||||
) -> list[OpenAIChatCompletionChunk]:
|
||||
) -> list[ChatCompletionResponseStreamChunk]:
|
||||
chunks = []
|
||||
|
||||
if finish_reason is not None:
|
||||
|
@ -247,9 +247,10 @@ async def _process_vllm_chat_completion_stream_response(
|
|||
if choice.delta.tool_calls:
|
||||
for delta_tool_call in choice.delta.tool_calls:
|
||||
tool_call = convert_tool_call(delta_tool_call)
|
||||
if delta_tool_call.index not in tool_call_bufs:
|
||||
tool_call_bufs[delta_tool_call.index] = UnparseableToolCall()
|
||||
tool_call_buf = tool_call_bufs[delta_tool_call.index]
|
||||
index_str = str(delta_tool_call.index)
|
||||
if index_str not in tool_call_bufs:
|
||||
tool_call_bufs[index_str] = UnparseableToolCall()
|
||||
tool_call_buf = tool_call_bufs[index_str]
|
||||
tool_call_buf.tool_name += str(tool_call.tool_name)
|
||||
tool_call_buf.call_id += tool_call.call_id
|
||||
tool_call_buf.arguments += (
|
||||
|
|
|
@ -251,7 +251,7 @@ exclude = [
|
|||
"^llama_stack/models/llama/llama3/multimodal/model\\.py$",
|
||||
"^llama_stack/models/llama/llama4/",
|
||||
"^llama_stack/providers/inline/inference/meta_reference/parallel_utils\\.py$",
|
||||
"^llama_stack/providers/inline/inference/meta_reference/quantization/fp8_impls\\.py$",
|
||||
"^llama_stack/providers/inline/inference/meta_reference/quantization/loader\\.py$",
|
||||
"^llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers\\.py$",
|
||||
"^llama_stack/providers/inline/inference/vllm/",
|
||||
"^llama_stack/providers/inline/post_training/common/validator\\.py$",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue