Krrish Dholakia 2024-12-11 01:03:57 -08:00
parent 5d1274cb6e
commit 06074bb13b
8 changed files with 197 additions and 62 deletions

View file

@ -126,7 +126,7 @@ from .llms.sagemaker.chat.handler import SagemakerChatHandler
from .llms.sagemaker.completion.handler import SagemakerLLM
from .llms.text_completion_codestral import CodestralTextCompletion
from .llms.together_ai.completion.handler import TogetherAITextCompletion
from .llms.triton import TritonChatCompletion
from .llms.triton.completion.handler import TritonChatCompletion
from .llms.vertex_ai import vertex_ai_non_gemini
from .llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
from .llms.vertex_ai.gemini_embeddings.batch_embed_content_handler import (
@ -559,7 +559,9 @@ def mock_completion(
raise litellm.MockException(
status_code=getattr(mock_response, "status_code", 500), # type: ignore
message=getattr(mock_response, "text", str(mock_response)),
llm_provider=getattr(mock_response, "llm_provider", custom_llm_provider or "openai"), # type: ignore
llm_provider=getattr(
mock_response, "llm_provider", custom_llm_provider or "openai"
), # type: ignore
model=model, # type: ignore
request=httpx.Request(method="POST", url="https://api.openai.com/v1/"),
)
@ -568,7 +570,9 @@ def mock_completion(
):
raise litellm.RateLimitError(
message="this is a mock rate limit error",
llm_provider=getattr(mock_response, "llm_provider", custom_llm_provider or "openai"), # type: ignore
llm_provider=getattr(
mock_response, "llm_provider", custom_llm_provider or "openai"
), # type: ignore
model=model,
)
elif (
@ -577,7 +581,9 @@ def mock_completion(
):
raise litellm.InternalServerError(
message="this is a mock internal server error",
llm_provider=getattr(mock_response, "llm_provider", custom_llm_provider or "openai"), # type: ignore
llm_provider=getattr(
mock_response, "llm_provider", custom_llm_provider or "openai"
), # type: ignore
model=model,
)
elif isinstance(mock_response, str) and mock_response.startswith(
@ -2374,7 +2380,6 @@ def completion( # type: ignore # noqa: PLR0915
return _model_response
response = _model_response
elif custom_llm_provider == "text-completion-codestral":
api_base = (
api_base
or optional_params.pop("api_base", None)
@ -2705,6 +2710,8 @@ def completion( # type: ignore # noqa: PLR0915
logging_obj=logging,
stream=stream,
acompletion=acompletion,
client=client,
litellm_params=litellm_params,
)
## RESPONSE OBJECT
@ -2944,7 +2951,9 @@ def completion_with_retries(*args, **kwargs):
)
num_retries = kwargs.pop("num_retries", 3)
retry_strategy: Literal["exponential_backoff_retry", "constant_retry"] = kwargs.pop("retry_strategy", "constant_retry") # type: ignore
retry_strategy: Literal["exponential_backoff_retry", "constant_retry"] = kwargs.pop(
"retry_strategy", "constant_retry"
) # type: ignore
original_function = kwargs.pop("original_function", completion)
if retry_strategy == "exponential_backoff_retry":
retryer = tenacity.Retrying(
@ -3331,9 +3340,7 @@ def embedding( # noqa: PLR0915
max_retries=max_retries,
)
elif custom_llm_provider == "databricks":
api_base = (
api_base or litellm.api_base or get_secret("DATABRICKS_API_BASE")
) # type: ignore
api_base = api_base or litellm.api_base or get_secret("DATABRICKS_API_BASE") # type: ignore
# set API KEY
api_key = (
@ -3465,7 +3472,6 @@ def embedding( # noqa: PLR0915
aembedding=aembedding,
)
elif custom_llm_provider == "gemini":
gemini_api_key = (
api_key or get_secret_str("GEMINI_API_KEY") or litellm.api_key
)
@ -3960,7 +3966,11 @@ def text_completion( # noqa: PLR0915
optional_params["custom_llm_provider"] = custom_llm_provider
# get custom_llm_provider
_model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
_model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(
model=model, # type: ignore
custom_llm_provider=custom_llm_provider,
api_base=api_base,
)
if custom_llm_provider == "huggingface":
# if echo == True, for TGI llms we need to set top_n_tokens to 3
@ -4212,7 +4222,6 @@ async def amoderation(
)
openai_client = kwargs.get("client", None)
if openai_client is None or not isinstance(openai_client, AsyncOpenAI):
# call helper to get OpenAI client
# _get_openai_client maintains in-memory caching logic for OpenAI clients
_openai_client: AsyncOpenAI = openai_chat_completions._get_openai_client( # type: ignore
@ -4322,7 +4331,11 @@ def image_generation( # noqa: PLR0915
headers.update(extra_headers)
model_response: ImageResponse = litellm.utils.ImageResponse()
if model is not None or custom_llm_provider is not None:
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(
model=model, # type: ignore
custom_llm_provider=custom_llm_provider,
api_base=api_base,
)
else:
model = "dall-e-2"
custom_llm_provider = "openai" # default to dall-e-2 on openai
@ -4644,7 +4657,9 @@ def transcription(
model_response = litellm.utils.TranscriptionResponse()
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(
model=model, custom_llm_provider=custom_llm_provider, api_base=api_base
) # type: ignore
if dynamic_api_key is not None:
api_key = dynamic_api_key
@ -4710,12 +4725,7 @@ def transcription(
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
)
# set API KEY
api_key = (
api_key
or litellm.api_key
or litellm.openai_key
or get_secret("OPENAI_API_KEY")
) # type: ignore
api_key = api_key or litellm.api_key or litellm.openai_key or get_secret("OPENAI_API_KEY") # type: ignore
response = openai_audio_transcriptions.audio_transcriptions(
model=model,
audio_file=file,
@ -4802,7 +4812,9 @@ def speech(
proxy_server_request = kwargs.get("proxy_server_request", None)
extra_headers = kwargs.get("extra_headers", None)
model_info = kwargs.get("model_info", None)
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(
model=model, custom_llm_provider=custom_llm_provider, api_base=api_base
) # type: ignore
kwargs.pop("tags", [])
optional_params = {}
@ -4895,9 +4907,7 @@ def speech(
)
api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore
api_version = (
api_version or litellm.api_version or get_secret("AZURE_API_VERSION")
) # type: ignore
api_version = api_version or litellm.api_version or get_secret("AZURE_API_VERSION") # type: ignore
api_key = (
api_key
@ -5004,7 +5014,6 @@ async def ahealth_check( # noqa: PLR0915
"""
passed_in_mode: Optional[str] = None
try:
model: Optional[str] = model_params.get("model", None)
if model is None: