litellm-mirror/litellm/proxy/pass_through_endpoints/success_handler.py
Krish Dholakia 2ed593e052
Updated cohere v2 passthrough (#9997)
* Add cohere `/v2/chat` pass-through cost tracking support (#8235)

* feat(cohere_passthrough_handler.py): initial working commit with cohere passthrough cost tracking

* fix(v2_transformation.py): support cohere /v2/chat endpoint

* fix: fix linting errors

* fix: fix import

* fix(v2_transformation.py): fix linting error

* test: handle openai exception change
2025-04-14 19:51:01 -07:00

214 lines
7.3 KiB
Python

import json
from datetime import datetime
from typing import Optional, Union
from urllib.parse import urlparse
import httpx
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.proxy._types import PassThroughEndpointLoggingResultValues
from litellm.types.utils import StandardPassThroughResponseObject
from litellm.utils import executor as thread_pool_executor
from .llm_provider_handlers.anthropic_passthrough_logging_handler import (
AnthropicPassthroughLoggingHandler,
)
from .llm_provider_handlers.assembly_passthrough_logging_handler import (
AssemblyAIPassthroughLoggingHandler,
)
from .llm_provider_handlers.cohere_passthrough_logging_handler import (
CoherePassthroughLoggingHandler,
)
from .llm_provider_handlers.vertex_passthrough_logging_handler import (
VertexPassthroughLoggingHandler,
)
cohere_passthrough_logging_handler = CoherePassthroughLoggingHandler()
class PassThroughEndpointLogging:
def __init__(self):
self.TRACKED_VERTEX_ROUTES = [
"generateContent",
"streamGenerateContent",
"predict",
]
# Anthropic
self.TRACKED_ANTHROPIC_ROUTES = ["/messages"]
# Cohere
self.TRACKED_COHERE_ROUTES = ["/v2/chat"]
self.assemblyai_passthrough_logging_handler = (
AssemblyAIPassthroughLoggingHandler()
)
async def _handle_logging(
self,
logging_obj: LiteLLMLoggingObj,
standard_logging_response_object: Union[
StandardPassThroughResponseObject,
PassThroughEndpointLoggingResultValues,
dict,
],
result: str,
start_time: datetime,
end_time: datetime,
cache_hit: bool,
**kwargs,
):
"""Helper function to handle both sync and async logging operations"""
# Submit to thread pool for sync logging
thread_pool_executor.submit(
logging_obj.success_handler,
standard_logging_response_object,
start_time,
end_time,
cache_hit,
**kwargs,
)
# Handle async logging
await logging_obj.async_success_handler(
result=(
json.dumps(result)
if isinstance(result, dict)
else standard_logging_response_object
),
start_time=start_time,
end_time=end_time,
cache_hit=False,
**kwargs,
)
async def pass_through_async_success_handler(
self,
httpx_response: httpx.Response,
response_body: Optional[dict],
logging_obj: LiteLLMLoggingObj,
url_route: str,
result: str,
start_time: datetime,
end_time: datetime,
cache_hit: bool,
request_body: dict,
**kwargs,
):
standard_logging_response_object: Optional[
PassThroughEndpointLoggingResultValues
] = None
if self.is_vertex_route(url_route):
vertex_passthrough_logging_handler_result = (
VertexPassthroughLoggingHandler.vertex_passthrough_handler(
httpx_response=httpx_response,
logging_obj=logging_obj,
url_route=url_route,
result=result,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
**kwargs,
)
)
standard_logging_response_object = (
vertex_passthrough_logging_handler_result["result"]
)
kwargs = vertex_passthrough_logging_handler_result["kwargs"]
elif self.is_anthropic_route(url_route):
anthropic_passthrough_logging_handler_result = (
AnthropicPassthroughLoggingHandler.anthropic_passthrough_handler(
httpx_response=httpx_response,
response_body=response_body or {},
logging_obj=logging_obj,
url_route=url_route,
result=result,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
**kwargs,
)
)
standard_logging_response_object = (
anthropic_passthrough_logging_handler_result["result"]
)
kwargs = anthropic_passthrough_logging_handler_result["kwargs"]
elif self.is_cohere_route(url_route):
cohere_passthrough_logging_handler_result = (
cohere_passthrough_logging_handler.passthrough_chat_handler(
httpx_response=httpx_response,
response_body=response_body or {},
logging_obj=logging_obj,
url_route=url_route,
result=result,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
request_body=request_body,
**kwargs,
)
)
standard_logging_response_object = (
cohere_passthrough_logging_handler_result["result"]
)
kwargs = cohere_passthrough_logging_handler_result["kwargs"]
elif self.is_assemblyai_route(url_route):
if (
AssemblyAIPassthroughLoggingHandler._should_log_request(
httpx_response.request.method
)
is not True
):
return
self.assemblyai_passthrough_logging_handler.assemblyai_passthrough_logging_handler(
httpx_response=httpx_response,
response_body=response_body or {},
logging_obj=logging_obj,
url_route=url_route,
result=result,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
**kwargs,
)
return
if standard_logging_response_object is None:
standard_logging_response_object = StandardPassThroughResponseObject(
response=httpx_response.text
)
await self._handle_logging(
logging_obj=logging_obj,
standard_logging_response_object=standard_logging_response_object,
result=result,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
**kwargs,
)
def is_vertex_route(self, url_route: str):
for route in self.TRACKED_VERTEX_ROUTES:
if route in url_route:
return True
return False
def is_anthropic_route(self, url_route: str):
for route in self.TRACKED_ANTHROPIC_ROUTES:
if route in url_route:
return True
return False
def is_cohere_route(self, url_route: str):
for route in self.TRACKED_COHERE_ROUTES:
if route in url_route:
return True
def is_assemblyai_route(self, url_route: str):
parsed_url = urlparse(url_route)
if parsed_url.hostname == "api.assemblyai.com":
return True
elif "/transcript" in parsed_url.path:
return True
return False