mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
feat(assistants/main.py): add assistants api streaming support
This commit is contained in:
parent
7b474ec267
commit
f3d78532f9
9 changed files with 444 additions and 65 deletions
|
@ -117,7 +117,13 @@ import pydantic
|
|||
from litellm.proxy._types import *
|
||||
from litellm.caching import DualCache, RedisCache
|
||||
from litellm.proxy.health_check import perform_health_check
|
||||
from litellm.router import LiteLLM_Params, Deployment, updateDeployment, ModelGroupInfo
|
||||
from litellm.router import (
|
||||
LiteLLM_Params,
|
||||
Deployment,
|
||||
updateDeployment,
|
||||
ModelGroupInfo,
|
||||
AssistantsTypedDict,
|
||||
)
|
||||
from litellm.router import ModelInfo as RouterModelInfo
|
||||
from litellm._logging import verbose_router_logger, verbose_proxy_logger
|
||||
from litellm.proxy.auth.handle_jwt import JWTHandler
|
||||
|
@ -2853,6 +2859,17 @@ class ProxyConfig:
|
|||
if "ollama" in litellm_model_name and litellm_model_api_base is None:
|
||||
run_ollama_serve()
|
||||
|
||||
## ASSISTANT SETTINGS
|
||||
assistants_config: Optional[AssistantsTypedDict] = None
|
||||
assistant_settings = config.get("assistant_settings", None)
|
||||
if assistant_settings:
|
||||
for k, v in assistant_settings["litellm_params"].items():
|
||||
if isinstance(v, str) and v.startswith("os.environ/"):
|
||||
_v = v.replace("os.environ/", "")
|
||||
v = os.getenv(_v)
|
||||
assistant_settings["litellm_params"][k] = v
|
||||
assistants_config = AssistantsTypedDict(**assistant_settings) # type: ignore
|
||||
|
||||
## ROUTER SETTINGS (e.g. routing_strategy, ...)
|
||||
router_settings = config.get("router_settings", None)
|
||||
if router_settings and isinstance(router_settings, dict):
|
||||
|
@ -2868,7 +2885,9 @@ class ProxyConfig:
|
|||
for k, v in router_settings.items():
|
||||
if k in available_args:
|
||||
router_params[k] = v
|
||||
router = litellm.Router(**router_params) # type:ignore
|
||||
router = litellm.Router(
|
||||
**router_params, assistants_config=assistants_config
|
||||
) # type:ignore
|
||||
return router, model_list, general_settings
|
||||
|
||||
def get_model_info_with_id(self, model, db_model=False) -> RouterModelInfo:
|
||||
|
@ -3606,6 +3625,60 @@ def data_generator(response):
|
|||
yield f"data: {json.dumps(chunk)}\n\n"
|
||||
|
||||
|
||||
async def async_assistants_data_generator(
|
||||
response, user_api_key_dict: UserAPIKeyAuth, request_data: dict
|
||||
):
|
||||
verbose_proxy_logger.debug("inside generator")
|
||||
try:
|
||||
start_time = time.time()
|
||||
async with response as chunk:
|
||||
|
||||
### CALL HOOKS ### - modify outgoing data
|
||||
chunk = await proxy_logging_obj.async_post_call_streaming_hook(
|
||||
user_api_key_dict=user_api_key_dict, response=chunk
|
||||
)
|
||||
|
||||
# chunk = chunk.model_dump_json(exclude_none=True)
|
||||
async for c in chunk:
|
||||
c = c.model_dump_json(exclude_none=True)
|
||||
try:
|
||||
yield f"data: {c}\n\n"
|
||||
except Exception as e:
|
||||
yield f"data: {str(e)}\n\n"
|
||||
|
||||
# Streaming is done, yield the [DONE] chunk
|
||||
done_message = "[DONE]"
|
||||
yield f"data: {done_message}\n\n"
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
original_exception=e,
|
||||
request_data=request_data,
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`"
|
||||
)
|
||||
router_model_names = llm_router.model_names if llm_router is not None else []
|
||||
if user_debug:
|
||||
traceback.print_exc()
|
||||
|
||||
if isinstance(e, HTTPException):
|
||||
raise e
|
||||
else:
|
||||
error_traceback = traceback.format_exc()
|
||||
error_msg = f"{str(e)}\n\n{error_traceback}"
|
||||
|
||||
proxy_exception = ProxyException(
|
||||
message=getattr(e, "message", error_msg),
|
||||
type=getattr(e, "type", "None"),
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", 500),
|
||||
)
|
||||
error_returned = json.dumps({"error": proxy_exception.to_dict()})
|
||||
yield f"data: {error_returned}\n\n"
|
||||
|
||||
|
||||
async def async_data_generator(
|
||||
response, user_api_key_dict: UserAPIKeyAuth, request_data: dict
|
||||
):
|
||||
|
@ -5347,7 +5420,6 @@ async def get_assistants(
|
|||
try:
|
||||
# Use orjson to parse JSON data, orjson speeds up requests significantly
|
||||
body = await request.body()
|
||||
data = orjson.loads(body)
|
||||
|
||||
# Include original request and headers in the data
|
||||
data["proxy_server_request"] = { # type: ignore
|
||||
|
@ -5405,9 +5477,7 @@ async def get_assistants(
|
|||
raise HTTPException(
|
||||
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
|
||||
)
|
||||
response = await llm_router.aget_assistants(
|
||||
custom_llm_provider="openai", client=None, **data
|
||||
)
|
||||
response = await llm_router.aget_assistants(**data)
|
||||
|
||||
### ALERTING ###
|
||||
data["litellm_status"] = "success" # used for alerting
|
||||
|
@ -5479,7 +5549,6 @@ async def create_threads(
|
|||
try:
|
||||
# Use orjson to parse JSON data, orjson speeds up requests significantly
|
||||
body = await request.body()
|
||||
data = orjson.loads(body)
|
||||
|
||||
# Include original request and headers in the data
|
||||
data["proxy_server_request"] = { # type: ignore
|
||||
|
@ -5537,9 +5606,7 @@ async def create_threads(
|
|||
raise HTTPException(
|
||||
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
|
||||
)
|
||||
response = await llm_router.acreate_thread(
|
||||
custom_llm_provider="openai", client=None, **data
|
||||
)
|
||||
response = await llm_router.acreate_thread(**data)
|
||||
|
||||
### ALERTING ###
|
||||
data["litellm_status"] = "success" # used for alerting
|
||||
|
@ -5667,9 +5734,7 @@ async def get_thread(
|
|||
raise HTTPException(
|
||||
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
|
||||
)
|
||||
response = await llm_router.aget_thread(
|
||||
custom_llm_provider="openai", thread_id=thread_id, client=None, **data
|
||||
)
|
||||
response = await llm_router.aget_thread(thread_id=thread_id, **data)
|
||||
|
||||
### ALERTING ###
|
||||
data["litellm_status"] = "success" # used for alerting
|
||||
|
@ -5800,9 +5865,7 @@ async def add_messages(
|
|||
raise HTTPException(
|
||||
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
|
||||
)
|
||||
response = await llm_router.a_add_message(
|
||||
custom_llm_provider="openai", thread_id=thread_id, client=None, **data
|
||||
)
|
||||
response = await llm_router.a_add_message(thread_id=thread_id, **data)
|
||||
|
||||
### ALERTING ###
|
||||
data["litellm_status"] = "success" # used for alerting
|
||||
|
@ -5929,9 +5992,7 @@ async def get_messages(
|
|||
raise HTTPException(
|
||||
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
|
||||
)
|
||||
response = await llm_router.aget_messages(
|
||||
custom_llm_provider="openai", thread_id=thread_id, client=None, **data
|
||||
)
|
||||
response = await llm_router.aget_messages(thread_id=thread_id, **data)
|
||||
|
||||
### ALERTING ###
|
||||
data["litellm_status"] = "success" # used for alerting
|
||||
|
@ -6060,9 +6121,19 @@ async def run_thread(
|
|||
raise HTTPException(
|
||||
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
|
||||
)
|
||||
response = await llm_router.arun_thread(
|
||||
custom_llm_provider="openai", thread_id=thread_id, client=None, **data
|
||||
)
|
||||
response = await llm_router.arun_thread(thread_id=thread_id, **data)
|
||||
|
||||
if (
|
||||
"stream" in data and data["stream"] == True
|
||||
): # use generate_responses to stream responses
|
||||
return StreamingResponse(
|
||||
async_assistants_data_generator(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
response=response,
|
||||
request_data=data,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
### ALERTING ###
|
||||
data["litellm_status"] = "success" # used for alerting
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue