From 84c31a5528f5dec60e110217e7d0baa29336a30c Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 4 May 2024 17:27:48 -0700 Subject: [PATCH 01/11] feat(openai.py): add support for openai assistants v0 commit. Closes https://github.com/BerriAI/litellm/issues/2842 --- litellm/assistants/main.py | 2 + litellm/llms/openai.py | 296 ++++++++++++++++++++++++++++++- litellm/tests/test_assistants.py | 164 +++++++++++++++++ 3 files changed, 461 insertions(+), 1 deletion(-) create mode 100644 litellm/assistants/main.py create mode 100644 litellm/tests/test_assistants.py diff --git a/litellm/assistants/main.py b/litellm/assistants/main.py new file mode 100644 index 000000000..0d3216482 --- /dev/null +++ b/litellm/assistants/main.py @@ -0,0 +1,2 @@ +# What is this? +## Main file for assistants API logic diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index 5a76605b3..a6d6f4109 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -1,4 +1,14 @@ -from typing import Optional, Union, Any, BinaryIO +from typing import ( + Optional, + Union, + Any, + BinaryIO, + Literal, + Annotated, + Iterable, +) +from typing_extensions import override +from pydantic import BaseModel import types, time, json, traceback import httpx from .base import BaseLLM @@ -17,6 +27,73 @@ import aiohttp, requests import litellm from .prompt_templates.factory import prompt_factory, custom_prompt from openai import OpenAI, AsyncOpenAI +from openai.types.beta.threads.message_content import MessageContent +from openai.types.beta.threads.message_create_params import Attachment +from openai.types.beta.threads.message import Message as OpenAIMessage +from openai.types.beta.thread_create_params import ( + Message as OpenAICreateThreadParamsMessage, +) +from openai.types.beta.assistant_tool_param import AssistantToolParam +from openai.types.beta.threads.run import Run +from openai.types.beta.assistant import Assistant +from openai.pagination import SyncCursorPage + +from typing import TypedDict, List, Optional + + +class NotGiven: + """ + A sentinel singleton class used to distinguish omitted keyword arguments + from those passed in with the value None (which may have different behavior). + + For example: + + ```py + def get(timeout: Union[int, NotGiven, None] = NotGiven()) -> Response: + ... + + + get(timeout=1) # 1s timeout + get(timeout=None) # No timeout + get() # Default timeout behavior, which may not be statically known at the method definition. + ``` + """ + + def __bool__(self) -> Literal[False]: + return False + + @override + def __repr__(self) -> str: + return "NOT_GIVEN" + + +NOT_GIVEN = NotGiven() + + +class MessageData(TypedDict): + role: Literal["user", "assistant"] + content: str + attachments: Optional[List[Attachment]] + metadata: Optional[dict] + + +class Thread(BaseModel): + id: str + """The identifier, which can be referenced in API endpoints.""" + + created_at: int + """The Unix timestamp (in seconds) for when the thread was created.""" + + metadata: Optional[object] = None + """Set of 16 key-value pairs that can be attached to an object. + + This can be useful for storing additional information about the object in a + structured format. Keys can be a maximum of 64 characters long and values can be + a maxium of 512 characters long. + """ + + object: Literal["thread"] + """The object type, which is always `thread`.""" class OpenAIError(Exception): @@ -1236,3 +1313,220 @@ class OpenAITextCompletion(BaseLLM): async for transformed_chunk in streamwrapper: yield transformed_chunk + + +class OpenAIAssistantsAPI(BaseLLM): + def __init__(self) -> None: + super().__init__() + + def get_openai_client( + self, + api_key: str, + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: int, + organization: Optional[str], + client: Optional[OpenAI] = None, + ) -> OpenAI: + if client is None: + openai_client = OpenAI( + api_key=api_key, + base_url=api_base, + http_client=litellm.client_session, + timeout=timeout, + max_retries=max_retries, + organization=organization, + ) + else: + openai_client = client + + return openai_client + + ### ASSISTANTS ### + + def get_assistants( + self, + api_key: str, + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: int, + organization: Optional[str], + client: Optional[OpenAI], + ) -> SyncCursorPage[Assistant]: + openai_client = self.get_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) + + response = openai_client.beta.assistants.list() + + return response + + ### MESSAGES ### + + def add_message( + self, + thread_id: str, + message_data: MessageData, + api_key: str, + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: int, + organization: Optional[str], + client: Optional[OpenAI] = None, + ) -> OpenAIMessage: + + openai_client = self.get_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) + + thread_message: OpenAIMessage = openai_client.beta.threads.messages.create( + thread_id, **message_data + ) + + response_obj: Optional[OpenAIMessage] = None + if getattr(thread_message, "status", None) is None: + thread_message.status = "completed" + response_obj = OpenAIMessage(**thread_message.dict()) + else: + response_obj = OpenAIMessage(**thread_message.dict()) + return response_obj + + def get_messages( + self, + thread_id: str, + api_key: str, + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: int, + organization: Optional[str], + client: Optional[OpenAI] = None, + ) -> SyncCursorPage[OpenAIMessage]: + openai_client = self.get_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) + + response = openai_client.beta.threads.messages.list(thread_id=thread_id) + + return response + + ### THREADS ### + + def create_thread( + self, + metadata: dict, + api_key: str, + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: int, + organization: Optional[str], + client: Optional[OpenAI], + messages: Union[ + Iterable[OpenAICreateThreadParamsMessage], NotGiven + ] = NOT_GIVEN, + ) -> Thread: + """ + Here's an example: + ``` + from litellm.llms.openai import OpenAIAssistantsAPI, MessageData + + # create thread + message: MessageData = {"role": "user", "content": "Hey, how's it going?"} + openai_api.create_thread(messages=[message]) + ``` + """ + openai_client = self.get_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) + + message_thread = openai_client.beta.threads.create( + messages=messages, # type: ignore + metadata=metadata, + ) + + return Thread(**message_thread.dict()) + + def get_thread( + self, + thread_id: str, + api_key: str, + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: int, + organization: Optional[str], + client: Optional[OpenAI], + ) -> Thread: + openai_client = self.get_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) + + response = openai_client.beta.threads.retrieve(thread_id=thread_id) + + return Thread(**response.dict()) + + def delete_thread(self): + pass + + ### RUNS ### + + def run_thread( + self, + thread_id: str, + assistant_id: str, + additional_instructions: Optional[str], + instructions: Optional[str], + metadata: Optional[object], + model: Optional[str], + stream: Optional[Literal[False]] | Literal[True], + tools: Optional[Iterable[AssistantToolParam]], + api_key: str, + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: int, + organization: Optional[str], + client: Optional[OpenAI], + ) -> Run: + openai_client = self.get_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) + + response = openai_client.beta.threads.runs.create_and_poll( + thread_id=thread_id, + assistant_id=assistant_id, + additional_instructions=additional_instructions, + instructions=instructions, + metadata=metadata, + model=model, + tools=tools, + ) + + return response diff --git a/litellm/tests/test_assistants.py b/litellm/tests/test_assistants.py new file mode 100644 index 000000000..9b8585ec6 --- /dev/null +++ b/litellm/tests/test_assistants.py @@ -0,0 +1,164 @@ +# What is this? +## Unit Tests for OpenAI Assistants API +import sys, os, json +import traceback +from dotenv import load_dotenv + +load_dotenv() +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import pytest, logging, asyncio +import litellm +from litellm.llms.openai import ( + OpenAIAssistantsAPI, + MessageData, + Thread, + OpenAIMessage as Message, +) + +""" +V0 Scope: + +- Add Message -> `/v1/threads/{thread_id}/messages` +- Run Thread -> `/v1/threads/{thread_id}/run` +""" + + +def test_create_thread() -> Thread: + openai_api = OpenAIAssistantsAPI() + + message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore + new_thread = openai_api.create_thread( + messages=[message], # type: ignore + api_key=os.getenv("OPENAI_API_KEY"), # type: ignore + metadata={}, + api_base=None, + timeout=600, + max_retries=2, + organization=None, + client=None, + ) + + print(f"new_thread: {new_thread}") + print(f"type of thread: {type(new_thread)}") + assert isinstance( + new_thread, Thread + ), f"type of thread={type(new_thread)}. Expected Thread-type" + return new_thread + + +def test_add_message(): + openai_api = OpenAIAssistantsAPI() + # create thread + new_thread = test_create_thread() + # add message to thread + message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore + added_message = openai_api.add_message( + thread_id=new_thread.id, + message_data=message, + api_key=os.getenv("OPENAI_API_KEY"), + api_base=None, + timeout=600, + max_retries=2, + organization=None, + client=None, + ) + + print(f"added message: {added_message}") + + assert isinstance(added_message, Message) + + +def test_get_thread(): + openai_api = OpenAIAssistantsAPI() + + ## create a thread w/ message ### + new_thread = test_create_thread() + + retrieved_thread = openai_api.get_thread( + thread_id=new_thread.id, + api_key=os.getenv("OPENAI_API_KEY"), + api_base=None, + timeout=600, + max_retries=2, + organization=None, + client=None, + ) + + assert isinstance( + retrieved_thread, Thread + ), f"type of thread={type(retrieved_thread)}. Expected Thread-type" + return new_thread + + +def test_run_thread(): + """ + - Get Assistants + - Create thread + - Create run w/ Assistants + Thread + """ + openai_api = OpenAIAssistantsAPI() + + assistants = openai_api.get_assistants( + api_key=os.getenv("OPENAI_API_KEY"), + api_base=None, + timeout=600, + max_retries=2, + organization=None, + client=None, + ) + + ## get the first assistant ### + assistant_id = assistants.data[0].id + + ## create a thread w/ message ### + new_thread = test_create_thread() + + thread_id = new_thread.id + + # add message to thread + message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore + added_message = openai_api.add_message( + thread_id=new_thread.id, + message_data=message, + api_key=os.getenv("OPENAI_API_KEY"), + api_base=None, + timeout=600, + max_retries=2, + organization=None, + client=None, + ) + + run = openai_api.run_thread( + thread_id=thread_id, + assistant_id=assistant_id, + additional_instructions=None, + instructions=None, + metadata=None, + model=None, + stream=None, + tools=None, + api_key=os.getenv("OPENAI_API_KEY"), + api_base=None, + timeout=600, + max_retries=2, + organization=None, + client=None, + ) + + print(f"run: {run}") + + if run.status == "completed": + messages = openai_api.get_messages( + thread_id=new_thread.id, + api_key=os.getenv("OPENAI_API_KEY"), + api_base=None, + timeout=600, + max_retries=2, + organization=None, + client=None, + ) + assert isinstance(messages.data[0], Message) + else: + pytest.fail("An unexpected error occurred when running the thread") From 681a95e37b94043194d21b1afcb1a16f95a761dd Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 4 May 2024 19:35:37 -0700 Subject: [PATCH 02/11] fix(assistants/main.py): support `litellm.create_thread()` call --- litellm/__init__.py | 2 +- litellm/assistants/main.py | 116 +++++++++++++++++++++++++++++++ litellm/llms/openai.py | 109 ++++++----------------------- litellm/tests/test_assistants.py | 25 +++++-- litellm/types/llms/__init__.py | 3 + litellm/types/llms/openai.py | 80 +++++++++++++++++++++ litellm/types/router.py | 67 +++++++++++++++++- 7 files changed, 308 insertions(+), 94 deletions(-) create mode 100644 litellm/types/llms/__init__.py create mode 100644 litellm/types/llms/openai.py diff --git a/litellm/__init__.py b/litellm/__init__.py index dc640f0e9..b05c1c910 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -605,7 +605,6 @@ all_embedding_models = ( ####### IMAGE GENERATION MODELS ################### openai_image_generation_models = ["dall-e-2", "dall-e-3"] - from .timeout import timeout from .utils import ( client, @@ -694,3 +693,4 @@ from .exceptions import ( from .budget_manager import BudgetManager from .proxy.proxy_cli import run_server from .router import Router +from .assistants.main import * diff --git a/litellm/assistants/main.py b/litellm/assistants/main.py index 0d3216482..16a1f973c 100644 --- a/litellm/assistants/main.py +++ b/litellm/assistants/main.py @@ -1,2 +1,118 @@ # What is this? ## Main file for assistants API logic +from typing import Iterable +import os +import litellm +from openai import OpenAI +from litellm import client +from litellm.utils import supports_httpx_timeout +from ..llms.openai import OpenAIAssistantsAPI +from ..types.llms.openai import * +from ..types.router import * + +####### ENVIRONMENT VARIABLES ################### +openai_assistants_api = OpenAIAssistantsAPI() + +### ASSISTANTS ### + +### THREADS ### + + +def create_thread( + custom_llm_provider: Literal["openai"], + messages: Optional[Iterable[OpenAICreateThreadParamsMessage]] = None, + metadata: Optional[dict] = None, + tool_resources: Optional[OpenAICreateThreadParamsToolResources] = None, + client: Optional[OpenAI] = None, + **kwargs +) -> Thread: + """ + - get the llm provider + - if openai - route it there + - pass through relevant params + + ``` + from litellm import create_thread + + create_thread( + custom_llm_provider="openai", + ### OPTIONAL ### + messages = { + "role": "user", + "content": "Hello, what is AI?" + }, + { + "role": "user", + "content": "How does AI work? Explain it in simple terms." + }] + ) + ``` + """ + optional_params = GenericLiteLLMParams(**kwargs) + + ### TIMEOUT LOGIC ### + timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600 + # set timeout for 10 minutes by default + + if ( + timeout is not None + and isinstance(timeout, httpx.Timeout) + and supports_httpx_timeout(custom_llm_provider) == False + ): + read_timeout = timeout.read or 600 + timeout = read_timeout # default 10 min timeout + elif timeout is not None and not isinstance(timeout, httpx.Timeout): + timeout = float(timeout) # type: ignore + elif timeout is None: + timeout = 600.0 + + response: Optional[Thread] = None + if custom_llm_provider == "openai": + api_base = ( + optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there + or litellm.api_base + or os.getenv("OPENAI_API_BASE") + or "https://api.openai.com/v1" + ) + organization = ( + optional_params.organization + or litellm.organization + or os.getenv("OPENAI_ORGANIZATION", None) + or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105 + ) + # set API KEY + api_key = ( + optional_params.api_key + or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there + or litellm.openai_key + or os.getenv("OPENAI_API_KEY") + ) + response = openai_assistants_api.create_thread( + messages=messages, + metadata=metadata, + api_base=api_base, + api_key=api_key, + timeout=timeout, + max_retries=optional_params.max_retries, + organization=organization, + client=client, + ) + else: + raise litellm.exceptions.BadRequestError( + message="LiteLLM doesn't support {} for 'create_thread'. Only 'openai' is supported.".format( + custom_llm_provider + ), + model="n/a", + llm_provider=custom_llm_provider, + response=httpx.Response( + status_code=400, + content="Unsupported provider", + request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore + ), + ) + return response + + +### MESSAGES ### + +### RUNS ### diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index a6d6f4109..9cc6d86bb 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -27,73 +27,7 @@ import aiohttp, requests import litellm from .prompt_templates.factory import prompt_factory, custom_prompt from openai import OpenAI, AsyncOpenAI -from openai.types.beta.threads.message_content import MessageContent -from openai.types.beta.threads.message_create_params import Attachment -from openai.types.beta.threads.message import Message as OpenAIMessage -from openai.types.beta.thread_create_params import ( - Message as OpenAICreateThreadParamsMessage, -) -from openai.types.beta.assistant_tool_param import AssistantToolParam -from openai.types.beta.threads.run import Run -from openai.types.beta.assistant import Assistant -from openai.pagination import SyncCursorPage - -from typing import TypedDict, List, Optional - - -class NotGiven: - """ - A sentinel singleton class used to distinguish omitted keyword arguments - from those passed in with the value None (which may have different behavior). - - For example: - - ```py - def get(timeout: Union[int, NotGiven, None] = NotGiven()) -> Response: - ... - - - get(timeout=1) # 1s timeout - get(timeout=None) # No timeout - get() # Default timeout behavior, which may not be statically known at the method definition. - ``` - """ - - def __bool__(self) -> Literal[False]: - return False - - @override - def __repr__(self) -> str: - return "NOT_GIVEN" - - -NOT_GIVEN = NotGiven() - - -class MessageData(TypedDict): - role: Literal["user", "assistant"] - content: str - attachments: Optional[List[Attachment]] - metadata: Optional[dict] - - -class Thread(BaseModel): - id: str - """The identifier, which can be referenced in API endpoints.""" - - created_at: int - """The Unix timestamp (in seconds) for when the thread was created.""" - - metadata: Optional[object] = None - """Set of 16 key-value pairs that can be attached to an object. - - This can be useful for storing additional information about the object in a - structured format. Keys can be a maximum of 64 characters long and values can be - a maxium of 512 characters long. - """ - - object: Literal["thread"] - """The object type, which is always `thread`.""" +from ..types.llms.openai import * class OpenAIError(Exception): @@ -1321,22 +1255,22 @@ class OpenAIAssistantsAPI(BaseLLM): def get_openai_client( self, - api_key: str, + api_key: Optional[str], api_base: Optional[str], timeout: Union[float, httpx.Timeout], - max_retries: int, + max_retries: Optional[int], organization: Optional[str], client: Optional[OpenAI] = None, ) -> OpenAI: + received_args = locals() if client is None: - openai_client = OpenAI( - api_key=api_key, - base_url=api_base, - http_client=litellm.client_session, - timeout=timeout, - max_retries=max_retries, - organization=organization, - ) + data = {} + for k, v in received_args.items(): + if k == "self" or k == "client": + pass + elif v is not None: + data[k] = v + openai_client = OpenAI(**data) # type: ignore else: openai_client = client @@ -1428,16 +1362,14 @@ class OpenAIAssistantsAPI(BaseLLM): def create_thread( self, - metadata: dict, - api_key: str, + metadata: Optional[dict], + api_key: Optional[str], api_base: Optional[str], timeout: Union[float, httpx.Timeout], - max_retries: int, + max_retries: Optional[int], organization: Optional[str], client: Optional[OpenAI], - messages: Union[ - Iterable[OpenAICreateThreadParamsMessage], NotGiven - ] = NOT_GIVEN, + messages: Optional[Iterable[OpenAICreateThreadParamsMessage]], ) -> Thread: """ Here's an example: @@ -1458,10 +1390,13 @@ class OpenAIAssistantsAPI(BaseLLM): client=client, ) - message_thread = openai_client.beta.threads.create( - messages=messages, # type: ignore - metadata=metadata, - ) + data = {} + if messages is not None: + data["messages"] = messages # type: ignore + if metadata is not None: + data["metadata"] = metadata # type: ignore + + message_thread = openai_client.beta.threads.create(**data) # type: ignore return Thread(**message_thread.dict()) diff --git a/litellm/tests/test_assistants.py b/litellm/tests/test_assistants.py index 9b8585ec6..58c8c4c1f 100644 --- a/litellm/tests/test_assistants.py +++ b/litellm/tests/test_assistants.py @@ -10,6 +10,7 @@ sys.path.insert( ) # Adds the parent directory to the system path import pytest, logging, asyncio import litellm +from litellm import create_thread from litellm.llms.openai import ( OpenAIAssistantsAPI, MessageData, @@ -25,7 +26,23 @@ V0 Scope: """ -def test_create_thread() -> Thread: +def test_create_thread_litellm(): + message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore + new_thread = create_thread( + custom_llm_provider="openai", + messages=[message], + ) + + assert isinstance( + new_thread, Thread + ), f"type of thread={type(new_thread)}. Expected Thread-type" + return new_thread + + +test_create_thread_litellm() + + +def test_create_thread_openai_direct() -> Thread: openai_api = OpenAIAssistantsAPI() message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore @@ -48,7 +65,7 @@ def test_create_thread() -> Thread: return new_thread -def test_add_message(): +def test_add_message_openai_direct(): openai_api = OpenAIAssistantsAPI() # create thread new_thread = test_create_thread() @@ -70,7 +87,7 @@ def test_add_message(): assert isinstance(added_message, Message) -def test_get_thread(): +def test_get_thread_openai_direct(): openai_api = OpenAIAssistantsAPI() ## create a thread w/ message ### @@ -92,7 +109,7 @@ def test_get_thread(): return new_thread -def test_run_thread(): +def test_run_thread_openai_direct(): """ - Get Assistants - Create thread diff --git a/litellm/types/llms/__init__.py b/litellm/types/llms/__init__.py new file mode 100644 index 000000000..14952c9ae --- /dev/null +++ b/litellm/types/llms/__init__.py @@ -0,0 +1,3 @@ +__all__ = ["openai"] + +from . import openai diff --git a/litellm/types/llms/openai.py b/litellm/types/llms/openai.py new file mode 100644 index 000000000..f9f7b3bf0 --- /dev/null +++ b/litellm/types/llms/openai.py @@ -0,0 +1,80 @@ +from typing import ( + Optional, + Union, + Any, + BinaryIO, + Literal, + Annotated, + Iterable, +) +from typing_extensions import override +from pydantic import BaseModel + +from openai.types.beta.threads.message_content import MessageContent +from openai.types.beta.threads.message_create_params import Attachment +from openai.types.beta.threads.message import Message as OpenAIMessage +from openai.types.beta.thread_create_params import ( + Message as OpenAICreateThreadParamsMessage, + ToolResources as OpenAICreateThreadParamsToolResources, +) +from openai.types.beta.assistant_tool_param import AssistantToolParam +from openai.types.beta.threads.run import Run +from openai.types.beta.assistant import Assistant +from openai.pagination import SyncCursorPage + +from typing import TypedDict, List, Optional + + +class NotGiven: + """ + A sentinel singleton class used to distinguish omitted keyword arguments + from those passed in with the value None (which may have different behavior). + + For example: + + ```py + def get(timeout: Union[int, NotGiven, None] = NotGiven()) -> Response: + ... + + + get(timeout=1) # 1s timeout + get(timeout=None) # No timeout + get() # Default timeout behavior, which may not be statically known at the method definition. + ``` + """ + + def __bool__(self) -> Literal[False]: + return False + + @override + def __repr__(self) -> str: + return "NOT_GIVEN" + + +NOT_GIVEN = NotGiven() + + +class MessageData(TypedDict): + role: Literal["user", "assistant"] + content: str + attachments: Optional[List[Attachment]] + metadata: Optional[dict] + + +class Thread(BaseModel): + id: str + """The identifier, which can be referenced in API endpoints.""" + + created_at: int + """The Unix timestamp (in seconds) for when the thread was created.""" + + metadata: Optional[object] = None + """Set of 16 key-value pairs that can be attached to an object. + + This can be useful for storing additional information about the object in a + structured format. Keys can be a maximum of 64 characters long and values can be + a maxium of 512 characters long. + """ + + object: Literal["thread"] + """The object type, which is always `thread`.""" diff --git a/litellm/types/router.py b/litellm/types/router.py index 068a99b00..d6b698f01 100644 --- a/litellm/types/router.py +++ b/litellm/types/router.py @@ -97,8 +97,11 @@ class ModelInfo(BaseModel): setattr(self, key, value) -class LiteLLM_Params(BaseModel): - model: str +class GenericLiteLLMParams(BaseModel): + """ + LiteLLM Params without 'model' arg (used across completion / assistants api) + """ + custom_llm_provider: Optional[str] = None tpm: Optional[int] = None rpm: Optional[int] = None @@ -121,6 +124,66 @@ class LiteLLM_Params(BaseModel): aws_secret_access_key: Optional[str] = None aws_region_name: Optional[str] = None + def __init__( + self, + custom_llm_provider: Optional[str] = None, + max_retries: Optional[Union[int, str]] = None, + tpm: Optional[int] = None, + rpm: Optional[int] = None, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + api_version: Optional[str] = None, + timeout: Optional[Union[float, str]] = None, # if str, pass in as os.environ/ + stream_timeout: Optional[Union[float, str]] = ( + None # timeout when making stream=True calls, if str, pass in as os.environ/ + ), + organization: Optional[str] = None, # for openai orgs + ## VERTEX AI ## + vertex_project: Optional[str] = None, + vertex_location: Optional[str] = None, + ## AWS BEDROCK / SAGEMAKER ## + aws_access_key_id: Optional[str] = None, + aws_secret_access_key: Optional[str] = None, + aws_region_name: Optional[str] = None, + **params + ): + args = locals() + args.pop("max_retries", None) + args.pop("self", None) + args.pop("params", None) + args.pop("__class__", None) + if max_retries is not None and isinstance(max_retries, str): + max_retries = int(max_retries) # cast to int + super().__init__(max_retries=max_retries, **args, **params) + + class Config: + extra = "allow" + arbitrary_types_allowed = True + + def __contains__(self, key): + # Define custom behavior for the 'in' operator + return hasattr(self, key) + + def get(self, key, default=None): + # Custom .get() method to access attributes with a default value if the attribute doesn't exist + return getattr(self, key, default) + + def __getitem__(self, key): + # Allow dictionary-style access to attributes + return getattr(self, key) + + def __setitem__(self, key, value): + # Allow dictionary-style assignment of attributes + setattr(self, key, value) + + +class LiteLLM_Params(GenericLiteLLMParams): + """ + LiteLLM Params with 'model' requirement - used for completions + """ + + model: str + def __init__( self, model: str, From b7796c74872e393b5c7b0e60456cb9864e742ba6 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 4 May 2024 19:56:11 -0700 Subject: [PATCH 03/11] feat(assistants/main.py): add 'add_message' endpoint --- litellm/assistants/main.py | 83 +++++++++++++++++++++++++++++++- litellm/llms/openai.py | 6 ++- litellm/tests/test_assistants.py | 20 ++++++-- 3 files changed, 102 insertions(+), 7 deletions(-) diff --git a/litellm/assistants/main.py b/litellm/assistants/main.py index 16a1f973c..b02a2bd59 100644 --- a/litellm/assistants/main.py +++ b/litellm/assistants/main.py @@ -24,7 +24,7 @@ def create_thread( metadata: Optional[dict] = None, tool_resources: Optional[OpenAICreateThreadParamsToolResources] = None, client: Optional[OpenAI] = None, - **kwargs + **kwargs, ) -> Thread: """ - get the llm provider @@ -115,4 +115,85 @@ def create_thread( ### MESSAGES ### + +def add_message( + custom_llm_provider: Literal["openai"], + thread_id: str, + role: Literal["user", "assistant"], + content: str, + attachments: Optional[List[Attachment]] = None, + metadata: Optional[dict] = None, + client: Optional[OpenAI] = None, + **kwargs, +) -> OpenAIMessage: + ### COMMON OBJECTS ### + message_data = MessageData( + role=role, content=content, attachments=attachments, metadata=metadata + ) + optional_params = GenericLiteLLMParams(**kwargs) + + ### TIMEOUT LOGIC ### + timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600 + # set timeout for 10 minutes by default + + if ( + timeout is not None + and isinstance(timeout, httpx.Timeout) + and supports_httpx_timeout(custom_llm_provider) == False + ): + read_timeout = timeout.read or 600 + timeout = read_timeout # default 10 min timeout + elif timeout is not None and not isinstance(timeout, httpx.Timeout): + timeout = float(timeout) # type: ignore + elif timeout is None: + timeout = 600.0 + + response: Optional[OpenAIMessage] = None + if custom_llm_provider == "openai": + api_base = ( + optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there + or litellm.api_base + or os.getenv("OPENAI_API_BASE") + or "https://api.openai.com/v1" + ) + organization = ( + optional_params.organization + or litellm.organization + or os.getenv("OPENAI_ORGANIZATION", None) + or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105 + ) + # set API KEY + api_key = ( + optional_params.api_key + or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there + or litellm.openai_key + or os.getenv("OPENAI_API_KEY") + ) + response = openai_assistants_api.add_message( + thread_id=thread_id, + message_data=message_data, + api_base=api_base, + api_key=api_key, + timeout=timeout, + max_retries=optional_params.max_retries, + organization=organization, + client=client, + ) + else: + raise litellm.exceptions.BadRequestError( + message="LiteLLM doesn't support {} for 'create_thread'. Only 'openai' is supported.".format( + custom_llm_provider + ), + model="n/a", + llm_provider=custom_llm_provider, + response=httpx.Response( + status_code=400, + content="Unsupported provider", + request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore + ), + ) + + return response + + ### RUNS ### diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index 9cc6d86bb..01d3bd2f2 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -1268,6 +1268,8 @@ class OpenAIAssistantsAPI(BaseLLM): for k, v in received_args.items(): if k == "self" or k == "client": pass + elif k == "api_base" and v is not None: + data["base_url"] = v elif v is not None: data[k] = v openai_client = OpenAI(**data) # type: ignore @@ -1306,10 +1308,10 @@ class OpenAIAssistantsAPI(BaseLLM): self, thread_id: str, message_data: MessageData, - api_key: str, + api_key: Optional[str], api_base: Optional[str], timeout: Union[float, httpx.Timeout], - max_retries: int, + max_retries: Optional[int], organization: Optional[str], client: Optional[OpenAI] = None, ) -> OpenAIMessage: diff --git a/litellm/tests/test_assistants.py b/litellm/tests/test_assistants.py index 58c8c4c1f..a3acdae18 100644 --- a/litellm/tests/test_assistants.py +++ b/litellm/tests/test_assistants.py @@ -26,11 +26,11 @@ V0 Scope: """ -def test_create_thread_litellm(): +def test_create_thread_litellm() -> Thread: message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore new_thread = create_thread( custom_llm_provider="openai", - messages=[message], + messages=[message], # type: ignore ) assert isinstance( @@ -39,7 +39,19 @@ def test_create_thread_litellm(): return new_thread -test_create_thread_litellm() +def test_add_message_litellm(): + message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore + new_thread = test_create_thread_litellm() + + # add message to thread + message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore + added_message = litellm.add_message( + thread_id=new_thread.id, custom_llm_provider="openai", **message + ) + + print(f"added message: {added_message}") + + assert isinstance(added_message, Message) def test_create_thread_openai_direct() -> Thread: @@ -68,7 +80,7 @@ def test_create_thread_openai_direct() -> Thread: def test_add_message_openai_direct(): openai_api = OpenAIAssistantsAPI() # create thread - new_thread = test_create_thread() + new_thread = test_create_thread_openai_direct() # add message to thread message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore added_message = openai_api.add_message( From cad01fb5862f89da85afd5486441ea646b169ace Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 4 May 2024 21:14:03 -0700 Subject: [PATCH 04/11] feat(assistants/main.py): support 'litellm.get_threads' --- litellm/assistants/main.py | 71 +++++++++++++++++++++++++++ litellm/llms/openai.py | 4 +- litellm/tests/test_assistants.py | 83 ++++++-------------------------- 3 files changed, 88 insertions(+), 70 deletions(-) diff --git a/litellm/assistants/main.py b/litellm/assistants/main.py index b02a2bd59..957126ae1 100644 --- a/litellm/assistants/main.py +++ b/litellm/assistants/main.py @@ -113,6 +113,77 @@ def create_thread( return response +def get_thread( + custom_llm_provider: Literal["openai"], + thread_id: str, + client: Optional[OpenAI] = None, + **kwargs, +) -> Thread: + """Get the thread object, given a thread_id""" + optional_params = GenericLiteLLMParams(**kwargs) + + ### TIMEOUT LOGIC ### + timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600 + # set timeout for 10 minutes by default + + if ( + timeout is not None + and isinstance(timeout, httpx.Timeout) + and supports_httpx_timeout(custom_llm_provider) == False + ): + read_timeout = timeout.read or 600 + timeout = read_timeout # default 10 min timeout + elif timeout is not None and not isinstance(timeout, httpx.Timeout): + timeout = float(timeout) # type: ignore + elif timeout is None: + timeout = 600.0 + + response: Optional[Thread] = None + if custom_llm_provider == "openai": + api_base = ( + optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there + or litellm.api_base + or os.getenv("OPENAI_API_BASE") + or "https://api.openai.com/v1" + ) + organization = ( + optional_params.organization + or litellm.organization + or os.getenv("OPENAI_ORGANIZATION", None) + or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105 + ) + # set API KEY + api_key = ( + optional_params.api_key + or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there + or litellm.openai_key + or os.getenv("OPENAI_API_KEY") + ) + response = openai_assistants_api.get_thread( + thread_id=thread_id, + api_base=api_base, + api_key=api_key, + timeout=timeout, + max_retries=optional_params.max_retries, + organization=organization, + client=client, + ) + else: + raise litellm.exceptions.BadRequestError( + message="LiteLLM doesn't support {} for 'get_thread'. Only 'openai' is supported.".format( + custom_llm_provider + ), + model="n/a", + llm_provider=custom_llm_provider, + response=httpx.Response( + status_code=400, + content="Unsupported provider", + request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore + ), + ) + return response + + ### MESSAGES ### diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index 01d3bd2f2..16f4868f4 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -1405,10 +1405,10 @@ class OpenAIAssistantsAPI(BaseLLM): def get_thread( self, thread_id: str, - api_key: str, + api_key: Optional[str], api_base: Optional[str], timeout: Union[float, httpx.Timeout], - max_retries: int, + max_retries: Optional[int], organization: Optional[str], client: Optional[OpenAI], ) -> Thread: diff --git a/litellm/tests/test_assistants.py b/litellm/tests/test_assistants.py index a3acdae18..ff83dd30c 100644 --- a/litellm/tests/test_assistants.py +++ b/litellm/tests/test_assistants.py @@ -10,7 +10,7 @@ sys.path.insert( ) # Adds the parent directory to the system path import pytest, logging, asyncio import litellm -from litellm import create_thread +from litellm import create_thread, get_thread from litellm.llms.openai import ( OpenAIAssistantsAPI, MessageData, @@ -39,6 +39,20 @@ def test_create_thread_litellm() -> Thread: return new_thread +def test_get_thread_litellm(): + new_thread = test_create_thread_litellm() + + received_thread = get_thread( + custom_llm_provider="openai", + thread_id=new_thread.id, + ) + + assert isinstance( + received_thread, Thread + ), f"type of thread={type(received_thread)}. Expected Thread-type" + return new_thread + + def test_add_message_litellm(): message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore new_thread = test_create_thread_litellm() @@ -54,73 +68,6 @@ def test_add_message_litellm(): assert isinstance(added_message, Message) -def test_create_thread_openai_direct() -> Thread: - openai_api = OpenAIAssistantsAPI() - - message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore - new_thread = openai_api.create_thread( - messages=[message], # type: ignore - api_key=os.getenv("OPENAI_API_KEY"), # type: ignore - metadata={}, - api_base=None, - timeout=600, - max_retries=2, - organization=None, - client=None, - ) - - print(f"new_thread: {new_thread}") - print(f"type of thread: {type(new_thread)}") - assert isinstance( - new_thread, Thread - ), f"type of thread={type(new_thread)}. Expected Thread-type" - return new_thread - - -def test_add_message_openai_direct(): - openai_api = OpenAIAssistantsAPI() - # create thread - new_thread = test_create_thread_openai_direct() - # add message to thread - message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore - added_message = openai_api.add_message( - thread_id=new_thread.id, - message_data=message, - api_key=os.getenv("OPENAI_API_KEY"), - api_base=None, - timeout=600, - max_retries=2, - organization=None, - client=None, - ) - - print(f"added message: {added_message}") - - assert isinstance(added_message, Message) - - -def test_get_thread_openai_direct(): - openai_api = OpenAIAssistantsAPI() - - ## create a thread w/ message ### - new_thread = test_create_thread() - - retrieved_thread = openai_api.get_thread( - thread_id=new_thread.id, - api_key=os.getenv("OPENAI_API_KEY"), - api_base=None, - timeout=600, - max_retries=2, - organization=None, - client=None, - ) - - assert isinstance( - retrieved_thread, Thread - ), f"type of thread={type(retrieved_thread)}. Expected Thread-type" - return new_thread - - def test_run_thread_openai_direct(): """ - Get Assistants From 8fe6c9b4016095bdd13f300c9d9f4e29ed02ee73 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 4 May 2024 21:30:28 -0700 Subject: [PATCH 05/11] feat(assistants/main.py): support `litellm.get_assistants()` and `litellm.get_messages()` --- litellm/assistants/main.py | 225 +++++++++++++++++++++++++++++++ litellm/llms/openai.py | 12 +- litellm/tests/test_assistants.py | 34 +++++ 3 files changed, 265 insertions(+), 6 deletions(-) diff --git a/litellm/assistants/main.py b/litellm/assistants/main.py index 957126ae1..25d2433d7 100644 --- a/litellm/assistants/main.py +++ b/litellm/assistants/main.py @@ -15,6 +15,75 @@ openai_assistants_api = OpenAIAssistantsAPI() ### ASSISTANTS ### + +def get_assistants( + custom_llm_provider: Literal["openai"], + client: Optional[OpenAI] = None, + **kwargs, +) -> SyncCursorPage[Assistant]: + optional_params = GenericLiteLLMParams(**kwargs) + + ### TIMEOUT LOGIC ### + timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600 + # set timeout for 10 minutes by default + + if ( + timeout is not None + and isinstance(timeout, httpx.Timeout) + and supports_httpx_timeout(custom_llm_provider) == False + ): + read_timeout = timeout.read or 600 + timeout = read_timeout # default 10 min timeout + elif timeout is not None and not isinstance(timeout, httpx.Timeout): + timeout = float(timeout) # type: ignore + elif timeout is None: + timeout = 600.0 + + response: Optional[SyncCursorPage[Assistant]] = None + if custom_llm_provider == "openai": + api_base = ( + optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there + or litellm.api_base + or os.getenv("OPENAI_API_BASE") + or "https://api.openai.com/v1" + ) + organization = ( + optional_params.organization + or litellm.organization + or os.getenv("OPENAI_ORGANIZATION", None) + or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105 + ) + # set API KEY + api_key = ( + optional_params.api_key + or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there + or litellm.openai_key + or os.getenv("OPENAI_API_KEY") + ) + response = openai_assistants_api.get_assistants( + api_base=api_base, + api_key=api_key, + timeout=timeout, + max_retries=optional_params.max_retries, + organization=organization, + client=client, + ) + else: + raise litellm.exceptions.BadRequestError( + message="LiteLLM doesn't support {} for 'get_assistants'. Only 'openai' is supported.".format( + custom_llm_provider + ), + model="n/a", + llm_provider=custom_llm_provider, + response=httpx.Response( + status_code=400, + content="Unsupported provider", + request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore + ), + ) + return response + + ### THREADS ### @@ -267,4 +336,160 @@ def add_message( return response +def get_messages( + custom_llm_provider: Literal["openai"], + thread_id: str, + client: Optional[OpenAI] = None, + **kwargs, +) -> SyncCursorPage[OpenAIMessage]: + optional_params = GenericLiteLLMParams(**kwargs) + + ### TIMEOUT LOGIC ### + timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600 + # set timeout for 10 minutes by default + + if ( + timeout is not None + and isinstance(timeout, httpx.Timeout) + and supports_httpx_timeout(custom_llm_provider) == False + ): + read_timeout = timeout.read or 600 + timeout = read_timeout # default 10 min timeout + elif timeout is not None and not isinstance(timeout, httpx.Timeout): + timeout = float(timeout) # type: ignore + elif timeout is None: + timeout = 600.0 + + response: Optional[SyncCursorPage[OpenAIMessage]] = None + if custom_llm_provider == "openai": + api_base = ( + optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there + or litellm.api_base + or os.getenv("OPENAI_API_BASE") + or "https://api.openai.com/v1" + ) + organization = ( + optional_params.organization + or litellm.organization + or os.getenv("OPENAI_ORGANIZATION", None) + or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105 + ) + # set API KEY + api_key = ( + optional_params.api_key + or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there + or litellm.openai_key + or os.getenv("OPENAI_API_KEY") + ) + response = openai_assistants_api.get_messages( + thread_id=thread_id, + api_base=api_base, + api_key=api_key, + timeout=timeout, + max_retries=optional_params.max_retries, + organization=organization, + client=client, + ) + else: + raise litellm.exceptions.BadRequestError( + message="LiteLLM doesn't support {} for 'get_messages'. Only 'openai' is supported.".format( + custom_llm_provider + ), + model="n/a", + llm_provider=custom_llm_provider, + response=httpx.Response( + status_code=400, + content="Unsupported provider", + request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore + ), + ) + + return response + + ### RUNS ### + + +def run_thread( + custom_llm_provider: Literal["openai"], + thread_id: str, + assistant_id: str, + additional_instructions: Optional[str] = None, + instructions: Optional[str] = None, + metadata: Optional[dict] = None, + model: Optional[str] = None, + stream: Optional[bool] = None, + tools: Optional[Iterable[AssistantToolParam]] = None, + client: Optional[OpenAI] = None, + **kwargs, +) -> Run: + """Run a given thread + assistant.""" + optional_params = GenericLiteLLMParams(**kwargs) + + ### TIMEOUT LOGIC ### + timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600 + # set timeout for 10 minutes by default + + if ( + timeout is not None + and isinstance(timeout, httpx.Timeout) + and supports_httpx_timeout(custom_llm_provider) == False + ): + read_timeout = timeout.read or 600 + timeout = read_timeout # default 10 min timeout + elif timeout is not None and not isinstance(timeout, httpx.Timeout): + timeout = float(timeout) # type: ignore + elif timeout is None: + timeout = 600.0 + + response: Optional[Run] = None + if custom_llm_provider == "openai": + api_base = ( + optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there + or litellm.api_base + or os.getenv("OPENAI_API_BASE") + or "https://api.openai.com/v1" + ) + organization = ( + optional_params.organization + or litellm.organization + or os.getenv("OPENAI_ORGANIZATION", None) + or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105 + ) + # set API KEY + api_key = ( + optional_params.api_key + or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there + or litellm.openai_key + or os.getenv("OPENAI_API_KEY") + ) + response = openai_assistants_api.run_thread( + thread_id=thread_id, + assistant_id=assistant_id, + additional_instructions=additional_instructions, + instructions=instructions, + metadata=metadata, + model=model, + stream=stream, + tools=tools, + api_base=api_base, + api_key=api_key, + timeout=timeout, + max_retries=optional_params.max_retries, + organization=organization, + client=client, + ) + else: + raise litellm.exceptions.BadRequestError( + message="LiteLLM doesn't support {} for 'run_thread'. Only 'openai' is supported.".format( + custom_llm_provider + ), + model="n/a", + llm_provider=custom_llm_provider, + response=httpx.Response( + status_code=400, + content="Unsupported provider", + request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore + ), + ) + return response diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index 16f4868f4..a95f83e99 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -1282,10 +1282,10 @@ class OpenAIAssistantsAPI(BaseLLM): def get_assistants( self, - api_key: str, + api_key: Optional[str], api_base: Optional[str], timeout: Union[float, httpx.Timeout], - max_retries: int, + max_retries: Optional[int], organization: Optional[str], client: Optional[OpenAI], ) -> SyncCursorPage[Assistant]: @@ -1340,10 +1340,10 @@ class OpenAIAssistantsAPI(BaseLLM): def get_messages( self, thread_id: str, - api_key: str, + api_key: Optional[str], api_base: Optional[str], timeout: Union[float, httpx.Timeout], - max_retries: int, + max_retries: Optional[int], organization: Optional[str], client: Optional[OpenAI] = None, ) -> SyncCursorPage[OpenAIMessage]: @@ -1440,10 +1440,10 @@ class OpenAIAssistantsAPI(BaseLLM): model: Optional[str], stream: Optional[Literal[False]] | Literal[True], tools: Optional[Iterable[AssistantToolParam]], - api_key: str, + api_key: Optional[str], api_base: Optional[str], timeout: Union[float, httpx.Timeout], - max_retries: int, + max_retries: Optional[int], organization: Optional[str], client: Optional[OpenAI], ) -> Run: diff --git a/litellm/tests/test_assistants.py b/litellm/tests/test_assistants.py index ff83dd30c..940b874ff 100644 --- a/litellm/tests/test_assistants.py +++ b/litellm/tests/test_assistants.py @@ -68,6 +68,40 @@ def test_add_message_litellm(): assert isinstance(added_message, Message) +def test_run_thread_litellm(): + """ + - Get Assistants + - Create thread + - Create run w/ Assistants + Thread + """ + assistants = litellm.get_assistants(custom_llm_provider="openai") + + ## get the first assistant ### + assistant_id = assistants.data[0].id + + new_thread = test_create_thread_litellm() + + thread_id = new_thread.id + + # add message to thread + message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore + added_message = litellm.add_message( + thread_id=new_thread.id, custom_llm_provider="openai", **message + ) + + run = litellm.run_thread( + custom_llm_provider="openai", thread_id=thread_id, assistant_id=assistant_id + ) + + if run.status == "completed": + messages = litellm.get_messages( + thread_id=new_thread.id, custom_llm_provider="openai" + ) + assert isinstance(messages.data[0], Message) + else: + pytest.fail("An unexpected error occurred when running the thread") + + def test_run_thread_openai_direct(): """ - Get Assistants From 5406205e4b9a128ea037a3ec8ed858ddb5afce64 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 4 May 2024 21:31:07 -0700 Subject: [PATCH 06/11] test(test_assistants.py): cleanup tests --- litellm/tests/test_assistants.py | 72 -------------------------------- 1 file changed, 72 deletions(-) diff --git a/litellm/tests/test_assistants.py b/litellm/tests/test_assistants.py index 940b874ff..7f20a6df0 100644 --- a/litellm/tests/test_assistants.py +++ b/litellm/tests/test_assistants.py @@ -100,75 +100,3 @@ def test_run_thread_litellm(): assert isinstance(messages.data[0], Message) else: pytest.fail("An unexpected error occurred when running the thread") - - -def test_run_thread_openai_direct(): - """ - - Get Assistants - - Create thread - - Create run w/ Assistants + Thread - """ - openai_api = OpenAIAssistantsAPI() - - assistants = openai_api.get_assistants( - api_key=os.getenv("OPENAI_API_KEY"), - api_base=None, - timeout=600, - max_retries=2, - organization=None, - client=None, - ) - - ## get the first assistant ### - assistant_id = assistants.data[0].id - - ## create a thread w/ message ### - new_thread = test_create_thread() - - thread_id = new_thread.id - - # add message to thread - message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore - added_message = openai_api.add_message( - thread_id=new_thread.id, - message_data=message, - api_key=os.getenv("OPENAI_API_KEY"), - api_base=None, - timeout=600, - max_retries=2, - organization=None, - client=None, - ) - - run = openai_api.run_thread( - thread_id=thread_id, - assistant_id=assistant_id, - additional_instructions=None, - instructions=None, - metadata=None, - model=None, - stream=None, - tools=None, - api_key=os.getenv("OPENAI_API_KEY"), - api_base=None, - timeout=600, - max_retries=2, - organization=None, - client=None, - ) - - print(f"run: {run}") - - if run.status == "completed": - messages = openai_api.get_messages( - thread_id=new_thread.id, - api_key=os.getenv("OPENAI_API_KEY"), - api_base=None, - timeout=600, - max_retries=2, - organization=None, - client=None, - ) - assert isinstance(messages.data[0], Message) - else: - pytest.fail("An unexpected error occurred when running the thread") From f2bf6411d8e607b4be0019496748df52fb314410 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 4 May 2024 21:48:42 -0700 Subject: [PATCH 07/11] fix(openai.py): fix linting error --- litellm/llms/openai.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index a95f83e99..25e22e184 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -1438,7 +1438,7 @@ class OpenAIAssistantsAPI(BaseLLM): instructions: Optional[str], metadata: Optional[object], model: Optional[str], - stream: Optional[Literal[False]] | Literal[True], + stream: Optional[bool], tools: Optional[Iterable[AssistantToolParam]], api_key: Optional[str], api_base: Optional[str], From 1195bf296bdb6013489f4e7f825d690bd47a0944 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 4 May 2024 21:49:30 -0700 Subject: [PATCH 08/11] fix(openai.py): fix typing import for python 3.8 --- litellm/llms/openai.py | 1 - 1 file changed, 1 deletion(-) diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index 25e22e184..f007507c9 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -4,7 +4,6 @@ from typing import ( Any, BinaryIO, Literal, - Annotated, Iterable, ) from typing_extensions import override From 2deac08ff1bc2346dff0d608bfca52ab85ecd6a3 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 4 May 2024 21:53:08 -0700 Subject: [PATCH 09/11] fix(types/openai.py): fix typing import --- litellm/types/llms/openai.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/litellm/types/llms/openai.py b/litellm/types/llms/openai.py index f9f7b3bf0..8a37349d6 100644 --- a/litellm/types/llms/openai.py +++ b/litellm/types/llms/openai.py @@ -7,11 +7,10 @@ from typing import ( Annotated, Iterable, ) -from typing_extensions import override +from typing_extensions import override, Required from pydantic import BaseModel from openai.types.beta.threads.message_content import MessageContent -from openai.types.beta.threads.message_create_params import Attachment from openai.types.beta.threads.message import Message as OpenAIMessage from openai.types.beta.thread_create_params import ( Message as OpenAICreateThreadParamsMessage, @@ -54,6 +53,27 @@ class NotGiven: NOT_GIVEN = NotGiven() +class FileSearchToolParam(TypedDict, total=False): + type: Required[Literal["file_search"]] + """The type of tool being defined: `file_search`""" + + +class CodeInterpreterToolParam(TypedDict, total=False): + type: Required[Literal["code_interpreter"]] + """The type of tool being defined: `code_interpreter`""" + + +AttachmentTool = Union[CodeInterpreterToolParam, FileSearchToolParam] + + +class Attachment(TypedDict, total=False): + file_id: str + """The ID of the file to attach to the message.""" + + tools: Iterable[AttachmentTool] + """The tools to add this file to.""" + + class MessageData(TypedDict): role: Literal["user", "assistant"] content: str From 66129bc92117a4dc0a3eb5434bebe000ba5c022e Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 4 May 2024 22:02:43 -0700 Subject: [PATCH 10/11] fix(typing/openai.py): fix openai typing error (version-related) --- litellm/types/llms/openai.py | 51 +++++++++++++++++++++++++++++++++++- 1 file changed, 50 insertions(+), 1 deletion(-) diff --git a/litellm/types/llms/openai.py b/litellm/types/llms/openai.py index 8a37349d6..fe553c559 100644 --- a/litellm/types/llms/openai.py +++ b/litellm/types/llms/openai.py @@ -14,7 +14,6 @@ from openai.types.beta.threads.message_content import MessageContent from openai.types.beta.threads.message import Message as OpenAIMessage from openai.types.beta.thread_create_params import ( Message as OpenAICreateThreadParamsMessage, - ToolResources as OpenAICreateThreadParamsToolResources, ) from openai.types.beta.assistant_tool_param import AssistantToolParam from openai.types.beta.threads.run import Run @@ -53,6 +52,56 @@ class NotGiven: NOT_GIVEN = NotGiven() +class ToolResourcesCodeInterpreter(TypedDict, total=False): + file_ids: List[str] + """ + A list of [file](https://platform.openai.com/docs/api-reference/files) IDs made + available to the `code_interpreter` tool. There can be a maximum of 20 files + associated with the tool. + """ + + +class ToolResourcesFileSearchVectorStore(TypedDict, total=False): + file_ids: List[str] + """ + A list of [file](https://platform.openai.com/docs/api-reference/files) IDs to + add to the vector store. There can be a maximum of 10000 files in a vector + store. + """ + + metadata: object + """Set of 16 key-value pairs that can be attached to a vector store. + + This can be useful for storing additional information about the vector store in + a structured format. Keys can be a maximum of 64 characters long and values can + be a maxium of 512 characters long. + """ + + +class ToolResourcesFileSearch(TypedDict, total=False): + vector_store_ids: List[str] + """ + The + [vector store](https://platform.openai.com/docs/api-reference/vector-stores/object) + attached to this thread. There can be a maximum of 1 vector store attached to + the thread. + """ + + vector_stores: Iterable[ToolResourcesFileSearchVectorStore] + """ + A helper to create a + [vector store](https://platform.openai.com/docs/api-reference/vector-stores/object) + with file_ids and attach it to this thread. There can be a maximum of 1 vector + store attached to the thread. + """ + + +class OpenAICreateThreadParamsToolResources(TypedDict, total=False): + code_interpreter: ToolResourcesCodeInterpreter + + file_search: ToolResourcesFileSearch + + class FileSearchToolParam(TypedDict, total=False): type: Required[Literal["file_search"]] """The type of tool being defined: `file_search`""" From 06ae5844736d45594242b9b2a4366ace523fb73d Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 4 May 2024 22:04:17 -0700 Subject: [PATCH 11/11] fix(types/openai.py): fix python3.8 typing issue --- litellm/types/llms/openai.py | 1 - 1 file changed, 1 deletion(-) diff --git a/litellm/types/llms/openai.py b/litellm/types/llms/openai.py index fe553c559..1c60ad6db 100644 --- a/litellm/types/llms/openai.py +++ b/litellm/types/llms/openai.py @@ -4,7 +4,6 @@ from typing import ( Any, BinaryIO, Literal, - Annotated, Iterable, ) from typing_extensions import override, Required