mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-11 19:56:03 +00:00
fix(mypy): resolve additional type errors in batches and together
- batches.py: Fix 6 cascading errors from body variable shadowing - together.py: Add type casts for Together API integration 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
0c95140ca7
commit
a7866dff48
2 changed files with 14 additions and 10 deletions
|
|
@ -451,7 +451,7 @@ class ReferenceBatchesImpl(Batches):
|
||||||
]
|
]
|
||||||
|
|
||||||
for param, expected_type, type_string in required_params:
|
for param, expected_type, type_string in required_params:
|
||||||
if param not in body:
|
if param not in request_body:
|
||||||
errors.append(
|
errors.append(
|
||||||
BatchError(
|
BatchError(
|
||||||
code="invalid_request",
|
code="invalid_request",
|
||||||
|
|
@ -461,7 +461,7 @@ class ReferenceBatchesImpl(Batches):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
valid = False
|
valid = False
|
||||||
elif not isinstance(body[param], expected_type):
|
elif not isinstance(request_body[param], expected_type):
|
||||||
errors.append(
|
errors.append(
|
||||||
BatchError(
|
BatchError(
|
||||||
code="invalid_request",
|
code="invalid_request",
|
||||||
|
|
@ -472,15 +472,15 @@ class ReferenceBatchesImpl(Batches):
|
||||||
)
|
)
|
||||||
valid = False
|
valid = False
|
||||||
|
|
||||||
if "model" in body and isinstance(body["model"], str):
|
if "model" in request_body and isinstance(request_body["model"], str):
|
||||||
try:
|
try:
|
||||||
await self.models_api.get_model(body["model"])
|
await self.models_api.get_model(request_body["model"])
|
||||||
except Exception:
|
except Exception:
|
||||||
errors.append(
|
errors.append(
|
||||||
BatchError(
|
BatchError(
|
||||||
code="model_not_found",
|
code="model_not_found",
|
||||||
line=line_num,
|
line=line_num,
|
||||||
message=f"Model '{body['model']}' does not exist or is not supported",
|
message=f"Model '{request_body['model']}' does not exist or is not supported",
|
||||||
param="body.model",
|
param="body.model",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
@ -488,14 +488,14 @@ class ReferenceBatchesImpl(Batches):
|
||||||
|
|
||||||
if valid:
|
if valid:
|
||||||
assert isinstance(url, str), "URL must be a string" # for mypy
|
assert isinstance(url, str), "URL must be a string" # for mypy
|
||||||
assert isinstance(body, dict), "Body must be a dictionary" # for mypy
|
assert isinstance(request_body, dict), "Body must be a dictionary" # for mypy
|
||||||
requests.append(
|
requests.append(
|
||||||
BatchRequest(
|
BatchRequest(
|
||||||
line_num=line_num,
|
line_num=line_num,
|
||||||
url=url,
|
url=url,
|
||||||
method=request["method"],
|
method=request["method"],
|
||||||
custom_id=request["custom_id"],
|
custom_id=request["custom_id"],
|
||||||
body=body,
|
body=request_body,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@
|
||||||
|
|
||||||
|
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
from together import AsyncTogether
|
from together import AsyncTogether
|
||||||
from together.constants import BASE_URL
|
from together.constants import BASE_URL
|
||||||
|
|
@ -81,10 +82,11 @@ class TogetherInferenceAdapter(OpenAIMixin, NeedsRequestProviderData):
|
||||||
if params.dimensions is not None:
|
if params.dimensions is not None:
|
||||||
raise ValueError("Together's embeddings endpoint does not support dimensions param.")
|
raise ValueError("Together's embeddings endpoint does not support dimensions param.")
|
||||||
|
|
||||||
|
# Cast encoding_format to match OpenAI SDK's expected Literal type
|
||||||
response = await self.client.embeddings.create(
|
response = await self.client.embeddings.create(
|
||||||
model=await self._get_provider_model_id(params.model),
|
model=await self._get_provider_model_id(params.model),
|
||||||
input=params.input,
|
input=params.input,
|
||||||
encoding_format=params.encoding_format,
|
encoding_format=cast(Any, params.encoding_format),
|
||||||
)
|
)
|
||||||
|
|
||||||
response.model = (
|
response.model = (
|
||||||
|
|
@ -97,6 +99,8 @@ class TogetherInferenceAdapter(OpenAIMixin, NeedsRequestProviderData):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Together's embedding endpoint for {params.model} did not return usage information, substituting -1s."
|
f"Together's embedding endpoint for {params.model} did not return usage information, substituting -1s."
|
||||||
)
|
)
|
||||||
response.usage = OpenAIEmbeddingUsage(prompt_tokens=-1, total_tokens=-1)
|
# Cast to allow monkey-patching the response object
|
||||||
|
response.usage = cast(Any, OpenAIEmbeddingUsage(prompt_tokens=-1, total_tokens=-1))
|
||||||
|
|
||||||
return response # type: ignore[no-any-return]
|
# Together's CreateEmbeddingResponse is compatible with OpenAIEmbeddingsResponse after monkey-patching
|
||||||
|
return cast(OpenAIEmbeddingsResponse, response)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue