forked from phoenix/litellm-mirror
use helper class for pass through success handler
This commit is contained in:
parent
e1e1e2e566
commit
f50374e81d
2 changed files with 117 additions and 3 deletions
|
@ -3,6 +3,7 @@ import asyncio
|
||||||
import json
|
import json
|
||||||
import traceback
|
import traceback
|
||||||
from base64 import b64encode
|
from base64 import b64encode
|
||||||
|
from datetime import datetime
|
||||||
from typing import AsyncIterable, List, Optional
|
from typing import AsyncIterable, List, Optional
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
@ -20,6 +21,7 @@ from fastapi.responses import StreamingResponse
|
||||||
import litellm
|
import litellm
|
||||||
from litellm._logging import verbose_proxy_logger
|
from litellm._logging import verbose_proxy_logger
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||||
from litellm.proxy._types import (
|
from litellm.proxy._types import (
|
||||||
ConfigFieldInfo,
|
ConfigFieldInfo,
|
||||||
ConfigFieldUpdate,
|
ConfigFieldUpdate,
|
||||||
|
@ -30,8 +32,12 @@ from litellm.proxy._types import (
|
||||||
)
|
)
|
||||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||||
|
|
||||||
|
from .success_handler import PassThroughEndpointLogging
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
pass_through_endpoint_logging = PassThroughEndpointLogging()
|
||||||
|
|
||||||
|
|
||||||
async def set_env_variables_in_header(custom_headers: dict):
|
async def set_env_variables_in_header(custom_headers: dict):
|
||||||
"""
|
"""
|
||||||
|
@ -330,7 +336,7 @@ async def pass_through_request(
|
||||||
async_client = httpx.AsyncClient(timeout=600)
|
async_client = httpx.AsyncClient(timeout=600)
|
||||||
|
|
||||||
# create logging object
|
# create logging object
|
||||||
start_time = time.time()
|
start_time = datetime.now()
|
||||||
logging_obj = Logging(
|
logging_obj = Logging(
|
||||||
model="unknown",
|
model="unknown",
|
||||||
messages=[{"role": "user", "content": "no-message-pass-through-endpoint"}],
|
messages=[{"role": "user", "content": "no-message-pass-through-endpoint"}],
|
||||||
|
@ -473,12 +479,15 @@ async def pass_through_request(
|
||||||
content = await response.aread()
|
content = await response.aread()
|
||||||
|
|
||||||
## LOG SUCCESS
|
## LOG SUCCESS
|
||||||
end_time = time.time()
|
end_time = datetime.now()
|
||||||
|
|
||||||
await logging_obj.async_success_handler(
|
await pass_through_endpoint_logging.pass_through_async_success_handler(
|
||||||
|
httpx_response=response,
|
||||||
|
url_route=str(url),
|
||||||
result="",
|
result="",
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
end_time=end_time,
|
end_time=end_time,
|
||||||
|
logging_obj=logging_obj,
|
||||||
cache_hit=False,
|
cache_hit=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
105
litellm/proxy/pass_through_endpoints/success_handler.py
Normal file
105
litellm/proxy/pass_through_endpoints/success_handler.py
Normal file
|
@ -0,0 +1,105 @@
|
||||||
|
import re
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||||
|
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
||||||
|
VertexLLM,
|
||||||
|
)
|
||||||
|
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||||
|
|
||||||
|
|
||||||
|
class PassThroughEndpointLogging:
|
||||||
|
def __init__(self):
|
||||||
|
self.TRACKED_VERTEX_ROUTES = [
|
||||||
|
"generateContent",
|
||||||
|
"streamGenerateContent",
|
||||||
|
"predict",
|
||||||
|
]
|
||||||
|
|
||||||
|
async def pass_through_async_success_handler(
|
||||||
|
self,
|
||||||
|
httpx_response: httpx.Response,
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
|
url_route: str,
|
||||||
|
result: str,
|
||||||
|
start_time: datetime,
|
||||||
|
end_time: datetime,
|
||||||
|
cache_hit: bool,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if self.is_vertex_route(url_route):
|
||||||
|
await self.vertex_passthrough_handler(
|
||||||
|
httpx_response=httpx_response,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
url_route=url_route,
|
||||||
|
result=result,
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
cache_hit=cache_hit,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await logging_obj.async_success_handler(
|
||||||
|
result="",
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
cache_hit=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def is_vertex_route(self, url_route: str):
|
||||||
|
for route in self.TRACKED_VERTEX_ROUTES:
|
||||||
|
if route in url_route:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def extract_model_from_url(self, url: str) -> str:
|
||||||
|
pattern = r"/models/([^:]+)"
|
||||||
|
match = re.search(pattern, url)
|
||||||
|
if match:
|
||||||
|
return match.group(1)
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
|
async def vertex_passthrough_handler(
|
||||||
|
self,
|
||||||
|
httpx_response: httpx.Response,
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
|
url_route: str,
|
||||||
|
result: str,
|
||||||
|
start_time: datetime,
|
||||||
|
end_time: datetime,
|
||||||
|
cache_hit: bool,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if "generateContent" in url_route:
|
||||||
|
model = self.extract_model_from_url(url_route)
|
||||||
|
|
||||||
|
instance_of_vertex_llm = VertexLLM()
|
||||||
|
litellm_model_response: litellm.ModelResponse = (
|
||||||
|
instance_of_vertex_llm._process_response(
|
||||||
|
model=model,
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "no-message-pass-through-endpoint"}
|
||||||
|
],
|
||||||
|
response=httpx_response,
|
||||||
|
model_response=litellm.ModelResponse(),
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
optional_params={},
|
||||||
|
litellm_params={},
|
||||||
|
api_key="",
|
||||||
|
data={},
|
||||||
|
print_verbose=litellm.print_verbose,
|
||||||
|
encoding=None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
logging_obj.model = litellm_model_response.model
|
||||||
|
logging_obj.model_call_details["model"] = logging_obj.model
|
||||||
|
|
||||||
|
await logging_obj.async_success_handler(
|
||||||
|
result=litellm_model_response,
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
cache_hit=cache_hit,
|
||||||
|
)
|
Loading…
Add table
Add a link
Reference in a new issue