diff --git a/litellm/assistants/main.py b/litellm/assistants/main.py index 1486b8984..1ab85a5b1 100644 --- a/litellm/assistants/main.py +++ b/litellm/assistants/main.py @@ -905,6 +905,14 @@ async def arun_thread( ) +def run_thread_stream( + *, + event_handler: Optional[AssistantEventHandler] = None, + **kwargs, +) -> AssistantStreamManager[AssistantEventHandler]: + return run_thread(stream=True, event_handler=event_handler, **kwargs) # type: ignore + + def run_thread( custom_llm_provider: Literal["openai", "azure"], thread_id: str, @@ -916,6 +924,7 @@ def run_thread( stream: Optional[bool] = None, tools: Optional[Iterable[AssistantToolParam]] = None, client: Optional[Any] = None, + event_handler: Optional[AssistantEventHandler] = None, # for stream=True calls **kwargs, ) -> Run: """Run a given thread + assistant.""" @@ -959,6 +968,7 @@ def run_thread( or litellm.openai_key or os.getenv("OPENAI_API_KEY") ) + response = openai_assistants_api.run_thread( thread_id=thread_id, assistant_id=assistant_id, @@ -975,6 +985,7 @@ def run_thread( organization=organization, client=client, arun_thread=arun_thread, + event_handler=event_handler, ) elif custom_llm_provider == "azure": api_base = ( diff --git a/litellm/llms/azure.py b/litellm/llms/azure.py index 709385ef7..c907e3b0e 100644 --- a/litellm/llms/azure.py +++ b/litellm/llms/azure.py @@ -31,6 +31,10 @@ from ..types.llms.openai import ( Thread, AssistantToolParam, Run, + AssistantEventHandler, + AsyncAssistantEventHandler, + AsyncAssistantStreamManager, + AssistantStreamManager, ) @@ -1975,27 +1979,67 @@ class AzureAssistantsAPI(BaseLLM): ) response = await openai_client.beta.threads.runs.create_and_poll( # type: ignore - thread_id="thread_OHLZkEj5xJLxdk0REZ4cl9sP", - assistant_id="asst_nIzr656D1GIVMLHOKD76bN2T", - additional_instructions=None, - instructions=None, - metadata=None, - model=None, - tools=None, + thread_id=thread_id, + assistant_id=assistant_id, + additional_instructions=additional_instructions, + instructions=instructions, + metadata=metadata, + model=model, + tools=tools, ) - # response = await openai_client.beta.threads.runs.create_and_poll( # type: ignore - # thread_id=thread_id, - # assistant_id=assistant_id, - # additional_instructions=additional_instructions, - # instructions=instructions, - # metadata=metadata, - # model=model, - # tools=tools, - # ) - return response + async def async_run_thread_stream( + self, + client: AsyncAzureOpenAI, + thread_id: str, + assistant_id: str, + additional_instructions: Optional[str], + instructions: Optional[str], + metadata: Optional[object], + model: Optional[str], + tools: Optional[Iterable[AssistantToolParam]], + event_handler: Optional[AssistantEventHandler], + ) -> AsyncAssistantStreamManager[AsyncAssistantEventHandler]: + data = { + "thread_id": thread_id, + "assistant_id": assistant_id, + "additional_instructions": additional_instructions, + "instructions": instructions, + "metadata": metadata, + "model": model, + "tools": tools, + } + if event_handler is not None: + data["event_handler"] = event_handler + return client.beta.threads.runs.stream(**data) # type: ignore + + def run_thread_stream( + self, + client: AzureOpenAI, + thread_id: str, + assistant_id: str, + additional_instructions: Optional[str], + instructions: Optional[str], + metadata: Optional[object], + model: Optional[str], + tools: Optional[Iterable[AssistantToolParam]], + event_handler: Optional[AssistantEventHandler], + ) -> AssistantStreamManager[AssistantEventHandler]: + data = { + "thread_id": thread_id, + "assistant_id": assistant_id, + "additional_instructions": additional_instructions, + "instructions": instructions, + "metadata": metadata, + "model": model, + "tools": tools, + } + if event_handler is not None: + data["event_handler"] = event_handler + return client.beta.threads.runs.stream(**data) # type: ignore + # fmt: off @overload @@ -2062,8 +2106,30 @@ class AzureAssistantsAPI(BaseLLM): max_retries: Optional[int], client=None, arun_thread=None, + event_handler: Optional[AssistantEventHandler] = None, ): if arun_thread is not None and arun_thread == True: + if stream is not None and stream == True: + azure_client = self.async_get_azure_client( + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + ) + return self.async_run_thread_stream( + client=azure_client, + thread_id=thread_id, + assistant_id=assistant_id, + additional_instructions=additional_instructions, + instructions=instructions, + metadata=metadata, + model=model, + tools=tools, + event_handler=event_handler, + ) return self.arun_thread( thread_id=thread_id, assistant_id=assistant_id, @@ -2091,6 +2157,19 @@ class AzureAssistantsAPI(BaseLLM): client=client, ) + if stream is not None and stream == True: + return self.run_thread_stream( + client=openai_client, + thread_id=thread_id, + assistant_id=assistant_id, + additional_instructions=additional_instructions, + instructions=instructions, + metadata=metadata, + model=model, + tools=tools, + event_handler=event_handler, + ) + response = openai_client.beta.threads.runs.create_and_poll( # type: ignore thread_id=thread_id, assistant_id=assistant_id, diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index 03657f0ee..69e510ae7 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -2534,6 +2534,56 @@ class OpenAIAssistantsAPI(BaseLLM): return response + async def async_run_thread_stream( + self, + client: AsyncOpenAI, + thread_id: str, + assistant_id: str, + additional_instructions: Optional[str], + instructions: Optional[str], + metadata: Optional[object], + model: Optional[str], + tools: Optional[Iterable[AssistantToolParam]], + event_handler: Optional[AssistantEventHandler], + ) -> AsyncAssistantStreamManager[AsyncAssistantEventHandler]: + data = { + "thread_id": thread_id, + "assistant_id": assistant_id, + "additional_instructions": additional_instructions, + "instructions": instructions, + "metadata": metadata, + "model": model, + "tools": tools, + } + if event_handler is not None: + data["event_handler"] = event_handler + return client.beta.threads.runs.stream(**data) # type: ignore + + def run_thread_stream( + self, + client: OpenAI, + thread_id: str, + assistant_id: str, + additional_instructions: Optional[str], + instructions: Optional[str], + metadata: Optional[object], + model: Optional[str], + tools: Optional[Iterable[AssistantToolParam]], + event_handler: Optional[AssistantEventHandler], + ) -> AssistantStreamManager[AssistantEventHandler]: + data = { + "thread_id": thread_id, + "assistant_id": assistant_id, + "additional_instructions": additional_instructions, + "instructions": instructions, + "metadata": metadata, + "model": model, + "tools": tools, + } + if event_handler is not None: + data["event_handler"] = event_handler + return client.beta.threads.runs.stream(**data) # type: ignore + # fmt: off @overload @@ -2554,6 +2604,7 @@ class OpenAIAssistantsAPI(BaseLLM): organization: Optional[str], client, arun_thread: Literal[True], + event_handler: Optional[AssistantEventHandler], ) -> Coroutine[None, None, Run]: ... @@ -2575,6 +2626,7 @@ class OpenAIAssistantsAPI(BaseLLM): organization: Optional[str], client, arun_thread: Optional[Literal[False]], + event_handler: Optional[AssistantEventHandler], ) -> Run: ... @@ -2597,8 +2649,29 @@ class OpenAIAssistantsAPI(BaseLLM): organization: Optional[str], client=None, arun_thread=None, + event_handler: Optional[AssistantEventHandler] = None, ): if arun_thread is not None and arun_thread == True: + if stream is not None and stream == True: + _client = self.async_get_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) + return self.async_run_thread_stream( + client=_client, + thread_id=thread_id, + assistant_id=assistant_id, + additional_instructions=additional_instructions, + instructions=instructions, + metadata=metadata, + model=model, + tools=tools, + event_handler=event_handler, + ) return self.arun_thread( thread_id=thread_id, assistant_id=assistant_id, @@ -2624,6 +2697,19 @@ class OpenAIAssistantsAPI(BaseLLM): client=client, ) + if stream is not None and stream == True: + return self.run_thread_stream( + client=openai_client, + thread_id=thread_id, + assistant_id=assistant_id, + additional_instructions=additional_instructions, + instructions=instructions, + metadata=metadata, + model=model, + tools=tools, + event_handler=event_handler, + ) + response = openai_client.beta.threads.runs.create_and_poll( # type: ignore thread_id=thread_id, assistant_id=assistant_id, diff --git a/litellm/proxy/_super_secret_config.yaml b/litellm/proxy/_super_secret_config.yaml index b6a6a06ad..f9a4d51fe 100644 --- a/litellm/proxy/_super_secret_config.yaml +++ b/litellm/proxy/_super_secret_config.yaml @@ -40,8 +40,14 @@ model_list: vertex_project: my-project-9d5c vertex_location: us-central1 +assistant_settings: + custom_llm_provider: openai + litellm_params: + api_key: os.environ/OPENAI_API_KEY + + router_settings: enable_pre_call_checks: true litellm_settings: - success_callback: ["langfuse"] \ No newline at end of file + success_callback: ["langfuse"] diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index c154bd8e9..4751b6a6f 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -117,7 +117,13 @@ import pydantic from litellm.proxy._types import * from litellm.caching import DualCache, RedisCache from litellm.proxy.health_check import perform_health_check -from litellm.router import LiteLLM_Params, Deployment, updateDeployment, ModelGroupInfo +from litellm.router import ( + LiteLLM_Params, + Deployment, + updateDeployment, + ModelGroupInfo, + AssistantsTypedDict, +) from litellm.router import ModelInfo as RouterModelInfo from litellm._logging import verbose_router_logger, verbose_proxy_logger from litellm.proxy.auth.handle_jwt import JWTHandler @@ -2853,6 +2859,17 @@ class ProxyConfig: if "ollama" in litellm_model_name and litellm_model_api_base is None: run_ollama_serve() + ## ASSISTANT SETTINGS + assistants_config: Optional[AssistantsTypedDict] = None + assistant_settings = config.get("assistant_settings", None) + if assistant_settings: + for k, v in assistant_settings["litellm_params"].items(): + if isinstance(v, str) and v.startswith("os.environ/"): + _v = v.replace("os.environ/", "") + v = os.getenv(_v) + assistant_settings["litellm_params"][k] = v + assistants_config = AssistantsTypedDict(**assistant_settings) # type: ignore + ## ROUTER SETTINGS (e.g. routing_strategy, ...) router_settings = config.get("router_settings", None) if router_settings and isinstance(router_settings, dict): @@ -2868,7 +2885,9 @@ class ProxyConfig: for k, v in router_settings.items(): if k in available_args: router_params[k] = v - router = litellm.Router(**router_params) # type:ignore + router = litellm.Router( + **router_params, assistants_config=assistants_config + ) # type:ignore return router, model_list, general_settings def get_model_info_with_id(self, model, db_model=False) -> RouterModelInfo: @@ -3606,6 +3625,60 @@ def data_generator(response): yield f"data: {json.dumps(chunk)}\n\n" +async def async_assistants_data_generator( + response, user_api_key_dict: UserAPIKeyAuth, request_data: dict +): + verbose_proxy_logger.debug("inside generator") + try: + start_time = time.time() + async with response as chunk: + + ### CALL HOOKS ### - modify outgoing data + chunk = await proxy_logging_obj.async_post_call_streaming_hook( + user_api_key_dict=user_api_key_dict, response=chunk + ) + + # chunk = chunk.model_dump_json(exclude_none=True) + async for c in chunk: + c = c.model_dump_json(exclude_none=True) + try: + yield f"data: {c}\n\n" + except Exception as e: + yield f"data: {str(e)}\n\n" + + # Streaming is done, yield the [DONE] chunk + done_message = "[DONE]" + yield f"data: {done_message}\n\n" + except Exception as e: + traceback.print_exc() + await proxy_logging_obj.post_call_failure_hook( + user_api_key_dict=user_api_key_dict, + original_exception=e, + request_data=request_data, + ) + verbose_proxy_logger.debug( + f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`" + ) + router_model_names = llm_router.model_names if llm_router is not None else [] + if user_debug: + traceback.print_exc() + + if isinstance(e, HTTPException): + raise e + else: + error_traceback = traceback.format_exc() + error_msg = f"{str(e)}\n\n{error_traceback}" + + proxy_exception = ProxyException( + message=getattr(e, "message", error_msg), + type=getattr(e, "type", "None"), + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", 500), + ) + error_returned = json.dumps({"error": proxy_exception.to_dict()}) + yield f"data: {error_returned}\n\n" + + async def async_data_generator( response, user_api_key_dict: UserAPIKeyAuth, request_data: dict ): @@ -5347,7 +5420,6 @@ async def get_assistants( try: # Use orjson to parse JSON data, orjson speeds up requests significantly body = await request.body() - data = orjson.loads(body) # Include original request and headers in the data data["proxy_server_request"] = { # type: ignore @@ -5405,9 +5477,7 @@ async def get_assistants( raise HTTPException( status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value} ) - response = await llm_router.aget_assistants( - custom_llm_provider="openai", client=None, **data - ) + response = await llm_router.aget_assistants(**data) ### ALERTING ### data["litellm_status"] = "success" # used for alerting @@ -5479,7 +5549,6 @@ async def create_threads( try: # Use orjson to parse JSON data, orjson speeds up requests significantly body = await request.body() - data = orjson.loads(body) # Include original request and headers in the data data["proxy_server_request"] = { # type: ignore @@ -5537,9 +5606,7 @@ async def create_threads( raise HTTPException( status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value} ) - response = await llm_router.acreate_thread( - custom_llm_provider="openai", client=None, **data - ) + response = await llm_router.acreate_thread(**data) ### ALERTING ### data["litellm_status"] = "success" # used for alerting @@ -5667,9 +5734,7 @@ async def get_thread( raise HTTPException( status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value} ) - response = await llm_router.aget_thread( - custom_llm_provider="openai", thread_id=thread_id, client=None, **data - ) + response = await llm_router.aget_thread(thread_id=thread_id, **data) ### ALERTING ### data["litellm_status"] = "success" # used for alerting @@ -5800,9 +5865,7 @@ async def add_messages( raise HTTPException( status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value} ) - response = await llm_router.a_add_message( - custom_llm_provider="openai", thread_id=thread_id, client=None, **data - ) + response = await llm_router.a_add_message(thread_id=thread_id, **data) ### ALERTING ### data["litellm_status"] = "success" # used for alerting @@ -5929,9 +5992,7 @@ async def get_messages( raise HTTPException( status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value} ) - response = await llm_router.aget_messages( - custom_llm_provider="openai", thread_id=thread_id, client=None, **data - ) + response = await llm_router.aget_messages(thread_id=thread_id, **data) ### ALERTING ### data["litellm_status"] = "success" # used for alerting @@ -6060,9 +6121,19 @@ async def run_thread( raise HTTPException( status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value} ) - response = await llm_router.arun_thread( - custom_llm_provider="openai", thread_id=thread_id, client=None, **data - ) + response = await llm_router.arun_thread(thread_id=thread_id, **data) + + if ( + "stream" in data and data["stream"] == True + ): # use generate_responses to stream responses + return StreamingResponse( + async_assistants_data_generator( + user_api_key_dict=user_api_key_dict, + response=response, + request_data=data, + ), + media_type="text/event-stream", + ) ### ALERTING ### data["litellm_status"] = "success" # used for alerting diff --git a/litellm/router.py b/litellm/router.py index a5de577e3..e3fed496f 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -51,6 +51,7 @@ from litellm.types.router import ( AlertingConfig, DeploymentTypedDict, ModelGroupInfo, + AssistantsTypedDict, ) from litellm.integrations.custom_logger import CustomLogger from litellm.llms.azure import get_azure_ad_token_from_oidc @@ -78,6 +79,8 @@ class Router: def __init__( self, model_list: Optional[List[Union[DeploymentTypedDict, Dict]]] = None, + ## ASSISTANTS API ## + assistants_config: Optional[AssistantsTypedDict] = None, ## CACHING ## redis_url: Optional[str] = None, redis_host: Optional[str] = None, @@ -212,6 +215,7 @@ class Router: elif debug_level == "DEBUG": verbose_router_logger.setLevel(logging.DEBUG) + self.assistants_config = assistants_config self.deployment_names: List = ( [] ) # names of models under litellm_params. ex. azure/chatgpt-v-2 @@ -1831,31 +1835,56 @@ class Router: async def aget_assistants( self, - custom_llm_provider: Literal["openai"], + custom_llm_provider: Optional[Literal["openai", "azure"]] = None, client: Optional[AsyncOpenAI] = None, **kwargs, ) -> AsyncCursorPage[Assistant]: + if custom_llm_provider is None: + if self.assistants_config is not None: + custom_llm_provider = self.assistants_config["custom_llm_provider"] + kwargs.update(self.assistants_config["litellm_params"]) + else: + raise Exception( + "'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`" + ) + return await litellm.aget_assistants( custom_llm_provider=custom_llm_provider, client=client, **kwargs ) async def acreate_thread( self, - custom_llm_provider: Literal["openai"], + custom_llm_provider: Optional[Literal["openai", "azure"]] = None, client: Optional[AsyncOpenAI] = None, **kwargs, ) -> Thread: + if custom_llm_provider is None: + if self.assistants_config is not None: + custom_llm_provider = self.assistants_config["custom_llm_provider"] + kwargs.update(self.assistants_config["litellm_params"]) + else: + raise Exception( + "'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`" + ) return await litellm.acreate_thread( custom_llm_provider=custom_llm_provider, client=client, **kwargs ) async def aget_thread( self, - custom_llm_provider: Literal["openai"], thread_id: str, + custom_llm_provider: Optional[Literal["openai", "azure"]] = None, client: Optional[AsyncOpenAI] = None, **kwargs, ) -> Thread: + if custom_llm_provider is None: + if self.assistants_config is not None: + custom_llm_provider = self.assistants_config["custom_llm_provider"] + kwargs.update(self.assistants_config["litellm_params"]) + else: + raise Exception( + "'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`" + ) return await litellm.aget_thread( custom_llm_provider=custom_llm_provider, thread_id=thread_id, @@ -1865,15 +1894,24 @@ class Router: async def a_add_message( self, - custom_llm_provider: Literal["openai"], thread_id: str, role: Literal["user", "assistant"], content: str, attachments: Optional[List[Attachment]] = None, metadata: Optional[dict] = None, + custom_llm_provider: Optional[Literal["openai", "azure"]] = None, client: Optional[AsyncOpenAI] = None, **kwargs, ) -> OpenAIMessage: + if custom_llm_provider is None: + if self.assistants_config is not None: + custom_llm_provider = self.assistants_config["custom_llm_provider"] + kwargs.update(self.assistants_config["litellm_params"]) + else: + raise Exception( + "'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`" + ) + return await litellm.a_add_message( custom_llm_provider=custom_llm_provider, thread_id=thread_id, @@ -1887,11 +1925,19 @@ class Router: async def aget_messages( self, - custom_llm_provider: Literal["openai"], thread_id: str, + custom_llm_provider: Optional[Literal["openai", "azure"]] = None, client: Optional[AsyncOpenAI] = None, **kwargs, ) -> AsyncCursorPage[OpenAIMessage]: + if custom_llm_provider is None: + if self.assistants_config is not None: + custom_llm_provider = self.assistants_config["custom_llm_provider"] + kwargs.update(self.assistants_config["litellm_params"]) + else: + raise Exception( + "'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`" + ) return await litellm.aget_messages( custom_llm_provider=custom_llm_provider, thread_id=thread_id, @@ -1901,9 +1947,9 @@ class Router: async def arun_thread( self, - custom_llm_provider: Literal["openai"], thread_id: str, assistant_id: str, + custom_llm_provider: Optional[Literal["openai", "azure"]] = None, additional_instructions: Optional[str] = None, instructions: Optional[str] = None, metadata: Optional[dict] = None, @@ -1913,6 +1959,16 @@ class Router: client: Optional[Any] = None, **kwargs, ) -> Run: + + if custom_llm_provider is None: + if self.assistants_config is not None: + custom_llm_provider = self.assistants_config["custom_llm_provider"] + kwargs.update(self.assistants_config["litellm_params"]) + else: + raise Exception( + "'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`" + ) + return await litellm.arun_thread( custom_llm_provider=custom_llm_provider, thread_id=thread_id, diff --git a/litellm/tests/test_assistants.py b/litellm/tests/test_assistants.py index 4577b5e6f..cf4adf0fa 100644 --- a/litellm/tests/test_assistants.py +++ b/litellm/tests/test_assistants.py @@ -18,7 +18,9 @@ from litellm.llms.openai import ( OpenAIMessage as Message, AsyncCursorPage, SyncCursorPage, + AssistantEventHandler, ) +from typing_extensions import override """ V0 Scope: @@ -129,10 +131,46 @@ async def test_add_message_litellm(sync_mode, provider): assert isinstance(added_message, Message) -@pytest.mark.parametrize("provider", ["openai", "azure"]) -@pytest.mark.parametrize("sync_mode", [True, False]) +class EventHandler(AssistantEventHandler): + @override + def on_text_created(self, text) -> None: + print(f"\nassistant > ", end="", flush=True) + + @override + def on_text_delta(self, delta, snapshot): + print(delta.value, end="", flush=True) + + def on_tool_call_created(self, tool_call): + print(f"\nassistant > {tool_call.type}\n", flush=True) + + def on_tool_call_delta(self, delta, snapshot): + if delta.type == "code_interpreter": + if delta.code_interpreter.input: + print(delta.code_interpreter.input, end="", flush=True) + if delta.code_interpreter.outputs: + print(f"\n\noutput >", flush=True) + for output in delta.code_interpreter.outputs: + if output.type == "logs": + print(f"\n{output.logs}", flush=True) + + +@pytest.mark.parametrize( + "provider", + [ + "azure", + "openai", + ], +) # +@pytest.mark.parametrize( + "sync_mode", + [False, True], +) # +@pytest.mark.parametrize( + "is_streaming", + [True, False], +) # @pytest.mark.asyncio -async def test_aarun_thread_litellm(sync_mode, provider): +async def test_aarun_thread_litellm(sync_mode, provider, is_streaming): """ - Get Assistants - Create thread @@ -163,27 +201,48 @@ async def test_aarun_thread_litellm(sync_mode, provider): if sync_mode: added_message = litellm.add_message(**data) - run = litellm.run_thread(assistant_id=assistant_id, **data) - - if run.status == "completed": - messages = litellm.get_messages( - thread_id=_new_thread.id, custom_llm_provider=provider + if is_streaming: + run = litellm.run_thread_stream( + assistant_id=assistant_id, event_handler=EventHandler(), **data ) - assert isinstance(messages.data[0], Message) + with run as run: + assert isinstance(run, AssistantEventHandler) + print(run) + run.until_done() else: - pytest.fail("An unexpected error occurred when running the thread") + run = litellm.run_thread( + assistant_id=assistant_id, stream=is_streaming, **data + ) + if run.status == "completed": + messages = litellm.get_messages( + thread_id=_new_thread.id, custom_llm_provider=provider + ) + assert isinstance(messages.data[0], Message) + else: + pytest.fail("An unexpected error occurred when running the thread") else: added_message = await litellm.a_add_message(**data) - run = await litellm.arun_thread( - custom_llm_provider=provider, thread_id=thread_id, assistant_id=assistant_id - ) - - if run.status == "completed": - messages = await litellm.aget_messages( - thread_id=_new_thread.id, custom_llm_provider=provider + if is_streaming: + run = litellm.run_thread_stream( + assistant_id=assistant_id, event_handler=EventHandler(), **data ) - assert isinstance(messages.data[0], Message) + with run as run: + assert isinstance(run, AssistantEventHandler) + print(run) + run.until_done() else: - pytest.fail("An unexpected error occurred when running the thread") + run = await litellm.arun_thread( + custom_llm_provider=provider, + thread_id=thread_id, + assistant_id=assistant_id, + ) + + if run.status == "completed": + messages = await litellm.aget_messages( + thread_id=_new_thread.id, custom_llm_provider=provider + ) + assert isinstance(messages.data[0], Message) + else: + pytest.fail("An unexpected error occurred when running the thread") diff --git a/litellm/types/llms/openai.py b/litellm/types/llms/openai.py index 885ed6053..bc0c82434 100644 --- a/litellm/types/llms/openai.py +++ b/litellm/types/llms/openai.py @@ -13,6 +13,12 @@ from openai.types.beta.threads.message import Message as OpenAIMessage from openai.types.beta.thread_create_params import ( Message as OpenAICreateThreadParamsMessage, ) +from openai.lib.streaming._assistants import ( + AssistantEventHandler, + AssistantStreamManager, + AsyncAssistantStreamManager, + AsyncAssistantEventHandler, +) from openai.types.beta.assistant_tool_param import AssistantToolParam from openai.types.beta.threads.run import Run from openai.types.beta.assistant import Assistant diff --git a/litellm/types/router.py b/litellm/types/router.py index 38ddef361..4398ab24f 100644 --- a/litellm/types/router.py +++ b/litellm/types/router.py @@ -446,3 +446,8 @@ class ModelGroupInfo(BaseModel): supports_vision: bool = Field(default=False) supports_function_calling: bool = Field(default=False) supported_openai_params: List[str] = Field(default=[]) + + +class AssistantsTypedDict(TypedDict): + custom_llm_provider: Literal["azure", "openai"] + litellm_params: LiteLLMParamsTypedDict