fixes to streaming for ai21, baseten, and openai text completions

This commit is contained in:
Krrish Dholakia 2023-08-28 09:38:40 -07:00
parent d11cb3e2ea
commit d542066d4b
9 changed files with 273 additions and 117 deletions

View file

@ -20,6 +20,7 @@ azure_key: Optional[str] = None
anthropic_key: Optional[str] = None
replicate_key: Optional[str] = None
cohere_key: Optional[str] = None
ai21_key: Optional[str] = None
openrouter_key: Optional[str] = None
huggingface_key: Optional[str] = None
vertex_project: Optional[str] = None

127
litellm/llms/ai21.py Normal file
View file

@ -0,0 +1,127 @@
import os, json
from enum import Enum
import requests
import time
from typing import Callable
from litellm.utils import ModelResponse
class AI21Error(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class AI21LLM:
def __init__(
self, encoding, logging_obj, api_key=None
):
self.encoding = encoding
self.completion_url_fragment_1 = "https://api.ai21.com/studio/v1/"
self.completion_url_fragment_2 = "/complete"
self.api_key = api_key
self.logging_obj = logging_obj
self.validate_environment(api_key=api_key)
def validate_environment(
self, api_key
): # set up the environment required to run the model
# set the api key
if self.api_key == None:
raise ValueError(
"Missing Baseten API Key - A call is being made to baseten but no key is set either in the environment variables or via params"
)
self.api_key = api_key
self.headers = {
"accept": "application/json",
"content-type": "application/json",
"Authorization": "Bearer " + self.api_key,
}
def completion(
self,
model: str,
messages: list,
model_response: ModelResponse,
print_verbose: Callable,
optional_params=None,
litellm_params=None,
logger_fn=None,
): # logic for parsing in - calling - parsing out model completion calls
model = model
prompt = ""
for message in messages:
if "role" in message:
if message["role"] == "user":
prompt += (
f"{message['content']}"
)
else:
prompt += (
f"{message['content']}"
)
else:
prompt += f"{message['content']}"
data = {
"prompt": prompt,
# "instruction": prompt, # some baseten models require the prompt to be passed in via the 'instruction' kwarg
**optional_params,
}
## LOGGING
self.logging_obj.pre_call(
input=prompt,
api_key=self.api_key,
additional_args={"complete_input_dict": data},
)
## COMPLETION CALL
response = requests.post(
self.completion_url_fragment_1 + model + self.completion_url_fragment_2, headers=self.headers, data=json.dumps(data)
)
if "stream" in optional_params and optional_params["stream"] == True:
return response.iter_lines()
else:
## LOGGING
self.logging_obj.post_call(
input=prompt,
api_key=self.api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT
completion_response = response.json()
if "error" in completion_response:
raise AI21Error(
message=completion_response["error"],
status_code=response.status_code,
)
else:
try:
model_response["choices"][0]["message"]["content"] = completion_response["completions"][0]["data"]["text"]
except:
raise ValueError(f"Unable to parse response. Original response: {response.text}")
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
prompt_tokens = len(
self.encoding.encode(prompt)
)
completion_tokens = len(
self.encoding.encode(model_response["choices"][0]["message"]["content"])
)
model_response["created"] = time.time()
model_response["model"] = model
model_response["usage"] = {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
}
return model_response
def embedding(
self,
): # logic for parsing in - calling - parsing out model embedding calls
pass

View file

@ -62,7 +62,7 @@ class BasetenLLM:
data = {
"prompt": prompt,
# "instruction": prompt, # some baseten models require the prompt to be passed in via the 'instruction' kwarg
# **optional_params,
**optional_params,
}
## LOGGING

View file

@ -22,6 +22,7 @@ from litellm.utils import (
from .llms.anthropic import AnthropicLLM
from .llms.huggingface_restapi import HuggingfaceRestAPILLM
from .llms.baseten import BasetenLLM
from .llms.ai21 import AI21LLM
import tiktoken
from concurrent.futures import ThreadPoolExecutor
@ -302,7 +303,11 @@ def completion(
headers=litellm.headers,
)
else:
response = openai.Completion.create(model=model, prompt=prompt)
response = openai.Completion.create(model=model, prompt=prompt, **optional_params)
if "stream" in optional_params and optional_params["stream"] == True:
response = CustomStreamWrapper(response, model)
return response
## LOGGING
logging.post_call(
input=prompt,
@ -661,32 +666,34 @@ def completion(
model_response["model"] = model
response = model_response
elif model in litellm.ai21_models:
install_and_import("ai21")
import ai21
ai21.api_key = get_secret("AI21_API_KEY")
prompt = " ".join([message["content"] for message in messages])
## LOGGING
logging.pre_call(input=prompt, api_key=ai21.api_key)
ai21_response = ai21.Completion.execute(
custom_llm_provider = "ai21"
ai21_key = (
api_key
or litellm.ai21_key
or os.environ.get("AI21_API_KEY")
)
ai21_client = AI21LLM(
encoding=encoding, api_key=ai21_key, logging_obj=logging
)
model_response = ai21_client.completion(
model=model,
prompt=prompt,
messages=messages,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
)
completion_response = ai21_response["completions"][0]["data"]["text"]
## LOGGING
logging.post_call(
input=prompt,
api_key=ai21.api_key,
original_response=completion_response,
)
if "stream" in optional_params and optional_params["stream"] == True:
# don't try to access stream object,
response = CustomStreamWrapper(
model_response, model, custom_llm_provider="ai21"
)
return response
## RESPONSE OBJECT
model_response["choices"][0]["message"]["content"] = completion_response
model_response["created"] = time.time()
model_response["model"] = model
response = model_response
elif custom_llm_provider == "ollama":
endpoint = (
@ -725,7 +732,7 @@ def completion(
if "stream" in optional_params and optional_params["stream"] == True:
# don't try to access stream object,
response = CustomStreamWrapper(
model_response, model, custom_llm_provider="huggingface"
model_response, model, custom_llm_provider="baseten"
)
return response
response = model_response

View file

@ -18,30 +18,85 @@ score = 0
def logger_fn(model_call_object: dict):
return
print(f"model call details: {model_call_object}")
user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}]
# test on baseten completion call
try:
response = completion(
model="wizard-lm", messages=messages, stream=True, logger_fn=logger_fn
)
print(f"response: {response}")
complete_response = ""
start_time = time.time()
for chunk in response:
chunk_time = time.time()
print(f"time since initial request: {chunk_time - start_time:.5f}")
print(chunk["choices"][0]["delta"])
complete_response += chunk["choices"][0]["delta"]["content"]
if complete_response == "":
raise Exception("Empty response received")
except:
print(f"error occurred: {traceback.format_exc()}")
pass
# test on openai completion call
# try:
# response = completion(
# model="gpt-3.5-turbo", messages=messages, stream=True, logger_fn=logger_fn
# )
# complete_response = ""
# start_time = time.time()
# for chunk in response:
# chunk_time = time.time()
# print(f"time since initial request: {chunk_time - start_time:.5f}")
# print(chunk["choices"][0]["delta"])
# complete_response += chunk["choices"][0]["delta"]["content"]
# if complete_response == "":
# raise Exception("Empty response received")
# except:
# print(f"error occurred: {traceback.format_exc()}")
# pass
try:
response = completion(
model="text-davinci-003", messages=messages, stream=True, logger_fn=logger_fn
)
complete_response = ""
start_time = time.time()
for chunk in response:
chunk_time = time.time()
print(f"chunk: {chunk}")
complete_response += chunk["choices"][0]["delta"]["content"]
if complete_response == "":
raise Exception("Empty response received")
except:
print(f"error occurred: {traceback.format_exc()}")
pass
# # test on ai21 completion call
try:
response = completion(
model="j2-ultra", messages=messages, stream=True, logger_fn=logger_fn
)
print(f"response: {response}")
complete_response = ""
start_time = time.time()
for chunk in response:
chunk_time = time.time()
print(f"time since initial request: {chunk_time - start_time:.5f}")
print(chunk["choices"][0]["delta"])
complete_response += chunk["choices"][0]["delta"]["content"]
if complete_response == "":
raise Exception("Empty response received")
except:
print(f"error occurred: {traceback.format_exc()}")
pass
# test on openai completion call
try:
response = completion(
model="gpt-3.5-turbo", messages=messages, stream=True, logger_fn=logger_fn
)
complete_response = ""
start_time = time.time()
for chunk in response:
chunk_time = time.time()
print(f"time since initial request: {chunk_time - start_time:.5f}")
print(chunk["choices"][0]["delta"])
complete_response += chunk["choices"][0]["delta"]["content"]
if complete_response == "":
raise Exception("Empty response received")
except:
print(f"error occurred: {traceback.format_exc()}")
pass
# # test on azure completion call
@ -63,25 +118,6 @@ messages = [{"content": user_message, "role": "user"}]
# pass
# test on anthropic completion call
try:
response = completion(
model="claude-instant-1", messages=messages, stream=True, logger_fn=logger_fn
)
complete_response = ""
start_time = time.time()
for chunk in response:
chunk_time = time.time()
print(f"time since initial request: {chunk_time - start_time:.5f}")
print(chunk["choices"][0]["delta"])
complete_response += chunk["choices"][0]["delta"]["content"]
if complete_response == "":
raise Exception("Empty response received")
except:
print(f"error occurred: {traceback.format_exc()}")
pass
# # test on huggingface completion call
# try:
# start_time = time.time()
@ -123,7 +159,7 @@ except:
print(f"error occurred: {traceback.format_exc()}")
pass
# test on together ai completion call - starcoder
# # test on together ai completion call - starcoder
try:
start_time = time.time()
response = completion(
@ -148,57 +184,3 @@ try:
except:
print(f"error occurred: {traceback.format_exc()}")
pass
# # test on azure completion call
# try:
# response = completion(
# model="azure/chatgpt-test", messages=messages, stream=True, logger_fn=logger_fn
# )
# response = ""
# for chunk in response:
# chunk_time = time.time()
# print(f"time since initial request: {chunk_time - start_time:.2f}")
# print(chunk["choices"][0]["delta"])
# response += chunk["choices"][0]["delta"]
# if response == "":
# raise Exception("Empty response received")
# except:
# print(f"error occurred: {traceback.format_exc()}")
# pass
# # test on anthropic completion call
# try:
# response = completion(
# model="claude-instant-1", messages=messages, stream=True, logger_fn=logger_fn
# )
# response = ""
# for chunk in response:
# chunk_time = time.time()
# print(f"time since initial request: {chunk_time - start_time:.2f}")
# print(chunk["choices"][0]["delta"])
# response += chunk["choices"][0]["delta"]
# if response == "":
# raise Exception("Empty response received")
# except:
# print(f"error occurred: {traceback.format_exc()}")
# pass
# # # test on huggingface completion call
# # try:
# # response = completion(
# # model="meta-llama/Llama-2-7b-chat-hf",
# # messages=messages,
# # custom_llm_provider="huggingface",
# # custom_api_base="https://s7c7gytn18vnu4tw.us-east-1.aws.endpoints.huggingface.cloud",
# # stream=True,
# # logger_fn=logger_fn,
# # )
# # for chunk in response:
# # print(chunk["choices"][0]["delta"])
# # score += 1
# # except:
# # print(f"error occurred: {traceback.format_exc()}")
# # pass

View file

@ -648,6 +648,7 @@ def get_optional_params( # use the openai defaults
optional_params["top_k"] = top_k
elif custom_llm_provider == "baseten":
optional_params["temperature"] = temperature
optional_params["stream"] = stream
optional_params["top_p"] = top_p
optional_params["top_k"] = top_k
optional_params["num_beams"] = num_beams
@ -1561,6 +1562,35 @@ class CustomStreamWrapper:
else:
return ""
return ""
def handle_ai21_chunk(self, chunk):
chunk = chunk.decode("utf-8")
data_json = json.loads(chunk)
try:
return data_json["completions"][0]["data"]["text"]
except:
raise ValueError(f"Unable to parse response. Original response: {chunk}")
def handle_openai_text_completion_chunk(self, chunk):
try:
return chunk["choices"][0]["text"]
except:
raise ValueError(f"Unable to parse response. Original response: {chunk}")
def handle_baseten_chunk(self, chunk):
chunk = chunk.decode("utf-8")
data_json = json.loads(chunk)
if "model_output" in data_json:
if isinstance(data_json["model_output"], dict) and "data" in data_json["model_output"] and isinstance(data_json["model_output"]["data"], list):
return data_json["model_output"]["data"][0]
elif isinstance(data_json["model_output"], str):
return data_json["model_output"]
elif "completion" in data_json and isinstance(data_json["completion"], str):
return data_json["completion"]
else:
raise ValueError(f"Unable to parse response. Original response: {chunk}")
else:
return ""
def __next__(self):
completion_obj = {"role": "assistant", "content": ""}
@ -1584,6 +1614,15 @@ class CustomStreamWrapper:
elif self.custom_llm_provider and self.custom_llm_provider == "huggingface":
chunk = next(self.completion_stream)
completion_obj["content"] = self.handle_huggingface_chunk(chunk)
elif self.custom_llm_provider and self.custom_llm_provider == "baseten": # baseten doesn't provide streaming
chunk = next(self.completion_stream)
completion_obj["content"] = self.handle_baseten_chunk(chunk)
elif self.custom_llm_provider and self.custom_llm_provider == "ai21": #ai21 doesn't provide streaming
chunk = next(self.completion_stream)
completion_obj["content"] = self.handle_ai21_chunk(chunk)
elif self.model in litellm.open_ai_text_completion_models:
chunk = next(self.completion_stream)
completion_obj["content"] = self.handle_openai_text_completion_chunk(chunk)
# return this for all models
return {"choices": [{"delta": completion_obj}]}