From d009dc29f76990e4b269e5c17116a1bafb7f8835 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Tue, 28 Oct 2025 10:37:27 -0700 Subject: [PATCH] fix(mypy): resolve provider utility and testing type issues (#3935) Fixes mypy type errors in provider utilities and testing infrastructure: - `mcp.py`: Cast incompatible client types, wrap image data properly - `batches.py`: Rename walrus variable to avoid shadowing - `api_recorder.py`: Use cast for Pydantic field annotation No functional changes. --------- Co-authored-by: Claude --- src/llama_stack/core/server/routes.py | 7 ++--- .../llama/llama3/multimodal/encoder_utils.py | 2 +- .../inline/batches/reference/batches.py | 27 +++++++++---------- .../remote/inference/together/together.py | 10 ++++--- src/llama_stack/providers/utils/tools/mcp.py | 8 +++--- src/llama_stack/testing/api_recorder.py | 4 ++- 6 files changed, 33 insertions(+), 25 deletions(-) diff --git a/src/llama_stack/core/server/routes.py b/src/llama_stack/core/server/routes.py index 4970d0bf8..48a961318 100644 --- a/src/llama_stack/core/server/routes.py +++ b/src/llama_stack/core/server/routes.py @@ -68,8 +68,9 @@ def get_all_api_routes( else: http_method = hdrs.METH_POST routes.append( - (Route(path=path, methods=[http_method], name=name, endpoint=None), webmethod) - ) # setting endpoint to None since don't use a Router object + # setting endpoint to None since don't use a Router object + (Route(path=path, methods=[http_method], name=name, endpoint=None), webmethod) # type: ignore[arg-type] + ) apis[api] = routes @@ -98,7 +99,7 @@ def initialize_route_impls(impls, external_apis: dict[Api, ExternalApiSpec] | No impl = impls[api] func = getattr(impl, route.name) # Get the first (and typically only) method from the set, filtering out HEAD - available_methods = [m for m in route.methods if m != "HEAD"] + available_methods = [m for m in (route.methods or []) if m != "HEAD"] if not available_methods: continue # Skip if only HEAD method is available method = available_methods[0].lower() diff --git a/src/llama_stack/models/llama/llama3/multimodal/encoder_utils.py b/src/llama_stack/models/llama/llama3/multimodal/encoder_utils.py index 0cc5aec81..a87d77cc3 100644 --- a/src/llama_stack/models/llama/llama3/multimodal/encoder_utils.py +++ b/src/llama_stack/models/llama/llama3/multimodal/encoder_utils.py @@ -141,7 +141,7 @@ def build_encoder_attention_mask( """ Build vision encoder attention mask that omits padding tokens. """ - masks_list = [] + masks_list: list[torch.Tensor] = [] for arx in ar: mask_i = torch.ones((num_chunks, x.shape[2], 1), dtype=x.dtype) mask_i[: arx[0] * arx[1], :ntok] = 0 diff --git a/src/llama_stack/providers/inline/batches/reference/batches.py b/src/llama_stack/providers/inline/batches/reference/batches.py index 79dc9c84c..7c4358b84 100644 --- a/src/llama_stack/providers/inline/batches/reference/batches.py +++ b/src/llama_stack/providers/inline/batches/reference/batches.py @@ -358,11 +358,10 @@ class ReferenceBatchesImpl(Batches): # TODO(SECURITY): do something about large files file_content_response = await self.files_api.openai_retrieve_file_content(batch.input_file_id) - # Handle both bytes and memoryview types - body = file_content_response.body - if isinstance(body, memoryview): - body = bytes(body) - file_content = body.decode("utf-8") + # Handle both bytes and memoryview types - convert to bytes unconditionally + # (bytes(x) returns x if already bytes, creates new bytes from memoryview otherwise) + body_bytes = bytes(file_content_response.body) + file_content = body_bytes.decode("utf-8") for line_num, line in enumerate(file_content.strip().split("\n"), 1): if line.strip(): # skip empty lines try: @@ -419,8 +418,8 @@ class ReferenceBatchesImpl(Batches): ) valid = False - if (body := request.get("body")) and isinstance(body, dict): - if body.get("stream", False): + if (request_body := request.get("body")) and isinstance(request_body, dict): + if request_body.get("stream", False): errors.append( BatchError( code="streaming_unsupported", @@ -451,7 +450,7 @@ class ReferenceBatchesImpl(Batches): ] for param, expected_type, type_string in required_params: - if param not in body: + if param not in request_body: errors.append( BatchError( code="invalid_request", @@ -461,7 +460,7 @@ class ReferenceBatchesImpl(Batches): ) ) valid = False - elif not isinstance(body[param], expected_type): + elif not isinstance(request_body[param], expected_type): errors.append( BatchError( code="invalid_request", @@ -472,15 +471,15 @@ class ReferenceBatchesImpl(Batches): ) valid = False - if "model" in body and isinstance(body["model"], str): + if "model" in request_body and isinstance(request_body["model"], str): try: - await self.models_api.get_model(body["model"]) + await self.models_api.get_model(request_body["model"]) except Exception: errors.append( BatchError( code="model_not_found", line=line_num, - message=f"Model '{body['model']}' does not exist or is not supported", + message=f"Model '{request_body['model']}' does not exist or is not supported", param="body.model", ) ) @@ -488,14 +487,14 @@ class ReferenceBatchesImpl(Batches): if valid: assert isinstance(url, str), "URL must be a string" # for mypy - assert isinstance(body, dict), "Body must be a dictionary" # for mypy + assert isinstance(request_body, dict), "Body must be a dictionary" # for mypy requests.append( BatchRequest( line_num=line_num, url=url, method=request["method"], custom_id=request["custom_id"], - body=body, + body=request_body, ), ) except json.JSONDecodeError: diff --git a/src/llama_stack/providers/remote/inference/together/together.py b/src/llama_stack/providers/remote/inference/together/together.py index e31ebf7c5..4caa4004d 100644 --- a/src/llama_stack/providers/remote/inference/together/together.py +++ b/src/llama_stack/providers/remote/inference/together/together.py @@ -6,6 +6,7 @@ from collections.abc import Iterable +from typing import Any, cast from together import AsyncTogether from together.constants import BASE_URL @@ -81,10 +82,11 @@ class TogetherInferenceAdapter(OpenAIMixin, NeedsRequestProviderData): if params.dimensions is not None: raise ValueError("Together's embeddings endpoint does not support dimensions param.") + # Cast encoding_format to match OpenAI SDK's expected Literal type response = await self.client.embeddings.create( model=await self._get_provider_model_id(params.model), input=params.input, - encoding_format=params.encoding_format, + encoding_format=cast(Any, params.encoding_format), ) response.model = ( @@ -97,6 +99,8 @@ class TogetherInferenceAdapter(OpenAIMixin, NeedsRequestProviderData): logger.warning( f"Together's embedding endpoint for {params.model} did not return usage information, substituting -1s." ) - response.usage = OpenAIEmbeddingUsage(prompt_tokens=-1, total_tokens=-1) + # Cast to allow monkey-patching the response object + response.usage = cast(Any, OpenAIEmbeddingUsage(prompt_tokens=-1, total_tokens=-1)) - return response # type: ignore[no-any-return] + # Together's CreateEmbeddingResponse is compatible with OpenAIEmbeddingsResponse after monkey-patching + return cast(OpenAIEmbeddingsResponse, response) diff --git a/src/llama_stack/providers/utils/tools/mcp.py b/src/llama_stack/providers/utils/tools/mcp.py index 48f07cb19..a271cb959 100644 --- a/src/llama_stack/providers/utils/tools/mcp.py +++ b/src/llama_stack/providers/utils/tools/mcp.py @@ -15,7 +15,7 @@ from mcp import types as mcp_types from mcp.client.sse import sse_client from mcp.client.streamable_http import streamablehttp_client -from llama_stack.apis.common.content_types import ImageContentItem, InterleavedContentItem, TextContentItem +from llama_stack.apis.common.content_types import ImageContentItem, InterleavedContentItem, TextContentItem, _URLOrData from llama_stack.apis.tools import ( ListToolDefsResponse, ToolDef, @@ -49,7 +49,9 @@ async def client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerat try: client = streamablehttp_client if strategy == MCPProtol.SSE: - client = sse_client + # sse_client and streamablehttp_client have different signatures, but both + # are called the same way here, so we cast to Any to avoid type errors + client = cast(Any, sse_client) async with client(endpoint, headers=headers) as client_streams: async with ClientSession(read_stream=client_streams[0], write_stream=client_streams[1]) as session: await session.initialize() @@ -137,7 +139,7 @@ async def invoke_mcp_tool( if isinstance(item, mcp_types.TextContent): content.append(TextContentItem(text=item.text)) elif isinstance(item, mcp_types.ImageContent): - content.append(ImageContentItem(image=item.data)) + content.append(ImageContentItem(image=_URLOrData(data=item.data))) elif isinstance(item, mcp_types.EmbeddedResource): logger.warning(f"EmbeddedResource is not supported: {item}") else: diff --git a/src/llama_stack/testing/api_recorder.py b/src/llama_stack/testing/api_recorder.py index 84407223c..e0c80d63c 100644 --- a/src/llama_stack/testing/api_recorder.py +++ b/src/llama_stack/testing/api_recorder.py @@ -40,7 +40,9 @@ from openai.types.completion_choice import CompletionChoice from llama_stack.core.testing_context import get_test_context, is_debug_mode # update the "finish_reason" field, since its type definition is wrong (no None is accepted) -CompletionChoice.model_fields["finish_reason"].annotation = Literal["stop", "length", "content_filter"] | None +CompletionChoice.model_fields["finish_reason"].annotation = cast( + type[Any] | None, Literal["stop", "length", "content_filter"] | None +) CompletionChoice.model_rebuild() REPO_ROOT = Path(__file__).parent.parent.parent.parent