Merge branch 'main' into feature/watsonx-integration

This commit is contained in:
Simon Sanchez Viloria 2024-05-06 17:27:14 +02:00
commit 9a95fa9348
144 changed files with 8872 additions and 2296 deletions

View file

@ -12,6 +12,7 @@ from typing import Any, Literal, Union, BinaryIO
from functools import partial
import dotenv, traceback, random, asyncio, time, contextvars
from copy import deepcopy
import httpx
import litellm
from ._logging import verbose_logger
@ -33,9 +34,12 @@ from litellm.utils import (
async_mock_completion_streaming_obj,
convert_to_model_response_object,
token_counter,
create_pretrained_tokenizer,
create_tokenizer,
Usage,
get_optional_params_embeddings,
get_optional_params_image_gen,
supports_httpx_timeout,
)
from .llms import (
anthropic_text,
@ -75,6 +79,7 @@ from .llms.prompt_templates.factory import (
prompt_factory,
custom_prompt,
function_call_prompt,
map_system_message_pt,
)
import tiktoken
from concurrent.futures import ThreadPoolExecutor
@ -363,7 +368,7 @@ def mock_completion(
model: str,
messages: List,
stream: Optional[bool] = False,
mock_response: str = "This is a mock request",
mock_response: Union[str, Exception] = "This is a mock request",
logging=None,
**kwargs,
):
@ -390,6 +395,20 @@ def mock_completion(
- If 'stream' is True, it returns a response that mimics the behavior of a streaming completion.
"""
try:
## LOGGING
if logging is not None:
logging.pre_call(
input=messages,
api_key="mock-key",
)
if isinstance(mock_response, Exception):
raise litellm.APIError(
status_code=500, # type: ignore
message=str(mock_response),
llm_provider="openai", # type: ignore
model=model, # type: ignore
request=httpx.Request(method="POST", url="https://api.openai.com/v1/"),
)
model_response = ModelResponse(stream=stream)
if stream is True:
# don't try to access stream object,
@ -436,7 +455,7 @@ def completion(
model: str,
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
messages: List = [],
timeout: Optional[Union[float, int]] = None,
timeout: Optional[Union[float, str, httpx.Timeout]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
n: Optional[int] = None,
@ -539,6 +558,7 @@ def completion(
eos_token = kwargs.get("eos_token", None)
preset_cache_key = kwargs.get("preset_cache_key", None)
hf_model_name = kwargs.get("hf_model_name", None)
supports_system_message = kwargs.get("supports_system_message", None)
### TEXT COMPLETION CALLS ###
text_completion = kwargs.get("text_completion", False)
atext_completion = kwargs.get("atext_completion", False)
@ -604,6 +624,7 @@ def completion(
"model_list",
"num_retries",
"context_window_fallback_dict",
"retry_policy",
"roles",
"final_prompt_value",
"bos_token",
@ -629,16 +650,27 @@ def completion(
"no-log",
"base_model",
"stream_timeout",
"supports_system_message",
]
default_params = openai_params + litellm_params
non_default_params = {
k: v for k, v in kwargs.items() if k not in default_params
} # model-specific params - pass them straight to the model/provider
if timeout is None:
timeout = (
kwargs.get("request_timeout", None) or 600
) # set timeout for 10 minutes by default
timeout = float(timeout)
### TIMEOUT LOGIC ###
timeout = timeout or kwargs.get("request_timeout", 600) or 600
# set timeout for 10 minutes by default
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
try:
if base_url is not None:
api_base = base_url
@ -733,6 +765,13 @@ def completion(
custom_prompt_dict[model]["bos_token"] = bos_token
if eos_token:
custom_prompt_dict[model]["eos_token"] = eos_token
if (
supports_system_message is not None
and isinstance(supports_system_message, bool)
and supports_system_message == False
):
messages = map_system_message_pt(messages=messages)
model_api_key = get_api_key(
llm_provider=custom_llm_provider, dynamic_api_key=api_key
) # get the api key from the environment if required for the model
@ -859,7 +898,7 @@ def completion(
logger_fn=logger_fn,
logging_obj=logging,
acompletion=acompletion,
timeout=timeout,
timeout=timeout, # type: ignore
client=client, # pass AsyncAzureOpenAI, AzureOpenAI client
)
@ -1000,7 +1039,7 @@ def completion(
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
timeout=timeout,
timeout=timeout, # type: ignore
custom_prompt_dict=custom_prompt_dict,
client=client, # pass AsyncOpenAI, OpenAI client
organization=organization,
@ -1085,7 +1124,7 @@ def completion(
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
timeout=timeout,
timeout=timeout, # type: ignore
)
if (
@ -1459,7 +1498,7 @@ def completion(
acompletion=acompletion,
logging_obj=logging,
custom_prompt_dict=custom_prompt_dict,
timeout=timeout,
timeout=timeout, # type: ignore
)
if (
"stream" in optional_params
@ -1552,7 +1591,7 @@ def completion(
logger_fn=logger_fn,
logging_obj=logging,
acompletion=acompletion,
timeout=timeout,
timeout=timeout, # type: ignore
)
## LOGGING
logging.post_call(
@ -1832,6 +1871,7 @@ def completion(
logger_fn=logger_fn,
encoding=encoding,
logging_obj=logging,
extra_headers=extra_headers,
timeout=timeout,
)
@ -1875,7 +1915,7 @@ def completion(
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
litellm_params=litellm_params, # type: ignore
logger_fn=logger_fn,
encoding=encoding,
logging_obj=logging,
@ -2261,7 +2301,7 @@ def batch_completion(
n: Optional[int] = None,
stream: Optional[bool] = None,
stop=None,
max_tokens: Optional[float] = None,
max_tokens: Optional[int] = None,
presence_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[dict] = None,
@ -2655,6 +2695,7 @@ def embedding(
"model_list",
"num_retries",
"context_window_fallback_dict",
"retry_policy",
"roles",
"final_prompt_value",
"bos_token",
@ -3525,6 +3566,7 @@ def image_generation(
"model_list",
"num_retries",
"context_window_fallback_dict",
"retry_policy",
"roles",
"final_prompt_value",
"bos_token",