mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
fix: do not use mutable defaults in litellm/main.py
This commit is contained in:
parent
f816630f08
commit
cc769ea551
1 changed files with 18 additions and 42 deletions
|
@ -313,7 +313,7 @@ class AsyncCompletions:
|
|||
async def acompletion(
|
||||
model: str,
|
||||
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
|
||||
messages: List = [],
|
||||
messages: Optional[List] = None,
|
||||
functions: Optional[List] = None,
|
||||
function_call: Optional[str] = None,
|
||||
timeout: Optional[Union[float, int]] = None,
|
||||
|
@ -395,6 +395,7 @@ async def acompletion(
|
|||
- The `completion` function is called using `run_in_executor` to execute synchronously in the event loop.
|
||||
- If `stream` is True, the function returns an async generator that yields completion lines.
|
||||
"""
|
||||
messages = messages or []
|
||||
fallbacks = kwargs.get("fallbacks", None)
|
||||
mock_timeout = kwargs.get("mock_timeout", None)
|
||||
|
||||
|
@ -530,18 +531,14 @@ def _handle_mock_potential_exceptions(
|
|||
raise litellm.MockException(
|
||||
status_code=getattr(mock_response, "status_code", 500), # type: ignore
|
||||
message=getattr(mock_response, "text", str(mock_response)),
|
||||
llm_provider=getattr(
|
||||
mock_response, "llm_provider", custom_llm_provider or "openai"
|
||||
), # type: ignore
|
||||
llm_provider=getattr(mock_response, "llm_provider", custom_llm_provider or "openai"), # type: ignore
|
||||
model=model, # type: ignore
|
||||
request=httpx.Request(method="POST", url="https://api.openai.com/v1/"),
|
||||
)
|
||||
elif isinstance(mock_response, str) and mock_response == "litellm.RateLimitError":
|
||||
raise litellm.RateLimitError(
|
||||
message="this is a mock rate limit error",
|
||||
llm_provider=getattr(
|
||||
mock_response, "llm_provider", custom_llm_provider or "openai"
|
||||
), # type: ignore
|
||||
llm_provider=getattr(mock_response, "llm_provider", custom_llm_provider or "openai"), # type: ignore
|
||||
model=model,
|
||||
)
|
||||
elif (
|
||||
|
@ -550,9 +547,7 @@ def _handle_mock_potential_exceptions(
|
|||
):
|
||||
raise litellm.ContextWindowExceededError(
|
||||
message="this is a mock context window exceeded error",
|
||||
llm_provider=getattr(
|
||||
mock_response, "llm_provider", custom_llm_provider or "openai"
|
||||
), # type: ignore
|
||||
llm_provider=getattr(mock_response, "llm_provider", custom_llm_provider or "openai"), # type: ignore
|
||||
model=model,
|
||||
)
|
||||
elif (
|
||||
|
@ -561,9 +556,7 @@ def _handle_mock_potential_exceptions(
|
|||
):
|
||||
raise litellm.InternalServerError(
|
||||
message="this is a mock internal server error",
|
||||
llm_provider=getattr(
|
||||
mock_response, "llm_provider", custom_llm_provider or "openai"
|
||||
), # type: ignore
|
||||
llm_provider=getattr(mock_response, "llm_provider", custom_llm_provider or "openai"), # type: ignore
|
||||
model=model,
|
||||
)
|
||||
elif isinstance(mock_response, str) and mock_response.startswith(
|
||||
|
@ -778,7 +771,7 @@ def mock_completion(
|
|||
def completion( # type: ignore # noqa: PLR0915
|
||||
model: str,
|
||||
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
|
||||
messages: List = [],
|
||||
messages: Optional[List] = None,
|
||||
timeout: Optional[Union[float, str, httpx.Timeout]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
|
@ -863,6 +856,7 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
- It supports various optional parameters for customizing the completion behavior.
|
||||
- If 'mock_response' is provided, a mock completion response is returned for testing or debugging.
|
||||
"""
|
||||
messages = messages or []
|
||||
### VALIDATE Request ###
|
||||
if model is None:
|
||||
raise ValueError("model param not passed in.")
|
||||
|
@ -945,7 +939,7 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
prompt_variables = cast(Optional[dict], kwargs.get("prompt_variables", None))
|
||||
### COPY MESSAGES ### - related issue https://github.com/BerriAI/litellm/discussions/4489
|
||||
messages = get_completion_messages(
|
||||
messages=messages,
|
||||
messages=messages or [],
|
||||
ensure_alternating_roles=ensure_alternating_roles or False,
|
||||
user_continue_message=user_continue_message,
|
||||
assistant_continue_message=assistant_continue_message,
|
||||
|
@ -1206,7 +1200,7 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
kwargs.pop("mock_timeout", None) # remove for any fallbacks triggered
|
||||
return mock_completion(
|
||||
model,
|
||||
messages,
|
||||
messages or [],
|
||||
stream=stream,
|
||||
n=n,
|
||||
mock_response=mock_response,
|
||||
|
@ -1218,7 +1212,8 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
mock_timeout=mock_timeout,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
if messages is None:
|
||||
messages = []
|
||||
if custom_llm_provider == "azure":
|
||||
# azure configs
|
||||
## check dynamic params ##
|
||||
|
@ -3283,7 +3278,7 @@ async def aembedding(*args, **kwargs) -> EmbeddingResponse:
|
|||
@client
|
||||
def embedding( # noqa: PLR0915
|
||||
model,
|
||||
input=[],
|
||||
input: Optional[List] = None,
|
||||
# Optional params
|
||||
dimensions: Optional[int] = None,
|
||||
encoding_format: Optional[str] = None,
|
||||
|
@ -3325,6 +3320,7 @@ def embedding( # noqa: PLR0915
|
|||
Raises:
|
||||
- exception_type: If an exception occurs during the API call.
|
||||
"""
|
||||
input = input or []
|
||||
azure = kwargs.get("azure", None)
|
||||
client = kwargs.pop("client", None)
|
||||
max_retries = kwargs.get("max_retries", None)
|
||||
|
@ -3520,12 +3516,7 @@ def embedding( # noqa: PLR0915
|
|||
api_base = api_base or litellm.api_base or get_secret("DATABRICKS_API_BASE") # type: ignore
|
||||
|
||||
# set API KEY
|
||||
api_key = (
|
||||
api_key
|
||||
or litellm.api_key
|
||||
or litellm.databricks_key
|
||||
or get_secret("DATABRICKS_API_KEY")
|
||||
) # type: ignore
|
||||
api_key = api_key or litellm.api_key or litellm.databricks_key or get_secret("DATABRICKS_API_KEY") # type: ignore
|
||||
|
||||
## EMBEDDING CALL
|
||||
response = databricks_embedding.embedding(
|
||||
|
@ -3599,12 +3590,7 @@ def embedding( # noqa: PLR0915
|
|||
client=client,
|
||||
)
|
||||
elif custom_llm_provider == "huggingface":
|
||||
api_key = (
|
||||
api_key
|
||||
or litellm.huggingface_key
|
||||
or get_secret("HUGGINGFACE_API_KEY")
|
||||
or litellm.api_key
|
||||
) # type: ignore
|
||||
api_key = api_key or litellm.huggingface_key or get_secret("HUGGINGFACE_API_KEY") or litellm.api_key # type: ignore
|
||||
response = huggingface_embed.embedding(
|
||||
model=model,
|
||||
input=input,
|
||||
|
@ -3765,12 +3751,7 @@ def embedding( # noqa: PLR0915
|
|||
api_key=api_key,
|
||||
)
|
||||
elif custom_llm_provider == "ollama":
|
||||
api_base = (
|
||||
litellm.api_base
|
||||
or api_base
|
||||
or get_secret_str("OLLAMA_API_BASE")
|
||||
or "http://localhost:11434"
|
||||
) # type: ignore
|
||||
api_base = litellm.api_base or api_base or get_secret_str("OLLAMA_API_BASE") or "http://localhost:11434" # type: ignore
|
||||
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
|
@ -5146,12 +5127,7 @@ def transcription(
|
|||
custom_llm_provider == "openai"
|
||||
or custom_llm_provider in litellm.openai_compatible_providers
|
||||
):
|
||||
api_base = (
|
||||
api_base
|
||||
or litellm.api_base
|
||||
or get_secret("OPENAI_API_BASE")
|
||||
or "https://api.openai.com/v1"
|
||||
) # type: ignore
|
||||
api_base = api_base or litellm.api_base or get_secret("OPENAI_API_BASE") or "https://api.openai.com/v1" # type: ignore
|
||||
openai.organization = (
|
||||
litellm.organization
|
||||
or get_secret("OPENAI_ORGANIZATION")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue