mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-19 19:30:05 +00:00
chore: add mypy
inference parallel utils (#2670)
# What does this PR do? <!-- Provide a short summary of what this PR does and why. Link to relevant issues if applicable. --> This PR adds static type coverage to `llama-stack` Part of https://github.com/meta-llama/llama-stack/issues/2647 <!-- If resolving an issue, uncomment and update the line below --> <!-- Closes #[issue-number] --> ## Test Plan <!-- Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.* --> Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
parent
ca7edcd6a4
commit
b78b8e1486
2 changed files with 3 additions and 4 deletions
|
@ -98,7 +98,7 @@ class ProcessingMessageWrapper(BaseModel):
|
|||
|
||||
|
||||
def mp_rank_0() -> bool:
|
||||
return get_model_parallel_rank() == 0
|
||||
return bool(get_model_parallel_rank() == 0)
|
||||
|
||||
|
||||
def encode_msg(msg: ProcessingMessage) -> bytes:
|
||||
|
@ -125,7 +125,7 @@ def retrieve_requests(reply_socket_url: str):
|
|||
reply_socket.send_multipart([client_id, encode_msg(obj)])
|
||||
|
||||
while True:
|
||||
tasks = [None]
|
||||
tasks: list[ProcessingMessage | None] = [None]
|
||||
if mp_rank_0():
|
||||
client_id, maybe_task_json = maybe_get_work(reply_socket)
|
||||
if maybe_task_json is not None:
|
||||
|
@ -152,7 +152,7 @@ def retrieve_requests(reply_socket_url: str):
|
|||
break
|
||||
|
||||
for obj in out:
|
||||
updates = [None]
|
||||
updates: list[ProcessingMessage | None] = [None]
|
||||
if mp_rank_0():
|
||||
_, update_json = maybe_get_work(reply_socket)
|
||||
update = maybe_parse_message(update_json)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue