address feedback

This commit is contained in:
Ashwin Bharambe 2025-10-28 09:39:43 -07:00
parent 9347e49414
commit 24667e43e0
3 changed files with 17 additions and 14 deletions

View file

@ -55,6 +55,9 @@ def _create_s3_client(config: S3FilesImplConfig) -> S3Client:
} }
) )
# Both cast and type:ignore are needed here:
# - cast tells mypy the return type for downstream usage (S3Client vs generic client)
# - type:ignore suppresses the call-overload error from boto3's complex overloaded signatures
return cast("S3Client", boto3.client("s3", **s3_config)) # type: ignore[call-overload] return cast("S3Client", boto3.client("s3", **s3_config)) # type: ignore[call-overload]
except (BotoCoreError, NoCredentialsError) as e: except (BotoCoreError, NoCredentialsError) as e:

View file

@ -37,21 +37,21 @@ class GeminiInferenceAdapter(OpenAIMixin):
Override embeddings method to handle Gemini's missing usage statistics. Override embeddings method to handle Gemini's missing usage statistics.
Gemini's embedding API doesn't return usage information, so we provide default values. Gemini's embedding API doesn't return usage information, so we provide default values.
""" """
# Build kwargs conditionally to avoid NotGiven/Omit type mismatch # Build request params conditionally to avoid NotGiven/Omit type mismatch
kwargs: dict[str, Any] = { request_params: dict[str, Any] = {
"model": await self._get_provider_model_id(params.model), "model": await self._get_provider_model_id(params.model),
"input": params.input, "input": params.input,
} }
if params.encoding_format is not None: if params.encoding_format is not None:
kwargs["encoding_format"] = params.encoding_format request_params["encoding_format"] = params.encoding_format
if params.dimensions is not None: if params.dimensions is not None:
kwargs["dimensions"] = params.dimensions request_params["dimensions"] = params.dimensions
if params.user is not None: if params.user is not None:
kwargs["user"] = params.user request_params["user"] = params.user
if params.model_extra: if params.model_extra:
kwargs["extra_body"] = params.model_extra request_params["extra_body"] = params.model_extra
response = await self.client.embeddings.create(**kwargs) response = await self.client.embeddings.create(**request_params)
data = [] data = []
for i, embedding_data in enumerate(response.data): for i, embedding_data in enumerate(response.data):

View file

@ -351,22 +351,22 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
""" """
Direct OpenAI embeddings API call. Direct OpenAI embeddings API call.
""" """
# Build kwargs conditionally to avoid NotGiven/Omit type mismatch # Build request params conditionally to avoid NotGiven/Omit type mismatch
# The OpenAI SDK uses Omit in signatures but NOT_GIVEN has type NotGiven # The OpenAI SDK uses Omit in signatures but NOT_GIVEN has type NotGiven
kwargs: dict[str, Any] = { request_params: dict[str, Any] = {
"model": await self._get_provider_model_id(params.model), "model": await self._get_provider_model_id(params.model),
"input": params.input, "input": params.input,
} }
if params.encoding_format is not None: if params.encoding_format is not None:
kwargs["encoding_format"] = params.encoding_format request_params["encoding_format"] = params.encoding_format
if params.dimensions is not None: if params.dimensions is not None:
kwargs["dimensions"] = params.dimensions request_params["dimensions"] = params.dimensions
if params.user is not None: if params.user is not None:
kwargs["user"] = params.user request_params["user"] = params.user
if params.model_extra: if params.model_extra:
kwargs["extra_body"] = params.model_extra request_params["extra_body"] = params.model_extra
response = await self.client.embeddings.create(**kwargs) response = await self.client.embeddings.create(**request_params)
data = [] data = []
for i, embedding_data in enumerate(response.data): for i, embedding_data in enumerate(response.data):