diff --git a/tests/test_openai_endpoints.py b/tests/test_openai_endpoints.py index 38e87c254..7bc97ca59 100644 --- a/tests/test_openai_endpoints.py +++ b/tests/test_openai_endpoints.py @@ -4,6 +4,7 @@ import pytest import asyncio import aiohttp, openai from openai import OpenAI, AsyncOpenAI +from typing import Optional, List, Union def response_header_check(response): @@ -71,7 +72,7 @@ async def new_user(session): return await response.json() -async def chat_completion(session, key, model="gpt-4"): +async def chat_completion(session, key, model: Union[str, List] = "gpt-4"): url = "http://0.0.0.0:4000/chat/completions" headers = { "Authorization": f"Bearer {key}", @@ -409,3 +410,27 @@ async def test_openai_wildcard_chat_completion(): # call chat/completions with a model that the key was not created for + the model is not on the config.yaml await chat_completion(session=session, key=key, model="gpt-3.5-turbo-0125") + + +@pytest.mark.asyncio +async def test_batch_chat_completions(): + """ + - Make chat completion call using + + """ + async with aiohttp.ClientSession() as session: + + # call chat/completions with a model that the key was not created for + the model is not on the config.yaml + response = await chat_completion( + session=session, + key="sk-1234", + model=[ + "gpt-3.5-turbo", + "fake-openai-endpoint", + ], + ) + + print(f"response: {response}") + + assert len(response) == 2 + assert isinstance(response, list)