fix(utils.py): fix proxy streaming spend tracking

This commit is contained in:
Krrish Dholakia 2024-01-23 15:59:03 -08:00
parent 01a2514b98
commit f8870fb48e
4 changed files with 130 additions and 14 deletions

View file

@ -4,13 +4,20 @@
import pytest
import asyncio
import aiohttp
from openai import AsyncOpenAI
import sys, os
sys.path.insert(
0, os.path.abspath("../")
) # Adds the parent directory to the system path
import litellm
async def generate_key(session, i):
url = "http://0.0.0.0:4000/key/generate"
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
data = {
"models": ["azure-models"],
"models": ["azure-models", "gpt-4"],
"aliases": {"mistral-7b": "gpt-3.5-turbo"},
"duration": None,
}
@ -82,6 +89,34 @@ async def chat_completion(session, key, model="gpt-4"):
if status != 200:
raise Exception(f"Request did not return a 200 status code: {status}")
return await response.json()
async def chat_completion_streaming(session, key, model="gpt-4"):
client = AsyncOpenAI(api_key=key, base_url="http://0.0.0.0:4000")
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello!"},
]
data = {
"model": model,
"messages": messages,
"stream": True,
}
response = await client.chat.completions.create(**data)
content = ""
async for chunk in response:
content += chunk.choices[0].delta.content or ""
print(f"content: {content}")
prompt_tokens = litellm.token_counter(model="azure/gpt-35-turbo", messages=messages)
completion_tokens = litellm.token_counter(
model="azure/gpt-35-turbo", text=content, count_response_tokens=True
)
return prompt_tokens, completion_tokens
@pytest.mark.asyncio
async def test_key_update():
@ -181,3 +216,49 @@ async def test_key_info():
random_key = key_gen["key"]
status = await get_key_info(session=session, get_key=key, call_key=random_key)
assert status == 403
@pytest.mark.asyncio
async def test_key_info_spend_values():
"""
- create key
- make completion call
- assert cost is expected value
"""
async with aiohttp.ClientSession() as session:
## Test Spend Update ##
# completion
# response = await chat_completion(session=session, key=key)
# prompt_cost, completion_cost = litellm.cost_per_token(
# model="azure/gpt-35-turbo",
# prompt_tokens=response["usage"]["prompt_tokens"],
# completion_tokens=response["usage"]["completion_tokens"],
# )
# response_cost = prompt_cost + completion_cost
# await asyncio.sleep(5) # allow db log to be updated
# key_info = await get_key_info(session=session, get_key=key, call_key=key)
# print(
# f"response_cost: {response_cost}; key_info spend: {key_info['info']['spend']}"
# )
# assert response_cost == key_info["info"]["spend"]
## streaming
key_gen = await generate_key(session=session, i=0)
new_key = key_gen["key"]
prompt_tokens, completion_tokens = await chat_completion_streaming(
session=session, key=new_key
)
print(f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}")
prompt_cost, completion_cost = litellm.cost_per_token(
model="azure/gpt-35-turbo",
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
response_cost = prompt_cost + completion_cost
await asyncio.sleep(5) # allow db log to be updated
key_info = await get_key_info(
session=session, get_key=new_key, call_key=new_key
)
print(
f"response_cost: {response_cost}; key_info spend: {key_info['info']['spend']}"
)
assert response_cost == key_info["info"]["spend"]

View file

@ -68,6 +68,7 @@ async def chat_completion(session, key):
if status != 200:
raise Exception(f"Request did not return a 200 status code: {status}")
return await response.json()
@pytest.mark.asyncio