diff --git a/litellm/__init__.py b/litellm/__init__.py index baa3aa4af..227a7d662 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -20,6 +20,7 @@ azure_key: Optional[str] = None anthropic_key: Optional[str] = None replicate_key: Optional[str] = None cohere_key: Optional[str] = None +ai21_key: Optional[str] = None openrouter_key: Optional[str] = None huggingface_key: Optional[str] = None vertex_project: Optional[str] = None diff --git a/litellm/__pycache__/__init__.cpython-311.pyc b/litellm/__pycache__/__init__.cpython-311.pyc index df608813d..b15b88021 100644 Binary files a/litellm/__pycache__/__init__.cpython-311.pyc and b/litellm/__pycache__/__init__.cpython-311.pyc differ diff --git a/litellm/__pycache__/main.cpython-311.pyc b/litellm/__pycache__/main.cpython-311.pyc index 58ce173b3..ed8d32b96 100644 Binary files a/litellm/__pycache__/main.cpython-311.pyc and b/litellm/__pycache__/main.cpython-311.pyc differ diff --git a/litellm/__pycache__/utils.cpython-311.pyc b/litellm/__pycache__/utils.cpython-311.pyc index 773d47432..8d8cbae60 100644 Binary files a/litellm/__pycache__/utils.cpython-311.pyc and b/litellm/__pycache__/utils.cpython-311.pyc differ diff --git a/litellm/llms/ai21.py b/litellm/llms/ai21.py new file mode 100644 index 000000000..9b856be4c --- /dev/null +++ b/litellm/llms/ai21.py @@ -0,0 +1,127 @@ +import os, json +from enum import Enum +import requests +import time +from typing import Callable +from litellm.utils import ModelResponse + +class AI21Error(Exception): + def __init__(self, status_code, message): + self.status_code = status_code + self.message = message + super().__init__( + self.message + ) # Call the base class constructor with the parameters it needs + + +class AI21LLM: + def __init__( + self, encoding, logging_obj, api_key=None + ): + self.encoding = encoding + self.completion_url_fragment_1 = "https://api.ai21.com/studio/v1/" + self.completion_url_fragment_2 = "/complete" + self.api_key = api_key + self.logging_obj = logging_obj + self.validate_environment(api_key=api_key) + + def validate_environment( + self, api_key + ): # set up the environment required to run the model + # set the api key + if self.api_key == None: + raise ValueError( + "Missing Baseten API Key - A call is being made to baseten but no key is set either in the environment variables or via params" + ) + self.api_key = api_key + self.headers = { + "accept": "application/json", + "content-type": "application/json", + "Authorization": "Bearer " + self.api_key, + } + + def completion( + self, + model: str, + messages: list, + model_response: ModelResponse, + print_verbose: Callable, + optional_params=None, + litellm_params=None, + logger_fn=None, + ): # logic for parsing in - calling - parsing out model completion calls + model = model + prompt = "" + for message in messages: + if "role" in message: + if message["role"] == "user": + prompt += ( + f"{message['content']}" + ) + else: + prompt += ( + f"{message['content']}" + ) + else: + prompt += f"{message['content']}" + data = { + "prompt": prompt, + # "instruction": prompt, # some baseten models require the prompt to be passed in via the 'instruction' kwarg + **optional_params, + } + + ## LOGGING + self.logging_obj.pre_call( + input=prompt, + api_key=self.api_key, + additional_args={"complete_input_dict": data}, + ) + ## COMPLETION CALL + response = requests.post( + self.completion_url_fragment_1 + model + self.completion_url_fragment_2, headers=self.headers, data=json.dumps(data) + ) + if "stream" in optional_params and optional_params["stream"] == True: + return response.iter_lines() + else: + ## LOGGING + self.logging_obj.post_call( + input=prompt, + api_key=self.api_key, + original_response=response.text, + additional_args={"complete_input_dict": data}, + ) + print_verbose(f"raw model_response: {response.text}") + ## RESPONSE OBJECT + completion_response = response.json() + if "error" in completion_response: + raise AI21Error( + message=completion_response["error"], + status_code=response.status_code, + ) + else: + try: + model_response["choices"][0]["message"]["content"] = completion_response["completions"][0]["data"]["text"] + except: + raise ValueError(f"Unable to parse response. Original response: {response.text}") + + ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. + prompt_tokens = len( + self.encoding.encode(prompt) + ) + completion_tokens = len( + self.encoding.encode(model_response["choices"][0]["message"]["content"]) + ) + + model_response["created"] = time.time() + model_response["model"] = model + model_response["usage"] = { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + } + return model_response + + def embedding( + self, + ): # logic for parsing in - calling - parsing out model embedding calls + pass diff --git a/litellm/llms/baseten.py b/litellm/llms/baseten.py index cc0fcec8d..49753d67b 100644 --- a/litellm/llms/baseten.py +++ b/litellm/llms/baseten.py @@ -62,7 +62,7 @@ class BasetenLLM: data = { "prompt": prompt, # "instruction": prompt, # some baseten models require the prompt to be passed in via the 'instruction' kwarg - # **optional_params, + **optional_params, } ## LOGGING diff --git a/litellm/main.py b/litellm/main.py index fa8cc847b..f20f173ff 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -22,6 +22,7 @@ from litellm.utils import ( from .llms.anthropic import AnthropicLLM from .llms.huggingface_restapi import HuggingfaceRestAPILLM from .llms.baseten import BasetenLLM +from .llms.ai21 import AI21LLM import tiktoken from concurrent.futures import ThreadPoolExecutor @@ -302,7 +303,11 @@ def completion( headers=litellm.headers, ) else: - response = openai.Completion.create(model=model, prompt=prompt) + response = openai.Completion.create(model=model, prompt=prompt, **optional_params) + + if "stream" in optional_params and optional_params["stream"] == True: + response = CustomStreamWrapper(response, model) + return response ## LOGGING logging.post_call( input=prompt, @@ -661,32 +666,34 @@ def completion( model_response["model"] = model response = model_response elif model in litellm.ai21_models: - install_and_import("ai21") - import ai21 - - ai21.api_key = get_secret("AI21_API_KEY") - - prompt = " ".join([message["content"] for message in messages]) - ## LOGGING - logging.pre_call(input=prompt, api_key=ai21.api_key) - - ai21_response = ai21.Completion.execute( + custom_llm_provider = "ai21" + ai21_key = ( + api_key + or litellm.ai21_key + or os.environ.get("AI21_API_KEY") + ) + ai21_client = AI21LLM( + encoding=encoding, api_key=ai21_key, logging_obj=logging + ) + + model_response = ai21_client.completion( model=model, - prompt=prompt, + messages=messages, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, ) - completion_response = ai21_response["completions"][0]["data"]["text"] - - ## LOGGING - logging.post_call( - input=prompt, - api_key=ai21.api_key, - original_response=completion_response, - ) - + + if "stream" in optional_params and optional_params["stream"] == True: + # don't try to access stream object, + response = CustomStreamWrapper( + model_response, model, custom_llm_provider="ai21" + ) + return response + ## RESPONSE OBJECT - model_response["choices"][0]["message"]["content"] = completion_response - model_response["created"] = time.time() - model_response["model"] = model response = model_response elif custom_llm_provider == "ollama": endpoint = ( @@ -725,7 +732,7 @@ def completion( if "stream" in optional_params and optional_params["stream"] == True: # don't try to access stream object, response = CustomStreamWrapper( - model_response, model, custom_llm_provider="huggingface" + model_response, model, custom_llm_provider="baseten" ) return response response = model_response diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index f6a676177..306a317eb 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -18,30 +18,85 @@ score = 0 def logger_fn(model_call_object: dict): - return print(f"model call details: {model_call_object}") user_message = "Hello, how are you?" messages = [{"content": user_message, "role": "user"}] +# test on baseten completion call +try: + response = completion( + model="wizard-lm", messages=messages, stream=True, logger_fn=logger_fn + ) + print(f"response: {response}") + complete_response = "" + start_time = time.time() + for chunk in response: + chunk_time = time.time() + print(f"time since initial request: {chunk_time - start_time:.5f}") + print(chunk["choices"][0]["delta"]) + complete_response += chunk["choices"][0]["delta"]["content"] + if complete_response == "": + raise Exception("Empty response received") +except: + print(f"error occurred: {traceback.format_exc()}") + pass + # test on openai completion call -# try: -# response = completion( -# model="gpt-3.5-turbo", messages=messages, stream=True, logger_fn=logger_fn -# ) -# complete_response = "" -# start_time = time.time() -# for chunk in response: -# chunk_time = time.time() -# print(f"time since initial request: {chunk_time - start_time:.5f}") -# print(chunk["choices"][0]["delta"]) -# complete_response += chunk["choices"][0]["delta"]["content"] -# if complete_response == "": -# raise Exception("Empty response received") -# except: -# print(f"error occurred: {traceback.format_exc()}") -# pass +try: + response = completion( + model="text-davinci-003", messages=messages, stream=True, logger_fn=logger_fn + ) + complete_response = "" + start_time = time.time() + for chunk in response: + chunk_time = time.time() + print(f"chunk: {chunk}") + complete_response += chunk["choices"][0]["delta"]["content"] + if complete_response == "": + raise Exception("Empty response received") +except: + print(f"error occurred: {traceback.format_exc()}") + pass + +# # test on ai21 completion call +try: + response = completion( + model="j2-ultra", messages=messages, stream=True, logger_fn=logger_fn + ) + print(f"response: {response}") + complete_response = "" + start_time = time.time() + for chunk in response: + chunk_time = time.time() + print(f"time since initial request: {chunk_time - start_time:.5f}") + print(chunk["choices"][0]["delta"]) + complete_response += chunk["choices"][0]["delta"]["content"] + if complete_response == "": + raise Exception("Empty response received") +except: + print(f"error occurred: {traceback.format_exc()}") + pass + + +# test on openai completion call +try: + response = completion( + model="gpt-3.5-turbo", messages=messages, stream=True, logger_fn=logger_fn + ) + complete_response = "" + start_time = time.time() + for chunk in response: + chunk_time = time.time() + print(f"time since initial request: {chunk_time - start_time:.5f}") + print(chunk["choices"][0]["delta"]) + complete_response += chunk["choices"][0]["delta"]["content"] + if complete_response == "": + raise Exception("Empty response received") +except: + print(f"error occurred: {traceback.format_exc()}") + pass # # test on azure completion call @@ -63,25 +118,6 @@ messages = [{"content": user_message, "role": "user"}] # pass -# test on anthropic completion call -try: - response = completion( - model="claude-instant-1", messages=messages, stream=True, logger_fn=logger_fn - ) - complete_response = "" - start_time = time.time() - for chunk in response: - chunk_time = time.time() - print(f"time since initial request: {chunk_time - start_time:.5f}") - print(chunk["choices"][0]["delta"]) - complete_response += chunk["choices"][0]["delta"]["content"] - if complete_response == "": - raise Exception("Empty response received") -except: - print(f"error occurred: {traceback.format_exc()}") - pass - - # # test on huggingface completion call # try: # start_time = time.time() @@ -123,7 +159,7 @@ except: print(f"error occurred: {traceback.format_exc()}") pass -# test on together ai completion call - starcoder +# # test on together ai completion call - starcoder try: start_time = time.time() response = completion( @@ -148,57 +184,3 @@ try: except: print(f"error occurred: {traceback.format_exc()}") pass - - -# # test on azure completion call -# try: -# response = completion( -# model="azure/chatgpt-test", messages=messages, stream=True, logger_fn=logger_fn -# ) -# response = "" -# for chunk in response: -# chunk_time = time.time() -# print(f"time since initial request: {chunk_time - start_time:.2f}") -# print(chunk["choices"][0]["delta"]) -# response += chunk["choices"][0]["delta"] -# if response == "": -# raise Exception("Empty response received") -# except: -# print(f"error occurred: {traceback.format_exc()}") -# pass - - -# # test on anthropic completion call -# try: -# response = completion( -# model="claude-instant-1", messages=messages, stream=True, logger_fn=logger_fn -# ) -# response = "" -# for chunk in response: -# chunk_time = time.time() -# print(f"time since initial request: {chunk_time - start_time:.2f}") -# print(chunk["choices"][0]["delta"]) -# response += chunk["choices"][0]["delta"] -# if response == "": -# raise Exception("Empty response received") -# except: -# print(f"error occurred: {traceback.format_exc()}") -# pass - - -# # # test on huggingface completion call -# # try: -# # response = completion( -# # model="meta-llama/Llama-2-7b-chat-hf", -# # messages=messages, -# # custom_llm_provider="huggingface", -# # custom_api_base="https://s7c7gytn18vnu4tw.us-east-1.aws.endpoints.huggingface.cloud", -# # stream=True, -# # logger_fn=logger_fn, -# # ) -# # for chunk in response: -# # print(chunk["choices"][0]["delta"]) -# # score += 1 -# # except: -# # print(f"error occurred: {traceback.format_exc()}") -# # pass diff --git a/litellm/utils.py b/litellm/utils.py index 7e85decf5..fb136a330 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -648,6 +648,7 @@ def get_optional_params( # use the openai defaults optional_params["top_k"] = top_k elif custom_llm_provider == "baseten": optional_params["temperature"] = temperature + optional_params["stream"] = stream optional_params["top_p"] = top_p optional_params["top_k"] = top_k optional_params["num_beams"] = num_beams @@ -1561,6 +1562,35 @@ class CustomStreamWrapper: else: return "" return "" + + def handle_ai21_chunk(self, chunk): + chunk = chunk.decode("utf-8") + data_json = json.loads(chunk) + try: + return data_json["completions"][0]["data"]["text"] + except: + raise ValueError(f"Unable to parse response. Original response: {chunk}") + + def handle_openai_text_completion_chunk(self, chunk): + try: + return chunk["choices"][0]["text"] + except: + raise ValueError(f"Unable to parse response. Original response: {chunk}") + + def handle_baseten_chunk(self, chunk): + chunk = chunk.decode("utf-8") + data_json = json.loads(chunk) + if "model_output" in data_json: + if isinstance(data_json["model_output"], dict) and "data" in data_json["model_output"] and isinstance(data_json["model_output"]["data"], list): + return data_json["model_output"]["data"][0] + elif isinstance(data_json["model_output"], str): + return data_json["model_output"] + elif "completion" in data_json and isinstance(data_json["completion"], str): + return data_json["completion"] + else: + raise ValueError(f"Unable to parse response. Original response: {chunk}") + else: + return "" def __next__(self): completion_obj = {"role": "assistant", "content": ""} @@ -1584,6 +1614,15 @@ class CustomStreamWrapper: elif self.custom_llm_provider and self.custom_llm_provider == "huggingface": chunk = next(self.completion_stream) completion_obj["content"] = self.handle_huggingface_chunk(chunk) + elif self.custom_llm_provider and self.custom_llm_provider == "baseten": # baseten doesn't provide streaming + chunk = next(self.completion_stream) + completion_obj["content"] = self.handle_baseten_chunk(chunk) + elif self.custom_llm_provider and self.custom_llm_provider == "ai21": #ai21 doesn't provide streaming + chunk = next(self.completion_stream) + completion_obj["content"] = self.handle_ai21_chunk(chunk) + elif self.model in litellm.open_ai_text_completion_models: + chunk = next(self.completion_stream) + completion_obj["content"] = self.handle_openai_text_completion_chunk(chunk) # return this for all models return {"choices": [{"delta": completion_obj}]}