mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
feat(pass_through_endpoints.py): add pass-through support for all cohere endpoints
This commit is contained in:
parent
77177ff469
commit
1856ac585d
3 changed files with 44 additions and 2 deletions
|
@ -2321,7 +2321,7 @@ def get_standard_logging_object_payload(
|
||||||
model_map_value=_model_cost_information,
|
model_map_value=_model_cost_information,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
verbose_logger.warning(
|
verbose_logger.debug( # keep in debug otherwise it will trigger on every call
|
||||||
"Model is not mapped in model cost map. Defaulting to None model_cost_information for standard_logging_payload"
|
"Model is not mapped in model cost map. Defaulting to None model_cost_information for standard_logging_payload"
|
||||||
)
|
)
|
||||||
model_cost_information = StandardLoggingModelInformation(
|
model_cost_information = StandardLoggingModelInformation(
|
||||||
|
|
|
@ -320,7 +320,7 @@ async def pass_through_request(
|
||||||
call_type="pass_through_endpoint",
|
call_type="pass_through_endpoint",
|
||||||
)
|
)
|
||||||
|
|
||||||
async_client = httpx.AsyncClient()
|
async_client = httpx.AsyncClient(timeout=600)
|
||||||
|
|
||||||
# create logging object
|
# create logging object
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
|
@ -94,3 +94,45 @@ async def gemini_proxy_route(
|
||||||
)
|
)
|
||||||
|
|
||||||
return received_value
|
return received_value
|
||||||
|
|
||||||
|
|
||||||
|
@router.api_route("/cohere/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"])
|
||||||
|
async def cohere_proxy_route(
|
||||||
|
endpoint: str,
|
||||||
|
request: Request,
|
||||||
|
fastapi_response: Response,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
|
):
|
||||||
|
base_target_url = "https://api.cohere.com"
|
||||||
|
encoded_endpoint = httpx.URL(endpoint).path
|
||||||
|
|
||||||
|
# Ensure endpoint starts with '/' for proper URL construction
|
||||||
|
if not encoded_endpoint.startswith("/"):
|
||||||
|
encoded_endpoint = "/" + encoded_endpoint
|
||||||
|
|
||||||
|
# Construct the full target URL using httpx
|
||||||
|
base_url = httpx.URL(base_target_url)
|
||||||
|
updated_url = base_url.copy_with(path=encoded_endpoint)
|
||||||
|
|
||||||
|
# Add or update query parameters
|
||||||
|
cohere_api_key = litellm.utils.get_secret(secret_name="COHERE_API_KEY")
|
||||||
|
|
||||||
|
## check for streaming
|
||||||
|
is_streaming_request = False
|
||||||
|
if "stream" in str(updated_url):
|
||||||
|
is_streaming_request = True
|
||||||
|
|
||||||
|
## CREATE PASS-THROUGH
|
||||||
|
endpoint_func = create_pass_through_route(
|
||||||
|
endpoint=endpoint,
|
||||||
|
target=str(updated_url),
|
||||||
|
custom_headers={"Authorization": "Bearer {}".format(cohere_api_key)},
|
||||||
|
) # dynamically construct pass-through endpoint based on incoming path
|
||||||
|
received_value = await endpoint_func(
|
||||||
|
request,
|
||||||
|
fastapi_response,
|
||||||
|
user_api_key_dict,
|
||||||
|
stream=is_streaming_request,
|
||||||
|
)
|
||||||
|
|
||||||
|
return received_value
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue