feat(assistants/main.py): add assistants api streaming support

This commit is contained in:
Krrish Dholakia 2024-06-04 16:30:35 -07:00
parent 7b474ec267
commit f3d78532f9
9 changed files with 444 additions and 65 deletions

View file

@ -117,7 +117,13 @@ import pydantic
from litellm.proxy._types import *
from litellm.caching import DualCache, RedisCache
from litellm.proxy.health_check import perform_health_check
from litellm.router import LiteLLM_Params, Deployment, updateDeployment, ModelGroupInfo
from litellm.router import (
LiteLLM_Params,
Deployment,
updateDeployment,
ModelGroupInfo,
AssistantsTypedDict,
)
from litellm.router import ModelInfo as RouterModelInfo
from litellm._logging import verbose_router_logger, verbose_proxy_logger
from litellm.proxy.auth.handle_jwt import JWTHandler
@ -2853,6 +2859,17 @@ class ProxyConfig:
if "ollama" in litellm_model_name and litellm_model_api_base is None:
run_ollama_serve()
## ASSISTANT SETTINGS
assistants_config: Optional[AssistantsTypedDict] = None
assistant_settings = config.get("assistant_settings", None)
if assistant_settings:
for k, v in assistant_settings["litellm_params"].items():
if isinstance(v, str) and v.startswith("os.environ/"):
_v = v.replace("os.environ/", "")
v = os.getenv(_v)
assistant_settings["litellm_params"][k] = v
assistants_config = AssistantsTypedDict(**assistant_settings) # type: ignore
## ROUTER SETTINGS (e.g. routing_strategy, ...)
router_settings = config.get("router_settings", None)
if router_settings and isinstance(router_settings, dict):
@ -2868,7 +2885,9 @@ class ProxyConfig:
for k, v in router_settings.items():
if k in available_args:
router_params[k] = v
router = litellm.Router(**router_params) # type:ignore
router = litellm.Router(
**router_params, assistants_config=assistants_config
) # type:ignore
return router, model_list, general_settings
def get_model_info_with_id(self, model, db_model=False) -> RouterModelInfo:
@ -3606,6 +3625,60 @@ def data_generator(response):
yield f"data: {json.dumps(chunk)}\n\n"
async def async_assistants_data_generator(
response, user_api_key_dict: UserAPIKeyAuth, request_data: dict
):
verbose_proxy_logger.debug("inside generator")
try:
start_time = time.time()
async with response as chunk:
### CALL HOOKS ### - modify outgoing data
chunk = await proxy_logging_obj.async_post_call_streaming_hook(
user_api_key_dict=user_api_key_dict, response=chunk
)
# chunk = chunk.model_dump_json(exclude_none=True)
async for c in chunk:
c = c.model_dump_json(exclude_none=True)
try:
yield f"data: {c}\n\n"
except Exception as e:
yield f"data: {str(e)}\n\n"
# Streaming is done, yield the [DONE] chunk
done_message = "[DONE]"
yield f"data: {done_message}\n\n"
except Exception as e:
traceback.print_exc()
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict,
original_exception=e,
request_data=request_data,
)
verbose_proxy_logger.debug(
f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`"
)
router_model_names = llm_router.model_names if llm_router is not None else []
if user_debug:
traceback.print_exc()
if isinstance(e, HTTPException):
raise e
else:
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}\n\n{error_traceback}"
proxy_exception = ProxyException(
message=getattr(e, "message", error_msg),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", 500),
)
error_returned = json.dumps({"error": proxy_exception.to_dict()})
yield f"data: {error_returned}\n\n"
async def async_data_generator(
response, user_api_key_dict: UserAPIKeyAuth, request_data: dict
):
@ -5347,7 +5420,6 @@ async def get_assistants(
try:
# Use orjson to parse JSON data, orjson speeds up requests significantly
body = await request.body()
data = orjson.loads(body)
# Include original request and headers in the data
data["proxy_server_request"] = { # type: ignore
@ -5405,9 +5477,7 @@ async def get_assistants(
raise HTTPException(
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
)
response = await llm_router.aget_assistants(
custom_llm_provider="openai", client=None, **data
)
response = await llm_router.aget_assistants(**data)
### ALERTING ###
data["litellm_status"] = "success" # used for alerting
@ -5479,7 +5549,6 @@ async def create_threads(
try:
# Use orjson to parse JSON data, orjson speeds up requests significantly
body = await request.body()
data = orjson.loads(body)
# Include original request and headers in the data
data["proxy_server_request"] = { # type: ignore
@ -5537,9 +5606,7 @@ async def create_threads(
raise HTTPException(
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
)
response = await llm_router.acreate_thread(
custom_llm_provider="openai", client=None, **data
)
response = await llm_router.acreate_thread(**data)
### ALERTING ###
data["litellm_status"] = "success" # used for alerting
@ -5667,9 +5734,7 @@ async def get_thread(
raise HTTPException(
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
)
response = await llm_router.aget_thread(
custom_llm_provider="openai", thread_id=thread_id, client=None, **data
)
response = await llm_router.aget_thread(thread_id=thread_id, **data)
### ALERTING ###
data["litellm_status"] = "success" # used for alerting
@ -5800,9 +5865,7 @@ async def add_messages(
raise HTTPException(
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
)
response = await llm_router.a_add_message(
custom_llm_provider="openai", thread_id=thread_id, client=None, **data
)
response = await llm_router.a_add_message(thread_id=thread_id, **data)
### ALERTING ###
data["litellm_status"] = "success" # used for alerting
@ -5929,9 +5992,7 @@ async def get_messages(
raise HTTPException(
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
)
response = await llm_router.aget_messages(
custom_llm_provider="openai", thread_id=thread_id, client=None, **data
)
response = await llm_router.aget_messages(thread_id=thread_id, **data)
### ALERTING ###
data["litellm_status"] = "success" # used for alerting
@ -6060,9 +6121,19 @@ async def run_thread(
raise HTTPException(
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
)
response = await llm_router.arun_thread(
custom_llm_provider="openai", thread_id=thread_id, client=None, **data
)
response = await llm_router.arun_thread(thread_id=thread_id, **data)
if (
"stream" in data and data["stream"] == True
): # use generate_responses to stream responses
return StreamingResponse(
async_assistants_data_generator(
user_api_key_dict=user_api_key_dict,
response=response,
request_data=data,
),
media_type="text/event-stream",
)
### ALERTING ###
data["litellm_status"] = "success" # used for alerting