fixes to streaming for ai21, baseten, and openai text completions

This commit is contained in:
Krrish Dholakia 2023-08-28 09:38:40 -07:00
parent d11cb3e2ea
commit d542066d4b
9 changed files with 273 additions and 117 deletions

View file

@ -22,6 +22,7 @@ from litellm.utils import (
from .llms.anthropic import AnthropicLLM
from .llms.huggingface_restapi import HuggingfaceRestAPILLM
from .llms.baseten import BasetenLLM
from .llms.ai21 import AI21LLM
import tiktoken
from concurrent.futures import ThreadPoolExecutor
@ -302,7 +303,11 @@ def completion(
headers=litellm.headers,
)
else:
response = openai.Completion.create(model=model, prompt=prompt)
response = openai.Completion.create(model=model, prompt=prompt, **optional_params)
if "stream" in optional_params and optional_params["stream"] == True:
response = CustomStreamWrapper(response, model)
return response
## LOGGING
logging.post_call(
input=prompt,
@ -661,32 +666,34 @@ def completion(
model_response["model"] = model
response = model_response
elif model in litellm.ai21_models:
install_and_import("ai21")
import ai21
ai21.api_key = get_secret("AI21_API_KEY")
prompt = " ".join([message["content"] for message in messages])
## LOGGING
logging.pre_call(input=prompt, api_key=ai21.api_key)
ai21_response = ai21.Completion.execute(
custom_llm_provider = "ai21"
ai21_key = (
api_key
or litellm.ai21_key
or os.environ.get("AI21_API_KEY")
)
ai21_client = AI21LLM(
encoding=encoding, api_key=ai21_key, logging_obj=logging
)
model_response = ai21_client.completion(
model=model,
prompt=prompt,
messages=messages,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
)
completion_response = ai21_response["completions"][0]["data"]["text"]
## LOGGING
logging.post_call(
input=prompt,
api_key=ai21.api_key,
original_response=completion_response,
)
if "stream" in optional_params and optional_params["stream"] == True:
# don't try to access stream object,
response = CustomStreamWrapper(
model_response, model, custom_llm_provider="ai21"
)
return response
## RESPONSE OBJECT
model_response["choices"][0]["message"]["content"] = completion_response
model_response["created"] = time.time()
model_response["model"] = model
response = model_response
elif custom_llm_provider == "ollama":
endpoint = (
@ -725,7 +732,7 @@ def completion(
if "stream" in optional_params and optional_params["stream"] == True:
# don't try to access stream object,
response = CustomStreamWrapper(
model_response, model, custom_llm_provider="huggingface"
model_response, model, custom_llm_provider="baseten"
)
return response
response = model_response