Merge pull request #1574 from BerriAI/litellm_fix_streaming_spend_tracking

[WIP] fix(utils.py): fix proxy streaming spend tracking
This commit is contained in:
Krish Dholakia 2024-01-23 17:07:40 -08:00 committed by GitHub
commit 4ca4913468
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 179 additions and 23 deletions

View file

@ -114,6 +114,25 @@ jobs:
pip install "pytest==7.3.1" pip install "pytest==7.3.1"
pip install "pytest-asyncio==0.21.1" pip install "pytest-asyncio==0.21.1"
pip install aiohttp pip install aiohttp
pip install openai
python -m pip install --upgrade pip
python -m pip install -r .circleci/requirements.txt
pip install "pytest==7.3.1"
pip install "pytest-asyncio==0.21.1"
pip install mypy
pip install "google-generativeai>=0.3.2"
pip install "google-cloud-aiplatform>=1.38.0"
pip install "boto3>=1.28.57"
pip install langchain
pip install "langfuse>=2.0.0"
pip install numpydoc
pip install prisma
pip install "httpx==0.24.1"
pip install "gunicorn==21.2.0"
pip install "anyio==3.7.1"
pip install "aiodynamo==23.10.1"
pip install "asyncio==3.4.3"
pip install "PyGithub==1.59.1"
# Run pytest and generate JUnit XML report # Run pytest and generate JUnit XML report
- run: - run:
name: Build Docker image name: Build Docker image

View file

