forked from phoenix/litellm-mirror
fix(main.py): support async streaming for text completions endpoint
This commit is contained in:
parent
7df9c8e4d8
commit
1608dd7e0b
7 changed files with 175 additions and 68 deletions
|
@ -280,6 +280,7 @@ model_list = (
|
|||
provider_list: List = [
|
||||
"openai",
|
||||
"custom_openai",
|
||||
"text-completion-openai",
|
||||
"cohere",
|
||||
"anthropic",
|
||||
"replicate",
|
||||
|
|
|
@ -521,12 +521,14 @@ class OpenAITextCompletion(BaseLLM):
|
|||
else:
|
||||
prompt = " ".join([message["content"] for message in messages]) # type: ignore
|
||||
|
||||
# don't send max retries to the api, if set
|
||||
optional_params.pop("max_retries", None)
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
**optional_params
|
||||
}
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
|
|
|
@ -193,7 +193,6 @@ async def acompletion(*args, **kwargs):
|
|||
# Call the synchronous function using run_in_executor
|
||||
response = await loop.run_in_executor(None, func_with_context)
|
||||
if kwargs.get("stream", False): # return an async generator
|
||||
print_verbose(f"ENTERS STREAMING FOR ACOMPLETION")
|
||||
return _async_streaming(response=response, model=model, custom_llm_provider=custom_llm_provider, args=args)
|
||||
else:
|
||||
return response
|
||||
|
@ -664,17 +663,6 @@ def completion(
|
|||
prompt = messages[0]["content"]
|
||||
else:
|
||||
prompt = " ".join([message["content"] for message in messages]) # type: ignore
|
||||
## LOGGING
|
||||
logging.pre_call(
|
||||
input=prompt,
|
||||
api_key=api_key,
|
||||
additional_args={
|
||||
"openai_organization": litellm.organization,
|
||||
"headers": headers,
|
||||
"api_base": api_base,
|
||||
"api_type": openai.api_type,
|
||||
},
|
||||
)
|
||||
## COMPLETION CALL
|
||||
model_response = openai_text_completions.completion(
|
||||
model=model,
|
||||
|
@ -1991,6 +1979,59 @@ def embedding(
|
|||
|
||||
|
||||
###### Text Completion ################
|
||||
async def atext_completion(*args, **kwargs):
|
||||
"""
|
||||
Implemented to handle async streaming for the text completion endpoint
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
model = args[0] if len(args) > 0 else kwargs["model"]
|
||||
### PASS ARGS TO COMPLETION ###
|
||||
kwargs["acompletion"] = True
|
||||
custom_llm_provider = None
|
||||
try:
|
||||
# Use a partial function to pass your keyword arguments
|
||||
func = partial(text_completion, *args, **kwargs)
|
||||
|
||||
# Add the context to the function
|
||||
ctx = contextvars.copy_context()
|
||||
func_with_context = partial(ctx.run, func)
|
||||
|
||||
_, custom_llm_provider, _, _ = get_llm_provider(model=model, api_base=kwargs.get("api_base", None))
|
||||
|
||||
if (custom_llm_provider == "openai"
|
||||
or custom_llm_provider == "azure"
|
||||
or custom_llm_provider == "custom_openai"
|
||||
or custom_llm_provider == "anyscale"
|
||||
or custom_llm_provider == "mistral"
|
||||
or custom_llm_provider == "openrouter"
|
||||
or custom_llm_provider == "deepinfra"
|
||||
or custom_llm_provider == "perplexity"
|
||||
or custom_llm_provider == "text-completion-openai"
|
||||
or custom_llm_provider == "huggingface"
|
||||
or custom_llm_provider == "ollama"
|
||||
or custom_llm_provider == "vertex_ai"): # currently implemented aiohttp calls for just azure and openai, soon all.
|
||||
if kwargs.get("stream", False):
|
||||
response = text_completion(*args, **kwargs)
|
||||
else:
|
||||
# Await normally
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
if isinstance(init_response, dict) or isinstance(init_response, ModelResponse): ## CACHING SCENARIO
|
||||
response = init_response
|
||||
elif asyncio.iscoroutine(init_response):
|
||||
response = await init_response
|
||||
else:
|
||||
# Call the synchronous function using run_in_executor
|
||||
response = await loop.run_in_executor(None, func_with_context)
|
||||
if kwargs.get("stream", False): # return an async generator
|
||||
return _async_streaming(response=response, model=model, custom_llm_provider=custom_llm_provider, args=args)
|
||||
else:
|
||||
return response
|
||||
except Exception as e:
|
||||
custom_llm_provider = custom_llm_provider or "openai"
|
||||
raise exception_type(
|
||||
model=model, custom_llm_provider=custom_llm_provider, original_exception=e, completion_kwargs=args,
|
||||
)
|
||||
|
||||
def text_completion(
|
||||
prompt: Union[str, List[Union[str, List[Union[str, List[int]]]]]], # Required: The prompt(s) to generate completions for.
|
||||
model: Optional[str]=None, # Optional: either `model` or `engine` can be set
|
||||
|
|
|
@ -797,37 +797,6 @@ async def async_data_generator(response, user_api_key_dict):
|
|||
except:
|
||||
yield f"data: {json.dumps(chunk)}\n\n"
|
||||
|
||||
def litellm_completion(*args, **kwargs):
|
||||
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
|
||||
call_type = kwargs.pop("call_type")
|
||||
# override with user settings, these are params passed via cli
|
||||
if user_temperature:
|
||||
kwargs["temperature"] = user_temperature
|
||||
if user_request_timeout:
|
||||
kwargs["request_timeout"] = user_request_timeout
|
||||
if user_max_tokens:
|
||||
kwargs["max_tokens"] = user_max_tokens
|
||||
if user_api_base:
|
||||
kwargs["api_base"] = user_api_base
|
||||
## ROUTE TO CORRECT ENDPOINT ##
|
||||
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
|
||||
try:
|
||||
if llm_router is not None and kwargs["model"] in router_model_names: # model in router model list
|
||||
if call_type == "chat_completion":
|
||||
response = llm_router.completion(*args, **kwargs)
|
||||
elif call_type == "text_completion":
|
||||
response = llm_router.text_completion(*args, **kwargs)
|
||||
else:
|
||||
if call_type == "chat_completion":
|
||||
response = litellm.completion(*args, **kwargs)
|
||||
elif call_type == "text_completion":
|
||||
response = litellm.text_completion(*args, **kwargs)
|
||||
except Exception as e:
|
||||
raise e
|
||||
if 'stream' in kwargs and kwargs['stream'] == True: # use generate_responses to stream responses
|
||||
return StreamingResponse(data_generator(response), media_type='text/event-stream')
|
||||
return response
|
||||
|
||||
def get_litellm_model_info(model: dict = {}):
|
||||
model_info = model.get("model_info", {})
|
||||
model_to_lookup = model.get("litellm_params", {}).get("model", None)
|
||||
|
@ -907,7 +876,8 @@ def model_list():
|
|||
@router.post("/v1/completions", dependencies=[Depends(user_api_key_auth)])
|
||||
@router.post("/completions", dependencies=[Depends(user_api_key_auth)])
|
||||
@router.post("/engines/{model:path}/completions", dependencies=[Depends(user_api_key_auth)])
|
||||
async def completion(request: Request, model: Optional[str] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth)):
|
||||
async def completion(request: Request, model: Optional[str] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()):
|
||||
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
|
||||
try:
|
||||
body = await request.body()
|
||||
body_str = body.decode()
|
||||
|
@ -925,17 +895,44 @@ async def completion(request: Request, model: Optional[str] = None, user_api_key
|
|||
)
|
||||
if user_model:
|
||||
data["model"] = user_model
|
||||
data["call_type"] = "text_completion"
|
||||
if "metadata" in data:
|
||||
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
|
||||
else:
|
||||
data["metadata"] = {"user_api_key": user_api_key_dict.api_key}
|
||||
|
||||
return litellm_completion(
|
||||
**data
|
||||
)
|
||||
# override with user settings, these are params passed via cli
|
||||
if user_temperature:
|
||||
data["temperature"] = user_temperature
|
||||
if user_request_timeout:
|
||||
data["request_timeout"] = user_request_timeout
|
||||
if user_max_tokens:
|
||||
data["max_tokens"] = user_max_tokens
|
||||
if user_api_base:
|
||||
data["api_base"] = user_api_base
|
||||
|
||||
### CALL HOOKS ### - modify incoming data before calling the model
|
||||
data = await proxy_logging_obj.pre_call_hook(user_api_key_dict=user_api_key_dict, data=data, call_type="completion")
|
||||
|
||||
### ROUTE THE REQUEST ###
|
||||
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
|
||||
if llm_router is not None and data["model"] in router_model_names: # model in router model list
|
||||
response = await llm_router.atext_completion(**data)
|
||||
elif llm_router is not None and data["model"] in llm_router.deployment_names: # model in router deployments, calling a specific deployment on the router
|
||||
response = await llm_router.atext_completion(**data, specific_deployment = True)
|
||||
elif llm_router is not None and llm_router.model_group_alias is not None and data["model"] in llm_router.model_group_alias: # model set in model_group_alias
|
||||
response = await llm_router.atext_completion(**data)
|
||||
else: # router is not set
|
||||
response = await litellm.atext_completion(**data)
|
||||
|
||||
print(f"final response: {response}")
|
||||
if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
|
||||
return StreamingResponse(async_data_generator(user_api_key_dict=user_api_key_dict, response=response), media_type='text/event-stream')
|
||||
|
||||
background_tasks.add_task(log_input_output, request, response) # background task for logging to OTEL
|
||||
return response
|
||||
except Exception as e:
|
||||
print(f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`")
|
||||
traceback.print_exc()
|
||||
error_traceback = traceback.format_exc()
|
||||
error_msg = f"{str(e)}\n\n{error_traceback}"
|
||||
try:
|
||||
|
|
|
@ -310,6 +310,45 @@ class Router:
|
|||
return self.function_with_retries(**kwargs)
|
||||
else:
|
||||
raise e
|
||||
|
||||
async def atext_completion(self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
is_retry: Optional[bool] = False,
|
||||
is_fallback: Optional[bool] = False,
|
||||
is_async: Optional[bool] = False,
|
||||
**kwargs):
|
||||
try:
|
||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||
messages=[{"role": "user", "content": prompt}]
|
||||
# pick the one that is available (lowest TPM/RPM)
|
||||
deployment = self.get_available_deployment(model=model, messages=messages, specific_deployment=kwargs.pop("specific_deployment", None))
|
||||
|
||||
data = deployment["litellm_params"].copy()
|
||||
for k, v in self.default_litellm_params.items():
|
||||
if k not in data: # prioritize model-specific params > default router params
|
||||
data[k] = v
|
||||
########## remove -ModelID-XXXX from model ##############
|
||||
original_model_string = data["model"]
|
||||
# Find the index of "ModelID" in the string
|
||||
index_of_model_id = original_model_string.find("-ModelID")
|
||||
# Remove everything after "-ModelID" if it exists
|
||||
if index_of_model_id != -1:
|
||||
data["model"] = original_model_string[:index_of_model_id]
|
||||
else:
|
||||
data["model"] = original_model_string
|
||||
# call via litellm.atext_completion()
|
||||
response = await litellm.atext_completion(**{**data, "prompt": prompt, "caching": self.cache_responses, **kwargs}) # type: ignore
|
||||
return response
|
||||
except Exception as e:
|
||||
if self.num_retries > 0:
|
||||
kwargs["model"] = model
|
||||
kwargs["messages"] = messages
|
||||
kwargs["original_exception"] = e
|
||||
kwargs["original_function"] = self.completion
|
||||
return self.function_with_retries(**kwargs)
|
||||
else:
|
||||
raise e
|
||||
|
||||
def embedding(self,
|
||||
model: str,
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
import sys, os
|
||||
import sys, os, asyncio
|
||||
import traceback
|
||||
from dotenv import load_dotenv
|
||||
|
||||
|
@ -10,7 +10,7 @@ sys.path.insert(
|
|||
) # Adds the parent directory to the system path
|
||||
import pytest
|
||||
import litellm
|
||||
from litellm import embedding, completion, text_completion, completion_cost
|
||||
from litellm import embedding, completion, text_completion, completion_cost, atext_completion
|
||||
from litellm import RateLimitError
|
||||
|
||||
|
||||
|
@ -61,7 +61,7 @@ def test_completion_openai_engine():
|
|||
#print(response.choices[0].text)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
test_completion_openai_engine()
|
||||
# test_completion_openai_engine()
|
||||
|
||||
|
||||
def test_completion_chatgpt_prompt():
|
||||
|
@ -163,8 +163,23 @@ def test_text_completion_stream():
|
|||
max_tokens=10,
|
||||
)
|
||||
for chunk in response:
|
||||
print(chunk)
|
||||
print(f"chunk: {chunk}")
|
||||
except Exception as e:
|
||||
pytest.fail(f"GOT exception for HF In streaming{e}")
|
||||
|
||||
test_text_completion_stream()
|
||||
# test_text_completion_stream()
|
||||
|
||||
async def test_text_completion_async_stream():
|
||||
try:
|
||||
response = await atext_completion(
|
||||
model="text-completion-openai/text-davinci-003",
|
||||
prompt="good morning",
|
||||
stream=True,
|
||||
max_tokens=10,
|
||||
)
|
||||
async for chunk in response:
|
||||
print(f"chunk: {chunk}")
|
||||
except Exception as e:
|
||||
pytest.fail(f"GOT exception for HF In streaming{e}")
|
||||
|
||||
asyncio.run(test_text_completion_async_stream())
|
|
@ -5872,31 +5872,43 @@ class TextCompletionStreamWrapper:
|
|||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
def convert_to_text_completion_object(self, chunk: ModelResponse):
|
||||
response = TextCompletionResponse()
|
||||
response["id"] = chunk.get("id", None)
|
||||
response["object"] = "text_completion"
|
||||
response["created"] = response.get("created", None)
|
||||
response["model"] = response.get("model", None)
|
||||
text_choices = TextChoices()
|
||||
text_choices["text"] = chunk["choices"][0]["delta"]["content"]
|
||||
text_choices["index"] = response["choices"][0]["index"]
|
||||
text_choices["finish_reason"] = response["choices"][0]["finish_reason"]
|
||||
response["choices"] = [text_choices]
|
||||
return response
|
||||
|
||||
def __next__(self):
|
||||
# model_response = ModelResponse(stream=True, model=self.model)
|
||||
response = TextCompletionResponse()
|
||||
try:
|
||||
while True: # loop until a non-empty string is found
|
||||
# return this for all models
|
||||
chunk = next(self.completion_stream)
|
||||
response["id"] = chunk.get("id", None)
|
||||
response["object"] = "text_completion"
|
||||
response["created"] = response.get("created", None)
|
||||
response["model"] = response.get("model", None)
|
||||
text_choices = TextChoices()
|
||||
text_choices["text"] = chunk["choices"][0]["delta"]["content"]
|
||||
text_choices["index"] = response["choices"][0]["index"]
|
||||
text_choices["finish_reason"] = response["choices"][0]["finish_reason"]
|
||||
response["choices"] = [text_choices]
|
||||
return response
|
||||
for chunk in self.completion_stream:
|
||||
if chunk == "None" or chunk is None:
|
||||
raise Exception
|
||||
processed_chunk = self.convert_to_text_completion_object(chunk=chunk)
|
||||
return processed_chunk
|
||||
raise StopIteration
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
except Exception as e:
|
||||
print(f"got exception {e}") # noqa
|
||||
|
||||
async def __anext__(self):
|
||||
try:
|
||||
return next(self)
|
||||
async for chunk in self.completion_stream:
|
||||
if chunk == "None" or chunk is None:
|
||||
raise Exception
|
||||
processed_chunk = self.convert_to_text_completion_object(chunk=chunk)
|
||||
return processed_chunk
|
||||
raise StopIteration
|
||||
except StopIteration:
|
||||
raise StopAsyncIteration
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue