mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
build: Squashed commit of https://github.com/BerriAI/litellm/pull/7170
Closes https://github.com/BerriAI/litellm/pull/7170
This commit is contained in:
parent
5d1274cb6e
commit
06074bb13b
8 changed files with 197 additions and 62 deletions
|
@ -126,7 +126,7 @@ from .llms.sagemaker.chat.handler import SagemakerChatHandler
|
|||
from .llms.sagemaker.completion.handler import SagemakerLLM
|
||||
from .llms.text_completion_codestral import CodestralTextCompletion
|
||||
from .llms.together_ai.completion.handler import TogetherAITextCompletion
|
||||
from .llms.triton import TritonChatCompletion
|
||||
from .llms.triton.completion.handler import TritonChatCompletion
|
||||
from .llms.vertex_ai import vertex_ai_non_gemini
|
||||
from .llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
|
||||
from .llms.vertex_ai.gemini_embeddings.batch_embed_content_handler import (
|
||||
|
@ -559,7 +559,9 @@ def mock_completion(
|
|||
raise litellm.MockException(
|
||||
status_code=getattr(mock_response, "status_code", 500), # type: ignore
|
||||
message=getattr(mock_response, "text", str(mock_response)),
|
||||
llm_provider=getattr(mock_response, "llm_provider", custom_llm_provider or "openai"), # type: ignore
|
||||
llm_provider=getattr(
|
||||
mock_response, "llm_provider", custom_llm_provider or "openai"
|
||||
), # type: ignore
|
||||
model=model, # type: ignore
|
||||
request=httpx.Request(method="POST", url="https://api.openai.com/v1/"),
|
||||
)
|
||||
|
@ -568,7 +570,9 @@ def mock_completion(
|
|||
):
|
||||
raise litellm.RateLimitError(
|
||||
message="this is a mock rate limit error",
|
||||
llm_provider=getattr(mock_response, "llm_provider", custom_llm_provider or "openai"), # type: ignore
|
||||
llm_provider=getattr(
|
||||
mock_response, "llm_provider", custom_llm_provider or "openai"
|
||||
), # type: ignore
|
||||
model=model,
|
||||
)
|
||||
elif (
|
||||
|
@ -577,7 +581,9 @@ def mock_completion(
|
|||
):
|
||||
raise litellm.InternalServerError(
|
||||
message="this is a mock internal server error",
|
||||
llm_provider=getattr(mock_response, "llm_provider", custom_llm_provider or "openai"), # type: ignore
|
||||
llm_provider=getattr(
|
||||
mock_response, "llm_provider", custom_llm_provider or "openai"
|
||||
), # type: ignore
|
||||
model=model,
|
||||
)
|
||||
elif isinstance(mock_response, str) and mock_response.startswith(
|
||||
|
@ -2374,7 +2380,6 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
return _model_response
|
||||
response = _model_response
|
||||
elif custom_llm_provider == "text-completion-codestral":
|
||||
|
||||
api_base = (
|
||||
api_base
|
||||
or optional_params.pop("api_base", None)
|
||||
|
@ -2705,6 +2710,8 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
logging_obj=logging,
|
||||
stream=stream,
|
||||
acompletion=acompletion,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
## RESPONSE OBJECT
|
||||
|
@ -2944,7 +2951,9 @@ def completion_with_retries(*args, **kwargs):
|
|||
)
|
||||
|
||||
num_retries = kwargs.pop("num_retries", 3)
|
||||
retry_strategy: Literal["exponential_backoff_retry", "constant_retry"] = kwargs.pop("retry_strategy", "constant_retry") # type: ignore
|
||||
retry_strategy: Literal["exponential_backoff_retry", "constant_retry"] = kwargs.pop(
|
||||
"retry_strategy", "constant_retry"
|
||||
) # type: ignore
|
||||
original_function = kwargs.pop("original_function", completion)
|
||||
if retry_strategy == "exponential_backoff_retry":
|
||||
retryer = tenacity.Retrying(
|
||||
|
@ -3331,9 +3340,7 @@ def embedding( # noqa: PLR0915
|
|||
max_retries=max_retries,
|
||||
)
|
||||
elif custom_llm_provider == "databricks":
|
||||
api_base = (
|
||||
api_base or litellm.api_base or get_secret("DATABRICKS_API_BASE")
|
||||
) # type: ignore
|
||||
api_base = api_base or litellm.api_base or get_secret("DATABRICKS_API_BASE") # type: ignore
|
||||
|
||||
# set API KEY
|
||||
api_key = (
|
||||
|
@ -3465,7 +3472,6 @@ def embedding( # noqa: PLR0915
|
|||
aembedding=aembedding,
|
||||
)
|
||||
elif custom_llm_provider == "gemini":
|
||||
|
||||
gemini_api_key = (
|
||||
api_key or get_secret_str("GEMINI_API_KEY") or litellm.api_key
|
||||
)
|
||||
|
@ -3960,7 +3966,11 @@ def text_completion( # noqa: PLR0915
|
|||
optional_params["custom_llm_provider"] = custom_llm_provider
|
||||
|
||||
# get custom_llm_provider
|
||||
_model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
|
||||
_model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(
|
||||
model=model, # type: ignore
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
api_base=api_base,
|
||||
)
|
||||
|
||||
if custom_llm_provider == "huggingface":
|
||||
# if echo == True, for TGI llms we need to set top_n_tokens to 3
|
||||
|
@ -4212,7 +4222,6 @@ async def amoderation(
|
|||
)
|
||||
openai_client = kwargs.get("client", None)
|
||||
if openai_client is None or not isinstance(openai_client, AsyncOpenAI):
|
||||
|
||||
# call helper to get OpenAI client
|
||||
# _get_openai_client maintains in-memory caching logic for OpenAI clients
|
||||
_openai_client: AsyncOpenAI = openai_chat_completions._get_openai_client( # type: ignore
|
||||
|
@ -4322,7 +4331,11 @@ def image_generation( # noqa: PLR0915
|
|||
headers.update(extra_headers)
|
||||
model_response: ImageResponse = litellm.utils.ImageResponse()
|
||||
if model is not None or custom_llm_provider is not None:
|
||||
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
|
||||
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(
|
||||
model=model, # type: ignore
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
api_base=api_base,
|
||||
)
|
||||
else:
|
||||
model = "dall-e-2"
|
||||
custom_llm_provider = "openai" # default to dall-e-2 on openai
|
||||
|
@ -4644,7 +4657,9 @@ def transcription(
|
|||
|
||||
model_response = litellm.utils.TranscriptionResponse()
|
||||
|
||||
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
|
||||
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(
|
||||
model=model, custom_llm_provider=custom_llm_provider, api_base=api_base
|
||||
) # type: ignore
|
||||
|
||||
if dynamic_api_key is not None:
|
||||
api_key = dynamic_api_key
|
||||
|
@ -4710,12 +4725,7 @@ def transcription(
|
|||
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
||||
)
|
||||
# set API KEY
|
||||
api_key = (
|
||||
api_key
|
||||
or litellm.api_key
|
||||
or litellm.openai_key
|
||||
or get_secret("OPENAI_API_KEY")
|
||||
) # type: ignore
|
||||
api_key = api_key or litellm.api_key or litellm.openai_key or get_secret("OPENAI_API_KEY") # type: ignore
|
||||
response = openai_audio_transcriptions.audio_transcriptions(
|
||||
model=model,
|
||||
audio_file=file,
|
||||
|
@ -4802,7 +4812,9 @@ def speech(
|
|||
proxy_server_request = kwargs.get("proxy_server_request", None)
|
||||
extra_headers = kwargs.get("extra_headers", None)
|
||||
model_info = kwargs.get("model_info", None)
|
||||
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
|
||||
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(
|
||||
model=model, custom_llm_provider=custom_llm_provider, api_base=api_base
|
||||
) # type: ignore
|
||||
kwargs.pop("tags", [])
|
||||
|
||||
optional_params = {}
|
||||
|
@ -4895,9 +4907,7 @@ def speech(
|
|||
)
|
||||
api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore
|
||||
|
||||
api_version = (
|
||||
api_version or litellm.api_version or get_secret("AZURE_API_VERSION")
|
||||
) # type: ignore
|
||||
api_version = api_version or litellm.api_version or get_secret("AZURE_API_VERSION") # type: ignore
|
||||
|
||||
api_key = (
|
||||
api_key
|
||||
|
@ -5004,7 +5014,6 @@ async def ahealth_check( # noqa: PLR0915
|
|||
"""
|
||||
passed_in_mode: Optional[str] = None
|
||||
try:
|
||||
|
||||
model: Optional[str] = model_params.get("model", None)
|
||||
|
||||
if model is None:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue