From 2ae8e4a7929e35a63ad106dfb46bfc856d1ef77c Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Mon, 31 Jul 2023 15:56:56 -0700 Subject: [PATCH] openai at parity --- litellm/main.py | 124 +++++++++++++++++++++++-------- litellm/tests/test_completion.py | 117 +++++++++++++++++++++++++++++ 2 files changed, 209 insertions(+), 32 deletions(-) create mode 100644 litellm/tests/test_completion.py diff --git a/litellm/main.py b/litellm/main.py index c301a1eb0..c5ab46423 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -82,12 +82,69 @@ def client(original_function): raise e return wrapper + +def get_optional_params( + # 12 optional params + functions = [], + function_call = "", + temperature = 1, + top_p = 1, + n = 1, + stream = False, + stop = None, + max_tokens = float('inf'), + presence_penalty = 0, + frequency_penalty = 0, + logit_bias = {}, + user = "", +): + optional_params = {} + if functions != []: + optional_params["functions"] = functions + if function_call != "": + optional_params["function_call"] = function_call + if temperature != 1: + optional_params["temperature"] = temperature + if top_p != 1: + optional_params["top_p"] = top_p + if n != 1: + optional_params["n"] = n + if stream: + optional_params["stream"] = stream + if stop != None: + optional_params["stop"] = stop + if max_tokens != float('inf'): + optional_params["max_tokens"] = max_tokens + if presence_penalty != 0: + optional_params["presence_penalty"] = presence_penalty + if frequency_penalty != 0: + optional_params["frequency_penalty"] = frequency_penalty + if logit_bias != {}: + optional_params["logit_bias"] = logit_bias + if user != "": + optional_params["user"] = user + return optional_params + ####### COMPLETION ENDPOINTS ################ ############################################# @client @func_set_timeout(60, allowOverride=True) ## https://pypi.org/project/func-timeout/ - timeouts, in case calls hang (e.g. Azure) -def completion(model, messages, max_tokens=None, *, forceTimeout=60, azure=False, logger_fn=None): # ,*,.. requires optional params like forceTimeout, azure and logger_fn to be passed in as keyword arguments +def completion( + model, messages, # required params + # Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create + functions=[], function_call="", # optional params + temperature=1, top_p=1, n=1, stream=False, stop=None, max_tokens=float('inf'), + presence_penalty=0, frequency_penalty=0, logit_bias={}, user="", + # Optional liteLLM function params + azure=False, logger_fn=None, verbose=False + ): try: + # check if user passed in any of the OpenAI optional params + optional_params = get_optional_params( + functions=functions, function_call=function_call, + temperature=temperature, top_p=top_p, n=n, stream=stream, stop=stop, max_tokens=max_tokens, + presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, logit_bias=logit_bias, user=user + ) if azure == True: # azure configs openai.api_type = "azure" @@ -95,11 +152,39 @@ def completion(model, messages, max_tokens=None, *, forceTimeout=60, azure=False openai.api_version = os.environ.get("AZURE_API_VERSION") openai.api_key = os.environ.get("AZURE_API_KEY") ## LOGGING - logging(model=model, input=input, azure=azure, logger_fn=logger_fn) + logging(model=model, input=messages, azure=azure, logger_fn=logger_fn) ## COMPLETION CALL response = openai.ChatCompletion.create( engine=model, - messages = messages + messages = messages, + **optional_params + ) + elif model in open_ai_chat_completion_models: + openai.api_type = "openai" + openai.api_base = "https://api.openai.com/v1" + openai.api_version = None + openai.api_key = os.environ.get("OPENAI_API_KEY") + ## LOGGING + logging(model=model, input=messages, azure=azure, logger_fn=logger_fn) + + ## COMPLETION CALL + response = openai.ChatCompletion.create( + model=model, + messages = messages, + **optional_params + ) + elif model in open_ai_text_completion_models: + openai.api_type = "openai" + openai.api_base = "https://api.openai.com/v1" + openai.api_version = None + openai.api_key = os.environ.get("OPENAI_API_KEY") + prompt = " ".join([message["content"] for message in messages]) + ## LOGGING + logging(model=model, input=prompt, azure=azure, logger_fn=logger_fn) + ## COMPLETION CALL + response = openai.Completion.create( + model=model, + prompt = prompt ) elif "replicate" in model: # replicate defaults to os.environ.get("REPLICATE_API_TOKEN") @@ -109,7 +194,7 @@ def completion(model, messages, max_tokens=None, *, forceTimeout=60, azure=False os.environ["REPLICATE_API_TOKEN"] = replicate_api_token prompt = " ".join([message["content"] for message in messages]) input = {"prompt": prompt} - if max_tokens: + if max_tokens != float('inf'): input["max_length"] = max_tokens # for t5 models input["max_new_tokens"] = max_tokens # for llama2 models ## LOGGING @@ -147,9 +232,10 @@ def completion(model, messages, max_tokens=None, *, forceTimeout=60, azure=False prompt += f"{HUMAN_PROMPT}{message['content']}" prompt += f"{AI_PROMPT}" anthropic = Anthropic() - if max_tokens: + # check if user passed in max_tokens != float('inf') + if max_tokens != float('inf'): max_tokens_to_sample = max_tokens - else: + else: max_tokens_to_sample = 300 # default in Anthropic docs https://docs.anthropic.com/claude/reference/client-libraries ## LOGGING logging(model=model, input=prompt, azure=azure, additional_args={"max_tokens": max_tokens}, logger_fn=logger_fn) @@ -197,32 +283,6 @@ def completion(model, messages, max_tokens=None, *, forceTimeout=60, azure=False ], } response = new_response - - elif model in open_ai_chat_completion_models: - openai.api_type = "openai" - openai.api_base = "https://api.openai.com/v1" - openai.api_version = None - openai.api_key = os.environ.get("OPENAI_API_KEY") - ## LOGGING - logging(model=model, input=messages, azure=azure, logger_fn=logger_fn) - ## COMPLETION CALL - response = openai.ChatCompletion.create( - model=model, - messages = messages - ) - elif model in open_ai_text_completion_models: - openai.api_type = "openai" - openai.api_base = "https://api.openai.com/v1" - openai.api_version = None - openai.api_key = os.environ.get("OPENAI_API_KEY") - prompt = " ".join([message["content"] for message in messages]) - ## LOGGING - logging(model=model, input=prompt, azure=azure, logger_fn=logger_fn) - ## COMPLETION CALL - response = openai.Completion.create( - model=model, - prompt = prompt - ) else: logging(model=model, input=messages, azure=azure, logger_fn=logger_fn) return response diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py new file mode 100644 index 000000000..93da24aa8 --- /dev/null +++ b/litellm/tests/test_completion.py @@ -0,0 +1,117 @@ +import sys, os +import traceback +sys.path.append('..') # Adds the parent directory to the system path +import main +from main import completion + +main.set_verbose = True + +user_message = "Hello, how are you?" +messages = [{ "content": user_message,"role": "user"}] + +################# Test 1 ################# +# test on openai completion call, with model and messages +try: + response = completion(model="gpt-3.5-turbo", messages=messages) + print(response) +except Exception as e: + print(f"error occurred: {traceback.format_exc()}") + raise e + +################# Test 1.1 ################# +# test on openai completion call, with model and messages, optional params +try: + response = completion(model="gpt-3.5-turbo", messages=messages, temperature=0.5, top_p=0.1, user="ishaan_dev@berri.ai") + print(response) +except Exception as e: + print(f"error occurred: {traceback.format_exc()}") + raise e + +################# Test 1.2 ################# +# test on openai completion call, with model and messages, optional params +try: + response = completion(model="gpt-3.5-turbo", messages=messages, temperature=0.5, top_p=0.1, n=2, max_tokens=150, presence_penalty=0.5, frequency_penalty=-0.5, logit_bias={123:5}, user="ishaan_dev@berri.ai") + print(response) +except Exception as e: + print(f"error occurred: {traceback.format_exc()}") + raise e + + + +################# Test 1.3 ################# +# Test with Stream = True +try: + response = completion(model="gpt-3.5-turbo", messages=messages, temperature=0.5, top_p=0.1, n=2, max_tokens=150, presence_penalty=0.5, stream=True, frequency_penalty=-0.5, logit_bias={27000:5}, user="ishaan_dev@berri.ai") + print(response) +except Exception as e: + print(f"error occurred: {traceback.format_exc()}") + raise e + +################# Test 2 ################# +# test on openai completion call, with functions +function1 = [ + { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"] + } + }, + "required": ["location"] + } + } + ] +user_message = "Hello, whats the weather in San Francisco??" +messages = [{ "content": user_message,"role": "user"}] +try: + response = completion(model="gpt-3.5-turbo", messages=messages, functions=function1) + print(response) +except Exception as e: + print(f"error occurred: {traceback.format_exc()}") + raise e + + +################# Test 3 ################# +# test on Azure Openai Completion Call +try: + response = completion(model="chatgpt-test", messages=messages, azure=True) + print(response) +except Exception as e: + print(f"error occurred: {traceback.format_exc()}") + raise e + +################# Test 4 ################# +# test on Claude Completion Call +try: + response = completion(model="claude-instant-1", messages=messages) + print(response) +except Exception as e: + print(f"error occurred: {traceback.format_exc()}") + raise e + +################# Test 5 ################# +# test on Cohere Completion Call +try: + response = completion(model="command-nightly", messages=messages, max_tokens=500) + print(response) +except Exception as e: + print(f"error occurred: {traceback.format_exc()}") + raise e + +################# Test 6 ################# +# test on Replicate llama2 Completion Call +try: + model_name = "replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1" + response = completion(model=model_name, messages=messages, max_tokens=500) + print(response) +except Exception as e: + print(f"error occurred: {traceback.format_exc()}") + raise e