diff --git a/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py b/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py index 97e96b929..7ade75032 100644 --- a/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +++ b/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py @@ -98,7 +98,7 @@ class ProcessingMessageWrapper(BaseModel): def mp_rank_0() -> bool: - return get_model_parallel_rank() == 0 + return bool(get_model_parallel_rank() == 0) def encode_msg(msg: ProcessingMessage) -> bytes: @@ -125,7 +125,7 @@ def retrieve_requests(reply_socket_url: str): reply_socket.send_multipart([client_id, encode_msg(obj)]) while True: - tasks = [None] + tasks: list[ProcessingMessage | None] = [None] if mp_rank_0(): client_id, maybe_task_json = maybe_get_work(reply_socket) if maybe_task_json is not None: @@ -152,7 +152,7 @@ def retrieve_requests(reply_socket_url: str): break for obj in out: - updates = [None] + updates: list[ProcessingMessage | None] = [None] if mp_rank_0(): _, update_json = maybe_get_work(reply_socket) update = maybe_parse_message(update_json) diff --git a/pyproject.toml b/pyproject.toml index 22ad816d0..4d54bece0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -254,7 +254,6 @@ exclude = [ "^llama_stack/models/llama/llama3/generation\\.py$", "^llama_stack/models/llama/llama3/multimodal/model\\.py$", "^llama_stack/models/llama/llama4/", - "^llama_stack/providers/inline/inference/meta_reference/parallel_utils\\.py$", "^llama_stack/providers/inline/inference/meta_reference/quantization/fp8_impls\\.py$", "^llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers\\.py$", "^llama_stack/providers/inline/inference/vllm/",