fix(proxy_server.py): support for streaming

This commit is contained in:
Krrish Dholakia 2023-12-09 16:22:53 -08:00
parent 0c8b8200b8
commit 6ef0e8485e
4 changed files with 219 additions and 142 deletions

View file

@ -1,17 +1,32 @@
from typing import Optional from typing import Optional
import litellm
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
from fastapi import HTTPException from fastapi import HTTPException
async def max_parallel_request_allow_request(max_parallel_requests: Optional[int], api_key: Optional[str], user_api_key_cache: DualCache): class MaxParallelRequestsHandler(CustomLogger):
# Class variables or attributes
def __init__(self):
pass
def print_verbose(self, print_statement):
if litellm.set_verbose is True:
print(print_statement) # noqa
async def max_parallel_request_allow_request(self, max_parallel_requests: Optional[int], api_key: Optional[str], user_api_key_cache: DualCache):
if api_key is None: if api_key is None:
return return
if max_parallel_requests is None: if max_parallel_requests is None:
return return
self.user_api_key_cache = user_api_key_cache # save the api key cache for updating the value
# CHECK IF REQUEST ALLOWED # CHECK IF REQUEST ALLOWED
request_count_api_key = f"{api_key}_request_count" request_count_api_key = f"{api_key}_request_count"
current = user_api_key_cache.get_cache(key=request_count_api_key) current = user_api_key_cache.get_cache(key=request_count_api_key)
self.print_verbose(f"current: {current}")
if current is None: if current is None:
user_api_key_cache.set_cache(request_count_api_key, 1) user_api_key_cache.set_cache(request_count_api_key, 1)
elif int(current) < max_parallel_requests: elif int(current) < max_parallel_requests:
@ -21,13 +36,35 @@ async def max_parallel_request_allow_request(max_parallel_requests: Optional[int
raise HTTPException(status_code=429, detail="Max parallel request limit reached.") raise HTTPException(status_code=429, detail="Max parallel request limit reached.")
async def max_parallel_request_update_count(api_key: Optional[str], user_api_key_cache: DualCache): async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
self.print_verbose(f"INSIDE ASYNC SUCCESS LOGGING")
user_api_key = kwargs["litellm_params"]["metadata"]["user_api_key"]
if user_api_key is None:
return
request_count_api_key = f"{user_api_key}_request_count"
# check if it has collected an entire stream response
self.print_verbose(f"'complete_streaming_response' is in kwargs: {'complete_streaming_response' in kwargs}")
if "complete_streaming_response" in kwargs or kwargs["stream"] != True:
# Decrease count for this token
current = self.user_api_key_cache.get_cache(key=request_count_api_key) or 1
new_val = current - 1
self.print_verbose(f"updated_value in success call: {new_val}")
self.user_api_key_cache.set_cache(request_count_api_key, new_val)
except Exception as e:
self.print_verbose(e) # noqa
async def async_log_failure_call(self, api_key, user_api_key_cache):
try:
if api_key is None: if api_key is None:
return return
request_count_api_key = f"{api_key}_request_count" request_count_api_key = f"{api_key}_request_count"
# Decrease count for this token # Decrease count for this token
current = user_api_key_cache.get_cache(key=request_count_api_key) or 1 current = self.user_api_key_cache.get_cache(key=request_count_api_key) or 1
user_api_key_cache.set_cache(request_count_api_key, int(current) - 1) new_val = current - 1
self.print_verbose(f"updated_value in failure call: {new_val}")
return self.user_api_key_cache.set_cache(request_count_api_key, new_val)
except Exception as e:
self.print_verbose(f"An exception occurred - {str(e)}") # noqa

View file

@ -1,4 +1,4 @@
import sys, os, platform, time, copy, re, asyncio import sys, os, platform, time, copy, re, asyncio, inspect
import threading, ast import threading, ast
import shutil, random, traceback, requests import shutil, random, traceback, requests
from datetime import datetime, timedelta from datetime import datetime, timedelta
@ -94,7 +94,6 @@ import litellm
from litellm.proxy.utils import ( from litellm.proxy.utils import (
PrismaClient, PrismaClient,
get_instance_fn, get_instance_fn,
CallHooks,
ProxyLogging ProxyLogging
) )
import pydantic import pydantic
@ -198,8 +197,8 @@ user_custom_auth = None
use_background_health_checks = None use_background_health_checks = None
health_check_interval = None health_check_interval = None
health_check_results = {} health_check_results = {}
call_hooks = CallHooks(user_api_key_cache=user_api_key_cache) ### INITIALIZE GLOBAL LOGGING OBJECT ###
proxy_logging_obj: Optional[ProxyLogging] = None proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache)
### REDIS QUEUE ### ### REDIS QUEUE ###
async_result = None async_result = None
celery_app_conn = None celery_app_conn = None
@ -309,10 +308,9 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
) )
def prisma_setup(database_url: Optional[str]): def prisma_setup(database_url: Optional[str]):
global prisma_client, proxy_logging_obj global prisma_client, proxy_logging_obj, user_api_key_cache
### INITIALIZE GLOBAL LOGGING OBJECT ###
proxy_logging_obj = ProxyLogging()
proxy_logging_obj._init_litellm_callbacks()
if database_url is not None: if database_url is not None:
try: try:
prisma_client = PrismaClient(database_url=database_url, proxy_logging_obj=proxy_logging_obj) prisma_client = PrismaClient(database_url=database_url, proxy_logging_obj=proxy_logging_obj)
@ -390,6 +388,10 @@ async def track_cost_callback(
completion=output_text completion=output_text
) )
print("streaming response_cost", response_cost) print("streaming response_cost", response_cost)
user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None)
print(f"user_api_key - {user_api_key}; prisma_client - {prisma_client}")
if user_api_key and prisma_client:
await update_prisma_database(token=user_api_key, response_cost=response_cost)
elif kwargs["stream"] == False: # for non streaming responses elif kwargs["stream"] == False: # for non streaming responses
input_text = kwargs.get("messages", "") input_text = kwargs.get("messages", "")
print(f"type of input_text: {type(input_text)}") print(f"type of input_text: {type(input_text)}")
@ -475,9 +477,6 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
print_verbose(f"Loaded config YAML (api_key and environment_variables are not shown):\n{json.dumps(printed_yaml, indent=2)}") print_verbose(f"Loaded config YAML (api_key and environment_variables are not shown):\n{json.dumps(printed_yaml, indent=2)}")
## ROUTER CONFIG
cache_responses = False
## ENVIRONMENT VARIABLES ## ENVIRONMENT VARIABLES
environment_variables = config.get('environment_variables', None) environment_variables = config.get('environment_variables', None)
if environment_variables: if environment_variables:
@ -554,6 +553,8 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
else: else:
setattr(litellm, key, value) setattr(litellm, key, value)
## GENERAL SERVER SETTINGS (e.g. master key,..) # do this after initializing litellm, to ensure sentry logging works for proxylogging ## GENERAL SERVER SETTINGS (e.g. master key,..) # do this after initializing litellm, to ensure sentry logging works for proxylogging
general_settings = config.get("general_settings", {}) general_settings = config.get("general_settings", {})
if general_settings is None: if general_settings is None:
@ -589,18 +590,41 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
use_background_health_checks = general_settings.get("background_health_checks", False) use_background_health_checks = general_settings.get("background_health_checks", False)
health_check_interval = general_settings.get("health_check_interval", 300) health_check_interval = general_settings.get("health_check_interval", 300)
router_params: dict = {
"num_retries": 3,
"cache_responses": litellm.cache != None # cache if user passed in cache values
}
## MODEL LIST ## MODEL LIST
model_list = config.get('model_list', None) model_list = config.get('model_list', None)
if model_list: if model_list:
router = litellm.Router(model_list=model_list, num_retries=3, cache_responses=cache_responses) router_params["model_list"] = model_list
print(f"\033[32mLiteLLM: Proxy initialized with Config, Set models:\033[0m") print(f"\033[32mLiteLLM: Proxy initialized with Config, Set models:\033[0m")
for model in model_list: for model in model_list:
print(f"\033[32m {model.get('model_name', '')}\033[0m") print(f"\033[32m {model.get('model_name', '')}\033[0m")
litellm_model_name = model["litellm_params"]["model"] litellm_model_name = model["litellm_params"]["model"]
if "ollama" in litellm_model_name: litellm_model_api_base = model["litellm_params"].get("api_base", None)
if "ollama" in litellm_model_name and litellm_model_api_base is None:
run_ollama_serve() run_ollama_serve()
call_hooks.update_router_config(litellm_settings=litellm_settings, model_list=model_list, general_settings=general_settings) ## ROUTER SETTINGS (e.g. routing_strategy, ...)
router_settings = config.get("router_settings", None)
if router_settings and isinstance(router_settings, dict):
arg_spec = inspect.getfullargspec(litellm.Router)
# model list already set
exclude_args = {
"self",
"model_list",
}
available_args = [
x for x in arg_spec.args if x not in exclude_args
]
for k, v in router_settings.items():
if k in available_args:
router_params[k] = v
router = litellm.Router(**router_params) # type:ignore
return router, model_list, general_settings return router, model_list, general_settings
async def generate_key_helper_fn(duration: Optional[str], models: list, aliases: dict, config: dict, spend: float, token: Optional[str]=None, user_id: Optional[str]=None, max_parallel_requests: Optional[int]=None): async def generate_key_helper_fn(duration: Optional[str], models: list, aliases: dict, config: dict, spend: float, token: Optional[str]=None, user_id: Optional[str]=None, max_parallel_requests: Optional[int]=None):
@ -772,10 +796,13 @@ def data_generator(response):
yield f"data: {json.dumps(chunk)}\n\n" yield f"data: {json.dumps(chunk)}\n\n"
async def async_data_generator(response, user_api_key_dict): async def async_data_generator(response, user_api_key_dict):
global call_hooks
print_verbose("inside generator") print_verbose("inside generator")
async for chunk in response: async for chunk in response:
# try:
# await proxy_logging_obj.pre_call_hook(user_api_key_dict=user_api_key_dict, data=None, call_type="completion")
# except Exception as e:
# print(f"An exception occurred - {str(e)}")
print_verbose(f"returned chunk: {chunk}") print_verbose(f"returned chunk: {chunk}")
try: try:
yield f"data: {json.dumps(chunk.dict())}\n\n" yield f"data: {json.dumps(chunk.dict())}\n\n"
@ -946,7 +973,7 @@ async def completion(request: Request, model: Optional[str] = None, user_api_key
@router.post("/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"]) @router.post("/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"])
@router.post("/openai/deployments/{model:path}/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"]) # azure compatible endpoint @router.post("/openai/deployments/{model:path}/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"]) # azure compatible endpoint
async def chat_completion(request: Request, model: Optional[str] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()): async def chat_completion(request: Request, model: Optional[str] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()):
global general_settings, user_debug, call_hooks global general_settings, user_debug, proxy_logging_obj
try: try:
data = {} data = {}
data = await request.json() # type: ignore data = await request.json() # type: ignore
@ -992,7 +1019,7 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
data["api_base"] = user_api_base data["api_base"] = user_api_base
### CALL HOOKS ### - modify incoming data before calling the model ### CALL HOOKS ### - modify incoming data before calling the model
data = await call_hooks.pre_call(user_api_key_dict=user_api_key_dict, data=data, call_type="completion") data = await proxy_logging_obj.pre_call_hook(user_api_key_dict=user_api_key_dict, data=data, call_type="completion")
### ROUTE THE REQUEST ### ### ROUTE THE REQUEST ###
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else [] router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
@ -1009,15 +1036,10 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
return StreamingResponse(async_data_generator(user_api_key_dict=user_api_key_dict, response=response), media_type='text/event-stream') return StreamingResponse(async_data_generator(user_api_key_dict=user_api_key_dict, response=response), media_type='text/event-stream')
### CALL HOOKS ### - modify outgoing response
response = await call_hooks.post_call_success(user_api_key_dict=user_api_key_dict, response=response, call_type="completion")
background_tasks.add_task(log_input_output, request, response) # background task for logging to OTEL background_tasks.add_task(log_input_output, request, response) # background task for logging to OTEL
return response return response
except Exception as e: except Exception as e:
print(f"Exception received: {str(e)}") await proxy_logging_obj.post_call_failure_hook(user_api_key_dict=user_api_key_dict, original_exception=e)
raise e
await call_hooks.post_call_failure(original_exception=e, user_api_key_dict=user_api_key_dict)
print(f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`") print(f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`")
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else [] router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
if llm_router is not None and data.get("model", "") in router_model_names: if llm_router is not None and data.get("model", "") in router_model_names:
@ -1052,7 +1074,7 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
@router.post("/v1/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse) @router.post("/v1/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse)
@router.post("/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse) @router.post("/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse)
async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()): async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()):
global call_hooks global proxy_logging_obj
try: try:
# Use orjson to parse JSON data, orjson speeds up requests significantly # Use orjson to parse JSON data, orjson speeds up requests significantly
body = await request.body() body = await request.body()
@ -1095,8 +1117,8 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen
data["input"] = input_list data["input"] = input_list
break break
### CALL HOOKS ### - modify incoming data before calling the model ### CALL HOOKS ### - modify incoming data / reject request before calling the model
data = await call_hooks.pre_call(user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings") data = await proxy_logging_obj.pre_call_hook(user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings")
## ROUTE TO CORRECT ENDPOINT ## ## ROUTE TO CORRECT ENDPOINT ##
if llm_router is not None and data["model"] in router_model_names: # model in router model list if llm_router is not None and data["model"] in router_model_names: # model in router model list
@ -1107,12 +1129,9 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen
response = await litellm.aembedding(**data) response = await litellm.aembedding(**data)
background_tasks.add_task(log_input_output, request, response) # background task for logging to OTEL background_tasks.add_task(log_input_output, request, response) # background task for logging to OTEL
### CALL HOOKS ### - modify outgoing response
data = call_hooks.post_call_success(user_api_key_dict=user_api_key_dict, response=response, call_type="embeddings")
return response return response
except Exception as e: except Exception as e:
await call_hooks.post_call_failure(user_api_key_dict=user_api_key_dict, original_exception=e) await proxy_logging_obj.post_call_failure_hook(user_api_key_dict=user_api_key_dict, original_exception=e)
traceback.print_exc() traceback.print_exc()
raise e raise e
@ -1139,7 +1158,7 @@ async def generate_key_fn(request: Request, data: GenerateKeyRequest, Authorizat
- user_id: (str) Unique user id - used for tracking spend across multiple keys for same user id. - user_id: (str) Unique user id - used for tracking spend across multiple keys for same user id.
""" """
# data = await request.json() # data = await request.json()
data_json = json.loads(data.json()) # type: ignore data_json = data.json() # type: ignore
response = await generate_key_helper_fn(**data_json) response = await generate_key_helper_fn(**data_json)
return GenerateKeyResponse(key=response["token"], expires=response["expires"], user_id=response["user_id"]) return GenerateKeyResponse(key=response["token"], expires=response["expires"], user_id=response["user_id"])

View file

@ -3,7 +3,7 @@ import os, subprocess, hashlib, importlib, asyncio
import litellm, backoff import litellm, backoff
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm.proxy.hooks.parallel_request_limiter import max_parallel_request_allow_request, max_parallel_request_update_count from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler
def print_verbose(print_statement): def print_verbose(print_statement):
if litellm.set_verbose: if litellm.set_verbose:
@ -11,20 +11,23 @@ def print_verbose(print_statement):
### LOGGING ### ### LOGGING ###
class ProxyLogging: class ProxyLogging:
""" """
Logging for proxy. Logging/Custom Handlers for proxy.
Implemented mainly to log successful/failed db read/writes. Implemented mainly to:
- log successful/failed db read/writes
Currently just logs this to a provided sentry integration. - support the max parallel request integration
""" """
def __init__(self,): def __init__(self, user_api_key_cache: DualCache):
## INITIALIZE LITELLM CALLBACKS ## ## INITIALIZE LITELLM CALLBACKS ##
self._init_litellm_callbacks() self.call_details: dict = {}
self.call_details["user_api_key_cache"] = user_api_key_cache
self.max_parallel_request_limiter = MaxParallelRequestsHandler()
pass pass
def _init_litellm_callbacks(self): def _init_litellm_callbacks(self):
if len(litellm.callbacks) > 0:
litellm.callbacks.append(self.max_parallel_request_limiter)
for callback in litellm.callbacks: for callback in litellm.callbacks:
if callback not in litellm.input_callback: if callback not in litellm.input_callback:
litellm.input_callback.append(callback) litellm.input_callback.append(callback)
@ -53,6 +56,30 @@ class ProxyLogging:
callback_list=callback_list callback_list=callback_list
) )
async def pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, data: dict, call_type: Literal["completion", "embeddings"]):
"""
Allows users to modify/reject the incoming request to the proxy, without having to deal with parsing Request body.
Covers:
1. /chat/completions
2. /embeddings
"""
try:
self.call_details["data"] = data
self.call_details["call_type"] = call_type
## check if max parallel requests set
if user_api_key_dict.max_parallel_requests is not None:
## if set, check if request allowed
await self.max_parallel_request_limiter.max_parallel_request_allow_request(
max_parallel_requests=user_api_key_dict.max_parallel_requests,
api_key=user_api_key_dict.api_key,
user_api_key_cache=self.call_details["user_api_key_cache"])
return data
except Exception as e:
raise e
async def success_handler(self, *args, **kwargs): async def success_handler(self, *args, **kwargs):
""" """
Log successful db read/writes Log successful db read/writes
@ -68,6 +95,27 @@ class ProxyLogging:
if litellm.utils.capture_exception: if litellm.utils.capture_exception:
litellm.utils.capture_exception(error=original_exception) litellm.utils.capture_exception(error=original_exception)
async def post_call_failure_hook(self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth):
"""
Allows users to raise custom exceptions/log when a call fails, without having to deal with parsing Request body.
Covers:
1. /chat/completions
2. /embeddings
"""
# check if max parallel requests set
if user_api_key_dict.max_parallel_requests is not None:
## decrement call count if call failed
if (hasattr(original_exception, "status_code")
and original_exception.status_code == 429
and "Max parallel request limit reached" in str(original_exception)):
pass # ignore failed calls due to max limit being reached
else:
await self.max_parallel_request_limiter.async_log_failure_call(
api_key=user_api_key_dict.api_key,
user_api_key_cache=self.call_details["user_api_key_cache"])
return
### DB CONNECTOR ### ### DB CONNECTOR ###
# Define the retry decorator with backoff strategy # Define the retry decorator with backoff strategy
@ -290,65 +338,4 @@ def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
except Exception as e: except Exception as e:
raise e raise e
### CALL HOOKS ###
class CallHooks:
"""
Allows users to modify the incoming request / output to the proxy, without having to deal with parsing Request body.
Covers:
1. /chat/completions
2. /embeddings
"""
def __init__(self, user_api_key_cache: DualCache):
self.call_details: dict = {}
self.call_details["user_api_key_cache"] = user_api_key_cache
def update_router_config(self, litellm_settings: dict, general_settings: dict, model_list: list):
self.call_details["litellm_settings"] = litellm_settings
self.call_details["general_settings"] = general_settings
self.call_details["model_list"] = model_list
async def pre_call(self, user_api_key_dict: UserAPIKeyAuth, data: dict, call_type: Literal["completion", "embeddings"]):
try:
self.call_details["data"] = data
self.call_details["call_type"] = call_type
## check if max parallel requests set
if user_api_key_dict.max_parallel_requests is not None:
## if set, check if request allowed
await max_parallel_request_allow_request(
max_parallel_requests=user_api_key_dict.max_parallel_requests,
api_key=user_api_key_dict.api_key,
user_api_key_cache=self.call_details["user_api_key_cache"])
return data
except Exception as e:
raise e
async def post_call_success(self, user_api_key_dict: UserAPIKeyAuth, response: Optional[Any]=None, call_type: Optional[Literal["completion", "embeddings"]]=None, chunk: Optional[Any]=None):
try:
# check if max parallel requests set
if user_api_key_dict.max_parallel_requests is not None:
## decrement call, once complete
await max_parallel_request_update_count(
api_key=user_api_key_dict.api_key,
user_api_key_cache=self.call_details["user_api_key_cache"])
return response
except Exception as e:
raise e
async def post_call_failure(self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth):
# check if max parallel requests set
if user_api_key_dict.max_parallel_requests is not None:
## decrement call count if call failed
if (hasattr(original_exception, "status_code")
and original_exception.status_code == 429
and "Max parallel request limit reached" in str(original_exception)):
pass # ignore failed calls due to max limit being reached
else:
await max_parallel_request_update_count(
api_key=user_api_key_dict.api_key,
user_api_key_cache=self.call_details["user_api_key_cache"])
return

View file

@ -25,7 +25,7 @@ from fastapi.testclient import TestClient
from fastapi import FastAPI from fastapi import FastAPI
from litellm.proxy.proxy_server import router, save_worker_config, startup_event # Replace with the actual module where your FastAPI router is defined from litellm.proxy.proxy_server import router, save_worker_config, startup_event # Replace with the actual module where your FastAPI router is defined
filepath = os.path.dirname(os.path.abspath(__file__)) filepath = os.path.dirname(os.path.abspath(__file__))
config_fp = f"{filepath}/test_configs/test_config_custom_auth.yaml" config_fp = f"{filepath}/test_configs/test_config.yaml"
save_worker_config(config=config_fp, model=None, alias=None, api_base=None, api_version=None, debug=False, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False) save_worker_config(config=config_fp, model=None, alias=None, api_base=None, api_version=None, debug=False, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False)
app = FastAPI() app = FastAPI()
app.include_router(router) # Include your router in the test app app.include_router(router) # Include your router in the test app
@ -100,3 +100,37 @@ def test_add_new_key_max_parallel_limit(client):
_run_in_parallel() _run_in_parallel()
except Exception as e: except Exception as e:
pytest.fail(f"LiteLLM Proxy test failed. Exception: {str(e)}") pytest.fail(f"LiteLLM Proxy test failed. Exception: {str(e)}")
def test_add_new_key_max_parallel_limit_streaming(client):
try:
# Your test data
test_data = {"duration": "20m", "max_parallel_requests": 1}
# Your bearer token
token = os.getenv('PROXY_MASTER_KEY')
headers = {
"Authorization": f"Bearer {token}"
}
response = client.post("/key/generate", json=test_data, headers=headers)
print(f"response: {response.text}")
assert response.status_code == 200
result = response.json()
def _post_data():
json_data = {'model': 'azure-model', "messages": [{"role": "user", "content": f"this is a test request, write a short poem {time.time()}"}], "stream": True}
response = client.post("/chat/completions", json=json_data, headers={"Authorization": f"Bearer {result['key']}"})
return response
def _run_in_parallel():
with ThreadPoolExecutor(max_workers=2) as executor:
future1 = executor.submit(_post_data)
future2 = executor.submit(_post_data)
# Obtain the results from the futures
response1 = future1.result()
response2 = future2.result()
if response1.status_code == 429 or response2.status_code == 429:
pass
else:
raise Exception()
_run_in_parallel()
except Exception as e:
pytest.fail(f"LiteLLM Proxy test failed. Exception: {str(e)}")