(test) router: add tests for azure completion, acompletion

This commit is contained in:
ishaan-jaff 2023-12-05 13:59:27 -08:00
parent 0d1b42eda5
commit bc70a6fba8

View file

@ -114,6 +114,9 @@ def test_reading_key_from_model_list():
] ]
) )
print("\n response", response) print("\n response", response)
str_response = response.choices[0].message.content
print("\n str_response", str_response)
assert len(str_response) > 0
print("\n Testing streaming response") print("\n Testing streaming response")
response = router.completion( response = router.completion(
@ -126,9 +129,13 @@ def test_reading_key_from_model_list():
], ],
stream=True stream=True
) )
completed_response = ""
for chunk in response: for chunk in response:
if chunk is not None: if chunk is not None:
print(chunk) print(chunk)
completed_response += chunk.choices[0].delta.content or ""
print("\n completed_response", completed_response)
assert len(completed_response) > 0
print("\n Passed Streaming") print("\n Passed Streaming")
os.environ["AZURE_API_KEY"] = old_api_key os.environ["AZURE_API_KEY"] = old_api_key
router.reset() router.reset()
@ -183,15 +190,18 @@ def test_router_azure_acompletion():
async def test1(): async def test1():
response = await router.acompletion( response: litellm.ModelResponse = await router.acompletion(
model="gpt-3.5-turbo", model="gpt-3.5-turbo",
messages=[ messages=[
{ {
"role": "user", "role": "user",
"content": "hello this request will fail" "content": "hello this request will pass"
} }
] ]
) )
str_response = response.choices[0].message.content
print("\n str_response", str_response)
assert len(str_response) > 0
print("\n response", response) print("\n response", response)
asyncio.run(test1()) asyncio.run(test1())
@ -207,9 +217,13 @@ def test_router_azure_acompletion():
], ],
stream=True stream=True
) )
completed_response = ""
async for chunk in response: async for chunk in response:
if chunk is not None: if chunk is not None:
print(chunk) print(chunk)
completed_response += chunk.choices[0].delta.content or ""
print("\n completed_response", completed_response)
assert len(completed_response) > 0
asyncio.run(test2()) asyncio.run(test2())
print("\n Passed Streaming") print("\n Passed Streaming")
os.environ["AZURE_API_KEY"] = old_api_key os.environ["AZURE_API_KEY"] = old_api_key