diff --git a/litellm/assistants/main.py b/litellm/assistants/main.py new file mode 100644 index 0000000000..0d32164829 --- /dev/null +++ b/litellm/assistants/main.py @@ -0,0 +1,2 @@ +# What is this? +## Main file for assistants API logic diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index 5a76605b3a..a6d6f4109f 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -1,4 +1,14 @@ -from typing import Optional, Union, Any, BinaryIO +from typing import ( + Optional, + Union, + Any, + BinaryIO, + Literal, + Annotated, + Iterable, +) +from typing_extensions import override +from pydantic import BaseModel import types, time, json, traceback import httpx from .base import BaseLLM @@ -17,6 +27,73 @@ import aiohttp, requests import litellm from .prompt_templates.factory import prompt_factory, custom_prompt from openai import OpenAI, AsyncOpenAI +from openai.types.beta.threads.message_content import MessageContent +from openai.types.beta.threads.message_create_params import Attachment +from openai.types.beta.threads.message import Message as OpenAIMessage +from openai.types.beta.thread_create_params import ( + Message as OpenAICreateThreadParamsMessage, +) +from openai.types.beta.assistant_tool_param import AssistantToolParam +from openai.types.beta.threads.run import Run +from openai.types.beta.assistant import Assistant +from openai.pagination import SyncCursorPage + +from typing import TypedDict, List, Optional + + +class NotGiven: + """ + A sentinel singleton class used to distinguish omitted keyword arguments + from those passed in with the value None (which may have different behavior). + + For example: + + ```py + def get(timeout: Union[int, NotGiven, None] = NotGiven()) -> Response: + ... + + + get(timeout=1) # 1s timeout + get(timeout=None) # No timeout + get() # Default timeout behavior, which may not be statically known at the method definition. + ``` + """ + + def __bool__(self) -> Literal[False]: + return False + + @override + def __repr__(self) -> str: + return "NOT_GIVEN" + + +NOT_GIVEN = NotGiven() + + +class MessageData(TypedDict): + role: Literal["user", "assistant"] + content: str + attachments: Optional[List[Attachment]] + metadata: Optional[dict] + + +class Thread(BaseModel): + id: str + """The identifier, which can be referenced in API endpoints.""" + + created_at: int + """The Unix timestamp (in seconds) for when the thread was created.""" + + metadata: Optional[object] = None + """Set of 16 key-value pairs that can be attached to an object. + + This can be useful for storing additional information about the object in a + structured format. Keys can be a maximum of 64 characters long and values can be + a maxium of 512 characters long. + """ + + object: Literal["thread"] + """The object type, which is always `thread`.""" class OpenAIError(Exception): @@ -1236,3 +1313,220 @@ class OpenAITextCompletion(BaseLLM): async for transformed_chunk in streamwrapper: yield transformed_chunk + + +class OpenAIAssistantsAPI(BaseLLM): + def __init__(self) -> None: + super().__init__() + + def get_openai_client( + self, + api_key: str, + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: int, + organization: Optional[str], + client: Optional[OpenAI] = None, + ) -> OpenAI: + if client is None: + openai_client = OpenAI( + api_key=api_key, + base_url=api_base, + http_client=litellm.client_session, + timeout=timeout, + max_retries=max_retries, + organization=organization, + ) + else: + openai_client = client + + return openai_client + + ### ASSISTANTS ### + + def get_assistants( + self, + api_key: str, + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: int, + organization: Optional[str], + client: Optional[OpenAI], + ) -> SyncCursorPage[Assistant]: + openai_client = self.get_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) + + response = openai_client.beta.assistants.list() + + return response + + ### MESSAGES ### + + def add_message( + self, + thread_id: str, + message_data: MessageData, + api_key: str, + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: int, + organization: Optional[str], + client: Optional[OpenAI] = None, + ) -> OpenAIMessage: + + openai_client = self.get_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) + + thread_message: OpenAIMessage = openai_client.beta.threads.messages.create( + thread_id, **message_data + ) + + response_obj: Optional[OpenAIMessage] = None + if getattr(thread_message, "status", None) is None: + thread_message.status = "completed" + response_obj = OpenAIMessage(**thread_message.dict()) + else: + response_obj = OpenAIMessage(**thread_message.dict()) + return response_obj + + def get_messages( + self, + thread_id: str, + api_key: str, + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: int, + organization: Optional[str], + client: Optional[OpenAI] = None, + ) -> SyncCursorPage[OpenAIMessage]: + openai_client = self.get_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) + + response = openai_client.beta.threads.messages.list(thread_id=thread_id) + + return response + + ### THREADS ### + + def create_thread( + self, + metadata: dict, + api_key: str, + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: int, + organization: Optional[str], + client: Optional[OpenAI], + messages: Union[ + Iterable[OpenAICreateThreadParamsMessage], NotGiven + ] = NOT_GIVEN, + ) -> Thread: + """ + Here's an example: + ``` + from litellm.llms.openai import OpenAIAssistantsAPI, MessageData + + # create thread + message: MessageData = {"role": "user", "content": "Hey, how's it going?"} + openai_api.create_thread(messages=[message]) + ``` + """ + openai_client = self.get_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) + + message_thread = openai_client.beta.threads.create( + messages=messages, # type: ignore + metadata=metadata, + ) + + return Thread(**message_thread.dict()) + + def get_thread( + self, + thread_id: str, + api_key: str, + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: int, + organization: Optional[str], + client: Optional[OpenAI], + ) -> Thread: + openai_client = self.get_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) + + response = openai_client.beta.threads.retrieve(thread_id=thread_id) + + return Thread(**response.dict()) + + def delete_thread(self): + pass + + ### RUNS ### + + def run_thread( + self, + thread_id: str, + assistant_id: str, + additional_instructions: Optional[str], + instructions: Optional[str], + metadata: Optional[object], + model: Optional[str], + stream: Optional[Literal[False]] | Literal[True], + tools: Optional[Iterable[AssistantToolParam]], + api_key: str, + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: int, + organization: Optional[str], + client: Optional[OpenAI], + ) -> Run: + openai_client = self.get_openai_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + organization=organization, + client=client, + ) + + response = openai_client.beta.threads.runs.create_and_poll( + thread_id=thread_id, + assistant_id=assistant_id, + additional_instructions=additional_instructions, + instructions=instructions, + metadata=metadata, + model=model, + tools=tools, + ) + + return response diff --git a/litellm/tests/test_assistants.py b/litellm/tests/test_assistants.py new file mode 100644 index 0000000000..9b8585ec6d --- /dev/null +++ b/litellm/tests/test_assistants.py @@ -0,0 +1,164 @@ +# What is this? +## Unit Tests for OpenAI Assistants API +import sys, os, json +import traceback +from dotenv import load_dotenv + +load_dotenv() +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import pytest, logging, asyncio +import litellm +from litellm.llms.openai import ( + OpenAIAssistantsAPI, + MessageData, + Thread, + OpenAIMessage as Message, +) + +""" +V0 Scope: + +- Add Message -> `/v1/threads/{thread_id}/messages` +- Run Thread -> `/v1/threads/{thread_id}/run` +""" + + +def test_create_thread() -> Thread: + openai_api = OpenAIAssistantsAPI() + + message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore + new_thread = openai_api.create_thread( + messages=[message], # type: ignore + api_key=os.getenv("OPENAI_API_KEY"), # type: ignore + metadata={}, + api_base=None, + timeout=600, + max_retries=2, + organization=None, + client=None, + ) + + print(f"new_thread: {new_thread}") + print(f"type of thread: {type(new_thread)}") + assert isinstance( + new_thread, Thread + ), f"type of thread={type(new_thread)}. Expected Thread-type" + return new_thread + + +def test_add_message(): + openai_api = OpenAIAssistantsAPI() + # create thread + new_thread = test_create_thread() + # add message to thread + message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore + added_message = openai_api.add_message( + thread_id=new_thread.id, + message_data=message, + api_key=os.getenv("OPENAI_API_KEY"), + api_base=None, + timeout=600, + max_retries=2, + organization=None, + client=None, + ) + + print(f"added message: {added_message}") + + assert isinstance(added_message, Message) + + +def test_get_thread(): + openai_api = OpenAIAssistantsAPI() + + ## create a thread w/ message ### + new_thread = test_create_thread() + + retrieved_thread = openai_api.get_thread( + thread_id=new_thread.id, + api_key=os.getenv("OPENAI_API_KEY"), + api_base=None, + timeout=600, + max_retries=2, + organization=None, + client=None, + ) + + assert isinstance( + retrieved_thread, Thread + ), f"type of thread={type(retrieved_thread)}. Expected Thread-type" + return new_thread + + +def test_run_thread(): + """ + - Get Assistants + - Create thread + - Create run w/ Assistants + Thread + """ + openai_api = OpenAIAssistantsAPI() + + assistants = openai_api.get_assistants( + api_key=os.getenv("OPENAI_API_KEY"), + api_base=None, + timeout=600, + max_retries=2, + organization=None, + client=None, + ) + + ## get the first assistant ### + assistant_id = assistants.data[0].id + + ## create a thread w/ message ### + new_thread = test_create_thread() + + thread_id = new_thread.id + + # add message to thread + message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore + added_message = openai_api.add_message( + thread_id=new_thread.id, + message_data=message, + api_key=os.getenv("OPENAI_API_KEY"), + api_base=None, + timeout=600, + max_retries=2, + organization=None, + client=None, + ) + + run = openai_api.run_thread( + thread_id=thread_id, + assistant_id=assistant_id, + additional_instructions=None, + instructions=None, + metadata=None, + model=None, + stream=None, + tools=None, + api_key=os.getenv("OPENAI_API_KEY"), + api_base=None, + timeout=600, + max_retries=2, + organization=None, + client=None, + ) + + print(f"run: {run}") + + if run.status == "completed": + messages = openai_api.get_messages( + thread_id=new_thread.id, + api_key=os.getenv("OPENAI_API_KEY"), + api_base=None, + timeout=600, + max_retries=2, + organization=None, + client=None, + ) + assert isinstance(messages.data[0], Message) + else: + pytest.fail("An unexpected error occurred when running the thread")