forked from phoenix/litellm-mirror
fix(proxy_server.py): support for streaming
This commit is contained in:
parent
0c8b8200b8
commit
6ef0e8485e
4 changed files with 219 additions and 142 deletions
|
@ -1,17 +1,32 @@
|
|||
from typing import Optional
|
||||
import litellm
|
||||
from litellm.caching import DualCache
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from fastapi import HTTPException
|
||||
|
||||
async def max_parallel_request_allow_request(max_parallel_requests: Optional[int], api_key: Optional[str], user_api_key_cache: DualCache):
|
||||
class MaxParallelRequestsHandler(CustomLogger):
|
||||
# Class variables or attributes
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def print_verbose(self, print_statement):
|
||||
if litellm.set_verbose is True:
|
||||
print(print_statement) # noqa
|
||||
|
||||
|
||||
async def max_parallel_request_allow_request(self, max_parallel_requests: Optional[int], api_key: Optional[str], user_api_key_cache: DualCache):
|
||||
if api_key is None:
|
||||
return
|
||||
|
||||
if max_parallel_requests is None:
|
||||
return
|
||||
|
||||
self.user_api_key_cache = user_api_key_cache # save the api key cache for updating the value
|
||||
|
||||
# CHECK IF REQUEST ALLOWED
|
||||
request_count_api_key = f"{api_key}_request_count"
|
||||
current = user_api_key_cache.get_cache(key=request_count_api_key)
|
||||
self.print_verbose(f"current: {current}")
|
||||
if current is None:
|
||||
user_api_key_cache.set_cache(request_count_api_key, 1)
|
||||
elif int(current) < max_parallel_requests:
|
||||
|
@ -21,13 +36,35 @@ async def max_parallel_request_allow_request(max_parallel_requests: Optional[int
|
|||
raise HTTPException(status_code=429, detail="Max parallel request limit reached.")
|
||||
|
||||
|
||||
async def max_parallel_request_update_count(api_key: Optional[str], user_api_key_cache: DualCache):
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
self.print_verbose(f"INSIDE ASYNC SUCCESS LOGGING")
|
||||
user_api_key = kwargs["litellm_params"]["metadata"]["user_api_key"]
|
||||
if user_api_key is None:
|
||||
return
|
||||
|
||||
request_count_api_key = f"{user_api_key}_request_count"
|
||||
# check if it has collected an entire stream response
|
||||
self.print_verbose(f"'complete_streaming_response' is in kwargs: {'complete_streaming_response' in kwargs}")
|
||||
if "complete_streaming_response" in kwargs or kwargs["stream"] != True:
|
||||
# Decrease count for this token
|
||||
current = self.user_api_key_cache.get_cache(key=request_count_api_key) or 1
|
||||
new_val = current - 1
|
||||
self.print_verbose(f"updated_value in success call: {new_val}")
|
||||
self.user_api_key_cache.set_cache(request_count_api_key, new_val)
|
||||
except Exception as e:
|
||||
self.print_verbose(e) # noqa
|
||||
|
||||
async def async_log_failure_call(self, api_key, user_api_key_cache):
|
||||
try:
|
||||
if api_key is None:
|
||||
return
|
||||
|
||||
request_count_api_key = f"{api_key}_request_count"
|
||||
# Decrease count for this token
|
||||
current = user_api_key_cache.get_cache(key=request_count_api_key) or 1
|
||||
user_api_key_cache.set_cache(request_count_api_key, int(current) - 1)
|
||||
|
||||
return
|
||||
current = self.user_api_key_cache.get_cache(key=request_count_api_key) or 1
|
||||
new_val = current - 1
|
||||
self.print_verbose(f"updated_value in failure call: {new_val}")
|
||||
self.user_api_key_cache.set_cache(request_count_api_key, new_val)
|
||||
except Exception as e:
|
||||
self.print_verbose(f"An exception occurred - {str(e)}") # noqa
|
|
@ -1,4 +1,4 @@
|
|||
import sys, os, platform, time, copy, re, asyncio
|
||||
import sys, os, platform, time, copy, re, asyncio, inspect
|
||||
import threading, ast
|
||||
import shutil, random, traceback, requests
|
||||
from datetime import datetime, timedelta
|
||||
|
@ -94,7 +94,6 @@ import litellm
|
|||
from litellm.proxy.utils import (
|
||||
PrismaClient,
|
||||
get_instance_fn,
|
||||
CallHooks,
|
||||
ProxyLogging
|
||||
)
|
||||
import pydantic
|
||||
|
@ -198,8 +197,8 @@ user_custom_auth = None
|
|||
use_background_health_checks = None
|
||||
health_check_interval = None
|
||||
health_check_results = {}
|
||||
call_hooks = CallHooks(user_api_key_cache=user_api_key_cache)
|
||||
proxy_logging_obj: Optional[ProxyLogging] = None
|
||||
### INITIALIZE GLOBAL LOGGING OBJECT ###
|
||||
proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache)
|
||||
### REDIS QUEUE ###
|
||||
async_result = None
|
||||
celery_app_conn = None
|
||||
|
@ -309,10 +308,9 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
|
|||
)
|
||||
|
||||
def prisma_setup(database_url: Optional[str]):
|
||||
global prisma_client, proxy_logging_obj
|
||||
### INITIALIZE GLOBAL LOGGING OBJECT ###
|
||||
proxy_logging_obj = ProxyLogging()
|
||||
global prisma_client, proxy_logging_obj, user_api_key_cache
|
||||
|
||||
proxy_logging_obj._init_litellm_callbacks()
|
||||
if database_url is not None:
|
||||
try:
|
||||
prisma_client = PrismaClient(database_url=database_url, proxy_logging_obj=proxy_logging_obj)
|
||||
|
@ -390,6 +388,10 @@ async def track_cost_callback(
|
|||
completion=output_text
|
||||
)
|
||||
print("streaming response_cost", response_cost)
|
||||
user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None)
|
||||
print(f"user_api_key - {user_api_key}; prisma_client - {prisma_client}")
|
||||
if user_api_key and prisma_client:
|
||||
await update_prisma_database(token=user_api_key, response_cost=response_cost)
|
||||
elif kwargs["stream"] == False: # for non streaming responses
|
||||
input_text = kwargs.get("messages", "")
|
||||
print(f"type of input_text: {type(input_text)}")
|
||||
|
@ -475,9 +477,6 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
|||
|
||||
print_verbose(f"Loaded config YAML (api_key and environment_variables are not shown):\n{json.dumps(printed_yaml, indent=2)}")
|
||||
|
||||
## ROUTER CONFIG
|
||||
cache_responses = False
|
||||
|
||||
## ENVIRONMENT VARIABLES
|
||||
environment_variables = config.get('environment_variables', None)
|
||||
if environment_variables:
|
||||
|
@ -554,6 +553,8 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
|||
else:
|
||||
setattr(litellm, key, value)
|
||||
|
||||
|
||||
|
||||
## GENERAL SERVER SETTINGS (e.g. master key,..) # do this after initializing litellm, to ensure sentry logging works for proxylogging
|
||||
general_settings = config.get("general_settings", {})
|
||||
if general_settings is None:
|
||||
|
@ -589,18 +590,41 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
|||
use_background_health_checks = general_settings.get("background_health_checks", False)
|
||||
health_check_interval = general_settings.get("health_check_interval", 300)
|
||||
|
||||
router_params: dict = {
|
||||
"num_retries": 3,
|
||||
"cache_responses": litellm.cache != None # cache if user passed in cache values
|
||||
}
|
||||
## MODEL LIST
|
||||
model_list = config.get('model_list', None)
|
||||
if model_list:
|
||||
router = litellm.Router(model_list=model_list, num_retries=3, cache_responses=cache_responses)
|
||||
router_params["model_list"] = model_list
|
||||
print(f"\033[32mLiteLLM: Proxy initialized with Config, Set models:\033[0m")
|
||||
for model in model_list:
|
||||
print(f"\033[32m {model.get('model_name', '')}\033[0m")
|
||||
litellm_model_name = model["litellm_params"]["model"]
|
||||
if "ollama" in litellm_model_name:
|
||||
litellm_model_api_base = model["litellm_params"].get("api_base", None)
|
||||
if "ollama" in litellm_model_name and litellm_model_api_base is None:
|
||||
run_ollama_serve()
|
||||
|
||||
call_hooks.update_router_config(litellm_settings=litellm_settings, model_list=model_list, general_settings=general_settings)
|
||||
## ROUTER SETTINGS (e.g. routing_strategy, ...)
|
||||
router_settings = config.get("router_settings", None)
|
||||
if router_settings and isinstance(router_settings, dict):
|
||||
arg_spec = inspect.getfullargspec(litellm.Router)
|
||||
# model list already set
|
||||
exclude_args = {
|
||||
"self",
|
||||
"model_list",
|
||||
}
|
||||
|
||||
available_args = [
|
||||
x for x in arg_spec.args if x not in exclude_args
|
||||
]
|
||||
|
||||
for k, v in router_settings.items():
|
||||
if k in available_args:
|
||||
router_params[k] = v
|
||||
|
||||
router = litellm.Router(**router_params) # type:ignore
|
||||
return router, model_list, general_settings
|
||||
|
||||
async def generate_key_helper_fn(duration: Optional[str], models: list, aliases: dict, config: dict, spend: float, token: Optional[str]=None, user_id: Optional[str]=None, max_parallel_requests: Optional[int]=None):
|
||||
|
@ -772,10 +796,13 @@ def data_generator(response):
|
|||
yield f"data: {json.dumps(chunk)}\n\n"
|
||||
|
||||
async def async_data_generator(response, user_api_key_dict):
|
||||
global call_hooks
|
||||
|
||||
print_verbose("inside generator")
|
||||
async for chunk in response:
|
||||
# try:
|
||||
# await proxy_logging_obj.pre_call_hook(user_api_key_dict=user_api_key_dict, data=None, call_type="completion")
|
||||
# except Exception as e:
|
||||
# print(f"An exception occurred - {str(e)}")
|
||||
|
||||
print_verbose(f"returned chunk: {chunk}")
|
||||
try:
|
||||
yield f"data: {json.dumps(chunk.dict())}\n\n"
|
||||
|
@ -946,7 +973,7 @@ async def completion(request: Request, model: Optional[str] = None, user_api_key
|
|||
@router.post("/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"])
|
||||
@router.post("/openai/deployments/{model:path}/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"]) # azure compatible endpoint
|
||||
async def chat_completion(request: Request, model: Optional[str] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()):
|
||||
global general_settings, user_debug, call_hooks
|
||||
global general_settings, user_debug, proxy_logging_obj
|
||||
try:
|
||||
data = {}
|
||||
data = await request.json() # type: ignore
|
||||
|
@ -992,7 +1019,7 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
|
|||
data["api_base"] = user_api_base
|
||||
|
||||
### CALL HOOKS ### - modify incoming data before calling the model
|
||||
data = await call_hooks.pre_call(user_api_key_dict=user_api_key_dict, data=data, call_type="completion")
|
||||
data = await proxy_logging_obj.pre_call_hook(user_api_key_dict=user_api_key_dict, data=data, call_type="completion")
|
||||
|
||||
### ROUTE THE REQUEST ###
|
||||
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
|
||||
|
@ -1009,15 +1036,10 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
|
|||
if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
|
||||
return StreamingResponse(async_data_generator(user_api_key_dict=user_api_key_dict, response=response), media_type='text/event-stream')
|
||||
|
||||
### CALL HOOKS ### - modify outgoing response
|
||||
response = await call_hooks.post_call_success(user_api_key_dict=user_api_key_dict, response=response, call_type="completion")
|
||||
|
||||
background_tasks.add_task(log_input_output, request, response) # background task for logging to OTEL
|
||||
return response
|
||||
except Exception as e:
|
||||
print(f"Exception received: {str(e)}")
|
||||
raise e
|
||||
await call_hooks.post_call_failure(original_exception=e, user_api_key_dict=user_api_key_dict)
|
||||
await proxy_logging_obj.post_call_failure_hook(user_api_key_dict=user_api_key_dict, original_exception=e)
|
||||
print(f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`")
|
||||
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
|
||||
if llm_router is not None and data.get("model", "") in router_model_names:
|
||||
|
@ -1052,7 +1074,7 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
|
|||
@router.post("/v1/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse)
|
||||
@router.post("/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse)
|
||||
async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()):
|
||||
global call_hooks
|
||||
global proxy_logging_obj
|
||||
try:
|
||||
# Use orjson to parse JSON data, orjson speeds up requests significantly
|
||||
body = await request.body()
|
||||
|
@ -1095,8 +1117,8 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen
|
|||
data["input"] = input_list
|
||||
break
|
||||
|
||||
### CALL HOOKS ### - modify incoming data before calling the model
|
||||
data = await call_hooks.pre_call(user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings")
|
||||
### CALL HOOKS ### - modify incoming data / reject request before calling the model
|
||||
data = await proxy_logging_obj.pre_call_hook(user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings")
|
||||
|
||||
## ROUTE TO CORRECT ENDPOINT ##
|
||||
if llm_router is not None and data["model"] in router_model_names: # model in router model list
|
||||
|
@ -1107,12 +1129,9 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen
|
|||
response = await litellm.aembedding(**data)
|
||||
background_tasks.add_task(log_input_output, request, response) # background task for logging to OTEL
|
||||
|
||||
### CALL HOOKS ### - modify outgoing response
|
||||
data = call_hooks.post_call_success(user_api_key_dict=user_api_key_dict, response=response, call_type="embeddings")
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
await call_hooks.post_call_failure(user_api_key_dict=user_api_key_dict, original_exception=e)
|
||||
await proxy_logging_obj.post_call_failure_hook(user_api_key_dict=user_api_key_dict, original_exception=e)
|
||||
traceback.print_exc()
|
||||
raise e
|
||||
|
||||
|
@ -1139,7 +1158,7 @@ async def generate_key_fn(request: Request, data: GenerateKeyRequest, Authorizat
|
|||
- user_id: (str) Unique user id - used for tracking spend across multiple keys for same user id.
|
||||
"""
|
||||
# data = await request.json()
|
||||
data_json = json.loads(data.json()) # type: ignore
|
||||
data_json = data.json() # type: ignore
|
||||
response = await generate_key_helper_fn(**data_json)
|
||||
return GenerateKeyResponse(key=response["token"], expires=response["expires"], user_id=response["user_id"])
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ import os, subprocess, hashlib, importlib, asyncio
|
|||
import litellm, backoff
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.caching import DualCache
|
||||
from litellm.proxy.hooks.parallel_request_limiter import max_parallel_request_allow_request, max_parallel_request_update_count
|
||||
from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler
|
||||
|
||||
def print_verbose(print_statement):
|
||||
if litellm.set_verbose:
|
||||
|
@ -11,20 +11,23 @@ def print_verbose(print_statement):
|
|||
### LOGGING ###
|
||||
class ProxyLogging:
|
||||
"""
|
||||
Logging for proxy.
|
||||
Logging/Custom Handlers for proxy.
|
||||
|
||||
Implemented mainly to log successful/failed db read/writes.
|
||||
|
||||
Currently just logs this to a provided sentry integration.
|
||||
Implemented mainly to:
|
||||
- log successful/failed db read/writes
|
||||
- support the max parallel request integration
|
||||
"""
|
||||
|
||||
def __init__(self,):
|
||||
def __init__(self, user_api_key_cache: DualCache):
|
||||
## INITIALIZE LITELLM CALLBACKS ##
|
||||
self._init_litellm_callbacks()
|
||||
self.call_details: dict = {}
|
||||
self.call_details["user_api_key_cache"] = user_api_key_cache
|
||||
self.max_parallel_request_limiter = MaxParallelRequestsHandler()
|
||||
pass
|
||||
|
||||
def _init_litellm_callbacks(self):
|
||||
if len(litellm.callbacks) > 0:
|
||||
|
||||
litellm.callbacks.append(self.max_parallel_request_limiter)
|
||||
for callback in litellm.callbacks:
|
||||
if callback not in litellm.input_callback:
|
||||
litellm.input_callback.append(callback)
|
||||
|
@ -53,6 +56,30 @@ class ProxyLogging:
|
|||
callback_list=callback_list
|
||||
)
|
||||
|
||||
async def pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, data: dict, call_type: Literal["completion", "embeddings"]):
|
||||
"""
|
||||
Allows users to modify/reject the incoming request to the proxy, without having to deal with parsing Request body.
|
||||
|
||||
Covers:
|
||||
1. /chat/completions
|
||||
2. /embeddings
|
||||
"""
|
||||
try:
|
||||
self.call_details["data"] = data
|
||||
self.call_details["call_type"] = call_type
|
||||
|
||||
## check if max parallel requests set
|
||||
if user_api_key_dict.max_parallel_requests is not None:
|
||||
## if set, check if request allowed
|
||||
await self.max_parallel_request_limiter.max_parallel_request_allow_request(
|
||||
max_parallel_requests=user_api_key_dict.max_parallel_requests,
|
||||
api_key=user_api_key_dict.api_key,
|
||||
user_api_key_cache=self.call_details["user_api_key_cache"])
|
||||
|
||||
return data
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def success_handler(self, *args, **kwargs):
|
||||
"""
|
||||
Log successful db read/writes
|
||||
|
@ -68,6 +95,27 @@ class ProxyLogging:
|
|||
if litellm.utils.capture_exception:
|
||||
litellm.utils.capture_exception(error=original_exception)
|
||||
|
||||
async def post_call_failure_hook(self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth):
|
||||
"""
|
||||
Allows users to raise custom exceptions/log when a call fails, without having to deal with parsing Request body.
|
||||
|
||||
Covers:
|
||||
1. /chat/completions
|
||||
2. /embeddings
|
||||
"""
|
||||
# check if max parallel requests set
|
||||
if user_api_key_dict.max_parallel_requests is not None:
|
||||
## decrement call count if call failed
|
||||
if (hasattr(original_exception, "status_code")
|
||||
and original_exception.status_code == 429
|
||||
and "Max parallel request limit reached" in str(original_exception)):
|
||||
pass # ignore failed calls due to max limit being reached
|
||||
else:
|
||||
await self.max_parallel_request_limiter.async_log_failure_call(
|
||||
api_key=user_api_key_dict.api_key,
|
||||
user_api_key_cache=self.call_details["user_api_key_cache"])
|
||||
return
|
||||
|
||||
|
||||
### DB CONNECTOR ###
|
||||
# Define the retry decorator with backoff strategy
|
||||
|
@ -290,65 +338,4 @@ def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
|
|||
except Exception as e:
|
||||
raise e
|
||||
|
||||
### CALL HOOKS ###
|
||||
class CallHooks:
|
||||
"""
|
||||
Allows users to modify the incoming request / output to the proxy, without having to deal with parsing Request body.
|
||||
|
||||
Covers:
|
||||
1. /chat/completions
|
||||
2. /embeddings
|
||||
"""
|
||||
|
||||
def __init__(self, user_api_key_cache: DualCache):
|
||||
self.call_details: dict = {}
|
||||
self.call_details["user_api_key_cache"] = user_api_key_cache
|
||||
|
||||
def update_router_config(self, litellm_settings: dict, general_settings: dict, model_list: list):
|
||||
self.call_details["litellm_settings"] = litellm_settings
|
||||
self.call_details["general_settings"] = general_settings
|
||||
self.call_details["model_list"] = model_list
|
||||
|
||||
async def pre_call(self, user_api_key_dict: UserAPIKeyAuth, data: dict, call_type: Literal["completion", "embeddings"]):
|
||||
try:
|
||||
self.call_details["data"] = data
|
||||
self.call_details["call_type"] = call_type
|
||||
|
||||
## check if max parallel requests set
|
||||
if user_api_key_dict.max_parallel_requests is not None:
|
||||
## if set, check if request allowed
|
||||
await max_parallel_request_allow_request(
|
||||
max_parallel_requests=user_api_key_dict.max_parallel_requests,
|
||||
api_key=user_api_key_dict.api_key,
|
||||
user_api_key_cache=self.call_details["user_api_key_cache"])
|
||||
|
||||
return data
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def post_call_success(self, user_api_key_dict: UserAPIKeyAuth, response: Optional[Any]=None, call_type: Optional[Literal["completion", "embeddings"]]=None, chunk: Optional[Any]=None):
|
||||
try:
|
||||
# check if max parallel requests set
|
||||
if user_api_key_dict.max_parallel_requests is not None:
|
||||
## decrement call, once complete
|
||||
await max_parallel_request_update_count(
|
||||
api_key=user_api_key_dict.api_key,
|
||||
user_api_key_cache=self.call_details["user_api_key_cache"])
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def post_call_failure(self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth):
|
||||
# check if max parallel requests set
|
||||
if user_api_key_dict.max_parallel_requests is not None:
|
||||
## decrement call count if call failed
|
||||
if (hasattr(original_exception, "status_code")
|
||||
and original_exception.status_code == 429
|
||||
and "Max parallel request limit reached" in str(original_exception)):
|
||||
pass # ignore failed calls due to max limit being reached
|
||||
else:
|
||||
await max_parallel_request_update_count(
|
||||
api_key=user_api_key_dict.api_key,
|
||||
user_api_key_cache=self.call_details["user_api_key_cache"])
|
||||
return
|
|
@ -25,7 +25,7 @@ from fastapi.testclient import TestClient
|
|||
from fastapi import FastAPI
|
||||
from litellm.proxy.proxy_server import router, save_worker_config, startup_event # Replace with the actual module where your FastAPI router is defined
|
||||
filepath = os.path.dirname(os.path.abspath(__file__))
|
||||
config_fp = f"{filepath}/test_configs/test_config_custom_auth.yaml"
|
||||
config_fp = f"{filepath}/test_configs/test_config.yaml"
|
||||
save_worker_config(config=config_fp, model=None, alias=None, api_base=None, api_version=None, debug=False, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False)
|
||||
app = FastAPI()
|
||||
app.include_router(router) # Include your router in the test app
|
||||
|
@ -100,3 +100,37 @@ def test_add_new_key_max_parallel_limit(client):
|
|||
_run_in_parallel()
|
||||
except Exception as e:
|
||||
pytest.fail(f"LiteLLM Proxy test failed. Exception: {str(e)}")
|
||||
|
||||
def test_add_new_key_max_parallel_limit_streaming(client):
|
||||
try:
|
||||
# Your test data
|
||||
test_data = {"duration": "20m", "max_parallel_requests": 1}
|
||||
# Your bearer token
|
||||
token = os.getenv('PROXY_MASTER_KEY')
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}"
|
||||
}
|
||||
response = client.post("/key/generate", json=test_data, headers=headers)
|
||||
print(f"response: {response.text}")
|
||||
assert response.status_code == 200
|
||||
result = response.json()
|
||||
def _post_data():
|
||||
json_data = {'model': 'azure-model', "messages": [{"role": "user", "content": f"this is a test request, write a short poem {time.time()}"}], "stream": True}
|
||||
response = client.post("/chat/completions", json=json_data, headers={"Authorization": f"Bearer {result['key']}"})
|
||||
return response
|
||||
def _run_in_parallel():
|
||||
with ThreadPoolExecutor(max_workers=2) as executor:
|
||||
future1 = executor.submit(_post_data)
|
||||
future2 = executor.submit(_post_data)
|
||||
|
||||
# Obtain the results from the futures
|
||||
response1 = future1.result()
|
||||
response2 = future2.result()
|
||||
if response1.status_code == 429 or response2.status_code == 429:
|
||||
pass
|
||||
else:
|
||||
raise Exception()
|
||||
_run_in_parallel()
|
||||
except Exception as e:
|
||||
pytest.fail(f"LiteLLM Proxy test failed. Exception: {str(e)}")
|
Loading…
Add table
Add a link
Reference in a new issue