feat(openai.py): add support for openai assistants

v0 commit. Closes https://github.com/BerriAI/litellm/issues/2842
This commit is contained in:
Krrish Dholakia 2024-05-04 17:27:48 -07:00
parent b7ca9a53c9
commit 84c31a5528
3 changed files with 461 additions and 1 deletions

View file

@ -0,0 +1,2 @@
# What is this?
## Main file for assistants API logic

View file

@ -1,4 +1,14 @@
from typing import Optional, Union, Any, BinaryIO from typing import (
Optional,
Union,
Any,
BinaryIO,
Literal,
Annotated,
Iterable,
)
from typing_extensions import override
from pydantic import BaseModel
import types, time, json, traceback import types, time, json, traceback
import httpx import httpx
from .base import BaseLLM from .base import BaseLLM
@ -17,6 +27,73 @@ import aiohttp, requests
import litellm import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
from openai import OpenAI, AsyncOpenAI from openai import OpenAI, AsyncOpenAI
from openai.types.beta.threads.message_content import MessageContent
from openai.types.beta.threads.message_create_params import Attachment
from openai.types.beta.threads.message import Message as OpenAIMessage
from openai.types.beta.thread_create_params import (
Message as OpenAICreateThreadParamsMessage,
)
from openai.types.beta.assistant_tool_param import AssistantToolParam
from openai.types.beta.threads.run import Run
from openai.types.beta.assistant import Assistant
from openai.pagination import SyncCursorPage
from typing import TypedDict, List, Optional
class NotGiven:
"""
A sentinel singleton class used to distinguish omitted keyword arguments
from those passed in with the value None (which may have different behavior).
For example:
```py
def get(timeout: Union[int, NotGiven, None] = NotGiven()) -> Response:
...
get(timeout=1) # 1s timeout
get(timeout=None) # No timeout
get() # Default timeout behavior, which may not be statically known at the method definition.
```
"""
def __bool__(self) -> Literal[False]:
return False
@override
def __repr__(self) -> str:
return "NOT_GIVEN"
NOT_GIVEN = NotGiven()
class MessageData(TypedDict):
role: Literal["user", "assistant"]
content: str
attachments: Optional[List[Attachment]]
metadata: Optional[dict]
class Thread(BaseModel):
id: str
"""The identifier, which can be referenced in API endpoints."""
created_at: int
"""The Unix timestamp (in seconds) for when the thread was created."""
metadata: Optional[object] = None
"""Set of 16 key-value pairs that can be attached to an object.
This can be useful for storing additional information about the object in a
structured format. Keys can be a maximum of 64 characters long and values can be
a maxium of 512 characters long.
"""
object: Literal["thread"]
"""The object type, which is always `thread`."""
class OpenAIError(Exception): class OpenAIError(Exception):
@ -1236,3 +1313,220 @@ class OpenAITextCompletion(BaseLLM):
async for transformed_chunk in streamwrapper: async for transformed_chunk in streamwrapper:
yield transformed_chunk yield transformed_chunk
class OpenAIAssistantsAPI(BaseLLM):
def __init__(self) -> None:
super().__init__()
def get_openai_client(
self,
api_key: str,
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: int,
organization: Optional[str],
client: Optional[OpenAI] = None,
) -> OpenAI:
if client is None:
openai_client = OpenAI(
api_key=api_key,
base_url=api_base,
http_client=litellm.client_session,
timeout=timeout,
max_retries=max_retries,
organization=organization,
)
else:
openai_client = client
return openai_client
### ASSISTANTS ###
def get_assistants(
self,
api_key: str,
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: int,
organization: Optional[str],
client: Optional[OpenAI],
) -> SyncCursorPage[Assistant]:
openai_client = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
response = openai_client.beta.assistants.list()
return response
### MESSAGES ###
def add_message(
self,
thread_id: str,
message_data: MessageData,
api_key: str,
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: int,
organization: Optional[str],
client: Optional[OpenAI] = None,
) -> OpenAIMessage:
openai_client = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
thread_message: OpenAIMessage = openai_client.beta.threads.messages.create(
thread_id, **message_data
)
response_obj: Optional[OpenAIMessage] = None
if getattr(thread_message, "status", None) is None:
thread_message.status = "completed"
response_obj = OpenAIMessage(**thread_message.dict())
else:
response_obj = OpenAIMessage(**thread_message.dict())
return response_obj
def get_messages(
self,
thread_id: str,
api_key: str,
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: int,
organization: Optional[str],
client: Optional[OpenAI] = None,
) -> SyncCursorPage[OpenAIMessage]:
openai_client = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
response = openai_client.beta.threads.messages.list(thread_id=thread_id)
return response
### THREADS ###
def create_thread(
self,
metadata: dict,
api_key: str,
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: int,
organization: Optional[str],
client: Optional[OpenAI],
messages: Union[
Iterable[OpenAICreateThreadParamsMessage], NotGiven
] = NOT_GIVEN,
) -> Thread:
"""
Here's an example:
```
from litellm.llms.openai import OpenAIAssistantsAPI, MessageData
# create thread
message: MessageData = {"role": "user", "content": "Hey, how's it going?"}
openai_api.create_thread(messages=[message])
```
"""
openai_client = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
message_thread = openai_client.beta.threads.create(
messages=messages, # type: ignore
metadata=metadata,
)
return Thread(**message_thread.dict())
def get_thread(
self,
thread_id: str,
api_key: str,
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: int,
organization: Optional[str],
client: Optional[OpenAI],
) -> Thread:
openai_client = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
response = openai_client.beta.threads.retrieve(thread_id=thread_id)
return Thread(**response.dict())
def delete_thread(self):
pass
### RUNS ###
def run_thread(
self,
thread_id: str,
assistant_id: str,
additional_instructions: Optional[str],
instructions: Optional[str],
metadata: Optional[object],
model: Optional[str],
stream: Optional[Literal[False]] | Literal[True],
tools: Optional[Iterable[AssistantToolParam]],
api_key: str,
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: int,
organization: Optional[str],
client: Optional[OpenAI],
) -> Run:
openai_client = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
response = openai_client.beta.threads.runs.create_and_poll(
thread_id=thread_id,
assistant_id=assistant_id,
additional_instructions=additional_instructions,
instructions=instructions,
metadata=metadata,
model=model,
tools=tools,
)
return response

View file

@ -0,0 +1,164 @@
# What is this?
## Unit Tests for OpenAI Assistants API
import sys, os, json
import traceback
from dotenv import load_dotenv
load_dotenv()
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest, logging, asyncio
import litellm
from litellm.llms.openai import (
OpenAIAssistantsAPI,
MessageData,
Thread,
OpenAIMessage as Message,
)
"""
V0 Scope:
- Add Message -> `/v1/threads/{thread_id}/messages`
- Run Thread -> `/v1/threads/{thread_id}/run`
"""
def test_create_thread() -> Thread:
openai_api = OpenAIAssistantsAPI()
message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore
new_thread = openai_api.create_thread(
messages=[message], # type: ignore
api_key=os.getenv("OPENAI_API_KEY"), # type: ignore
metadata={},
api_base=None,
timeout=600,
max_retries=2,
organization=None,
client=None,
)
print(f"new_thread: {new_thread}")
print(f"type of thread: {type(new_thread)}")
assert isinstance(
new_thread, Thread
), f"type of thread={type(new_thread)}. Expected Thread-type"
return new_thread
def test_add_message():
openai_api = OpenAIAssistantsAPI()
# create thread
new_thread = test_create_thread()
# add message to thread
message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore
added_message = openai_api.add_message(
thread_id=new_thread.id,
message_data=message,
api_key=os.getenv("OPENAI_API_KEY"),
api_base=None,
timeout=600,
max_retries=2,
organization=None,
client=None,
)
print(f"added message: {added_message}")
assert isinstance(added_message, Message)
def test_get_thread():
openai_api = OpenAIAssistantsAPI()
## create a thread w/ message ###
new_thread = test_create_thread()
retrieved_thread = openai_api.get_thread(
thread_id=new_thread.id,
api_key=os.getenv("OPENAI_API_KEY"),
api_base=None,
timeout=600,
max_retries=2,
organization=None,
client=None,
)
assert isinstance(
retrieved_thread, Thread
), f"type of thread={type(retrieved_thread)}. Expected Thread-type"
return new_thread
def test_run_thread():
"""
- Get Assistants
- Create thread
- Create run w/ Assistants + Thread
"""
openai_api = OpenAIAssistantsAPI()
assistants = openai_api.get_assistants(
api_key=os.getenv("OPENAI_API_KEY"),
api_base=None,
timeout=600,
max_retries=2,
organization=None,
client=None,
)
## get the first assistant ###
assistant_id = assistants.data[0].id
## create a thread w/ message ###
new_thread = test_create_thread()
thread_id = new_thread.id
# add message to thread
message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore
added_message = openai_api.add_message(
thread_id=new_thread.id,
message_data=message,
api_key=os.getenv("OPENAI_API_KEY"),
api_base=None,
timeout=600,
max_retries=2,
organization=None,
client=None,
)
run = openai_api.run_thread(
thread_id=thread_id,
assistant_id=assistant_id,
additional_instructions=None,
instructions=None,
metadata=None,
model=None,
stream=None,
tools=None,
api_key=os.getenv("OPENAI_API_KEY"),
api_base=None,
timeout=600,
max_retries=2,
organization=None,
client=None,
)
print(f"run: {run}")
if run.status == "completed":
messages = openai_api.get_messages(
thread_id=new_thread.id,
api_key=os.getenv("OPENAI_API_KEY"),
api_base=None,
timeout=600,
max_retries=2,
organization=None,
client=None,
)
assert isinstance(messages.data[0], Message)
else:
pytest.fail("An unexpected error occurred when running the thread")