mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
feat(assistants/main.py): add assistants api streaming support
This commit is contained in:
parent
7b474ec267
commit
f3d78532f9
9 changed files with 444 additions and 65 deletions
|
@ -905,6 +905,14 @@ async def arun_thread(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def run_thread_stream(
|
||||||
|
*,
|
||||||
|
event_handler: Optional[AssistantEventHandler] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> AssistantStreamManager[AssistantEventHandler]:
|
||||||
|
return run_thread(stream=True, event_handler=event_handler, **kwargs) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def run_thread(
|
def run_thread(
|
||||||
custom_llm_provider: Literal["openai", "azure"],
|
custom_llm_provider: Literal["openai", "azure"],
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
|
@ -916,6 +924,7 @@ def run_thread(
|
||||||
stream: Optional[bool] = None,
|
stream: Optional[bool] = None,
|
||||||
tools: Optional[Iterable[AssistantToolParam]] = None,
|
tools: Optional[Iterable[AssistantToolParam]] = None,
|
||||||
client: Optional[Any] = None,
|
client: Optional[Any] = None,
|
||||||
|
event_handler: Optional[AssistantEventHandler] = None, # for stream=True calls
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Run:
|
) -> Run:
|
||||||
"""Run a given thread + assistant."""
|
"""Run a given thread + assistant."""
|
||||||
|
@ -959,6 +968,7 @@ def run_thread(
|
||||||
or litellm.openai_key
|
or litellm.openai_key
|
||||||
or os.getenv("OPENAI_API_KEY")
|
or os.getenv("OPENAI_API_KEY")
|
||||||
)
|
)
|
||||||
|
|
||||||
response = openai_assistants_api.run_thread(
|
response = openai_assistants_api.run_thread(
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
assistant_id=assistant_id,
|
assistant_id=assistant_id,
|
||||||
|
@ -975,6 +985,7 @@ def run_thread(
|
||||||
organization=organization,
|
organization=organization,
|
||||||
client=client,
|
client=client,
|
||||||
arun_thread=arun_thread,
|
arun_thread=arun_thread,
|
||||||
|
event_handler=event_handler,
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "azure":
|
elif custom_llm_provider == "azure":
|
||||||
api_base = (
|
api_base = (
|
||||||
|
|
|
@ -31,6 +31,10 @@ from ..types.llms.openai import (
|
||||||
Thread,
|
Thread,
|
||||||
AssistantToolParam,
|
AssistantToolParam,
|
||||||
Run,
|
Run,
|
||||||
|
AssistantEventHandler,
|
||||||
|
AsyncAssistantEventHandler,
|
||||||
|
AsyncAssistantStreamManager,
|
||||||
|
AssistantStreamManager,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1975,27 +1979,67 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await openai_client.beta.threads.runs.create_and_poll( # type: ignore
|
response = await openai_client.beta.threads.runs.create_and_poll( # type: ignore
|
||||||
thread_id="thread_OHLZkEj5xJLxdk0REZ4cl9sP",
|
thread_id=thread_id,
|
||||||
assistant_id="asst_nIzr656D1GIVMLHOKD76bN2T",
|
assistant_id=assistant_id,
|
||||||
additional_instructions=None,
|
additional_instructions=additional_instructions,
|
||||||
instructions=None,
|
instructions=instructions,
|
||||||
metadata=None,
|
metadata=metadata,
|
||||||
model=None,
|
model=model,
|
||||||
tools=None,
|
tools=tools,
|
||||||
)
|
)
|
||||||
|
|
||||||
# response = await openai_client.beta.threads.runs.create_and_poll( # type: ignore
|
|
||||||
# thread_id=thread_id,
|
|
||||||
# assistant_id=assistant_id,
|
|
||||||
# additional_instructions=additional_instructions,
|
|
||||||
# instructions=instructions,
|
|
||||||
# metadata=metadata,
|
|
||||||
# model=model,
|
|
||||||
# tools=tools,
|
|
||||||
# )
|
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
async def async_run_thread_stream(
|
||||||
|
self,
|
||||||
|
client: AsyncAzureOpenAI,
|
||||||
|
thread_id: str,
|
||||||
|
assistant_id: str,
|
||||||
|
additional_instructions: Optional[str],
|
||||||
|
instructions: Optional[str],
|
||||||
|
metadata: Optional[object],
|
||||||
|
model: Optional[str],
|
||||||
|
tools: Optional[Iterable[AssistantToolParam]],
|
||||||
|
event_handler: Optional[AssistantEventHandler],
|
||||||
|
) -> AsyncAssistantStreamManager[AsyncAssistantEventHandler]:
|
||||||
|
data = {
|
||||||
|
"thread_id": thread_id,
|
||||||
|
"assistant_id": assistant_id,
|
||||||
|
"additional_instructions": additional_instructions,
|
||||||
|
"instructions": instructions,
|
||||||
|
"metadata": metadata,
|
||||||
|
"model": model,
|
||||||
|
"tools": tools,
|
||||||
|
}
|
||||||
|
if event_handler is not None:
|
||||||
|
data["event_handler"] = event_handler
|
||||||
|
return client.beta.threads.runs.stream(**data) # type: ignore
|
||||||
|
|
||||||
|
def run_thread_stream(
|
||||||
|
self,
|
||||||
|
client: AzureOpenAI,
|
||||||
|
thread_id: str,
|
||||||
|
assistant_id: str,
|
||||||
|
additional_instructions: Optional[str],
|
||||||
|
instructions: Optional[str],
|
||||||
|
metadata: Optional[object],
|
||||||
|
model: Optional[str],
|
||||||
|
tools: Optional[Iterable[AssistantToolParam]],
|
||||||
|
event_handler: Optional[AssistantEventHandler],
|
||||||
|
) -> AssistantStreamManager[AssistantEventHandler]:
|
||||||
|
data = {
|
||||||
|
"thread_id": thread_id,
|
||||||
|
"assistant_id": assistant_id,
|
||||||
|
"additional_instructions": additional_instructions,
|
||||||
|
"instructions": instructions,
|
||||||
|
"metadata": metadata,
|
||||||
|
"model": model,
|
||||||
|
"tools": tools,
|
||||||
|
}
|
||||||
|
if event_handler is not None:
|
||||||
|
data["event_handler"] = event_handler
|
||||||
|
return client.beta.threads.runs.stream(**data) # type: ignore
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
@ -2062,8 +2106,30 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
max_retries: Optional[int],
|
max_retries: Optional[int],
|
||||||
client=None,
|
client=None,
|
||||||
arun_thread=None,
|
arun_thread=None,
|
||||||
|
event_handler: Optional[AssistantEventHandler] = None,
|
||||||
):
|
):
|
||||||
if arun_thread is not None and arun_thread == True:
|
if arun_thread is not None and arun_thread == True:
|
||||||
|
if stream is not None and stream == True:
|
||||||
|
azure_client = self.async_get_azure_client(
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
api_version=api_version,
|
||||||
|
azure_ad_token=azure_ad_token,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
return self.async_run_thread_stream(
|
||||||
|
client=azure_client,
|
||||||
|
thread_id=thread_id,
|
||||||
|
assistant_id=assistant_id,
|
||||||
|
additional_instructions=additional_instructions,
|
||||||
|
instructions=instructions,
|
||||||
|
metadata=metadata,
|
||||||
|
model=model,
|
||||||
|
tools=tools,
|
||||||
|
event_handler=event_handler,
|
||||||
|
)
|
||||||
return self.arun_thread(
|
return self.arun_thread(
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
assistant_id=assistant_id,
|
assistant_id=assistant_id,
|
||||||
|
@ -2091,6 +2157,19 @@ class AzureAssistantsAPI(BaseLLM):
|
||||||
client=client,
|
client=client,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if stream is not None and stream == True:
|
||||||
|
return self.run_thread_stream(
|
||||||
|
client=openai_client,
|
||||||
|
thread_id=thread_id,
|
||||||
|
assistant_id=assistant_id,
|
||||||
|
additional_instructions=additional_instructions,
|
||||||
|
instructions=instructions,
|
||||||
|
metadata=metadata,
|
||||||
|
model=model,
|
||||||
|
tools=tools,
|
||||||
|
event_handler=event_handler,
|
||||||
|
)
|
||||||
|
|
||||||
response = openai_client.beta.threads.runs.create_and_poll( # type: ignore
|
response = openai_client.beta.threads.runs.create_and_poll( # type: ignore
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
assistant_id=assistant_id,
|
assistant_id=assistant_id,
|
||||||
|
|
|
@ -2534,6 +2534,56 @@ class OpenAIAssistantsAPI(BaseLLM):
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
async def async_run_thread_stream(
|
||||||
|
self,
|
||||||
|
client: AsyncOpenAI,
|
||||||
|
thread_id: str,
|
||||||
|
assistant_id: str,
|
||||||
|
additional_instructions: Optional[str],
|
||||||
|
instructions: Optional[str],
|
||||||
|
metadata: Optional[object],
|
||||||
|
model: Optional[str],
|
||||||
|
tools: Optional[Iterable[AssistantToolParam]],
|
||||||
|
event_handler: Optional[AssistantEventHandler],
|
||||||
|
) -> AsyncAssistantStreamManager[AsyncAssistantEventHandler]:
|
||||||
|
data = {
|
||||||
|
"thread_id": thread_id,
|
||||||
|
"assistant_id": assistant_id,
|
||||||
|
"additional_instructions": additional_instructions,
|
||||||
|
"instructions": instructions,
|
||||||
|
"metadata": metadata,
|
||||||
|
"model": model,
|
||||||
|
"tools": tools,
|
||||||
|
}
|
||||||
|
if event_handler is not None:
|
||||||
|
data["event_handler"] = event_handler
|
||||||
|
return client.beta.threads.runs.stream(**data) # type: ignore
|
||||||
|
|
||||||
|
def run_thread_stream(
|
||||||
|
self,
|
||||||
|
client: OpenAI,
|
||||||
|
thread_id: str,
|
||||||
|
assistant_id: str,
|
||||||
|
additional_instructions: Optional[str],
|
||||||
|
instructions: Optional[str],
|
||||||
|
metadata: Optional[object],
|
||||||
|
model: Optional[str],
|
||||||
|
tools: Optional[Iterable[AssistantToolParam]],
|
||||||
|
event_handler: Optional[AssistantEventHandler],
|
||||||
|
) -> AssistantStreamManager[AssistantEventHandler]:
|
||||||
|
data = {
|
||||||
|
"thread_id": thread_id,
|
||||||
|
"assistant_id": assistant_id,
|
||||||
|
"additional_instructions": additional_instructions,
|
||||||
|
"instructions": instructions,
|
||||||
|
"metadata": metadata,
|
||||||
|
"model": model,
|
||||||
|
"tools": tools,
|
||||||
|
}
|
||||||
|
if event_handler is not None:
|
||||||
|
data["event_handler"] = event_handler
|
||||||
|
return client.beta.threads.runs.stream(**data) # type: ignore
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
@ -2554,6 +2604,7 @@ class OpenAIAssistantsAPI(BaseLLM):
|
||||||
organization: Optional[str],
|
organization: Optional[str],
|
||||||
client,
|
client,
|
||||||
arun_thread: Literal[True],
|
arun_thread: Literal[True],
|
||||||
|
event_handler: Optional[AssistantEventHandler],
|
||||||
) -> Coroutine[None, None, Run]:
|
) -> Coroutine[None, None, Run]:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@ -2575,6 +2626,7 @@ class OpenAIAssistantsAPI(BaseLLM):
|
||||||
organization: Optional[str],
|
organization: Optional[str],
|
||||||
client,
|
client,
|
||||||
arun_thread: Optional[Literal[False]],
|
arun_thread: Optional[Literal[False]],
|
||||||
|
event_handler: Optional[AssistantEventHandler],
|
||||||
) -> Run:
|
) -> Run:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@ -2597,8 +2649,29 @@ class OpenAIAssistantsAPI(BaseLLM):
|
||||||
organization: Optional[str],
|
organization: Optional[str],
|
||||||
client=None,
|
client=None,
|
||||||
arun_thread=None,
|
arun_thread=None,
|
||||||
|
event_handler: Optional[AssistantEventHandler] = None,
|
||||||
):
|
):
|
||||||
if arun_thread is not None and arun_thread == True:
|
if arun_thread is not None and arun_thread == True:
|
||||||
|
if stream is not None and stream == True:
|
||||||
|
_client = self.async_get_openai_client(
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
organization=organization,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
return self.async_run_thread_stream(
|
||||||
|
client=_client,
|
||||||
|
thread_id=thread_id,
|
||||||
|
assistant_id=assistant_id,
|
||||||
|
additional_instructions=additional_instructions,
|
||||||
|
instructions=instructions,
|
||||||
|
metadata=metadata,
|
||||||
|
model=model,
|
||||||
|
tools=tools,
|
||||||
|
event_handler=event_handler,
|
||||||
|
)
|
||||||
return self.arun_thread(
|
return self.arun_thread(
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
assistant_id=assistant_id,
|
assistant_id=assistant_id,
|
||||||
|
@ -2624,6 +2697,19 @@ class OpenAIAssistantsAPI(BaseLLM):
|
||||||
client=client,
|
client=client,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if stream is not None and stream == True:
|
||||||
|
return self.run_thread_stream(
|
||||||
|
client=openai_client,
|
||||||
|
thread_id=thread_id,
|
||||||
|
assistant_id=assistant_id,
|
||||||
|
additional_instructions=additional_instructions,
|
||||||
|
instructions=instructions,
|
||||||
|
metadata=metadata,
|
||||||
|
model=model,
|
||||||
|
tools=tools,
|
||||||
|
event_handler=event_handler,
|
||||||
|
)
|
||||||
|
|
||||||
response = openai_client.beta.threads.runs.create_and_poll( # type: ignore
|
response = openai_client.beta.threads.runs.create_and_poll( # type: ignore
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
assistant_id=assistant_id,
|
assistant_id=assistant_id,
|
||||||
|
|
|
@ -40,8 +40,14 @@ model_list:
|
||||||
vertex_project: my-project-9d5c
|
vertex_project: my-project-9d5c
|
||||||
vertex_location: us-central1
|
vertex_location: us-central1
|
||||||
|
|
||||||
|
assistant_settings:
|
||||||
|
custom_llm_provider: openai
|
||||||
|
litellm_params:
|
||||||
|
api_key: os.environ/OPENAI_API_KEY
|
||||||
|
|
||||||
|
|
||||||
router_settings:
|
router_settings:
|
||||||
enable_pre_call_checks: true
|
enable_pre_call_checks: true
|
||||||
|
|
||||||
litellm_settings:
|
litellm_settings:
|
||||||
success_callback: ["langfuse"]
|
success_callback: ["langfuse"]
|
||||||
|
|
|
@ -117,7 +117,13 @@ import pydantic
|
||||||
from litellm.proxy._types import *
|
from litellm.proxy._types import *
|
||||||
from litellm.caching import DualCache, RedisCache
|
from litellm.caching import DualCache, RedisCache
|
||||||
from litellm.proxy.health_check import perform_health_check
|
from litellm.proxy.health_check import perform_health_check
|
||||||
from litellm.router import LiteLLM_Params, Deployment, updateDeployment, ModelGroupInfo
|
from litellm.router import (
|
||||||
|
LiteLLM_Params,
|
||||||
|
Deployment,
|
||||||
|
updateDeployment,
|
||||||
|
ModelGroupInfo,
|
||||||
|
AssistantsTypedDict,
|
||||||
|
)
|
||||||
from litellm.router import ModelInfo as RouterModelInfo
|
from litellm.router import ModelInfo as RouterModelInfo
|
||||||
from litellm._logging import verbose_router_logger, verbose_proxy_logger
|
from litellm._logging import verbose_router_logger, verbose_proxy_logger
|
||||||
from litellm.proxy.auth.handle_jwt import JWTHandler
|
from litellm.proxy.auth.handle_jwt import JWTHandler
|
||||||
|
@ -2853,6 +2859,17 @@ class ProxyConfig:
|
||||||
if "ollama" in litellm_model_name and litellm_model_api_base is None:
|
if "ollama" in litellm_model_name and litellm_model_api_base is None:
|
||||||
run_ollama_serve()
|
run_ollama_serve()
|
||||||
|
|
||||||
|
## ASSISTANT SETTINGS
|
||||||
|
assistants_config: Optional[AssistantsTypedDict] = None
|
||||||
|
assistant_settings = config.get("assistant_settings", None)
|
||||||
|
if assistant_settings:
|
||||||
|
for k, v in assistant_settings["litellm_params"].items():
|
||||||
|
if isinstance(v, str) and v.startswith("os.environ/"):
|
||||||
|
_v = v.replace("os.environ/", "")
|
||||||
|
v = os.getenv(_v)
|
||||||
|
assistant_settings["litellm_params"][k] = v
|
||||||
|
assistants_config = AssistantsTypedDict(**assistant_settings) # type: ignore
|
||||||
|
|
||||||
## ROUTER SETTINGS (e.g. routing_strategy, ...)
|
## ROUTER SETTINGS (e.g. routing_strategy, ...)
|
||||||
router_settings = config.get("router_settings", None)
|
router_settings = config.get("router_settings", None)
|
||||||
if router_settings and isinstance(router_settings, dict):
|
if router_settings and isinstance(router_settings, dict):
|
||||||
|
@ -2868,7 +2885,9 @@ class ProxyConfig:
|
||||||
for k, v in router_settings.items():
|
for k, v in router_settings.items():
|
||||||
if k in available_args:
|
if k in available_args:
|
||||||
router_params[k] = v
|
router_params[k] = v
|
||||||
router = litellm.Router(**router_params) # type:ignore
|
router = litellm.Router(
|
||||||
|
**router_params, assistants_config=assistants_config
|
||||||
|
) # type:ignore
|
||||||
return router, model_list, general_settings
|
return router, model_list, general_settings
|
||||||
|
|
||||||
def get_model_info_with_id(self, model, db_model=False) -> RouterModelInfo:
|
def get_model_info_with_id(self, model, db_model=False) -> RouterModelInfo:
|
||||||
|
@ -3606,6 +3625,60 @@ def data_generator(response):
|
||||||
yield f"data: {json.dumps(chunk)}\n\n"
|
yield f"data: {json.dumps(chunk)}\n\n"
|
||||||
|
|
||||||
|
|
||||||
|
async def async_assistants_data_generator(
|
||||||
|
response, user_api_key_dict: UserAPIKeyAuth, request_data: dict
|
||||||
|
):
|
||||||
|
verbose_proxy_logger.debug("inside generator")
|
||||||
|
try:
|
||||||
|
start_time = time.time()
|
||||||
|
async with response as chunk:
|
||||||
|
|
||||||
|
### CALL HOOKS ### - modify outgoing data
|
||||||
|
chunk = await proxy_logging_obj.async_post_call_streaming_hook(
|
||||||
|
user_api_key_dict=user_api_key_dict, response=chunk
|
||||||
|
)
|
||||||
|
|
||||||
|
# chunk = chunk.model_dump_json(exclude_none=True)
|
||||||
|
async for c in chunk:
|
||||||
|
c = c.model_dump_json(exclude_none=True)
|
||||||
|
try:
|
||||||
|
yield f"data: {c}\n\n"
|
||||||
|
except Exception as e:
|
||||||
|
yield f"data: {str(e)}\n\n"
|
||||||
|
|
||||||
|
# Streaming is done, yield the [DONE] chunk
|
||||||
|
done_message = "[DONE]"
|
||||||
|
yield f"data: {done_message}\n\n"
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
await proxy_logging_obj.post_call_failure_hook(
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
original_exception=e,
|
||||||
|
request_data=request_data,
|
||||||
|
)
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`"
|
||||||
|
)
|
||||||
|
router_model_names = llm_router.model_names if llm_router is not None else []
|
||||||
|
if user_debug:
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
if isinstance(e, HTTPException):
|
||||||
|
raise e
|
||||||
|
else:
|
||||||
|
error_traceback = traceback.format_exc()
|
||||||
|
error_msg = f"{str(e)}\n\n{error_traceback}"
|
||||||
|
|
||||||
|
proxy_exception = ProxyException(
|
||||||
|
message=getattr(e, "message", error_msg),
|
||||||
|
type=getattr(e, "type", "None"),
|
||||||
|
param=getattr(e, "param", "None"),
|
||||||
|
code=getattr(e, "status_code", 500),
|
||||||
|
)
|
||||||
|
error_returned = json.dumps({"error": proxy_exception.to_dict()})
|
||||||
|
yield f"data: {error_returned}\n\n"
|
||||||
|
|
||||||
|
|
||||||
async def async_data_generator(
|
async def async_data_generator(
|
||||||
response, user_api_key_dict: UserAPIKeyAuth, request_data: dict
|
response, user_api_key_dict: UserAPIKeyAuth, request_data: dict
|
||||||
):
|
):
|
||||||
|
@ -5347,7 +5420,6 @@ async def get_assistants(
|
||||||
try:
|
try:
|
||||||
# Use orjson to parse JSON data, orjson speeds up requests significantly
|
# Use orjson to parse JSON data, orjson speeds up requests significantly
|
||||||
body = await request.body()
|
body = await request.body()
|
||||||
data = orjson.loads(body)
|
|
||||||
|
|
||||||
# Include original request and headers in the data
|
# Include original request and headers in the data
|
||||||
data["proxy_server_request"] = { # type: ignore
|
data["proxy_server_request"] = { # type: ignore
|
||||||
|
@ -5405,9 +5477,7 @@ async def get_assistants(
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
|
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
|
||||||
)
|
)
|
||||||
response = await llm_router.aget_assistants(
|
response = await llm_router.aget_assistants(**data)
|
||||||
custom_llm_provider="openai", client=None, **data
|
|
||||||
)
|
|
||||||
|
|
||||||
### ALERTING ###
|
### ALERTING ###
|
||||||
data["litellm_status"] = "success" # used for alerting
|
data["litellm_status"] = "success" # used for alerting
|
||||||
|
@ -5479,7 +5549,6 @@ async def create_threads(
|
||||||
try:
|
try:
|
||||||
# Use orjson to parse JSON data, orjson speeds up requests significantly
|
# Use orjson to parse JSON data, orjson speeds up requests significantly
|
||||||
body = await request.body()
|
body = await request.body()
|
||||||
data = orjson.loads(body)
|
|
||||||
|
|
||||||
# Include original request and headers in the data
|
# Include original request and headers in the data
|
||||||
data["proxy_server_request"] = { # type: ignore
|
data["proxy_server_request"] = { # type: ignore
|
||||||
|
@ -5537,9 +5606,7 @@ async def create_threads(
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
|
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
|
||||||
)
|
)
|
||||||
response = await llm_router.acreate_thread(
|
response = await llm_router.acreate_thread(**data)
|
||||||
custom_llm_provider="openai", client=None, **data
|
|
||||||
)
|
|
||||||
|
|
||||||
### ALERTING ###
|
### ALERTING ###
|
||||||
data["litellm_status"] = "success" # used for alerting
|
data["litellm_status"] = "success" # used for alerting
|
||||||
|
@ -5667,9 +5734,7 @@ async def get_thread(
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
|
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
|
||||||
)
|
)
|
||||||
response = await llm_router.aget_thread(
|
response = await llm_router.aget_thread(thread_id=thread_id, **data)
|
||||||
custom_llm_provider="openai", thread_id=thread_id, client=None, **data
|
|
||||||
)
|
|
||||||
|
|
||||||
### ALERTING ###
|
### ALERTING ###
|
||||||
data["litellm_status"] = "success" # used for alerting
|
data["litellm_status"] = "success" # used for alerting
|
||||||
|
@ -5800,9 +5865,7 @@ async def add_messages(
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
|
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
|
||||||
)
|
)
|
||||||
response = await llm_router.a_add_message(
|
response = await llm_router.a_add_message(thread_id=thread_id, **data)
|
||||||
custom_llm_provider="openai", thread_id=thread_id, client=None, **data
|
|
||||||
)
|
|
||||||
|
|
||||||
### ALERTING ###
|
### ALERTING ###
|
||||||
data["litellm_status"] = "success" # used for alerting
|
data["litellm_status"] = "success" # used for alerting
|
||||||
|
@ -5929,9 +5992,7 @@ async def get_messages(
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
|
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
|
||||||
)
|
)
|
||||||
response = await llm_router.aget_messages(
|
response = await llm_router.aget_messages(thread_id=thread_id, **data)
|
||||||
custom_llm_provider="openai", thread_id=thread_id, client=None, **data
|
|
||||||
)
|
|
||||||
|
|
||||||
### ALERTING ###
|
### ALERTING ###
|
||||||
data["litellm_status"] = "success" # used for alerting
|
data["litellm_status"] = "success" # used for alerting
|
||||||
|
@ -6060,9 +6121,19 @@ async def run_thread(
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
|
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
|
||||||
)
|
)
|
||||||
response = await llm_router.arun_thread(
|
response = await llm_router.arun_thread(thread_id=thread_id, **data)
|
||||||
custom_llm_provider="openai", thread_id=thread_id, client=None, **data
|
|
||||||
)
|
if (
|
||||||
|
"stream" in data and data["stream"] == True
|
||||||
|
): # use generate_responses to stream responses
|
||||||
|
return StreamingResponse(
|
||||||
|
async_assistants_data_generator(
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
response=response,
|
||||||
|
request_data=data,
|
||||||
|
),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
)
|
||||||
|
|
||||||
### ALERTING ###
|
### ALERTING ###
|
||||||
data["litellm_status"] = "success" # used for alerting
|
data["litellm_status"] = "success" # used for alerting
|
||||||
|
|
|
@ -51,6 +51,7 @@ from litellm.types.router import (
|
||||||
AlertingConfig,
|
AlertingConfig,
|
||||||
DeploymentTypedDict,
|
DeploymentTypedDict,
|
||||||
ModelGroupInfo,
|
ModelGroupInfo,
|
||||||
|
AssistantsTypedDict,
|
||||||
)
|
)
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
from litellm.llms.azure import get_azure_ad_token_from_oidc
|
from litellm.llms.azure import get_azure_ad_token_from_oidc
|
||||||
|
@ -78,6 +79,8 @@ class Router:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_list: Optional[List[Union[DeploymentTypedDict, Dict]]] = None,
|
model_list: Optional[List[Union[DeploymentTypedDict, Dict]]] = None,
|
||||||
|
## ASSISTANTS API ##
|
||||||
|
assistants_config: Optional[AssistantsTypedDict] = None,
|
||||||
## CACHING ##
|
## CACHING ##
|
||||||
redis_url: Optional[str] = None,
|
redis_url: Optional[str] = None,
|
||||||
redis_host: Optional[str] = None,
|
redis_host: Optional[str] = None,
|
||||||
|
@ -212,6 +215,7 @@ class Router:
|
||||||
elif debug_level == "DEBUG":
|
elif debug_level == "DEBUG":
|
||||||
verbose_router_logger.setLevel(logging.DEBUG)
|
verbose_router_logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
self.assistants_config = assistants_config
|
||||||
self.deployment_names: List = (
|
self.deployment_names: List = (
|
||||||
[]
|
[]
|
||||||
) # names of models under litellm_params. ex. azure/chatgpt-v-2
|
) # names of models under litellm_params. ex. azure/chatgpt-v-2
|
||||||
|
@ -1831,31 +1835,56 @@ class Router:
|
||||||
|
|
||||||
async def aget_assistants(
|
async def aget_assistants(
|
||||||
self,
|
self,
|
||||||
custom_llm_provider: Literal["openai"],
|
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
|
||||||
client: Optional[AsyncOpenAI] = None,
|
client: Optional[AsyncOpenAI] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> AsyncCursorPage[Assistant]:
|
) -> AsyncCursorPage[Assistant]:
|
||||||
|
if custom_llm_provider is None:
|
||||||
|
if self.assistants_config is not None:
|
||||||
|
custom_llm_provider = self.assistants_config["custom_llm_provider"]
|
||||||
|
kwargs.update(self.assistants_config["litellm_params"])
|
||||||
|
else:
|
||||||
|
raise Exception(
|
||||||
|
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
|
||||||
|
)
|
||||||
|
|
||||||
return await litellm.aget_assistants(
|
return await litellm.aget_assistants(
|
||||||
custom_llm_provider=custom_llm_provider, client=client, **kwargs
|
custom_llm_provider=custom_llm_provider, client=client, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
async def acreate_thread(
|
async def acreate_thread(
|
||||||
self,
|
self,
|
||||||
custom_llm_provider: Literal["openai"],
|
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
|
||||||
client: Optional[AsyncOpenAI] = None,
|
client: Optional[AsyncOpenAI] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Thread:
|
) -> Thread:
|
||||||
|
if custom_llm_provider is None:
|
||||||
|
if self.assistants_config is not None:
|
||||||
|
custom_llm_provider = self.assistants_config["custom_llm_provider"]
|
||||||
|
kwargs.update(self.assistants_config["litellm_params"])
|
||||||
|
else:
|
||||||
|
raise Exception(
|
||||||
|
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
|
||||||
|
)
|
||||||
return await litellm.acreate_thread(
|
return await litellm.acreate_thread(
|
||||||
custom_llm_provider=custom_llm_provider, client=client, **kwargs
|
custom_llm_provider=custom_llm_provider, client=client, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
async def aget_thread(
|
async def aget_thread(
|
||||||
self,
|
self,
|
||||||
custom_llm_provider: Literal["openai"],
|
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
|
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
|
||||||
client: Optional[AsyncOpenAI] = None,
|
client: Optional[AsyncOpenAI] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Thread:
|
) -> Thread:
|
||||||
|
if custom_llm_provider is None:
|
||||||
|
if self.assistants_config is not None:
|
||||||
|
custom_llm_provider = self.assistants_config["custom_llm_provider"]
|
||||||
|
kwargs.update(self.assistants_config["litellm_params"])
|
||||||
|
else:
|
||||||
|
raise Exception(
|
||||||
|
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
|
||||||
|
)
|
||||||
return await litellm.aget_thread(
|
return await litellm.aget_thread(
|
||||||
custom_llm_provider=custom_llm_provider,
|
custom_llm_provider=custom_llm_provider,
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
|
@ -1865,15 +1894,24 @@ class Router:
|
||||||
|
|
||||||
async def a_add_message(
|
async def a_add_message(
|
||||||
self,
|
self,
|
||||||
custom_llm_provider: Literal["openai"],
|
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
role: Literal["user", "assistant"],
|
role: Literal["user", "assistant"],
|
||||||
content: str,
|
content: str,
|
||||||
attachments: Optional[List[Attachment]] = None,
|
attachments: Optional[List[Attachment]] = None,
|
||||||
metadata: Optional[dict] = None,
|
metadata: Optional[dict] = None,
|
||||||
|
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
|
||||||
client: Optional[AsyncOpenAI] = None,
|
client: Optional[AsyncOpenAI] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> OpenAIMessage:
|
) -> OpenAIMessage:
|
||||||
|
if custom_llm_provider is None:
|
||||||
|
if self.assistants_config is not None:
|
||||||
|
custom_llm_provider = self.assistants_config["custom_llm_provider"]
|
||||||
|
kwargs.update(self.assistants_config["litellm_params"])
|
||||||
|
else:
|
||||||
|
raise Exception(
|
||||||
|
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
|
||||||
|
)
|
||||||
|
|
||||||
return await litellm.a_add_message(
|
return await litellm.a_add_message(
|
||||||
custom_llm_provider=custom_llm_provider,
|
custom_llm_provider=custom_llm_provider,
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
|
@ -1887,11 +1925,19 @@ class Router:
|
||||||
|
|
||||||
async def aget_messages(
|
async def aget_messages(
|
||||||
self,
|
self,
|
||||||
custom_llm_provider: Literal["openai"],
|
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
|
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
|
||||||
client: Optional[AsyncOpenAI] = None,
|
client: Optional[AsyncOpenAI] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> AsyncCursorPage[OpenAIMessage]:
|
) -> AsyncCursorPage[OpenAIMessage]:
|
||||||
|
if custom_llm_provider is None:
|
||||||
|
if self.assistants_config is not None:
|
||||||
|
custom_llm_provider = self.assistants_config["custom_llm_provider"]
|
||||||
|
kwargs.update(self.assistants_config["litellm_params"])
|
||||||
|
else:
|
||||||
|
raise Exception(
|
||||||
|
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
|
||||||
|
)
|
||||||
return await litellm.aget_messages(
|
return await litellm.aget_messages(
|
||||||
custom_llm_provider=custom_llm_provider,
|
custom_llm_provider=custom_llm_provider,
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
|
@ -1901,9 +1947,9 @@ class Router:
|
||||||
|
|
||||||
async def arun_thread(
|
async def arun_thread(
|
||||||
self,
|
self,
|
||||||
custom_llm_provider: Literal["openai"],
|
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
assistant_id: str,
|
assistant_id: str,
|
||||||
|
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
|
||||||
additional_instructions: Optional[str] = None,
|
additional_instructions: Optional[str] = None,
|
||||||
instructions: Optional[str] = None,
|
instructions: Optional[str] = None,
|
||||||
metadata: Optional[dict] = None,
|
metadata: Optional[dict] = None,
|
||||||
|
@ -1913,6 +1959,16 @@ class Router:
|
||||||
client: Optional[Any] = None,
|
client: Optional[Any] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Run:
|
) -> Run:
|
||||||
|
|
||||||
|
if custom_llm_provider is None:
|
||||||
|
if self.assistants_config is not None:
|
||||||
|
custom_llm_provider = self.assistants_config["custom_llm_provider"]
|
||||||
|
kwargs.update(self.assistants_config["litellm_params"])
|
||||||
|
else:
|
||||||
|
raise Exception(
|
||||||
|
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
|
||||||
|
)
|
||||||
|
|
||||||
return await litellm.arun_thread(
|
return await litellm.arun_thread(
|
||||||
custom_llm_provider=custom_llm_provider,
|
custom_llm_provider=custom_llm_provider,
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
|
|
|
@ -18,7 +18,9 @@ from litellm.llms.openai import (
|
||||||
OpenAIMessage as Message,
|
OpenAIMessage as Message,
|
||||||
AsyncCursorPage,
|
AsyncCursorPage,
|
||||||
SyncCursorPage,
|
SyncCursorPage,
|
||||||
|
AssistantEventHandler,
|
||||||
)
|
)
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
"""
|
"""
|
||||||
V0 Scope:
|
V0 Scope:
|
||||||
|
@ -129,10 +131,46 @@ async def test_add_message_litellm(sync_mode, provider):
|
||||||
assert isinstance(added_message, Message)
|
assert isinstance(added_message, Message)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("provider", ["openai", "azure"])
|
class EventHandler(AssistantEventHandler):
|
||||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
@override
|
||||||
|
def on_text_created(self, text) -> None:
|
||||||
|
print(f"\nassistant > ", end="", flush=True)
|
||||||
|
|
||||||
|
@override
|
||||||
|
def on_text_delta(self, delta, snapshot):
|
||||||
|
print(delta.value, end="", flush=True)
|
||||||
|
|
||||||
|
def on_tool_call_created(self, tool_call):
|
||||||
|
print(f"\nassistant > {tool_call.type}\n", flush=True)
|
||||||
|
|
||||||
|
def on_tool_call_delta(self, delta, snapshot):
|
||||||
|
if delta.type == "code_interpreter":
|
||||||
|
if delta.code_interpreter.input:
|
||||||
|
print(delta.code_interpreter.input, end="", flush=True)
|
||||||
|
if delta.code_interpreter.outputs:
|
||||||
|
print(f"\n\noutput >", flush=True)
|
||||||
|
for output in delta.code_interpreter.outputs:
|
||||||
|
if output.type == "logs":
|
||||||
|
print(f"\n{output.logs}", flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"provider",
|
||||||
|
[
|
||||||
|
"azure",
|
||||||
|
"openai",
|
||||||
|
],
|
||||||
|
) #
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"sync_mode",
|
||||||
|
[False, True],
|
||||||
|
) #
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"is_streaming",
|
||||||
|
[True, False],
|
||||||
|
) #
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_aarun_thread_litellm(sync_mode, provider):
|
async def test_aarun_thread_litellm(sync_mode, provider, is_streaming):
|
||||||
"""
|
"""
|
||||||
- Get Assistants
|
- Get Assistants
|
||||||
- Create thread
|
- Create thread
|
||||||
|
@ -163,27 +201,48 @@ async def test_aarun_thread_litellm(sync_mode, provider):
|
||||||
if sync_mode:
|
if sync_mode:
|
||||||
added_message = litellm.add_message(**data)
|
added_message = litellm.add_message(**data)
|
||||||
|
|
||||||
run = litellm.run_thread(assistant_id=assistant_id, **data)
|
if is_streaming:
|
||||||
|
run = litellm.run_thread_stream(
|
||||||
if run.status == "completed":
|
assistant_id=assistant_id, event_handler=EventHandler(), **data
|
||||||
messages = litellm.get_messages(
|
|
||||||
thread_id=_new_thread.id, custom_llm_provider=provider
|
|
||||||
)
|
)
|
||||||
assert isinstance(messages.data[0], Message)
|
with run as run:
|
||||||
|
assert isinstance(run, AssistantEventHandler)
|
||||||
|
print(run)
|
||||||
|
run.until_done()
|
||||||
else:
|
else:
|
||||||
pytest.fail("An unexpected error occurred when running the thread")
|
run = litellm.run_thread(
|
||||||
|
assistant_id=assistant_id, stream=is_streaming, **data
|
||||||
|
)
|
||||||
|
if run.status == "completed":
|
||||||
|
messages = litellm.get_messages(
|
||||||
|
thread_id=_new_thread.id, custom_llm_provider=provider
|
||||||
|
)
|
||||||
|
assert isinstance(messages.data[0], Message)
|
||||||
|
else:
|
||||||
|
pytest.fail("An unexpected error occurred when running the thread")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
added_message = await litellm.a_add_message(**data)
|
added_message = await litellm.a_add_message(**data)
|
||||||
|
|
||||||
run = await litellm.arun_thread(
|
if is_streaming:
|
||||||
custom_llm_provider=provider, thread_id=thread_id, assistant_id=assistant_id
|
run = litellm.run_thread_stream(
|
||||||
)
|
assistant_id=assistant_id, event_handler=EventHandler(), **data
|
||||||
|
|
||||||
if run.status == "completed":
|
|
||||||
messages = await litellm.aget_messages(
|
|
||||||
thread_id=_new_thread.id, custom_llm_provider=provider
|
|
||||||
)
|
)
|
||||||
assert isinstance(messages.data[0], Message)
|
with run as run:
|
||||||
|
assert isinstance(run, AssistantEventHandler)
|
||||||
|
print(run)
|
||||||
|
run.until_done()
|
||||||
else:
|
else:
|
||||||
pytest.fail("An unexpected error occurred when running the thread")
|
run = await litellm.arun_thread(
|
||||||
|
custom_llm_provider=provider,
|
||||||
|
thread_id=thread_id,
|
||||||
|
assistant_id=assistant_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if run.status == "completed":
|
||||||
|
messages = await litellm.aget_messages(
|
||||||
|
thread_id=_new_thread.id, custom_llm_provider=provider
|
||||||
|
)
|
||||||
|
assert isinstance(messages.data[0], Message)
|
||||||
|
else:
|
||||||
|
pytest.fail("An unexpected error occurred when running the thread")
|
||||||
|
|
|
@ -13,6 +13,12 @@ from openai.types.beta.threads.message import Message as OpenAIMessage
|
||||||
from openai.types.beta.thread_create_params import (
|
from openai.types.beta.thread_create_params import (
|
||||||
Message as OpenAICreateThreadParamsMessage,
|
Message as OpenAICreateThreadParamsMessage,
|
||||||
)
|
)
|
||||||
|
from openai.lib.streaming._assistants import (
|
||||||
|
AssistantEventHandler,
|
||||||
|
AssistantStreamManager,
|
||||||
|
AsyncAssistantStreamManager,
|
||||||
|
AsyncAssistantEventHandler,
|
||||||
|
)
|
||||||
from openai.types.beta.assistant_tool_param import AssistantToolParam
|
from openai.types.beta.assistant_tool_param import AssistantToolParam
|
||||||
from openai.types.beta.threads.run import Run
|
from openai.types.beta.threads.run import Run
|
||||||
from openai.types.beta.assistant import Assistant
|
from openai.types.beta.assistant import Assistant
|
||||||
|
|
|
@ -446,3 +446,8 @@ class ModelGroupInfo(BaseModel):
|
||||||
supports_vision: bool = Field(default=False)
|
supports_vision: bool = Field(default=False)
|
||||||
supports_function_calling: bool = Field(default=False)
|
supports_function_calling: bool = Field(default=False)
|
||||||
supported_openai_params: List[str] = Field(default=[])
|
supported_openai_params: List[str] = Field(default=[])
|
||||||
|
|
||||||
|
|
||||||
|
class AssistantsTypedDict(TypedDict):
|
||||||
|
custom_llm_provider: Literal["azure", "openai"]
|
||||||
|
litellm_params: LiteLLMParamsTypedDict
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue