mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
test(test_custom_callback_unit.py): adding unit tests for custom callbacks + fixing related bugs
This commit is contained in:
parent
1d2f5ce975
commit
ea89a8a938
8 changed files with 501 additions and 122 deletions
|
@ -262,41 +262,43 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
|
|||
if (route == "/key/generate" or route == "/key/delete" or route == "/key/info") and not is_master_key_valid:
|
||||
raise Exception(f"If master key is set, only master key can be used to generate, delete or get info for new keys")
|
||||
|
||||
if prisma_client:
|
||||
## check for cache hit (In-Memory Cache)
|
||||
valid_token = user_api_key_cache.get_cache(key=api_key)
|
||||
print(f"valid_token from cache: {valid_token}")
|
||||
if valid_token is None:
|
||||
## check db
|
||||
valid_token = await prisma_client.get_data(token=api_key, expires=datetime.utcnow())
|
||||
user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60)
|
||||
elif valid_token is not None:
|
||||
print(f"API Key Cache Hit!")
|
||||
if valid_token:
|
||||
litellm.model_alias_map = valid_token.aliases
|
||||
config = valid_token.config
|
||||
if config != {}:
|
||||
model_list = config.get("model_list", [])
|
||||
llm_model_list = model_list
|
||||
print("\n new llm router model list", llm_model_list)
|
||||
if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called
|
||||
api_key = valid_token.token
|
||||
valid_token_dict = _get_pydantic_json_dict(valid_token)
|
||||
valid_token_dict.pop("token", None)
|
||||
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
|
||||
else:
|
||||
data = await request.json()
|
||||
model = data.get("model", None)
|
||||
if model in litellm.model_alias_map:
|
||||
model = litellm.model_alias_map[model]
|
||||
if model and model not in valid_token.models:
|
||||
raise Exception(f"Token not allowed to access model")
|
||||
if prisma_client is None: # if both master key + user key submitted, and user key != master key, and no db connected, raise an error
|
||||
raise Exception("No connected db.")
|
||||
|
||||
## check for cache hit (In-Memory Cache)
|
||||
valid_token = user_api_key_cache.get_cache(key=api_key)
|
||||
print(f"valid_token from cache: {valid_token}")
|
||||
if valid_token is None:
|
||||
## check db
|
||||
valid_token = await prisma_client.get_data(token=api_key, expires=datetime.utcnow())
|
||||
user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60)
|
||||
elif valid_token is not None:
|
||||
print(f"API Key Cache Hit!")
|
||||
if valid_token:
|
||||
litellm.model_alias_map = valid_token.aliases
|
||||
config = valid_token.config
|
||||
if config != {}:
|
||||
model_list = config.get("model_list", [])
|
||||
llm_model_list = model_list
|
||||
print("\n new llm router model list", llm_model_list)
|
||||
if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called
|
||||
api_key = valid_token.token
|
||||
valid_token_dict = _get_pydantic_json_dict(valid_token)
|
||||
valid_token_dict.pop("token", None)
|
||||
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
|
||||
else:
|
||||
raise Exception(f"Invalid token")
|
||||
data = await request.json()
|
||||
model = data.get("model", None)
|
||||
if model in litellm.model_alias_map:
|
||||
model = litellm.model_alias_map[model]
|
||||
if model and model not in valid_token.models:
|
||||
raise Exception(f"Token not allowed to access model")
|
||||
api_key = valid_token.token
|
||||
valid_token_dict = _get_pydantic_json_dict(valid_token)
|
||||
valid_token_dict.pop("token", None)
|
||||
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
|
||||
else:
|
||||
raise Exception(f"Invalid token")
|
||||
except Exception as e:
|
||||
print(f"An exception occurred - {traceback.format_exc()}")
|
||||
if isinstance(e, HTTPException):
|
||||
|
@ -380,25 +382,14 @@ async def track_cost_callback(
|
|||
if "complete_streaming_response" in kwargs:
|
||||
# for tracking streaming cost we pass the "messages" and the output_text to litellm.completion_cost
|
||||
completion_response=kwargs["complete_streaming_response"]
|
||||
input_text = kwargs["messages"]
|
||||
output_text = completion_response["choices"][0]["message"]["content"]
|
||||
response_cost = litellm.completion_cost(
|
||||
model = kwargs["model"],
|
||||
messages = input_text,
|
||||
completion=output_text
|
||||
)
|
||||
response_cost = litellm.completion_cost(completion_response=completion_response)
|
||||
print("streaming response_cost", response_cost)
|
||||
user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None)
|
||||
print(f"user_api_key - {user_api_key}; prisma_client - {prisma_client}")
|
||||
if user_api_key and prisma_client:
|
||||
await update_prisma_database(token=user_api_key, response_cost=response_cost)
|
||||
elif kwargs["stream"] == False: # for non streaming responses
|
||||
input_text = kwargs.get("messages", "")
|
||||
print(f"type of input_text: {type(input_text)}")
|
||||
if isinstance(input_text, list):
|
||||
response_cost = litellm.completion_cost(completion_response=completion_response, messages=input_text)
|
||||
elif isinstance(input_text, str):
|
||||
response_cost = litellm.completion_cost(completion_response=completion_response, prompt=input_text)
|
||||
response_cost = litellm.completion_cost(completion_response=completion_response)
|
||||
print(f"received completion response: {completion_response}")
|
||||
|
||||
print(f"regular response_cost: {response_cost}")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue