From a40af5b91bc9c4af7cba31afaee085dc2bd673e2 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 27 Oct 2025 23:04:59 -0700 Subject: [PATCH] fix(mypy): resolve union type and list annotation errors MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - batches.py: Fix bytes/memoryview union type narrowing issue - encoder_utils.py: Add type annotation for masks_list 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- .../models/llama/llama3/multimodal/encoder_utils.py | 2 +- .../providers/inline/batches/reference/batches.py | 9 ++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/llama_stack/models/llama/llama3/multimodal/encoder_utils.py b/src/llama_stack/models/llama/llama3/multimodal/encoder_utils.py index 0cc5aec81..a87d77cc3 100644 --- a/src/llama_stack/models/llama/llama3/multimodal/encoder_utils.py +++ b/src/llama_stack/models/llama/llama3/multimodal/encoder_utils.py @@ -141,7 +141,7 @@ def build_encoder_attention_mask( """ Build vision encoder attention mask that omits padding tokens. """ - masks_list = [] + masks_list: list[torch.Tensor] = [] for arx in ar: mask_i = torch.ones((num_chunks, x.shape[2], 1), dtype=x.dtype) mask_i[: arx[0] * arx[1], :ntok] = 0 diff --git a/src/llama_stack/providers/inline/batches/reference/batches.py b/src/llama_stack/providers/inline/batches/reference/batches.py index 241218dca..7c4358b84 100644 --- a/src/llama_stack/providers/inline/batches/reference/batches.py +++ b/src/llama_stack/providers/inline/batches/reference/batches.py @@ -358,11 +358,10 @@ class ReferenceBatchesImpl(Batches): # TODO(SECURITY): do something about large files file_content_response = await self.files_api.openai_retrieve_file_content(batch.input_file_id) - # Handle both bytes and memoryview types - body = file_content_response.body - if isinstance(body, memoryview): - body = bytes(body) - file_content = body.decode("utf-8") + # Handle both bytes and memoryview types - convert to bytes unconditionally + # (bytes(x) returns x if already bytes, creates new bytes from memoryview otherwise) + body_bytes = bytes(file_content_response.body) + file_content = body_bytes.decode("utf-8") for line_num, line in enumerate(file_content.strip().split("\n"), 1): if line.strip(): # skip empty lines try: