mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
fix: do not use mutable defaults in litellm/main.py
This commit is contained in:
parent
f816630f08
commit
cc769ea551
1 changed files with 18 additions and 42 deletions
|
@ -313,7 +313,7 @@ class AsyncCompletions:
|
||||||
async def acompletion(
|
async def acompletion(
|
||||||
model: str,
|
model: str,
|
||||||
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
|
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
|
||||||
messages: List = [],
|
messages: Optional[List] = None,
|
||||||
functions: Optional[List] = None,
|
functions: Optional[List] = None,
|
||||||
function_call: Optional[str] = None,
|
function_call: Optional[str] = None,
|
||||||
timeout: Optional[Union[float, int]] = None,
|
timeout: Optional[Union[float, int]] = None,
|
||||||
|
@ -395,6 +395,7 @@ async def acompletion(
|
||||||
- The `completion` function is called using `run_in_executor` to execute synchronously in the event loop.
|
- The `completion` function is called using `run_in_executor` to execute synchronously in the event loop.
|
||||||
- If `stream` is True, the function returns an async generator that yields completion lines.
|
- If `stream` is True, the function returns an async generator that yields completion lines.
|
||||||
"""
|
"""
|
||||||
|
messages = messages or []
|
||||||
fallbacks = kwargs.get("fallbacks", None)
|
fallbacks = kwargs.get("fallbacks", None)
|
||||||
mock_timeout = kwargs.get("mock_timeout", None)
|
mock_timeout = kwargs.get("mock_timeout", None)
|
||||||
|
|
||||||
|
@ -530,18 +531,14 @@ def _handle_mock_potential_exceptions(
|
||||||
raise litellm.MockException(
|
raise litellm.MockException(
|
||||||
status_code=getattr(mock_response, "status_code", 500), # type: ignore
|
status_code=getattr(mock_response, "status_code", 500), # type: ignore
|
||||||
message=getattr(mock_response, "text", str(mock_response)),
|
message=getattr(mock_response, "text", str(mock_response)),
|
||||||
llm_provider=getattr(
|
llm_provider=getattr(mock_response, "llm_provider", custom_llm_provider or "openai"), # type: ignore
|
||||||
mock_response, "llm_provider", custom_llm_provider or "openai"
|
|
||||||
), # type: ignore
|
|
||||||
model=model, # type: ignore
|
model=model, # type: ignore
|
||||||
request=httpx.Request(method="POST", url="https://api.openai.com/v1/"),
|
request=httpx.Request(method="POST", url="https://api.openai.com/v1/"),
|
||||||
)
|
)
|
||||||
elif isinstance(mock_response, str) and mock_response == "litellm.RateLimitError":
|
elif isinstance(mock_response, str) and mock_response == "litellm.RateLimitError":
|
||||||
raise litellm.RateLimitError(
|
raise litellm.RateLimitError(
|
||||||
message="this is a mock rate limit error",
|
message="this is a mock rate limit error",
|
||||||
llm_provider=getattr(
|
llm_provider=getattr(mock_response, "llm_provider", custom_llm_provider or "openai"), # type: ignore
|
||||||
mock_response, "llm_provider", custom_llm_provider or "openai"
|
|
||||||
), # type: ignore
|
|
||||||
model=model,
|
model=model,
|
||||||
)
|
)
|
||||||
elif (
|
elif (
|
||||||
|
@ -550,9 +547,7 @@ def _handle_mock_potential_exceptions(
|
||||||
):
|
):
|
||||||
raise litellm.ContextWindowExceededError(
|
raise litellm.ContextWindowExceededError(
|
||||||
message="this is a mock context window exceeded error",
|
message="this is a mock context window exceeded error",
|
||||||
llm_provider=getattr(
|
llm_provider=getattr(mock_response, "llm_provider", custom_llm_provider or "openai"), # type: ignore
|
||||||
mock_response, "llm_provider", custom_llm_provider or "openai"
|
|
||||||
), # type: ignore
|
|
||||||
model=model,
|
model=model,
|
||||||
)
|
)
|
||||||
elif (
|
elif (
|
||||||
|
@ -561,9 +556,7 @@ def _handle_mock_potential_exceptions(
|
||||||
):
|
):
|
||||||
raise litellm.InternalServerError(
|
raise litellm.InternalServerError(
|
||||||
message="this is a mock internal server error",
|
message="this is a mock internal server error",
|
||||||
llm_provider=getattr(
|
llm_provider=getattr(mock_response, "llm_provider", custom_llm_provider or "openai"), # type: ignore
|
||||||
mock_response, "llm_provider", custom_llm_provider or "openai"
|
|
||||||
), # type: ignore
|
|
||||||
model=model,
|
model=model,
|
||||||
)
|
)
|
||||||
elif isinstance(mock_response, str) and mock_response.startswith(
|
elif isinstance(mock_response, str) and mock_response.startswith(
|
||||||
|
@ -778,7 +771,7 @@ def mock_completion(
|
||||||
def completion( # type: ignore # noqa: PLR0915
|
def completion( # type: ignore # noqa: PLR0915
|
||||||
model: str,
|
model: str,
|
||||||
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
|
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
|
||||||
messages: List = [],
|
messages: Optional[List] = None,
|
||||||
timeout: Optional[Union[float, str, httpx.Timeout]] = None,
|
timeout: Optional[Union[float, str, httpx.Timeout]] = None,
|
||||||
temperature: Optional[float] = None,
|
temperature: Optional[float] = None,
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
|
@ -863,6 +856,7 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
- It supports various optional parameters for customizing the completion behavior.
|
- It supports various optional parameters for customizing the completion behavior.
|
||||||
- If 'mock_response' is provided, a mock completion response is returned for testing or debugging.
|
- If 'mock_response' is provided, a mock completion response is returned for testing or debugging.
|
||||||
"""
|
"""
|
||||||
|
messages = messages or []
|
||||||
### VALIDATE Request ###
|
### VALIDATE Request ###
|
||||||
if model is None:
|
if model is None:
|
||||||
raise ValueError("model param not passed in.")
|
raise ValueError("model param not passed in.")
|
||||||
|
@ -945,7 +939,7 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
prompt_variables = cast(Optional[dict], kwargs.get("prompt_variables", None))
|
prompt_variables = cast(Optional[dict], kwargs.get("prompt_variables", None))
|
||||||
### COPY MESSAGES ### - related issue https://github.com/BerriAI/litellm/discussions/4489
|
### COPY MESSAGES ### - related issue https://github.com/BerriAI/litellm/discussions/4489
|
||||||
messages = get_completion_messages(
|
messages = get_completion_messages(
|
||||||
messages=messages,
|
messages=messages or [],
|
||||||
ensure_alternating_roles=ensure_alternating_roles or False,
|
ensure_alternating_roles=ensure_alternating_roles or False,
|
||||||
user_continue_message=user_continue_message,
|
user_continue_message=user_continue_message,
|
||||||
assistant_continue_message=assistant_continue_message,
|
assistant_continue_message=assistant_continue_message,
|
||||||
|
@ -1206,7 +1200,7 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
kwargs.pop("mock_timeout", None) # remove for any fallbacks triggered
|
kwargs.pop("mock_timeout", None) # remove for any fallbacks triggered
|
||||||
return mock_completion(
|
return mock_completion(
|
||||||
model,
|
model,
|
||||||
messages,
|
messages or [],
|
||||||
stream=stream,
|
stream=stream,
|
||||||
n=n,
|
n=n,
|
||||||
mock_response=mock_response,
|
mock_response=mock_response,
|
||||||
|
@ -1218,7 +1212,8 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
mock_timeout=mock_timeout,
|
mock_timeout=mock_timeout,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
if messages is None:
|
||||||
|
messages = []
|
||||||
if custom_llm_provider == "azure":
|
if custom_llm_provider == "azure":
|
||||||
# azure configs
|
# azure configs
|
||||||
## check dynamic params ##
|
## check dynamic params ##
|
||||||
|
@ -3283,7 +3278,7 @@ async def aembedding(*args, **kwargs) -> EmbeddingResponse:
|
||||||
@client
|
@client
|
||||||
def embedding( # noqa: PLR0915
|
def embedding( # noqa: PLR0915
|
||||||
model,
|
model,
|
||||||
input=[],
|
input: Optional[List] = None,
|
||||||
# Optional params
|
# Optional params
|
||||||
dimensions: Optional[int] = None,
|
dimensions: Optional[int] = None,
|
||||||
encoding_format: Optional[str] = None,
|
encoding_format: Optional[str] = None,
|
||||||
|
@ -3325,6 +3320,7 @@ def embedding( # noqa: PLR0915
|
||||||
Raises:
|
Raises:
|
||||||
- exception_type: If an exception occurs during the API call.
|
- exception_type: If an exception occurs during the API call.
|
||||||
"""
|
"""
|
||||||
|
input = input or []
|
||||||
azure = kwargs.get("azure", None)
|
azure = kwargs.get("azure", None)
|
||||||
client = kwargs.pop("client", None)
|
client = kwargs.pop("client", None)
|
||||||
max_retries = kwargs.get("max_retries", None)
|
max_retries = kwargs.get("max_retries", None)
|
||||||
|
@ -3520,12 +3516,7 @@ def embedding( # noqa: PLR0915
|
||||||
api_base = api_base or litellm.api_base or get_secret("DATABRICKS_API_BASE") # type: ignore
|
api_base = api_base or litellm.api_base or get_secret("DATABRICKS_API_BASE") # type: ignore
|
||||||
|
|
||||||
# set API KEY
|
# set API KEY
|
||||||
api_key = (
|
api_key = api_key or litellm.api_key or litellm.databricks_key or get_secret("DATABRICKS_API_KEY") # type: ignore
|
||||||
api_key
|
|
||||||
or litellm.api_key
|
|
||||||
or litellm.databricks_key
|
|
||||||
or get_secret("DATABRICKS_API_KEY")
|
|
||||||
) # type: ignore
|
|
||||||
|
|
||||||
## EMBEDDING CALL
|
## EMBEDDING CALL
|
||||||
response = databricks_embedding.embedding(
|
response = databricks_embedding.embedding(
|
||||||
|
@ -3599,12 +3590,7 @@ def embedding( # noqa: PLR0915
|
||||||
client=client,
|
client=client,
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "huggingface":
|
elif custom_llm_provider == "huggingface":
|
||||||
api_key = (
|
api_key = api_key or litellm.huggingface_key or get_secret("HUGGINGFACE_API_KEY") or litellm.api_key # type: ignore
|
||||||
api_key
|
|
||||||
or litellm.huggingface_key
|
|
||||||
or get_secret("HUGGINGFACE_API_KEY")
|
|
||||||
or litellm.api_key
|
|
||||||
) # type: ignore
|
|
||||||
response = huggingface_embed.embedding(
|
response = huggingface_embed.embedding(
|
||||||
model=model,
|
model=model,
|
||||||
input=input,
|
input=input,
|
||||||
|
@ -3765,12 +3751,7 @@ def embedding( # noqa: PLR0915
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "ollama":
|
elif custom_llm_provider == "ollama":
|
||||||
api_base = (
|
api_base = litellm.api_base or api_base or get_secret_str("OLLAMA_API_BASE") or "http://localhost:11434" # type: ignore
|
||||||
litellm.api_base
|
|
||||||
or api_base
|
|
||||||
or get_secret_str("OLLAMA_API_BASE")
|
|
||||||
or "http://localhost:11434"
|
|
||||||
) # type: ignore
|
|
||||||
|
|
||||||
if isinstance(input, str):
|
if isinstance(input, str):
|
||||||
input = [input]
|
input = [input]
|
||||||
|
@ -5146,12 +5127,7 @@ def transcription(
|
||||||
custom_llm_provider == "openai"
|
custom_llm_provider == "openai"
|
||||||
or custom_llm_provider in litellm.openai_compatible_providers
|
or custom_llm_provider in litellm.openai_compatible_providers
|
||||||
):
|
):
|
||||||
api_base = (
|
api_base = api_base or litellm.api_base or get_secret("OPENAI_API_BASE") or "https://api.openai.com/v1" # type: ignore
|
||||||
api_base
|
|
||||||
or litellm.api_base
|
|
||||||
or get_secret("OPENAI_API_BASE")
|
|
||||||
or "https://api.openai.com/v1"
|
|
||||||
) # type: ignore
|
|
||||||
openai.organization = (
|
openai.organization = (
|
||||||
litellm.organization
|
litellm.organization
|
||||||
or get_secret("OPENAI_ORGANIZATION")
|
or get_secret("OPENAI_ORGANIZATION")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue