diff --git a/.circleci/config.yml b/.circleci/config.yml index 5712c71ca..f697be521 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -209,6 +209,7 @@ jobs: -e MISTRAL_API_KEY=$MISTRAL_API_KEY \ -e AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \ -e GROQ_API_KEY=$GROQ_API_KEY \ + -e COHERE_API_KEY=$COHERE_API_KEY \ -e AWS_SECRET_ACCESS_KEY=$AWS_SECRET_ACCESS_KEY \ -e AWS_REGION_NAME=$AWS_REGION_NAME \ -e AUTO_INFER_REGION=True \ diff --git a/proxy_server_config.yaml b/proxy_server_config.yaml index dc943d59d..4912ebbbf 100644 --- a/proxy_server_config.yaml +++ b/proxy_server_config.yaml @@ -154,6 +154,14 @@ general_settings: database_connection_pool_limit: 10 # database_url: "postgresql://:@:/" # [OPTIONAL] use for token-based auth to proxy + pass_through_endpoints: + - path: "/v1/rerank" # route you want to add to LiteLLM Proxy Server + target: "https://api.cohere.com/v1/rerank" # URL this route should forward requests to + headers: # headers to forward to this URL + content-type: application/json # (Optional) Extra Headers to pass to this endpoint + accept: application/json + forward_headers: True + # environment_variables: # settings for using redis caching # REDIS_HOST: redis-16337.c322.us-east-1-2.ec2.cloud.redislabs.com diff --git a/tests/test_passthrough_endpoints.py b/tests/test_passthrough_endpoints.py new file mode 100644 index 000000000..69ce71371 --- /dev/null +++ b/tests/test_passthrough_endpoints.py @@ -0,0 +1,63 @@ +import pytest +import asyncio +import aiohttp, openai +from openai import OpenAI, AsyncOpenAI +from typing import Optional, List, Union + +import aiohttp +import asyncio +import json +import os +import dotenv + + +dotenv.load_dotenv() + + +async def cohere_rerank(session): + url = "http://localhost:4000/v1/rerank" + headers = { + "Authorization": f"bearer {os.getenv('COHERE_API_KEY')}", + "Content-Type": "application/json", + "Accept": "application/json", + } + data = { + "model": "rerank-english-v3.0", + "query": "What is the capital of the United States?", + "top_n": 3, + "documents": [ + "Carson City is the capital city of the American state of Nevada.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.", + "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.", + "Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages.", + "Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states.", + ], + } + + async with session.post(url, headers=headers, json=data) as response: + status = response.status + response_text = await response.text() + print(f"Status: {status}") + print(f"Response:\n{response_text}") + print() + + if status != 200: + raise Exception(f"Request did not return a 200 status code: {status}") + + return await response.json() + + +@pytest.mark.asyncio +async def test_basic_passthrough(): + """ + - Make request to pass through endpoint + + - This SHOULD not go through LiteLLM user_api_key_auth + - This should forward headers from request to pass through endpoint + """ + async with aiohttp.ClientSession() as session: + response = await cohere_rerank(session) + print("response from cohere rerank", response) + + assert response["id"] is not None + assert response["results"] is not None