mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 18:00:36 +00:00
fix(mypy): resolve provider utility and testing type issues (#3935)
Some checks failed
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 0s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 0s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 2s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 3s
Test Llama Stack Build / generate-matrix (push) Successful in 3s
Vector IO Integration Tests / test-matrix (push) Failing after 5s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Python Package Build Test / build (3.12) (push) Failing after 2s
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 4s
Test Llama Stack Build / build-single-provider (push) Failing after 4s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 4s
Python Package Build Test / build (3.13) (push) Failing after 3s
Test llama stack list-deps / generate-matrix (push) Successful in 4s
Test llama stack list-deps / show-single-provider (push) Failing after 3s
API Conformance Tests / check-schema-compatibility (push) Successful in 11s
Test llama stack list-deps / list-deps-from-config (push) Failing after 4s
Test External API and Providers / test-external (venv) (push) Failing after 3s
Unit Tests / unit-tests (3.12) (push) Failing after 4s
Unit Tests / unit-tests (3.13) (push) Failing after 4s
Test llama stack list-deps / list-deps (push) Failing after 4s
Test Llama Stack Build / build (push) Failing after 7s
UI Tests / ui-tests (22) (push) Successful in 51s
Pre-commit / pre-commit (push) Successful in 2m0s
Some checks failed
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 0s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 0s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 2s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 3s
Test Llama Stack Build / generate-matrix (push) Successful in 3s
Vector IO Integration Tests / test-matrix (push) Failing after 5s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Python Package Build Test / build (3.12) (push) Failing after 2s
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 4s
Test Llama Stack Build / build-single-provider (push) Failing after 4s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 4s
Python Package Build Test / build (3.13) (push) Failing after 3s
Test llama stack list-deps / generate-matrix (push) Successful in 4s
Test llama stack list-deps / show-single-provider (push) Failing after 3s
API Conformance Tests / check-schema-compatibility (push) Successful in 11s
Test llama stack list-deps / list-deps-from-config (push) Failing after 4s
Test External API and Providers / test-external (venv) (push) Failing after 3s
Unit Tests / unit-tests (3.12) (push) Failing after 4s
Unit Tests / unit-tests (3.13) (push) Failing after 4s
Test llama stack list-deps / list-deps (push) Failing after 4s
Test Llama Stack Build / build (push) Failing after 7s
UI Tests / ui-tests (22) (push) Successful in 51s
Pre-commit / pre-commit (push) Successful in 2m0s
Fixes mypy type errors in provider utilities and testing infrastructure: - `mcp.py`: Cast incompatible client types, wrap image data properly - `batches.py`: Rename walrus variable to avoid shadowing - `api_recorder.py`: Use cast for Pydantic field annotation No functional changes. --------- Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
parent
fcf07790c8
commit
d009dc29f7
6 changed files with 33 additions and 25 deletions
|
|
@ -68,8 +68,9 @@ def get_all_api_routes(
|
||||||
else:
|
else:
|
||||||
http_method = hdrs.METH_POST
|
http_method = hdrs.METH_POST
|
||||||
routes.append(
|
routes.append(
|
||||||
(Route(path=path, methods=[http_method], name=name, endpoint=None), webmethod)
|
# setting endpoint to None since don't use a Router object
|
||||||
) # setting endpoint to None since don't use a Router object
|
(Route(path=path, methods=[http_method], name=name, endpoint=None), webmethod) # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
|
||||||
apis[api] = routes
|
apis[api] = routes
|
||||||
|
|
||||||
|
|
@ -98,7 +99,7 @@ def initialize_route_impls(impls, external_apis: dict[Api, ExternalApiSpec] | No
|
||||||
impl = impls[api]
|
impl = impls[api]
|
||||||
func = getattr(impl, route.name)
|
func = getattr(impl, route.name)
|
||||||
# Get the first (and typically only) method from the set, filtering out HEAD
|
# Get the first (and typically only) method from the set, filtering out HEAD
|
||||||
available_methods = [m for m in route.methods if m != "HEAD"]
|
available_methods = [m for m in (route.methods or []) if m != "HEAD"]
|
||||||
if not available_methods:
|
if not available_methods:
|
||||||
continue # Skip if only HEAD method is available
|
continue # Skip if only HEAD method is available
|
||||||
method = available_methods[0].lower()
|
method = available_methods[0].lower()
|
||||||
|
|
|
||||||
|
|
@ -141,7 +141,7 @@ def build_encoder_attention_mask(
|
||||||
"""
|
"""
|
||||||
Build vision encoder attention mask that omits padding tokens.
|
Build vision encoder attention mask that omits padding tokens.
|
||||||
"""
|
"""
|
||||||
masks_list = []
|
masks_list: list[torch.Tensor] = []
|
||||||
for arx in ar:
|
for arx in ar:
|
||||||
mask_i = torch.ones((num_chunks, x.shape[2], 1), dtype=x.dtype)
|
mask_i = torch.ones((num_chunks, x.shape[2], 1), dtype=x.dtype)
|
||||||
mask_i[: arx[0] * arx[1], :ntok] = 0
|
mask_i[: arx[0] * arx[1], :ntok] = 0
|
||||||
|
|
|
||||||
|
|
@ -358,11 +358,10 @@ class ReferenceBatchesImpl(Batches):
|
||||||
|
|
||||||
# TODO(SECURITY): do something about large files
|
# TODO(SECURITY): do something about large files
|
||||||
file_content_response = await self.files_api.openai_retrieve_file_content(batch.input_file_id)
|
file_content_response = await self.files_api.openai_retrieve_file_content(batch.input_file_id)
|
||||||
# Handle both bytes and memoryview types
|
# Handle both bytes and memoryview types - convert to bytes unconditionally
|
||||||
body = file_content_response.body
|
# (bytes(x) returns x if already bytes, creates new bytes from memoryview otherwise)
|
||||||
if isinstance(body, memoryview):
|
body_bytes = bytes(file_content_response.body)
|
||||||
body = bytes(body)
|
file_content = body_bytes.decode("utf-8")
|
||||||
file_content = body.decode("utf-8")
|
|
||||||
for line_num, line in enumerate(file_content.strip().split("\n"), 1):
|
for line_num, line in enumerate(file_content.strip().split("\n"), 1):
|
||||||
if line.strip(): # skip empty lines
|
if line.strip(): # skip empty lines
|
||||||
try:
|
try:
|
||||||
|
|
@ -419,8 +418,8 @@ class ReferenceBatchesImpl(Batches):
|
||||||
)
|
)
|
||||||
valid = False
|
valid = False
|
||||||
|
|
||||||
if (body := request.get("body")) and isinstance(body, dict):
|
if (request_body := request.get("body")) and isinstance(request_body, dict):
|
||||||
if body.get("stream", False):
|
if request_body.get("stream", False):
|
||||||
errors.append(
|
errors.append(
|
||||||
BatchError(
|
BatchError(
|
||||||
code="streaming_unsupported",
|
code="streaming_unsupported",
|
||||||
|
|
@ -451,7 +450,7 @@ class ReferenceBatchesImpl(Batches):
|
||||||
]
|
]
|
||||||
|
|
||||||
for param, expected_type, type_string in required_params:
|
for param, expected_type, type_string in required_params:
|
||||||
if param not in body:
|
if param not in request_body:
|
||||||
errors.append(
|
errors.append(
|
||||||
BatchError(
|
BatchError(
|
||||||
code="invalid_request",
|
code="invalid_request",
|
||||||
|
|
@ -461,7 +460,7 @@ class ReferenceBatchesImpl(Batches):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
valid = False
|
valid = False
|
||||||
elif not isinstance(body[param], expected_type):
|
elif not isinstance(request_body[param], expected_type):
|
||||||
errors.append(
|
errors.append(
|
||||||
BatchError(
|
BatchError(
|
||||||
code="invalid_request",
|
code="invalid_request",
|
||||||
|
|
@ -472,15 +471,15 @@ class ReferenceBatchesImpl(Batches):
|
||||||
)
|
)
|
||||||
valid = False
|
valid = False
|
||||||
|
|
||||||
if "model" in body and isinstance(body["model"], str):
|
if "model" in request_body and isinstance(request_body["model"], str):
|
||||||
try:
|
try:
|
||||||
await self.models_api.get_model(body["model"])
|
await self.models_api.get_model(request_body["model"])
|
||||||
except Exception:
|
except Exception:
|
||||||
errors.append(
|
errors.append(
|
||||||
BatchError(
|
BatchError(
|
||||||
code="model_not_found",
|
code="model_not_found",
|
||||||
line=line_num,
|
line=line_num,
|
||||||
message=f"Model '{body['model']}' does not exist or is not supported",
|
message=f"Model '{request_body['model']}' does not exist or is not supported",
|
||||||
param="body.model",
|
param="body.model",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
@ -488,14 +487,14 @@ class ReferenceBatchesImpl(Batches):
|
||||||
|
|
||||||
if valid:
|
if valid:
|
||||||
assert isinstance(url, str), "URL must be a string" # for mypy
|
assert isinstance(url, str), "URL must be a string" # for mypy
|
||||||
assert isinstance(body, dict), "Body must be a dictionary" # for mypy
|
assert isinstance(request_body, dict), "Body must be a dictionary" # for mypy
|
||||||
requests.append(
|
requests.append(
|
||||||
BatchRequest(
|
BatchRequest(
|
||||||
line_num=line_num,
|
line_num=line_num,
|
||||||
url=url,
|
url=url,
|
||||||
method=request["method"],
|
method=request["method"],
|
||||||
custom_id=request["custom_id"],
|
custom_id=request["custom_id"],
|
||||||
body=body,
|
body=request_body,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@
|
||||||
|
|
||||||
|
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
from together import AsyncTogether
|
from together import AsyncTogether
|
||||||
from together.constants import BASE_URL
|
from together.constants import BASE_URL
|
||||||
|
|
@ -81,10 +82,11 @@ class TogetherInferenceAdapter(OpenAIMixin, NeedsRequestProviderData):
|
||||||
if params.dimensions is not None:
|
if params.dimensions is not None:
|
||||||
raise ValueError("Together's embeddings endpoint does not support dimensions param.")
|
raise ValueError("Together's embeddings endpoint does not support dimensions param.")
|
||||||
|
|
||||||
|
# Cast encoding_format to match OpenAI SDK's expected Literal type
|
||||||
response = await self.client.embeddings.create(
|
response = await self.client.embeddings.create(
|
||||||
model=await self._get_provider_model_id(params.model),
|
model=await self._get_provider_model_id(params.model),
|
||||||
input=params.input,
|
input=params.input,
|
||||||
encoding_format=params.encoding_format,
|
encoding_format=cast(Any, params.encoding_format),
|
||||||
)
|
)
|
||||||
|
|
||||||
response.model = (
|
response.model = (
|
||||||
|
|
@ -97,6 +99,8 @@ class TogetherInferenceAdapter(OpenAIMixin, NeedsRequestProviderData):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Together's embedding endpoint for {params.model} did not return usage information, substituting -1s."
|
f"Together's embedding endpoint for {params.model} did not return usage information, substituting -1s."
|
||||||
)
|
)
|
||||||
response.usage = OpenAIEmbeddingUsage(prompt_tokens=-1, total_tokens=-1)
|
# Cast to allow monkey-patching the response object
|
||||||
|
response.usage = cast(Any, OpenAIEmbeddingUsage(prompt_tokens=-1, total_tokens=-1))
|
||||||
|
|
||||||
return response # type: ignore[no-any-return]
|
# Together's CreateEmbeddingResponse is compatible with OpenAIEmbeddingsResponse after monkey-patching
|
||||||
|
return cast(OpenAIEmbeddingsResponse, response)
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ from mcp import types as mcp_types
|
||||||
from mcp.client.sse import sse_client
|
from mcp.client.sse import sse_client
|
||||||
from mcp.client.streamable_http import streamablehttp_client
|
from mcp.client.streamable_http import streamablehttp_client
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import ImageContentItem, InterleavedContentItem, TextContentItem
|
from llama_stack.apis.common.content_types import ImageContentItem, InterleavedContentItem, TextContentItem, _URLOrData
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
ListToolDefsResponse,
|
ListToolDefsResponse,
|
||||||
ToolDef,
|
ToolDef,
|
||||||
|
|
@ -49,7 +49,9 @@ async def client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerat
|
||||||
try:
|
try:
|
||||||
client = streamablehttp_client
|
client = streamablehttp_client
|
||||||
if strategy == MCPProtol.SSE:
|
if strategy == MCPProtol.SSE:
|
||||||
client = sse_client
|
# sse_client and streamablehttp_client have different signatures, but both
|
||||||
|
# are called the same way here, so we cast to Any to avoid type errors
|
||||||
|
client = cast(Any, sse_client)
|
||||||
async with client(endpoint, headers=headers) as client_streams:
|
async with client(endpoint, headers=headers) as client_streams:
|
||||||
async with ClientSession(read_stream=client_streams[0], write_stream=client_streams[1]) as session:
|
async with ClientSession(read_stream=client_streams[0], write_stream=client_streams[1]) as session:
|
||||||
await session.initialize()
|
await session.initialize()
|
||||||
|
|
@ -137,7 +139,7 @@ async def invoke_mcp_tool(
|
||||||
if isinstance(item, mcp_types.TextContent):
|
if isinstance(item, mcp_types.TextContent):
|
||||||
content.append(TextContentItem(text=item.text))
|
content.append(TextContentItem(text=item.text))
|
||||||
elif isinstance(item, mcp_types.ImageContent):
|
elif isinstance(item, mcp_types.ImageContent):
|
||||||
content.append(ImageContentItem(image=item.data))
|
content.append(ImageContentItem(image=_URLOrData(data=item.data)))
|
||||||
elif isinstance(item, mcp_types.EmbeddedResource):
|
elif isinstance(item, mcp_types.EmbeddedResource):
|
||||||
logger.warning(f"EmbeddedResource is not supported: {item}")
|
logger.warning(f"EmbeddedResource is not supported: {item}")
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -40,7 +40,9 @@ from openai.types.completion_choice import CompletionChoice
|
||||||
from llama_stack.core.testing_context import get_test_context, is_debug_mode
|
from llama_stack.core.testing_context import get_test_context, is_debug_mode
|
||||||
|
|
||||||
# update the "finish_reason" field, since its type definition is wrong (no None is accepted)
|
# update the "finish_reason" field, since its type definition is wrong (no None is accepted)
|
||||||
CompletionChoice.model_fields["finish_reason"].annotation = Literal["stop", "length", "content_filter"] | None
|
CompletionChoice.model_fields["finish_reason"].annotation = cast(
|
||||||
|
type[Any] | None, Literal["stop", "length", "content_filter"] | None
|
||||||
|
)
|
||||||
CompletionChoice.model_rebuild()
|
CompletionChoice.model_rebuild()
|
||||||
|
|
||||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue