diff --git a/litellm/assistants/main.py b/litellm/assistants/main.py index 957126ae1..25d2433d7 100644 --- a/litellm/assistants/main.py +++ b/litellm/assistants/main.py @@ -15,6 +15,75 @@ openai_assistants_api = OpenAIAssistantsAPI() ### ASSISTANTS ### + +def get_assistants( + custom_llm_provider: Literal["openai"], + client: Optional[OpenAI] = None, + **kwargs, +) -> SyncCursorPage[Assistant]: + optional_params = GenericLiteLLMParams(**kwargs) + + ### TIMEOUT LOGIC ### + timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600 + # set timeout for 10 minutes by default + + if ( + timeout is not None + and isinstance(timeout, httpx.Timeout) + and supports_httpx_timeout(custom_llm_provider) == False + ): + read_timeout = timeout.read or 600 + timeout = read_timeout # default 10 min timeout + elif timeout is not None and not isinstance(timeout, httpx.Timeout): + timeout = float(timeout) # type: ignore + elif timeout is None: + timeout = 600.0 + + response: Optional[SyncCursorPage[Assistant]] = None + if custom_llm_provider == "openai": + api_base = ( + optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there + or litellm.api_base + or os.getenv("OPENAI_API_BASE") + or "https://api.openai.com/v1" + ) + organization = ( + optional_params.organization + or litellm.organization + or os.getenv("OPENAI_ORGANIZATION", None) + or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105 + ) + # set API KEY + api_key = ( + optional_params.api_key + or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there + or litellm.openai_key + or os.getenv("OPENAI_API_KEY") + ) + response = openai_assistants_api.get_assistants( + api_base=api_base, + api_key=api_key, + timeout=timeout, + max_retries=optional_params.max_retries, + organization=organization, + client=client, + ) + else: + raise litellm.exceptions.BadRequestError( + message="LiteLLM doesn't support {} for 'get_assistants'. Only 'openai' is supported.".format( + custom_llm_provider + ), + model="n/a", + llm_provider=custom_llm_provider, + response=httpx.Response( + status_code=400, + content="Unsupported provider", + request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore + ), + ) + return response + + ### THREADS ### @@ -267,4 +336,160 @@ def add_message( return response +def get_messages( + custom_llm_provider: Literal["openai"], + thread_id: str, + client: Optional[OpenAI] = None, + **kwargs, +) -> SyncCursorPage[OpenAIMessage]: + optional_params = GenericLiteLLMParams(**kwargs) + + ### TIMEOUT LOGIC ### + timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600 + # set timeout for 10 minutes by default + + if ( + timeout is not None + and isinstance(timeout, httpx.Timeout) + and supports_httpx_timeout(custom_llm_provider) == False + ): + read_timeout = timeout.read or 600 + timeout = read_timeout # default 10 min timeout + elif timeout is not None and not isinstance(timeout, httpx.Timeout): + timeout = float(timeout) # type: ignore + elif timeout is None: + timeout = 600.0 + + response: Optional[SyncCursorPage[OpenAIMessage]] = None + if custom_llm_provider == "openai": + api_base = ( + optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there + or litellm.api_base + or os.getenv("OPENAI_API_BASE") + or "https://api.openai.com/v1" + ) + organization = ( + optional_params.organization + or litellm.organization + or os.getenv("OPENAI_ORGANIZATION", None) + or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105 + ) + # set API KEY + api_key = ( + optional_params.api_key + or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there + or litellm.openai_key + or os.getenv("OPENAI_API_KEY") + ) + response = openai_assistants_api.get_messages( + thread_id=thread_id, + api_base=api_base, + api_key=api_key, + timeout=timeout, + max_retries=optional_params.max_retries, + organization=organization, + client=client, + ) + else: + raise litellm.exceptions.BadRequestError( + message="LiteLLM doesn't support {} for 'get_messages'. Only 'openai' is supported.".format( + custom_llm_provider + ), + model="n/a", + llm_provider=custom_llm_provider, + response=httpx.Response( + status_code=400, + content="Unsupported provider", + request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore + ), + ) + + return response + + ### RUNS ### + + +def run_thread( + custom_llm_provider: Literal["openai"], + thread_id: str, + assistant_id: str, + additional_instructions: Optional[str] = None, + instructions: Optional[str] = None, + metadata: Optional[dict] = None, + model: Optional[str] = None, + stream: Optional[bool] = None, + tools: Optional[Iterable[AssistantToolParam]] = None, + client: Optional[OpenAI] = None, + **kwargs, +) -> Run: + """Run a given thread + assistant.""" + optional_params = GenericLiteLLMParams(**kwargs) + + ### TIMEOUT LOGIC ### + timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600 + # set timeout for 10 minutes by default + + if ( + timeout is not None + and isinstance(timeout, httpx.Timeout) + and supports_httpx_timeout(custom_llm_provider) == False + ): + read_timeout = timeout.read or 600 + timeout = read_timeout # default 10 min timeout + elif timeout is not None and not isinstance(timeout, httpx.Timeout): + timeout = float(timeout) # type: ignore + elif timeout is None: + timeout = 600.0 + + response: Optional[Run] = None + if custom_llm_provider == "openai": + api_base = ( + optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there + or litellm.api_base + or os.getenv("OPENAI_API_BASE") + or "https://api.openai.com/v1" + ) + organization = ( + optional_params.organization + or litellm.organization + or os.getenv("OPENAI_ORGANIZATION", None) + or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105 + ) + # set API KEY + api_key = ( + optional_params.api_key + or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there + or litellm.openai_key + or os.getenv("OPENAI_API_KEY") + ) + response = openai_assistants_api.run_thread( + thread_id=thread_id, + assistant_id=assistant_id, + additional_instructions=additional_instructions, + instructions=instructions, + metadata=metadata, + model=model, + stream=stream, + tools=tools, + api_base=api_base, + api_key=api_key, + timeout=timeout, + max_retries=optional_params.max_retries, + organization=organization, + client=client, + ) + else: + raise litellm.exceptions.BadRequestError( + message="LiteLLM doesn't support {} for 'run_thread'. Only 'openai' is supported.".format( + custom_llm_provider + ), + model="n/a", + llm_provider=custom_llm_provider, + response=httpx.Response( + status_code=400, + content="Unsupported provider", + request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore + ), + ) + return response diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index 16f4868f4..a95f83e99 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -1282,10 +1282,10 @@ class OpenAIAssistantsAPI(BaseLLM): def get_assistants( self, - api_key: str, + api_key: Optional[str], api_base: Optional[str], timeout: Union[float, httpx.Timeout], - max_retries: int, + max_retries: Optional[int], organization: Optional[str], client: Optional[OpenAI], ) -> SyncCursorPage[Assistant]: @@ -1340,10 +1340,10 @@ class OpenAIAssistantsAPI(BaseLLM): def get_messages( self, thread_id: str, - api_key: str, + api_key: Optional[str], api_base: Optional[str], timeout: Union[float, httpx.Timeout], - max_retries: int, + max_retries: Optional[int], organization: Optional[str], client: Optional[OpenAI] = None, ) -> SyncCursorPage[OpenAIMessage]: @@ -1440,10 +1440,10 @@ class OpenAIAssistantsAPI(BaseLLM): model: Optional[str], stream: Optional[Literal[False]] | Literal[True], tools: Optional[Iterable[AssistantToolParam]], - api_key: str, + api_key: Optional[str], api_base: Optional[str], timeout: Union[float, httpx.Timeout], - max_retries: int, + max_retries: Optional[int], organization: Optional[str], client: Optional[OpenAI], ) -> Run: diff --git a/litellm/tests/test_assistants.py b/litellm/tests/test_assistants.py index ff83dd30c..940b874ff 100644 --- a/litellm/tests/test_assistants.py +++ b/litellm/tests/test_assistants.py @@ -68,6 +68,40 @@ def test_add_message_litellm(): assert isinstance(added_message, Message) +def test_run_thread_litellm(): + """ + - Get Assistants + - Create thread + - Create run w/ Assistants + Thread + """ + assistants = litellm.get_assistants(custom_llm_provider="openai") + + ## get the first assistant ### + assistant_id = assistants.data[0].id + + new_thread = test_create_thread_litellm() + + thread_id = new_thread.id + + # add message to thread + message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore + added_message = litellm.add_message( + thread_id=new_thread.id, custom_llm_provider="openai", **message + ) + + run = litellm.run_thread( + custom_llm_provider="openai", thread_id=thread_id, assistant_id=assistant_id + ) + + if run.status == "completed": + messages = litellm.get_messages( + thread_id=new_thread.id, custom_llm_provider="openai" + ) + assert isinstance(messages.data[0], Message) + else: + pytest.fail("An unexpected error occurred when running the thread") + + def test_run_thread_openai_direct(): """ - Get Assistants