Merge branch 'main' into feat/friendliai

This commit is contained in:
Wonseok Lee (Jack) 2024-06-21 10:50:03 +09:00 committed by GitHub
commit c4c7d1b367
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
201 changed files with 22438 additions and 13694 deletions

View file

@ -7,107 +7,132 @@
#
# Thank you ! We ❤️ you! - Krrish & Ishaan
import os, openai, sys, json, inspect, uuid, datetime, threading
from typing import Any, Literal, Union, BinaryIO
from typing_extensions import overload
from functools import partial
import dotenv, traceback, random, asyncio, time, contextvars
import asyncio
import contextvars
import datetime
import inspect
import json
import os
import random
import sys
import threading
import time
import traceback
import uuid
from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy
from functools import partial
from typing import (
Any,
BinaryIO,
Callable,
Dict,
List,
Literal,
Mapping,
Optional,
Union,
)
import dotenv
import httpx
import openai
import tiktoken
from typing_extensions import overload
import litellm
from ._logging import verbose_logger
from litellm import ( # type: ignore
Logging,
client,
exception_type,
get_optional_params,
get_litellm_params,
Logging,
get_optional_params,
)
from litellm.utils import (
get_secret,
CustomStreamWrapper,
read_config_args,
completion_with_fallbacks,
get_llm_provider,
get_api_key,
mock_completion_streaming_obj,
Usage,
async_mock_completion_streaming_obj,
completion_with_fallbacks,
convert_to_model_response_object,
token_counter,
create_pretrained_tokenizer,
create_tokenizer,
Usage,
get_api_key,
get_llm_provider,
get_optional_params_embeddings,
get_optional_params_image_gen,
get_secret,
mock_completion_streaming_obj,
read_config_args,
supports_httpx_timeout,
token_counter,
)
from ._logging import verbose_logger
from .caching import disable_cache, enable_cache, update_cache
from .llms import (
anthropic_text,
together_ai,
ai21,
sagemaker,
bedrock,
triton,
huggingface_restapi,
replicate,
aleph_alpha,
nlp_cloud,
anthropic_text,
baseten,
vllm,
ollama,
ollama_chat,
cloudflare,
bedrock,
clarifai,
cloudflare,
cohere,
cohere_chat,
petals,
gemini,
huggingface_restapi,
maritalk,
nlp_cloud,
ollama,
ollama_chat,
oobabooga,
openrouter,
palm,
gemini,
petals,
replicate,
sagemaker,
together_ai,
triton,
vertex_ai,
vertex_ai_anthropic,
maritalk,
vllm,
watsonx,
)
from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion
from .llms.azure import AzureChatCompletion
from .llms.databricks import DatabricksChatCompletion
from .llms.azure_text import AzureTextCompletion
from .llms.anthropic import AnthropicChatCompletion
from .llms.anthropic_text import AnthropicTextCompletion
from .llms.azure import AzureChatCompletion
from .llms.azure_text import AzureTextCompletion
from .llms.bedrock_httpx import BedrockConverseLLM, BedrockLLM
from .llms.databricks import DatabricksChatCompletion
from .llms.huggingface_restapi import Huggingface
from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion
from .llms.predibase import PredibaseChatCompletion
from .llms.bedrock_httpx import BedrockLLM, BedrockConverseLLM
from .llms.vertex_httpx import VertexLLM
from .llms.triton import TritonChatCompletion
from .llms.prompt_templates.factory import (
prompt_factory,
custom_prompt,
function_call_prompt,
map_system_message_pt,
prompt_factory,
)
import tiktoken
from concurrent.futures import ThreadPoolExecutor
from typing import Callable, List, Optional, Dict, Union, Mapping
from .caching import enable_cache, disable_cache, update_cache
from .llms.text_completion_codestral import CodestralTextCompletion
from .llms.triton import TritonChatCompletion
from .llms.vertex_httpx import VertexLLM
from .types.llms.openai import HttpxBinaryResponseContent
from .types.utils import ChatCompletionMessageToolCall
encoding = tiktoken.get_encoding("cl100k_base")
from litellm.utils import (
get_secret,
Choices,
CustomStreamWrapper,
TextCompletionStreamWrapper,
ModelResponse,
TextCompletionResponse,
TextChoices,
EmbeddingResponse,
ImageResponse,
read_config_args,
Choices,
Message,
ModelResponse,
TextChoices,
TextCompletionResponse,
TextCompletionStreamWrapper,
TranscriptionResponse,
get_secret,
read_config_args,
)
####### ENVIRONMENT VARIABLES ###################
@ -120,6 +145,7 @@ azure_chat_completions = AzureChatCompletion()
azure_text_completions = AzureTextCompletion()
huggingface = Huggingface()
predibase_chat_completions = PredibaseChatCompletion()
codestral_text_completions = CodestralTextCompletion()
triton_chat_completions = TritonChatCompletion()
bedrock_chat_completion = BedrockLLM()
bedrock_converse_chat_completion = BedrockConverseLLM()
@ -322,6 +348,8 @@ async def acompletion(
or custom_llm_provider == "deepinfra"
or custom_llm_provider == "perplexity"
or custom_llm_provider == "groq"
or custom_llm_provider == "codestral"
or custom_llm_provider == "text-completion-codestral"
or custom_llm_provider == "deepseek"
or custom_llm_provider == "text-completion-openai"
or custom_llm_provider == "huggingface"
@ -329,6 +357,7 @@ async def acompletion(
or custom_llm_provider == "ollama_chat"
or custom_llm_provider == "replicate"
or custom_llm_provider == "vertex_ai"
or custom_llm_provider == "vertex_ai_beta"
or custom_llm_provider == "gemini"
or custom_llm_provider == "sagemaker"
or custom_llm_provider == "anthropic"
@ -350,9 +379,10 @@ async def acompletion(
else:
response = init_response # type: ignore
if custom_llm_provider == "text-completion-openai" and isinstance(
response, TextCompletionResponse
):
if (
custom_llm_provider == "text-completion-openai"
or custom_llm_provider == "text-completion-codestral"
) and isinstance(response, TextCompletionResponse):
response = litellm.OpenAITextCompletionConfig().convert_to_chat_model_response_object(
response_object=response,
model_response_object=litellm.ModelResponse(),
@ -367,7 +397,9 @@ async def acompletion(
return response
except Exception as e:
verbose_logger.error(
"litellm.acompletion(): Exception occured - {}".format(str(e))
"litellm.acompletion(): Exception occured - {}\n{}".format(
str(e), traceback.format_exc()
)
)
verbose_logger.debug(traceback.format_exc())
custom_llm_provider = custom_llm_provider or "openai"
@ -397,7 +429,9 @@ def mock_completion(
messages: List,
stream: Optional[bool] = False,
mock_response: Union[str, Exception] = "This is a mock request",
mock_tool_calls: Optional[List] = None,
logging=None,
custom_llm_provider=None,
**kwargs,
):
"""
@ -435,7 +469,7 @@ def mock_completion(
raise litellm.APIError(
status_code=getattr(mock_response, "status_code", 500), # type: ignore
message=getattr(mock_response, "text", str(mock_response)),
llm_provider=getattr(mock_response, "llm_provider", "openai"), # type: ignore
llm_provider=getattr(mock_response, "llm_provider", custom_llm_provider or "openai"), # type: ignore
model=model, # type: ignore
request=httpx.Request(method="POST", url="https://api.openai.com/v1/"),
)
@ -464,6 +498,12 @@ def mock_completion(
model_response["created"] = int(time.time())
model_response["model"] = model
if mock_tool_calls:
model_response["choices"][0]["message"]["tool_calls"] = [
ChatCompletionMessageToolCall(**tool_call)
for tool_call in mock_tool_calls
]
setattr(
model_response,
"usage",
@ -577,6 +617,7 @@ def completion(
args = locals()
api_base = kwargs.get("api_base", None)
mock_response = kwargs.get("mock_response", None)
mock_tool_calls = kwargs.get("mock_tool_calls", None)
force_timeout = kwargs.get("force_timeout", 600) ## deprecated
logger_fn = kwargs.get("logger_fn", None)
verbose = kwargs.get("verbose", False)
@ -895,15 +936,17 @@ def completion(
litellm_params=litellm_params,
custom_llm_provider=custom_llm_provider,
)
if mock_response:
if mock_response or mock_tool_calls:
return mock_completion(
model,
messages,
stream=stream,
mock_response=mock_response,
mock_tool_calls=mock_tool_calls,
logging=logging,
acompletion=acompletion,
mock_delay=kwargs.get("mock_delay", None),
custom_llm_provider=custom_llm_provider,
)
if custom_llm_provider == "azure":
# azure configs
@ -1035,91 +1078,6 @@ def completion(
"api_base": api_base,
},
)
elif (
model in litellm.open_ai_chat_completion_models
or custom_llm_provider == "custom_openai"
or custom_llm_provider == "deepinfra"
or custom_llm_provider == "perplexity"
or custom_llm_provider == "groq"
or custom_llm_provider == "deepseek"
or custom_llm_provider == "anyscale"
or custom_llm_provider == "mistral"
or custom_llm_provider == "openai"
or custom_llm_provider == "together_ai"
or custom_llm_provider in litellm.openai_compatible_providers
or "ft:gpt-3.5-turbo" in model # finetune gpt-3.5-turbo
): # allow user to make an openai call with a custom base
# note: if a user sets a custom base - we should ensure this works
# allow for the setting of dynamic and stateful api-bases
api_base = (
api_base # for deepinfra/perplexity/anyscale/groq/friendliai we check in get_llm_provider and pass in the api base from there
or litellm.api_base
or get_secret("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
openai.organization = (
organization
or litellm.organization
or get_secret("OPENAI_ORGANIZATION")
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
)
# set API KEY
api_key = (
api_key
or litellm.api_key # for deepinfra/perplexity/anyscale/friendliai we check in get_llm_provider and pass in the api key from there
or litellm.openai_key
or get_secret("OPENAI_API_KEY")
)
headers = headers or litellm.headers
## LOAD CONFIG - if set
config = litellm.OpenAIConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > openai_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
## COMPLETION CALL
try:
response = openai_chat_completions.completion(
model=model,
messages=messages,
headers=headers,
model_response=model_response,
print_verbose=print_verbose,
api_key=api_key,
api_base=api_base,
acompletion=acompletion,
logging_obj=logging,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
timeout=timeout, # type: ignore
custom_prompt_dict=custom_prompt_dict,
client=client, # pass AsyncOpenAI, OpenAI client
organization=organization,
custom_llm_provider=custom_llm_provider,
)
except Exception as e:
## LOGGING - log the original exception returned
logging.post_call(
input=messages,
api_key=api_key,
original_response=str(e),
additional_args={"headers": headers},
)
raise e
if optional_params.get("stream", False):
## LOGGING
logging.post_call(
input=messages,
api_key=api_key,
original_response=response,
additional_args={"headers": headers},
)
elif (
custom_llm_provider == "text-completion-openai"
or "ft:babbage-002" in model
@ -1203,6 +1161,93 @@ def completion(
additional_args={"headers": headers},
)
response = _response
elif (
model in litellm.open_ai_chat_completion_models
or custom_llm_provider == "custom_openai"
or custom_llm_provider == "deepinfra"
or custom_llm_provider == "perplexity"
or custom_llm_provider == "groq"
or custom_llm_provider == "codestral"
or custom_llm_provider == "deepseek"
or custom_llm_provider == "anyscale"
or custom_llm_provider == "mistral"
or custom_llm_provider == "openai"
or custom_llm_provider == "together_ai"
or custom_llm_provider in litellm.openai_compatible_providers
or "ft:gpt-3.5-turbo" in model # finetune gpt-3.5-turbo
): # allow user to make an openai call with a custom base
# note: if a user sets a custom base - we should ensure this works
# allow for the setting of dynamic and stateful api-bases
api_base = (
api_base # for deepinfra/perplexity/anyscale/groq/friendliai we check in get_llm_provider and pass in the api base from there
or litellm.api_base
or get_secret("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
openai.organization = (
organization
or litellm.organization
or get_secret("OPENAI_ORGANIZATION")
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
)
# set API KEY
api_key = (
api_key
or litellm.api_key # for deepinfra/perplexity/anyscale/friendliai we check in get_llm_provider and pass in the api key from there
or litellm.openai_key
or get_secret("OPENAI_API_KEY")
)
headers = headers or litellm.headers
## LOAD CONFIG - if set
config = litellm.OpenAIConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > openai_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
## COMPLETION CALL
try:
response = openai_chat_completions.completion(
model=model,
messages=messages,
headers=headers,
model_response=model_response,
print_verbose=print_verbose,
api_key=api_key,
api_base=api_base,
acompletion=acompletion,
logging_obj=logging,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
timeout=timeout, # type: ignore
custom_prompt_dict=custom_prompt_dict,
client=client, # pass AsyncOpenAI, OpenAI client
organization=organization,
custom_llm_provider=custom_llm_provider,
)
except Exception as e:
## LOGGING - log the original exception returned
logging.post_call(
input=messages,
api_key=api_key,
original_response=str(e),
additional_args={"headers": headers},
)
raise e
if optional_params.get("stream", False):
## LOGGING
logging.post_call(
input=messages,
api_key=api_key,
original_response=response,
additional_args={"headers": headers},
)
elif (
"replicate" in model
or custom_llm_provider == "replicate"
@ -1840,7 +1885,25 @@ def completion(
)
return response
response = model_response
elif custom_llm_provider == "gemini":
elif custom_llm_provider == "vertex_ai_beta" or custom_llm_provider == "gemini":
vertex_ai_project = (
optional_params.pop("vertex_project", None)
or optional_params.pop("vertex_ai_project", None)
or litellm.vertex_project
or get_secret("VERTEXAI_PROJECT")
)
vertex_ai_location = (
optional_params.pop("vertex_location", None)
or optional_params.pop("vertex_ai_location", None)
or litellm.vertex_location
or get_secret("VERTEXAI_LOCATION")
)
vertex_credentials = (
optional_params.pop("vertex_credentials", None)
or optional_params.pop("vertex_ai_credentials", None)
or get_secret("VERTEXAI_CREDENTIALS")
)
gemini_api_key = (
api_key
or get_secret("GEMINI_API_KEY")
@ -1848,34 +1911,28 @@ def completion(
or litellm.api_key
)
# palm does not support streaming as yet :(
model_response = gemini.completion(
new_params = deepcopy(optional_params)
response = vertex_chat_completion.completion( # type: ignore
model=model,
messages=messages,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
optional_params=new_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
api_key=gemini_api_key,
vertex_location=vertex_ai_location,
vertex_project=vertex_ai_project,
vertex_credentials=vertex_credentials,
gemini_api_key=gemini_api_key,
logging_obj=logging,
acompletion=acompletion,
custom_prompt_dict=custom_prompt_dict,
timeout=timeout,
custom_llm_provider=custom_llm_provider,
client=client,
api_base=api_base,
)
if (
"stream" in optional_params
and optional_params["stream"] == True
and acompletion == False
):
response = CustomStreamWrapper(
iter(model_response),
model,
custom_llm_provider="gemini",
logging_obj=logging,
)
return response
response = model_response
elif custom_llm_provider == "vertex_ai":
vertex_ai_project = (
optional_params.pop("vertex_project", None)
@ -1894,6 +1951,7 @@ def completion(
or optional_params.pop("vertex_ai_credentials", None)
or get_secret("VERTEXAI_CREDENTIALS")
)
new_params = deepcopy(optional_params)
if "claude-3" in model:
model_response = vertex_ai_anthropic.completion(
@ -1982,6 +2040,46 @@ def completion(
timeout=timeout,
)
if (
"stream" in optional_params
and optional_params["stream"] is True
and acompletion is False
):
return _model_response
response = _model_response
elif custom_llm_provider == "text-completion-codestral":
api_base = (
api_base
or optional_params.pop("api_base", None)
or optional_params.pop("base_url", None)
or litellm.api_base
or "https://codestral.mistral.ai/v1/fim/completions"
)
api_key = api_key or litellm.api_key or get_secret("CODESTRAL_API_KEY")
text_completion_model_response = litellm.TextCompletionResponse(
stream=stream
)
_model_response = codestral_text_completions.completion( # type: ignore
model=model,
messages=messages,
model_response=text_completion_model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
logging_obj=logging,
acompletion=acompletion,
api_base=api_base,
custom_prompt_dict=custom_prompt_dict,
api_key=api_key,
timeout=timeout,
)
if (
"stream" in optional_params
and optional_params["stream"] is True
@ -3371,7 +3469,9 @@ def embedding(
###### Text Completion ################
@client
async def atext_completion(*args, **kwargs):
async def atext_completion(
*args, **kwargs
) -> Union[TextCompletionResponse, TextCompletionStreamWrapper]:
"""
Implemented to handle async streaming for the text completion endpoint
"""
@ -3403,6 +3503,7 @@ async def atext_completion(*args, **kwargs):
or custom_llm_provider == "deepinfra"
or custom_llm_provider == "perplexity"
or custom_llm_provider == "groq"
or custom_llm_provider == "text-completion-codestral"
or custom_llm_provider == "deepseek"
or custom_llm_provider == "fireworks_ai"
or custom_llm_provider == "text-completion-openai"
@ -3664,6 +3765,7 @@ def text_completion(
custom_llm_provider == "openai"
or custom_llm_provider == "azure"
or custom_llm_provider == "azure_text"
or custom_llm_provider == "text-completion-codestral"
or custom_llm_provider == "text-completion-openai"
)
and isinstance(prompt, list)
@ -3680,6 +3782,12 @@ def text_completion(
)
kwargs.pop("prompt", None)
if model is not None and model.startswith(
"openai/"
): # for openai compatible endpoints - e.g. vllm, call the native /v1/completions endpoint for text completion calls
model = model.replace("openai/", "text-completion-openai/")
kwargs["text_completion"] = True
response = completion(
model=model,
@ -3842,6 +3950,7 @@ def image_generation(
proxy_server_request = kwargs.get("proxy_server_request", None)
model_info = kwargs.get("model_info", None)
metadata = kwargs.get("metadata", {})
client = kwargs.get("client", None)
model_response = litellm.utils.ImageResponse()
if model is not None or custom_llm_provider is not None:
@ -3980,6 +4089,7 @@ def image_generation(
model_response=model_response,
api_version=api_version,
aimg_generation=aimg_generation,
client=client,
)
elif custom_llm_provider == "openai":
model_response = openai_chat_completions.image_generation(
@ -3992,6 +4102,7 @@ def image_generation(
optional_params=optional_params,
model_response=model_response,
aimg_generation=aimg_generation,
client=client,
)
elif custom_llm_provider == "bedrock":
if model is None: