Litellm merge pr (#7161)

* build: merge branch

* test: fix openai naming

* fix(main.py): fix openai renaming

* style: ignore function length for config factory

* fix(sagemaker/): fix routing logic

* fix: fix imports

* fix: fix override
This commit is contained in:
Krish Dholakia 2024-12-10 22:49:26 -08:00 committed by GitHub
parent d5aae81c6d
commit 350cfc36f7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
88 changed files with 3617 additions and 4421 deletions

View file

@ -83,26 +83,13 @@ from litellm.utils import (
from ._logging import verbose_logger
from .caching.caching import disable_cache, enable_cache, update_cache
from .litellm_core_utils.streaming_chunk_builder_utils import ChunkProcessor
from .llms import (
aleph_alpha,
baseten,
maritalk,
nlp_cloud,
ollama_chat,
oobabooga,
openrouter,
palm,
petals,
replicate,
)
from .llms.ai21 import completion as ai21
from .llms import aleph_alpha, baseten, maritalk, ollama_chat, petals
from .llms.anthropic.chat import AnthropicChatCompletion
from .llms.azure.audio_transcriptions import AzureAudioTranscription
from .llms.azure.azure import AzureChatCompletion, _check_dynamic_azure_params
from .llms.azure.chat.o1_handler import AzureOpenAIO1ChatCompletion
from .llms.azure_ai.chat import AzureAIChatCompletion
from .llms.azure.completion.handler import AzureTextCompletion
from .llms.azure_ai.embed import AzureAIEmbedding
from .llms.azure_text import AzureTextCompletion
from .llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
from .llms.bedrock.embed.embedding import BedrockEmbedding
from .llms.bedrock.image.image_handler import BedrockImageGeneration
@ -111,13 +98,16 @@ from .llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
from .llms.custom_llm import CustomLLM, custom_chat_llm_router
from .llms.databricks.chat.handler import DatabricksChatCompletion
from .llms.databricks.embed.handler import DatabricksEmbeddingHandler
from .llms.deprecated_providers import palm
from .llms.groq.chat.handler import GroqChatCompletion
from .llms.huggingface_restapi import Huggingface
from .llms.huggingface.chat.handler import Huggingface
from .llms.nlp_cloud.chat.handler import completion as nlp_cloud_chat_completion
from .llms.oobabooga.chat import oobabooga
from .llms.ollama.completion import handler as ollama
from .llms.openai.transcriptions.handler import OpenAIAudioTranscription
from .llms.openai.chat.o1_handler import OpenAIO1ChatCompletion
from .llms.openai.completion.handler import OpenAITextCompletion
from .llms.openai.openai import OpenAIChatCompletion
from .llms.openai_like.chat.handler import OpenAILikeChatHandler
from .llms.openai_like.embedding.handler import OpenAILikeEmbeddingHandler
from .llms.predibase import PredibaseChatCompletion
from .llms.prompt_templates.common_utils import get_completion_messages
@ -131,6 +121,7 @@ from .llms.prompt_templates.factory import (
)
from .llms.sagemaker.chat.handler import SagemakerChatHandler
from .llms.sagemaker.completion.handler import SagemakerLLM
from .llms.replicate.chat.handler import completion as replicate_chat_completion
from .llms.text_completion_codestral import CodestralTextCompletion
from .llms.together_ai.completion.handler import TogetherAITextCompletion
from .llms.triton import TritonChatCompletion
@ -159,7 +150,7 @@ from .llms.vertex_ai_and_google_ai_studio.vertex_embeddings.embedding_handler im
from .llms.vertex_ai_and_google_ai_studio.vertex_model_garden.main import (
VertexAIModelGardenModels,
)
from .llms.vllm.completion import handler
from .llms.vllm.completion import handler as vllm_handler
from .llms.watsonx.chat.handler import WatsonXChatHandler
from .llms.watsonx.completion.handler import IBMWatsonXAI
from .types.llms.openai import (
@ -196,12 +187,10 @@ from litellm.utils import (
####### ENVIRONMENT VARIABLES ###################
openai_chat_completions = OpenAIChatCompletion()
openai_text_completions = OpenAITextCompletion()
openai_o1_chat_completions = OpenAIO1ChatCompletion()
openai_audio_transcriptions = OpenAIAudioTranscription()
databricks_chat_completions = DatabricksChatCompletion()
groq_chat_completions = GroqChatCompletion()
together_ai_text_completions = TogetherAITextCompletion()
azure_ai_chat_completions = AzureAIChatCompletion()
azure_ai_embedding = AzureAIEmbedding()
anthropic_chat_completions = AnthropicChatCompletion()
azure_chat_completions = AzureChatCompletion()
@ -228,6 +217,7 @@ watsonxai = IBMWatsonXAI()
sagemaker_llm = SagemakerLLM()
watsonx_chat_completion = WatsonXChatHandler()
openai_like_embedding = OpenAILikeEmbeddingHandler()
openai_like_chat_completion = OpenAILikeChatHandler()
databricks_embedding = DatabricksEmbeddingHandler()
base_llm_http_handler = BaseLLMHTTPHandler()
sagemaker_chat_completion = SagemakerChatHandler()
@ -449,6 +439,7 @@ async def acompletion(
or custom_llm_provider == "cerebras"
or custom_llm_provider == "sambanova"
or custom_llm_provider == "ai21_chat"
or custom_llm_provider == "ai21"
or custom_llm_provider == "volcengine"
or custom_llm_provider == "codestral"
or custom_llm_provider == "text-completion-codestral"
@ -1316,7 +1307,7 @@ def completion( # type: ignore # noqa: PLR0915
## COMPLETION CALL
try:
response = azure_ai_chat_completions.completion(
response = openai_chat_completions.completion(
model=model,
messages=messages,
headers=headers,
@ -1513,9 +1504,7 @@ def completion( # type: ignore # noqa: PLR0915
or custom_llm_provider == "nvidia_nim"
or custom_llm_provider == "cerebras"
or custom_llm_provider == "sambanova"
or custom_llm_provider == "ai21_chat"
or custom_llm_provider == "volcengine"
or custom_llm_provider == "codestral"
or custom_llm_provider == "deepseek"
or custom_llm_provider == "anyscale"
or custom_llm_provider == "mistral"
@ -1562,46 +1551,25 @@ def completion( # type: ignore # noqa: PLR0915
## COMPLETION CALL
try:
if litellm.openAIO1Config.is_model_o1_reasoning_model(model=model):
response = openai_o1_chat_completions.completion(
model=model,
messages=messages,
headers=headers,
model_response=model_response,
print_verbose=print_verbose,
api_key=api_key,
api_base=api_base,
acompletion=acompletion,
logging_obj=logging,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
timeout=timeout, # type: ignore
custom_prompt_dict=custom_prompt_dict,
client=client, # pass AsyncOpenAI, OpenAI client
organization=organization,
custom_llm_provider=custom_llm_provider,
)
else:
response = openai_chat_completions.completion(
model=model,
messages=messages,
headers=headers,
model_response=model_response,
print_verbose=print_verbose,
api_key=api_key,
api_base=api_base,
acompletion=acompletion,
logging_obj=logging,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
timeout=timeout, # type: ignore
custom_prompt_dict=custom_prompt_dict,
client=client, # pass AsyncOpenAI, OpenAI client
organization=organization,
custom_llm_provider=custom_llm_provider,
)
response = openai_chat_completions.completion(
model=model,
messages=messages,
headers=headers,
model_response=model_response,
print_verbose=print_verbose,
api_key=api_key,
api_base=api_base,
acompletion=acompletion,
logging_obj=logging,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
timeout=timeout, # type: ignore
custom_prompt_dict=custom_prompt_dict,
client=client, # pass AsyncOpenAI, OpenAI client
organization=organization,
custom_llm_provider=custom_llm_provider,
)
except Exception as e:
## LOGGING - log the original exception returned
logging.post_call(
@ -1627,7 +1595,6 @@ def completion( # type: ignore # noqa: PLR0915
or model in litellm.replicate_models
):
# Setting the relevant API KEY for replicate, replicate defaults to using os.environ.get("REPLICATE_API_TOKEN")
replicate_key = None
replicate_key = (
api_key
or litellm.replicate_key
@ -1645,7 +1612,7 @@ def completion( # type: ignore # noqa: PLR0915
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
model_response = replicate.completion( # type: ignore
model_response = replicate_chat_completion( # type: ignore
model=model,
messages=messages,
api_base=api_base,
@ -1659,6 +1626,7 @@ def completion( # type: ignore # noqa: PLR0915
logging_obj=logging,
custom_prompt_dict=custom_prompt_dict,
acompletion=acompletion,
headers=headers,
)
if optional_params.get("stream", False) is True:
@ -1806,7 +1774,7 @@ def completion( # type: ignore # noqa: PLR0915
or "https://api.nlpcloud.io/v1/gpu/"
)
response = nlp_cloud.completion(
response = nlp_cloud_chat_completion(
model=model,
messages=messages,
api_base=api_base,
@ -1969,10 +1937,10 @@ def completion( # type: ignore # noqa: PLR0915
api_base
or litellm.api_base
or get_secret("MARITALK_API_BASE")
or "https://chat.maritaca.ai/api/chat/inference"
or "https://chat.maritaca.ai/api"
)
model_response = maritalk.completion(
model_response = openai_like_chat_completion.completion(
model=model,
messages=messages,
api_base=api_base,
@ -1984,17 +1952,10 @@ def completion( # type: ignore # noqa: PLR0915
encoding=encoding,
api_key=maritalk_key,
logging_obj=logging,
custom_llm_provider="maritalk",
custom_prompt_dict=custom_prompt_dict,
)
if "stream" in optional_params and optional_params["stream"] is True:
# don't try to access stream object,
response = CustomStreamWrapper(
model_response,
model,
custom_llm_provider="maritalk",
logging_obj=logging,
)
return response
response = model_response
elif custom_llm_provider == "huggingface":
custom_llm_provider = "huggingface"
@ -2012,7 +1973,7 @@ def completion( # type: ignore # noqa: PLR0915
model=model,
messages=messages,
api_base=api_base, # type: ignore
headers=hf_headers,
headers=hf_headers or {},
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
@ -2024,6 +1985,7 @@ def completion( # type: ignore # noqa: PLR0915
logging_obj=logging,
custom_prompt_dict=custom_prompt_dict,
timeout=timeout, # type: ignore
client=client,
)
if (
"stream" in optional_params
@ -2146,7 +2108,7 @@ def completion( # type: ignore # noqa: PLR0915
headers = openrouter_headers
## Load Config
config = openrouter.OpenrouterConfig.get_config()
config = litellm.OpenrouterConfig.get_config()
for k, v in config.items():
if k == "extra_body":
# we use openai 'extra_body' to pass openrouter specific params - transforms, route, models
@ -2190,30 +2152,9 @@ def completion( # type: ignore # noqa: PLR0915
"""
pass
elif custom_llm_provider == "palm":
palm_api_key = api_key or get_secret("PALM_API_KEY") or litellm.api_key
# palm does not support streaming as yet :(
model_response = palm.completion(
model=model,
messages=messages,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
api_key=palm_api_key,
logging_obj=logging,
raise ValueError(
"Palm was decommisioned on October 2024. Please use the `gemini/` route for Gemini Google AI Studio Models. Announcement: https://ai.google.dev/palm_docs/palm?hl=en"
)
# fake palm streaming
if "stream" in optional_params and optional_params["stream"] is True:
# fake streaming for palm
resp_string = model_response["choices"][0]["message"]["content"]
response = CustomStreamWrapper(
resp_string, model, custom_llm_provider="palm", logging_obj=logging
)
return response
response = model_response
elif custom_llm_provider == "vertex_ai_beta" or custom_llm_provider == "gemini":
vertex_ai_project = (
optional_params.pop("vertex_project", None)
@ -2475,51 +2416,9 @@ def completion( # type: ignore # noqa: PLR0915
):
return _model_response
response = _model_response
elif custom_llm_provider == "ai21":
custom_llm_provider = "ai21"
ai21_key = (
api_key
or litellm.ai21_key
or os.environ.get("AI21_API_KEY")
or litellm.api_key
)
api_base = (
api_base
or litellm.api_base
or get_secret("AI21_API_BASE")
or "https://api.ai21.com/studio/v1/"
)
model_response = ai21.completion(
model=model,
messages=messages,
api_base=api_base,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
api_key=ai21_key,
logging_obj=logging,
)
if "stream" in optional_params and optional_params["stream"] is True:
# don't try to access stream object,
response = CustomStreamWrapper(
model_response,
model,
custom_llm_provider="ai21",
logging_obj=logging,
)
return response
## RESPONSE OBJECT
response = model_response
elif custom_llm_provider == "sagemaker_chat":
# boto3 reads keys from .env
response = sagemaker_chat_completion.completion(
model_response = sagemaker_chat_completion.completion(
model=model,
messages=messages,
model_response=model_response,
@ -2531,9 +2430,13 @@ def completion( # type: ignore # noqa: PLR0915
encoding=encoding,
logging_obj=logging,
acompletion=acompletion,
headers=headers or {},
)
elif custom_llm_provider == "sagemaker":
## RESPONSE OBJECT
response = model_response
elif (
custom_llm_provider == "sagemaker"
):
# boto3 reads keys from .env
model_response = sagemaker_llm.completion(
model=model,
@ -2691,7 +2594,7 @@ def completion( # type: ignore # noqa: PLR0915
response = response
elif custom_llm_provider == "vllm":
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
model_response = handler.completion(
model_response = vllm_handler.completion(
model=model,
messages=messages,
custom_prompt_dict=custom_prompt_dict,
@ -3872,6 +3775,7 @@ async def atext_completion(
or custom_llm_provider == "cerebras"
or custom_llm_provider == "sambanova"
or custom_llm_provider == "ai21_chat"
or custom_llm_provider == "ai21"
or custom_llm_provider == "volcengine"
or custom_llm_provider == "text-completion-codestral"
or custom_llm_provider == "deepseek"