forked from phoenix/litellm-mirror
59 lines
1.5 KiB
Python
59 lines
1.5 KiB
Python
import os
|
|
import sys
|
|
import traceback
|
|
|
|
from dotenv import load_dotenv
|
|
|
|
load_dotenv()
|
|
import io
|
|
import os
|
|
|
|
sys.path.insert(
|
|
0, os.path.abspath("../..")
|
|
) # Adds the parent directory to the system path
|
|
import json
|
|
|
|
import pytest
|
|
|
|
import litellm
|
|
from litellm import RateLimitError, Timeout, completion, completion_cost, embedding
|
|
|
|
litellm.num_retries = 3
|
|
|
|
|
|
@pytest.mark.parametrize("stream", [True, False])
|
|
@pytest.mark.asyncio
|
|
async def test_chat_completion_cohere_citations(stream):
|
|
try:
|
|
litellm.set_verbose = True
|
|
messages = [
|
|
{
|
|
"role": "user",
|
|
"content": "Which penguins are the tallest?",
|
|
},
|
|
]
|
|
response = await litellm.acompletion(
|
|
model="cohere_chat/command-r",
|
|
messages=messages,
|
|
documents=[
|
|
{"title": "Tall penguins", "text": "Emperor penguins are the tallest."},
|
|
{
|
|
"title": "Penguin habitats",
|
|
"text": "Emperor penguins only live in Antarctica.",
|
|
},
|
|
],
|
|
stream=stream,
|
|
)
|
|
|
|
if stream:
|
|
citations_chunk = False
|
|
async for chunk in response:
|
|
print("received chunk", chunk)
|
|
if "citations" in chunk:
|
|
citations_chunk = True
|
|
break
|
|
assert citations_chunk
|
|
else:
|
|
assert response.citations is not None
|
|
except Exception as e:
|
|
pytest.fail(f"Error occurred: {e}")
|