This commit is contained in:
Lucca Zenobio 2024-04-25 15:00:07 -03:00
commit 6127d9f488
339 changed files with 82761 additions and 7086 deletions

View file

@ -14,6 +14,7 @@ import dotenv, traceback, random, asyncio, time, contextvars
from copy import deepcopy
import httpx
import litellm
from ._logging import verbose_logger
from litellm import ( # type: ignore
client,
@ -38,7 +39,6 @@ from litellm.utils import (
get_optional_params_image_gen,
)
from .llms import (
anthropic,
anthropic_text,
together_ai,
ai21,
@ -61,11 +61,14 @@ from .llms import (
palm,
gemini,
vertex_ai,
vertex_ai_anthropic,
maritalk,
)
from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion
from .llms.azure import AzureChatCompletion
from .llms.azure_text import AzureTextCompletion
from .llms.anthropic import AnthropicChatCompletion
from .llms.anthropic_text import AnthropicTextCompletion
from .llms.huggingface_restapi import Huggingface
from .llms.prompt_templates.factory import (
prompt_factory,
@ -97,6 +100,8 @@ from litellm.utils import (
dotenv.load_dotenv() # Loading env variables using dotenv
openai_chat_completions = OpenAIChatCompletion()
openai_text_completions = OpenAITextCompletion()
anthropic_chat_completions = AnthropicChatCompletion()
anthropic_text_completions = AnthropicTextCompletion()
azure_chat_completions = AzureChatCompletion()
azure_text_completions = AzureTextCompletion()
huggingface = Huggingface()
@ -115,24 +120,54 @@ class LiteLLM:
default_headers: Optional[Mapping[str, str]] = None,
):
self.params = locals()
self.chat = Chat(self.params)
self.chat = Chat(self.params, router_obj=None)
class Chat:
def __init__(self, params):
def __init__(self, params, router_obj: Optional[Any]):
self.params = params
self.completions = Completions(self.params)
if self.params.get("acompletion", False) == True:
self.params.pop("acompletion")
self.completions: Union[AsyncCompletions, Completions] = AsyncCompletions(
self.params, router_obj=router_obj
)
else:
self.completions = Completions(self.params, router_obj=router_obj)
class Completions:
def __init__(self, params):
def __init__(self, params, router_obj: Optional[Any]):
self.params = params
self.router_obj = router_obj
def create(self, messages, model=None, **kwargs):
for k, v in kwargs.items():
self.params[k] = v
model = model or self.params.get("model")
response = completion(model=model, messages=messages, **self.params)
if self.router_obj is not None:
response = self.router_obj.completion(
model=model, messages=messages, **self.params
)
else:
response = completion(model=model, messages=messages, **self.params)
return response
class AsyncCompletions:
def __init__(self, params, router_obj: Optional[Any]):
self.params = params
self.router_obj = router_obj
async def create(self, messages, model=None, **kwargs):
for k, v in kwargs.items():
self.params[k] = v
model = model or self.params.get("model")
if self.router_obj is not None:
response = await self.router_obj.acompletion(
model=model, messages=messages, **self.params
)
else:
response = await acompletion(model=model, messages=messages, **self.params)
return response
@ -149,7 +184,7 @@ async def acompletion(
n: Optional[int] = None,
stream: Optional[bool] = None,
stop=None,
max_tokens: Optional[float] = None,
max_tokens: Optional[int] = None,
presence_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[dict] = None,
@ -272,6 +307,7 @@ async def acompletion(
or custom_llm_provider == "vertex_ai"
or custom_llm_provider == "gemini"
or custom_llm_provider == "sagemaker"
or custom_llm_provider == "anthropic"
or custom_llm_provider in litellm.openai_compatible_providers
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
init_response = await loop.run_in_executor(None, func_with_context)
@ -283,6 +319,14 @@ async def acompletion(
response = await init_response
else:
response = init_response # type: ignore
if custom_llm_provider == "text-completion-openai" and isinstance(
response, TextCompletionResponse
):
response = litellm.OpenAITextCompletionConfig().convert_to_chat_model_response_object(
response_object=response,
model_response_object=litellm.ModelResponse(),
)
else:
# Call the synchronous function using run_in_executor
response = await loop.run_in_executor(None, func_with_context) # type: ignore
@ -298,6 +342,7 @@ async def acompletion(
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs=completion_kwargs,
extra_kwargs=kwargs,
)
@ -363,8 +408,10 @@ def mock_completion(
model_response["created"] = int(time.time())
model_response["model"] = model
model_response.usage = Usage(
prompt_tokens=10, completion_tokens=20, total_tokens=30
setattr(
model_response,
"usage",
Usage(prompt_tokens=10, completion_tokens=20, total_tokens=30),
)
try:
@ -392,7 +439,7 @@ def completion(
n: Optional[int] = None,
stream: Optional[bool] = None,
stop=None,
max_tokens: Optional[float] = None,
max_tokens: Optional[int] = None,
presence_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[dict] = None,
@ -489,6 +536,9 @@ def completion(
eos_token = kwargs.get("eos_token", None)
preset_cache_key = kwargs.get("preset_cache_key", None)
hf_model_name = kwargs.get("hf_model_name", None)
### TEXT COMPLETION CALLS ###
text_completion = kwargs.get("text_completion", False)
atext_completion = kwargs.get("atext_completion", False)
### ASYNC CALLS ###
acompletion = kwargs.get("acompletion", False)
client = kwargs.get("client", None)
@ -530,6 +580,8 @@ def completion(
litellm_params = [
"metadata",
"acompletion",
"atext_completion",
"text_completion",
"caching",
"mock_response",
"api_key",
@ -559,6 +611,7 @@ def completion(
"client",
"rpm",
"tpm",
"max_parallel_requests",
"input_cost_per_token",
"output_cost_per_token",
"input_cost_per_second",
@ -571,6 +624,8 @@ def completion(
"ttl",
"cache",
"no-log",
"base_model",
"stream_timeout",
]
default_params = openai_params + litellm_params
non_default_params = {
@ -600,6 +655,7 @@ def completion(
model
] # update the model to the actual value if an alias has been passed in
model_response = ModelResponse()
setattr(model_response, "usage", litellm.Usage())
if (
kwargs.get("azure", False) == True
): # don't remove flag check, to remain backwards compatible for repos like Codium
@ -639,7 +695,7 @@ def completion(
elif (
input_cost_per_second is not None
): # time based pricing just needs cost in place
output_cost_per_second = output_cost_per_second or 0.0
output_cost_per_second = output_cost_per_second
litellm.register_model(
{
f"{custom_llm_provider}/{model}": {
@ -1011,8 +1067,9 @@ def completion(
prompt = messages[0]["content"]
else:
prompt = " ".join([message["content"] for message in messages]) # type: ignore
## COMPLETION CALL
model_response = openai_text_completions.completion(
_response = openai_text_completions.completion(
model=model,
messages=messages,
model_response=model_response,
@ -1020,6 +1077,7 @@ def completion(
api_key=api_key,
api_base=api_base,
acompletion=acompletion,
client=client, # pass AsyncOpenAI, OpenAI client
logging_obj=logging,
optional_params=optional_params,
litellm_params=litellm_params,
@ -1027,15 +1085,25 @@ def completion(
timeout=timeout,
)
if (
optional_params.get("stream", False) == False
and acompletion == False
and text_completion == False
):
# convert to chat completion response
_response = litellm.OpenAITextCompletionConfig().convert_to_chat_model_response_object(
response_object=_response, model_response_object=model_response
)
if optional_params.get("stream", False) or acompletion == True:
## LOGGING
logging.post_call(
input=messages,
api_key=api_key,
original_response=model_response,
original_response=_response,
additional_args={"headers": headers},
)
response = model_response
response = _response
elif (
"replicate" in model
or custom_llm_provider == "replicate"
@ -1105,10 +1173,11 @@ def completion(
or get_secret("ANTHROPIC_API_BASE")
or "https://api.anthropic.com/v1/complete"
)
response = anthropic_text.completion(
response = anthropic_text_completions.completion(
model=model,
messages=messages,
api_base=api_base,
acompletion=acompletion,
custom_prompt_dict=litellm.custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
@ -1129,10 +1198,11 @@ def completion(
or get_secret("ANTHROPIC_API_BASE")
or "https://api.anthropic.com/v1/messages"
)
response = anthropic.completion(
response = anthropic_chat_completions.completion(
model=model,
messages=messages,
api_base=api_base,
acompletion=acompletion,
custom_prompt_dict=litellm.custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
@ -1144,19 +1214,6 @@ def completion(
logging_obj=logging,
headers=headers,
)
if (
"stream" in optional_params
and optional_params["stream"] == True
and not isinstance(response, CustomStreamWrapper)
):
# don't try to access stream object,
response = CustomStreamWrapper(
response,
model,
custom_llm_provider="anthropic",
logging_obj=logging,
)
if optional_params.get("stream", False) or acompletion == True:
## LOGGING
logging.post_call(
@ -1625,21 +1682,44 @@ def completion(
or litellm.vertex_location
or get_secret("VERTEXAI_LOCATION")
)
model_response = vertex_ai.completion(
model=model,
messages=messages,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
vertex_location=vertex_ai_location,
vertex_project=vertex_ai_project,
logging_obj=logging,
acompletion=acompletion,
vertex_credentials = (
optional_params.pop("vertex_credentials", None)
or optional_params.pop("vertex_ai_credentials", None)
or get_secret("VERTEXAI_CREDENTIALS")
)
new_params = deepcopy(optional_params)
if "claude-3" in model:
model_response = vertex_ai_anthropic.completion(
model=model,
messages=messages,
model_response=model_response,
print_verbose=print_verbose,
optional_params=new_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
vertex_location=vertex_ai_location,
vertex_project=vertex_ai_project,
vertex_credentials=vertex_credentials,
logging_obj=logging,
acompletion=acompletion,
)
else:
model_response = vertex_ai.completion(
model=model,
messages=messages,
model_response=model_response,
print_verbose=print_verbose,
optional_params=new_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
vertex_location=vertex_ai_location,
vertex_project=vertex_ai_project,
vertex_credentials=vertex_credentials,
logging_obj=logging,
acompletion=acompletion,
)
if (
"stream" in optional_params
@ -1753,7 +1833,11 @@ def completion(
timeout=timeout,
)
if "stream" in optional_params and optional_params["stream"] == True:
if (
"stream" in optional_params
and optional_params["stream"] == True
and not isinstance(response, CustomStreamWrapper)
):
# don't try to access stream object,
if "ai21" in model:
response = CustomStreamWrapper(
@ -1863,9 +1947,16 @@ def completion(
or "http://localhost:11434"
)
api_key = (
api_key
or litellm.ollama_key
or os.environ.get("OLLAMA_API_KEY")
or litellm.api_key
)
## LOGGING
generator = ollama_chat.get_ollama_response(
api_base,
api_key,
model,
messages,
optional_params,
@ -2061,6 +2152,7 @@ def completion(
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs=args,
extra_kwargs=kwargs,
)
@ -2422,6 +2514,7 @@ async def aembedding(*args, **kwargs):
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs=args,
extra_kwargs=kwargs,
)
@ -2473,6 +2566,7 @@ def embedding(
client = kwargs.pop("client", None)
rpm = kwargs.pop("rpm", None)
tpm = kwargs.pop("tpm", None)
max_parallel_requests = kwargs.pop("max_parallel_requests", None)
model_info = kwargs.get("model_info", None)
metadata = kwargs.get("metadata", None)
encoding_format = kwargs.get("encoding_format", None)
@ -2530,6 +2624,7 @@ def embedding(
"client",
"rpm",
"tpm",
"max_parallel_requests",
"input_cost_per_token",
"output_cost_per_token",
"input_cost_per_second",
@ -2731,6 +2826,11 @@ def embedding(
or litellm.vertex_location
or get_secret("VERTEXAI_LOCATION")
)
vertex_credentials = (
optional_params.pop("vertex_credentials", None)
or optional_params.pop("vertex_ai_credentials", None)
or get_secret("VERTEXAI_CREDENTIALS")
)
response = vertex_ai.embedding(
model=model,
@ -2741,6 +2841,7 @@ def embedding(
model_response=EmbeddingResponse(),
vertex_project=vertex_ai_project,
vertex_location=vertex_ai_location,
vertex_credentials=vertex_credentials,
aembedding=aembedding,
print_verbose=print_verbose,
)
@ -2755,28 +2856,25 @@ def embedding(
model_response=EmbeddingResponse(),
)
elif custom_llm_provider == "ollama":
ollama_input = None
if isinstance(input, list) and len(input) > 1:
raise litellm.BadRequestError(
message=f"Ollama Embeddings don't support batch embeddings",
model=model, # type: ignore
llm_provider="ollama", # type: ignore
)
if isinstance(input, list) and len(input) == 1:
ollama_input = "".join(input[0])
elif isinstance(input, str):
ollama_input = input
else:
api_base = (
litellm.api_base
or api_base
or get_secret("OLLAMA_API_BASE")
or "http://localhost:11434"
)
if isinstance(input, str):
input = [input]
if not all(isinstance(item, str) for item in input):
raise litellm.BadRequestError(
message=f"Invalid input for ollama embeddings. input={input}",
model=model, # type: ignore
llm_provider="ollama", # type: ignore
)
if aembedding == True:
if aembedding:
response = ollama.ollama_aembeddings(
api_base=api_base,
model=model,
prompt=ollama_input,
prompts=input,
encoding=encoding,
logging_obj=logging,
optional_params=optional_params,
@ -2860,7 +2958,10 @@ def embedding(
)
## Map to OpenAI Exception
raise exception_type(
model=model, original_exception=e, custom_llm_provider=custom_llm_provider
model=model,
original_exception=e,
custom_llm_provider=custom_llm_provider,
extra_kwargs=kwargs,
)
@ -2890,6 +2991,7 @@ async def atext_completion(*args, **kwargs):
if (
custom_llm_provider == "openai"
or custom_llm_provider == "azure"
or custom_llm_provider == "azure_text"
or custom_llm_provider == "custom_openai"
or custom_llm_provider == "anyscale"
or custom_llm_provider == "mistral"
@ -2921,7 +3023,31 @@ async def atext_completion(*args, **kwargs):
model=model,
)
else:
return response
transformed_logprobs = None
# only supported for TGI models
try:
raw_response = response._hidden_params.get("original_response", None)
transformed_logprobs = litellm.utils.transform_logprobs(raw_response)
except Exception as e:
print_verbose(f"LiteLLM non blocking exception: {e}")
## TRANSLATE CHAT TO TEXT FORMAT ##
if isinstance(response, TextCompletionResponse):
return response
text_completion_response = TextCompletionResponse()
text_completion_response["id"] = response.get("id", None)
text_completion_response["object"] = "text_completion"
text_completion_response["created"] = response.get("created", None)
text_completion_response["model"] = response.get("model", None)
text_choices = TextChoices()
text_choices["text"] = response["choices"][0]["message"]["content"]
text_choices["index"] = response["choices"][0]["index"]
text_choices["logprobs"] = transformed_logprobs
text_choices["finish_reason"] = response["choices"][0]["finish_reason"]
text_completion_response["choices"] = [text_choices]
text_completion_response["usage"] = response.get("usage", None)
return text_completion_response
except Exception as e:
custom_llm_provider = custom_llm_provider or "openai"
raise exception_type(
@ -2929,6 +3055,7 @@ async def atext_completion(*args, **kwargs):
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs=args,
extra_kwargs=kwargs,
)
@ -3105,7 +3232,7 @@ def text_completion(
concurrent.futures.as_completed(futures)
):
responses[i] = future.result()
text_completion_response.choices = responses
text_completion_response.choices = responses # type: ignore
return text_completion_response
# else:
@ -3113,8 +3240,36 @@ def text_completion(
# these are the params supported by Completion() but not ChatCompletion
# default case, non OpenAI requests go through here
messages = [{"role": "system", "content": prompt}]
# handle prompt formatting if prompt is a string vs. list of strings
messages = []
if isinstance(prompt, list) and len(prompt) > 0 and isinstance(prompt[0], str):
for p in prompt:
message = {"role": "user", "content": p}
messages.append(message)
elif isinstance(prompt, str):
messages = [{"role": "user", "content": prompt}]
elif (
(
custom_llm_provider == "openai"
or custom_llm_provider == "azure"
or custom_llm_provider == "azure_text"
or custom_llm_provider == "text-completion-openai"
)
and isinstance(prompt, list)
and len(prompt) > 0
and isinstance(prompt[0], list)
):
verbose_logger.warning(
msg="List of lists being passed. If this is for tokens, then it might not work across all models."
)
messages = [{"role": "user", "content": prompt}] # type: ignore
else:
raise Exception(
f"Unmapped prompt format. Your prompt is neither a list of strings nor a string. prompt={prompt}. File an issue - https://github.com/BerriAI/litellm/issues"
)
kwargs.pop("prompt", None)
kwargs["text_completion"] = True
response = completion(
model=model,
messages=messages,
@ -3134,6 +3289,10 @@ def text_completion(
transformed_logprobs = litellm.utils.transform_logprobs(raw_response)
except Exception as e:
print_verbose(f"LiteLLM non blocking exception: {e}")
if isinstance(response, TextCompletionResponse):
return response
text_completion_response["id"] = response.get("id", None)
text_completion_response["object"] = "text_completion"
text_completion_response["created"] = response.get("created", None)
@ -3145,6 +3304,7 @@ def text_completion(
text_choices["finish_reason"] = response["choices"][0]["finish_reason"]
text_completion_response["choices"] = [text_choices]
text_completion_response["usage"] = response.get("usage", None)
return text_completion_response
@ -3233,6 +3393,7 @@ async def aimage_generation(*args, **kwargs):
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs=args,
extra_kwargs=kwargs,
)
@ -3323,6 +3484,7 @@ def image_generation(
"client",
"rpm",
"tpm",
"max_parallel_requests",
"input_cost_per_token",
"output_cost_per_token",
"hf_model_name",
@ -3432,6 +3594,7 @@ def image_generation(
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs=locals(),
extra_kwargs=kwargs,
)
@ -3481,6 +3644,7 @@ async def atranscription(*args, **kwargs):
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs=args,
extra_kwargs=kwargs,
)
@ -3501,6 +3665,7 @@ def transcription(
api_key: Optional[str] = None,
api_base: Optional[str] = None,
api_version: Optional[str] = None,
max_retries: Optional[int] = None,
litellm_logging_obj=None,
custom_llm_provider=None,
**kwargs,
@ -3516,6 +3681,8 @@ def transcription(
proxy_server_request = kwargs.get("proxy_server_request", None)
model_info = kwargs.get("model_info", None)
metadata = kwargs.get("metadata", {})
if max_retries is None:
max_retries = openai.DEFAULT_MAX_RETRIES
model_response = litellm.utils.TranscriptionResponse()
@ -3559,6 +3726,7 @@ def transcription(
api_key=api_key,
api_version=api_version,
azure_ad_token=azure_ad_token,
max_retries=max_retries,
)
elif custom_llm_provider == "openai":
response = openai_chat_completions.audio_transcriptions(
@ -3569,6 +3737,7 @@ def transcription(
atranscription=atranscription,
timeout=timeout,
logging_obj=litellm_logging_obj,
max_retries=max_retries,
)
return response
@ -3656,6 +3825,9 @@ async def ahealth_check(
api_base = model_params.get("api_base") or get_secret("OPENAI_API_BASE")
if custom_llm_provider == "text-completion-openai":
mode = "completion"
response = await openai_chat_completions.ahealth_check(
model=model,
messages=model_params.get(
@ -3689,11 +3861,15 @@ async def ahealth_check(
return response
except Exception as e:
traceback.print_exc()
stack_trace = traceback.format_exc()
if isinstance(stack_trace, str):
stack_trace = stack_trace[:1000]
if model not in litellm.model_cost and mode is None:
raise Exception(
"Missing `mode`. Set the `mode` for the model - https://docs.litellm.ai/docs/proxy/health#embedding-models"
)
return {"error": f"{str(e)}"}
error_to_return = str(e) + " stack trace: " + stack_trace
return {"error": error_to_return}
####### HELPER FUNCTIONS ################