litellm-mirror/tests/local_testing/test_assistants.py
2025-04-19 08:09:45 -07:00

331 lines
10 KiB
Python

# What is this?
## Unit Tests for OpenAI Assistants API
import json
import os
import sys
import traceback
from dotenv import load_dotenv
load_dotenv()
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import asyncio
import logging
import pytest
from openai.types.beta.assistant import Assistant
from typing_extensions import override
import litellm
from litellm import create_thread, get_thread
from litellm.llms.openai.openai import (
AssistantEventHandler,
AsyncAssistantEventHandler,
AsyncCursorPage,
MessageData,
OpenAIAssistantsAPI,
)
from litellm.llms.openai.openai import OpenAIMessage as Message
from litellm.llms.openai.openai import SyncCursorPage, Thread
"""
V0 Scope:
- Add Message -> `/v1/threads/{thread_id}/messages`
- Run Thread -> `/v1/threads/{thread_id}/run`
"""
def _add_azure_related_dynamic_params(data: dict) -> dict:
data["api_version"] = "2024-02-15-preview"
data["api_base"] = os.getenv("AZURE_ASSISTANTS_API_BASE")
data["api_key"] = os.getenv("AZURE_ASSISTANTS_API_KEY")
return data
@pytest.mark.parametrize("provider", ["openai", "azure"])
@pytest.mark.parametrize(
"sync_mode",
[True, False],
)
@pytest.mark.asyncio
async def test_get_assistants(provider, sync_mode):
data = {
"custom_llm_provider": provider,
}
if provider == "azure":
data = _add_azure_related_dynamic_params(data)
if sync_mode == True:
assistants = litellm.get_assistants(**data)
assert isinstance(assistants, SyncCursorPage)
else:
assistants = await litellm.aget_assistants(**data)
assert isinstance(assistants, AsyncCursorPage)
@pytest.mark.parametrize("provider", ["azure", "openai"])
@pytest.mark.parametrize(
"sync_mode",
[True, False],
)
@pytest.mark.asyncio()
@pytest.mark.flaky(retries=3, delay=1)
async def test_create_delete_assistants(provider, sync_mode):
litellm.ssl_verify = False
litellm._turn_on_debug()
data = {
"custom_llm_provider": provider,
"model": "gpt-4.5-preview",
"instructions": "You are a personal math tutor. When asked a question, write and run Python code to answer the question.",
"name": "Math Tutor",
"tools": [{"type": "code_interpreter"}],
}
if provider == "azure":
data = _add_azure_related_dynamic_params(data)
if sync_mode == True:
assistant = litellm.create_assistants(**data)
print("New assistants", assistant)
assert isinstance(assistant, Assistant)
assert (
assistant.instructions
== "You are a personal math tutor. When asked a question, write and run Python code to answer the question."
)
assert assistant.id is not None
# delete the created assistant
delete_data = {
"custom_llm_provider": provider,
"assistant_id": assistant.id,
}
if provider == "azure":
delete_data = _add_azure_related_dynamic_params(delete_data)
response = litellm.delete_assistant(**delete_data)
print("Response deleting assistant", response)
assert response.id == assistant.id
else:
assistant = await litellm.acreate_assistants(**data)
print("New assistants", assistant)
assert isinstance(assistant, Assistant)
assert (
assistant.instructions
== "You are a personal math tutor. When asked a question, write and run Python code to answer the question."
)
assert assistant.id is not None
# delete the created assistant
delete_data = {
"custom_llm_provider": provider,
"assistant_id": assistant.id,
}
if provider == "azure":
delete_data = _add_azure_related_dynamic_params(delete_data)
response = await litellm.adelete_assistant(**delete_data)
print("Response deleting assistant", response)
assert response.id == assistant.id
@pytest.mark.parametrize("provider", ["openai", "azure"])
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_create_thread_litellm(sync_mode, provider) -> Thread:
message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore
data = {
"custom_llm_provider": provider,
"message": [message],
}
if provider == "azure":
data = _add_azure_related_dynamic_params(data)
if sync_mode:
new_thread = create_thread(**data)
else:
new_thread = await litellm.acreate_thread(**data)
assert isinstance(
new_thread, Thread
), f"type of thread={type(new_thread)}. Expected Thread-type"
return new_thread
@pytest.mark.parametrize("provider", ["openai", "azure"])
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_get_thread_litellm(provider, sync_mode):
new_thread = test_create_thread_litellm(sync_mode, provider)
if asyncio.iscoroutine(new_thread):
_new_thread = await new_thread
else:
_new_thread = new_thread
data = {
"custom_llm_provider": provider,
"thread_id": _new_thread.id,
}
if provider == "azure":
data = _add_azure_related_dynamic_params(data)
if sync_mode:
received_thread = get_thread(**data)
else:
received_thread = await litellm.aget_thread(**data)
assert isinstance(
received_thread, Thread
), f"type of thread={type(received_thread)}. Expected Thread-type"
return new_thread
@pytest.mark.parametrize("provider", ["openai", "azure"])
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_add_message_litellm(sync_mode, provider):
message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore
new_thread = test_create_thread_litellm(sync_mode, provider)
if asyncio.iscoroutine(new_thread):
_new_thread = await new_thread
else:
_new_thread = new_thread
# add message to thread
message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore
data = {"custom_llm_provider": provider, "thread_id": _new_thread.id, **message}
if provider == "azure":
data = _add_azure_related_dynamic_params(data)
if sync_mode:
added_message = litellm.add_message(**data)
else:
added_message = await litellm.a_add_message(**data)
print(f"added message: {added_message}")
assert isinstance(added_message, Message)
@pytest.mark.parametrize(
"provider",
[
"azure",
"openai",
],
) #
@pytest.mark.parametrize(
"sync_mode",
[
True,
False,
],
)
@pytest.mark.parametrize(
"is_streaming",
[True, False],
) #
@pytest.mark.asyncio
@pytest.mark.flaky(retries=3, delay=1)
async def test_aarun_thread_litellm(sync_mode, provider, is_streaming):
"""
- Get Assistants
- Create thread
- Create run w/ Assistants + Thread
"""
import openai
try:
get_assistants_data = {
"custom_llm_provider": provider,
}
if provider == "azure":
get_assistants_data = _add_azure_related_dynamic_params(get_assistants_data)
if sync_mode:
assistants = litellm.get_assistants(**get_assistants_data)
else:
assistants = await litellm.aget_assistants(**get_assistants_data)
## get the first assistant ###
try:
assistant_id = assistants.data[0].id
except IndexError:
pytest.skip("No assistants found")
new_thread = test_create_thread_litellm(sync_mode=sync_mode, provider=provider)
if asyncio.iscoroutine(new_thread):
_new_thread = await new_thread
else:
_new_thread = new_thread
thread_id = _new_thread.id
# add message to thread
message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore
data = {"custom_llm_provider": provider, "thread_id": _new_thread.id, **message}
if provider == "azure":
data = _add_azure_related_dynamic_params(data)
if sync_mode:
added_message = litellm.add_message(**data)
if is_streaming:
run = litellm.run_thread_stream(assistant_id=assistant_id, **data)
with run as run:
assert isinstance(run, AssistantEventHandler)
print(run)
run.until_done()
else:
run = litellm.run_thread(
assistant_id=assistant_id, stream=is_streaming, **data
)
if run.status == "completed":
messages = litellm.get_messages(
thread_id=_new_thread.id, custom_llm_provider=provider
)
assert isinstance(messages.data[0], Message)
else:
pytest.fail(
"An unexpected error occurred when running the thread, {}".format(
run
)
)
else:
added_message = await litellm.a_add_message(**data)
if is_streaming:
run = litellm.arun_thread_stream(assistant_id=assistant_id, **data)
async with run as run:
print(f"run: {run}")
assert isinstance(
run,
AsyncAssistantEventHandler,
)
print(run)
await run.until_done()
else:
run = await litellm.arun_thread(
custom_llm_provider=provider,
thread_id=thread_id,
assistant_id=assistant_id,
)
if run.status == "completed":
messages = await litellm.aget_messages(
thread_id=_new_thread.id, custom_llm_provider=provider
)
assert isinstance(messages.data[0], Message)
else:
pytest.fail(
"An unexpected error occurred when running the thread, {}".format(
run
)
)
except openai.APIError as e:
pass