From b72d372aa7c2732627abc1b123b7e04c2abc68e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateo=20C=C3=A1mara?= Date: Wed, 20 Dec 2023 19:49:12 +0100 Subject: [PATCH 1/5] feat: added explicit args to acomplete --- litellm/main.py | 98 ++++++++++++++++++++++++++++++++++++------------- 1 file changed, 72 insertions(+), 26 deletions(-) diff --git a/litellm/main.py b/litellm/main.py index b2ed72f7f..6581ae1a1 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -117,7 +117,31 @@ class Completions(): return response @client -async def acompletion(*args, **kwargs): +async def acompletion( + model: str, + messages: List = [], + functions: Optional[List] = None, + function_call: Optional[str] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + n: Optional[int] = None, + stream: Optional[bool] = None, + stop=None, + max_tokens: Optional[int] = None, + presence_penalty: Optional[float] = None, + frequency_penalty: Optional[float] = None, + logit_bias: Optional[Dict] = None, + user: Optional[str] = None, + metadata: Optional[Dict] = None, + api_base: Optional[str] = None, + api_version: Optional[str] = None, + api_key: Optional[str] = None, + model_list: Optional[List] = None, + mock_response: Optional[str] = None, + force_timeout: Optional[int] = None, + custom_llm_provider: Optional[str] = None, + **kwargs, +): """ Asynchronously executes a litellm.completion() call for any of litellm supported llms (example gpt-4, gpt-3.5-turbo, claude-2, command-nightly) @@ -138,7 +162,7 @@ async def acompletion(*args, **kwargs): frequency_penalty: It is used to penalize new tokens based on their frequency in the text so far. logit_bias (dict, optional): Used to modify the probability of specific tokens appearing in the completion. user (str, optional): A unique identifier representing your end-user. This can help the LLM provider to monitor and detect abuse. - metadata (dict, optional): Pass in additional metadata to tag your completion calls - eg. prompt version, details, etc. + metadata (dict, optional): Pass in additional metadata to tag your completion calls - eg. prompt version, details, etc. api_base (str, optional): Base URL for the API (default is None). api_version (str, optional): API version (default is None). api_key (str, optional): API key (default is None). @@ -157,22 +181,44 @@ async def acompletion(*args, **kwargs): - If `stream` is True, the function returns an async generator that yields completion lines. """ loop = asyncio.get_event_loop() - model = args[0] if len(args) > 0 else kwargs["model"] - ### PASS ARGS TO COMPLETION ### - kwargs["acompletion"] = True - custom_llm_provider = None - try: + # Adjusted to use explicit arguments instead of *args and **kwargs + completion_kwargs = { + "model": model, + "messages": messages, + "functions": functions, + "function_call": function_call, + "temperature": temperature, + "top_p": top_p, + "n": n, + "stream": stream, + "stop": stop, + "max_tokens": max_tokens, + "presence_penalty": presence_penalty, + "frequency_penalty": frequency_penalty, + "logit_bias": logit_bias, + "user": user, + "metadata": metadata, + "api_base": api_base, + "api_version": api_version, + "api_key": api_key, + "model_list": model_list, + "mock_response": mock_response, + "force_timeout": force_timeout, + "custom_llm_provider": custom_llm_provider, + "acompletion": True # assuming this is a required parameter + } + try: # Use a partial function to pass your keyword arguments - func = partial(completion, *args, **kwargs) + func = partial(completion, **completion_kwargs) # Add the context to the function ctx = contextvars.copy_context() func_with_context = partial(ctx.run, func) - _, custom_llm_provider, _, _ = get_llm_provider(model=model, api_base=kwargs.get("api_base", None)) + _, custom_llm_provider, _, _ = get_llm_provider(model=model, api_base=completion_kwargs.get("api_base", None)) - if (custom_llm_provider == "openai" - or custom_llm_provider == "azure" + if (custom_llm_provider == "openai" + or custom_llm_provider == "azure" or custom_llm_provider == "custom_openai" or custom_llm_provider == "anyscale" or custom_llm_provider == "mistral" @@ -182,39 +228,39 @@ async def acompletion(*args, **kwargs): or custom_llm_provider == "text-completion-openai" or custom_llm_provider == "huggingface" or custom_llm_provider == "ollama" - or custom_llm_provider == "vertex_ai"): # currently implemented aiohttp calls for just azure and openai, soon all. - if kwargs.get("stream", False): - response = completion(*args, **kwargs) + or custom_llm_provider == "vertex_ai"): # currently implemented aiohttp calls for just azure and openai, soon all. + if completion_kwargs.get("stream", False): + response = completion(**completion_kwargs) else: # Await normally init_response = await loop.run_in_executor(None, func_with_context) - if isinstance(init_response, dict) or isinstance(init_response, ModelResponse): ## CACHING SCENARIO + if isinstance(init_response, dict) or isinstance(init_response, ModelResponse): ## CACHING SCENARIO response = init_response elif asyncio.iscoroutine(init_response): response = await init_response - else: + else: # Call the synchronous function using run_in_executor response = await loop.run_in_executor(None, func_with_context) - if kwargs.get("stream", False): # return an async generator - return _async_streaming(response=response, model=model, custom_llm_provider=custom_llm_provider, args=args) - else: + if completion_kwargs.get("stream", False): # return an async generator + return _async_streaming(response=response, model=model, custom_llm_provider=custom_llm_provider, completion_kwargs=completion_kwargs) + else: return response - except Exception as e: + except Exception as e: custom_llm_provider = custom_llm_provider or "openai" raise exception_type( - model=model, custom_llm_provider=custom_llm_provider, original_exception=e, completion_kwargs=args, + model=model, custom_llm_provider=custom_llm_provider, original_exception=e, completion_kwargs=completion_kwargs, ) -async def _async_streaming(response, model, custom_llm_provider, args): - try: +async def _async_streaming(response, model, custom_llm_provider, completion_kwargs): + try: print_verbose(f"received response in _async_streaming: {response}") - async for line in response: + async for line in response: print_verbose(f"line in async streaming: {line}") yield line - except Exception as e: + except Exception as e: print_verbose(f"error raised _async_streaming: {traceback.format_exc()}") raise exception_type( - model=model, custom_llm_provider=custom_llm_provider, original_exception=e, completion_kwargs=args, + model=model, custom_llm_provider=custom_llm_provider, original_exception=e, completion_kwargs=completion_kwargs, ) def mock_completion(model: str, messages: List, stream: Optional[bool] = False, mock_response: str = "This is a mock request", **kwargs): From 48b2f69c933f73acff5c67774b3a4c7042a80f2b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateo=20C=C3=A1mara?= Date: Tue, 9 Jan 2024 12:05:31 +0100 Subject: [PATCH 2/5] Added the new acompletion parameters based on CompletionRequest attributes --- litellm/main.py | 72 ++++++++++++++++++++++++++++--------------------- 1 file changed, 42 insertions(+), 30 deletions(-) diff --git a/litellm/main.py b/litellm/main.py index 6581ae1a1..7c3ccd750 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -118,29 +118,37 @@ class Completions(): @client async def acompletion( - model: str, - messages: List = [], - functions: Optional[List] = None, - function_call: Optional[str] = None, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - n: Optional[int] = None, - stream: Optional[bool] = None, - stop=None, - max_tokens: Optional[int] = None, - presence_penalty: Optional[float] = None, - frequency_penalty: Optional[float] = None, - logit_bias: Optional[Dict] = None, - user: Optional[str] = None, - metadata: Optional[Dict] = None, - api_base: Optional[str] = None, - api_version: Optional[str] = None, - api_key: Optional[str] = None, - model_list: Optional[List] = None, - mock_response: Optional[str] = None, - force_timeout: Optional[int] = None, - custom_llm_provider: Optional[str] = None, - **kwargs, + model: str, + # Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create + messages: List = [], + functions: Optional[List] = None, + function_call: Optional[str] = None, + timeout: Optional[Union[float, int]] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + n: Optional[int] = None, + stream: Optional[bool] = None, + stop=None, + max_tokens: Optional[float] = None, + presence_penalty: Optional[float] = None, + frequency_penalty: Optional[float] = None, + logit_bias: Optional[dict] = None, + user: Optional[str] = None, + # openai v1.0+ new params + response_format: Optional[dict] = None, + seed: Optional[int] = None, + tools: Optional[List] = None, + tool_choice: Optional[str] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, + deployment_id=None, + # set api_base, api_version, api_key + base_url: Optional[str] = None, + api_version: Optional[str] = None, + api_key: Optional[str] = None, + model_list: Optional[list] = None, # pass in a list of api_base,keys, etc. + # Optional liteLLM function params + **kwargs, ): """ Asynchronously executes a litellm.completion() call for any of litellm supported llms (example gpt-4, gpt-3.5-turbo, claude-2, command-nightly) @@ -187,24 +195,28 @@ async def acompletion( "messages": messages, "functions": functions, "function_call": function_call, + "timeout": timeout, "temperature": temperature, "top_p": top_p, "n": n, "stream": stream, - "stop": stop, + "stop": stop, "max_tokens": max_tokens, "presence_penalty": presence_penalty, "frequency_penalty": frequency_penalty, "logit_bias": logit_bias, "user": user, - "metadata": metadata, - "api_base": api_base, + "response_format": response_format, + "seed": seed, + "tools": tools, + "tool_choice": tool_choice, + "logprobs": logprobs, + "top_logprobs": top_logprobs, + "deployment_id": deployment_id, + "base_url": base_url, "api_version": api_version, "api_key": api_key, "model_list": model_list, - "mock_response": mock_response, - "force_timeout": force_timeout, - "custom_llm_provider": custom_llm_provider, "acompletion": True # assuming this is a required parameter } try: @@ -215,7 +227,7 @@ async def acompletion( ctx = contextvars.copy_context() func_with_context = partial(ctx.run, func) - _, custom_llm_provider, _, _ = get_llm_provider(model=model, api_base=completion_kwargs.get("api_base", None)) + _, custom_llm_provider, _, _ = get_llm_provider(model=model, api_base=completion_kwargs.get("base_url", None)) if (custom_llm_provider == "openai" or custom_llm_provider == "azure" From bb06c51edeeaba3904a0655d07a090d51318c810 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateo=20C=C3=A1mara?= Date: Tue, 9 Jan 2024 12:06:49 +0100 Subject: [PATCH 3/5] Added test to check if acompletion is using the same parameters as CompletionRequest attributes. Added functools to client decorator to expose acompletion parameters from outside. --- litellm/tests/test_completion.py | 24 +++++++++++++++++++++++- litellm/utils.py | 3 +++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 80a51c480..6617f0530 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -10,7 +10,7 @@ sys.path.insert( ) # Adds the parent directory to the system path import pytest import litellm -from litellm import embedding, completion, completion_cost, Timeout +from litellm import embedding, completion, completion_cost, Timeout, acompletion from litellm import RateLimitError # litellm.num_retries = 3 @@ -859,6 +859,28 @@ def test_completion_azure_key_completion_arg(): # test_completion_azure_key_completion_arg() +def test_acompletion_params(): + import inspect + from litellm.types.completion import CompletionRequest + + acompletion_params_odict = inspect.signature(acompletion).parameters + acompletion_params = {name: param.annotation for name, param in acompletion_params_odict.items()} + completion_params = {field_name: field_type for field_name, field_type in CompletionRequest.__annotations__.items()} + + # remove kwargs + acompletion_params.pop("kwargs", None) + + keys_acompletion = set(acompletion_params.keys()) + keys_completion = set(completion_params.keys()) + + # Assert that the parameters are the same + if keys_acompletion != keys_completion: + pytest.fail("The parameters of the acompletion function and the CompletionRequest class are not the same.") + + +# test_acompletion_params() + + async def test_re_use_azure_async_client(): try: print("azure gpt-3.5 ASYNC with clie nttest\n\n") diff --git a/litellm/utils.py b/litellm/utils.py index 4520bee62..6223e7646 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -14,6 +14,7 @@ import subprocess, os import litellm, openai import itertools import random, uuid, requests +from functools import wraps import datetime, time import tiktoken import uuid @@ -1934,6 +1935,7 @@ def client(original_function): # [Non-Blocking Error] pass + @wraps(original_function) def wrapper(*args, **kwargs): start_time = datetime.datetime.now() result = None @@ -2128,6 +2130,7 @@ def client(original_function): e.message += f"\n Check the log in your dashboard - {liteDebuggerClient.dashboard_url}" raise e + @wraps(original_function) async def wrapper_async(*args, **kwargs): start_time = datetime.datetime.now() result = None From 0ec976b3d1d8e2648877c95549a88c69abd7bbdd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateo=20C=C3=A1mara?= Date: Tue, 9 Jan 2024 12:55:12 +0100 Subject: [PATCH 4/5] Reverted changes made by the IDE automatically --- litellm/main.py | 2111 +++++++++++++++++++++++++++++++++-------------- 1 file changed, 1471 insertions(+), 640 deletions(-) diff --git a/litellm/main.py b/litellm/main.py index 7c3ccd750..bb7b9817a 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -5,10 +5,10 @@ # | | # +-----------------------------------------------+ # -# Thank you ! We ❤️ you! - Krrish & Ishaan +# Thank you ! We ❤️ you! - Krrish & Ishaan import os, openai, sys, json, inspect, uuid, datetime, threading -from typing import Any +from typing import Any, Literal, Union from functools import partial import dotenv, traceback, random, asyncio, time, contextvars from copy import deepcopy @@ -29,12 +29,12 @@ from litellm.utils import ( completion_with_fallbacks, get_llm_provider, get_api_key, - mock_completion_streaming_obj, - convert_to_model_response_object, - token_counter, - Usage, + mock_completion_streaming_obj, + convert_to_model_response_object, + token_counter, + Usage, get_optional_params_embeddings, - get_optional_params_image_gen + get_optional_params_image_gen, ) from .llms import ( anthropic, @@ -49,20 +49,29 @@ from .llms import ( baseten, vllm, ollama, + ollama_chat, + cloudflare, cohere, petals, oobabooga, openrouter, palm, + gemini, vertex_ai, - maritalk) + maritalk, +) from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion from .llms.azure import AzureChatCompletion from .llms.huggingface_restapi import Huggingface -from .llms.prompt_templates.factory import prompt_factory, custom_prompt, function_call_prompt +from .llms.prompt_templates.factory import ( + prompt_factory, + custom_prompt, + function_call_prompt, +) import tiktoken from concurrent.futures import ThreadPoolExecutor from typing import Callable, List, Optional, Dict, Union, Mapping +from .caching import enable_cache, disable_cache, update_cache encoding = tiktoken.get_encoding("cl100k_base") from litellm.utils import ( @@ -74,8 +83,8 @@ from litellm.utils import ( TextChoices, EmbeddingResponse, read_config_args, - Choices, - Message + Choices, + Message, ) ####### ENVIRONMENT VARIABLES ################### @@ -86,35 +95,39 @@ azure_chat_completions = AzureChatCompletion() huggingface = Huggingface() ####### COMPLETION ENDPOINTS ################ + class LiteLLM: + def __init__( + self, + *, + api_key=None, + organization: Optional[str] = None, + base_url: Optional[str] = None, + timeout: Optional[float] = 600, + max_retries: Optional[int] = litellm.num_retries, + default_headers: Optional[Mapping[str, str]] = None, + ): + self.params = locals() + self.chat = Chat(self.params) - def __init__(self, *, - api_key=None, - organization: Optional[str] = None, - base_url: Optional[str]= None, - timeout: Optional[float] = 600, - max_retries: Optional[int] = litellm.num_retries, - default_headers: Optional[Mapping[str, str]] = None,): - self.params = locals() - self.chat = Chat(self.params) -class Chat(): +class Chat: + def __init__(self, params): + self.params = params + self.completions = Completions(self.params) - def __init__(self, params): - self.params = params - self.completions = Completions(self.params) - -class Completions(): - - def __init__(self, params): - self.params = params - def create(self, messages, model=None, **kwargs): - for k, v in kwargs.items(): - self.params[k] = v - model = model or self.params.get('model') - response = completion(model=model, messages=messages, **self.params) - return response +class Completions: + def __init__(self, params): + self.params = params + + def create(self, messages, model=None, **kwargs): + for k, v in kwargs.items(): + self.params[k] = v + model = model or self.params.get("model") + response = completion(model=model, messages=messages, **self.params) + return response + @client async def acompletion( @@ -221,7 +234,7 @@ async def acompletion( } try: # Use a partial function to pass your keyword arguments - func = partial(completion, **completion_kwargs) + func = partial(completion, *args, **kwargs) # Add the context to the function ctx = contextvars.copy_context() @@ -229,7 +242,8 @@ async def acompletion( _, custom_llm_provider, _, _ = get_llm_provider(model=model, api_base=completion_kwargs.get("base_url", None)) - if (custom_llm_provider == "openai" + if ( + custom_llm_provider == "openai" or custom_llm_provider == "azure" or custom_llm_provider == "custom_openai" or custom_llm_provider == "anyscale" @@ -240,42 +254,57 @@ async def acompletion( or custom_llm_provider == "text-completion-openai" or custom_llm_provider == "huggingface" or custom_llm_provider == "ollama" - or custom_llm_provider == "vertex_ai"): # currently implemented aiohttp calls for just azure and openai, soon all. - if completion_kwargs.get("stream", False): - response = completion(**completion_kwargs) + or custom_llm_provider == "ollama_chat" + or custom_llm_provider == "vertex_ai" + ): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all. + init_response = await loop.run_in_executor(None, func_with_context) + if isinstance(init_response, dict) or isinstance( + init_response, ModelResponse + ): ## CACHING SCENARIO + response = init_response + elif asyncio.iscoroutine(init_response): + response = await init_response else: - # Await normally - init_response = await loop.run_in_executor(None, func_with_context) - if isinstance(init_response, dict) or isinstance(init_response, ModelResponse): ## CACHING SCENARIO - response = init_response - elif asyncio.iscoroutine(init_response): - response = await init_response + response = init_response else: # Call the synchronous function using run_in_executor - response = await loop.run_in_executor(None, func_with_context) - if completion_kwargs.get("stream", False): # return an async generator - return _async_streaming(response=response, model=model, custom_llm_provider=custom_llm_provider, completion_kwargs=completion_kwargs) - else: - return response + response = await loop.run_in_executor(None, func_with_context) + # if kwargs.get("stream", False): # return an async generator + # return _async_streaming( + # response=response, + # model=model, + # custom_llm_provider=custom_llm_provider, + # args=args, + # ) + # else: + return response except Exception as e: custom_llm_provider = custom_llm_provider or "openai" raise exception_type( - model=model, custom_llm_provider=custom_llm_provider, original_exception=e, completion_kwargs=completion_kwargs, - ) + model=model, + custom_llm_provider=custom_llm_provider, + original_exception=e, + completion_kwargs=args, + ) -async def _async_streaming(response, model, custom_llm_provider, completion_kwargs): + +async def _async_streaming(response, model, custom_llm_provider, args): try: print_verbose(f"received response in _async_streaming: {response}") async for line in response: print_verbose(f"line in async streaming: {line}") yield line except Exception as e: - print_verbose(f"error raised _async_streaming: {traceback.format_exc()}") - raise exception_type( - model=model, custom_llm_provider=custom_llm_provider, original_exception=e, completion_kwargs=completion_kwargs, - ) + raise e -def mock_completion(model: str, messages: List, stream: Optional[bool] = False, mock_response: str = "This is a mock request", **kwargs): + +def mock_completion( + model: str, + messages: List, + stream: Optional[bool] = False, + mock_response: str = "This is a mock request", + **kwargs, +): """ Generate a mock completion response for testing or debugging purposes. @@ -302,9 +331,11 @@ def mock_completion(model: str, messages: List, stream: Optional[bool] = False, model_response = ModelResponse(stream=stream) if stream is True: # don't try to access stream object, - response = mock_completion_streaming_obj(model_response, mock_response=mock_response, model=model) + response = mock_completion_streaming_obj( + model_response, mock_response=mock_response, model=model + ) return response - + model_response["choices"][0]["message"]["content"] = mock_response model_response["created"] = int(time.time()) model_response["model"] = model @@ -314,13 +345,12 @@ def mock_completion(model: str, messages: List, stream: Optional[bool] = False, traceback.print_exc() raise Exception("Mock completion response failed") + @client def completion( model: str, # Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create messages: List = [], - functions: Optional[List] = None, - function_call: Optional[str] = None, timeout: Optional[Union[float, int]] = None, temperature: Optional[float] = None, top_p: Optional[float] = None, @@ -329,7 +359,7 @@ def completion( stop=None, max_tokens: Optional[float] = None, presence_penalty: Optional[float] = None, - frequency_penalty: Optional[float]=None, + frequency_penalty: Optional[float] = None, logit_bias: Optional[dict] = None, user: Optional[str] = None, # openai v1.0+ new params @@ -337,13 +367,17 @@ def completion( seed: Optional[int] = None, tools: Optional[List] = None, tool_choice: Optional[str] = None, - deployment_id = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, + deployment_id=None, + # soon to be deprecated params by OpenAI + functions: Optional[List] = None, + function_call: Optional[str] = None, # set api_base, api_version, api_key base_url: Optional[str] = None, api_version: Optional[str] = None, api_key: Optional[str] = None, - model_list: Optional[list] = None, # pass in a list of api_base,keys, etc. - + model_list: Optional[list] = None, # pass in a list of api_base,keys, etc. # Optional liteLLM function params **kwargs, ) -> Union[ModelResponse, CustomStreamWrapper]: @@ -366,7 +400,9 @@ def completion( frequency_penalty: It is used to penalize new tokens based on their frequency in the text so far. logit_bias (dict, optional): Used to modify the probability of specific tokens appearing in the completion. user (str, optional): A unique identifier representing your end-user. This can help the LLM provider to monitor and detect abuse. - metadata (dict, optional): Pass in additional metadata to tag your completion calls - eg. prompt version, details, etc. + logprobs (bool, optional): Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each output token returned in the content of message + top_logprobs (int, optional): An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with an associated log probability. logprobs must be set to true if this parameter is used. + metadata (dict, optional): Pass in additional metadata to tag your completion calls - eg. prompt version, details, etc. api_base (str, optional): Base URL for the API (default is None). api_version (str, optional): API version (default is None). api_key (str, optional): API key (default is None). @@ -386,26 +422,26 @@ def completion( """ ######### unpacking kwargs ##################### args = locals() - api_base = kwargs.get('api_base', None) - mock_response = kwargs.get('mock_response', None) - force_timeout= kwargs.get('force_timeout', 600) ## deprecated - logger_fn = kwargs.get('logger_fn', None) - verbose = kwargs.get('verbose', False) - custom_llm_provider = kwargs.get('custom_llm_provider', None) - litellm_logging_obj = kwargs.get('litellm_logging_obj', None) - id = kwargs.get('id', None) - metadata = kwargs.get('metadata', None) - model_info = kwargs.get('model_info', None) - proxy_server_request = kwargs.get('proxy_server_request', None) - fallbacks = kwargs.get('fallbacks', None) + api_base = kwargs.get("api_base", None) + mock_response = kwargs.get("mock_response", None) + force_timeout = kwargs.get("force_timeout", 600) ## deprecated + logger_fn = kwargs.get("logger_fn", None) + verbose = kwargs.get("verbose", False) + custom_llm_provider = kwargs.get("custom_llm_provider", None) + litellm_logging_obj = kwargs.get("litellm_logging_obj", None) + id = kwargs.get("id", None) + metadata = kwargs.get("metadata", None) + model_info = kwargs.get("model_info", None) + proxy_server_request = kwargs.get("proxy_server_request", None) + fallbacks = kwargs.get("fallbacks", None) headers = kwargs.get("headers", None) - num_retries = kwargs.get("num_retries", None) ## deprecated + num_retries = kwargs.get("num_retries", None) ## deprecated max_retries = kwargs.get("max_retries", None) context_window_fallback_dict = kwargs.get("context_window_fallback_dict", None) - ### CUSTOM MODEL COST ### + ### CUSTOM MODEL COST ### input_cost_per_token = kwargs.get("input_cost_per_token", None) output_cost_per_token = kwargs.get("output_cost_per_token", None) - ### CUSTOM PROMPT TEMPLATE ### + ### CUSTOM PROMPT TEMPLATE ### initial_prompt_value = kwargs.get("initial_prompt_value", None) roles = kwargs.get("roles", None) final_prompt_value = kwargs.get("final_prompt_value", None) @@ -413,104 +449,201 @@ def completion( eos_token = kwargs.get("eos_token", None) preset_cache_key = kwargs.get("preset_cache_key", None) hf_model_name = kwargs.get("hf_model_name", None) - ### ASYNC CALLS ### + ### ASYNC CALLS ### acompletion = kwargs.get("acompletion", False) client = kwargs.get("client", None) ######## end of unpacking kwargs ########### - openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "response_format", "seed", "tools", "tool_choice", "max_retries"] - litellm_params = ["metadata", "acompletion", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "model_info", "proxy_server_request", "preset_cache_key", "caching_groups"] + openai_params = [ + "functions", + "function_call", + "temperature", + "temperature", + "top_p", + "n", + "stream", + "stop", + "max_tokens", + "presence_penalty", + "frequency_penalty", + "logit_bias", + "user", + "request_timeout", + "api_base", + "api_version", + "api_key", + "deployment_id", + "organization", + "base_url", + "default_headers", + "timeout", + "response_format", + "seed", + "tools", + "tool_choice", + "max_retries", + "logprobs", + "top_logprobs", + ] + litellm_params = [ + "metadata", + "acompletion", + "caching", + "mock_response", + "api_key", + "api_version", + "api_base", + "force_timeout", + "logger_fn", + "verbose", + "custom_llm_provider", + "litellm_logging_obj", + "litellm_call_id", + "use_client", + "id", + "fallbacks", + "azure", + "headers", + "model_list", + "num_retries", + "context_window_fallback_dict", + "roles", + "final_prompt_value", + "bos_token", + "eos_token", + "request_timeout", + "complete_response", + "self", + "client", + "rpm", + "tpm", + "input_cost_per_token", + "output_cost_per_token", + "hf_model_name", + "model_info", + "proxy_server_request", + "preset_cache_key", + "caching_groups", + "ttl", + "cache", + ] default_params = openai_params + litellm_params - non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider + non_default_params = { + k: v for k, v in kwargs.items() if k not in default_params + } # model-specific params - pass them straight to the model/provider if mock_response: - return mock_completion(model, messages, stream=stream, mock_response=mock_response) + return mock_completion( + model, messages, stream=stream, mock_response=mock_response + ) if timeout is None: - timeout = kwargs.get("request_timeout", None) or 600 # set timeout for 10 minutes by default + timeout = ( + kwargs.get("request_timeout", None) or 600 + ) # set timeout for 10 minutes by default timeout = float(timeout) try: - if base_url is not None: + if base_url is not None: api_base = base_url - if max_retries is not None: # openai allows openai.OpenAI(max_retries=3) + if max_retries is not None: # openai allows openai.OpenAI(max_retries=3) num_retries = max_retries logging = litellm_logging_obj - fallbacks = ( - fallbacks - or litellm.model_fallbacks - ) + fallbacks = fallbacks or litellm.model_fallbacks if fallbacks is not None: return completion_with_fallbacks(**args) - if model_list is not None: - deployments = [m["litellm_params"] for m in model_list if m["model_name"] == model] + if model_list is not None: + deployments = [ + m["litellm_params"] for m in model_list if m["model_name"] == model + ] return batch_completion_models(deployments=deployments, **args) if litellm.model_alias_map and model in litellm.model_alias_map: model = litellm.model_alias_map[ model ] # update the model to the actual value if an alias has been passed in model_response = ModelResponse() - if kwargs.get('azure', False) == True: # don't remove flag check, to remain backwards compatible for repos like Codium - custom_llm_provider="azure" - if deployment_id != None: # azure llms - model=deployment_id - custom_llm_provider="azure" - model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base, api_key=api_key) + if ( + kwargs.get("azure", False) == True + ): # don't remove flag check, to remain backwards compatible for repos like Codium + custom_llm_provider = "azure" + if deployment_id != None: # azure llms + model = deployment_id + custom_llm_provider = "azure" + model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider( + model=model, + custom_llm_provider=custom_llm_provider, + api_base=api_base, + api_key=api_key, + ) ### REGISTER CUSTOM MODEL PRICING -- IF GIVEN ### - if input_cost_per_token is not None and output_cost_per_token is not None: - litellm.register_model({ - model: { - "input_cost_per_token": input_cost_per_token, - "output_cost_per_token": output_cost_per_token, - "litellm_provider": custom_llm_provider + if input_cost_per_token is not None and output_cost_per_token is not None: + litellm.register_model( + { + model: { + "input_cost_per_token": input_cost_per_token, + "output_cost_per_token": output_cost_per_token, + "litellm_provider": custom_llm_provider, + } } - }) + ) ### BUILD CUSTOM PROMPT TEMPLATE -- IF GIVEN ### - custom_prompt_dict = {} # type: ignore - if initial_prompt_value or roles or final_prompt_value or bos_token or eos_token: + custom_prompt_dict = {} # type: ignore + if ( + initial_prompt_value + or roles + or final_prompt_value + or bos_token + or eos_token + ): custom_prompt_dict = {model: {}} if initial_prompt_value: custom_prompt_dict[model]["initial_prompt_value"] = initial_prompt_value - if roles: + if roles: custom_prompt_dict[model]["roles"] = roles - if final_prompt_value: + if final_prompt_value: custom_prompt_dict[model]["final_prompt_value"] = final_prompt_value if bos_token: custom_prompt_dict[model]["bos_token"] = bos_token if eos_token: custom_prompt_dict[model]["eos_token"] = eos_token - model_api_key = get_api_key(llm_provider=custom_llm_provider, dynamic_api_key=api_key) # get the api key from the environment if required for the model - if model_api_key and "sk-litellm" in model_api_key: - api_base = "https://proxy.litellm.ai" - custom_llm_provider = "openai" - api_key = model_api_key + model_api_key = get_api_key( + llm_provider=custom_llm_provider, dynamic_api_key=api_key + ) # get the api key from the environment if required for the model - if dynamic_api_key is not None: - api_key = dynamic_api_key + if dynamic_api_key is not None: + api_key = dynamic_api_key # check if user passed in any of the OpenAI optional params optional_params = get_optional_params( - functions=functions, - function_call=function_call, - temperature=temperature, - top_p=top_p, - n=n, - stream=stream, - stop=stop, - max_tokens=max_tokens, - presence_penalty=presence_penalty, - frequency_penalty=frequency_penalty, - logit_bias=logit_bias, - user=user, - # params to identify the model - model=model, - custom_llm_provider=custom_llm_provider, - response_format=response_format, - seed=seed, - tools=tools, - tool_choice=tool_choice, - max_retries=max_retries, - **non_default_params + functions=functions, + function_call=function_call, + temperature=temperature, + top_p=top_p, + n=n, + stream=stream, + stop=stop, + max_tokens=max_tokens, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + user=user, + # params to identify the model + model=model, + custom_llm_provider=custom_llm_provider, + response_format=response_format, + seed=seed, + tools=tools, + tool_choice=tool_choice, + max_retries=max_retries, + logprobs=logprobs, + top_logprobs=top_logprobs, + **non_default_params, + ) + + if litellm.add_function_to_prompt and optional_params.get( + "functions_unsupported_model", None + ): # if user opts to add it to prompt, when API doesn't support function calling + functions_unsupported_model = optional_params.pop( + "functions_unsupported_model" + ) + messages = function_call_prompt( + messages=messages, functions=functions_unsupported_model ) - - if litellm.add_function_to_prompt and optional_params.get("functions_unsupported_model", None): # if user opts to add it to prompt, when API doesn't support function calling - functions_unsupported_model = optional_params.pop("functions_unsupported_model") - messages = function_call_prompt(messages=messages, functions=functions_unsupported_model) # For logging - save the values of the litellm-specific params passed in litellm_params = get_litellm_params( @@ -521,53 +654,50 @@ def completion( verbose=verbose, custom_llm_provider=custom_llm_provider, api_base=api_base, - litellm_call_id=kwargs.get('litellm_call_id', None), + litellm_call_id=kwargs.get("litellm_call_id", None), model_alias_map=litellm.model_alias_map, completion_call_id=id, metadata=metadata, model_info=model_info, proxy_server_request=proxy_server_request, - preset_cache_key=preset_cache_key + preset_cache_key=preset_cache_key, + ) + logging.update_environment_variables( + model=model, + user=user, + optional_params=optional_params, + litellm_params=litellm_params, ) - logging.update_environment_variables(model=model, user=user, optional_params=optional_params, litellm_params=litellm_params) if custom_llm_provider == "azure": # azure configs api_type = get_secret("AZURE_API_TYPE") or "azure" - api_base = ( - api_base - or litellm.api_base - or get_secret("AZURE_API_BASE") - ) + api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE") api_version = ( - api_version or - litellm.api_version or - get_secret("AZURE_API_VERSION") + api_version or litellm.api_version or get_secret("AZURE_API_VERSION") ) api_key = ( - api_key or - litellm.api_key or - litellm.azure_key or - get_secret("AZURE_OPENAI_API_KEY") or - get_secret("AZURE_API_KEY") + api_key + or litellm.api_key + or litellm.azure_key + or get_secret("AZURE_OPENAI_API_KEY") + or get_secret("AZURE_API_KEY") ) - azure_ad_token = ( - optional_params.pop("azure_ad_token", None) or - get_secret("AZURE_AD_TOKEN") + azure_ad_token = optional_params.pop("azure_ad_token", None) or get_secret( + "AZURE_AD_TOKEN" ) - headers = ( - headers or - litellm.headers - ) + headers = headers or litellm.headers ## LOAD CONFIG - if set - config=litellm.AzureOpenAIConfig.get_config() + config = litellm.AzureOpenAIConfig.get_config() for k, v in config.items(): - if k not in optional_params: # completion(top_k=3) > azure_config(top_k=3) <- allows for dynamic variables to be passed in + if ( + k not in optional_params + ): # completion(top_k=3) > azure_config(top_k=3) <- allows for dynamic variables to be passed in optional_params[k] = v ## COMPLETION CALL @@ -585,10 +715,10 @@ def completion( optional_params=optional_params, litellm_params=litellm_params, logger_fn=logger_fn, - logging_obj=logging, - acompletion=acompletion, + logging_obj=logging, + acompletion=acompletion, timeout=timeout, - client=client # pass AsyncAzureOpenAI, AzureOpenAI client + client=client, # pass AsyncAzureOpenAI, AzureOpenAI client ) if optional_params.get("stream", False) or acompletion == True: @@ -616,7 +746,7 @@ def completion( # note: if a user sets a custom base - we should ensure this works # allow for the setting of dynamic and stateful api-bases api_base = ( - api_base # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api base from there + api_base # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api base from there or litellm.api_base or get_secret("OPENAI_API_BASE") or "https://api.openai.com/v1" @@ -624,25 +754,24 @@ def completion( openai.organization = ( litellm.organization or get_secret("OPENAI_ORGANIZATION") - or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105 + or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105 ) # set API KEY api_key = ( - api_key or # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there - litellm.api_key or - litellm.openai_key or - get_secret("OPENAI_API_KEY") + api_key + or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there + or litellm.openai_key + or get_secret("OPENAI_API_KEY") ) - headers = ( - headers or - litellm.headers - ) + headers = headers or litellm.headers ## LOAD CONFIG - if set - config=litellm.OpenAIConfig.get_config() + config = litellm.OpenAIConfig.get_config() for k, v in config.items(): - if k not in optional_params: # completion(top_k=3) > openai_config(top_k=3) <- allows for dynamic variables to be passed in + if ( + k not in optional_params + ): # completion(top_k=3) > openai_config(top_k=3) <- allows for dynamic variables to be passed in optional_params[k] = v ## COMPLETION CALL @@ -650,6 +779,7 @@ def completion( response = openai_chat_completions.completion( model=model, messages=messages, + headers=headers, model_response=model_response, print_verbose=print_verbose, api_key=api_key, @@ -661,7 +791,7 @@ def completion( logger_fn=logger_fn, timeout=timeout, custom_prompt_dict=custom_prompt_dict, - client=client # pass AsyncOpenAI, OpenAI client + client=client, # pass AsyncOpenAI, OpenAI client ) except Exception as e: ## LOGGING - log the original exception returned @@ -699,31 +829,34 @@ def completion( # set API KEY api_key = ( - api_key or - litellm.api_key or - litellm.openai_key or - get_secret("OPENAI_API_KEY") + api_key + or litellm.api_key + or litellm.openai_key + or get_secret("OPENAI_API_KEY") ) - headers = ( - headers or - litellm.headers - ) + headers = headers or litellm.headers ## LOAD CONFIG - if set - config=litellm.OpenAITextCompletionConfig.get_config() + config = litellm.OpenAITextCompletionConfig.get_config() for k, v in config.items(): - if k not in optional_params: # completion(top_k=3) > openai_text_config(top_k=3) <- allows for dynamic variables to be passed in + if ( + k not in optional_params + ): # completion(top_k=3) > openai_text_config(top_k=3) <- allows for dynamic variables to be passed in optional_params[k] = v if litellm.organization: openai.organization = litellm.organization - if len(messages)>0 and "content" in messages[0] and type(messages[0]["content"]) == list: + if ( + len(messages) > 0 + and "content" in messages[0] + and type(messages[0]["content"]) == list + ): # text-davinci-003 can accept a string or array, if it's an array, assume the array is set in messages[0]['content'] # https://platform.openai.com/docs/api-reference/completions/create prompt = messages[0]["content"] else: - prompt = " ".join([message["content"] for message in messages]) # type: ignore + prompt = " ".join([message["content"] for message in messages]) # type: ignore ## COMPLETION CALL model_response = openai_text_completions.completion( model=model, @@ -736,9 +869,10 @@ def completion( logging_obj=logging, optional_params=optional_params, litellm_params=litellm_params, - logger_fn=logger_fn + logger_fn=logger_fn, + timeout=timeout, ) - + if optional_params.get("stream", False) or acompletion == True: ## LOGGING logging.post_call( @@ -749,16 +883,16 @@ def completion( ) response = model_response elif ( - "replicate" in model or - custom_llm_provider == "replicate" or - model in litellm.replicate_models + "replicate" in model + or custom_llm_provider == "replicate" + or model in litellm.replicate_models ): # Setting the relevant API KEY for replicate, replicate defaults to using os.environ.get("REPLICATE_API_TOKEN") replicate_key = None replicate_key = ( api_key or litellm.replicate_key - or litellm.api_key + or litellm.api_key or get_secret("REPLICATE_API_KEY") or get_secret("REPLICATE_API_TOKEN") ) @@ -770,10 +904,7 @@ def completion( or "https://api.replicate.com/v1" ) - custom_prompt_dict = ( - custom_prompt_dict - or litellm.custom_prompt_dict - ) + custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict model_response = replicate.completion( model=model, @@ -784,14 +915,14 @@ def completion( optional_params=optional_params, litellm_params=litellm_params, logger_fn=logger_fn, - encoding=encoding, # for calculating input/output tokens + encoding=encoding, # for calculating input/output tokens api_key=replicate_key, - logging_obj=logging, - custom_prompt_dict=custom_prompt_dict + logging_obj=logging, + custom_prompt_dict=custom_prompt_dict, ) if "stream" in optional_params and optional_params["stream"] == True: # don't try to access stream object, - model_response = CustomStreamWrapper(model_response, model, logging_obj=logging, custom_llm_provider="replicate") # type: ignore + model_response = CustomStreamWrapper(model_response, model, logging_obj=logging, custom_llm_provider="replicate") # type: ignore if optional_params.get("stream", False) or acompletion == True: ## LOGGING @@ -803,12 +934,12 @@ def completion( response = model_response - elif custom_llm_provider=="anthropic": + elif custom_llm_provider == "anthropic": api_key = ( - api_key - or litellm.anthropic_key + api_key + or litellm.anthropic_key or litellm.api_key - or os.environ.get("ANTHROPIC_API_KEY") + or os.environ.get("ANTHROPIC_API_KEY") ) api_base = ( api_base @@ -816,10 +947,7 @@ def completion( or get_secret("ANTHROPIC_API_BASE") or "https://api.anthropic.com/v1/complete" ) - custom_prompt_dict = ( - custom_prompt_dict - or litellm.custom_prompt_dict - ) + custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict response = anthropic.completion( model=model, messages=messages, @@ -830,14 +958,19 @@ def completion( optional_params=optional_params, litellm_params=litellm_params, logger_fn=logger_fn, - encoding=encoding, # for calculating input/output tokens + encoding=encoding, # for calculating input/output tokens api_key=api_key, - logging_obj=logging, + logging_obj=logging, ) if "stream" in optional_params and optional_params["stream"] == True: # don't try to access stream object, - response = CustomStreamWrapper(response, model, custom_llm_provider="anthropic", logging_obj=logging) - + response = CustomStreamWrapper( + response, + model, + custom_llm_provider="anthropic", + logging_obj=logging, + ) + if optional_params.get("stream", False) or acompletion == True: ## LOGGING logging.post_call( @@ -848,7 +981,10 @@ def completion( response = response elif custom_llm_provider == "nlp_cloud": nlp_cloud_key = ( - api_key or litellm.nlp_cloud_key or get_secret("NLP_CLOUD_API_KEY") or litellm.api_key + api_key + or litellm.nlp_cloud_key + or get_secret("NLP_CLOUD_API_KEY") + or litellm.api_key ) api_base = ( @@ -869,13 +1005,18 @@ def completion( logger_fn=logger_fn, encoding=encoding, api_key=nlp_cloud_key, - logging_obj=logging + logging_obj=logging, ) if "stream" in optional_params and optional_params["stream"] == True: # don't try to access stream object, - response = CustomStreamWrapper(response, model, custom_llm_provider="nlp_cloud", logging_obj=logging) - + response = CustomStreamWrapper( + response, + model, + custom_llm_provider="nlp_cloud", + logging_obj=logging, + ) + if optional_params.get("stream", False) or acompletion == True: ## LOGGING logging.post_call( @@ -887,7 +1028,11 @@ def completion( response = response elif custom_llm_provider == "aleph_alpha": aleph_alpha_key = ( - api_key or litellm.aleph_alpha_key or get_secret("ALEPH_ALPHA_API_KEY") or get_secret("ALEPHALPHA_API_KEY") or litellm.api_key + api_key + or litellm.aleph_alpha_key + or get_secret("ALEPH_ALPHA_API_KEY") + or get_secret("ALEPHALPHA_API_KEY") + or litellm.api_key ) api_base = ( @@ -909,12 +1054,17 @@ def completion( encoding=encoding, default_max_tokens_to_sample=litellm.max_tokens, api_key=aleph_alpha_key, - logging_obj=logging # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements + logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements ) if "stream" in optional_params and optional_params["stream"] == True: # don't try to access stream object, - response = CustomStreamWrapper(model_response, model, custom_llm_provider="aleph_alpha", logging_obj=logging) + response = CustomStreamWrapper( + model_response, + model, + custom_llm_provider="aleph_alpha", + logging_obj=logging, + ) return response response = model_response elif custom_llm_provider == "cohere": @@ -932,7 +1082,7 @@ def completion( or get_secret("COHERE_API_BASE") or "https://api.cohere.ai/v1/generate" ) - + model_response = cohere.completion( model=model, messages=messages, @@ -944,12 +1094,17 @@ def completion( logger_fn=logger_fn, encoding=encoding, api_key=cohere_key, - logging_obj=logging # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements + logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements ) if "stream" in optional_params and optional_params["stream"] == True: # don't try to access stream object, - response = CustomStreamWrapper(model_response, model, custom_llm_provider="cohere", logging_obj=logging) + response = CustomStreamWrapper( + model_response, + model, + custom_llm_provider="cohere", + logging_obj=logging, + ) return response response = model_response elif custom_llm_provider == "maritalk": @@ -966,7 +1121,7 @@ def completion( or get_secret("MARITALK_API_BASE") or "https://chat.maritaca.ai/api/chat/inference" ) - + model_response = maritalk.completion( model=model, messages=messages, @@ -978,17 +1133,20 @@ def completion( logger_fn=logger_fn, encoding=encoding, api_key=maritalk_key, - logging_obj=logging + logging_obj=logging, ) if "stream" in optional_params and optional_params["stream"] == True: # don't try to access stream object, - response = CustomStreamWrapper(model_response, model, custom_llm_provider="maritalk", logging_obj=logging) + response = CustomStreamWrapper( + model_response, + model, + custom_llm_provider="maritalk", + logging_obj=logging, + ) return response response = model_response - elif ( - custom_llm_provider == "huggingface" - ): + elif custom_llm_provider == "huggingface": custom_llm_provider = "huggingface" huggingface_key = ( api_key @@ -997,35 +1155,37 @@ def completion( or os.environ.get("HUGGINGFACE_API_KEY") or litellm.api_key ) - hf_headers = ( - headers - or litellm.headers - ) + hf_headers = headers or litellm.headers - custom_prompt_dict = ( - custom_prompt_dict - or litellm.custom_prompt_dict - ) + custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict model_response = huggingface.completion( model=model, messages=messages, - api_base=api_base, # type: ignore + api_base=api_base, # type: ignore headers=hf_headers, model_response=model_response, print_verbose=print_verbose, optional_params=optional_params, litellm_params=litellm_params, logger_fn=logger_fn, - encoding=encoding, - api_key=huggingface_key, + encoding=encoding, + api_key=huggingface_key, acompletion=acompletion, logging_obj=logging, - custom_prompt_dict=custom_prompt_dict + custom_prompt_dict=custom_prompt_dict, + timeout=timeout ) - if "stream" in optional_params and optional_params["stream"] == True and acompletion is False: + if ( + "stream" in optional_params + and optional_params["stream"] == True + and acompletion is False + ): # don't try to access stream object, response = CustomStreamWrapper( - model_response, model, custom_llm_provider="huggingface", logging_obj=logging + model_response, + model, + custom_llm_provider="huggingface", + logging_obj=logging, ) return response response = model_response @@ -1035,73 +1195,62 @@ def completion( model=model, messages=messages, model_response=model_response, - api_base=api_base, # type: ignore + api_base=api_base, # type: ignore print_verbose=print_verbose, optional_params=optional_params, litellm_params=litellm_params, api_key=None, logger_fn=logger_fn, encoding=encoding, - logging_obj=logging + logging_obj=logging, ) if "stream" in optional_params and optional_params["stream"] == True: # don't try to access stream object, response = CustomStreamWrapper( - model_response, model, custom_llm_provider="oobabooga", logging_obj=logging + model_response, + model, + custom_llm_provider="oobabooga", + logging_obj=logging, ) return response response = model_response elif custom_llm_provider == "openrouter": - api_base = ( - api_base - or litellm.api_base - or "https://openrouter.ai/api/v1" - ) + api_base = api_base or litellm.api_base or "https://openrouter.ai/api/v1" api_key = ( - api_key or - litellm.api_key or - litellm.openrouter_key or - get_secret("OPENROUTER_API_KEY") or - get_secret("OR_API_KEY") + api_key + or litellm.api_key + or litellm.openrouter_key + or get_secret("OPENROUTER_API_KEY") + or get_secret("OR_API_KEY") ) - openrouter_site_url = ( - get_secret("OR_SITE_URL") - or "https://litellm.ai" - ) + openrouter_site_url = get_secret("OR_SITE_URL") or "https://litellm.ai" - openrouter_app_name = ( - get_secret("OR_APP_NAME") - or "liteLLM" - ) + openrouter_app_name = get_secret("OR_APP_NAME") or "liteLLM" headers = ( - headers or - litellm.headers or - { + headers + or litellm.headers + or { "HTTP-Referer": openrouter_site_url, "X-Title": openrouter_app_name, } ) ## Load Config - config = openrouter.OpenrouterConfig.get_config() - for k, v in config.items(): + config = openrouter.OpenrouterConfig.get_config() + for k, v in config.items(): if k == "extra_body": # we use openai 'extra_body' to pass openrouter specific params - transforms, route, models - if "extra_body" in optional_params: + if "extra_body" in optional_params: optional_params[k].update(v) else: optional_params[k] = v - elif k not in optional_params: + elif k not in optional_params: optional_params[k] = v - data = { - "model": model, - "messages": messages, - **optional_params - } + data = {"model": model, "messages": messages, **optional_params} ## COMPLETION CALL response = openai_chat_completions.completion( @@ -1115,15 +1264,19 @@ def completion( optional_params=optional_params, litellm_params=litellm_params, logger_fn=logger_fn, - logging_obj=logging, + logging_obj=logging, acompletion=acompletion, - timeout=timeout + timeout=timeout, ) ## LOGGING logging.post_call( input=messages, api_key=openai.api_key, original_response=response ) - elif custom_llm_provider == "together_ai" or ("togethercomputer" in model) or (model in litellm.together_ai_models): + elif ( + custom_llm_provider == "together_ai" + or ("togethercomputer" in model) + or (model in litellm.together_ai_models) + ): custom_llm_provider = "together_ai" together_ai_key = ( api_key @@ -1140,11 +1293,8 @@ def completion( or "https://api.together.xyz/inference" ) - custom_prompt_dict = ( - custom_prompt_dict - or litellm.custom_prompt_dict - ) - + custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict + model_response = together_ai.completion( model=model, messages=messages, @@ -1157,22 +1307,24 @@ def completion( encoding=encoding, api_key=together_ai_key, logging_obj=logging, - custom_prompt_dict=custom_prompt_dict + custom_prompt_dict=custom_prompt_dict, ) - if "stream_tokens" in optional_params and optional_params["stream_tokens"] == True: + if ( + "stream_tokens" in optional_params + and optional_params["stream_tokens"] == True + ): # don't try to access stream object, response = CustomStreamWrapper( - model_response, model, custom_llm_provider="together_ai", logging_obj=logging + model_response, + model, + custom_llm_provider="together_ai", + logging_obj=logging, ) return response response = model_response elif custom_llm_provider == "palm": - palm_api_key = ( - api_key - or get_secret("PALM_API_KEY") - or litellm.api_key - ) - + palm_api_key = api_key or get_secret("PALM_API_KEY") or litellm.api_key + # palm does not support streaming as yet :( model_response = palm.completion( model=model, @@ -1184,7 +1336,7 @@ def completion( logger_fn=logger_fn, encoding=encoding, api_key=palm_api_key, - logging_obj=logging + logging_obj=logging, ) # fake palm streaming if "stream" in optional_params and optional_params["stream"] == True: @@ -1195,11 +1347,35 @@ def completion( ) return response response = model_response + elif custom_llm_provider == "gemini": + gemini_api_key = ( + api_key + or get_secret("GEMINI_API_KEY") + or get_secret("PALM_API_KEY") # older palm api key should also work + or litellm.api_key + ) + + # palm does not support streaming as yet :( + model_response = gemini.completion( + model=model, + messages=messages, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + api_key=gemini_api_key, + logging_obj=logging, + acompletion=acompletion, + custom_prompt_dict=custom_prompt_dict, + ) + response = model_response elif custom_llm_provider == "vertex_ai": - vertex_ai_project = (litellm.vertex_project - or get_secret("VERTEXAI_PROJECT")) - vertex_ai_location = (litellm.vertex_location - or get_secret("VERTEXAI_LOCATION")) + vertex_ai_project = litellm.vertex_project or get_secret("VERTEXAI_PROJECT") + vertex_ai_location = litellm.vertex_location or get_secret( + "VERTEXAI_LOCATION" + ) model_response = vertex_ai.completion( model=model, @@ -1212,14 +1388,21 @@ def completion( encoding=encoding, vertex_location=vertex_ai_location, vertex_project=vertex_ai_project, - logging_obj=logging, - acompletion=acompletion + logging_obj=logging, + acompletion=acompletion, ) - - if "stream" in optional_params and optional_params["stream"] == True and acompletion == False: + + if ( + "stream" in optional_params + and optional_params["stream"] == True + and acompletion == False + ): response = CustomStreamWrapper( - model_response, model, custom_llm_provider="vertex_ai", logging_obj=logging - ) + model_response, + model, + custom_llm_provider="vertex_ai", + logging_obj=logging, + ) return response response = model_response elif custom_llm_provider == "ai21": @@ -1237,7 +1420,7 @@ def completion( or get_secret("AI21_API_BASE") or "https://api.ai21.com/studio/v1/" ) - + model_response = ai21.completion( model=model, messages=messages, @@ -1249,16 +1432,19 @@ def completion( logger_fn=logger_fn, encoding=encoding, api_key=ai21_key, - logging_obj=logging + logging_obj=logging, ) - + if "stream" in optional_params and optional_params["stream"] == True: # don't try to access stream object, response = CustomStreamWrapper( - model_response, model, custom_llm_provider="ai21", logging_obj=logging + model_response, + model, + custom_llm_provider="ai21", + logging_obj=logging, ) return response - + ## RESPONSE OBJECT response = model_response elif custom_llm_provider == "sagemaker": @@ -1274,18 +1460,23 @@ def completion( hf_model_name=hf_model_name, logger_fn=logger_fn, encoding=encoding, - logging_obj=logging + logging_obj=logging, ) - if "stream" in optional_params and optional_params["stream"]==True: ## [BETA] + if ( + "stream" in optional_params and optional_params["stream"] == True + ): ## [BETA] # sagemaker does not support streaming as of now so we're faking streaming: # https://discuss.huggingface.co/t/streaming-output-text-when-deploying-on-sagemaker/39611 # "SageMaker is currently not supporting streaming responses." - + # fake streaming for sagemaker print_verbose(f"ENTERS SAGEMAKER CUSTOMSTREAMWRAPPER") resp_string = model_response["choices"][0]["message"]["content"] response = CustomStreamWrapper( - resp_string, model, custom_llm_provider="sagemaker", logging_obj=logging + resp_string, + model, + custom_llm_provider="sagemaker", + logging_obj=logging, ) return response @@ -1293,10 +1484,7 @@ def completion( response = model_response elif custom_llm_provider == "bedrock": # boto3 reads keys from .env - custom_prompt_dict = ( - custom_prompt_dict - or litellm.custom_prompt_dict - ) + custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict response = bedrock.completion( model=model, messages=messages, @@ -1310,18 +1498,23 @@ def completion( logging_obj=logging, ) - if "stream" in optional_params and optional_params["stream"] == True: # don't try to access stream object, - if "ai21" in model: + if "ai21" in model: response = CustomStreamWrapper( - response, model, custom_llm_provider="bedrock", logging_obj=logging + response, + model, + custom_llm_provider="bedrock", + logging_obj=logging, ) else: response = CustomStreamWrapper( - iter(response), model, custom_llm_provider="bedrock", logging_obj=logging + iter(response), + model, + custom_llm_provider="bedrock", + logging_obj=logging, ) - + if optional_params.get("stream", False): ## LOGGING logging.post_call( @@ -1330,7 +1523,6 @@ def completion( original_response=response, ) - ## RESPONSE OBJECT response = response elif custom_llm_provider == "vllm": @@ -1343,13 +1535,18 @@ def completion( litellm_params=litellm_params, logger_fn=logger_fn, encoding=encoding, - logging_obj=logging + logging_obj=logging, ) - if "stream" in optional_params and optional_params["stream"] == True: ## [BETA] + if ( + "stream" in optional_params and optional_params["stream"] == True + ): ## [BETA] # don't try to access stream object, response = CustomStreamWrapper( - model_response, model, custom_llm_provider="vllm", logging_obj=logging + model_response, + model, + custom_llm_provider="vllm", + logging_obj=logging, ) return response @@ -1357,27 +1554,27 @@ def completion( response = model_response elif custom_llm_provider == "ollama": api_base = ( - litellm.api_base or - api_base or - get_secret("OLLAMA_API_BASE") or - "http://localhost:11434" - - ) - custom_prompt_dict = ( - custom_prompt_dict - or litellm.custom_prompt_dict + litellm.api_base + or api_base + or get_secret("OLLAMA_API_BASE") + or "http://localhost:11434" ) + custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict if model in custom_prompt_dict: # check if the model has a registered custom prompt model_prompt_details = custom_prompt_dict[model] prompt = custom_prompt( - role_dict=model_prompt_details["roles"], - initial_prompt_value=model_prompt_details["initial_prompt_value"], - final_prompt_value=model_prompt_details["final_prompt_value"], - messages=messages + role_dict=model_prompt_details["roles"], + initial_prompt_value=model_prompt_details["initial_prompt_value"], + final_prompt_value=model_prompt_details["final_prompt_value"], + messages=messages, ) else: - prompt = prompt_factory(model=model, messages=messages, custom_llm_provider=custom_llm_provider) + prompt = prompt_factory( + model=model, + messages=messages, + custom_llm_provider=custom_llm_provider, + ) if isinstance(prompt, dict): # for multimode models - ollama/llava prompt_factory returns a dict { # "prompt": prompt, @@ -1387,18 +1584,100 @@ def completion( optional_params["images"] = images ## LOGGING - generator = ollama.get_ollama_response(api_base, model, prompt, optional_params, logging_obj=logging, acompletion=acompletion, model_response=model_response, encoding=encoding) + generator = ollama.get_ollama_response( + api_base, + model, + prompt, + optional_params, + logging_obj=logging, + acompletion=acompletion, + model_response=model_response, + encoding=encoding, + ) if acompletion is True or optional_params.get("stream", False) == True: return generator - + response = generator + elif custom_llm_provider == "ollama_chat": + api_base = ( + litellm.api_base + or api_base + or get_secret("OLLAMA_API_BASE") + or "http://localhost:11434" + ) + + ## LOGGING + generator = ollama_chat.get_ollama_response( + api_base, + model, + messages, + optional_params, + logging_obj=logging, + acompletion=acompletion, + model_response=model_response, + encoding=encoding, + ) + if acompletion is True or optional_params.get("stream", False) == True: + return generator + + response = generator + elif custom_llm_provider == "cloudflare": + api_key = ( + api_key + or litellm.cloudflare_api_key + or litellm.api_key + or get_secret("CLOUDFLARE_API_KEY") + ) + account_id = get_secret("CLOUDFLARE_ACCOUNT_ID") + api_base = ( + api_base + or litellm.api_base + or get_secret("CLOUDFLARE_API_BASE") + or f"https://api.cloudflare.com/client/v4/accounts/{account_id}/ai/run/" + ) + + custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict + response = cloudflare.completion( + model=model, + messages=messages, + api_base=api_base, + custom_prompt_dict=litellm.custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, # for calculating input/output tokens + api_key=api_key, + logging_obj=logging, + ) + if "stream" in optional_params and optional_params["stream"] == True: + # don't try to access stream object, + response = CustomStreamWrapper( + response, + model, + custom_llm_provider="cloudflare", + logging_obj=logging, + ) + + if optional_params.get("stream", False) or acompletion == True: + ## LOGGING + logging.post_call( + input=messages, + api_key=api_key, + original_response=response, + ) + response = response elif ( custom_llm_provider == "baseten" or litellm.api_base == "https://app.baseten.co" ): custom_llm_provider = "baseten" baseten_key = ( - api_key or litellm.baseten_key or os.environ.get("BASETEN_API_KEY") or litellm.api_key + api_key + or litellm.baseten_key + or os.environ.get("BASETEN_API_KEY") + or litellm.api_key ) model_response = baseten.completion( @@ -1409,25 +1688,24 @@ def completion( optional_params=optional_params, litellm_params=litellm_params, logger_fn=logger_fn, - encoding=encoding, - api_key=baseten_key, - logging_obj=logging + encoding=encoding, + api_key=baseten_key, + logging_obj=logging, ) - if inspect.isgenerator(model_response) or ("stream" in optional_params and optional_params["stream"] == True): + if inspect.isgenerator(model_response) or ( + "stream" in optional_params and optional_params["stream"] == True + ): # don't try to access stream object, response = CustomStreamWrapper( - model_response, model, custom_llm_provider="baseten", logging_obj=logging + model_response, + model, + custom_llm_provider="baseten", + logging_obj=logging, ) return response response = model_response - elif ( - custom_llm_provider == "petals" - or model in litellm.petals_models - ): - api_base = ( - api_base or - litellm.api_base - ) + elif custom_llm_provider == "petals" or model in litellm.petals_models: + api_base = api_base or litellm.api_base custom_llm_provider = "petals" stream = optional_params.pop("stream", False) @@ -1440,29 +1718,28 @@ def completion( optional_params=optional_params, litellm_params=litellm_params, logger_fn=logger_fn, - encoding=encoding, - logging_obj=logging + encoding=encoding, + logging_obj=logging, ) - if stream==True: ## [BETA] + if stream == True: ## [BETA] # Fake streaming for petals resp_string = model_response["choices"][0]["message"]["content"] response = CustomStreamWrapper( - resp_string, model, custom_llm_provider="petals", logging_obj=logging + resp_string, + model, + custom_llm_provider="petals", + logging_obj=logging, ) return response response = model_response - elif ( - custom_llm_provider == "custom" - ): + elif custom_llm_provider == "custom": import requests - url = ( - litellm.api_base or - api_base or - "" - ) + url = litellm.api_base or api_base or "" if url == None or url == "": - raise ValueError("api_base not set. Set api_base or litellm.api_base for custom endpoints") + raise ValueError( + "api_base not set. Set api_base or litellm.api_base for custom endpoints" + ) """ assume input to custom LLM api bases follow this format: @@ -1481,17 +1758,20 @@ def completion( ) """ - prompt = " ".join([message["content"] for message in messages]) # type: ignore - resp = requests.post(url, json={ - 'model': model, - 'params': { - 'prompt': [prompt], - 'max_tokens': max_tokens, - 'temperature': temperature, - 'top_p': top_p, - 'top_k': kwargs.get('top_k', 40), - } - }) + prompt = " ".join([message["content"] for message in messages]) # type: ignore + resp = requests.post( + url, + json={ + "model": model, + "params": { + "prompt": [prompt], + "max_tokens": max_tokens, + "temperature": temperature, + "top_p": top_p, + "top_k": kwargs.get("top_k", 40), + }, + }, + ) response_json = resp.json() """ assume all responses from custom api_bases of this format: @@ -1506,7 +1786,7 @@ def completion( ] } """ - string_response = response_json['data'][0]['output'][0] + string_response = response_json["data"][0]["output"][0] ## RESPONSE OBJECT model_response["choices"][0]["message"]["content"] = string_response model_response["created"] = int(time.time()) @@ -1520,8 +1800,11 @@ def completion( except Exception as e: ## Map to OpenAI Exception raise exception_type( - model=model, custom_llm_provider=custom_llm_provider, original_exception=e, completion_kwargs=args, - ) + model=model, + custom_llm_provider=custom_llm_provider, + original_exception=e, + completion_kwargs=args, + ) def completion_with_retries(*args, **kwargs): @@ -1531,17 +1814,26 @@ def completion_with_retries(*args, **kwargs): try: import tenacity except Exception as e: - raise Exception(f"tenacity import failed please run `pip install tenacity`. Error{e}") - + raise Exception( + f"tenacity import failed please run `pip install tenacity`. Error{e}" + ) + num_retries = kwargs.pop("num_retries", 3) retry_strategy = kwargs.pop("retry_strategy", "constant_retry") original_function = kwargs.pop("original_function", completion) - if retry_strategy == "constant_retry": - retryer = tenacity.Retrying(stop=tenacity.stop_after_attempt(num_retries), reraise=True) - elif retry_strategy == "exponential_backoff_retry": - retryer = tenacity.Retrying(wait=tenacity.wait_exponential(multiplier=1, max=10), stop=tenacity.stop_after_attempt(num_retries), reraise=True) + if retry_strategy == "constant_retry": + retryer = tenacity.Retrying( + stop=tenacity.stop_after_attempt(num_retries), reraise=True + ) + elif retry_strategy == "exponential_backoff_retry": + retryer = tenacity.Retrying( + wait=tenacity.wait_exponential(multiplier=1, max=10), + stop=tenacity.stop_after_attempt(num_retries), + reraise=True, + ) return retryer(original_function, *args, **kwargs) + async def acompletion_with_retries(*args, **kwargs): """ Executes a litellm.completion() with 3 retries @@ -1549,19 +1841,26 @@ async def acompletion_with_retries(*args, **kwargs): try: import tenacity except Exception as e: - raise Exception(f"tenacity import failed please run `pip install tenacity`. Error{e}") - + raise Exception( + f"tenacity import failed please run `pip install tenacity`. Error{e}" + ) + num_retries = kwargs.pop("num_retries", 3) retry_strategy = kwargs.pop("retry_strategy", "constant_retry") original_function = kwargs.pop("original_function", completion) - if retry_strategy == "constant_retry": - retryer = tenacity.Retrying(stop=tenacity.stop_after_attempt(num_retries), reraise=True) - elif retry_strategy == "exponential_backoff_retry": - retryer = tenacity.Retrying(wait=tenacity.wait_exponential(multiplier=1, max=10), stop=tenacity.stop_after_attempt(num_retries), reraise=True) + if retry_strategy == "constant_retry": + retryer = tenacity.Retrying( + stop=tenacity.stop_after_attempt(num_retries), reraise=True + ) + elif retry_strategy == "exponential_backoff_retry": + retryer = tenacity.Retrying( + wait=tenacity.wait_exponential(multiplier=1, max=10), + stop=tenacity.stop_after_attempt(num_retries), + reraise=True, + ) return await retryer(original_function, *args, **kwargs) - def batch_completion( model: str, # Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create @@ -1575,13 +1874,15 @@ def batch_completion( stop=None, max_tokens: Optional[float] = None, presence_penalty: Optional[float] = None, - frequency_penalty: Optional[float]=None, + frequency_penalty: Optional[float] = None, logit_bias: Optional[dict] = None, user: Optional[str] = None, - deployment_id = None, + deployment_id=None, request_timeout: Optional[int] = None, + timeout: Optional[int] = 600, # Optional liteLLM function params - **kwargs): + **kwargs, +): """ Batch litellm.completion function for a given model. @@ -1630,15 +1931,22 @@ def batch_completion( user=user, # params to identify the model model=model, - custom_llm_provider=custom_llm_provider + custom_llm_provider=custom_llm_provider, ) - results = vllm.batch_completions(model=model, messages=batch_messages, custom_prompt_dict=litellm.custom_prompt_dict, optional_params=optional_params) - # all non VLLM models for batch completion models + results = vllm.batch_completions( + model=model, + messages=batch_messages, + custom_prompt_dict=litellm.custom_prompt_dict, + optional_params=optional_params, + ) + # all non VLLM models for batch completion models else: + def chunks(lst, n): """Yield successive n-sized chunks from lst.""" for i in range(0, len(lst), n): - yield lst[i:i + n] + yield lst[i : i + n] + with ThreadPoolExecutor(max_workers=100) as executor: for sub_batch in chunks(batch_messages, 100): for message_list in sub_batch: @@ -1647,13 +1955,16 @@ def batch_completion( original_kwargs = {} if "kwargs" in kwargs_modified: original_kwargs = kwargs_modified.pop("kwargs") - future = executor.submit(completion, **kwargs_modified, **original_kwargs) + future = executor.submit( + completion, **kwargs_modified, **original_kwargs + ) completions.append(future) # Retrieve the results from the futures results = [future.result() for future in completions] return results + # send one request to multiple models # return as soon as one of the llms responds def batch_completion_models(*args, **kwargs): @@ -1675,6 +1986,7 @@ def batch_completion_models(*args, **kwargs): It sends requests concurrently and returns the response from the first model that responds. """ import concurrent + if "model" in kwargs: kwargs.pop("model") if "models" in kwargs: @@ -1683,21 +1995,29 @@ def batch_completion_models(*args, **kwargs): futures = {} with concurrent.futures.ThreadPoolExecutor(max_workers=len(models)) as executor: for model in models: - futures[model] = executor.submit(completion, *args, model=model, **kwargs) + futures[model] = executor.submit( + completion, *args, model=model, **kwargs + ) - for model, future in sorted(futures.items(), key=lambda x: models.index(x[0])): + for model, future in sorted( + futures.items(), key=lambda x: models.index(x[0]) + ): if future.result() is not None: return future.result() - elif "deployments" in kwargs: + elif "deployments" in kwargs: deployments = kwargs["deployments"] kwargs.pop("deployments") kwargs.pop("model_list") nested_kwargs = kwargs.pop("kwargs", {}) futures = {} - with concurrent.futures.ThreadPoolExecutor(max_workers=len(deployments)) as executor: + with concurrent.futures.ThreadPoolExecutor( + max_workers=len(deployments) + ) as executor: for deployment in deployments: - for key in kwargs.keys(): - if key not in deployment: # don't override deployment values e.g. model name, api base, etc. + for key in kwargs.keys(): + if ( + key not in deployment + ): # don't override deployment values e.g. model name, api base, etc. deployment[key] = kwargs[key] kwargs = {**deployment, **nested_kwargs} futures[deployment["model"]] = executor.submit(completion, **kwargs) @@ -1705,7 +2025,9 @@ def batch_completion_models(*args, **kwargs): while futures: # wait for the first returned future print_verbose("\n\n waiting for next result\n\n") - done, _ = concurrent.futures.wait(futures.values(), return_when=concurrent.futures.FIRST_COMPLETED) + done, _ = concurrent.futures.wait( + futures.values(), return_when=concurrent.futures.FIRST_COMPLETED + ) print_verbose(f"done list\n{done}") for future in done: try: @@ -1713,7 +2035,9 @@ def batch_completion_models(*args, **kwargs): return result except Exception as e: # if model 1 fails, continue with response from model 2, model3 - print_verbose(f"\n\ngot an exception, ignoring, removing from futures") + print_verbose( + f"\n\ngot an exception, ignoring, removing from futures" + ) print_verbose(futures) new_futures = {} for key, value in futures.items(): @@ -1726,12 +2050,12 @@ def batch_completion_models(*args, **kwargs): print_verbose(f"new futures{futures}") continue - print_verbose("\n\ndone looping through futures\n\n") print_verbose(futures) return None # If no response is received from any model + def batch_completion_models_all_responses(*args, **kwargs): """ Send a request to multiple language models concurrently and return a list of responses @@ -1773,6 +2097,7 @@ def batch_completion_models_all_responses(*args, **kwargs): return responses + ### EMBEDDING ENDPOINTS #################### @client async def aembedding(*args, **kwargs): @@ -1788,10 +2113,10 @@ async def aembedding(*args, **kwargs): """ loop = asyncio.get_event_loop() model = args[0] if len(args) > 0 else kwargs["model"] - ### PASS ARGS TO Embedding ### + ### PASS ARGS TO Embedding ### kwargs["aembedding"] = True custom_llm_provider = None - try: + try: # Use a partial function to pass your keyword arguments func = partial(embedding, *args, **kwargs) @@ -1799,49 +2124,63 @@ async def aembedding(*args, **kwargs): ctx = contextvars.copy_context() func_with_context = partial(ctx.run, func) - _, custom_llm_provider, _, _ = get_llm_provider(model=model, api_base=kwargs.get("api_base", None)) + _, custom_llm_provider, _, _ = get_llm_provider( + model=model, api_base=kwargs.get("api_base", None) + ) - if (custom_llm_provider == "openai" - or custom_llm_provider == "azure" + if ( + custom_llm_provider == "openai" + or custom_llm_provider == "azure" + or custom_llm_provider == "xinference" + or custom_llm_provider == "voyage" + or custom_llm_provider == "mistral" or custom_llm_provider == "custom_openai" or custom_llm_provider == "anyscale" or custom_llm_provider == "openrouter" or custom_llm_provider == "deepinfra" - or custom_llm_provider == "perplexity"): # currently implemented aiohttp calls for just azure and openai, soon all. + or custom_llm_provider == "perplexity" + or custom_llm_provider == "ollama" + ): # currently implemented aiohttp calls for just azure and openai, soon all. # Await normally init_response = await loop.run_in_executor(None, func_with_context) - if isinstance(init_response, dict) or isinstance(init_response, ModelResponse): ## CACHING SCENARIO + if isinstance(init_response, dict) or isinstance( + init_response, ModelResponse + ): ## CACHING SCENARIO response = init_response elif asyncio.iscoroutine(init_response): response = await init_response - else: + else: # Call the synchronous function using run_in_executor - response = await loop.run_in_executor(None, func_with_context) + response = await loop.run_in_executor(None, func_with_context) return response - except Exception as e: + except Exception as e: custom_llm_provider = custom_llm_provider or "openai" raise exception_type( - model=model, custom_llm_provider=custom_llm_provider, original_exception=e, completion_kwargs=args, - ) + model=model, + custom_llm_provider=custom_llm_provider, + original_exception=e, + completion_kwargs=args, + ) + @client def embedding( - model, - input=[], + model, + input=[], # Optional params - timeout=600, # default to 10 minutes + timeout=600, # default to 10 minutes # set api_base, api_version, api_key api_base: Optional[str] = None, api_version: Optional[str] = None, api_key: Optional[str] = None, api_type: Optional[str] = None, - caching: bool=False, - user: Optional[str]=None, + caching: bool = False, + user: Optional[str] = None, custom_llm_provider=None, - litellm_call_id=None, + litellm_call_id=None, litellm_logging_obj=None, - logger_fn=None, - **kwargs + logger_fn=None, + **kwargs, ): """ Embedding function that calls an API to generate embeddings for the given input. @@ -1875,43 +2214,118 @@ def embedding( encoding_format = kwargs.get("encoding_format", None) proxy_server_request = kwargs.get("proxy_server_request", None) aembedding = kwargs.get("aembedding", None) - openai_params = ["user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "max_retries", "encoding_format"] - litellm_params = ["metadata", "aembedding", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "proxy_server_request", "model_info", "preset_cache_key", "caching_groups"] + openai_params = [ + "user", + "request_timeout", + "api_base", + "api_version", + "api_key", + "deployment_id", + "organization", + "base_url", + "default_headers", + "timeout", + "max_retries", + "encoding_format", + ] + litellm_params = [ + "metadata", + "aembedding", + "caching", + "mock_response", + "api_key", + "api_version", + "api_base", + "force_timeout", + "logger_fn", + "verbose", + "custom_llm_provider", + "litellm_logging_obj", + "litellm_call_id", + "use_client", + "id", + "fallbacks", + "azure", + "headers", + "model_list", + "num_retries", + "context_window_fallback_dict", + "roles", + "final_prompt_value", + "bos_token", + "eos_token", + "request_timeout", + "complete_response", + "self", + "client", + "rpm", + "tpm", + "input_cost_per_token", + "output_cost_per_token", + "hf_model_name", + "proxy_server_request", + "model_info", + "preset_cache_key", + "caching_groups", + "ttl", + "cache", + ] default_params = openai_params + litellm_params - non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider - - model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base, api_key=api_key) - optional_params = get_optional_params_embeddings(user=user, encoding_format=encoding_format, custom_llm_provider=custom_llm_provider, **non_default_params) + non_default_params = { + k: v for k, v in kwargs.items() if k not in default_params + } # model-specific params - pass them straight to the model/provider + + model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider( + model=model, + custom_llm_provider=custom_llm_provider, + api_base=api_base, + api_key=api_key, + ) + optional_params = get_optional_params_embeddings( + user=user, + encoding_format=encoding_format, + custom_llm_provider=custom_llm_provider, + **non_default_params, + ) try: response = None logging = litellm_logging_obj - logging.update_environment_variables(model=model, user=user, optional_params=optional_params, litellm_params={"timeout": timeout, "azure": azure, "litellm_call_id": litellm_call_id, "logger_fn": logger_fn, "proxy_server_request": proxy_server_request, "model_info": model_info, "metadata": metadata, "aembedding": aembedding, "preset_cache_key": None, "stream_response": {}}) + logging.update_environment_variables( + model=model, + user=user, + optional_params=optional_params, + litellm_params={ + "timeout": timeout, + "azure": azure, + "litellm_call_id": litellm_call_id, + "logger_fn": logger_fn, + "proxy_server_request": proxy_server_request, + "model_info": model_info, + "metadata": metadata, + "aembedding": aembedding, + "preset_cache_key": None, + "stream_response": {}, + }, + ) if azure == True or custom_llm_provider == "azure": # azure configs api_type = get_secret("AZURE_API_TYPE") or "azure" - api_base = ( - api_base - or litellm.api_base - or get_secret("AZURE_API_BASE") - ) + api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE") api_version = ( - api_version or - litellm.api_version or - get_secret("AZURE_API_VERSION") + api_version or litellm.api_version or get_secret("AZURE_API_VERSION") ) - azure_ad_token = ( - kwargs.pop("azure_ad_token", None) or - get_secret("AZURE_AD_TOKEN") + azure_ad_token = kwargs.pop("azure_ad_token", None) or get_secret( + "AZURE_AD_TOKEN" ) api_key = ( - api_key or - litellm.api_key or - litellm.azure_key or - get_secret("AZURE_API_KEY") + api_key + or litellm.api_key + or litellm.azure_key + or get_secret("AZURE_API_KEY") ) ## EMBEDDING CALL response = azure_chat_completions.embedding( @@ -1923,12 +2337,14 @@ def embedding( azure_ad_token=azure_ad_token, logging_obj=logging, timeout=timeout, - model_response=EmbeddingResponse(), + model_response=EmbeddingResponse(), optional_params=optional_params, client=client, - aembedding=aembedding + aembedding=aembedding, ) - elif model in litellm.open_ai_embedding_models or custom_llm_provider == "openai": + elif ( + model in litellm.open_ai_embedding_models or custom_llm_provider == "openai" + ): api_base = ( api_base or litellm.api_base @@ -1938,19 +2354,18 @@ def embedding( openai.organization = ( litellm.organization or get_secret("OPENAI_ORGANIZATION") - or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105 + or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105 ) # set API KEY api_key = ( - api_key or - litellm.api_key or - litellm.openai_key or - get_secret("OPENAI_API_KEY") + api_key + or litellm.api_key + or litellm.openai_key + or get_secret("OPENAI_API_KEY") ) api_type = "openai" api_version = None - ## EMBEDDING CALL response = openai_chat_completions.embedding( model=model, @@ -1959,7 +2374,7 @@ def embedding( api_key=api_key, logging_obj=logging, timeout=timeout, - model_response=EmbeddingResponse(), + model_response=EmbeddingResponse(), optional_params=optional_params, client=client, aembedding=aembedding, @@ -1979,8 +2394,7 @@ def embedding( encoding=encoding, api_key=cohere_key, logging_obj=logging, - model_response= EmbeddingResponse() - + model_response=EmbeddingResponse(), ) elif custom_llm_provider == "huggingface": api_key = ( @@ -1996,7 +2410,7 @@ def embedding( api_key=api_key, api_base=api_base, logging_obj=logging, - model_response= EmbeddingResponse() + model_response=EmbeddingResponse(), ) elif custom_llm_provider == "bedrock": response = bedrock.embedding( @@ -2005,17 +2419,90 @@ def embedding( encoding=encoding, logging_obj=logging, optional_params=optional_params, - model_response= EmbeddingResponse() + model_response=EmbeddingResponse(), ) - elif custom_llm_provider == "sagemaker": + elif custom_llm_provider == "oobabooga": + response = oobabooga.embedding( + model=model, + input=input, + encoding=encoding, + api_base=api_base, + logging_obj=logging, + optional_params=optional_params, + model_response=EmbeddingResponse(), + ) + elif custom_llm_provider == "ollama": + if aembedding == True: + response = ollama.ollama_aembeddings( + model=model, + prompt=input, + encoding=encoding, + logging_obj=logging, + optional_params=optional_params, + model_response=EmbeddingResponse(), + ) + elif custom_llm_provider == "sagemaker": response = sagemaker.embedding( model=model, input=input, encoding=encoding, logging_obj=logging, optional_params=optional_params, - model_response= EmbeddingResponse(), - print_verbose=print_verbose + model_response=EmbeddingResponse(), + print_verbose=print_verbose, + ) + elif custom_llm_provider == "mistral": + api_key = api_key or litellm.api_key or get_secret("MISTRAL_API_KEY") + response = openai_chat_completions.embedding( + model=model, + input=input, + api_base=api_base, + api_key=api_key, + logging_obj=logging, + timeout=timeout, + model_response=EmbeddingResponse(), + optional_params=optional_params, + client=client, + aembedding=aembedding, + ) + elif custom_llm_provider == "voyage": + api_key = api_key or litellm.api_key or get_secret("VOYAGE_API_KEY") + response = openai_chat_completions.embedding( + model=model, + input=input, + api_base=api_base, + api_key=api_key, + logging_obj=logging, + timeout=timeout, + model_response=EmbeddingResponse(), + optional_params=optional_params, + client=client, + aembedding=aembedding, + ) + elif custom_llm_provider == "xinference": + api_key = ( + api_key + or litellm.api_key + or get_secret("XINFERENCE_API_KEY") + or "stub-xinference-key" + ) # xinference does not need an api key, pass a stub key if user did not set one + api_base = ( + api_base + or litellm.api_base + or get_secret("XINFERENCE_API_BASE") + or "http://127.0.0.1:9997/v1" + ) + response = openai_chat_completions.embedding( + model=model, + input=input, + api_base=api_base, + api_key=api_key, + logging_obj=logging, + timeout=timeout, + model_response=EmbeddingResponse(), + optional_params=optional_params, + client=client, + aembedding=aembedding, ) else: args = locals() @@ -2037,16 +2524,17 @@ def embedding( ###### Text Completion ################ +@client async def atext_completion(*args, **kwargs): """ - Implemented to handle async streaming for the text completion endpoint + Implemented to handle async streaming for the text completion endpoint """ loop = asyncio.get_event_loop() model = args[0] if len(args) > 0 else kwargs["model"] - ### PASS ARGS TO COMPLETION ### + ### PASS ARGS TO COMPLETION ### kwargs["acompletion"] = True custom_llm_provider = None - try: + try: # Use a partial function to pass your keyword arguments func = partial(text_completion, *args, **kwargs) @@ -2054,10 +2542,13 @@ async def atext_completion(*args, **kwargs): ctx = contextvars.copy_context() func_with_context = partial(ctx.run, func) - _, custom_llm_provider, _, _ = get_llm_provider(model=model, api_base=kwargs.get("api_base", None)) + _, custom_llm_provider, _, _ = get_llm_provider( + model=model, api_base=kwargs.get("api_base", None) + ) - if (custom_llm_provider == "openai" - or custom_llm_provider == "azure" + if ( + custom_llm_provider == "openai" + or custom_llm_provider == "azure" or custom_llm_provider == "custom_openai" or custom_llm_provider == "anyscale" or custom_llm_provider == "mistral" @@ -2067,58 +2558,92 @@ async def atext_completion(*args, **kwargs): or custom_llm_provider == "text-completion-openai" or custom_llm_provider == "huggingface" or custom_llm_provider == "ollama" - or custom_llm_provider == "vertex_ai"): # currently implemented aiohttp calls for just azure and openai, soon all. - if kwargs.get("stream", False): - response = text_completion(*args, **kwargs) - else: - # Await normally - response = await loop.run_in_executor(None, func_with_context) - if asyncio.iscoroutine(response): - response = await response - else: + or custom_llm_provider == "vertex_ai" + ): # currently implemented aiohttp calls for just azure and openai, soon all. + # Await normally + response = await loop.run_in_executor(None, func_with_context) + if asyncio.iscoroutine(response): + response = await response + else: # Call the synchronous function using run_in_executor - response = await loop.run_in_executor(None, func_with_context) - if kwargs.get("stream", False): # return an async generator - return _async_streaming(response=response, model=model, custom_llm_provider=custom_llm_provider, args=args) - else: + response = await loop.run_in_executor(None, func_with_context) + if kwargs.get("stream", False) == True: # return an async generator + return TextCompletionStreamWrapper( + completion_stream=_async_streaming( + response=response, + model=model, + custom_llm_provider=custom_llm_provider, + args=args, + ), + model=model, + ) + else: return response - except Exception as e: + except Exception as e: custom_llm_provider = custom_llm_provider or "openai" raise exception_type( - model=model, custom_llm_provider=custom_llm_provider, original_exception=e, completion_kwargs=args, - ) + model=model, + custom_llm_provider=custom_llm_provider, + original_exception=e, + completion_kwargs=args, + ) + +@client def text_completion( - prompt: Union[str, List[Union[str, List[Union[str, List[int]]]]]], # Required: The prompt(s) to generate completions for. - model: Optional[str]=None, # Optional: either `model` or `engine` can be set - best_of: Optional[int] = None, # Optional: Generates best_of completions server-side. - echo: Optional[bool] = None, # Optional: Echo back the prompt in addition to the completion. - frequency_penalty: Optional[float] = None, # Optional: Penalize new tokens based on their existing frequency. - logit_bias: Optional[Dict[int, int]] = None, # Optional: Modify the likelihood of specified tokens. - logprobs: Optional[int] = None, # Optional: Include the log probabilities on the most likely tokens. - max_tokens: Optional[int] = None, # Optional: The maximum number of tokens to generate in the completion. - n: Optional[int] = None, # Optional: How many completions to generate for each prompt. - presence_penalty: Optional[float] = None, # Optional: Penalize new tokens based on whether they appear in the text so far. - stop: Optional[Union[str, List[str]]] = None, # Optional: Sequences where the API will stop generating further tokens. - stream: Optional[bool] = None, # Optional: Whether to stream back partial progress. - suffix: Optional[str] = None, # Optional: The suffix that comes after a completion of inserted text. - temperature: Optional[float] = None, # Optional: Sampling temperature to use. - top_p: Optional[float] = None, # Optional: Nucleus sampling parameter. - user: Optional[str] = None, # Optional: A unique identifier representing your end-user. - + prompt: Union[ + str, List[Union[str, List[Union[str, List[int]]]]] + ], # Required: The prompt(s) to generate completions for. + model: Optional[str] = None, # Optional: either `model` or `engine` can be set + best_of: Optional[ + int + ] = None, # Optional: Generates best_of completions server-side. + echo: Optional[ + bool + ] = None, # Optional: Echo back the prompt in addition to the completion. + frequency_penalty: Optional[ + float + ] = None, # Optional: Penalize new tokens based on their existing frequency. + logit_bias: Optional[ + Dict[int, int] + ] = None, # Optional: Modify the likelihood of specified tokens. + logprobs: Optional[ + int + ] = None, # Optional: Include the log probabilities on the most likely tokens. + max_tokens: Optional[ + int + ] = None, # Optional: The maximum number of tokens to generate in the completion. + n: Optional[ + int + ] = None, # Optional: How many completions to generate for each prompt. + presence_penalty: Optional[ + float + ] = None, # Optional: Penalize new tokens based on whether they appear in the text so far. + stop: Optional[ + Union[str, List[str]] + ] = None, # Optional: Sequences where the API will stop generating further tokens. + stream: Optional[bool] = None, # Optional: Whether to stream back partial progress. + suffix: Optional[ + str + ] = None, # Optional: The suffix that comes after a completion of inserted text. + temperature: Optional[float] = None, # Optional: Sampling temperature to use. + top_p: Optional[float] = None, # Optional: Nucleus sampling parameter. + user: Optional[ + str + ] = None, # Optional: A unique identifier representing your end-user. # set api_base, api_version, api_key api_base: Optional[str] = None, api_version: Optional[str] = None, api_key: Optional[str] = None, - model_list: Optional[list] = None, # pass in a list of api_base,keys, etc. - + model_list: Optional[list] = None, # pass in a list of api_base,keys, etc. # Optional liteLLM function params custom_llm_provider: Optional[str] = None, - *args, - **kwargs + *args, + **kwargs, ): global print_verbose import copy + """ Generate text completions using the OpenAI API. @@ -2145,8 +2670,8 @@ def text_completion( Example: Your example of how to use this function goes here. """ - if "engine" in kwargs: - if model==None: + if "engine" in kwargs: + if model == None: # only use engine when model not passed model = kwargs["engine"] kwargs.pop("engine") @@ -2193,7 +2718,7 @@ def text_completion( optional_params["custom_llm_provider"] = custom_llm_provider # get custom_llm_provider - _, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore + _, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore if custom_llm_provider == "huggingface": # if echo == True, for TGI llms we need to set top_n_tokens to 3 @@ -2205,17 +2730,19 @@ def text_completion( # processing prompt - users can pass raw tokens to OpenAI Completion() if type(prompt) == list: import concurrent.futures + tokenizer = tiktoken.encoding_for_model("text-davinci-003") ## if it's a 2d list - each element in the list is a text_completion() request if len(prompt) > 0 and type(prompt[0]) == list: - responses = [None for x in prompt] # init responses + responses = [None for x in prompt] # init responses + def process_prompt(i, individual_prompt): decoded_prompt = tokenizer.decode(individual_prompt) all_params = {**kwargs, **optional_params} response = text_completion( model=model, prompt=decoded_prompt, - num_retries=3,# ensure this does not fail for the batch + num_retries=3, # ensure this does not fail for the batch *args, **all_params, ) @@ -2225,32 +2752,38 @@ def text_completion( text_completion_response["created"] = response.get("created", None) text_completion_response["model"] = response.get("model", None) return response["choices"][0] + with concurrent.futures.ThreadPoolExecutor() as executor: - futures = [executor.submit(process_prompt, i, individual_prompt) for i, individual_prompt in enumerate(prompt)] - for i, future in enumerate(concurrent.futures.as_completed(futures)): + futures = [ + executor.submit(process_prompt, i, individual_prompt) + for i, individual_prompt in enumerate(prompt) + ] + for i, future in enumerate( + concurrent.futures.as_completed(futures) + ): responses[i] = future.result() - text_completion_response.choices = responses + text_completion_response.choices = responses return text_completion_response # else: - # check if non default values passed in for best_of, echo, logprobs, suffix + # check if non default values passed in for best_of, echo, logprobs, suffix # these are the params supported by Completion() but not ChatCompletion - + # default case, non OpenAI requests go through here messages = [{"role": "system", "content": prompt}] kwargs.pop("prompt", None) response = completion( - model = model, + model=model, messages=messages, *args, **kwargs, **optional_params, ) + if kwargs.get("acompletion", False) == True: + return response if stream == True or kwargs.get("stream", False) == True: response = TextCompletionStreamWrapper(completion_stream=response, model=model) return response - if kwargs.get("acompletion", False) == True: - return response transformed_logprobs = None # only supported for TGI models try: @@ -2271,22 +2804,21 @@ def text_completion( text_completion_response["usage"] = response.get("usage", None) return text_completion_response + ##### Moderation ####################### -def moderation(input: str, api_key: Optional[str]=None): +def moderation(input: str, api_key: Optional[str] = None): # only supports open ai for now api_key = ( - api_key or - litellm.api_key or - litellm.openai_key or - get_secret("OPENAI_API_KEY") - ) + api_key or litellm.api_key or litellm.openai_key or get_secret("OPENAI_API_KEY") + ) openai.api_key = api_key - openai.api_type = "open_ai" # type: ignore + openai.api_type = "open_ai" # type: ignore openai.api_version = None openai.base_url = "https://api.openai.com/v1/" response = openai.moderations.create(input=input) return response + ##### Image Generation ####################### @client async def aimage_generation(*args, **kwargs): @@ -2302,10 +2834,10 @@ async def aimage_generation(*args, **kwargs): """ loop = asyncio.get_event_loop() model = args[0] if len(args) > 0 else kwargs["model"] - ### PASS ARGS TO Image Generation ### + ### PASS ARGS TO Image Generation ### kwargs["aimg_generation"] = True custom_llm_provider = None - try: + try: # Use a partial function to pass your keyword arguments func = partial(image_generation, *args, **kwargs) @@ -2313,117 +2845,321 @@ async def aimage_generation(*args, **kwargs): ctx = contextvars.copy_context() func_with_context = partial(ctx.run, func) - _, custom_llm_provider, _, _ = get_llm_provider(model=model, api_base=kwargs.get("api_base", None)) - + _, custom_llm_provider, _, _ = get_llm_provider( + model=model, api_base=kwargs.get("api_base", None) + ) + # Await normally init_response = await loop.run_in_executor(None, func_with_context) - if isinstance(init_response, dict) or isinstance(init_response, ModelResponse): ## CACHING SCENARIO + if isinstance(init_response, dict) or isinstance( + init_response, ModelResponse + ): ## CACHING SCENARIO response = init_response elif asyncio.iscoroutine(init_response): response = await init_response - else: + else: # Call the synchronous function using run_in_executor - response = await loop.run_in_executor(None, func_with_context) + response = await loop.run_in_executor(None, func_with_context) return response - except Exception as e: + except Exception as e: custom_llm_provider = custom_llm_provider or "openai" raise exception_type( - model=model, custom_llm_provider=custom_llm_provider, original_exception=e, completion_kwargs=args, - ) - -@client -def image_generation(prompt: str, - model: Optional[str]=None, - n: Optional[int]=None, - quality: Optional[str]=None, - response_format: Optional[str]=None, - size: Optional[str]=None, - style: Optional[str]=None, - user: Optional[str]=None, - timeout=600, # default to 10 minutes - api_key: Optional[str]=None, - api_base: Optional[str]=None, - api_version: Optional[str] = None, - litellm_logging_obj=None, - custom_llm_provider=None, - **kwargs): - """ - Maps the https://api.openai.com/v1/images/generations endpoint. + model=model, + custom_llm_provider=custom_llm_provider, + original_exception=e, + completion_kwargs=args, + ) - Currently supports just Azure + OpenAI. + +@client +def image_generation( + prompt: str, + model: Optional[str] = None, + n: Optional[int] = None, + quality: Optional[str] = None, + response_format: Optional[str] = None, + size: Optional[str] = None, + style: Optional[str] = None, + user: Optional[str] = None, + timeout=600, # default to 10 minutes + api_key: Optional[str] = None, + api_base: Optional[str] = None, + api_version: Optional[str] = None, + litellm_logging_obj=None, + custom_llm_provider=None, + **kwargs, +): + """ + Maps the https://api.openai.com/v1/images/generations endpoint. + + Currently supports just Azure + OpenAI. """ aimg_generation = kwargs.get("aimg_generation", False) litellm_call_id = kwargs.get("litellm_call_id", None) logger_fn = kwargs.get("logger_fn", None) - proxy_server_request = kwargs.get('proxy_server_request', None) + proxy_server_request = kwargs.get("proxy_server_request", None) model_info = kwargs.get("model_info", None) metadata = kwargs.get("metadata", {}) model_response = litellm.utils.ImageResponse() - if model is not None or custom_llm_provider is not None: - model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore - else: + if model is not None or custom_llm_provider is not None: + model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore + else: model = "dall-e-2" - custom_llm_provider = "openai" # default to dall-e-2 on openai - openai_params = ["user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "max_retries", "n", "quality", "size", "style"] - litellm_params = ["metadata", "aimg_generation", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "proxy_server_request", "model_info", "preset_cache_key", "caching_groups"] + custom_llm_provider = "openai" # default to dall-e-2 on openai + openai_params = [ + "user", + "request_timeout", + "api_base", + "api_version", + "api_key", + "deployment_id", + "organization", + "base_url", + "default_headers", + "timeout", + "max_retries", + "n", + "quality", + "size", + "style", + ] + litellm_params = [ + "metadata", + "aimg_generation", + "caching", + "mock_response", + "api_key", + "api_version", + "api_base", + "force_timeout", + "logger_fn", + "verbose", + "custom_llm_provider", + "litellm_logging_obj", + "litellm_call_id", + "use_client", + "id", + "fallbacks", + "azure", + "headers", + "model_list", + "num_retries", + "context_window_fallback_dict", + "roles", + "final_prompt_value", + "bos_token", + "eos_token", + "request_timeout", + "complete_response", + "self", + "client", + "rpm", + "tpm", + "input_cost_per_token", + "output_cost_per_token", + "hf_model_name", + "proxy_server_request", + "model_info", + "preset_cache_key", + "caching_groups", + "ttl", + "cache", + ] default_params = openai_params + litellm_params - non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider - optional_params = get_optional_params_image_gen(n=n, - quality=quality, - response_format=response_format, - size=size, - style=style, - user=user, - custom_llm_provider=custom_llm_provider, - **non_default_params) + non_default_params = { + k: v for k, v in kwargs.items() if k not in default_params + } # model-specific params - pass them straight to the model/provider + optional_params = get_optional_params_image_gen( + n=n, + quality=quality, + response_format=response_format, + size=size, + style=style, + user=user, + custom_llm_provider=custom_llm_provider, + **non_default_params, + ) logging = litellm_logging_obj - logging.update_environment_variables(model=model, user=user, optional_params=optional_params, litellm_params={"timeout": timeout, "azure": False, "litellm_call_id": litellm_call_id, "logger_fn": logger_fn, "proxy_server_request": proxy_server_request, "model_info": model_info, "metadata": metadata, "preset_cache_key": None, "stream_response": {}}) + logging.update_environment_variables( + model=model, + user=user, + optional_params=optional_params, + litellm_params={ + "timeout": timeout, + "azure": False, + "litellm_call_id": litellm_call_id, + "logger_fn": logger_fn, + "proxy_server_request": proxy_server_request, + "model_info": model_info, + "metadata": metadata, + "preset_cache_key": None, + "stream_response": {}, + }, + ) if custom_llm_provider == "azure": # azure configs api_type = get_secret("AZURE_API_TYPE") or "azure" - api_base = ( - api_base - or litellm.api_base - or get_secret("AZURE_API_BASE") - ) + api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE") api_version = ( - api_version or - litellm.api_version or - get_secret("AZURE_API_VERSION") + api_version or litellm.api_version or get_secret("AZURE_API_VERSION") ) api_key = ( - api_key or - litellm.api_key or - litellm.azure_key or - get_secret("AZURE_OPENAI_API_KEY") or - get_secret("AZURE_API_KEY") + api_key + or litellm.api_key + or litellm.azure_key + or get_secret("AZURE_OPENAI_API_KEY") + or get_secret("AZURE_API_KEY") ) - azure_ad_token = ( - optional_params.pop("azure_ad_token", None) or - get_secret("AZURE_AD_TOKEN") + azure_ad_token = optional_params.pop("azure_ad_token", None) or get_secret( + "AZURE_AD_TOKEN" ) - model_response = azure_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response, api_version = api_version, aimg_generation=aimage_generation) + model_response = azure_chat_completions.image_generation( + model=model, + prompt=prompt, + timeout=timeout, + api_key=api_key, + api_base=api_base, + logging_obj=litellm_logging_obj, + optional_params=optional_params, + model_response=model_response, + api_version=api_version, + aimg_generation=aimg_generation, + ) elif custom_llm_provider == "openai": - model_response = openai_chat_completions.image_generation(model=model, prompt=prompt, timeout=timeout, api_key=api_key, api_base=api_base, logging_obj=litellm_logging_obj, optional_params=optional_params, model_response = model_response, aimg_generation=aimage_generation) + model_response = openai_chat_completions.image_generation( + model=model, + prompt=prompt, + timeout=timeout, + api_key=api_key, + api_base=api_base, + logging_obj=litellm_logging_obj, + optional_params=optional_params, + model_response=model_response, + aimg_generation=aimg_generation, + ) return model_response + +##### Health Endpoints ####################### + + +async def ahealth_check( + model_params: dict, + mode: Optional[Literal["completion", "embedding", "image_generation"]] = None, + prompt: Optional[str] = None, + input: Optional[List] = None, + default_timeout: float = 6000, +): + """ + Support health checks for different providers. Return remaining rate limit, etc. + + For azure/openai -> completion.with_raw_response + For rest -> litellm.acompletion() + """ + try: + model: Optional[str] = model_params.get("model", None) + + if model is None: + raise Exception("model not set") + + model, custom_llm_provider, _, _ = get_llm_provider(model=model) + mode = mode or "completion" # default to completion calls + + if custom_llm_provider == "azure": + api_key = ( + model_params.get("api_key") + or get_secret("AZURE_API_KEY") + or get_secret("AZURE_OPENAI_API_KEY") + ) + + api_base = ( + model_params.get("api_base") + or get_secret("AZURE_API_BASE") + or get_secret("AZURE_OPENAI_API_BASE") + ) + + api_version = ( + model_params.get("api_version") + or get_secret("AZURE_API_VERSION") + or get_secret("AZURE_OPENAI_API_VERSION") + ) + + timeout = ( + model_params.get("timeout") + or litellm.request_timeout + or default_timeout + ) + + response = await azure_chat_completions.ahealth_check( + model=model, + messages=model_params.get( + "messages", None + ), # Replace with your actual messages list + api_key=api_key, + api_base=api_base, + api_version=api_version, + timeout=timeout, + mode=mode, + prompt=prompt, + input=input, + ) + elif custom_llm_provider == "openai": + api_key = model_params.get("api_key") or get_secret("OPENAI_API_KEY") + + timeout = ( + model_params.get("timeout") + or litellm.request_timeout + or default_timeout + ) + + response = await openai_chat_completions.ahealth_check( + model=model, + messages=model_params.get( + "messages", None + ), # Replace with your actual messages list + api_key=api_key, + timeout=timeout, + mode=mode, + prompt=prompt, + input=input, + ) + else: + if mode == "embedding": + model_params.pop("messages", None) + model_params["input"] = input + await litellm.aembedding(**model_params) + response = {} + elif mode == "image_generation": + model_params.pop("messages", None) + model_params["prompt"] = prompt + await litellm.aimage_generation(**model_params) + response = {} + else: # default to completion calls + await acompletion(**model_params) + response = {} # args like remaining ratelimit etc. + return response + except Exception as e: + return {"error": str(e)} + + ####### HELPER FUNCTIONS ################ ## Set verbose to true -> ```litellm.set_verbose = True``` def print_verbose(print_statement): try: if litellm.set_verbose: - print(print_statement) # noqa + print(print_statement) # noqa except: pass + def config_completion(**kwargs): if litellm.config_path != None: config_args = read_config_args(litellm.config_path) @@ -2434,12 +3170,79 @@ def config_completion(**kwargs): "No config path set, please set a config path using `litellm.config_path = 'path/to/config.json'`" ) -def stream_chunk_builder(chunks: list, messages: Optional[list]=None): +def stream_chunk_builder_text_completion(chunks: list, messages: Optional[List]=None): id = chunks[0]["id"] object = chunks[0]["object"] created = chunks[0]["created"] model = chunks[0]["model"] system_fingerprint = chunks[0].get("system_fingerprint", None) + finish_reason = chunks[-1]["choices"][0]["finish_reason"] + logprobs = chunks[-1]["choices"][0]["logprobs"] + + response = { + "id": id, + "object": object, + "created": created, + "model": model, + "system_fingerprint": system_fingerprint, + "choices": [ + { + "text": None, + "index": 0, + "logprobs": logprobs, + "finish_reason": finish_reason + } + ], + "usage": { + "prompt_tokens": None, + "completion_tokens": None, + "total_tokens": None + } + } + content_list = [] + for chunk in chunks: + choices = chunk["choices"] + for choice in choices: + if choice is not None and hasattr(choice, "text") and choice.get("text") is not None: + _choice = choice.get("text") + content_list.append(_choice) + + # Combine the "content" strings into a single string || combine the 'function' strings into a single string + combined_content = "".join(content_list) + + # Update the "content" field within the response dictionary + response["choices"][0]["text"] = combined_content + + if len(combined_content) > 0: + completion_output = combined_content + else: + completion_output = "" + # # Update usage information if needed + try: + response["usage"]["prompt_tokens"] = token_counter( + model=model, messages=messages + ) + except: # don't allow this failing to block a complete streaming response from being returned + print_verbose(f"token_counter failed, assuming prompt tokens is 0") + response["usage"]["prompt_tokens"] = 0 + response["usage"]["completion_tokens"] = token_counter( + model=model, + text=combined_content, + count_response_tokens=True, # count_response_tokens is a Flag to tell token counter this is a response, No need to add extra tokens we do for input messages + ) + response["usage"]["total_tokens"] = ( + response["usage"]["prompt_tokens"] + response["usage"]["completion_tokens"] + ) + return response + +def stream_chunk_builder(chunks: list, messages: Optional[list] = None): + id = chunks[0]["id"] + object = chunks[0]["object"] + created = chunks[0]["created"] + model = chunks[0]["model"] + system_fingerprint = chunks[0].get("system_fingerprint", None) + if isinstance(chunks[0]["choices"][0], litellm.utils.TextChoices): # route to the text completion logic + return stream_chunk_builder_text_completion(chunks=chunks, messages=messages) role = chunks[0]["choices"][0]["delta"]["role"] finish_reason = chunks[-1]["choices"][0]["finish_reason"] @@ -2453,18 +3256,15 @@ def stream_chunk_builder(chunks: list, messages: Optional[list]=None): "choices": [ { "index": 0, - "message": { - "role": role, - "content": "" - }, + "message": {"role": role, "content": ""}, "finish_reason": finish_reason, } ], "usage": { "prompt_tokens": 0, # Modify as needed "completion_tokens": 0, # Modify as needed - "total_tokens": 0 # Modify as needed - } + "total_tokens": 0, # Modify as needed + }, } # Extract the "content" strings from the nested dictionaries within "choices" @@ -2472,7 +3272,10 @@ def stream_chunk_builder(chunks: list, messages: Optional[list]=None): combined_content = "" combined_arguments = "" - if "tool_calls" in chunks[0]["choices"][0]["delta"] and chunks[0]["choices"][0]["delta"]["tool_calls"] is not None: + if ( + "tool_calls" in chunks[0]["choices"][0]["delta"] + and chunks[0]["choices"][0]["delta"]["tool_calls"] is not None + ): argument_list = [] delta = chunks[0]["choices"][0]["delta"] message = response["choices"][0]["message"] @@ -2503,22 +3306,38 @@ def stream_chunk_builder(chunks: list, messages: Optional[list]=None): # Now, tool_calls is expected to be a dictionary arguments = tool_calls[0].function.arguments argument_list.append(arguments) - if tool_calls[0].function.name: + if tool_calls[0].function.name: name = tool_calls[0].function.name - if tool_calls[0].type: + if tool_calls[0].type: type = tool_calls[0].type - if curr_index != prev_index: # new tool call + if curr_index != prev_index: # new tool call combined_arguments = "".join(argument_list) - tool_calls_list.append({"id": prev_id, "index": prev_index, "function": {"arguments": combined_arguments, "name": name}, "type": type}) - argument_list = [] # reset + tool_calls_list.append( + { + "id": prev_id, + "index": prev_index, + "function": {"arguments": combined_arguments, "name": name}, + "type": type, + } + ) + argument_list = [] # reset prev_index = curr_index prev_id = curr_id combined_arguments = "".join(argument_list) - tool_calls_list.append({"id": id, "function": {"arguments": combined_arguments, "name": name}, "type": type}) - response["choices"][0]["message"]["content"] = None + tool_calls_list.append( + { + "id": id, + "function": {"arguments": combined_arguments, "name": name}, + "type": type, + } + ) + response["choices"][0]["message"]["content"] = None response["choices"][0]["message"]["tool_calls"] = tool_calls_list - elif "function_call" in chunks[0]["choices"][0]["delta"] and chunks[0]["choices"][0]["delta"]["function_call"] is not None: + elif ( + "function_call" in chunks[0]["choices"][0]["delta"] + and chunks[0]["choices"][0]["delta"]["function_call"] is not None + ): argument_list = [] delta = chunks[0]["choices"][0]["delta"] function_call = delta.get("function_call", "") @@ -2533,7 +3352,7 @@ def stream_chunk_builder(chunks: list, messages: Optional[list]=None): for choice in choices: delta = choice.get("delta", {}) function_call = delta.get("function_call", "") - + # Check if a function call is present if function_call: # Now, function_call is expected to be a dictionary @@ -2542,7 +3361,9 @@ def stream_chunk_builder(chunks: list, messages: Optional[list]=None): combined_arguments = "".join(argument_list) response["choices"][0]["message"]["content"] = None - response["choices"][0]["message"]["function_call"]["arguments"] = combined_arguments + response["choices"][0]["message"]["function_call"][ + "arguments" + ] = combined_arguments else: for chunk in chunks: choices = chunk["choices"] @@ -2550,7 +3371,7 @@ def stream_chunk_builder(chunks: list, messages: Optional[list]=None): delta = choice.get("delta", {}) content = delta.get("content", "") if content == None: - continue # openai v1.0.0 sets content = None for chunks + continue # openai v1.0.0 sets content = None for chunks content_list.append(content) # Combine the "content" strings into a single string || combine the 'function' strings into a single string @@ -2558,19 +3379,29 @@ def stream_chunk_builder(chunks: list, messages: Optional[list]=None): # Update the "content" field within the response dictionary response["choices"][0]["message"]["content"] = combined_content - + if len(combined_content) > 0: completion_output = combined_content - elif len(combined_arguments) > 0: + elif len(combined_arguments) > 0: completion_output = combined_arguments - else: + else: completion_output = "" # # Update usage information if needed try: - response["usage"]["prompt_tokens"] = token_counter(model=model, messages=messages) - except: # don't allow this failing to block a complete streaming response from being returned + response["usage"]["prompt_tokens"] = token_counter( + model=model, messages=messages + ) + except: # don't allow this failing to block a complete streaming response from being returned print_verbose(f"token_counter failed, assuming prompt tokens is 0") response["usage"]["prompt_tokens"] = 0 - response["usage"]["completion_tokens"] = token_counter(model=model, text=completion_output) - response["usage"]["total_tokens"] = response["usage"]["prompt_tokens"] + response["usage"]["completion_tokens"] - return convert_to_model_response_object(response_object=response, model_response_object=litellm.ModelResponse()) + response["usage"]["completion_tokens"] = token_counter( + model=model, + text=combined_content, + count_response_tokens=True, # count_response_tokens is a Flag to tell token counter this is a response, No need to add extra tokens we do for input messages + ) + response["usage"]["total_tokens"] = ( + response["usage"]["prompt_tokens"] + response["usage"]["completion_tokens"] + ) + return convert_to_model_response_object( + response_object=response, model_response_object=litellm.ModelResponse() + ) From 9aedd4e7949b913845089b2cb44f57377dcf4464 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateo=20C=C3=A1mara?= Date: Tue, 9 Jan 2024 13:02:12 +0100 Subject: [PATCH 5/5] Moved test to a new file --- litellm/tests/test_acompletion.py | 23 +++++++++++++++++++++++ litellm/tests/test_completion.py | 24 +----------------------- 2 files changed, 24 insertions(+), 23 deletions(-) create mode 100644 litellm/tests/test_acompletion.py diff --git a/litellm/tests/test_acompletion.py b/litellm/tests/test_acompletion.py new file mode 100644 index 000000000..e5c09b9b7 --- /dev/null +++ b/litellm/tests/test_acompletion.py @@ -0,0 +1,23 @@ +import pytest +from litellm import acompletion + + +def test_acompletion_params(): + import inspect + from litellm.types.completion import CompletionRequest + + acompletion_params_odict = inspect.signature(acompletion).parameters + acompletion_params = {name: param.annotation for name, param in acompletion_params_odict.items()} + completion_params = {field_name: field_type for field_name, field_type in CompletionRequest.__annotations__.items()} + + # remove kwargs + acompletion_params.pop("kwargs", None) + + keys_acompletion = set(acompletion_params.keys()) + keys_completion = set(completion_params.keys()) + + # Assert that the parameters are the same + if keys_acompletion != keys_completion: + pytest.fail("The parameters of the acompletion function and the CompletionRequest class are not the same.") + +# test_acompletion_params() diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 6617f0530..80a51c480 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -10,7 +10,7 @@ sys.path.insert( ) # Adds the parent directory to the system path import pytest import litellm -from litellm import embedding, completion, completion_cost, Timeout, acompletion +from litellm import embedding, completion, completion_cost, Timeout from litellm import RateLimitError # litellm.num_retries = 3 @@ -859,28 +859,6 @@ def test_completion_azure_key_completion_arg(): # test_completion_azure_key_completion_arg() -def test_acompletion_params(): - import inspect - from litellm.types.completion import CompletionRequest - - acompletion_params_odict = inspect.signature(acompletion).parameters - acompletion_params = {name: param.annotation for name, param in acompletion_params_odict.items()} - completion_params = {field_name: field_type for field_name, field_type in CompletionRequest.__annotations__.items()} - - # remove kwargs - acompletion_params.pop("kwargs", None) - - keys_acompletion = set(acompletion_params.keys()) - keys_completion = set(completion_params.keys()) - - # Assert that the parameters are the same - if keys_acompletion != keys_completion: - pytest.fail("The parameters of the acompletion function and the CompletionRequest class are not the same.") - - -# test_acompletion_params() - - async def test_re_use_azure_async_client(): try: print("azure gpt-3.5 ASYNC with clie nttest\n\n")