forked from phoenix/litellm-mirror
(fixes) gcs bucket key based logging (#6044)
* fixes for gcs bucket logging * fix StandardCallbackDynamicParams * fix - gcs logging when payload is not serializable * add test_add_callback_via_key_litellm_pre_call_utils_gcs_bucket * working success callbacks * linting fixes * fix linting error * add type hints to functions * fixes for dynamic success and failure logging * fix for test_async_chat_openai_stream
This commit is contained in:
parent
793593e735
commit
670ecda4e2
9 changed files with 446 additions and 39 deletions
|
@ -82,7 +82,7 @@ class GCSBucketLogger(GCSBucketBase):
|
|||
if logging_payload is None:
|
||||
raise ValueError("standard_logging_object not found in kwargs")
|
||||
|
||||
json_logged_payload = json.dumps(logging_payload)
|
||||
json_logged_payload = json.dumps(logging_payload, default=str)
|
||||
|
||||
# Get the current date
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
|
@ -137,7 +137,7 @@ class GCSBucketLogger(GCSBucketBase):
|
|||
_litellm_params = kwargs.get("litellm_params") or {}
|
||||
metadata = _litellm_params.get("metadata") or {}
|
||||
|
||||
json_logged_payload = json.dumps(logging_payload)
|
||||
json_logged_payload = json.dumps(logging_payload, default=str)
|
||||
|
||||
# Get the current date
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
|
|
|
@ -192,16 +192,28 @@ class Logging:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
model: str,
|
||||
messages,
|
||||
stream,
|
||||
call_type,
|
||||
start_time,
|
||||
litellm_call_id,
|
||||
function_id,
|
||||
dynamic_success_callbacks=None,
|
||||
dynamic_failure_callbacks=None,
|
||||
dynamic_async_success_callbacks=None,
|
||||
litellm_call_id: str,
|
||||
function_id: str,
|
||||
dynamic_input_callbacks: Optional[
|
||||
List[Union[str, Callable, CustomLogger]]
|
||||
] = None,
|
||||
dynamic_success_callbacks: Optional[
|
||||
List[Union[str, Callable, CustomLogger]]
|
||||
] = None,
|
||||
dynamic_async_success_callbacks: Optional[
|
||||
List[Union[str, Callable, CustomLogger]]
|
||||
] = None,
|
||||
dynamic_failure_callbacks: Optional[
|
||||
List[Union[str, Callable, CustomLogger]]
|
||||
] = None,
|
||||
dynamic_async_failure_callbacks: Optional[
|
||||
List[Union[str, Callable, CustomLogger]]
|
||||
] = None,
|
||||
kwargs: Optional[Dict] = None,
|
||||
):
|
||||
if messages is not None:
|
||||
|
@ -230,27 +242,117 @@ class Logging:
|
|||
[]
|
||||
) # for generating complete stream response
|
||||
self.model_call_details: Dict[Any, Any] = {}
|
||||
self.dynamic_input_callbacks: List[Any] = (
|
||||
[]
|
||||
) # [TODO] callbacks set for just that call
|
||||
self.dynamic_failure_callbacks = dynamic_failure_callbacks
|
||||
self.dynamic_success_callbacks = (
|
||||
dynamic_success_callbacks # callbacks set for just that call
|
||||
)
|
||||
self.dynamic_async_success_callbacks = (
|
||||
dynamic_async_success_callbacks # callbacks set for just that call
|
||||
)
|
||||
|
||||
# Initialize dynamic callbacks
|
||||
self.dynamic_input_callbacks: Optional[
|
||||
List[Union[str, Callable, CustomLogger]]
|
||||
] = dynamic_input_callbacks
|
||||
self.dynamic_success_callbacks: Optional[
|
||||
List[Union[str, Callable, CustomLogger]]
|
||||
] = dynamic_success_callbacks
|
||||
self.dynamic_async_success_callbacks: Optional[
|
||||
List[Union[str, Callable, CustomLogger]]
|
||||
] = dynamic_async_success_callbacks
|
||||
self.dynamic_failure_callbacks: Optional[
|
||||
List[Union[str, Callable, CustomLogger]]
|
||||
] = dynamic_failure_callbacks
|
||||
self.dynamic_async_failure_callbacks: Optional[
|
||||
List[Union[str, Callable, CustomLogger]]
|
||||
] = dynamic_async_failure_callbacks
|
||||
|
||||
# Process dynamic callbacks
|
||||
self.process_dynamic_callbacks()
|
||||
|
||||
## DYNAMIC LANGFUSE / GCS / logging callback KEYS ##
|
||||
self.standard_callback_dynamic_params: StandardCallbackDynamicParams = (
|
||||
self.initialize_standard_callback_dynamic_params(kwargs)
|
||||
)
|
||||
## TIME TO FIRST TOKEN LOGGING ##
|
||||
|
||||
## TIME TO FIRST TOKEN LOGGING ##
|
||||
self.completion_start_time: Optional[datetime.datetime] = None
|
||||
|
||||
def process_dynamic_callbacks(self):
|
||||
"""
|
||||
Initializes CustomLogger compatible callbacks in self.dynamic_* callbacks
|
||||
|
||||
If a callback is in litellm._known_custom_logger_compatible_callbacks, it needs to be intialized and added to the respective dynamic_* callback list.
|
||||
"""
|
||||
# Process input callbacks
|
||||
self.dynamic_input_callbacks = self._process_dynamic_callback_list(
|
||||
self.dynamic_input_callbacks, dynamic_callbacks_type="input"
|
||||
)
|
||||
|
||||
# Process failure callbacks
|
||||
self.dynamic_failure_callbacks = self._process_dynamic_callback_list(
|
||||
self.dynamic_failure_callbacks, dynamic_callbacks_type="failure"
|
||||
)
|
||||
|
||||
# Process async failure callbacks
|
||||
self.dynamic_async_failure_callbacks = self._process_dynamic_callback_list(
|
||||
self.dynamic_async_failure_callbacks, dynamic_callbacks_type="async_failure"
|
||||
)
|
||||
|
||||
# Process success callbacks
|
||||
self.dynamic_success_callbacks = self._process_dynamic_callback_list(
|
||||
self.dynamic_success_callbacks, dynamic_callbacks_type="success"
|
||||
)
|
||||
|
||||
# Process async success callbacks
|
||||
self.dynamic_async_success_callbacks = self._process_dynamic_callback_list(
|
||||
self.dynamic_async_success_callbacks, dynamic_callbacks_type="async_success"
|
||||
)
|
||||
|
||||
def _process_dynamic_callback_list(
|
||||
self,
|
||||
callback_list: Optional[List[Union[str, Callable, CustomLogger]]],
|
||||
dynamic_callbacks_type: Literal[
|
||||
"input", "success", "failure", "async_success", "async_failure"
|
||||
],
|
||||
) -> Optional[List[Union[str, Callable, CustomLogger]]]:
|
||||
"""
|
||||
Helper function to initialize CustomLogger compatible callbacks in self.dynamic_* callbacks
|
||||
|
||||
- If a callback is in litellm._known_custom_logger_compatible_callbacks,
|
||||
replace the string with the initialized callback class.
|
||||
- If dynamic callback is a "success" callback that is a known_custom_logger_compatible_callbacks then add it to dynamic_async_success_callbacks
|
||||
- If dynamic callback is a "failure" callback that is a known_custom_logger_compatible_callbacks then add it to dynamic_failure_callbacks
|
||||
"""
|
||||
if callback_list is None:
|
||||
return None
|
||||
|
||||
processed_list: List[Union[str, Callable, CustomLogger]] = []
|
||||
for callback in callback_list:
|
||||
if (
|
||||
isinstance(callback, str)
|
||||
and callback in litellm._known_custom_logger_compatible_callbacks
|
||||
):
|
||||
callback_class = _init_custom_logger_compatible_class(
|
||||
callback, internal_usage_cache=None, llm_router=None # type: ignore
|
||||
)
|
||||
if callback_class is not None:
|
||||
processed_list.append(callback_class)
|
||||
|
||||
# If processing dynamic_success_callbacks, add to dynamic_async_success_callbacks
|
||||
if dynamic_callbacks_type == "success":
|
||||
if self.dynamic_async_success_callbacks is None:
|
||||
self.dynamic_async_success_callbacks = []
|
||||
self.dynamic_async_success_callbacks.append(callback_class)
|
||||
elif dynamic_callbacks_type == "failure":
|
||||
if self.dynamic_async_failure_callbacks is None:
|
||||
self.dynamic_async_failure_callbacks = []
|
||||
self.dynamic_async_failure_callbacks.append(callback_class)
|
||||
else:
|
||||
processed_list.append(callback)
|
||||
return processed_list
|
||||
|
||||
def initialize_standard_callback_dynamic_params(
|
||||
self, kwargs: Optional[Dict] = None
|
||||
) -> StandardCallbackDynamicParams:
|
||||
"""
|
||||
Initialize the standard callback dynamic params from the kwargs
|
||||
|
||||
checks if langfuse_secret_key, gcs_bucket_name in kwargs and sets the corresponding attributes in StandardCallbackDynamicParams
|
||||
"""
|
||||
standard_callback_dynamic_params = StandardCallbackDynamicParams()
|
||||
if kwargs:
|
||||
_supported_callback_params = (
|
||||
|
@ -413,7 +515,7 @@ class Logging:
|
|||
|
||||
self.model_call_details["api_call_start_time"] = datetime.datetime.now()
|
||||
# Input Integration Logging -> If you want to log the fact that an attempt to call the model was made
|
||||
callbacks = litellm.input_callback + self.dynamic_input_callbacks
|
||||
callbacks = litellm.input_callback + (self.dynamic_input_callbacks or [])
|
||||
for callback in callbacks:
|
||||
try:
|
||||
if callback == "supabase" and supabaseClient is not None:
|
||||
|
@ -529,7 +631,7 @@ class Logging:
|
|||
)
|
||||
# Input Integration Logging -> If you want to log the fact that an attempt to call the model was made
|
||||
|
||||
callbacks = litellm.input_callback + self.dynamic_input_callbacks
|
||||
callbacks = litellm.input_callback + (self.dynamic_input_callbacks or [])
|
||||
for callback in callbacks:
|
||||
try:
|
||||
if callback == "sentry" and add_breadcrumb:
|
||||
|
@ -2004,8 +2106,25 @@ class Logging:
|
|||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
|
||||
callbacks = [] # init this to empty incase it's not created
|
||||
|
||||
if self.dynamic_async_failure_callbacks is not None and isinstance(
|
||||
self.dynamic_async_failure_callbacks, list
|
||||
):
|
||||
callbacks = self.dynamic_async_failure_callbacks
|
||||
## keep the internal functions ##
|
||||
for callback in litellm._async_failure_callback:
|
||||
if (
|
||||
isinstance(callback, CustomLogger)
|
||||
and "_PROXY_" in callback.__class__.__name__
|
||||
):
|
||||
callbacks.append(callback)
|
||||
else:
|
||||
callbacks = litellm._async_failure_callback
|
||||
|
||||
result = None # result sent to all loggers, init this to None incase it's not created
|
||||
for callback in litellm._async_failure_callback:
|
||||
for callback in callbacks:
|
||||
try:
|
||||
if isinstance(callback, CustomLogger): # custom logger class
|
||||
await callback.async_log_failure_event(
|
||||
|
|
|
@ -12,7 +12,7 @@ from typing_extensions import Annotated, TypedDict
|
|||
|
||||
from litellm.integrations.SlackAlerting.types import AlertType
|
||||
from litellm.types.router import RouterErrors, UpdateRouterConfig
|
||||
from litellm.types.utils import ProviderField
|
||||
from litellm.types.utils import ProviderField, StandardCallbackDynamicParams
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
@ -959,21 +959,38 @@ class BlockKeyRequest(LiteLLMBase):
|
|||
class AddTeamCallback(LiteLLMBase):
|
||||
callback_name: str
|
||||
callback_type: Literal["success", "failure", "success_and_failure"]
|
||||
# for now - only supported for langfuse
|
||||
callback_vars: Dict[
|
||||
Literal["langfuse_public_key", "langfuse_secret_key", "langfuse_host"], str
|
||||
]
|
||||
callback_vars: Dict[str, str]
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_callback_vars(cls, values):
|
||||
callback_vars = values.get("callback_vars", {})
|
||||
valid_keys = set(StandardCallbackDynamicParams.__annotations__.keys())
|
||||
for key in callback_vars:
|
||||
if key not in valid_keys:
|
||||
raise ValueError(
|
||||
f"Invalid callback variable: {key}. Must be one of {valid_keys}"
|
||||
)
|
||||
return values
|
||||
|
||||
|
||||
class TeamCallbackMetadata(LiteLLMBase):
|
||||
success_callback: Optional[List[str]] = []
|
||||
failure_callback: Optional[List[str]] = []
|
||||
# for now - only supported for langfuse
|
||||
callback_vars: Optional[
|
||||
Dict[
|
||||
Literal["langfuse_public_key", "langfuse_secret_key", "langfuse_host"], str
|
||||
]
|
||||
] = {}
|
||||
callback_vars: Optional[Dict[str, str]] = {}
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_callback_vars(cls, values):
|
||||
callback_vars = values.get("callback_vars", {})
|
||||
valid_keys = set(StandardCallbackDynamicParams.__annotations__.keys())
|
||||
for key in callback_vars:
|
||||
if key not in valid_keys:
|
||||
raise ValueError(
|
||||
f"Invalid callback variable: {key}. Must be one of {valid_keys}"
|
||||
)
|
||||
return values
|
||||
|
||||
|
||||
class LiteLLM_TeamTable(TeamBase):
|
||||
|
|
|
@ -115,7 +115,7 @@ class PassThroughEndpointLogging:
|
|||
encoding=None,
|
||||
)
|
||||
)
|
||||
logging_obj.model = litellm_model_response.model
|
||||
logging_obj.model = litellm_model_response.model or model
|
||||
logging_obj.model_call_details["model"] = logging_obj.model
|
||||
|
||||
await logging_obj.async_success_handler(
|
||||
|
|
|
@ -5,7 +5,6 @@ model_list:
|
|||
api_key: fake-key
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||
|
||||
general_settings:
|
||||
alerting: ["slack"]
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -1384,8 +1384,12 @@ OPENAI_RESPONSE_HEADERS = [
|
|||
|
||||
|
||||
class StandardCallbackDynamicParams(TypedDict, total=False):
|
||||
# Langfuse dynamic params
|
||||
langfuse_public_key: Optional[str]
|
||||
langfuse_secret: Optional[str]
|
||||
langfuse_secret_key: Optional[str]
|
||||
langfuse_host: Optional[str]
|
||||
|
||||
# GCS dynamic params
|
||||
gcs_bucket_name: Optional[str]
|
||||
gcs_path_service_account: Optional[str]
|
||||
|
|
|
@ -58,6 +58,7 @@ import litellm.litellm_core_utils
|
|||
import litellm.litellm_core_utils.audio_utils.utils
|
||||
import litellm.litellm_core_utils.json_validation_rule
|
||||
from litellm.caching import DualCache
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||
from litellm.litellm_core_utils.exception_mapping_utils import (
|
||||
_get_litellm_response_headers,
|
||||
|
@ -430,9 +431,18 @@ def function_setup(
|
|||
for index in reversed(removed_async_items):
|
||||
litellm.failure_callback.pop(index)
|
||||
### DYNAMIC CALLBACKS ###
|
||||
dynamic_success_callbacks = None
|
||||
dynamic_async_success_callbacks = None
|
||||
dynamic_failure_callbacks = None
|
||||
dynamic_success_callbacks: Optional[
|
||||
List[Union[str, Callable, CustomLogger]]
|
||||
] = None
|
||||
dynamic_async_success_callbacks: Optional[
|
||||
List[Union[str, Callable, CustomLogger]]
|
||||
] = None
|
||||
dynamic_failure_callbacks: Optional[
|
||||
List[Union[str, Callable, CustomLogger]]
|
||||
] = None
|
||||
dynamic_async_failure_callbacks: Optional[
|
||||
List[Union[str, Callable, CustomLogger]]
|
||||
] = None
|
||||
if kwargs.get("success_callback", None) is not None and isinstance(
|
||||
kwargs["success_callback"], list
|
||||
):
|
||||
|
@ -561,6 +571,7 @@ def function_setup(
|
|||
dynamic_success_callbacks=dynamic_success_callbacks,
|
||||
dynamic_failure_callbacks=dynamic_failure_callbacks,
|
||||
dynamic_async_success_callbacks=dynamic_async_success_callbacks,
|
||||
dynamic_async_failure_callbacks=dynamic_async_failure_callbacks,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
|
|
|
@ -267,7 +267,7 @@ async def test_basic_gcs_logger_failure():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_gcs_logging_per_request():
|
||||
async def test_basic_gcs_logging_per_request_with_callback_set():
|
||||
"""
|
||||
Test GCS Bucket logging per request
|
||||
|
||||
|
@ -391,3 +391,128 @@ async def test_basic_gcs_logging_per_request():
|
|||
object_name=object_name,
|
||||
standard_callback_dynamic_params=standard_callback_dynamic_params,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_gcs_logging_per_request_with_no_litellm_callback_set():
|
||||
"""
|
||||
Test GCS Bucket logging per request
|
||||
|
||||
key difference: no litellm.callbacks set
|
||||
|
||||
Request 1 - pass gcs_bucket_name in kwargs
|
||||
Request 2 - don't pass gcs_bucket_name in kwargs - ensure 'litellm-testing-bucket'
|
||||
"""
|
||||
import logging
|
||||
from litellm._logging import verbose_logger
|
||||
|
||||
verbose_logger.setLevel(logging.DEBUG)
|
||||
load_vertex_ai_credentials()
|
||||
gcs_logger = GCSBucketLogger()
|
||||
|
||||
GCS_BUCKET_NAME = "key-logging-project1"
|
||||
standard_callback_dynamic_params: StandardCallbackDynamicParams = (
|
||||
StandardCallbackDynamicParams(gcs_bucket_name=GCS_BUCKET_NAME)
|
||||
)
|
||||
|
||||
try:
|
||||
response = await litellm.acompletion(
|
||||
model="gpt-4o-mini",
|
||||
temperature=0.7,
|
||||
messages=[{"role": "user", "content": "This is a test"}],
|
||||
max_tokens=10,
|
||||
user="ishaan-2",
|
||||
gcs_bucket_name=GCS_BUCKET_NAME,
|
||||
success_callback=["gcs_bucket"],
|
||||
failure_callback=["gcs_bucket"],
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
await asyncio.sleep(5)
|
||||
|
||||
# Get the current date
|
||||
# Get the current date
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
|
||||
# Modify the object_name to include the date-based folder
|
||||
object_name = f"{current_date}%2F{response.id}"
|
||||
|
||||
print("object_name", object_name)
|
||||
|
||||
# Check if object landed on GCS
|
||||
object_from_gcs = await gcs_logger.download_gcs_object(
|
||||
object_name=object_name,
|
||||
standard_callback_dynamic_params=standard_callback_dynamic_params,
|
||||
)
|
||||
print("object from gcs=", object_from_gcs)
|
||||
# convert object_from_gcs from bytes to DICT
|
||||
parsed_data = json.loads(object_from_gcs)
|
||||
print("object_from_gcs as dict", parsed_data)
|
||||
|
||||
print("type of object_from_gcs", type(parsed_data))
|
||||
|
||||
gcs_payload = StandardLoggingPayload(**parsed_data)
|
||||
|
||||
assert gcs_payload["model"] == "gpt-4o-mini"
|
||||
assert gcs_payload["messages"] == [{"role": "user", "content": "This is a test"}]
|
||||
|
||||
assert gcs_payload["response_cost"] > 0.0
|
||||
|
||||
assert gcs_payload["status"] == "success"
|
||||
|
||||
# clean up the object from GCS
|
||||
await gcs_logger.delete_gcs_object(
|
||||
object_name=object_name,
|
||||
standard_callback_dynamic_params=standard_callback_dynamic_params,
|
||||
)
|
||||
|
||||
# make a failure request - assert that failure callback is hit
|
||||
gcs_log_id = f"failure-test-{uuid.uuid4().hex}"
|
||||
try:
|
||||
response = await litellm.acompletion(
|
||||
model="gpt-4o-mini",
|
||||
temperature=0.7,
|
||||
messages=[{"role": "user", "content": "This is a test"}],
|
||||
max_tokens=10,
|
||||
user="ishaan-2",
|
||||
mock_response=litellm.BadRequestError(
|
||||
model="gpt-3.5-turbo",
|
||||
message="Error: 400: Bad Request: Invalid API key, please check your API key and try again.",
|
||||
llm_provider="openai",
|
||||
),
|
||||
success_callback=["gcs_bucket"],
|
||||
failure_callback=["gcs_bucket"],
|
||||
gcs_bucket_name=GCS_BUCKET_NAME,
|
||||
metadata={
|
||||
"gcs_log_id": gcs_log_id,
|
||||
},
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
await asyncio.sleep(5)
|
||||
|
||||
# check if the failure object is logged in GCS
|
||||
object_from_gcs = await gcs_logger.download_gcs_object(
|
||||
object_name=gcs_log_id,
|
||||
standard_callback_dynamic_params=standard_callback_dynamic_params,
|
||||
)
|
||||
print("object from gcs=", object_from_gcs)
|
||||
# convert object_from_gcs from bytes to DICT
|
||||
parsed_data = json.loads(object_from_gcs)
|
||||
print("object_from_gcs as dict", parsed_data)
|
||||
|
||||
gcs_payload = StandardLoggingPayload(**parsed_data)
|
||||
|
||||
assert gcs_payload["model"] == "gpt-4o-mini"
|
||||
assert gcs_payload["messages"] == [{"role": "user", "content": "This is a test"}]
|
||||
|
||||
assert gcs_payload["response_cost"] == 0
|
||||
assert gcs_payload["status"] == "failure"
|
||||
|
||||
# clean up the object from GCS
|
||||
await gcs_logger.delete_gcs_object(
|
||||
object_name=gcs_log_id,
|
||||
standard_callback_dynamic_params=standard_callback_dynamic_params,
|
||||
)
|
||||
|
|
|
@ -1389,6 +1389,138 @@ async def test_add_callback_via_key_litellm_pre_call_utils(
|
|||
assert new_data["failure_callback"] == expected_failure_callbacks
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"callback_type, expected_success_callbacks, expected_failure_callbacks",
|
||||
[
|
||||
("success", ["gcs_bucket"], []),
|
||||
("failure", [], ["gcs_bucket"]),
|
||||
("success_and_failure", ["gcs_bucket"], ["gcs_bucket"]),
|
||||
],
|
||||
)
|
||||
async def test_add_callback_via_key_litellm_pre_call_utils_gcs_bucket(
|
||||
prisma_client, callback_type, expected_success_callbacks, expected_failure_callbacks
|
||||
):
|
||||
import json
|
||||
|
||||
from fastapi import HTTPException, Request, Response
|
||||
from starlette.datastructures import URL
|
||||
|
||||
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
|
||||
|
||||
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||
|
||||
proxy_config = getattr(litellm.proxy.proxy_server, "proxy_config")
|
||||
|
||||
request = Request(scope={"type": "http", "method": "POST", "headers": {}})
|
||||
request._url = URL(url="/chat/completions")
|
||||
|
||||
test_data = {
|
||||
"model": "azure/chatgpt-v-2",
|
||||
"messages": [
|
||||
{"role": "user", "content": "write 1 sentence poem"},
|
||||
],
|
||||
"max_tokens": 10,
|
||||
"mock_response": "Hello world",
|
||||
"api_key": "my-fake-key",
|
||||
}
|
||||
|
||||
json_bytes = json.dumps(test_data).encode("utf-8")
|
||||
|
||||
request._body = json_bytes
|
||||
|
||||
data = {
|
||||
"data": {
|
||||
"model": "azure/chatgpt-v-2",
|
||||
"messages": [{"role": "user", "content": "write 1 sentence poem"}],
|
||||
"max_tokens": 10,
|
||||
"mock_response": "Hello world",
|
||||
"api_key": "my-fake-key",
|
||||
},
|
||||
"request": request,
|
||||
"user_api_key_dict": UserAPIKeyAuth(
|
||||
token=None,
|
||||
key_name=None,
|
||||
key_alias=None,
|
||||
spend=0.0,
|
||||
max_budget=None,
|
||||
expires=None,
|
||||
models=[],
|
||||
aliases={},
|
||||
config={},
|
||||
user_id=None,
|
||||
team_id=None,
|
||||
max_parallel_requests=None,
|
||||
metadata={
|
||||
"logging": [
|
||||
{
|
||||
"callback_name": "gcs_bucket",
|
||||
"callback_type": callback_type,
|
||||
"callback_vars": {
|
||||
"gcs_bucket_name": "key-logging-project1",
|
||||
"gcs_path_service_account": "adroit-crow-413218-a956eef1a2a8.json",
|
||||
},
|
||||
}
|
||||
]
|
||||
},
|
||||
tpm_limit=None,
|
||||
rpm_limit=None,
|
||||
budget_duration=None,
|
||||
budget_reset_at=None,
|
||||
allowed_cache_controls=[],
|
||||
permissions={},
|
||||
model_spend={},
|
||||
model_max_budget={},
|
||||
soft_budget_cooldown=False,
|
||||
litellm_budget_table=None,
|
||||
org_id=None,
|
||||
team_spend=None,
|
||||
team_alias=None,
|
||||
team_tpm_limit=None,
|
||||
team_rpm_limit=None,
|
||||
team_max_budget=None,
|
||||
team_models=[],
|
||||
team_blocked=False,
|
||||
soft_budget=None,
|
||||
team_model_aliases=None,
|
||||
team_member_spend=None,
|
||||
team_metadata=None,
|
||||
end_user_id=None,
|
||||
end_user_tpm_limit=None,
|
||||
end_user_rpm_limit=None,
|
||||
end_user_max_budget=None,
|
||||
last_refreshed_at=None,
|
||||
api_key=None,
|
||||
user_role=None,
|
||||
allowed_model_region=None,
|
||||
parent_otel_span=None,
|
||||
),
|
||||
"proxy_config": proxy_config,
|
||||
"general_settings": {},
|
||||
"version": "0.0.0",
|
||||
}
|
||||
|
||||
new_data = await add_litellm_data_to_request(**data)
|
||||
print("NEW DATA: {}".format(new_data))
|
||||
|
||||
assert "gcs_bucket_name" in new_data
|
||||
assert new_data["gcs_bucket_name"] == "key-logging-project1"
|
||||
assert "gcs_path_service_account" in new_data
|
||||
assert (
|
||||
new_data["gcs_path_service_account"] == "adroit-crow-413218-a956eef1a2a8.json"
|
||||
)
|
||||
|
||||
if expected_success_callbacks:
|
||||
assert "success_callback" in new_data
|
||||
assert new_data["success_callback"] == expected_success_callbacks
|
||||
|
||||
if expected_failure_callbacks:
|
||||
assert "failure_callback" in new_data
|
||||
assert new_data["failure_callback"] == expected_failure_callbacks
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gemini_pass_through_endpoint():
|
||||
from starlette.datastructures import URL
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue