feat log request / response on pass through endpoints

This commit is contained in:
Ishaan Jaff 2024-09-04 16:26:32 -07:00
parent 935dba4470
commit 5e121660d5
4 changed files with 26 additions and 3 deletions

View file

@ -63,7 +63,7 @@ Removes any field with `user_api_key_*` from metadata.
## What gets logged? ## What gets logged?
Found under `kwargs["standard_logging_payload"]`. This is a standard payload, logged for every response. Found under `kwargs["standard_logging_object"]`. This is a standard payload, logged for every response.
```python ```python
class StandardLoggingPayload(TypedDict): class StandardLoggingPayload(TypedDict):

View file

@ -37,13 +37,20 @@ from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from .streaming_handler import chunk_processor from .streaming_handler import chunk_processor
from .success_handler import PassThroughEndpointLogging from .success_handler import PassThroughEndpointLogging
from .types import EndpointType from .types import EndpointType, PassthroughStandardLoggingObject
router = APIRouter() router = APIRouter()
pass_through_endpoint_logging = PassThroughEndpointLogging() pass_through_endpoint_logging = PassThroughEndpointLogging()
def get_response_body(response: httpx.Response):
try:
return response.json()
except Exception:
return response.text
async def set_env_variables_in_header(custom_headers: dict): async def set_env_variables_in_header(custom_headers: dict):
""" """
checks if any headers on config.yaml are defined as os.environ/COHERE_API_KEY etc checks if any headers on config.yaml are defined as os.environ/COHERE_API_KEY etc
@ -359,6 +366,10 @@ async def pass_through_request(
litellm_call_id=str(uuid.uuid4()), litellm_call_id=str(uuid.uuid4()),
function_id="1245", function_id="1245",
) )
passthrough_logging_payload = PassthroughStandardLoggingObject(
url=str(url),
request_body=_parsed_body,
)
# done for supporting 'parallel_request_limiter.py' with pass-through endpoints # done for supporting 'parallel_request_limiter.py' with pass-through endpoints
kwargs = { kwargs = {
@ -371,6 +382,7 @@ async def pass_through_request(
} }
}, },
"call_type": "pass_through_endpoint", "call_type": "pass_through_endpoint",
"passthrough_logging_payload": passthrough_logging_payload,
} }
logging_obj.update_environment_variables( logging_obj.update_environment_variables(
model="unknown", model="unknown",
@ -503,8 +515,8 @@ async def pass_through_request(
content = await response.aread() content = await response.aread()
## LOG SUCCESS ## LOG SUCCESS
passthrough_logging_payload["response_body"] = get_response_body(response)
end_time = datetime.now() end_time = datetime.now()
await pass_through_endpoint_logging.pass_through_async_success_handler( await pass_through_endpoint_logging.pass_through_async_success_handler(
httpx_response=response, httpx_response=response,
url_route=str(url), url_route=str(url),
@ -513,6 +525,7 @@ async def pass_through_request(
end_time=end_time, end_time=end_time,
logging_obj=logging_obj, logging_obj=logging_obj,
cache_hit=False, cache_hit=False,
**kwargs,
) )
return Response( return Response(

View file

@ -48,6 +48,7 @@ class PassThroughEndpointLogging:
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
cache_hit=False, cache_hit=False,
**kwargs,
) )
def is_vertex_route(self, url_route: str): def is_vertex_route(self, url_route: str):
@ -103,6 +104,7 @@ class PassThroughEndpointLogging:
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
cache_hit=cache_hit, cache_hit=cache_hit,
**kwargs,
) )
elif "predict" in url_route: elif "predict" in url_route:
from litellm.llms.vertex_ai_and_google_ai_studio.image_generation.image_generation_handler import ( from litellm.llms.vertex_ai_and_google_ai_studio.image_generation.image_generation_handler import (
@ -152,4 +154,5 @@ class PassThroughEndpointLogging:
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
cache_hit=cache_hit, cache_hit=cache_hit,
**kwargs,
) )

View file

@ -1,6 +1,13 @@
from enum import Enum from enum import Enum
from typing import Optional, TypedDict
class EndpointType(str, Enum): class EndpointType(str, Enum):
VERTEX_AI = "vertex-ai" VERTEX_AI = "vertex-ai"
GENERIC = "generic" GENERIC = "generic"
class PassthroughStandardLoggingObject(TypedDict, total=False):
url: str
request_body: Optional[dict]
response_body: Optional[dict]