diff --git a/litellm/assistants/main.py b/litellm/assistants/main.py index b02a2bd59..957126ae1 100644 --- a/litellm/assistants/main.py +++ b/litellm/assistants/main.py @@ -113,6 +113,77 @@ def create_thread( return response +def get_thread( + custom_llm_provider: Literal["openai"], + thread_id: str, + client: Optional[OpenAI] = None, + **kwargs, +) -> Thread: + """Get the thread object, given a thread_id""" + optional_params = GenericLiteLLMParams(**kwargs) + + ### TIMEOUT LOGIC ### + timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600 + # set timeout for 10 minutes by default + + if ( + timeout is not None + and isinstance(timeout, httpx.Timeout) + and supports_httpx_timeout(custom_llm_provider) == False + ): + read_timeout = timeout.read or 600 + timeout = read_timeout # default 10 min timeout + elif timeout is not None and not isinstance(timeout, httpx.Timeout): + timeout = float(timeout) # type: ignore + elif timeout is None: + timeout = 600.0 + + response: Optional[Thread] = None + if custom_llm_provider == "openai": + api_base = ( + optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there + or litellm.api_base + or os.getenv("OPENAI_API_BASE") + or "https://api.openai.com/v1" + ) + organization = ( + optional_params.organization + or litellm.organization + or os.getenv("OPENAI_ORGANIZATION", None) + or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105 + ) + # set API KEY + api_key = ( + optional_params.api_key + or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there + or litellm.openai_key + or os.getenv("OPENAI_API_KEY") + ) + response = openai_assistants_api.get_thread( + thread_id=thread_id, + api_base=api_base, + api_key=api_key, + timeout=timeout, + max_retries=optional_params.max_retries, + organization=organization, + client=client, + ) + else: + raise litellm.exceptions.BadRequestError( + message="LiteLLM doesn't support {} for 'get_thread'. Only 'openai' is supported.".format( + custom_llm_provider + ), + model="n/a", + llm_provider=custom_llm_provider, + response=httpx.Response( + status_code=400, + content="Unsupported provider", + request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore + ), + ) + return response + + ### MESSAGES ### diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index 01d3bd2f2..16f4868f4 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -1405,10 +1405,10 @@ class OpenAIAssistantsAPI(BaseLLM): def get_thread( self, thread_id: str, - api_key: str, + api_key: Optional[str], api_base: Optional[str], timeout: Union[float, httpx.Timeout], - max_retries: int, + max_retries: Optional[int], organization: Optional[str], client: Optional[OpenAI], ) -> Thread: diff --git a/litellm/tests/test_assistants.py b/litellm/tests/test_assistants.py index a3acdae18..ff83dd30c 100644 --- a/litellm/tests/test_assistants.py +++ b/litellm/tests/test_assistants.py @@ -10,7 +10,7 @@ sys.path.insert( ) # Adds the parent directory to the system path import pytest, logging, asyncio import litellm -from litellm import create_thread +from litellm import create_thread, get_thread from litellm.llms.openai import ( OpenAIAssistantsAPI, MessageData, @@ -39,6 +39,20 @@ def test_create_thread_litellm() -> Thread: return new_thread +def test_get_thread_litellm(): + new_thread = test_create_thread_litellm() + + received_thread = get_thread( + custom_llm_provider="openai", + thread_id=new_thread.id, + ) + + assert isinstance( + received_thread, Thread + ), f"type of thread={type(received_thread)}. Expected Thread-type" + return new_thread + + def test_add_message_litellm(): message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore new_thread = test_create_thread_litellm() @@ -54,73 +68,6 @@ def test_add_message_litellm(): assert isinstance(added_message, Message) -def test_create_thread_openai_direct() -> Thread: - openai_api = OpenAIAssistantsAPI() - - message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore - new_thread = openai_api.create_thread( - messages=[message], # type: ignore - api_key=os.getenv("OPENAI_API_KEY"), # type: ignore - metadata={}, - api_base=None, - timeout=600, - max_retries=2, - organization=None, - client=None, - ) - - print(f"new_thread: {new_thread}") - print(f"type of thread: {type(new_thread)}") - assert isinstance( - new_thread, Thread - ), f"type of thread={type(new_thread)}. Expected Thread-type" - return new_thread - - -def test_add_message_openai_direct(): - openai_api = OpenAIAssistantsAPI() - # create thread - new_thread = test_create_thread_openai_direct() - # add message to thread - message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore - added_message = openai_api.add_message( - thread_id=new_thread.id, - message_data=message, - api_key=os.getenv("OPENAI_API_KEY"), - api_base=None, - timeout=600, - max_retries=2, - organization=None, - client=None, - ) - - print(f"added message: {added_message}") - - assert isinstance(added_message, Message) - - -def test_get_thread_openai_direct(): - openai_api = OpenAIAssistantsAPI() - - ## create a thread w/ message ### - new_thread = test_create_thread() - - retrieved_thread = openai_api.get_thread( - thread_id=new_thread.id, - api_key=os.getenv("OPENAI_API_KEY"), - api_base=None, - timeout=600, - max_retries=2, - organization=None, - client=None, - ) - - assert isinstance( - retrieved_thread, Thread - ), f"type of thread={type(retrieved_thread)}. Expected Thread-type" - return new_thread - - def test_run_thread_openai_direct(): """ - Get Assistants