feat(pass_through_endpoints.py): add pass-through support for all cohere endpoints

This commit is contained in:
Krrish Dholakia 2024-08-17 16:57:40 -07:00
parent 77177ff469
commit 1856ac585d
3 changed files with 44 additions and 2 deletions

View file

@ -2321,7 +2321,7 @@ def get_standard_logging_object_payload(
model_map_value=_model_cost_information,
)
except Exception:
verbose_logger.warning(
verbose_logger.debug( # keep in debug otherwise it will trigger on every call
"Model is not mapped in model cost map. Defaulting to None model_cost_information for standard_logging_payload"
)
model_cost_information = StandardLoggingModelInformation(

View file

@ -320,7 +320,7 @@ async def pass_through_request(
call_type="pass_through_endpoint",
)
async_client = httpx.AsyncClient()
async_client = httpx.AsyncClient(timeout=600)
# create logging object
start_time = time.time()

View file

@ -94,3 +94,45 @@ async def gemini_proxy_route(
)
return received_value
@router.api_route("/cohere/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def cohere_proxy_route(
endpoint: str,
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
base_target_url = "https://api.cohere.com"
encoded_endpoint = httpx.URL(endpoint).path
# Ensure endpoint starts with '/' for proper URL construction
if not encoded_endpoint.startswith("/"):
encoded_endpoint = "/" + encoded_endpoint
# Construct the full target URL using httpx
base_url = httpx.URL(base_target_url)
updated_url = base_url.copy_with(path=encoded_endpoint)
# Add or update query parameters
cohere_api_key = litellm.utils.get_secret(secret_name="COHERE_API_KEY")
## check for streaming
is_streaming_request = False
if "stream" in str(updated_url):
is_streaming_request = True
## CREATE PASS-THROUGH
endpoint_func = create_pass_through_route(
endpoint=endpoint,
target=str(updated_url),
custom_headers={"Authorization": "Bearer {}".format(cohere_api_key)},
) # dynamically construct pass-through endpoint based on incoming path
received_value = await endpoint_func(
request,
fastapi_response,
user_api_key_dict,
stream=is_streaming_request,
)
return received_value