mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
fixes to streaming for ai21, baseten, and openai text completions
This commit is contained in:
parent
66a0c7cf37
commit
3087c904eb
9 changed files with 273 additions and 117 deletions
|
@ -22,6 +22,7 @@ from litellm.utils import (
|
|||
from .llms.anthropic import AnthropicLLM
|
||||
from .llms.huggingface_restapi import HuggingfaceRestAPILLM
|
||||
from .llms.baseten import BasetenLLM
|
||||
from .llms.ai21 import AI21LLM
|
||||
import tiktoken
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
|
@ -302,7 +303,11 @@ def completion(
|
|||
headers=litellm.headers,
|
||||
)
|
||||
else:
|
||||
response = openai.Completion.create(model=model, prompt=prompt)
|
||||
response = openai.Completion.create(model=model, prompt=prompt, **optional_params)
|
||||
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
response = CustomStreamWrapper(response, model)
|
||||
return response
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
input=prompt,
|
||||
|
@ -661,32 +666,34 @@ def completion(
|
|||
model_response["model"] = model
|
||||
response = model_response
|
||||
elif model in litellm.ai21_models:
|
||||
install_and_import("ai21")
|
||||
import ai21
|
||||
|
||||
ai21.api_key = get_secret("AI21_API_KEY")
|
||||
|
||||
prompt = " ".join([message["content"] for message in messages])
|
||||
## LOGGING
|
||||
logging.pre_call(input=prompt, api_key=ai21.api_key)
|
||||
|
||||
ai21_response = ai21.Completion.execute(
|
||||
custom_llm_provider = "ai21"
|
||||
ai21_key = (
|
||||
api_key
|
||||
or litellm.ai21_key
|
||||
or os.environ.get("AI21_API_KEY")
|
||||
)
|
||||
ai21_client = AI21LLM(
|
||||
encoding=encoding, api_key=ai21_key, logging_obj=logging
|
||||
)
|
||||
|
||||
model_response = ai21_client.completion(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
messages=messages,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
)
|
||||
completion_response = ai21_response["completions"][0]["data"]["text"]
|
||||
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
input=prompt,
|
||||
api_key=ai21.api_key,
|
||||
original_response=completion_response,
|
||||
)
|
||||
|
||||
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
# don't try to access stream object,
|
||||
response = CustomStreamWrapper(
|
||||
model_response, model, custom_llm_provider="ai21"
|
||||
)
|
||||
return response
|
||||
|
||||
## RESPONSE OBJECT
|
||||
model_response["choices"][0]["message"]["content"] = completion_response
|
||||
model_response["created"] = time.time()
|
||||
model_response["model"] = model
|
||||
response = model_response
|
||||
elif custom_llm_provider == "ollama":
|
||||
endpoint = (
|
||||
|
@ -725,7 +732,7 @@ def completion(
|
|||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
# don't try to access stream object,
|
||||
response = CustomStreamWrapper(
|
||||
model_response, model, custom_llm_provider="huggingface"
|
||||
model_response, model, custom_llm_provider="baseten"
|
||||
)
|
||||
return response
|
||||
response = model_response
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue