Updated cohere v2 passthrough (#9997)

* Add cohere `/v2/chat` pass-through cost tracking support (#8235)

* feat(cohere_passthrough_handler.py): initial working commit with cohere passthrough cost tracking

* fix(v2_transformation.py): support cohere /v2/chat endpoint

* fix: fix linting errors

* fix: fix import

* fix(v2_transformation.py): fix linting error

* test: handle openai exception change
This commit is contained in:
Krish Dholakia 2025-04-14 19:51:01 -07:00 committed by GitHub
parent db857c74d4
commit 2ed593e052
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 742 additions and 20 deletions

View file

@ -16,10 +16,15 @@ from .llm_provider_handlers.anthropic_passthrough_logging_handler import (
from .llm_provider_handlers.assembly_passthrough_logging_handler import (
AssemblyAIPassthroughLoggingHandler,
)
from .llm_provider_handlers.cohere_passthrough_logging_handler import (
CoherePassthroughLoggingHandler,
)
from .llm_provider_handlers.vertex_passthrough_logging_handler import (
VertexPassthroughLoggingHandler,
)
cohere_passthrough_logging_handler = CoherePassthroughLoggingHandler()
class PassThroughEndpointLogging:
def __init__(self):
@ -32,6 +37,8 @@ class PassThroughEndpointLogging:
# Anthropic
self.TRACKED_ANTHROPIC_ROUTES = ["/messages"]
# Cohere
self.TRACKED_COHERE_ROUTES = ["/v2/chat"]
self.assemblyai_passthrough_logging_handler = (
AssemblyAIPassthroughLoggingHandler()
)
@ -84,6 +91,7 @@ class PassThroughEndpointLogging:
start_time: datetime,
end_time: datetime,
cache_hit: bool,
request_body: dict,
**kwargs,
):
standard_logging_response_object: Optional[
@ -125,6 +133,25 @@ class PassThroughEndpointLogging:
anthropic_passthrough_logging_handler_result["result"]
)
kwargs = anthropic_passthrough_logging_handler_result["kwargs"]
elif self.is_cohere_route(url_route):
cohere_passthrough_logging_handler_result = (
cohere_passthrough_logging_handler.passthrough_chat_handler(
httpx_response=httpx_response,
response_body=response_body or {},
logging_obj=logging_obj,
url_route=url_route,
result=result,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
request_body=request_body,
**kwargs,
)
)
standard_logging_response_object = (
cohere_passthrough_logging_handler_result["result"]
)
kwargs = cohere_passthrough_logging_handler_result["kwargs"]
elif self.is_assemblyai_route(url_route):
if (
AssemblyAIPassthroughLoggingHandler._should_log_request(
@ -173,6 +200,11 @@ class PassThroughEndpointLogging:
return True
return False
def is_cohere_route(self, url_route: str):
for route in self.TRACKED_COHERE_ROUTES:
if route in url_route:
return True
def is_assemblyai_route(self, url_route: str):
parsed_url = urlparse(url_route)
if parsed_url.hostname == "api.assemblyai.com":