chore: add mypy inference parallel utils (#2670)

# What does this PR do?
<!-- Provide a short summary of what this PR does and why. Link to
relevant issues if applicable. -->
This PR adds static type coverage to `llama-stack`

Part of https://github.com/meta-llama/llama-stack/issues/2647

<!-- If resolving an issue, uncomment and update the line below -->
<!-- Closes #[issue-number] -->

## Test Plan
<!-- Describe the tests you ran to verify your changes with result
summaries. *Provide clear instructions so the plan can be easily
re-executed.* -->

Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
Mustafa Elbehery 2025-07-18 12:01:10 +02:00 committed by GitHub
parent ca7edcd6a4
commit b78b8e1486
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 3 additions and 4 deletions

View file

@ -98,7 +98,7 @@ class ProcessingMessageWrapper(BaseModel):
def mp_rank_0() -> bool: def mp_rank_0() -> bool:
return get_model_parallel_rank() == 0 return bool(get_model_parallel_rank() == 0)
def encode_msg(msg: ProcessingMessage) -> bytes: def encode_msg(msg: ProcessingMessage) -> bytes:
@ -125,7 +125,7 @@ def retrieve_requests(reply_socket_url: str):
reply_socket.send_multipart([client_id, encode_msg(obj)]) reply_socket.send_multipart([client_id, encode_msg(obj)])
while True: while True:
tasks = [None] tasks: list[ProcessingMessage | None] = [None]
if mp_rank_0(): if mp_rank_0():
client_id, maybe_task_json = maybe_get_work(reply_socket) client_id, maybe_task_json = maybe_get_work(reply_socket)
if maybe_task_json is not None: if maybe_task_json is not None:
@ -152,7 +152,7 @@ def retrieve_requests(reply_socket_url: str):
break break
for obj in out: for obj in out:
updates = [None] updates: list[ProcessingMessage | None] = [None]
if mp_rank_0(): if mp_rank_0():
_, update_json = maybe_get_work(reply_socket) _, update_json = maybe_get_work(reply_socket)
update = maybe_parse_message(update_json) update = maybe_parse_message(update_json)

View file

@ -254,7 +254,6 @@ exclude = [
"^llama_stack/models/llama/llama3/generation\\.py$", "^llama_stack/models/llama/llama3/generation\\.py$",
"^llama_stack/models/llama/llama3/multimodal/model\\.py$", "^llama_stack/models/llama/llama3/multimodal/model\\.py$",
"^llama_stack/models/llama/llama4/", "^llama_stack/models/llama/llama4/",
"^llama_stack/providers/inline/inference/meta_reference/parallel_utils\\.py$",
"^llama_stack/providers/inline/inference/meta_reference/quantization/fp8_impls\\.py$", "^llama_stack/providers/inline/inference/meta_reference/quantization/fp8_impls\\.py$",
"^llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers\\.py$", "^llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers\\.py$",
"^llama_stack/providers/inline/inference/vllm/", "^llama_stack/providers/inline/inference/vllm/",