mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 18:00:36 +00:00
fix(mypy): resolve OpenAI SDK and provider type issues (#3936)
## Summary - Fix OpenAI SDK NotGiven/Omit type mismatches in embeddings calls - Fix incorrect OpenAIChatCompletionChunk import in vllm provider - Refactor to avoid type:ignore comments by using conditional kwargs ## Changes **openai_mixin.py (9 errors fixed):** - Build kwargs conditionally for embeddings.create() to avoid NotGiven/Omit mismatch - Only include parameters when they have actual values (not None) **gemini.py (9 errors fixed):** - Apply same conditional kwargs pattern - Add missing Any import **vllm.py (2 errors fixed):** - Use correct OpenAIChatCompletionChunk from llama_stack.apis.inference - Remove incorrect alias from openai package ## Technical Notes The OpenAI SDK has a type system quirk where `NOT_GIVEN` has type `NotGiven` but parameter signatures expect `Omit`. By only passing parameters with actual values, we avoid this mismatch entirely without needing `# type: ignore` comments. 🤖 Generated with [Claude Code](https://claude.com/claude-code) --------- Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
parent
d009dc29f7
commit
1d385b5b75
7 changed files with 60 additions and 41 deletions
|
|
@ -4,14 +4,19 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from typing import Annotated, Any
|
||||
from typing import TYPE_CHECKING, Annotated, Any, cast
|
||||
|
||||
import boto3
|
||||
from botocore.exceptions import BotoCoreError, ClientError, NoCredentialsError
|
||||
from fastapi import Depends, File, Form, Response, UploadFile
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mypy_boto3_s3.client import S3Client
|
||||
|
||||
from llama_stack.apis.common.errors import ResourceNotFoundError
|
||||
from llama_stack.apis.common.responses import Order
|
||||
from llama_stack.apis.files import (
|
||||
|
|
@ -34,7 +39,7 @@ from .config import S3FilesImplConfig
|
|||
# TODO: provider data for S3 credentials
|
||||
|
||||
|
||||
def _create_s3_client(config: S3FilesImplConfig) -> boto3.client:
|
||||
def _create_s3_client(config: S3FilesImplConfig) -> S3Client:
|
||||
try:
|
||||
s3_config = {
|
||||
"region_name": config.region,
|
||||
|
|
@ -52,13 +57,16 @@ def _create_s3_client(config: S3FilesImplConfig) -> boto3.client:
|
|||
}
|
||||
)
|
||||
|
||||
return boto3.client("s3", **s3_config)
|
||||
# Both cast and type:ignore are needed here:
|
||||
# - cast tells mypy the return type for downstream usage (S3Client vs generic client)
|
||||
# - type:ignore suppresses the call-overload error from boto3's complex overloaded signatures
|
||||
return cast("S3Client", boto3.client("s3", **s3_config)) # type: ignore[call-overload]
|
||||
|
||||
except (BotoCoreError, NoCredentialsError) as e:
|
||||
raise RuntimeError(f"Failed to initialize S3 client: {e}") from e
|
||||
|
||||
|
||||
async def _create_bucket_if_not_exists(client: boto3.client, config: S3FilesImplConfig) -> None:
|
||||
async def _create_bucket_if_not_exists(client: S3Client, config: S3FilesImplConfig) -> None:
|
||||
try:
|
||||
client.head_bucket(Bucket=config.bucket_name)
|
||||
except ClientError as e:
|
||||
|
|
@ -76,7 +84,7 @@ async def _create_bucket_if_not_exists(client: boto3.client, config: S3FilesImpl
|
|||
else:
|
||||
client.create_bucket(
|
||||
Bucket=config.bucket_name,
|
||||
CreateBucketConfiguration={"LocationConstraint": config.region},
|
||||
CreateBucketConfiguration=cast(Any, {"LocationConstraint": config.region}),
|
||||
)
|
||||
except ClientError as create_error:
|
||||
raise RuntimeError(
|
||||
|
|
@ -128,7 +136,7 @@ class S3FilesImpl(Files):
|
|||
def __init__(self, config: S3FilesImplConfig, policy: list[AccessRule]) -> None:
|
||||
self._config = config
|
||||
self.policy = policy
|
||||
self._client: boto3.client | None = None
|
||||
self._client: S3Client | None = None
|
||||
self._sql_store: AuthorizedSqlStore | None = None
|
||||
|
||||
def _now(self) -> int:
|
||||
|
|
@ -184,7 +192,7 @@ class S3FilesImpl(Files):
|
|||
pass
|
||||
|
||||
@property
|
||||
def client(self) -> boto3.client:
|
||||
def client(self) -> S3Client:
|
||||
assert self._client is not None, "Provider not initialized"
|
||||
return self._client
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from openai import NOT_GIVEN
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIEmbeddingData,
|
||||
|
|
@ -37,21 +37,20 @@ class GeminiInferenceAdapter(OpenAIMixin):
|
|||
Override embeddings method to handle Gemini's missing usage statistics.
|
||||
Gemini's embedding API doesn't return usage information, so we provide default values.
|
||||
"""
|
||||
# Prepare request parameters
|
||||
request_params = {
|
||||
# Build request params conditionally to avoid NotGiven/Omit type mismatch
|
||||
request_params: dict[str, Any] = {
|
||||
"model": await self._get_provider_model_id(params.model),
|
||||
"input": params.input,
|
||||
"encoding_format": params.encoding_format if params.encoding_format is not None else NOT_GIVEN,
|
||||
"dimensions": params.dimensions if params.dimensions is not None else NOT_GIVEN,
|
||||
"user": params.user if params.user is not None else NOT_GIVEN,
|
||||
}
|
||||
if params.encoding_format is not None:
|
||||
request_params["encoding_format"] = params.encoding_format
|
||||
if params.dimensions is not None:
|
||||
request_params["dimensions"] = params.dimensions
|
||||
if params.user is not None:
|
||||
request_params["user"] = params.user
|
||||
if params.model_extra:
|
||||
request_params["extra_body"] = params.model_extra
|
||||
|
||||
# Add extra_body if present
|
||||
extra_body = params.model_extra
|
||||
if extra_body:
|
||||
request_params["extra_body"] = extra_body
|
||||
|
||||
# Call OpenAI embeddings API with properly typed parameters
|
||||
response = await self.client.embeddings.create(**request_params)
|
||||
|
||||
data = []
|
||||
|
|
|
|||
|
|
@ -7,13 +7,11 @@ from collections.abc import AsyncIterator
|
|||
from urllib.parse import urljoin
|
||||
|
||||
import httpx
|
||||
from openai.types.chat.chat_completion_chunk import (
|
||||
ChatCompletionChunk as OpenAIChatCompletionChunk,
|
||||
)
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
ToolChoice,
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue