add /rerank test

This commit is contained in:
Ishaan Jaff 2024-08-27 17:50:37 -07:00
parent 5f2f7aa754
commit c27640e6e4
3 changed files with 71 additions and 0 deletions

View file

@ -88,6 +88,8 @@ def _get_bearer_token(
api_key = api_key.replace("Bearer ", "") # extract the token
elif api_key.startswith("Basic "):
api_key = api_key.replace("Basic ", "") # handle langfuse input
elif api_key.startswith("bearer "):
api_key = api_key.replace("bearer ", "")
else:
api_key = ""
return api_key

View file

@ -4,6 +4,10 @@ model_list:
model: openai/fake
api_key: fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/
- model_name: rerank-english-v3.0
litellm_params:
model: cohere/rerank-english-v3.0
api_key: os.environ/COHERE_API_KEY
litellm_settings:
cache: true

View file

@ -0,0 +1,65 @@
import pytest
import asyncio
import aiohttp, openai
from openai import OpenAI, AsyncOpenAI
from typing import Optional, List, Union
import uuid
async def make_rerank_curl_request(
session,
key,
query,
documents,
model="rerank-english-v3.0",
top_n=3,
):
url = "http://0.0.0.0:4000/rerank"
headers = {
"Authorization": f"Bearer {key}",
"Content-Type": "application/json",
}
data = {
"model": model,
"query": query,
"documents": documents,
"top_n": top_n,
}
async with session.post(url, headers=headers, json=data) as response:
status = response.status
response_text = await response.text()
if status != 200:
raise Exception(response_text)
return await response.json()
@pytest.mark.asyncio
async def test_basic_rerank_on_proxy():
"""
Test litellm.rerank() on proxy
This SHOULD NOT call the pass through endpoints :)
"""
async with aiohttp.ClientSession() as session:
docs = [
"Carson City is the capital city of the American state of Nevada.",
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.",
"Washington, D.C. is the capital of the United States.",
"Capital punishment has existed in the United States since before it was a country.",
]
try:
response = await make_rerank_curl_request(
session,
"sk-1234",
query="What is the capital of the United States?",
documents=docs,
)
print("response=", response)
except Exception as e:
print(e)
pytest.fail("Rerank request failed")