use helper class for pass through success handler

This commit is contained in:
Ishaan Jaff 2024-08-30 15:52:47 -07:00
parent e1e1e2e566
commit f50374e81d
2 changed files with 117 additions and 3 deletions

View file

@ -3,6 +3,7 @@ import asyncio
import json import json
import traceback import traceback
from base64 import b64encode from base64 import b64encode
from datetime import datetime
from typing import AsyncIterable, List, Optional from typing import AsyncIterable, List, Optional
import httpx import httpx
@ -20,6 +21,7 @@ from fastapi.responses import StreamingResponse
import litellm import litellm
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.proxy._types import ( from litellm.proxy._types import (
ConfigFieldInfo, ConfigFieldInfo,
ConfigFieldUpdate, ConfigFieldUpdate,
@ -30,8 +32,12 @@ from litellm.proxy._types import (
) )
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from .success_handler import PassThroughEndpointLogging
router = APIRouter() router = APIRouter()
pass_through_endpoint_logging = PassThroughEndpointLogging()
async def set_env_variables_in_header(custom_headers: dict): async def set_env_variables_in_header(custom_headers: dict):
""" """
@ -330,7 +336,7 @@ async def pass_through_request(
async_client = httpx.AsyncClient(timeout=600) async_client = httpx.AsyncClient(timeout=600)
# create logging object # create logging object
start_time = time.time() start_time = datetime.now()
logging_obj = Logging( logging_obj = Logging(
model="unknown", model="unknown",
messages=[{"role": "user", "content": "no-message-pass-through-endpoint"}], messages=[{"role": "user", "content": "no-message-pass-through-endpoint"}],
@ -473,12 +479,15 @@ async def pass_through_request(
content = await response.aread() content = await response.aread()
## LOG SUCCESS ## LOG SUCCESS
end_time = time.time() end_time = datetime.now()
await logging_obj.async_success_handler( await pass_through_endpoint_logging.pass_through_async_success_handler(
httpx_response=response,
url_route=str(url),
result="", result="",
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
logging_obj=logging_obj,
cache_hit=False, cache_hit=False,
) )

View file

@ -0,0 +1,105 @@
import re
from datetime import datetime
import httpx
import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
VertexLLM,
)
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
class PassThroughEndpointLogging:
def __init__(self):
self.TRACKED_VERTEX_ROUTES = [
"generateContent",
"streamGenerateContent",
"predict",
]
async def pass_through_async_success_handler(
self,
httpx_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
url_route: str,
result: str,
start_time: datetime,
end_time: datetime,
cache_hit: bool,
**kwargs,
):
if self.is_vertex_route(url_route):
await self.vertex_passthrough_handler(
httpx_response=httpx_response,
logging_obj=logging_obj,
url_route=url_route,
result=result,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
**kwargs,
)
else:
await logging_obj.async_success_handler(
result="",
start_time=start_time,
end_time=end_time,
cache_hit=False,
)
def is_vertex_route(self, url_route: str):
for route in self.TRACKED_VERTEX_ROUTES:
if route in url_route:
return True
return False
def extract_model_from_url(self, url: str) -> str:
pattern = r"/models/([^:]+)"
match = re.search(pattern, url)
if match:
return match.group(1)
return "unknown"
async def vertex_passthrough_handler(
self,
httpx_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
url_route: str,
result: str,
start_time: datetime,
end_time: datetime,
cache_hit: bool,
**kwargs,
):
if "generateContent" in url_route:
model = self.extract_model_from_url(url_route)
instance_of_vertex_llm = VertexLLM()
litellm_model_response: litellm.ModelResponse = (
instance_of_vertex_llm._process_response(
model=model,
messages=[
{"role": "user", "content": "no-message-pass-through-endpoint"}
],
response=httpx_response,
model_response=litellm.ModelResponse(),
logging_obj=logging_obj,
optional_params={},
litellm_params={},
api_key="",
data={},
print_verbose=litellm.print_verbose,
encoding=None,
)
)
logging_obj.model = litellm_model_response.model
logging_obj.model_call_details["model"] = logging_obj.model
await logging_obj.async_success_handler(
result=litellm_model_response,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
)