Merge branch 'main' into litellm_gemini_refactoring

This commit is contained in:
Krish Dholakia 2024-06-17 17:28:50 -07:00 committed by GitHub
commit a80520004e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 1889 additions and 1035 deletions

View file

@ -107,6 +107,10 @@ from .llms.databricks import DatabricksChatCompletion
from .llms.huggingface_restapi import Huggingface
from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion
from .llms.predibase import PredibaseChatCompletion
from .llms.bedrock_httpx import BedrockLLM, BedrockConverseLLM
from .llms.vertex_httpx import VertexLLM
from .llms.triton import TritonChatCompletion
from .llms.text_completion_codestral import CodestralTextCompletion
from .llms.prompt_templates.factory import (
custom_prompt,
function_call_prompt,
@ -143,6 +147,7 @@ azure_chat_completions = AzureChatCompletion()
azure_text_completions = AzureTextCompletion()
huggingface = Huggingface()
predibase_chat_completions = PredibaseChatCompletion()
codestral_text_completions = CodestralTextCompletion()
triton_chat_completions = TritonChatCompletion()
bedrock_chat_completion = BedrockLLM()
bedrock_converse_chat_completion = BedrockConverseLLM()
@ -345,6 +350,8 @@ async def acompletion(
or custom_llm_provider == "deepinfra"
or custom_llm_provider == "perplexity"
or custom_llm_provider == "groq"
or custom_llm_provider == "codestral"
or custom_llm_provider == "text-completion-codestral"
or custom_llm_provider == "deepseek"
or custom_llm_provider == "text-completion-openai"
or custom_llm_provider == "huggingface"
@ -374,9 +381,10 @@ async def acompletion(
else:
response = init_response # type: ignore
if custom_llm_provider == "text-completion-openai" and isinstance(
response, TextCompletionResponse
):
if (
custom_llm_provider == "text-completion-openai"
or custom_llm_provider == "text-completion-codestral"
) and isinstance(response, TextCompletionResponse):
response = litellm.OpenAITextCompletionConfig().convert_to_chat_model_response_object(
response_object=response,
model_response_object=litellm.ModelResponse(),
@ -1069,6 +1077,7 @@ def completion(
or custom_llm_provider == "deepinfra"
or custom_llm_provider == "perplexity"
or custom_llm_provider == "groq"
or custom_llm_provider == "codestral"
or custom_llm_provider == "deepseek"
or custom_llm_provider == "anyscale"
or custom_llm_provider == "mistral"
@ -2021,6 +2030,46 @@ def completion(
timeout=timeout,
)
if (
"stream" in optional_params
and optional_params["stream"] is True
and acompletion is False
):
return _model_response
response = _model_response
elif custom_llm_provider == "text-completion-codestral":
api_base = (
api_base
or optional_params.pop("api_base", None)
or optional_params.pop("base_url", None)
or litellm.api_base
or "https://codestral.mistral.ai/v1/fim/completions"
)
api_key = api_key or litellm.api_key or get_secret("CODESTRAL_API_KEY")
text_completion_model_response = litellm.TextCompletionResponse(
stream=stream
)
_model_response = codestral_text_completions.completion( # type: ignore
model=model,
messages=messages,
model_response=text_completion_model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
logging_obj=logging,
acompletion=acompletion,
api_base=api_base,
custom_prompt_dict=custom_prompt_dict,
api_key=api_key,
timeout=timeout,
)
if (
"stream" in optional_params
and optional_params["stream"] is True
@ -3410,7 +3459,9 @@ def embedding(
###### Text Completion ################
@client
async def atext_completion(*args, **kwargs):
async def atext_completion(
*args, **kwargs
) -> Union[TextCompletionResponse, TextCompletionStreamWrapper]:
"""
Implemented to handle async streaming for the text completion endpoint
"""
@ -3442,6 +3493,7 @@ async def atext_completion(*args, **kwargs):
or custom_llm_provider == "deepinfra"
or custom_llm_provider == "perplexity"
or custom_llm_provider == "groq"
or custom_llm_provider == "text-completion-codestral"
or custom_llm_provider == "deepseek"
or custom_llm_provider == "fireworks_ai"
or custom_llm_provider == "text-completion-openai"
@ -3703,6 +3755,7 @@ def text_completion(
custom_llm_provider == "openai"
or custom_llm_provider == "azure"
or custom_llm_provider == "azure_text"
or custom_llm_provider == "text-completion-codestral"
or custom_llm_provider == "text-completion-openai"
)
and isinstance(prompt, list)