mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
chore: add mypy inference fp8_impls
Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
parent
d880c2df0e
commit
1c08a1cae9
7 changed files with 38 additions and 25 deletions
|
@ -54,6 +54,7 @@ class DistributionInspectImpl(Inspect):
|
||||||
provider_types=[], # These APIs don't have "real" providers - they're internal to the stack
|
provider_types=[], # These APIs don't have "real" providers - they're internal to the stack
|
||||||
)
|
)
|
||||||
for e in endpoints
|
for e in endpoints
|
||||||
|
if e.methods
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -63,10 +64,11 @@ class DistributionInspectImpl(Inspect):
|
||||||
[
|
[
|
||||||
RouteInfo(
|
RouteInfo(
|
||||||
route=e.path,
|
route=e.path,
|
||||||
method=next(iter([m for m in e.methods if m != "HEAD"])),
|
method=next(iter([m for m in e.methods if m != "HEAD"])) if e.methods else "POST",
|
||||||
provider_types=[p.provider_type for p in providers],
|
provider_types=[p.provider_type for p in providers],
|
||||||
)
|
)
|
||||||
for e in endpoints
|
for e in endpoints
|
||||||
|
if e.methods
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -64,9 +64,7 @@ def get_all_api_routes() -> dict[Api, list[Route]]:
|
||||||
http_method = hdrs.METH_DELETE
|
http_method = hdrs.METH_DELETE
|
||||||
else:
|
else:
|
||||||
http_method = hdrs.METH_POST
|
http_method = hdrs.METH_POST
|
||||||
routes.append(
|
routes.append(Route(path=path, methods=[http_method], name=name, endpoint=None))
|
||||||
Route(path=path, methods=[http_method], name=name, endpoint=None)
|
|
||||||
) # setting endpoint to None since don't use a Router object
|
|
||||||
|
|
||||||
apis[api] = routes
|
apis[api] = routes
|
||||||
|
|
||||||
|
@ -95,6 +93,8 @@ def initialize_route_impls(impls: dict[Api, Any]) -> RouteImpls:
|
||||||
impl = impls[api]
|
impl = impls[api]
|
||||||
func = getattr(impl, route.name)
|
func = getattr(impl, route.name)
|
||||||
# Get the first (and typically only) method from the set, filtering out HEAD
|
# Get the first (and typically only) method from the set, filtering out HEAD
|
||||||
|
if route.methods is None:
|
||||||
|
continue # Skip if no methods are available
|
||||||
available_methods = [m for m in route.methods if m != "HEAD"]
|
available_methods = [m for m in route.methods if m != "HEAD"]
|
||||||
if not available_methods:
|
if not available_methods:
|
||||||
continue # Skip if only HEAD method is available
|
continue # Skip if only HEAD method is available
|
||||||
|
|
|
@ -42,14 +42,14 @@ def maybe_reshard_state_dict(
|
||||||
else:
|
else:
|
||||||
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
|
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
|
||||||
|
|
||||||
ckpt_paths = np.array(sorted(ckpt_paths))
|
ckpt_paths_array = np.array(sorted(ckpt_paths))
|
||||||
|
|
||||||
new_mp_size, new_mp_rank = get_model_parallel_world_size(), get_model_parallel_rank()
|
new_mp_size, new_mp_rank = get_model_parallel_world_size(), get_model_parallel_rank()
|
||||||
old_mp_size = len(ckpt_paths)
|
old_mp_size = len(ckpt_paths_array)
|
||||||
old_mp_ranks = map_mp_rank(old_mp_size, new_mp_size, new_mp_rank)
|
old_mp_ranks = map_mp_rank(old_mp_size, new_mp_size, new_mp_rank)
|
||||||
|
|
||||||
print(f"Loading checkpoint shards:\n{str(ckpt_paths[old_mp_ranks])}") # type: ignore
|
print(f"Loading checkpoint shards:\n{str(ckpt_paths_array[old_mp_ranks])}")
|
||||||
paths = ckpt_paths[old_mp_ranks] # type: ignore
|
paths = ckpt_paths_array[old_mp_ranks]
|
||||||
state_dicts = [torch.load(str(p), map_location=map_location, mmap=mmap) for p in paths]
|
state_dicts = [torch.load(str(p), map_location=map_location, mmap=mmap) for p in paths]
|
||||||
|
|
||||||
if new_mp_size == old_mp_size:
|
if new_mp_size == old_mp_size:
|
||||||
|
|
|
@ -34,6 +34,9 @@ class ConsoleSpanProcessor(SpanProcessor):
|
||||||
if span.attributes and span.attributes.get("__autotraced__"):
|
if span.attributes and span.attributes.get("__autotraced__"):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if span.start_time is None:
|
||||||
|
return
|
||||||
|
|
||||||
timestamp = datetime.fromtimestamp(span.start_time / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3]
|
timestamp = datetime.fromtimestamp(span.start_time / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3]
|
||||||
|
|
||||||
print(
|
print(
|
||||||
|
@ -46,6 +49,9 @@ class ConsoleSpanProcessor(SpanProcessor):
|
||||||
if span.attributes and span.attributes.get("__autotraced__"):
|
if span.attributes and span.attributes.get("__autotraced__"):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if span.end_time is None:
|
||||||
|
return
|
||||||
|
|
||||||
timestamp = datetime.fromtimestamp(span.end_time / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3]
|
timestamp = datetime.fromtimestamp(span.end_time / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3]
|
||||||
|
|
||||||
span_context = (
|
span_context = (
|
||||||
|
@ -59,6 +65,7 @@ class ConsoleSpanProcessor(SpanProcessor):
|
||||||
elif span.status.status_code != StatusCode.UNSET:
|
elif span.status.status_code != StatusCode.UNSET:
|
||||||
span_context += f"{COLORS['reset']} [{span.status.status_code}]"
|
span_context += f"{COLORS['reset']} [{span.status.status_code}]"
|
||||||
|
|
||||||
|
if span.start_time is not None and span.end_time is not None:
|
||||||
duration_ms = (span.end_time - span.start_time) / 1e6
|
duration_ms = (span.end_time - span.start_time) / 1e6
|
||||||
span_context += f"{COLORS['reset']} ({duration_ms:.2f}ms)"
|
span_context += f"{COLORS['reset']} ({duration_ms:.2f}ms)"
|
||||||
|
|
||||||
|
@ -76,6 +83,9 @@ class ConsoleSpanProcessor(SpanProcessor):
|
||||||
for event in span.events:
|
for event in span.events:
|
||||||
event_time = datetime.fromtimestamp(event.timestamp / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3]
|
event_time = datetime.fromtimestamp(event.timestamp / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3]
|
||||||
|
|
||||||
|
severity = "info"
|
||||||
|
message = event.name
|
||||||
|
if event.attributes:
|
||||||
severity = event.attributes.get("severity", "info")
|
severity = event.attributes.get("severity", "info")
|
||||||
message = event.attributes.get("message", event.name)
|
message = event.attributes.get("message", event.name)
|
||||||
if isinstance(message, dict | list):
|
if isinstance(message, dict | list):
|
||||||
|
@ -87,9 +97,10 @@ class ConsoleSpanProcessor(SpanProcessor):
|
||||||
"info": COLORS["white"],
|
"info": COLORS["white"],
|
||||||
"debug": COLORS["dim"],
|
"debug": COLORS["dim"],
|
||||||
}
|
}
|
||||||
msg_color = severity_colors.get(severity, COLORS["white"])
|
msg_color = severity_colors.get(str(severity), COLORS["white"])
|
||||||
|
|
||||||
print(f" {event_time} {msg_color}[{severity.upper()}] {message}{COLORS['reset']}")
|
severity_str = str(severity).upper() if severity else "INFO"
|
||||||
|
print(f" {event_time} {msg_color}[{severity_str}] {message}{COLORS['reset']}")
|
||||||
|
|
||||||
if event.attributes:
|
if event.attributes:
|
||||||
for key, value in event.attributes.items():
|
for key, value in event.attributes.items():
|
||||||
|
|
|
@ -10,8 +10,7 @@ import sqlite3
|
||||||
import threading
|
import threading
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
from opentelemetry.sdk.trace import SpanProcessor
|
from opentelemetry.sdk.trace import ReadableSpan, SpanProcessor
|
||||||
from opentelemetry.trace import Span
|
|
||||||
from opentelemetry.trace.span import format_span_id, format_trace_id
|
from opentelemetry.trace.span import format_span_id, format_trace_id
|
||||||
|
|
||||||
from llama_stack.providers.utils.telemetry.tracing import LOCAL_ROOT_SPAN_MARKER
|
from llama_stack.providers.utils.telemetry.tracing import LOCAL_ROOT_SPAN_MARKER
|
||||||
|
@ -93,11 +92,11 @@ class SQLiteSpanProcessor(SpanProcessor):
|
||||||
conn.commit()
|
conn.commit()
|
||||||
cursor.close()
|
cursor.close()
|
||||||
|
|
||||||
def on_start(self, span: Span, parent_context=None):
|
def on_start(self, span: ReadableSpan, parent_context=None):
|
||||||
"""Called when a span starts."""
|
"""Called when a span starts."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def on_end(self, span: Span):
|
def on_end(self, span: ReadableSpan):
|
||||||
"""Called when a span ends. Export the span data to SQLite."""
|
"""Called when a span ends. Export the span data to SQLite."""
|
||||||
try:
|
try:
|
||||||
conn = self._get_connection()
|
conn = self._get_connection()
|
||||||
|
|
|
@ -168,7 +168,7 @@ def _process_vllm_chat_completion_end_of_stream(
|
||||||
last_chunk_content: str | None,
|
last_chunk_content: str | None,
|
||||||
current_event_type: ChatCompletionResponseEventType,
|
current_event_type: ChatCompletionResponseEventType,
|
||||||
tool_call_bufs: dict[str, UnparseableToolCall] | None = None,
|
tool_call_bufs: dict[str, UnparseableToolCall] | None = None,
|
||||||
) -> list[OpenAIChatCompletionChunk]:
|
) -> list[ChatCompletionResponseStreamChunk]:
|
||||||
chunks = []
|
chunks = []
|
||||||
|
|
||||||
if finish_reason is not None:
|
if finish_reason is not None:
|
||||||
|
@ -247,9 +247,10 @@ async def _process_vllm_chat_completion_stream_response(
|
||||||
if choice.delta.tool_calls:
|
if choice.delta.tool_calls:
|
||||||
for delta_tool_call in choice.delta.tool_calls:
|
for delta_tool_call in choice.delta.tool_calls:
|
||||||
tool_call = convert_tool_call(delta_tool_call)
|
tool_call = convert_tool_call(delta_tool_call)
|
||||||
if delta_tool_call.index not in tool_call_bufs:
|
index_str = str(delta_tool_call.index)
|
||||||
tool_call_bufs[delta_tool_call.index] = UnparseableToolCall()
|
if index_str not in tool_call_bufs:
|
||||||
tool_call_buf = tool_call_bufs[delta_tool_call.index]
|
tool_call_bufs[index_str] = UnparseableToolCall()
|
||||||
|
tool_call_buf = tool_call_bufs[index_str]
|
||||||
tool_call_buf.tool_name += str(tool_call.tool_name)
|
tool_call_buf.tool_name += str(tool_call.tool_name)
|
||||||
tool_call_buf.call_id += tool_call.call_id
|
tool_call_buf.call_id += tool_call.call_id
|
||||||
tool_call_buf.arguments += (
|
tool_call_buf.arguments += (
|
||||||
|
|
|
@ -251,7 +251,7 @@ exclude = [
|
||||||
"^llama_stack/models/llama/llama3/multimodal/model\\.py$",
|
"^llama_stack/models/llama/llama3/multimodal/model\\.py$",
|
||||||
"^llama_stack/models/llama/llama4/",
|
"^llama_stack/models/llama/llama4/",
|
||||||
"^llama_stack/providers/inline/inference/meta_reference/parallel_utils\\.py$",
|
"^llama_stack/providers/inline/inference/meta_reference/parallel_utils\\.py$",
|
||||||
"^llama_stack/providers/inline/inference/meta_reference/quantization/fp8_impls\\.py$",
|
"^llama_stack/providers/inline/inference/meta_reference/quantization/loader\\.py$",
|
||||||
"^llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers\\.py$",
|
"^llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers\\.py$",
|
||||||
"^llama_stack/providers/inline/inference/vllm/",
|
"^llama_stack/providers/inline/inference/vllm/",
|
||||||
"^llama_stack/providers/inline/post_training/common/validator\\.py$",
|
"^llama_stack/providers/inline/post_training/common/validator\\.py$",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue