mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
Litellm merge pr (#7161)
* build: merge branch * test: fix openai naming * fix(main.py): fix openai renaming * style: ignore function length for config factory * fix(sagemaker/): fix routing logic * fix: fix imports * fix: fix override
This commit is contained in:
parent
d5aae81c6d
commit
350cfc36f7
88 changed files with 3617 additions and 4421 deletions
200
litellm/main.py
200
litellm/main.py
|
@ -83,26 +83,13 @@ from litellm.utils import (
|
|||
from ._logging import verbose_logger
|
||||
from .caching.caching import disable_cache, enable_cache, update_cache
|
||||
from .litellm_core_utils.streaming_chunk_builder_utils import ChunkProcessor
|
||||
from .llms import (
|
||||
aleph_alpha,
|
||||
baseten,
|
||||
maritalk,
|
||||
nlp_cloud,
|
||||
ollama_chat,
|
||||
oobabooga,
|
||||
openrouter,
|
||||
palm,
|
||||
petals,
|
||||
replicate,
|
||||
)
|
||||
from .llms.ai21 import completion as ai21
|
||||
from .llms import aleph_alpha, baseten, maritalk, ollama_chat, petals
|
||||
from .llms.anthropic.chat import AnthropicChatCompletion
|
||||
from .llms.azure.audio_transcriptions import AzureAudioTranscription
|
||||
from .llms.azure.azure import AzureChatCompletion, _check_dynamic_azure_params
|
||||
from .llms.azure.chat.o1_handler import AzureOpenAIO1ChatCompletion
|
||||
from .llms.azure_ai.chat import AzureAIChatCompletion
|
||||
from .llms.azure.completion.handler import AzureTextCompletion
|
||||
from .llms.azure_ai.embed import AzureAIEmbedding
|
||||
from .llms.azure_text import AzureTextCompletion
|
||||
from .llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
|
||||
from .llms.bedrock.embed.embedding import BedrockEmbedding
|
||||
from .llms.bedrock.image.image_handler import BedrockImageGeneration
|
||||
|
@ -111,13 +98,16 @@ from .llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
|
|||
from .llms.custom_llm import CustomLLM, custom_chat_llm_router
|
||||
from .llms.databricks.chat.handler import DatabricksChatCompletion
|
||||
from .llms.databricks.embed.handler import DatabricksEmbeddingHandler
|
||||
from .llms.deprecated_providers import palm
|
||||
from .llms.groq.chat.handler import GroqChatCompletion
|
||||
from .llms.huggingface_restapi import Huggingface
|
||||
from .llms.huggingface.chat.handler import Huggingface
|
||||
from .llms.nlp_cloud.chat.handler import completion as nlp_cloud_chat_completion
|
||||
from .llms.oobabooga.chat import oobabooga
|
||||
from .llms.ollama.completion import handler as ollama
|
||||
from .llms.openai.transcriptions.handler import OpenAIAudioTranscription
|
||||
from .llms.openai.chat.o1_handler import OpenAIO1ChatCompletion
|
||||
from .llms.openai.completion.handler import OpenAITextCompletion
|
||||
from .llms.openai.openai import OpenAIChatCompletion
|
||||
from .llms.openai_like.chat.handler import OpenAILikeChatHandler
|
||||
from .llms.openai_like.embedding.handler import OpenAILikeEmbeddingHandler
|
||||
from .llms.predibase import PredibaseChatCompletion
|
||||
from .llms.prompt_templates.common_utils import get_completion_messages
|
||||
|
@ -131,6 +121,7 @@ from .llms.prompt_templates.factory import (
|
|||
)
|
||||
from .llms.sagemaker.chat.handler import SagemakerChatHandler
|
||||
from .llms.sagemaker.completion.handler import SagemakerLLM
|
||||
from .llms.replicate.chat.handler import completion as replicate_chat_completion
|
||||
from .llms.text_completion_codestral import CodestralTextCompletion
|
||||
from .llms.together_ai.completion.handler import TogetherAITextCompletion
|
||||
from .llms.triton import TritonChatCompletion
|
||||
|
@ -159,7 +150,7 @@ from .llms.vertex_ai_and_google_ai_studio.vertex_embeddings.embedding_handler im
|
|||
from .llms.vertex_ai_and_google_ai_studio.vertex_model_garden.main import (
|
||||
VertexAIModelGardenModels,
|
||||
)
|
||||
from .llms.vllm.completion import handler
|
||||
from .llms.vllm.completion import handler as vllm_handler
|
||||
from .llms.watsonx.chat.handler import WatsonXChatHandler
|
||||
from .llms.watsonx.completion.handler import IBMWatsonXAI
|
||||
from .types.llms.openai import (
|
||||
|
@ -196,12 +187,10 @@ from litellm.utils import (
|
|||
####### ENVIRONMENT VARIABLES ###################
|
||||
openai_chat_completions = OpenAIChatCompletion()
|
||||
openai_text_completions = OpenAITextCompletion()
|
||||
openai_o1_chat_completions = OpenAIO1ChatCompletion()
|
||||
openai_audio_transcriptions = OpenAIAudioTranscription()
|
||||
databricks_chat_completions = DatabricksChatCompletion()
|
||||
groq_chat_completions = GroqChatCompletion()
|
||||
together_ai_text_completions = TogetherAITextCompletion()
|
||||
azure_ai_chat_completions = AzureAIChatCompletion()
|
||||
azure_ai_embedding = AzureAIEmbedding()
|
||||
anthropic_chat_completions = AnthropicChatCompletion()
|
||||
azure_chat_completions = AzureChatCompletion()
|
||||
|
@ -228,6 +217,7 @@ watsonxai = IBMWatsonXAI()
|
|||
sagemaker_llm = SagemakerLLM()
|
||||
watsonx_chat_completion = WatsonXChatHandler()
|
||||
openai_like_embedding = OpenAILikeEmbeddingHandler()
|
||||
openai_like_chat_completion = OpenAILikeChatHandler()
|
||||
databricks_embedding = DatabricksEmbeddingHandler()
|
||||
base_llm_http_handler = BaseLLMHTTPHandler()
|
||||
sagemaker_chat_completion = SagemakerChatHandler()
|
||||
|
@ -449,6 +439,7 @@ async def acompletion(
|
|||
or custom_llm_provider == "cerebras"
|
||||
or custom_llm_provider == "sambanova"
|
||||
or custom_llm_provider == "ai21_chat"
|
||||
or custom_llm_provider == "ai21"
|
||||
or custom_llm_provider == "volcengine"
|
||||
or custom_llm_provider == "codestral"
|
||||
or custom_llm_provider == "text-completion-codestral"
|
||||
|
@ -1316,7 +1307,7 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
|
||||
## COMPLETION CALL
|
||||
try:
|
||||
response = azure_ai_chat_completions.completion(
|
||||
response = openai_chat_completions.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
headers=headers,
|
||||
|
@ -1513,9 +1504,7 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
or custom_llm_provider == "nvidia_nim"
|
||||
or custom_llm_provider == "cerebras"
|
||||
or custom_llm_provider == "sambanova"
|
||||
or custom_llm_provider == "ai21_chat"
|
||||
or custom_llm_provider == "volcengine"
|
||||
or custom_llm_provider == "codestral"
|
||||
or custom_llm_provider == "deepseek"
|
||||
or custom_llm_provider == "anyscale"
|
||||
or custom_llm_provider == "mistral"
|
||||
|
@ -1562,46 +1551,25 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
|
||||
## COMPLETION CALL
|
||||
try:
|
||||
if litellm.openAIO1Config.is_model_o1_reasoning_model(model=model):
|
||||
response = openai_o1_chat_completions.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
headers=headers,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
acompletion=acompletion,
|
||||
logging_obj=logging,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
timeout=timeout, # type: ignore
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
client=client, # pass AsyncOpenAI, OpenAI client
|
||||
organization=organization,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
else:
|
||||
response = openai_chat_completions.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
headers=headers,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
acompletion=acompletion,
|
||||
logging_obj=logging,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
timeout=timeout, # type: ignore
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
client=client, # pass AsyncOpenAI, OpenAI client
|
||||
organization=organization,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
response = openai_chat_completions.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
headers=headers,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
acompletion=acompletion,
|
||||
logging_obj=logging,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
timeout=timeout, # type: ignore
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
client=client, # pass AsyncOpenAI, OpenAI client
|
||||
organization=organization,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
except Exception as e:
|
||||
## LOGGING - log the original exception returned
|
||||
logging.post_call(
|
||||
|
@ -1627,7 +1595,6 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
or model in litellm.replicate_models
|
||||
):
|
||||
# Setting the relevant API KEY for replicate, replicate defaults to using os.environ.get("REPLICATE_API_TOKEN")
|
||||
replicate_key = None
|
||||
replicate_key = (
|
||||
api_key
|
||||
or litellm.replicate_key
|
||||
|
@ -1645,7 +1612,7 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
|
||||
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
||||
|
||||
model_response = replicate.completion( # type: ignore
|
||||
model_response = replicate_chat_completion( # type: ignore
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_base=api_base,
|
||||
|
@ -1659,6 +1626,7 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
logging_obj=logging,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
acompletion=acompletion,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
if optional_params.get("stream", False) is True:
|
||||
|
@ -1806,7 +1774,7 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
or "https://api.nlpcloud.io/v1/gpu/"
|
||||
)
|
||||
|
||||
response = nlp_cloud.completion(
|
||||
response = nlp_cloud_chat_completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_base=api_base,
|
||||
|
@ -1969,10 +1937,10 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
api_base
|
||||
or litellm.api_base
|
||||
or get_secret("MARITALK_API_BASE")
|
||||
or "https://chat.maritaca.ai/api/chat/inference"
|
||||
or "https://chat.maritaca.ai/api"
|
||||
)
|
||||
|
||||
model_response = maritalk.completion(
|
||||
model_response = openai_like_chat_completion.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_base=api_base,
|
||||
|
@ -1984,17 +1952,10 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
encoding=encoding,
|
||||
api_key=maritalk_key,
|
||||
logging_obj=logging,
|
||||
custom_llm_provider="maritalk",
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
)
|
||||
|
||||
if "stream" in optional_params and optional_params["stream"] is True:
|
||||
# don't try to access stream object,
|
||||
response = CustomStreamWrapper(
|
||||
model_response,
|
||||
model,
|
||||
custom_llm_provider="maritalk",
|
||||
logging_obj=logging,
|
||||
)
|
||||
return response
|
||||
response = model_response
|
||||
elif custom_llm_provider == "huggingface":
|
||||
custom_llm_provider = "huggingface"
|
||||
|
@ -2012,7 +1973,7 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
model=model,
|
||||
messages=messages,
|
||||
api_base=api_base, # type: ignore
|
||||
headers=hf_headers,
|
||||
headers=hf_headers or {},
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
|
@ -2024,6 +1985,7 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
logging_obj=logging,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
timeout=timeout, # type: ignore
|
||||
client=client,
|
||||
)
|
||||
if (
|
||||
"stream" in optional_params
|
||||
|
@ -2146,7 +2108,7 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
headers = openrouter_headers
|
||||
|
||||
## Load Config
|
||||
config = openrouter.OpenrouterConfig.get_config()
|
||||
config = litellm.OpenrouterConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if k == "extra_body":
|
||||
# we use openai 'extra_body' to pass openrouter specific params - transforms, route, models
|
||||
|
@ -2190,30 +2152,9 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
"""
|
||||
pass
|
||||
elif custom_llm_provider == "palm":
|
||||
palm_api_key = api_key or get_secret("PALM_API_KEY") or litellm.api_key
|
||||
|
||||
# palm does not support streaming as yet :(
|
||||
model_response = palm.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
api_key=palm_api_key,
|
||||
logging_obj=logging,
|
||||
raise ValueError(
|
||||
"Palm was decommisioned on October 2024. Please use the `gemini/` route for Gemini Google AI Studio Models. Announcement: https://ai.google.dev/palm_docs/palm?hl=en"
|
||||
)
|
||||
# fake palm streaming
|
||||
if "stream" in optional_params and optional_params["stream"] is True:
|
||||
# fake streaming for palm
|
||||
resp_string = model_response["choices"][0]["message"]["content"]
|
||||
response = CustomStreamWrapper(
|
||||
resp_string, model, custom_llm_provider="palm", logging_obj=logging
|
||||
)
|
||||
return response
|
||||
response = model_response
|
||||
elif custom_llm_provider == "vertex_ai_beta" or custom_llm_provider == "gemini":
|
||||
vertex_ai_project = (
|
||||
optional_params.pop("vertex_project", None)
|
||||
|
@ -2475,51 +2416,9 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
):
|
||||
return _model_response
|
||||
response = _model_response
|
||||
elif custom_llm_provider == "ai21":
|
||||
custom_llm_provider = "ai21"
|
||||
ai21_key = (
|
||||
api_key
|
||||
or litellm.ai21_key
|
||||
or os.environ.get("AI21_API_KEY")
|
||||
or litellm.api_key
|
||||
)
|
||||
|
||||
api_base = (
|
||||
api_base
|
||||
or litellm.api_base
|
||||
or get_secret("AI21_API_BASE")
|
||||
or "https://api.ai21.com/studio/v1/"
|
||||
)
|
||||
|
||||
model_response = ai21.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_base=api_base,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
api_key=ai21_key,
|
||||
logging_obj=logging,
|
||||
)
|
||||
|
||||
if "stream" in optional_params and optional_params["stream"] is True:
|
||||
# don't try to access stream object,
|
||||
response = CustomStreamWrapper(
|
||||
model_response,
|
||||
model,
|
||||
custom_llm_provider="ai21",
|
||||
logging_obj=logging,
|
||||
)
|
||||
return response
|
||||
|
||||
## RESPONSE OBJECT
|
||||
response = model_response
|
||||
elif custom_llm_provider == "sagemaker_chat":
|
||||
# boto3 reads keys from .env
|
||||
response = sagemaker_chat_completion.completion(
|
||||
model_response = sagemaker_chat_completion.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
model_response=model_response,
|
||||
|
@ -2531,9 +2430,13 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
encoding=encoding,
|
||||
logging_obj=logging,
|
||||
acompletion=acompletion,
|
||||
headers=headers or {},
|
||||
)
|
||||
elif custom_llm_provider == "sagemaker":
|
||||
|
||||
## RESPONSE OBJECT
|
||||
response = model_response
|
||||
elif (
|
||||
custom_llm_provider == "sagemaker"
|
||||
):
|
||||
# boto3 reads keys from .env
|
||||
model_response = sagemaker_llm.completion(
|
||||
model=model,
|
||||
|
@ -2691,7 +2594,7 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
response = response
|
||||
elif custom_llm_provider == "vllm":
|
||||
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
||||
model_response = handler.completion(
|
||||
model_response = vllm_handler.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
|
@ -3872,6 +3775,7 @@ async def atext_completion(
|
|||
or custom_llm_provider == "cerebras"
|
||||
or custom_llm_provider == "sambanova"
|
||||
or custom_llm_provider == "ai21_chat"
|
||||
or custom_llm_provider == "ai21"
|
||||
or custom_llm_provider == "volcengine"
|
||||
or custom_llm_provider == "text-completion-codestral"
|
||||
or custom_llm_provider == "deepseek"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue