forked from phoenix/litellm-mirror
fix(main.py): support async streaming for text completions endpoint
This commit is contained in:
parent
7df9c8e4d8
commit
1608dd7e0b
7 changed files with 175 additions and 68 deletions
|
@ -280,6 +280,7 @@ model_list = (
|
||||||
provider_list: List = [
|
provider_list: List = [
|
||||||
"openai",
|
"openai",
|
||||||
"custom_openai",
|
"custom_openai",
|
||||||
|
"text-completion-openai",
|
||||||
"cohere",
|
"cohere",
|
||||||
"anthropic",
|
"anthropic",
|
||||||
"replicate",
|
"replicate",
|
||||||
|
|
|
@ -521,12 +521,14 @@ class OpenAITextCompletion(BaseLLM):
|
||||||
else:
|
else:
|
||||||
prompt = " ".join([message["content"] for message in messages]) # type: ignore
|
prompt = " ".join([message["content"] for message in messages]) # type: ignore
|
||||||
|
|
||||||
|
# don't send max retries to the api, if set
|
||||||
|
optional_params.pop("max_retries", None)
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"model": model,
|
"model": model,
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
**optional_params
|
**optional_params
|
||||||
}
|
}
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=messages,
|
input=messages,
|
||||||
|
|
|
@ -193,7 +193,6 @@ async def acompletion(*args, **kwargs):
|
||||||
# Call the synchronous function using run_in_executor
|
# Call the synchronous function using run_in_executor
|
||||||
response = await loop.run_in_executor(None, func_with_context)
|
response = await loop.run_in_executor(None, func_with_context)
|
||||||
if kwargs.get("stream", False): # return an async generator
|
if kwargs.get("stream", False): # return an async generator
|
||||||
print_verbose(f"ENTERS STREAMING FOR ACOMPLETION")
|
|
||||||
return _async_streaming(response=response, model=model, custom_llm_provider=custom_llm_provider, args=args)
|
return _async_streaming(response=response, model=model, custom_llm_provider=custom_llm_provider, args=args)
|
||||||
else:
|
else:
|
||||||
return response
|
return response
|
||||||
|
@ -664,17 +663,6 @@ def completion(
|
||||||
prompt = messages[0]["content"]
|
prompt = messages[0]["content"]
|
||||||
else:
|
else:
|
||||||
prompt = " ".join([message["content"] for message in messages]) # type: ignore
|
prompt = " ".join([message["content"] for message in messages]) # type: ignore
|
||||||
## LOGGING
|
|
||||||
logging.pre_call(
|
|
||||||
input=prompt,
|
|
||||||
api_key=api_key,
|
|
||||||
additional_args={
|
|
||||||
"openai_organization": litellm.organization,
|
|
||||||
"headers": headers,
|
|
||||||
"api_base": api_base,
|
|
||||||
"api_type": openai.api_type,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
## COMPLETION CALL
|
## COMPLETION CALL
|
||||||
model_response = openai_text_completions.completion(
|
model_response = openai_text_completions.completion(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -1991,6 +1979,59 @@ def embedding(
|
||||||
|
|
||||||
|
|
||||||
###### Text Completion ################
|
###### Text Completion ################
|
||||||
|
async def atext_completion(*args, **kwargs):
|
||||||
|
"""
|
||||||
|
Implemented to handle async streaming for the text completion endpoint
|
||||||
|
"""
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
model = args[0] if len(args) > 0 else kwargs["model"]
|
||||||
|
### PASS ARGS TO COMPLETION ###
|
||||||
|
kwargs["acompletion"] = True
|
||||||
|
custom_llm_provider = None
|
||||||
|
try:
|
||||||
|
# Use a partial function to pass your keyword arguments
|
||||||
|
func = partial(text_completion, *args, **kwargs)
|
||||||
|
|
||||||
|
# Add the context to the function
|
||||||
|
ctx = contextvars.copy_context()
|
||||||
|
func_with_context = partial(ctx.run, func)
|
||||||
|
|
||||||
|
_, custom_llm_provider, _, _ = get_llm_provider(model=model, api_base=kwargs.get("api_base", None))
|
||||||
|
|
||||||
|
if (custom_llm_provider == "openai"
|
||||||
|
or custom_llm_provider == "azure"
|
||||||
|
or custom_llm_provider == "custom_openai"
|
||||||
|
or custom_llm_provider == "anyscale"
|
||||||
|
or custom_llm_provider == "mistral"
|
||||||
|
or custom_llm_provider == "openrouter"
|
||||||
|
or custom_llm_provider == "deepinfra"
|
||||||
|
or custom_llm_provider == "perplexity"
|
||||||
|
or custom_llm_provider == "text-completion-openai"
|
||||||
|
or custom_llm_provider == "huggingface"
|
||||||
|
or custom_llm_provider == "ollama"
|
||||||
|
or custom_llm_provider == "vertex_ai"): # currently implemented aiohttp calls for just azure and openai, soon all.
|
||||||
|
if kwargs.get("stream", False):
|
||||||
|
response = text_completion(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
# Await normally
|
||||||
|
init_response = await loop.run_in_executor(None, func_with_context)
|
||||||
|
if isinstance(init_response, dict) or isinstance(init_response, ModelResponse): ## CACHING SCENARIO
|
||||||
|
response = init_response
|
||||||
|
elif asyncio.iscoroutine(init_response):
|
||||||
|
response = await init_response
|
||||||
|
else:
|
||||||
|
# Call the synchronous function using run_in_executor
|
||||||
|
response = await loop.run_in_executor(None, func_with_context)
|
||||||
|
if kwargs.get("stream", False): # return an async generator
|
||||||
|
return _async_streaming(response=response, model=model, custom_llm_provider=custom_llm_provider, args=args)
|
||||||
|
else:
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
custom_llm_provider = custom_llm_provider or "openai"
|
||||||
|
raise exception_type(
|
||||||
|
model=model, custom_llm_provider=custom_llm_provider, original_exception=e, completion_kwargs=args,
|
||||||
|
)
|
||||||
|
|
||||||
def text_completion(
|
def text_completion(
|
||||||
prompt: Union[str, List[Union[str, List[Union[str, List[int]]]]]], # Required: The prompt(s) to generate completions for.
|
prompt: Union[str, List[Union[str, List[Union[str, List[int]]]]]], # Required: The prompt(s) to generate completions for.
|
||||||
model: Optional[str]=None, # Optional: either `model` or `engine` can be set
|
model: Optional[str]=None, # Optional: either `model` or `engine` can be set
|
||||||
|
|
|
@ -797,37 +797,6 @@ async def async_data_generator(response, user_api_key_dict):
|
||||||
except:
|
except:
|
||||||
yield f"data: {json.dumps(chunk)}\n\n"
|
yield f"data: {json.dumps(chunk)}\n\n"
|
||||||
|
|
||||||
def litellm_completion(*args, **kwargs):
|
|
||||||
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
|
|
||||||
call_type = kwargs.pop("call_type")
|
|
||||||
# override with user settings, these are params passed via cli
|
|
||||||
if user_temperature:
|
|
||||||
kwargs["temperature"] = user_temperature
|
|
||||||
if user_request_timeout:
|
|
||||||
kwargs["request_timeout"] = user_request_timeout
|
|
||||||
if user_max_tokens:
|
|
||||||
kwargs["max_tokens"] = user_max_tokens
|
|
||||||
if user_api_base:
|
|
||||||
kwargs["api_base"] = user_api_base
|
|
||||||
## ROUTE TO CORRECT ENDPOINT ##
|
|
||||||
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
|
|
||||||
try:
|
|
||||||
if llm_router is not None and kwargs["model"] in router_model_names: # model in router model list
|
|
||||||
if call_type == "chat_completion":
|
|
||||||
response = llm_router.completion(*args, **kwargs)
|
|
||||||
elif call_type == "text_completion":
|
|
||||||
response = llm_router.text_completion(*args, **kwargs)
|
|
||||||
else:
|
|
||||||
if call_type == "chat_completion":
|
|
||||||
response = litellm.completion(*args, **kwargs)
|
|
||||||
elif call_type == "text_completion":
|
|
||||||
response = litellm.text_completion(*args, **kwargs)
|
|
||||||
except Exception as e:
|
|
||||||
raise e
|
|
||||||
if 'stream' in kwargs and kwargs['stream'] == True: # use generate_responses to stream responses
|
|
||||||
return StreamingResponse(data_generator(response), media_type='text/event-stream')
|
|
||||||
return response
|
|
||||||
|
|
||||||
def get_litellm_model_info(model: dict = {}):
|
def get_litellm_model_info(model: dict = {}):
|
||||||
model_info = model.get("model_info", {})
|
model_info = model.get("model_info", {})
|
||||||
model_to_lookup = model.get("litellm_params", {}).get("model", None)
|
model_to_lookup = model.get("litellm_params", {}).get("model", None)
|
||||||
|
@ -907,7 +876,8 @@ def model_list():
|
||||||
@router.post("/v1/completions", dependencies=[Depends(user_api_key_auth)])
|
@router.post("/v1/completions", dependencies=[Depends(user_api_key_auth)])
|
||||||
@router.post("/completions", dependencies=[Depends(user_api_key_auth)])
|
@router.post("/completions", dependencies=[Depends(user_api_key_auth)])
|
||||||
@router.post("/engines/{model:path}/completions", dependencies=[Depends(user_api_key_auth)])
|
@router.post("/engines/{model:path}/completions", dependencies=[Depends(user_api_key_auth)])
|
||||||
async def completion(request: Request, model: Optional[str] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth)):
|
async def completion(request: Request, model: Optional[str] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()):
|
||||||
|
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
|
||||||
try:
|
try:
|
||||||
body = await request.body()
|
body = await request.body()
|
||||||
body_str = body.decode()
|
body_str = body.decode()
|
||||||
|
@ -925,17 +895,44 @@ async def completion(request: Request, model: Optional[str] = None, user_api_key
|
||||||
)
|
)
|
||||||
if user_model:
|
if user_model:
|
||||||
data["model"] = user_model
|
data["model"] = user_model
|
||||||
data["call_type"] = "text_completion"
|
|
||||||
if "metadata" in data:
|
if "metadata" in data:
|
||||||
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
|
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
|
||||||
else:
|
else:
|
||||||
data["metadata"] = {"user_api_key": user_api_key_dict.api_key}
|
data["metadata"] = {"user_api_key": user_api_key_dict.api_key}
|
||||||
|
|
||||||
return litellm_completion(
|
# override with user settings, these are params passed via cli
|
||||||
**data
|
if user_temperature:
|
||||||
)
|
data["temperature"] = user_temperature
|
||||||
|
if user_request_timeout:
|
||||||
|
data["request_timeout"] = user_request_timeout
|
||||||
|
if user_max_tokens:
|
||||||
|
data["max_tokens"] = user_max_tokens
|
||||||
|
if user_api_base:
|
||||||
|
data["api_base"] = user_api_base
|
||||||
|
|
||||||
|
### CALL HOOKS ### - modify incoming data before calling the model
|
||||||
|
data = await proxy_logging_obj.pre_call_hook(user_api_key_dict=user_api_key_dict, data=data, call_type="completion")
|
||||||
|
|
||||||
|
### ROUTE THE REQUEST ###
|
||||||
|
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
|
||||||
|
if llm_router is not None and data["model"] in router_model_names: # model in router model list
|
||||||
|
response = await llm_router.atext_completion(**data)
|
||||||
|
elif llm_router is not None and data["model"] in llm_router.deployment_names: # model in router deployments, calling a specific deployment on the router
|
||||||
|
response = await llm_router.atext_completion(**data, specific_deployment = True)
|
||||||
|
elif llm_router is not None and llm_router.model_group_alias is not None and data["model"] in llm_router.model_group_alias: # model set in model_group_alias
|
||||||
|
response = await llm_router.atext_completion(**data)
|
||||||
|
else: # router is not set
|
||||||
|
response = await litellm.atext_completion(**data)
|
||||||
|
|
||||||
|
print(f"final response: {response}")
|
||||||
|
if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
|
||||||
|
return StreamingResponse(async_data_generator(user_api_key_dict=user_api_key_dict, response=response), media_type='text/event-stream')
|
||||||
|
|
||||||
|
background_tasks.add_task(log_input_output, request, response) # background task for logging to OTEL
|
||||||
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`")
|
print(f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`")
|
||||||
|
traceback.print_exc()
|
||||||
error_traceback = traceback.format_exc()
|
error_traceback = traceback.format_exc()
|
||||||
error_msg = f"{str(e)}\n\n{error_traceback}"
|
error_msg = f"{str(e)}\n\n{error_traceback}"
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -311,6 +311,45 @@ class Router:
|
||||||
else:
|
else:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
async def atext_completion(self,
|
||||||
|
model: str,
|
||||||
|
prompt: str,
|
||||||
|
is_retry: Optional[bool] = False,
|
||||||
|
is_fallback: Optional[bool] = False,
|
||||||
|
is_async: Optional[bool] = False,
|
||||||
|
**kwargs):
|
||||||
|
try:
|
||||||
|
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||||
|
messages=[{"role": "user", "content": prompt}]
|
||||||
|
# pick the one that is available (lowest TPM/RPM)
|
||||||
|
deployment = self.get_available_deployment(model=model, messages=messages, specific_deployment=kwargs.pop("specific_deployment", None))
|
||||||
|
|
||||||
|
data = deployment["litellm_params"].copy()
|
||||||
|
for k, v in self.default_litellm_params.items():
|
||||||
|
if k not in data: # prioritize model-specific params > default router params
|
||||||
|
data[k] = v
|
||||||
|
########## remove -ModelID-XXXX from model ##############
|
||||||
|
original_model_string = data["model"]
|
||||||
|
# Find the index of "ModelID" in the string
|
||||||
|
index_of_model_id = original_model_string.find("-ModelID")
|
||||||
|
# Remove everything after "-ModelID" if it exists
|
||||||
|
if index_of_model_id != -1:
|
||||||
|
data["model"] = original_model_string[:index_of_model_id]
|
||||||
|
else:
|
||||||
|
data["model"] = original_model_string
|
||||||
|
# call via litellm.atext_completion()
|
||||||
|
response = await litellm.atext_completion(**{**data, "prompt": prompt, "caching": self.cache_responses, **kwargs}) # type: ignore
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
if self.num_retries > 0:
|
||||||
|
kwargs["model"] = model
|
||||||
|
kwargs["messages"] = messages
|
||||||
|
kwargs["original_exception"] = e
|
||||||
|
kwargs["original_function"] = self.completion
|
||||||
|
return self.function_with_retries(**kwargs)
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
|
||||||
def embedding(self,
|
def embedding(self,
|
||||||
model: str,
|
model: str,
|
||||||
input: Union[str, List],
|
input: Union[str, List],
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
import sys, os
|
import sys, os, asyncio
|
||||||
import traceback
|
import traceback
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
@ -10,7 +10,7 @@ sys.path.insert(
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
import pytest
|
import pytest
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import embedding, completion, text_completion, completion_cost
|
from litellm import embedding, completion, text_completion, completion_cost, atext_completion
|
||||||
from litellm import RateLimitError
|
from litellm import RateLimitError
|
||||||
|
|
||||||
|
|
||||||
|
@ -61,7 +61,7 @@ def test_completion_openai_engine():
|
||||||
#print(response.choices[0].text)
|
#print(response.choices[0].text)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
test_completion_openai_engine()
|
# test_completion_openai_engine()
|
||||||
|
|
||||||
|
|
||||||
def test_completion_chatgpt_prompt():
|
def test_completion_chatgpt_prompt():
|
||||||
|
@ -163,8 +163,23 @@ def test_text_completion_stream():
|
||||||
max_tokens=10,
|
max_tokens=10,
|
||||||
)
|
)
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
print(chunk)
|
print(f"chunk: {chunk}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"GOT exception for HF In streaming{e}")
|
pytest.fail(f"GOT exception for HF In streaming{e}")
|
||||||
|
|
||||||
test_text_completion_stream()
|
# test_text_completion_stream()
|
||||||
|
|
||||||
|
async def test_text_completion_async_stream():
|
||||||
|
try:
|
||||||
|
response = await atext_completion(
|
||||||
|
model="text-completion-openai/text-davinci-003",
|
||||||
|
prompt="good morning",
|
||||||
|
stream=True,
|
||||||
|
max_tokens=10,
|
||||||
|
)
|
||||||
|
async for chunk in response:
|
||||||
|
print(f"chunk: {chunk}")
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"GOT exception for HF In streaming{e}")
|
||||||
|
|
||||||
|
asyncio.run(test_text_completion_async_stream())
|
|
@ -5873,13 +5873,8 @@ class TextCompletionStreamWrapper:
|
||||||
def __aiter__(self):
|
def __aiter__(self):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __next__(self):
|
def convert_to_text_completion_object(self, chunk: ModelResponse):
|
||||||
# model_response = ModelResponse(stream=True, model=self.model)
|
|
||||||
response = TextCompletionResponse()
|
response = TextCompletionResponse()
|
||||||
try:
|
|
||||||
while True: # loop until a non-empty string is found
|
|
||||||
# return this for all models
|
|
||||||
chunk = next(self.completion_stream)
|
|
||||||
response["id"] = chunk.get("id", None)
|
response["id"] = chunk.get("id", None)
|
||||||
response["object"] = "text_completion"
|
response["object"] = "text_completion"
|
||||||
response["created"] = response.get("created", None)
|
response["created"] = response.get("created", None)
|
||||||
|
@ -5890,13 +5885,30 @@ class TextCompletionStreamWrapper:
|
||||||
text_choices["finish_reason"] = response["choices"][0]["finish_reason"]
|
text_choices["finish_reason"] = response["choices"][0]["finish_reason"]
|
||||||
response["choices"] = [text_choices]
|
response["choices"] = [text_choices]
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
# model_response = ModelResponse(stream=True, model=self.model)
|
||||||
|
response = TextCompletionResponse()
|
||||||
|
try:
|
||||||
|
for chunk in self.completion_stream:
|
||||||
|
if chunk == "None" or chunk is None:
|
||||||
|
raise Exception
|
||||||
|
processed_chunk = self.convert_to_text_completion_object(chunk=chunk)
|
||||||
|
return processed_chunk
|
||||||
|
raise StopIteration
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
raise StopIteration
|
raise StopIteration
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"got exception {e}") # noqa
|
print(f"got exception {e}") # noqa
|
||||||
|
|
||||||
async def __anext__(self):
|
async def __anext__(self):
|
||||||
try:
|
try:
|
||||||
return next(self)
|
async for chunk in self.completion_stream:
|
||||||
|
if chunk == "None" or chunk is None:
|
||||||
|
raise Exception
|
||||||
|
processed_chunk = self.convert_to_text_completion_object(chunk=chunk)
|
||||||
|
return processed_chunk
|
||||||
|
raise StopIteration
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
raise StopAsyncIteration
|
raise StopAsyncIteration
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue