fix typing errors

This commit is contained in:
Ishaan Jaff 2025-03-18 12:31:44 -07:00
parent 3261c66b39
commit 17e0742334

View file

@ -25,6 +25,7 @@ from functools import partial
from typing import (
Any,
Callable,
Coroutine,
Dict,
List,
Literal,
@ -3288,7 +3289,7 @@ def embedding( # noqa: PLR0915
litellm_call_id=None,
logger_fn=None,
**kwargs,
) -> EmbeddingResponse:
) -> Union[EmbeddingResponse, Coroutine[Any, Any, EmbeddingResponse]]:
"""
Embedding function that calls an API to generate embeddings for the given input.
@ -3409,7 +3410,9 @@ def embedding( # noqa: PLR0915
if mock_response is not None:
return mock_embedding(model=model, mock_response=mock_response)
try:
response: Optional[EmbeddingResponse] = None
response: Optional[
Union[EmbeddingResponse, Coroutine[Any, Any, EmbeddingResponse]]
] = None
if azure is True or custom_llm_provider == "azure":
# azure configs
@ -3901,7 +3904,11 @@ def embedding( # noqa: PLR0915
raise LiteLLMUnknownProvider(
model=model, custom_llm_provider=custom_llm_provider
)
if response is not None and hasattr(response, "_hidden_params"):
if (
response is not None
and hasattr(response, "_hidden_params")
and isinstance(response, EmbeddingResponse)
):
response._hidden_params["custom_llm_provider"] = custom_llm_provider
if response is None: