diff --git a/litellm/assistants/main.py b/litellm/assistants/main.py index 25d2433d7b..848a83e53c 100644 --- a/litellm/assistants/main.py +++ b/litellm/assistants/main.py @@ -1,11 +1,12 @@ # What is this? ## Main file for assistants API logic from typing import Iterable -import os +from functools import partial +import os, asyncio, contextvars import litellm -from openai import OpenAI +from openai import OpenAI, AsyncOpenAI from litellm import client -from litellm.utils import supports_httpx_timeout +from litellm.utils import supports_httpx_timeout, exception_type, get_llm_provider from ..llms.openai import OpenAIAssistantsAPI from ..types.llms.openai import * from ..types.router import * @@ -16,11 +17,49 @@ openai_assistants_api = OpenAIAssistantsAPI() ### ASSISTANTS ### +async def aget_assistants( + custom_llm_provider: Literal["openai"], + client: Optional[AsyncOpenAI] = None, + **kwargs, +) -> AsyncCursorPage[Assistant]: + loop = asyncio.get_event_loop() + ### PASS ARGS TO GET ASSISTANTS ### + kwargs["aget_assistants"] = True + try: + # Use a partial function to pass your keyword arguments + func = partial(get_assistants, custom_llm_provider, client, **kwargs) + + # Add the context to the function + ctx = contextvars.copy_context() + func_with_context = partial(ctx.run, func) + + _, custom_llm_provider, _, _ = get_llm_provider( # type: ignore + model="", custom_llm_provider=custom_llm_provider + ) # type: ignore + + # Await normally + init_response = await loop.run_in_executor(None, func_with_context) + if asyncio.iscoroutine(init_response): + response = await init_response + else: + response = init_response + return response # type: ignore + except Exception as e: + raise exception_type( + model="", + custom_llm_provider=custom_llm_provider, + original_exception=e, + completion_kwargs={}, + extra_kwargs=kwargs, + ) + + def get_assistants( custom_llm_provider: Literal["openai"], client: Optional[OpenAI] = None, **kwargs, ) -> SyncCursorPage[Assistant]: + aget_assistants = kwargs.pop("aget_assistants", None) optional_params = GenericLiteLLMParams(**kwargs) ### TIMEOUT LOGIC ### @@ -67,6 +106,7 @@ def get_assistants( max_retries=optional_params.max_retries, organization=organization, client=client, + aget_assistants=aget_assistants, ) else: raise litellm.exceptions.BadRequestError( @@ -87,6 +127,39 @@ def get_assistants( ### THREADS ### +async def acreate_thread(custom_llm_provider: Literal["openai"], **kwargs) -> Thread: + loop = asyncio.get_event_loop() + ### PASS ARGS TO GET ASSISTANTS ### + kwargs["acreate_thread"] = True + try: + # Use a partial function to pass your keyword arguments + func = partial(create_thread, custom_llm_provider, **kwargs) + + # Add the context to the function + ctx = contextvars.copy_context() + func_with_context = partial(ctx.run, func) + + _, custom_llm_provider, _, _ = get_llm_provider( # type: ignore + model="", custom_llm_provider=custom_llm_provider + ) # type: ignore + + # Await normally + init_response = await loop.run_in_executor(None, func_with_context) + if asyncio.iscoroutine(init_response): + response = await init_response + else: + response = init_response + return response # type: ignore + except Exception as e: + raise exception_type( + model="", + custom_llm_provider=custom_llm_provider, + original_exception=e, + completion_kwargs={}, + extra_kwargs=kwargs, + ) + + def create_thread( custom_llm_provider: Literal["openai"], messages: Optional[Iterable[OpenAICreateThreadParamsMessage]] = None, @@ -117,6 +190,7 @@ def create_thread( ) ``` """ + acreate_thread = kwargs.get("acreate_thread", None) optional_params = GenericLiteLLMParams(**kwargs) ### TIMEOUT LOGIC ### @@ -165,6 +239,7 @@ def create_thread( max_retries=optional_params.max_retries, organization=organization, client=client, + acreate_thread=acreate_thread, ) else: raise litellm.exceptions.BadRequestError( @@ -182,6 +257,44 @@ def create_thread( return response +async def aget_thread( + custom_llm_provider: Literal["openai"], + thread_id: str, + client: Optional[AsyncOpenAI] = None, + **kwargs, +) -> Thread: + loop = asyncio.get_event_loop() + ### PASS ARGS TO GET ASSISTANTS ### + kwargs["aget_thread"] = True + try: + # Use a partial function to pass your keyword arguments + func = partial(get_thread, custom_llm_provider, thread_id, client, **kwargs) + + # Add the context to the function + ctx = contextvars.copy_context() + func_with_context = partial(ctx.run, func) + + _, custom_llm_provider, _, _ = get_llm_provider( # type: ignore + model="", custom_llm_provider=custom_llm_provider + ) # type: ignore + + # Await normally + init_response = await loop.run_in_executor(None, func_with_context) + if asyncio.iscoroutine(init_response): + response = await init_response + else: + response = init_response + return response # type: ignore + except Exception as e: + raise exception_type( + model="", + custom_llm_provider=custom_llm_provider, + original_exception=e, + completion_kwargs={}, + extra_kwargs=kwargs, + ) + + def get_thread( custom_llm_provider: Literal["openai"], thread_id: str, @@ -189,6 +302,7 @@ def get_thread( **kwargs, ) -> Thread: """Get the thread object, given a thread_id""" + aget_thread = kwargs.pop("aget_thread", None) optional_params = GenericLiteLLMParams(**kwargs) ### TIMEOUT LOGIC ### @@ -236,6 +350,7 @@ def get_thread( max_retries=optional_params.max_retries, organization=organization, client=client, + aget_thread=aget_thread, ) else: raise litellm.exceptions.BadRequestError( @@ -256,6 +371,59 @@ def get_thread( ### MESSAGES ### +async def a_add_message( + custom_llm_provider: Literal["openai"], + thread_id: str, + role: Literal["user", "assistant"], + content: str, + attachments: Optional[List[Attachment]] = None, + metadata: Optional[dict] = None, + client: Optional[AsyncOpenAI] = None, + **kwargs, +) -> OpenAIMessage: + loop = asyncio.get_event_loop() + ### PASS ARGS TO GET ASSISTANTS ### + kwargs["a_add_message"] = True + try: + # Use a partial function to pass your keyword arguments + func = partial( + add_message, + custom_llm_provider, + thread_id, + role, + content, + attachments, + metadata, + client, + **kwargs, + ) + + # Add the context to the function + ctx = contextvars.copy_context() + func_with_context = partial(ctx.run, func) + + _, custom_llm_provider, _, _ = get_llm_provider( # type: ignore + model="", custom_llm_provider=custom_llm_provider + ) # type: ignore + + # Await normally + init_response = await loop.run_in_executor(None, func_with_context) + if asyncio.iscoroutine(init_response): + response = await init_response + else: + # Call the synchronous function using run_in_executor + response = init_response + return response # type: ignore + except Exception as e: + raise exception_type( + model="", + custom_llm_provider=custom_llm_provider, + original_exception=e, + completion_kwargs={}, + extra_kwargs=kwargs, + ) + + def add_message( custom_llm_provider: Literal["openai"], thread_id: str, @@ -267,6 +435,7 @@ def add_message( **kwargs, ) -> OpenAIMessage: ### COMMON OBJECTS ### + a_add_message = kwargs.pop("a_add_message", None) message_data = MessageData( role=role, content=content, attachments=attachments, metadata=metadata ) @@ -318,6 +487,7 @@ def add_message( max_retries=optional_params.max_retries, organization=organization, client=client, + a_add_message=a_add_message, ) else: raise litellm.exceptions.BadRequestError( @@ -336,12 +506,58 @@ def add_message( return response +async def aget_messages( + custom_llm_provider: Literal["openai"], + thread_id: str, + client: Optional[AsyncOpenAI] = None, + **kwargs, +) -> AsyncCursorPage[OpenAIMessage]: + loop = asyncio.get_event_loop() + ### PASS ARGS TO GET ASSISTANTS ### + kwargs["aget_messages"] = True + try: + # Use a partial function to pass your keyword arguments + func = partial( + get_messages, + custom_llm_provider, + thread_id, + client, + **kwargs, + ) + + # Add the context to the function + ctx = contextvars.copy_context() + func_with_context = partial(ctx.run, func) + + _, custom_llm_provider, _, _ = get_llm_provider( # type: ignore + model="", custom_llm_provider=custom_llm_provider + ) # type: ignore + + # Await normally + init_response = await loop.run_in_executor(None, func_with_context) + if asyncio.iscoroutine(init_response): + response = await init_response + else: + # Call the synchronous function using run_in_executor + response = init_response + return response # type: ignore + except Exception as e: + raise exception_type( + model="", + custom_llm_provider=custom_llm_provider, + original_exception=e, + completion_kwargs={}, + extra_kwargs=kwargs, + ) + + def get_messages( custom_llm_provider: Literal["openai"], thread_id: str, client: Optional[OpenAI] = None, **kwargs, ) -> SyncCursorPage[OpenAIMessage]: + aget_messages = kwargs.pop("aget_messages", None) optional_params = GenericLiteLLMParams(**kwargs) ### TIMEOUT LOGIC ### @@ -389,6 +605,7 @@ def get_messages( max_retries=optional_params.max_retries, organization=organization, client=client, + aget_messages=aget_messages, ) else: raise litellm.exceptions.BadRequestError( @@ -408,6 +625,63 @@ def get_messages( ### RUNS ### +async def arun_thread( + custom_llm_provider: Literal["openai"], + thread_id: str, + assistant_id: str, + additional_instructions: Optional[str] = None, + instructions: Optional[str] = None, + metadata: Optional[dict] = None, + model: Optional[str] = None, + stream: Optional[bool] = None, + tools: Optional[Iterable[AssistantToolParam]] = None, + client: Optional[AsyncOpenAI] = None, + **kwargs, +) -> Run: + loop = asyncio.get_event_loop() + ### PASS ARGS TO GET ASSISTANTS ### + kwargs["arun_thread"] = True + try: + # Use a partial function to pass your keyword arguments + func = partial( + run_thread, + custom_llm_provider, + thread_id, + assistant_id, + additional_instructions, + instructions, + metadata, + model, + stream, + tools, + client, + **kwargs, + ) + + # Add the context to the function + ctx = contextvars.copy_context() + func_with_context = partial(ctx.run, func) + + _, custom_llm_provider, _, _ = get_llm_provider( # type: ignore + model="", custom_llm_provider=custom_llm_provider + ) # type: ignore + + # Await normally + init_response = await loop.run_in_executor(None, func_with_context) + if asyncio.iscoroutine(init_response): + response = await init_response + else: + # Call the synchronous function using run_in_executor + response = init_response + return response # type: ignore + except Exception as e: + raise exception_type( + model="", + custom_llm_provider=custom_llm_provider, + original_exception=e, + completion_kwargs={}, + extra_kwargs=kwargs, + ) def run_thread( @@ -424,6 +698,7 @@ def run_thread( **kwargs, ) -> Run: """Run a given thread + assistant.""" + arun_thread = kwargs.pop("arun_thread", None) optional_params = GenericLiteLLMParams(**kwargs) ### TIMEOUT LOGIC ### @@ -478,6 +753,7 @@ def run_thread( max_retries=optional_params.max_retries, organization=organization, client=client, + arun_thread=arun_thread, ) else: raise litellm.exceptions.BadRequestError( diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index 84d9c773fc..6424d46586 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -6,7 +6,7 @@ from typing import ( Literal, Iterable, ) -from typing_extensions import override +from typing_extensions import override, overload from pydantic import BaseModel import types, time, json, traceback import httpx @@ -1846,8 +1846,85 @@ class OpenAIAssistantsAPI(BaseLLM): return openai_client + def async_get_openai_client( + self, + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client: Optional[AsyncOpenAI] = None, + ) -> AsyncOpenAI: + received_args = locals() + if client is None: + data = {} + for k, v in received_args.items(): + if k == "self" or k == "client": + pass + elif k == "api_base" and v is not None: + data["base_url"] = v + elif v is not None: + data[k] = v + openai_client = AsyncOpenAI(**data) # type: ignore + else: + openai_client = client + + return openai_client + ### ASSISTANTS ### + async def async_get_assistants( + self, + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client: Optional[AsyncOpenAI], + ) -> AsyncCursorPage[Assistant]: + openai_client = self.async_get_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) + + response = await openai_client.beta.assistants.list() + + return response + + # fmt: off + + @overload + def get_assistants( + self, + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client: Optional[AsyncOpenAI], + aget_assistants: Literal[True], + ) -> Coroutine[None, None, AsyncCursorPage[Assistant]]: + ... + + @overload + def get_assistants( + self, + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client: Optional[OpenAI], + aget_assistants: Optional[Literal[False]], + ) -> SyncCursorPage[Assistant]: + ... + + # fmt: on + def get_assistants( self, api_key: Optional[str], @@ -1855,8 +1932,18 @@ class OpenAIAssistantsAPI(BaseLLM): timeout: Union[float, httpx.Timeout], max_retries: Optional[int], organization: Optional[str], - client: Optional[OpenAI], - ) -> SyncCursorPage[Assistant]: + client=None, + aget_assistants=None, + ): + if aget_assistants is not None and aget_assistants == True: + return self.async_get_assistants( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) openai_client = self.get_openai_client( api_key=api_key, api_base=api_base, @@ -1872,6 +1959,72 @@ class OpenAIAssistantsAPI(BaseLLM): ### MESSAGES ### + async def a_add_message( + self, + thread_id: str, + message_data: MessageData, + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client: Optional[AsyncOpenAI] = None, + ) -> OpenAIMessage: + openai_client = self.async_get_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) + + thread_message: OpenAIMessage = await openai_client.beta.threads.messages.create( # type: ignore + thread_id, **message_data # type: ignore + ) + + response_obj: Optional[OpenAIMessage] = None + if getattr(thread_message, "status", None) is None: + thread_message.status = "completed" + response_obj = OpenAIMessage(**thread_message.dict()) + else: + response_obj = OpenAIMessage(**thread_message.dict()) + return response_obj + + # fmt: off + + @overload + def add_message( + self, + thread_id: str, + message_data: MessageData, + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client: Optional[AsyncOpenAI], + a_add_message: Literal[True], + ) -> Coroutine[None, None, OpenAIMessage]: + ... + + @overload + def add_message( + self, + thread_id: str, + message_data: MessageData, + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client: Optional[OpenAI], + a_add_message: Optional[Literal[False]], + ) -> OpenAIMessage: + ... + + # fmt: on + def add_message( self, thread_id: str, @@ -1881,9 +2034,20 @@ class OpenAIAssistantsAPI(BaseLLM): timeout: Union[float, httpx.Timeout], max_retries: Optional[int], organization: Optional[str], - client: Optional[OpenAI] = None, - ) -> OpenAIMessage: - + client=None, + a_add_message: Optional[bool] = None, + ): + if a_add_message is not None and a_add_message == True: + return self.a_add_message( + thread_id=thread_id, + message_data=message_data, + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) openai_client = self.get_openai_client( api_key=api_key, api_base=api_base, @@ -1905,6 +2069,61 @@ class OpenAIAssistantsAPI(BaseLLM): response_obj = OpenAIMessage(**thread_message.dict()) return response_obj + async def async_get_messages( + self, + thread_id: str, + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client: Optional[AsyncOpenAI] = None, + ) -> AsyncCursorPage[OpenAIMessage]: + openai_client = self.async_get_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) + + response = await openai_client.beta.threads.messages.list(thread_id=thread_id) + + return response + + # fmt: off + + @overload + def get_messages( + self, + thread_id: str, + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client: Optional[AsyncOpenAI], + aget_messages: Literal[True], + ) -> Coroutine[None, None, AsyncCursorPage[OpenAIMessage]]: + ... + + @overload + def get_messages( + self, + thread_id: str, + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client: Optional[OpenAI], + aget_messages: Optional[Literal[False]], + ) -> SyncCursorPage[OpenAIMessage]: + ... + + # fmt: on + def get_messages( self, thread_id: str, @@ -1913,8 +2132,19 @@ class OpenAIAssistantsAPI(BaseLLM): timeout: Union[float, httpx.Timeout], max_retries: Optional[int], organization: Optional[str], - client: Optional[OpenAI] = None, - ) -> SyncCursorPage[OpenAIMessage]: + client=None, + aget_messages=None, + ): + if aget_messages is not None and aget_messages == True: + return self.async_get_messages( + thread_id=thread_id, + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) openai_client = self.get_openai_client( api_key=api_key, api_base=api_base, @@ -1930,6 +2160,70 @@ class OpenAIAssistantsAPI(BaseLLM): ### THREADS ### + async def async_create_thread( + self, + metadata: Optional[dict], + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client: Optional[AsyncOpenAI], + messages: Optional[Iterable[OpenAICreateThreadParamsMessage]], + ) -> Thread: + openai_client = self.async_get_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) + + data = {} + if messages is not None: + data["messages"] = messages # type: ignore + if metadata is not None: + data["metadata"] = metadata # type: ignore + + message_thread = await openai_client.beta.threads.create(**data) # type: ignore + + return Thread(**message_thread.dict()) + + # fmt: off + + @overload + def create_thread( + self, + metadata: Optional[dict], + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + messages: Optional[Iterable[OpenAICreateThreadParamsMessage]], + client: Optional[AsyncOpenAI], + acreate_thread: Literal[True], + ) -> Coroutine[None, None, Thread]: + ... + + @overload + def create_thread( + self, + metadata: Optional[dict], + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + messages: Optional[Iterable[OpenAICreateThreadParamsMessage]], + client: Optional[OpenAI], + acreate_thread: Optional[Literal[False]], + ) -> Thread: + ... + + # fmt: on + def create_thread( self, metadata: Optional[dict], @@ -1938,9 +2232,10 @@ class OpenAIAssistantsAPI(BaseLLM): timeout: Union[float, httpx.Timeout], max_retries: Optional[int], organization: Optional[str], - client: Optional[OpenAI], messages: Optional[Iterable[OpenAICreateThreadParamsMessage]], - ) -> Thread: + client=None, + acreate_thread=None, + ): """ Here's an example: ``` @@ -1951,6 +2246,17 @@ class OpenAIAssistantsAPI(BaseLLM): openai_api.create_thread(messages=[message]) ``` """ + if acreate_thread is not None and acreate_thread == True: + return self.async_create_thread( + metadata=metadata, + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + messages=messages, + ) openai_client = self.get_openai_client( api_key=api_key, api_base=api_base, @@ -1970,6 +2276,61 @@ class OpenAIAssistantsAPI(BaseLLM): return Thread(**message_thread.dict()) + async def async_get_thread( + self, + thread_id: str, + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client: Optional[AsyncOpenAI], + ) -> Thread: + openai_client = self.async_get_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) + + response = await openai_client.beta.threads.retrieve(thread_id=thread_id) + + return Thread(**response.dict()) + + # fmt: off + + @overload + def get_thread( + self, + thread_id: str, + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client: Optional[AsyncOpenAI], + aget_thread: Literal[True], + ) -> Coroutine[None, None, Thread]: + ... + + @overload + def get_thread( + self, + thread_id: str, + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client: Optional[OpenAI], + aget_thread: Optional[Literal[False]], + ) -> Thread: + ... + + # fmt: on + def get_thread( self, thread_id: str, @@ -1978,8 +2339,19 @@ class OpenAIAssistantsAPI(BaseLLM): timeout: Union[float, httpx.Timeout], max_retries: Optional[int], organization: Optional[str], - client: Optional[OpenAI], - ) -> Thread: + client=None, + aget_thread=None, + ): + if aget_thread is not None and aget_thread == True: + return self.async_get_thread( + thread_id=thread_id, + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) openai_client = self.get_openai_client( api_key=api_key, api_base=api_base, @@ -1998,6 +2370,90 @@ class OpenAIAssistantsAPI(BaseLLM): ### RUNS ### + async def arun_thread( + self, + thread_id: str, + assistant_id: str, + additional_instructions: Optional[str], + instructions: Optional[str], + metadata: Optional[object], + model: Optional[str], + stream: Optional[bool], + tools: Optional[Iterable[AssistantToolParam]], + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client: Optional[AsyncOpenAI], + ) -> Run: + openai_client = self.async_get_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) + + response = await openai_client.beta.threads.runs.create_and_poll( # type: ignore + thread_id=thread_id, + assistant_id=assistant_id, + additional_instructions=additional_instructions, + instructions=instructions, + metadata=metadata, + model=model, + tools=tools, + ) + + return response + + # fmt: off + + @overload + def run_thread( + self, + thread_id: str, + assistant_id: str, + additional_instructions: Optional[str], + instructions: Optional[str], + metadata: Optional[object], + model: Optional[str], + stream: Optional[bool], + tools: Optional[Iterable[AssistantToolParam]], + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client: Optional[AsyncOpenAI], + arun_thread: Literal[True], + ) -> Coroutine[None, None, Run]: + ... + + @overload + def run_thread( + self, + thread_id: str, + assistant_id: str, + additional_instructions: Optional[str], + instructions: Optional[str], + metadata: Optional[object], + model: Optional[str], + stream: Optional[bool], + tools: Optional[Iterable[AssistantToolParam]], + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client: Optional[OpenAI], + arun_thread: Optional[Literal[False]], + ) -> Run: + ... + + # fmt: on + def run_thread( self, thread_id: str, @@ -2013,8 +2469,26 @@ class OpenAIAssistantsAPI(BaseLLM): timeout: Union[float, httpx.Timeout], max_retries: Optional[int], organization: Optional[str], - client: Optional[OpenAI], - ) -> Run: + client=None, + arun_thread=None, + ): + if arun_thread is not None and arun_thread == True: + return self.arun_thread( + thread_id=thread_id, + assistant_id=assistant_id, + additional_instructions=additional_instructions, + instructions=instructions, + metadata=metadata, + model=model, + stream=stream, + tools=tools, + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) openai_client = self.get_openai_client( api_key=api_key, api_base=api_base, diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 56aa1b35e8..2206f84f65 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -5078,6 +5078,765 @@ async def audio_transcriptions( ) +###################################################################### + +# /v1/assistant Endpoints + + +###################################################################### + + +@router.get( + "/v1/assistants", + dependencies=[Depends(user_api_key_auth)], + tags=["assistants"], +) +@router.get( + "/assistants", + dependencies=[Depends(user_api_key_auth)], + tags=["assistants"], +) +async def get_assistants( + request: Request, + fastapi_response: Response, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + global proxy_logging_obj + data: Dict = {} + try: + # Use orjson to parse JSON data, orjson speeds up requests significantly + body = await request.body() + data = orjson.loads(body) + + # Include original request and headers in the data + data["proxy_server_request"] = { # type: ignore + "url": str(request.url), + "method": request.method, + "headers": dict(request.headers), + "body": copy.copy(data), # use copy instead of deepcopy + } + + if data.get("user", None) is None and user_api_key_dict.user_id is not None: + data["user"] = user_api_key_dict.user_id + + if "metadata" not in data: + data["metadata"] = {} + data["metadata"]["user_api_key"] = user_api_key_dict.api_key + data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata + _headers = dict(request.headers) + _headers.pop( + "authorization", None + ) # do not store the original `sk-..` api key in the db + data["metadata"]["headers"] = _headers + data["metadata"]["user_api_key_alias"] = getattr( + user_api_key_dict, "key_alias", None + ) + data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id + data["metadata"]["user_api_key_team_id"] = getattr( + user_api_key_dict, "team_id", None + ) + data["metadata"]["global_max_parallel_requests"] = general_settings.get( + "global_max_parallel_requests", None + ) + data["metadata"]["user_api_key_team_alias"] = getattr( + user_api_key_dict, "team_alias", None + ) + data["metadata"]["endpoint"] = str(request.url) + + ### TEAM-SPECIFIC PARAMS ### + if user_api_key_dict.team_id is not None: + team_config = await proxy_config.load_team_config( + team_id=user_api_key_dict.team_id + ) + if len(team_config) == 0: + pass + else: + team_id = team_config.pop("team_id", None) + data["metadata"]["team_id"] = team_id + data = { + **team_config, + **data, + } # add the team-specific configs to the completion call + + # for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch + if llm_router is None: + raise HTTPException( + status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value} + ) + response = await llm_router.aget_assistants( + custom_llm_provider="openai", client=None, **data + ) + + ### ALERTING ### + data["litellm_status"] = "success" # used for alerting + + ### RESPONSE HEADERS ### + hidden_params = getattr(response, "_hidden_params", {}) or {} + model_id = hidden_params.get("model_id", None) or "" + cache_key = hidden_params.get("cache_key", None) or "" + api_base = hidden_params.get("api_base", None) or "" + + fastapi_response.headers.update( + get_custom_headers( + user_api_key_dict=user_api_key_dict, + model_id=model_id, + cache_key=cache_key, + api_base=api_base, + version=version, + model_region=getattr(user_api_key_dict, "allowed_model_region", ""), + ) + ) + + return response + except Exception as e: + data["litellm_status"] = "fail" # used for alerting + await proxy_logging_obj.post_call_failure_hook( + user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data + ) + traceback.print_exc() + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "message", str(e.detail)), + type=getattr(e, "type", "None"), + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), + ) + else: + error_traceback = traceback.format_exc() + error_msg = f"{str(e)}" + raise ProxyException( + message=getattr(e, "message", error_msg), + type=getattr(e, "type", "None"), + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", 500), + ) + + +@router.post( + "/v1/threads", + dependencies=[Depends(user_api_key_auth)], + tags=["assistants"], +) +@router.post( + "/threads", + dependencies=[Depends(user_api_key_auth)], + tags=["assistants"], +) +async def create_threads( + request: Request, + fastapi_response: Response, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + global proxy_logging_obj + data: Dict = {} + try: + # Use orjson to parse JSON data, orjson speeds up requests significantly + body = await request.body() + data = orjson.loads(body) + + # Include original request and headers in the data + data["proxy_server_request"] = { # type: ignore + "url": str(request.url), + "method": request.method, + "headers": dict(request.headers), + "body": copy.copy(data), # use copy instead of deepcopy + } + + if data.get("user", None) is None and user_api_key_dict.user_id is not None: + data["user"] = user_api_key_dict.user_id + + if "litellm_metadata" not in data: + data["litellm_metadata"] = {} + data["litellm_metadata"]["user_api_key"] = user_api_key_dict.api_key + data["litellm_metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata + _headers = dict(request.headers) + _headers.pop( + "authorization", None + ) # do not store the original `sk-..` api key in the db + data["litellm_metadata"]["headers"] = _headers + data["litellm_metadata"]["user_api_key_alias"] = getattr( + user_api_key_dict, "key_alias", None + ) + data["litellm_metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id + data["litellm_metadata"]["user_api_key_team_id"] = getattr( + user_api_key_dict, "team_id", None + ) + data["litellm_metadata"]["global_max_parallel_requests"] = general_settings.get( + "global_max_parallel_requests", None + ) + data["litellm_metadata"]["user_api_key_team_alias"] = getattr( + user_api_key_dict, "team_alias", None + ) + data["litellm_metadata"]["endpoint"] = str(request.url) + + ### TEAM-SPECIFIC PARAMS ### + if user_api_key_dict.team_id is not None: + team_config = await proxy_config.load_team_config( + team_id=user_api_key_dict.team_id + ) + if len(team_config) == 0: + pass + else: + team_id = team_config.pop("team_id", None) + data["litellm_metadata"]["team_id"] = team_id + data = { + **team_config, + **data, + } # add the team-specific configs to the completion call + + # for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch + if llm_router is None: + raise HTTPException( + status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value} + ) + response = await llm_router.acreate_thread( + custom_llm_provider="openai", client=None, **data + ) + + ### ALERTING ### + data["litellm_status"] = "success" # used for alerting + + ### RESPONSE HEADERS ### + hidden_params = getattr(response, "_hidden_params", {}) or {} + model_id = hidden_params.get("model_id", None) or "" + cache_key = hidden_params.get("cache_key", None) or "" + api_base = hidden_params.get("api_base", None) or "" + + fastapi_response.headers.update( + get_custom_headers( + user_api_key_dict=user_api_key_dict, + model_id=model_id, + cache_key=cache_key, + api_base=api_base, + version=version, + model_region=getattr(user_api_key_dict, "allowed_model_region", ""), + ) + ) + + return response + except Exception as e: + data["litellm_status"] = "fail" # used for alerting + await proxy_logging_obj.post_call_failure_hook( + user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data + ) + traceback.print_exc() + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "message", str(e.detail)), + type=getattr(e, "type", "None"), + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), + ) + else: + error_traceback = traceback.format_exc() + error_msg = f"{str(e)}" + raise ProxyException( + message=getattr(e, "message", error_msg), + type=getattr(e, "type", "None"), + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", 500), + ) + + +@router.get( + "/v1/threads/{thread_id}", + dependencies=[Depends(user_api_key_auth)], + tags=["assistants"], +) +@router.get( + "/threads/{thread_id}", + dependencies=[Depends(user_api_key_auth)], + tags=["assistants"], +) +async def get_thread( + request: Request, + thread_id: str, + fastapi_response: Response, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + global proxy_logging_obj + data: Dict = {} + try: + + # Include original request and headers in the data + data["proxy_server_request"] = { # type: ignore + "url": str(request.url), + "method": request.method, + "headers": dict(request.headers), + "body": copy.copy(data), # use copy instead of deepcopy + } + + if data.get("user", None) is None and user_api_key_dict.user_id is not None: + data["user"] = user_api_key_dict.user_id + + if "metadata" not in data: + data["metadata"] = {} + data["metadata"]["user_api_key"] = user_api_key_dict.api_key + data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata + _headers = dict(request.headers) + _headers.pop( + "authorization", None + ) # do not store the original `sk-..` api key in the db + data["metadata"]["headers"] = _headers + data["metadata"]["user_api_key_alias"] = getattr( + user_api_key_dict, "key_alias", None + ) + data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id + data["metadata"]["user_api_key_team_id"] = getattr( + user_api_key_dict, "team_id", None + ) + data["metadata"]["global_max_parallel_requests"] = general_settings.get( + "global_max_parallel_requests", None + ) + data["metadata"]["user_api_key_team_alias"] = getattr( + user_api_key_dict, "team_alias", None + ) + data["metadata"]["endpoint"] = str(request.url) + + ### TEAM-SPECIFIC PARAMS ### + if user_api_key_dict.team_id is not None: + team_config = await proxy_config.load_team_config( + team_id=user_api_key_dict.team_id + ) + if len(team_config) == 0: + pass + else: + team_id = team_config.pop("team_id", None) + data["metadata"]["team_id"] = team_id + data = { + **team_config, + **data, + } # add the team-specific configs to the completion call + + # for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch + if llm_router is None: + raise HTTPException( + status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value} + ) + response = await llm_router.aget_thread( + custom_llm_provider="openai", thread_id=thread_id, client=None, **data + ) + + ### ALERTING ### + data["litellm_status"] = "success" # used for alerting + + ### RESPONSE HEADERS ### + hidden_params = getattr(response, "_hidden_params", {}) or {} + model_id = hidden_params.get("model_id", None) or "" + cache_key = hidden_params.get("cache_key", None) or "" + api_base = hidden_params.get("api_base", None) or "" + + fastapi_response.headers.update( + get_custom_headers( + user_api_key_dict=user_api_key_dict, + model_id=model_id, + cache_key=cache_key, + api_base=api_base, + version=version, + model_region=getattr(user_api_key_dict, "allowed_model_region", ""), + ) + ) + + return response + except Exception as e: + data["litellm_status"] = "fail" # used for alerting + await proxy_logging_obj.post_call_failure_hook( + user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data + ) + traceback.print_exc() + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "message", str(e.detail)), + type=getattr(e, "type", "None"), + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), + ) + else: + error_traceback = traceback.format_exc() + error_msg = f"{str(e)}" + raise ProxyException( + message=getattr(e, "message", error_msg), + type=getattr(e, "type", "None"), + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", 500), + ) + + +@router.post( + "/v1/threads/{thread_id}/messages", + dependencies=[Depends(user_api_key_auth)], + tags=["assistants"], +) +@router.post( + "/threads/{thread_id}/messages", + dependencies=[Depends(user_api_key_auth)], + tags=["assistants"], +) +async def add_messages( + request: Request, + thread_id: str, + fastapi_response: Response, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + global proxy_logging_obj + data: Dict = {} + try: + # Use orjson to parse JSON data, orjson speeds up requests significantly + body = await request.body() + data = orjson.loads(body) + + # Include original request and headers in the data + data["proxy_server_request"] = { # type: ignore + "url": str(request.url), + "method": request.method, + "headers": dict(request.headers), + "body": copy.copy(data), # use copy instead of deepcopy + } + + if data.get("user", None) is None and user_api_key_dict.user_id is not None: + data["user"] = user_api_key_dict.user_id + + if "litellm_metadata" not in data: + data["litellm_metadata"] = {} + data["litellm_metadata"]["user_api_key"] = user_api_key_dict.api_key + data["litellm_metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata + _headers = dict(request.headers) + _headers.pop( + "authorization", None + ) # do not store the original `sk-..` api key in the db + data["litellm_metadata"]["headers"] = _headers + data["litellm_metadata"]["user_api_key_alias"] = getattr( + user_api_key_dict, "key_alias", None + ) + data["litellm_metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id + data["litellm_metadata"]["user_api_key_team_id"] = getattr( + user_api_key_dict, "team_id", None + ) + data["litellm_metadata"]["global_max_parallel_requests"] = general_settings.get( + "global_max_parallel_requests", None + ) + data["litellm_metadata"]["user_api_key_team_alias"] = getattr( + user_api_key_dict, "team_alias", None + ) + data["litellm_metadata"]["endpoint"] = str(request.url) + + ### TEAM-SPECIFIC PARAMS ### + if user_api_key_dict.team_id is not None: + team_config = await proxy_config.load_team_config( + team_id=user_api_key_dict.team_id + ) + if len(team_config) == 0: + pass + else: + team_id = team_config.pop("team_id", None) + data["litellm_metadata"]["team_id"] = team_id + data = { + **team_config, + **data, + } # add the team-specific configs to the completion call + + # for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch + if llm_router is None: + raise HTTPException( + status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value} + ) + response = await llm_router.a_add_message( + custom_llm_provider="openai", thread_id=thread_id, client=None, **data + ) + + ### ALERTING ### + data["litellm_status"] = "success" # used for alerting + + ### RESPONSE HEADERS ### + hidden_params = getattr(response, "_hidden_params", {}) or {} + model_id = hidden_params.get("model_id", None) or "" + cache_key = hidden_params.get("cache_key", None) or "" + api_base = hidden_params.get("api_base", None) or "" + + fastapi_response.headers.update( + get_custom_headers( + user_api_key_dict=user_api_key_dict, + model_id=model_id, + cache_key=cache_key, + api_base=api_base, + version=version, + model_region=getattr(user_api_key_dict, "allowed_model_region", ""), + ) + ) + + return response + except Exception as e: + data["litellm_status"] = "fail" # used for alerting + await proxy_logging_obj.post_call_failure_hook( + user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data + ) + traceback.print_exc() + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "message", str(e.detail)), + type=getattr(e, "type", "None"), + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), + ) + else: + error_traceback = traceback.format_exc() + error_msg = f"{str(e)}" + raise ProxyException( + message=getattr(e, "message", error_msg), + type=getattr(e, "type", "None"), + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", 500), + ) + + +@router.get( + "/v1/threads/{thread_id}/messages", + dependencies=[Depends(user_api_key_auth)], + tags=["assistants"], +) +@router.get( + "/threads/{thread_id}/messages", + dependencies=[Depends(user_api_key_auth)], + tags=["assistants"], +) +async def get_messages( + request: Request, + thread_id: str, + fastapi_response: Response, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + global proxy_logging_obj + data: Dict = {} + try: + # Include original request and headers in the data + data["proxy_server_request"] = { # type: ignore + "url": str(request.url), + "method": request.method, + "headers": dict(request.headers), + "body": copy.copy(data), # use copy instead of deepcopy + } + + if data.get("user", None) is None and user_api_key_dict.user_id is not None: + data["user"] = user_api_key_dict.user_id + + if "metadata" not in data: + data["metadata"] = {} + data["metadata"]["user_api_key"] = user_api_key_dict.api_key + data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata + _headers = dict(request.headers) + _headers.pop( + "authorization", None + ) # do not store the original `sk-..` api key in the db + data["metadata"]["headers"] = _headers + data["metadata"]["user_api_key_alias"] = getattr( + user_api_key_dict, "key_alias", None + ) + data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id + data["metadata"]["user_api_key_team_id"] = getattr( + user_api_key_dict, "team_id", None + ) + data["metadata"]["global_max_parallel_requests"] = general_settings.get( + "global_max_parallel_requests", None + ) + data["metadata"]["user_api_key_team_alias"] = getattr( + user_api_key_dict, "team_alias", None + ) + data["metadata"]["endpoint"] = str(request.url) + + ### TEAM-SPECIFIC PARAMS ### + if user_api_key_dict.team_id is not None: + team_config = await proxy_config.load_team_config( + team_id=user_api_key_dict.team_id + ) + if len(team_config) == 0: + pass + else: + team_id = team_config.pop("team_id", None) + data["metadata"]["team_id"] = team_id + data = { + **team_config, + **data, + } # add the team-specific configs to the completion call + + # for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch + if llm_router is None: + raise HTTPException( + status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value} + ) + response = await llm_router.aget_messages( + custom_llm_provider="openai", thread_id=thread_id, client=None, **data + ) + + ### ALERTING ### + data["litellm_status"] = "success" # used for alerting + + ### RESPONSE HEADERS ### + hidden_params = getattr(response, "_hidden_params", {}) or {} + model_id = hidden_params.get("model_id", None) or "" + cache_key = hidden_params.get("cache_key", None) or "" + api_base = hidden_params.get("api_base", None) or "" + + fastapi_response.headers.update( + get_custom_headers( + user_api_key_dict=user_api_key_dict, + model_id=model_id, + cache_key=cache_key, + api_base=api_base, + version=version, + model_region=getattr(user_api_key_dict, "allowed_model_region", ""), + ) + ) + + return response + except Exception as e: + data["litellm_status"] = "fail" # used for alerting + await proxy_logging_obj.post_call_failure_hook( + user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data + ) + traceback.print_exc() + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "message", str(e.detail)), + type=getattr(e, "type", "None"), + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), + ) + else: + error_traceback = traceback.format_exc() + error_msg = f"{str(e)}" + raise ProxyException( + message=getattr(e, "message", error_msg), + type=getattr(e, "type", "None"), + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", 500), + ) + + +@router.get( + "/v1/threads/{thread_id}/runs", + dependencies=[Depends(user_api_key_auth)], + tags=["assistants"], +) +@router.get( + "/threads/{thread_id}/runs", + dependencies=[Depends(user_api_key_auth)], + tags=["assistants"], +) +async def run_thread( + request: Request, + thread_id: str, + fastapi_response: Response, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + global proxy_logging_obj + data: Dict = {} + try: + body = await request.body() + data = orjson.loads(body) + # Include original request and headers in the data + data["proxy_server_request"] = { # type: ignore + "url": str(request.url), + "method": request.method, + "headers": dict(request.headers), + "body": copy.copy(data), # use copy instead of deepcopy + } + + if data.get("user", None) is None and user_api_key_dict.user_id is not None: + data["user"] = user_api_key_dict.user_id + + if "litellm_metadata" not in data: + data["litellm_metadata"] = {} + data["litellm_metadata"]["user_api_key"] = user_api_key_dict.api_key + data["litellm_metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata + _headers = dict(request.headers) + _headers.pop( + "authorization", None + ) # do not store the original `sk-..` api key in the db + data["litellm_metadata"]["headers"] = _headers + data["litellm_metadata"]["user_api_key_alias"] = getattr( + user_api_key_dict, "key_alias", None + ) + data["litellm_metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id + data["litellm_metadata"]["user_api_key_team_id"] = getattr( + user_api_key_dict, "team_id", None + ) + data["litellm_metadata"]["global_max_parallel_requests"] = general_settings.get( + "global_max_parallel_requests", None + ) + data["litellm_metadata"]["user_api_key_team_alias"] = getattr( + user_api_key_dict, "team_alias", None + ) + data["litellm_metadata"]["endpoint"] = str(request.url) + + ### TEAM-SPECIFIC PARAMS ### + if user_api_key_dict.team_id is not None: + team_config = await proxy_config.load_team_config( + team_id=user_api_key_dict.team_id + ) + if len(team_config) == 0: + pass + else: + team_id = team_config.pop("team_id", None) + data["litellm_metadata"]["team_id"] = team_id + data = { + **team_config, + **data, + } # add the team-specific configs to the completion call + + # for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch + if llm_router is None: + raise HTTPException( + status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value} + ) + response = await llm_router.arun_thread( + custom_llm_provider="openai", thread_id=thread_id, client=None, **data + ) + + ### ALERTING ### + data["litellm_status"] = "success" # used for alerting + + ### RESPONSE HEADERS ### + hidden_params = getattr(response, "_hidden_params", {}) or {} + model_id = hidden_params.get("model_id", None) or "" + cache_key = hidden_params.get("cache_key", None) or "" + api_base = hidden_params.get("api_base", None) or "" + + fastapi_response.headers.update( + get_custom_headers( + user_api_key_dict=user_api_key_dict, + model_id=model_id, + cache_key=cache_key, + api_base=api_base, + version=version, + model_region=getattr(user_api_key_dict, "allowed_model_region", ""), + ) + ) + + return response + except Exception as e: + data["litellm_status"] = "fail" # used for alerting + await proxy_logging_obj.post_call_failure_hook( + user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data + ) + traceback.print_exc() + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "message", str(e.detail)), + type=getattr(e, "type", "None"), + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), + ) + else: + error_traceback = traceback.format_exc() + error_msg = f"{str(e)}" + raise ProxyException( + message=getattr(e, "message", error_msg), + type=getattr(e, "type", "None"), + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", 500), + ) + + ###################################################################### # /v1/batches Endpoints diff --git a/litellm/router.py b/litellm/router.py index d7535a83ae..bdc7b5fa00 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -53,6 +53,16 @@ from litellm.types.router import ( ) from litellm.integrations.custom_logger import CustomLogger from litellm.llms.azure import get_azure_ad_token_from_oidc +from litellm.types.llms.openai import ( + AsyncCursorPage, + Assistant, + Thread, + Attachment, + OpenAIMessage, + Run, + AssistantToolParam, +) +from typing import Iterable class Router: @@ -1646,6 +1656,108 @@ class Router: self.fail_calls[model_name] += 1 raise e + #### ASSISTANTS API #### + + async def aget_assistants( + self, + custom_llm_provider: Literal["openai"], + client: Optional[AsyncOpenAI] = None, + **kwargs, + ) -> AsyncCursorPage[Assistant]: + return await litellm.aget_assistants( + custom_llm_provider=custom_llm_provider, client=client, **kwargs + ) + + async def acreate_thread( + self, + custom_llm_provider: Literal["openai"], + client: Optional[AsyncOpenAI] = None, + **kwargs, + ) -> Thread: + return await litellm.acreate_thread( + custom_llm_provider=custom_llm_provider, client=client, **kwargs + ) + + async def aget_thread( + self, + custom_llm_provider: Literal["openai"], + thread_id: str, + client: Optional[AsyncOpenAI] = None, + **kwargs, + ) -> Thread: + return await litellm.aget_thread( + custom_llm_provider=custom_llm_provider, + thread_id=thread_id, + client=client, + **kwargs, + ) + + async def a_add_message( + self, + custom_llm_provider: Literal["openai"], + thread_id: str, + role: Literal["user", "assistant"], + content: str, + attachments: Optional[List[Attachment]] = None, + metadata: Optional[dict] = None, + client: Optional[AsyncOpenAI] = None, + **kwargs, + ) -> OpenAIMessage: + return await litellm.a_add_message( + custom_llm_provider=custom_llm_provider, + thread_id=thread_id, + role=role, + content=content, + attachments=attachments, + metadata=metadata, + client=client, + **kwargs, + ) + + async def aget_messages( + self, + custom_llm_provider: Literal["openai"], + thread_id: str, + client: Optional[AsyncOpenAI] = None, + **kwargs, + ) -> AsyncCursorPage[OpenAIMessage]: + return await litellm.aget_messages( + custom_llm_provider=custom_llm_provider, + thread_id=thread_id, + client=client, + **kwargs, + ) + + async def arun_thread( + self, + custom_llm_provider: Literal["openai"], + thread_id: str, + assistant_id: str, + additional_instructions: Optional[str] = None, + instructions: Optional[str] = None, + metadata: Optional[dict] = None, + model: Optional[str] = None, + stream: Optional[bool] = None, + tools: Optional[Iterable[AssistantToolParam]] = None, + client: Optional[AsyncOpenAI] = None, + **kwargs, + ) -> Run: + return await litellm.arun_thread( + custom_llm_provider=custom_llm_provider, + thread_id=thread_id, + assistant_id=assistant_id, + additional_instructions=additional_instructions, + instructions=instructions, + metadata=metadata, + model=model, + stream=stream, + tools=tools, + client=client, + **kwargs, + ) + + #### [END] ASSISTANTS API #### + async def async_function_with_fallbacks(self, *args, **kwargs): """ Try calling the function_with_retries diff --git a/litellm/tests/test_assistants.py b/litellm/tests/test_assistants.py index 7f20a6df06..1e35806890 100644 --- a/litellm/tests/test_assistants.py +++ b/litellm/tests/test_assistants.py @@ -16,6 +16,7 @@ from litellm.llms.openai import ( MessageData, Thread, OpenAIMessage as Message, + AsyncCursorPage, ) """ @@ -26,26 +27,55 @@ V0 Scope: """ -def test_create_thread_litellm() -> Thread: +@pytest.mark.asyncio +async def test_async_get_assistants(): + assistants = await litellm.aget_assistants(custom_llm_provider="openai") + assert isinstance(assistants, AsyncCursorPage) + + +@pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.asyncio +async def test_create_thread_litellm(sync_mode) -> Thread: message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore - new_thread = create_thread( - custom_llm_provider="openai", - messages=[message], # type: ignore - ) + + if sync_mode: + new_thread = create_thread( + custom_llm_provider="openai", + messages=[message], # type: ignore + ) + else: + new_thread = await litellm.acreate_thread( + custom_llm_provider="openai", + messages=[message], # type: ignore + ) assert isinstance( new_thread, Thread ), f"type of thread={type(new_thread)}. Expected Thread-type" + return new_thread -def test_get_thread_litellm(): - new_thread = test_create_thread_litellm() +@pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.asyncio +async def test_get_thread_litellm(sync_mode): + new_thread = test_create_thread_litellm(sync_mode) - received_thread = get_thread( - custom_llm_provider="openai", - thread_id=new_thread.id, - ) + if asyncio.iscoroutine(new_thread): + _new_thread = await new_thread + else: + _new_thread = new_thread + + if sync_mode: + received_thread = get_thread( + custom_llm_provider="openai", + thread_id=_new_thread.id, + ) + else: + received_thread = await litellm.aget_thread( + custom_llm_provider="openai", + thread_id=_new_thread.id, + ) assert isinstance( received_thread, Thread @@ -53,50 +83,90 @@ def test_get_thread_litellm(): return new_thread -def test_add_message_litellm(): +@pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.asyncio +async def test_add_message_litellm(sync_mode): message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore - new_thread = test_create_thread_litellm() + new_thread = test_create_thread_litellm(sync_mode) + if asyncio.iscoroutine(new_thread): + _new_thread = await new_thread + else: + _new_thread = new_thread # add message to thread message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore - added_message = litellm.add_message( - thread_id=new_thread.id, custom_llm_provider="openai", **message - ) + if sync_mode: + added_message = litellm.add_message( + thread_id=_new_thread.id, custom_llm_provider="openai", **message + ) + else: + added_message = await litellm.a_add_message( + thread_id=_new_thread.id, custom_llm_provider="openai", **message + ) print(f"added message: {added_message}") assert isinstance(added_message, Message) -def test_run_thread_litellm(): +@pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.asyncio +async def test_run_thread_litellm(sync_mode): """ - Get Assistants - Create thread - Create run w/ Assistants + Thread """ - assistants = litellm.get_assistants(custom_llm_provider="openai") + if sync_mode: + assistants = litellm.get_assistants(custom_llm_provider="openai") + else: + assistants = await litellm.aget_assistants(custom_llm_provider="openai") ## get the first assistant ### assistant_id = assistants.data[0].id - new_thread = test_create_thread_litellm() + new_thread = test_create_thread_litellm(sync_mode=sync_mode) - thread_id = new_thread.id + if asyncio.iscoroutine(new_thread): + _new_thread = await new_thread + else: + _new_thread = new_thread + + thread_id = _new_thread.id # add message to thread message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore - added_message = litellm.add_message( - thread_id=new_thread.id, custom_llm_provider="openai", **message - ) - run = litellm.run_thread( - custom_llm_provider="openai", thread_id=thread_id, assistant_id=assistant_id - ) - - if run.status == "completed": - messages = litellm.get_messages( - thread_id=new_thread.id, custom_llm_provider="openai" + if sync_mode: + added_message = litellm.add_message( + thread_id=_new_thread.id, custom_llm_provider="openai", **message ) - assert isinstance(messages.data[0], Message) + + run = litellm.run_thread( + custom_llm_provider="openai", thread_id=thread_id, assistant_id=assistant_id + ) + + if run.status == "completed": + messages = litellm.get_messages( + thread_id=_new_thread.id, custom_llm_provider="openai" + ) + assert isinstance(messages.data[0], Message) + else: + pytest.fail("An unexpected error occurred when running the thread") + else: - pytest.fail("An unexpected error occurred when running the thread") + added_message = await litellm.a_add_message( + thread_id=_new_thread.id, custom_llm_provider="openai", **message + ) + + run = await litellm.arun_thread( + custom_llm_provider="openai", thread_id=thread_id, assistant_id=assistant_id + ) + + if run.status == "completed": + messages = await litellm.aget_messages( + thread_id=_new_thread.id, custom_llm_provider="openai" + ) + assert isinstance(messages.data[0], Message) + else: + pytest.fail("An unexpected error occurred when running the thread") diff --git a/litellm/types/llms/openai.py b/litellm/types/llms/openai.py index 57c199b61f..4045281ff7 100644 --- a/litellm/types/llms/openai.py +++ b/litellm/types/llms/openai.py @@ -17,11 +17,10 @@ from openai.types.beta.thread_create_params import ( from openai.types.beta.assistant_tool_param import AssistantToolParam from openai.types.beta.threads.run import Run from openai.types.beta.assistant import Assistant -from openai.pagination import SyncCursorPage +from openai.pagination import SyncCursorPage, AsyncCursorPage from os import PathLike from openai.types import FileObject, Batch from openai._legacy_response import HttpxBinaryResponseContent - from typing import TypedDict, List, Optional, Tuple, Mapping, IO FileContent = Union[IO[bytes], bytes, PathLike]