fix: do not use mutable defaults in litellm/main.py

This commit is contained in:
Nilanjan De 2025-04-09 16:40:46 +04:00
parent f816630f08
commit cc769ea551
No known key found for this signature in database

View file

@ -313,7 +313,7 @@ class AsyncCompletions:
async def acompletion(
model: str,
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
messages: List = [],
messages: Optional[List] = None,
functions: Optional[List] = None,
function_call: Optional[str] = None,
timeout: Optional[Union[float, int]] = None,
@ -395,6 +395,7 @@ async def acompletion(
- The `completion` function is called using `run_in_executor` to execute synchronously in the event loop.
- If `stream` is True, the function returns an async generator that yields completion lines.
"""
messages = messages or []
fallbacks = kwargs.get("fallbacks", None)
mock_timeout = kwargs.get("mock_timeout", None)
@ -530,18 +531,14 @@ def _handle_mock_potential_exceptions(
raise litellm.MockException(
status_code=getattr(mock_response, "status_code", 500), # type: ignore
message=getattr(mock_response, "text", str(mock_response)),
llm_provider=getattr(
mock_response, "llm_provider", custom_llm_provider or "openai"
), # type: ignore
llm_provider=getattr(mock_response, "llm_provider", custom_llm_provider or "openai"), # type: ignore
model=model, # type: ignore
request=httpx.Request(method="POST", url="https://api.openai.com/v1/"),
)
elif isinstance(mock_response, str) and mock_response == "litellm.RateLimitError":
raise litellm.RateLimitError(
message="this is a mock rate limit error",
llm_provider=getattr(
mock_response, "llm_provider", custom_llm_provider or "openai"
), # type: ignore
llm_provider=getattr(mock_response, "llm_provider", custom_llm_provider or "openai"), # type: ignore
model=model,
)
elif (
@ -550,9 +547,7 @@ def _handle_mock_potential_exceptions(
):
raise litellm.ContextWindowExceededError(
message="this is a mock context window exceeded error",
llm_provider=getattr(
mock_response, "llm_provider", custom_llm_provider or "openai"
), # type: ignore
llm_provider=getattr(mock_response, "llm_provider", custom_llm_provider or "openai"), # type: ignore
model=model,
)
elif (
@ -561,9 +556,7 @@ def _handle_mock_potential_exceptions(
):
raise litellm.InternalServerError(
message="this is a mock internal server error",
llm_provider=getattr(
mock_response, "llm_provider", custom_llm_provider or "openai"
), # type: ignore
llm_provider=getattr(mock_response, "llm_provider", custom_llm_provider or "openai"), # type: ignore
model=model,
)
elif isinstance(mock_response, str) and mock_response.startswith(
@ -778,7 +771,7 @@ def mock_completion(
def completion( # type: ignore # noqa: PLR0915
model: str,
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
messages: List = [],
messages: Optional[List] = None,
timeout: Optional[Union[float, str, httpx.Timeout]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
@ -863,6 +856,7 @@ def completion( # type: ignore # noqa: PLR0915
- It supports various optional parameters for customizing the completion behavior.
- If 'mock_response' is provided, a mock completion response is returned for testing or debugging.
"""
messages = messages or []
### VALIDATE Request ###
if model is None:
raise ValueError("model param not passed in.")
@ -945,7 +939,7 @@ def completion( # type: ignore # noqa: PLR0915
prompt_variables = cast(Optional[dict], kwargs.get("prompt_variables", None))
### COPY MESSAGES ### - related issue https://github.com/BerriAI/litellm/discussions/4489
messages = get_completion_messages(
messages=messages,
messages=messages or [],
ensure_alternating_roles=ensure_alternating_roles or False,
user_continue_message=user_continue_message,
assistant_continue_message=assistant_continue_message,
@ -1206,7 +1200,7 @@ def completion( # type: ignore # noqa: PLR0915
kwargs.pop("mock_timeout", None) # remove for any fallbacks triggered
return mock_completion(
model,
messages,
messages or [],
stream=stream,
n=n,
mock_response=mock_response,
@ -1218,7 +1212,8 @@ def completion( # type: ignore # noqa: PLR0915
mock_timeout=mock_timeout,
timeout=timeout,
)
if messages is None:
messages = []
if custom_llm_provider == "azure":
# azure configs
## check dynamic params ##
@ -3283,7 +3278,7 @@ async def aembedding(*args, **kwargs) -> EmbeddingResponse:
@client
def embedding( # noqa: PLR0915
model,
input=[],
input: Optional[List] = None,
# Optional params
dimensions: Optional[int] = None,
encoding_format: Optional[str] = None,
@ -3325,6 +3320,7 @@ def embedding( # noqa: PLR0915
Raises:
- exception_type: If an exception occurs during the API call.
"""
input = input or []
azure = kwargs.get("azure", None)
client = kwargs.pop("client", None)
max_retries = kwargs.get("max_retries", None)
@ -3520,12 +3516,7 @@ def embedding( # noqa: PLR0915
api_base = api_base or litellm.api_base or get_secret("DATABRICKS_API_BASE") # type: ignore
# set API KEY
api_key = (
api_key
or litellm.api_key
or litellm.databricks_key
or get_secret("DATABRICKS_API_KEY")
) # type: ignore
api_key = api_key or litellm.api_key or litellm.databricks_key or get_secret("DATABRICKS_API_KEY") # type: ignore
## EMBEDDING CALL
response = databricks_embedding.embedding(
@ -3599,12 +3590,7 @@ def embedding( # noqa: PLR0915
client=client,
)
elif custom_llm_provider == "huggingface":
api_key = (
api_key
or litellm.huggingface_key
or get_secret("HUGGINGFACE_API_KEY")
or litellm.api_key
) # type: ignore
api_key = api_key or litellm.huggingface_key or get_secret("HUGGINGFACE_API_KEY") or litellm.api_key # type: ignore
response = huggingface_embed.embedding(
model=model,
input=input,
@ -3765,12 +3751,7 @@ def embedding( # noqa: PLR0915
api_key=api_key,
)
elif custom_llm_provider == "ollama":
api_base = (
litellm.api_base
or api_base
or get_secret_str("OLLAMA_API_BASE")
or "http://localhost:11434"
) # type: ignore
api_base = litellm.api_base or api_base or get_secret_str("OLLAMA_API_BASE") or "http://localhost:11434" # type: ignore
if isinstance(input, str):
input = [input]
@ -5146,12 +5127,7 @@ def transcription(
custom_llm_provider == "openai"
or custom_llm_provider in litellm.openai_compatible_providers
):
api_base = (
api_base
or litellm.api_base
or get_secret("OPENAI_API_BASE")
or "https://api.openai.com/v1"
) # type: ignore
api_base = api_base or litellm.api_base or get_secret("OPENAI_API_BASE") or "https://api.openai.com/v1" # type: ignore
openai.organization = (
litellm.organization
or get_secret("OPENAI_ORGANIZATION")