fix(main.py): support async streaming for text completions endpoint

This commit is contained in:
Krrish Dholakia 2023-12-14 13:56:32 -08:00
parent 7df9c8e4d8
commit 1608dd7e0b
7 changed files with 175 additions and 68 deletions

View file

@ -280,6 +280,7 @@ model_list = (
provider_list: List = [
"openai",
"custom_openai",
"text-completion-openai",
"cohere",
"anthropic",
"replicate",

View file

@ -521,12 +521,14 @@ class OpenAITextCompletion(BaseLLM):
else:
prompt = " ".join([message["content"] for message in messages]) # type: ignore
# don't send max retries to the api, if set
optional_params.pop("max_retries", None)
data = {
"model": model,
"prompt": prompt,
**optional_params
}
## LOGGING
logging_obj.pre_call(
input=messages,

View file

@ -193,7 +193,6 @@ async def acompletion(*args, **kwargs):
# Call the synchronous function using run_in_executor
response = await loop.run_in_executor(None, func_with_context)
if kwargs.get("stream", False): # return an async generator
print_verbose(f"ENTERS STREAMING FOR ACOMPLETION")
return _async_streaming(response=response, model=model, custom_llm_provider=custom_llm_provider, args=args)
else:
return response
@ -664,17 +663,6 @@ def completion(
prompt = messages[0]["content"]
else:
prompt = " ".join([message["content"] for message in messages]) # type: ignore
## LOGGING
logging.pre_call(
input=prompt,
api_key=api_key,
additional_args={
"openai_organization": litellm.organization,
"headers": headers,
"api_base": api_base,
"api_type": openai.api_type,
},
)
## COMPLETION CALL
model_response = openai_text_completions.completion(
model=model,
@ -1991,6 +1979,59 @@ def embedding(
###### Text Completion ################
async def atext_completion(*args, **kwargs):
"""
Implemented to handle async streaming for the text completion endpoint
"""
loop = asyncio.get_event_loop()
model = args[0] if len(args) > 0 else kwargs["model"]
### PASS ARGS TO COMPLETION ###
kwargs["acompletion"] = True
custom_llm_provider = None
try:
# Use a partial function to pass your keyword arguments
func = partial(text_completion, *args, **kwargs)
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
_, custom_llm_provider, _, _ = get_llm_provider(model=model, api_base=kwargs.get("api_base", None))
if (custom_llm_provider == "openai"
or custom_llm_provider == "azure"
or custom_llm_provider == "custom_openai"
or custom_llm_provider == "anyscale"
or custom_llm_provider == "mistral"
or custom_llm_provider == "openrouter"
or custom_llm_provider == "deepinfra"
or custom_llm_provider == "perplexity"
or custom_llm_provider == "text-completion-openai"
or custom_llm_provider == "huggingface"
or custom_llm_provider == "ollama"
or custom_llm_provider == "vertex_ai"): # currently implemented aiohttp calls for just azure and openai, soon all.
if kwargs.get("stream", False):
response = text_completion(*args, **kwargs)
else:
# Await normally
init_response = await loop.run_in_executor(None, func_with_context)
if isinstance(init_response, dict) or isinstance(init_response, ModelResponse): ## CACHING SCENARIO
response = init_response
elif asyncio.iscoroutine(init_response):
response = await init_response
else:
# Call the synchronous function using run_in_executor
response = await loop.run_in_executor(None, func_with_context)
if kwargs.get("stream", False): # return an async generator
return _async_streaming(response=response, model=model, custom_llm_provider=custom_llm_provider, args=args)
else:
return response
except Exception as e:
custom_llm_provider = custom_llm_provider or "openai"
raise exception_type(
model=model, custom_llm_provider=custom_llm_provider, original_exception=e, completion_kwargs=args,
)
def text_completion(
prompt: Union[str, List[Union[str, List[Union[str, List[int]]]]]], # Required: The prompt(s) to generate completions for.
model: Optional[str]=None, # Optional: either `model` or `engine` can be set

View file

@ -797,37 +797,6 @@ async def async_data_generator(response, user_api_key_dict):
except:
yield f"data: {json.dumps(chunk)}\n\n"
def litellm_completion(*args, **kwargs):
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
call_type = kwargs.pop("call_type")
# override with user settings, these are params passed via cli
if user_temperature:
kwargs["temperature"] = user_temperature
if user_request_timeout:
kwargs["request_timeout"] = user_request_timeout
if user_max_tokens:
kwargs["max_tokens"] = user_max_tokens
if user_api_base:
kwargs["api_base"] = user_api_base
## ROUTE TO CORRECT ENDPOINT ##
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
try:
if llm_router is not None and kwargs["model"] in router_model_names: # model in router model list
if call_type == "chat_completion":
response = llm_router.completion(*args, **kwargs)
elif call_type == "text_completion":
response = llm_router.text_completion(*args, **kwargs)
else:
if call_type == "chat_completion":
response = litellm.completion(*args, **kwargs)
elif call_type == "text_completion":
response = litellm.text_completion(*args, **kwargs)
except Exception as e:
raise e
if 'stream' in kwargs and kwargs['stream'] == True: # use generate_responses to stream responses
return StreamingResponse(data_generator(response), media_type='text/event-stream')
return response
def get_litellm_model_info(model: dict = {}):
model_info = model.get("model_info", {})
model_to_lookup = model.get("litellm_params", {}).get("model", None)
@ -907,7 +876,8 @@ def model_list():
@router.post("/v1/completions", dependencies=[Depends(user_api_key_auth)])
@router.post("/completions", dependencies=[Depends(user_api_key_auth)])
@router.post("/engines/{model:path}/completions", dependencies=[Depends(user_api_key_auth)])
async def completion(request: Request, model: Optional[str] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth)):
async def completion(request: Request, model: Optional[str] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()):
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
try:
body = await request.body()
body_str = body.decode()
@ -925,17 +895,44 @@ async def completion(request: Request, model: Optional[str] = None, user_api_key
)
if user_model:
data["model"] = user_model
data["call_type"] = "text_completion"
if "metadata" in data:
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
else:
data["metadata"] = {"user_api_key": user_api_key_dict.api_key}
return litellm_completion(
**data
)
# override with user settings, these are params passed via cli
if user_temperature:
data["temperature"] = user_temperature
if user_request_timeout:
data["request_timeout"] = user_request_timeout
if user_max_tokens:
data["max_tokens"] = user_max_tokens
if user_api_base:
data["api_base"] = user_api_base
### CALL HOOKS ### - modify incoming data before calling the model
data = await proxy_logging_obj.pre_call_hook(user_api_key_dict=user_api_key_dict, data=data, call_type="completion")
### ROUTE THE REQUEST ###
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
if llm_router is not None and data["model"] in router_model_names: # model in router model list
response = await llm_router.atext_completion(**data)
elif llm_router is not None and data["model"] in llm_router.deployment_names: # model in router deployments, calling a specific deployment on the router
response = await llm_router.atext_completion(**data, specific_deployment = True)
elif llm_router is not None and llm_router.model_group_alias is not None and data["model"] in llm_router.model_group_alias: # model set in model_group_alias
response = await llm_router.atext_completion(**data)
else: # router is not set
response = await litellm.atext_completion(**data)
print(f"final response: {response}")
if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
return StreamingResponse(async_data_generator(user_api_key_dict=user_api_key_dict, response=response), media_type='text/event-stream')
background_tasks.add_task(log_input_output, request, response) # background task for logging to OTEL
return response
except Exception as e:
print(f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`")
traceback.print_exc()
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}\n\n{error_traceback}"
try:

View file

@ -310,6 +310,45 @@ class Router:
return self.function_with_retries(**kwargs)
else:
raise e
async def atext_completion(self,
model: str,
prompt: str,
is_retry: Optional[bool] = False,
is_fallback: Optional[bool] = False,
is_async: Optional[bool] = False,
**kwargs):
try:
kwargs.setdefault("metadata", {}).update({"model_group": model})
messages=[{"role": "user", "content": prompt}]
# pick the one that is available (lowest TPM/RPM)
deployment = self.get_available_deployment(model=model, messages=messages, specific_deployment=kwargs.pop("specific_deployment", None))
data = deployment["litellm_params"].copy()
for k, v in self.default_litellm_params.items():
if k not in data: # prioritize model-specific params > default router params
data[k] = v
########## remove -ModelID-XXXX from model ##############
original_model_string = data["model"]
# Find the index of "ModelID" in the string
index_of_model_id = original_model_string.find("-ModelID")
# Remove everything after "-ModelID" if it exists
if index_of_model_id != -1:
data["model"] = original_model_string[:index_of_model_id]
else:
data["model"] = original_model_string
# call via litellm.atext_completion()
response = await litellm.atext_completion(**{**data, "prompt": prompt, "caching": self.cache_responses, **kwargs}) # type: ignore
return response
except Exception as e:
if self.num_retries > 0:
kwargs["model"] = model
kwargs["messages"] = messages
kwargs["original_exception"] = e
kwargs["original_function"] = self.completion
return self.function_with_retries(**kwargs)
else:
raise e
def embedding(self,
model: str,

View file

@ -1,4 +1,4 @@
import sys, os
import sys, os, asyncio
import traceback
from dotenv import load_dotenv
@ -10,7 +10,7 @@ sys.path.insert(
) # Adds the parent directory to the system path
import pytest
import litellm
from litellm import embedding, completion, text_completion, completion_cost
from litellm import embedding, completion, text_completion, completion_cost, atext_completion
from litellm import RateLimitError
@ -61,7 +61,7 @@ def test_completion_openai_engine():
#print(response.choices[0].text)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
test_completion_openai_engine()
# test_completion_openai_engine()
def test_completion_chatgpt_prompt():
@ -163,8 +163,23 @@ def test_text_completion_stream():
max_tokens=10,
)
for chunk in response:
print(chunk)
print(f"chunk: {chunk}")
except Exception as e:
pytest.fail(f"GOT exception for HF In streaming{e}")
test_text_completion_stream()
# test_text_completion_stream()
async def test_text_completion_async_stream():
try:
response = await atext_completion(
model="text-completion-openai/text-davinci-003",
prompt="good morning",
stream=True,
max_tokens=10,
)
async for chunk in response:
print(f"chunk: {chunk}")
except Exception as e:
pytest.fail(f"GOT exception for HF In streaming{e}")
asyncio.run(test_text_completion_async_stream())

View file

@ -5872,31 +5872,43 @@ class TextCompletionStreamWrapper:
def __aiter__(self):
return self
def convert_to_text_completion_object(self, chunk: ModelResponse):
response = TextCompletionResponse()
response["id"] = chunk.get("id", None)
response["object"] = "text_completion"
response["created"] = response.get("created", None)
response["model"] = response.get("model", None)
text_choices = TextChoices()
text_choices["text"] = chunk["choices"][0]["delta"]["content"]
text_choices["index"] = response["choices"][0]["index"]
text_choices["finish_reason"] = response["choices"][0]["finish_reason"]
response["choices"] = [text_choices]
return response
def __next__(self):
# model_response = ModelResponse(stream=True, model=self.model)
response = TextCompletionResponse()
try:
while True: # loop until a non-empty string is found
# return this for all models
chunk = next(self.completion_stream)
response["id"] = chunk.get("id", None)
response["object"] = "text_completion"
response["created"] = response.get("created", None)
response["model"] = response.get("model", None)
text_choices = TextChoices()
text_choices["text"] = chunk["choices"][0]["delta"]["content"]
text_choices["index"] = response["choices"][0]["index"]
text_choices["finish_reason"] = response["choices"][0]["finish_reason"]
response["choices"] = [text_choices]
return response
for chunk in self.completion_stream:
if chunk == "None" or chunk is None:
raise Exception
processed_chunk = self.convert_to_text_completion_object(chunk=chunk)
return processed_chunk
raise StopIteration
except StopIteration:
raise StopIteration
except Exception as e:
print(f"got exception {e}") # noqa
async def __anext__(self):
try:
return next(self)
async for chunk in self.completion_stream:
if chunk == "None" or chunk is None:
raise Exception
processed_chunk = self.convert_to_text_completion_object(chunk=chunk)
return processed_chunk
raise StopIteration
except StopIteration:
raise StopAsyncIteration