@ -570,7 +570,7 @@ async def track_cost_callback(
litellm_params = kwargs.get("litellm_params", {}) or {} litellm_params = kwargs.get("litellm_params", {}) or {}
proxy_server_request = litellm_params.get("proxy_server_request") or {} proxy_server_request = litellm_params.get("proxy_server_request") or {}
user_id = proxy_server_request.get("body", {}).get("user", None) user_id = proxy_server_request.get("body", {}).get("user", None)
if "response_cost" in kwargs: if kwargs.get("response_cost", None) is not None:
response_cost = kwargs["response_cost"] response_cost = kwargs["response_cost"]
user_api_key = kwargs["litellm_params"]["metadata"].get( user_api_key = kwargs["litellm_params"]["metadata"].get(
"user_api_key", None "user_api_key", None
@ -596,6 +596,10 @@ async def track_cost_callback(
end_time=end_time, end_time=end_time,
) )
else: else:
if kwargs["stream"] != True or (
kwargs["stream"] == True
and kwargs.get("complete_streaming_response") in kwargs
):
raise Exception( raise Exception(
f"Model not in litellm model cost map. Add custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing" f"Model not in litellm model cost map. Add custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing"
) )
@ -695,6 +699,7 @@ async def update_database(
valid_token.spend = new_spend valid_token.spend = new_spend
user_api_key_cache.set_cache(key=token, value=valid_token) user_api_key_cache.set_cache(key=token, value=valid_token)
### UPDATE SPEND LOGS ###
async def _insert_spend_log_to_db(): async def _insert_spend_log_to_db():
# Helper to generate payload to log # Helper to generate payload to log
verbose_proxy_logger.debug("inserting spend log to db") verbose_proxy_logger.debug("inserting spend log to db")

View file

@ -856,10 +856,15 @@ def get_logging_payload(kwargs, response_obj, start_time, end_time):
usage = response_obj["usage"] usage = response_obj["usage"]
id = response_obj.get("id", str(uuid.uuid4())) id = response_obj.get("id", str(uuid.uuid4()))
api_key = metadata.get("user_api_key", "") api_key = metadata.get("user_api_key", "")
if api_key is not None and type(api_key) == str: if api_key is not None and isinstance(api_key, str) and api_key.startswith("sk-"):
# hash the api_key # hash the api_key
api_key = hash_token(api_key) api_key = hash_token(api_key)
if "headers" in metadata and "authorization" in metadata["headers"]:
metadata["headers"].pop(
"authorization"
) # do not store the original `sk-..` api key in the db
payload = { payload = {
"request_id": id, "request_id": id,
"call_type": call_type, "call_type": call_type,

View file

@ -1067,9 +1067,13 @@ class Logging:
## if model in model cost map - log the response cost ## if model in model cost map - log the response cost
## else set cost to None ## else set cost to None
verbose_logger.debug(f"Model={self.model}; result={result}") verbose_logger.debug(f"Model={self.model}; result={result}")
if result is not None and ( if (
result is not None
and (
isinstance(result, ModelResponse) isinstance(result, ModelResponse)
or isinstance(result, EmbeddingResponse) or isinstance(result, EmbeddingResponse)
)
and self.stream != True
): ):
try: try:
self.model_call_details["response_cost"] = litellm.completion_cost( self.model_call_details["response_cost"] = litellm.completion_cost(
@ -1104,6 +1108,12 @@ class Logging:
self, result=None, start_time=None, end_time=None, cache_hit=None, **kwargs self, result=None, start_time=None, end_time=None, cache_hit=None, **kwargs
): ):
verbose_logger.debug(f"Logging Details LiteLLM-Success Call") verbose_logger.debug(f"Logging Details LiteLLM-Success Call")
start_time, end_time, result = self._success_handler_helper_fn(
start_time=start_time,
end_time=end_time,
result=result,
cache_hit=cache_hit,
)
# print(f"original response in success handler: {self.model_call_details['original_response']}") # print(f"original response in success handler: {self.model_call_details['original_response']}")
try: try:
verbose_logger.debug(f"success callbacks: {litellm.success_callback}") verbose_logger.debug(f"success callbacks: {litellm.success_callback}")
@ -1119,6 +1129,8 @@ class Logging:
complete_streaming_response = litellm.stream_chunk_builder( complete_streaming_response = litellm.stream_chunk_builder(
self.sync_streaming_chunks, self.sync_streaming_chunks,
messages=self.model_call_details.get("messages", None), messages=self.model_call_details.get("messages", None),
start_time=start_time,
end_time=end_time,
) )
except: except:
complete_streaming_response = None complete_streaming_response = None
@ -1132,13 +1144,19 @@ class Logging:
self.model_call_details[ self.model_call_details[
"complete_streaming_response" "complete_streaming_response"
] = complete_streaming_response ] = complete_streaming_response
try:
start_time, end_time, result = self._success_handler_helper_fn( self.model_call_details["response_cost"] = litellm.completion_cost(
start_time=start_time, completion_response=complete_streaming_response,
end_time=end_time,
result=result,
cache_hit=cache_hit,
) )
verbose_logger.debug(
f"Model={self.model}; cost={self.model_call_details['response_cost']}"
)
except litellm.NotFoundError as e:
verbose_logger.debug(
f"Model={self.model} not found in completion cost map."
)
self.model_call_details["response_cost"] = None
for callback in litellm.success_callback: for callback in litellm.success_callback:
try: try:
if callback == "lite_debugger": if callback == "lite_debugger":
@ -1423,6 +1441,18 @@ class Logging:
self.model_call_details[ self.model_call_details[
"complete_streaming_response" "complete_streaming_response"
] = complete_streaming_response ] = complete_streaming_response
try:
self.model_call_details["response_cost"] = litellm.completion_cost(
completion_response=complete_streaming_response,
)
verbose_logger.debug(
f"Model={self.model}; cost={self.model_call_details['response_cost']}"
)
except litellm.NotFoundError as e:
verbose_logger.debug(
f"Model={self.model} not found in completion cost map."
)
self.model_call_details["response_cost"] = None
for callback in litellm._async_success_callback: for callback in litellm._async_success_callback:
try: try:
@ -1470,6 +1500,19 @@ class Logging:
end_time=end_time, end_time=end_time,
) )
if callable(callback): # custom logger functions if callable(callback): # custom logger functions
if self.stream:
if "complete_streaming_response" in self.model_call_details:
await customLogger.async_log_event(
kwargs=self.model_call_details,
response_obj=self.model_call_details[
"complete_streaming_response"
],
start_time=start_time,
end_time=end_time,
print_verbose=print_verbose,
callback_func=callback,
)
else:
await customLogger.async_log_event( await customLogger.async_log_event(
kwargs=self.model_call_details, kwargs=self.model_call_details,
response_obj=result, response_obj=result,

View file

@ -4,13 +4,20 @@
import pytest import pytest
import asyncio import asyncio
import aiohttp import aiohttp
from openai import AsyncOpenAI
import sys, os
sys.path.insert(
0, os.path.abspath("../")
) # Adds the parent directory to the system path
import litellm
async def generate_key(session, i): async def generate_key(session, i):
url = "http://0.0.0.0:4000/key/generate" url = "http://0.0.0.0:4000/key/generate"
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
data = { data = {
"models": ["azure-models"], "models": ["azure-models", "gpt-4"],
"aliases": {"mistral-7b": "gpt-3.5-turbo"}, "aliases": {"mistral-7b": "gpt-3.5-turbo"},
"duration": None, "duration": None,
} }
@ -82,6 +89,36 @@ async def chat_completion(session, key, model="gpt-4"):
if status != 200: if status != 200:
raise Exception(f"Request did not return a 200 status code: {status}") raise Exception(f"Request did not return a 200 status code: {status}")
return await response.json()
async def chat_completion_streaming(session, key, model="gpt-4"):
client = AsyncOpenAI(api_key=key, base_url="http://0.0.0.0:4000")
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello!"},
]
prompt_tokens = litellm.token_counter(model="gpt-35-turbo", messages=messages)
assert prompt_tokens == 19
data = {
"model": model,
"messages": messages,
"stream": True,
}
response = await client.chat.completions.create(**data)
content = ""
async for chunk in response:
content += chunk.choices[0].delta.content or ""
print(f"content: {content}")
completion_tokens = litellm.token_counter(
model="azure/gpt-35-turbo", text=content, count_response_tokens=True
)
return prompt_tokens, completion_tokens
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_key_update(): async def test_key_update():
@ -181,3 +218,49 @@ async def test_key_info():
random_key = key_gen["key"] random_key = key_gen["key"]
status = await get_key_info(session=session, get_key=key, call_key=random_key) status = await get_key_info(session=session, get_key=key, call_key=random_key)
assert status == 403 assert status == 403
@pytest.mark.asyncio
async def test_key_info_spend_values():
"""
- create key
- make completion call
- assert cost is expected value
"""
async with aiohttp.ClientSession() as session:
## Test Spend Update ##
# completion
# response = await chat_completion(session=session, key=key)
# prompt_cost, completion_cost = litellm.cost_per_token(
# model="azure/gpt-35-turbo",
# prompt_tokens=response["usage"]["prompt_tokens"],
# completion_tokens=response["usage"]["completion_tokens"],
# )
# response_cost = prompt_cost + completion_cost
# await asyncio.sleep(5) # allow db log to be updated
# key_info = await get_key_info(session=session, get_key=key, call_key=key)
# print(
# f"response_cost: {response_cost}; key_info spend: {key_info['info']['spend']}"
# )
# assert response_cost == key_info["info"]["spend"]
## streaming
key_gen = await generate_key(session=session, i=0)
new_key = key_gen["key"]
prompt_tokens, completion_tokens = await chat_completion_streaming(
session=session, key=new_key
)
print(f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}")
prompt_cost, completion_cost = litellm.cost_per_token(
model="gpt-35-turbo",
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
response_cost = prompt_cost + completion_cost
await asyncio.sleep(5) # allow db log to be updated
key_info = await get_key_info(
session=session, get_key=new_key, call_key=new_key
)
print(
f"response_cost: {response_cost}; key_info spend: {key_info['info']['spend']}"
)
assert response_cost == key_info["info"]["spend"]

View file

@ -68,6 +68,7 @@ async def chat_completion(session, key):
if status != 200: if status != 200:
raise Exception(f"Request did not return a 200 status code: {status}") raise Exception(f"Request did not return a 200 status code: {status}")
return await response.json()
@pytest.mark.asyncio @pytest.mark.asyncio