diff --git a/litellm/llms/ai21.py b/litellm/llms/ai21.py index 882e217e6b..f3f4a43427 100644 --- a/litellm/llms/ai21.py +++ b/litellm/llms/ai21.py @@ -1,4 +1,5 @@ -import os, json +import os +import json from enum import Enum import requests import time @@ -13,115 +14,102 @@ class AI21Error(Exception): self.message ) # Call the base class constructor with the parameters it needs +def validate_environment(api_key): + if api_key is None: + raise ValueError( + "Missing AI21 API Key - A call is being made to ai21 but no key is set either in the environment variables or via params" + ) + headers = { + "accept": "application/json", + "content-type": "application/json", + "Authorization": "Bearer " + api_key, + } + return headers -class AI21LLM: - def __init__( - self, encoding, logging_obj, api_key=None - ): - self.encoding = encoding - self.completion_url_fragment_1 = "https://api.ai21.com/studio/v1/" - self.completion_url_fragment_2 = "/complete" - self.api_key = api_key - self.logging_obj = logging_obj - self.validate_environment(api_key=api_key) - - def validate_environment( - self, api_key - ): # set up the environment required to run the model - # set the api key - if self.api_key == None: - raise ValueError( - "Missing AI21 API Key - A call is being made to ai21 but no key is set either in the environment variables or via params" - ) - self.api_key = api_key - self.headers = { - "accept": "application/json", - "content-type": "application/json", - "Authorization": "Bearer " + self.api_key, - } - - def completion( - self, - model: str, - messages: list, - model_response: ModelResponse, - print_verbose: Callable, - optional_params=None, - litellm_params=None, - logger_fn=None, - ): # logic for parsing in - calling - parsing out model completion calls - model = model - prompt = "" - for message in messages: - if "role" in message: - if message["role"] == "user": - prompt += ( - f"{message['content']}" - ) - else: - prompt += ( - f"{message['content']}" - ) +def completion( + model: str, + messages: list, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + api_key, + logging_obj, + optional_params=None, + litellm_params=None, + logger_fn=None, +): + headers = validate_environment(api_key) + model = model + prompt = "" + for message in messages: + if "role" in message: + if message["role"] == "user": + prompt += ( + f"{message['content']}" + ) else: - prompt += f"{message['content']}" - data = { - "prompt": prompt, - # "instruction": prompt, # some baseten models require the prompt to be passed in via the 'instruction' kwarg - **optional_params, - } + prompt += ( + f"{message['content']}" + ) + else: + prompt += f"{message['content']}" + data = { + "prompt": prompt, + # "instruction": prompt, # some baseten models require the prompt to be passed in via the 'instruction' kwarg + **optional_params, + } - ## LOGGING - self.logging_obj.pre_call( + ## LOGGING + logging_obj.pre_call( input=prompt, - api_key=self.api_key, + api_key=api_key, additional_args={"complete_input_dict": data}, ) - ## COMPLETION CALL - response = requests.post( - self.completion_url_fragment_1 + model + self.completion_url_fragment_2, headers=self.headers, data=json.dumps(data) - ) - if "stream" in optional_params and optional_params["stream"] == True: - return response.iter_lines() - else: - ## LOGGING - self.logging_obj.post_call( + ## COMPLETION CALL + response = requests.post( + "https://api.ai21.com/studio/v1/" + model + "/complete", headers=headers, data=json.dumps(data) + ) + if "stream" in optional_params and optional_params["stream"] == True: + return response.iter_lines() + else: + ## LOGGING + logging_obj.post_call( input=prompt, - api_key=self.api_key, + api_key=api_key, original_response=response.text, additional_args={"complete_input_dict": data}, ) - print_verbose(f"raw model_response: {response.text}") - ## RESPONSE OBJECT - completion_response = response.json() - if "error" in completion_response: - raise AI21Error( - message=completion_response["error"], - status_code=response.status_code, - ) - else: - try: - model_response["choices"][0]["message"]["content"] = completion_response["completions"][0]["data"]["text"] - except: - raise AI21Error(message=json.dumps(completion_response), status_code=response.status_code) - - ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. - prompt_tokens = len( - self.encoding.encode(prompt) - ) - completion_tokens = len( - self.encoding.encode(model_response["choices"][0]["message"]["content"]) + print_verbose(f"raw model_response: {response.text}") + ## RESPONSE OBJECT + completion_response = response.json() + if "error" in completion_response: + raise AI21Error( + message=completion_response["error"], + status_code=response.status_code, ) + else: + try: + model_response["choices"][0]["message"]["content"] = completion_response["completions"][0]["data"]["text"] + except: + raise AI21Error(message=json.dumps(completion_response), status_code=response.status_code) - model_response["created"] = time.time() - model_response["model"] = model - model_response["usage"] = { - "prompt_tokens": prompt_tokens, - "completion_tokens": completion_tokens, - "total_tokens": prompt_tokens + completion_tokens, - } - return model_response + ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. + prompt_tokens = len( + encoding.encode(prompt) + ) + completion_tokens = len( + encoding.encode(model_response["choices"][0]["message"]["content"]) + ) - def embedding( - self, - ): # logic for parsing in - calling - parsing out model embedding calls - pass + model_response["created"] = time.time() + model_response["model"] = model + model_response["usage"] = { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + } + return model_response + +def embedding(): + # logic for parsing in - calling - parsing out model embedding calls + pass diff --git a/litellm/main.py b/litellm/main.py index aa684f5b16..c1c5be6b0e 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -21,9 +21,9 @@ from litellm.utils import ( ) from .llms import anthropic from .llms import together_ai +from .llms import ai21 from .llms.huggingface_restapi import HuggingfaceRestAPILLM from .llms.baseten import BasetenLLM -from .llms.ai21 import AI21LLM from .llms.aleph_alpha import AlephAlphaLLM import tiktoken from concurrent.futures import ThreadPoolExecutor @@ -657,12 +657,8 @@ def completion( api_key or litellm.ai21_key or os.environ.get("AI21_API_KEY") - ) - ai21_client = AI21LLM( - encoding=encoding, api_key=ai21_key, logging_obj=logging - ) - - model_response = ai21_client.completion( + ) + model_response = ai21.completion( model=model, messages=messages, model_response=model_response, @@ -670,6 +666,9 @@ def completion( optional_params=optional_params, litellm_params=litellm_params, logger_fn=logger_fn, + encoding=encoding, + api_key=ai21_key, + logging_obj=logging ) if "stream" in optional_params and optional_params["stream"] == True: