mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
164 lines
4.1 KiB
Python
164 lines
4.1 KiB
Python
# What is this?
|
|
## Unit Tests for OpenAI Assistants API
|
|
import sys, os, json
|
|
import traceback
|
|
from dotenv import load_dotenv
|
|
|
|
load_dotenv()
|
|
sys.path.insert(
|
|
0, os.path.abspath("../..")
|
|
) # Adds the parent directory to the system path
|
|
import pytest, logging, asyncio
|
|
import litellm
|
|
from litellm.llms.openai import (
|
|
OpenAIAssistantsAPI,
|
|
MessageData,
|
|
Thread,
|
|
OpenAIMessage as Message,
|
|
)
|
|
|
|
"""
|
|
V0 Scope:
|
|
|
|
- Add Message -> `/v1/threads/{thread_id}/messages`
|
|
- Run Thread -> `/v1/threads/{thread_id}/run`
|
|
"""
|
|
|
|
|
|
def test_create_thread() -> Thread:
|
|
openai_api = OpenAIAssistantsAPI()
|
|
|
|
message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore
|
|
new_thread = openai_api.create_thread(
|
|
messages=[message], # type: ignore
|
|
api_key=os.getenv("OPENAI_API_KEY"), # type: ignore
|
|
metadata={},
|
|
api_base=None,
|
|
timeout=600,
|
|
max_retries=2,
|
|
organization=None,
|
|
client=None,
|
|
)
|
|
|
|
print(f"new_thread: {new_thread}")
|
|
print(f"type of thread: {type(new_thread)}")
|
|
assert isinstance(
|
|
new_thread, Thread
|
|
), f"type of thread={type(new_thread)}. Expected Thread-type"
|
|
return new_thread
|
|
|
|
|
|
def test_add_message():
|
|
openai_api = OpenAIAssistantsAPI()
|
|
# create thread
|
|
new_thread = test_create_thread()
|
|
# add message to thread
|
|
message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore
|
|
added_message = openai_api.add_message(
|
|
thread_id=new_thread.id,
|
|
message_data=message,
|
|
api_key=os.getenv("OPENAI_API_KEY"),
|
|
api_base=None,
|
|
timeout=600,
|
|
max_retries=2,
|
|
organization=None,
|
|
client=None,
|
|
)
|
|
|
|
print(f"added message: {added_message}")
|
|
|
|
assert isinstance(added_message, Message)
|
|
|
|
|
|
def test_get_thread():
|
|
openai_api = OpenAIAssistantsAPI()
|
|
|
|
## create a thread w/ message ###
|
|
new_thread = test_create_thread()
|
|
|
|
retrieved_thread = openai_api.get_thread(
|
|
thread_id=new_thread.id,
|
|
api_key=os.getenv("OPENAI_API_KEY"),
|
|
api_base=None,
|
|
timeout=600,
|
|
max_retries=2,
|
|
organization=None,
|
|
client=None,
|
|
)
|
|
|
|
assert isinstance(
|
|
retrieved_thread, Thread
|
|
), f"type of thread={type(retrieved_thread)}. Expected Thread-type"
|
|
return new_thread
|
|
|
|
|
|
def test_run_thread():
|
|
"""
|
|
- Get Assistants
|
|
- Create thread
|
|
- Create run w/ Assistants + Thread
|
|
"""
|
|
openai_api = OpenAIAssistantsAPI()
|
|
|
|
assistants = openai_api.get_assistants(
|
|
api_key=os.getenv("OPENAI_API_KEY"),
|
|
api_base=None,
|
|
timeout=600,
|
|
max_retries=2,
|
|
organization=None,
|
|
client=None,
|
|
)
|
|
|
|
## get the first assistant ###
|
|
assistant_id = assistants.data[0].id
|
|
|
|
## create a thread w/ message ###
|
|
new_thread = test_create_thread()
|
|
|
|
thread_id = new_thread.id
|
|
|
|
# add message to thread
|
|
message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore
|
|
added_message = openai_api.add_message(
|
|
thread_id=new_thread.id,
|
|
message_data=message,
|
|
api_key=os.getenv("OPENAI_API_KEY"),
|
|
api_base=None,
|
|
timeout=600,
|
|
max_retries=2,
|
|
organization=None,
|
|
client=None,
|
|
)
|
|
|
|
run = openai_api.run_thread(
|
|
thread_id=thread_id,
|
|
assistant_id=assistant_id,
|
|
additional_instructions=None,
|
|
instructions=None,
|
|
metadata=None,
|
|
model=None,
|
|
stream=None,
|
|
tools=None,
|
|
api_key=os.getenv("OPENAI_API_KEY"),
|
|
api_base=None,
|
|
timeout=600,
|
|
max_retries=2,
|
|
organization=None,
|
|
client=None,
|
|
)
|
|
|
|
print(f"run: {run}")
|
|
|
|
if run.status == "completed":
|
|
messages = openai_api.get_messages(
|
|
thread_id=new_thread.id,
|
|
api_key=os.getenv("OPENAI_API_KEY"),
|
|
api_base=None,
|
|
timeout=600,
|
|
max_retries=2,
|
|
organization=None,
|
|
client=None,
|
|
)
|
|
assert isinstance(messages.data[0], Message)
|
|
else:
|
|
pytest.fail("An unexpected error occurred when running the thread")
|