feat(assistants/main.py): add assistants api streaming support

This commit is contained in:
Krrish Dholakia 2024-06-04 16:30:35 -07:00
parent 7b474ec267
commit f3d78532f9
9 changed files with 444 additions and 65 deletions

View file

@ -51,6 +51,7 @@ from litellm.types.router import (
AlertingConfig,
DeploymentTypedDict,
ModelGroupInfo,
AssistantsTypedDict,
)
from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.azure import get_azure_ad_token_from_oidc
@ -78,6 +79,8 @@ class Router:
def __init__(
self,
model_list: Optional[List[Union[DeploymentTypedDict, Dict]]] = None,
## ASSISTANTS API ##
assistants_config: Optional[AssistantsTypedDict] = None,
## CACHING ##
redis_url: Optional[str] = None,
redis_host: Optional[str] = None,
@ -212,6 +215,7 @@ class Router:
elif debug_level == "DEBUG":
verbose_router_logger.setLevel(logging.DEBUG)
self.assistants_config = assistants_config
self.deployment_names: List = (
[]
) # names of models under litellm_params. ex. azure/chatgpt-v-2
@ -1831,31 +1835,56 @@ class Router:
async def aget_assistants(
self,
custom_llm_provider: Literal["openai"],
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
client: Optional[AsyncOpenAI] = None,
**kwargs,
) -> AsyncCursorPage[Assistant]:
if custom_llm_provider is None:
if self.assistants_config is not None:
custom_llm_provider = self.assistants_config["custom_llm_provider"]
kwargs.update(self.assistants_config["litellm_params"])
else:
raise Exception(
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
)
return await litellm.aget_assistants(
custom_llm_provider=custom_llm_provider, client=client, **kwargs
)
async def acreate_thread(
self,
custom_llm_provider: Literal["openai"],
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
client: Optional[AsyncOpenAI] = None,
**kwargs,
) -> Thread:
if custom_llm_provider is None:
if self.assistants_config is not None:
custom_llm_provider = self.assistants_config["custom_llm_provider"]
kwargs.update(self.assistants_config["litellm_params"])
else:
raise Exception(
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
)
return await litellm.acreate_thread(
custom_llm_provider=custom_llm_provider, client=client, **kwargs
)
async def aget_thread(
self,
custom_llm_provider: Literal["openai"],
thread_id: str,
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
client: Optional[AsyncOpenAI] = None,
**kwargs,
) -> Thread:
if custom_llm_provider is None:
if self.assistants_config is not None:
custom_llm_provider = self.assistants_config["custom_llm_provider"]
kwargs.update(self.assistants_config["litellm_params"])
else:
raise Exception(
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
)
return await litellm.aget_thread(
custom_llm_provider=custom_llm_provider,
thread_id=thread_id,
@ -1865,15 +1894,24 @@ class Router:
async def a_add_message(
self,
custom_llm_provider: Literal["openai"],
thread_id: str,
role: Literal["user", "assistant"],
content: str,
attachments: Optional[List[Attachment]] = None,
metadata: Optional[dict] = None,
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
client: Optional[AsyncOpenAI] = None,
**kwargs,
) -> OpenAIMessage:
if custom_llm_provider is None:
if self.assistants_config is not None:
custom_llm_provider = self.assistants_config["custom_llm_provider"]
kwargs.update(self.assistants_config["litellm_params"])
else:
raise Exception(
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
)
return await litellm.a_add_message(
custom_llm_provider=custom_llm_provider,
thread_id=thread_id,
@ -1887,11 +1925,19 @@ class Router:
async def aget_messages(
self,
custom_llm_provider: Literal["openai"],
thread_id: str,
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
client: Optional[AsyncOpenAI] = None,
**kwargs,
) -> AsyncCursorPage[OpenAIMessage]:
if custom_llm_provider is None:
if self.assistants_config is not None:
custom_llm_provider = self.assistants_config["custom_llm_provider"]
kwargs.update(self.assistants_config["litellm_params"])
else:
raise Exception(
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
)
return await litellm.aget_messages(
custom_llm_provider=custom_llm_provider,
thread_id=thread_id,
@ -1901,9 +1947,9 @@ class Router:
async def arun_thread(
self,
custom_llm_provider: Literal["openai"],
thread_id: str,
assistant_id: str,
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
additional_instructions: Optional[str] = None,
instructions: Optional[str] = None,
metadata: Optional[dict] = None,
@ -1913,6 +1959,16 @@ class Router:
client: Optional[Any] = None,
**kwargs,
) -> Run:
if custom_llm_provider is None:
if self.assistants_config is not None:
custom_llm_provider = self.assistants_config["custom_llm_provider"]
kwargs.update(self.assistants_config["litellm_params"])
else:
raise Exception(
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
)
return await litellm.arun_thread(
custom_llm_provider=custom_llm_provider,
thread_id=thread_id,