diff --git a/tests/local_testing/test_assistants.py b/tests/local_testing/test_assistants.py index d5755f2aba..8d10668720 100644 --- a/tests/local_testing/test_assistants.py +++ b/tests/local_testing/test_assistants.py @@ -37,6 +37,12 @@ V0 Scope: - Run Thread -> `/v1/threads/{thread_id}/run` """ +def _add_azure_related_dynamic_params(data: dict) -> dict: + data["api_version"] = "2024-02-15-preview" + data["api_base"] = os.getenv("AZURE_ASSISTANTS_API_BASE") + data["api_key"] = os.getenv("AZURE_ASSISTANTS_API_KEY") + return data + @pytest.mark.parametrize("provider", ["openai", "azure"]) @pytest.mark.parametrize( @@ -49,7 +55,7 @@ async def test_get_assistants(provider, sync_mode): "custom_llm_provider": provider, } if provider == "azure": - data["api_version"] = "2024-02-15-preview" + data = _add_azure_related_dynamic_params(data) if sync_mode == True: assistants = litellm.get_assistants(**data) @@ -68,19 +74,18 @@ async def test_get_assistants(provider, sync_mode): @pytest.mark.flaky(retries=3, delay=1) async def test_create_delete_assistants(provider, sync_mode): litellm.ssl_verify = False - model = "gpt-4-turbo" + data = { + "custom_llm_provider": provider, + "model": "gpt-4-turbo", + "instructions": "You are a personal math tutor. When asked a question, write and run Python code to answer the question.", + "name": "Math Tutor", + "tools": [{"type": "code_interpreter"}], + } if provider == "azure": - os.environ["AZURE_API_VERSION"] = "2024-05-01-preview" - model = "chatgpt-v-3" + data = _add_azure_related_dynamic_params(data) if sync_mode == True: - assistant = litellm.create_assistants( - custom_llm_provider=provider, - model=model, - instructions="You are a personal math tutor. When asked a question, write and run Python code to answer the question.", - name="Math Tutor", - tools=[{"type": "code_interpreter"}], - ) + assistant = litellm.create_assistants(**data) print("New assistants", assistant) assert isinstance(assistant, Assistant) @@ -91,18 +96,19 @@ async def test_create_delete_assistants(provider, sync_mode): assert assistant.id is not None # delete the created assistant - response = litellm.delete_assistant( - custom_llm_provider=provider, assistant_id=assistant.id - ) + delete_data = { + "custom_llm_provider": provider, + "assistant_id": assistant.id, + } + if provider == "azure": + delete_data = _add_azure_related_dynamic_params(delete_data) + response = litellm.delete_assistant(**delete_data) print("Response deleting assistant", response) assert response.id == assistant.id else: assistant = await litellm.acreate_assistants( custom_llm_provider=provider, - model=model, - instructions="You are a personal math tutor. When asked a question, write and run Python code to answer the question.", - name="Math Tutor", - tools=[{"type": "code_interpreter"}], + **data, ) print("New assistants", assistant) assert isinstance(assistant, Assistant) @@ -112,9 +118,14 @@ async def test_create_delete_assistants(provider, sync_mode): ) assert assistant.id is not None - response = await litellm.adelete_assistant( - custom_llm_provider=provider, assistant_id=assistant.id - ) + # delete the created assistant + delete_data = { + "custom_llm_provider": provider, + "assistant_id": assistant.id, + } + if provider == "azure": + delete_data = _add_azure_related_dynamic_params(delete_data) + response = await litellm.adelete_assistant(**delete_data) print("Response deleting assistant", response) assert response.id == assistant.id @@ -129,7 +140,7 @@ async def test_create_thread_litellm(sync_mode, provider) -> Thread: "message": [message], } if provider == "azure": - data["api_version"] = "2024-02-15-preview" + data = _add_azure_related_dynamic_params(data) if sync_mode: new_thread = create_thread(**data) @@ -159,7 +170,7 @@ async def test_get_thread_litellm(provider, sync_mode): "thread_id": _new_thread.id, } if provider == "azure": - data["api_version"] = "2024-02-15-preview" + data = _add_azure_related_dynamic_params(data) if sync_mode: received_thread = get_thread(**data) @@ -188,7 +199,7 @@ async def test_add_message_litellm(sync_mode, provider): data = {"custom_llm_provider": provider, "thread_id": _new_thread.id, **message} if provider == "azure": - data["api_version"] = "2024-02-15-preview" + data = _add_azure_related_dynamic_params(data) if sync_mode: added_message = litellm.add_message(**data) else: @@ -252,6 +263,8 @@ async def test_aarun_thread_litellm(sync_mode, provider, is_streaming): message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore data = {"custom_llm_provider": provider, "thread_id": _new_thread.id, **message} + if provider == "azure": + data = _add_azure_related_dynamic_params(data) if sync_mode: added_message = litellm.add_message(**data)