test fixes for azure assistants

This commit is contained in:
Ishaan Jaff 2025-04-19 07:36:34 -07:00
parent ef6ac42658
commit 5bf76f0bb1

View file

@ -37,6 +37,12 @@ V0 Scope:
- Run Thread -> `/v1/threads/{thread_id}/run`
"""
def _add_azure_related_dynamic_params(data: dict) -> dict:
data["api_version"] = "2024-02-15-preview"
data["api_base"] = os.getenv("AZURE_ASSISTANTS_API_BASE")
data["api_key"] = os.getenv("AZURE_ASSISTANTS_API_KEY")
return data
@pytest.mark.parametrize("provider", ["openai", "azure"])
@pytest.mark.parametrize(
@ -49,7 +55,7 @@ async def test_get_assistants(provider, sync_mode):
"custom_llm_provider": provider,
}
if provider == "azure":
data["api_version"] = "2024-02-15-preview"
data = _add_azure_related_dynamic_params(data)
if sync_mode == True:
assistants = litellm.get_assistants(**data)
@ -68,19 +74,18 @@ async def test_get_assistants(provider, sync_mode):
@pytest.mark.flaky(retries=3, delay=1)
async def test_create_delete_assistants(provider, sync_mode):
litellm.ssl_verify = False
model = "gpt-4-turbo"
data = {
"custom_llm_provider": provider,
"model": "gpt-4-turbo",
"instructions": "You are a personal math tutor. When asked a question, write and run Python code to answer the question.",
"name": "Math Tutor",
"tools": [{"type": "code_interpreter"}],
}
if provider == "azure":
os.environ["AZURE_API_VERSION"] = "2024-05-01-preview"
model = "chatgpt-v-3"
data = _add_azure_related_dynamic_params(data)
if sync_mode == True:
assistant = litellm.create_assistants(
custom_llm_provider=provider,
model=model,
instructions="You are a personal math tutor. When asked a question, write and run Python code to answer the question.",
name="Math Tutor",
tools=[{"type": "code_interpreter"}],
)
assistant = litellm.create_assistants(**data)
print("New assistants", assistant)
assert isinstance(assistant, Assistant)
@ -91,18 +96,19 @@ async def test_create_delete_assistants(provider, sync_mode):
assert assistant.id is not None
# delete the created assistant
response = litellm.delete_assistant(
custom_llm_provider=provider, assistant_id=assistant.id
)
delete_data = {
"custom_llm_provider": provider,
"assistant_id": assistant.id,
}
if provider == "azure":
delete_data = _add_azure_related_dynamic_params(delete_data)
response = litellm.delete_assistant(**delete_data)
print("Response deleting assistant", response)
assert response.id == assistant.id
else:
assistant = await litellm.acreate_assistants(
custom_llm_provider=provider,
model=model,
instructions="You are a personal math tutor. When asked a question, write and run Python code to answer the question.",
name="Math Tutor",
tools=[{"type": "code_interpreter"}],
**data,
)
print("New assistants", assistant)
assert isinstance(assistant, Assistant)
@ -112,9 +118,14 @@ async def test_create_delete_assistants(provider, sync_mode):
)
assert assistant.id is not None
response = await litellm.adelete_assistant(
custom_llm_provider=provider, assistant_id=assistant.id
)
# delete the created assistant
delete_data = {
"custom_llm_provider": provider,
"assistant_id": assistant.id,
}
if provider == "azure":
delete_data = _add_azure_related_dynamic_params(delete_data)
response = await litellm.adelete_assistant(**delete_data)
print("Response deleting assistant", response)
assert response.id == assistant.id
@ -129,7 +140,7 @@ async def test_create_thread_litellm(sync_mode, provider) -> Thread:
"message": [message],
}
if provider == "azure":
data["api_version"] = "2024-02-15-preview"
data = _add_azure_related_dynamic_params(data)
if sync_mode:
new_thread = create_thread(**data)
@ -159,7 +170,7 @@ async def test_get_thread_litellm(provider, sync_mode):
"thread_id": _new_thread.id,
}
if provider == "azure":
data["api_version"] = "2024-02-15-preview"
data = _add_azure_related_dynamic_params(data)
if sync_mode:
received_thread = get_thread(**data)
@ -188,7 +199,7 @@ async def test_add_message_litellm(sync_mode, provider):
data = {"custom_llm_provider": provider, "thread_id": _new_thread.id, **message}
if provider == "azure":
data["api_version"] = "2024-02-15-preview"
data = _add_azure_related_dynamic_params(data)
if sync_mode:
added_message = litellm.add_message(**data)
else:
@ -252,6 +263,8 @@ async def test_aarun_thread_litellm(sync_mode, provider, is_streaming):
message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore
data = {"custom_llm_provider": provider, "thread_id": _new_thread.id, **message}
if provider == "azure":
data = _add_azure_related_dynamic_params(data)
if sync_mode:
added_message = litellm.add_message(**data)