Merge branch 'main' into litellm_reset_key_budget

This commit is contained in:
Krish Dholakia 2024-01-23 18:10:32 -08:00 committed by GitHub
commit 9784d03d65
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 492 additions and 78 deletions

View file

@ -115,6 +115,25 @@ jobs:
pip install "pytest==7.3.1" pip install "pytest==7.3.1"
pip install "pytest-asyncio==0.21.1" pip install "pytest-asyncio==0.21.1"
pip install aiohttp pip install aiohttp
pip install openai
python -m pip install --upgrade pip
python -m pip install -r .circleci/requirements.txt
pip install "pytest==7.3.1"
pip install "pytest-asyncio==0.21.1"
pip install mypy
pip install "google-generativeai>=0.3.2"
pip install "google-cloud-aiplatform>=1.38.0"
pip install "boto3>=1.28.57"
pip install langchain
pip install "langfuse>=2.0.0"
pip install numpydoc
pip install prisma
pip install "httpx==0.24.1"
pip install "gunicorn==21.2.0"
pip install "anyio==3.7.1"
pip install "aiodynamo==23.10.1"
pip install "asyncio==3.4.3"
pip install "PyGithub==1.59.1"
# Run pytest and generate JUnit XML report # Run pytest and generate JUnit XML report
- run: - run:
name: Build Docker image name: Build Docker image

View file

@ -98,7 +98,7 @@ def list_models():
st.error(f"An error occurred while requesting models: {e}") st.error(f"An error occurred while requesting models: {e}")
else: else:
st.warning( st.warning(
"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page." f"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page. Currently set Proxy Endpoint: {st.session_state.get('api_url', None)} and Proxy Key: {st.session_state.get('proxy_key', None)}"
) )
@ -151,7 +151,7 @@ def create_key():
raise e raise e
else: else:
st.warning( st.warning(
"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page." f"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page. Currently set Proxy Endpoint: {st.session_state.get('api_url', None)} and Proxy Key: {st.session_state.get('proxy_key', None)}"
) )

View file

@ -598,9 +598,9 @@ async def track_cost_callback(
end_time=end_time, end_time=end_time,
) )
else: else:
if ( if kwargs["stream"] != True or (
kwargs["stream"] != True kwargs["stream"] == True
or kwargs.get("complete_streaming_response", None) is not None and kwargs.get("complete_streaming_response") in kwargs
): ):
raise Exception( raise Exception(
f"Model not in litellm model cost map. Add custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing" f"Model not in litellm model cost map. Add custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing"
@ -701,6 +701,7 @@ async def update_database(
valid_token.spend = new_spend valid_token.spend = new_spend
user_api_key_cache.set_cache(key=token, value=valid_token) user_api_key_cache.set_cache(key=token, value=valid_token)
### UPDATE SPEND LOGS ###
async def _insert_spend_log_to_db(): async def _insert_spend_log_to_db():
# Helper to generate payload to log # Helper to generate payload to log
verbose_proxy_logger.debug("inserting spend log to db") verbose_proxy_logger.debug("inserting spend log to db")
@ -1438,6 +1439,28 @@ async def async_data_generator(response, user_api_key_dict):
yield f"data: {str(e)}\n\n" yield f"data: {str(e)}\n\n"
def select_data_generator(response, user_api_key_dict):
try:
# since boto3 - sagemaker does not support async calls, we should use a sync data_generator
if (
hasattr(response, "custom_llm_provider")
and response.custom_llm_provider == "sagemaker"
):
return data_generator(
response=response,
)
else:
# default to async_data_generator
return async_data_generator(
response=response, user_api_key_dict=user_api_key_dict
)
except:
# worst case - use async_data_generator
return async_data_generator(
response=response, user_api_key_dict=user_api_key_dict
)
def get_litellm_model_info(model: dict = {}): def get_litellm_model_info(model: dict = {}):
model_info = model.get("model_info", {}) model_info = model.get("model_info", {})
model_to_lookup = model.get("litellm_params", {}).get("model", None) model_to_lookup = model.get("litellm_params", {}).get("model", None)
@ -1679,11 +1702,12 @@ async def completion(
"stream" in data and data["stream"] == True "stream" in data and data["stream"] == True
): # use generate_responses to stream responses ): # use generate_responses to stream responses
custom_headers = {"x-litellm-model-id": model_id} custom_headers = {"x-litellm-model-id": model_id}
selected_data_generator = select_data_generator(
response=response, user_api_key_dict=user_api_key_dict
)
return StreamingResponse( return StreamingResponse(
async_data_generator( selected_data_generator,
user_api_key_dict=user_api_key_dict,
response=response,
),
media_type="text/event-stream", media_type="text/event-stream",
headers=custom_headers, headers=custom_headers,
) )
@ -1841,11 +1865,12 @@ async def chat_completion(
"stream" in data and data["stream"] == True "stream" in data and data["stream"] == True
): # use generate_responses to stream responses ): # use generate_responses to stream responses
custom_headers = {"x-litellm-model-id": model_id} custom_headers = {"x-litellm-model-id": model_id}
selected_data_generator = select_data_generator(
response=response, user_api_key_dict=user_api_key_dict
)
return StreamingResponse( return StreamingResponse(
async_data_generator( selected_data_generator,
user_api_key_dict=user_api_key_dict,
response=response,
),
media_type="text/event-stream", media_type="text/event-stream",
headers=custom_headers, headers=custom_headers,
) )
@ -2305,6 +2330,94 @@ async def info_key_fn(
) )
@router.get(
"/spend/keys",
tags=["Budget & Spend Tracking"],
dependencies=[Depends(user_api_key_auth)],
)
async def spend_key_fn():
"""
View all keys created, ordered by spend
Example Request:
```
curl -X GET "http://0.0.0.0:8000/spend/keys" \
-H "Authorization: Bearer sk-1234"
```
"""
global prisma_client
try:
if prisma_client is None:
raise Exception(
f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
)
key_info = await prisma_client.get_data(table_name="key", query_type="find_all")
return key_info
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"error": str(e)},
)
@router.get(
"/spend/logs",
tags=["Budget & Spend Tracking"],
dependencies=[Depends(user_api_key_auth)],
)
async def view_spend_logs(
request_id: Optional[str] = fastapi.Query(
default=None,
description="request_id to get spend logs for specific request_id. If none passed then pass spend logs for all requests",
),
):
"""
View all spend logs, if request_id is provided, only logs for that request_id will be returned
Example Request for all logs
```
curl -X GET "http://0.0.0.0:8000/spend/logs" \
-H "Authorization: Bearer sk-1234"
```
Example Request for specific request_id
```
curl -X GET "http://0.0.0.0:8000/spend/logs?request_id=chatcmpl-6dcb2540-d3d7-4e49-bb27-291f863f112e" \
-H "Authorization: Bearer sk-1234"
```
"""
global prisma_client
try:
if prisma_client is None:
raise Exception(
f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
)
spend_logs = []
if request_id is not None:
spend_log = await prisma_client.get_data(
table_name="spend",
query_type="find_unique",
request_id=request_id,
)
return [spend_log]
else:
spend_logs = await prisma_client.get_data(
table_name="spend", query_type="find_all"
)
return spend_logs
return None
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"error": str(e)},
)
#### USER MANAGEMENT #### #### USER MANAGEMENT ####
@router.post( @router.post(
"/user/new", "/user/new",

View file

@ -4,7 +4,7 @@ const openai = require('openai');
process.env.DEBUG=false; process.env.DEBUG=false;
async function runOpenAI() { async function runOpenAI() {
const client = new openai.OpenAI({ const client = new openai.OpenAI({
apiKey: 'sk-yPX56TDqBpr23W7ruFG3Yg', apiKey: 'sk-JkKeNi6WpWDngBsghJ6B9g',
baseURL: 'http://0.0.0.0:8000' baseURL: 'http://0.0.0.0:8000'
}); });

View file

@ -361,7 +361,8 @@ class PrismaClient:
self, self,
token: Optional[str] = None, token: Optional[str] = None,
user_id: Optional[str] = None, user_id: Optional[str] = None,
table_name: Optional[Literal["user", "key", "config"]] = None, request_id: Optional[str] = None,
table_name: Optional[Literal["user", "key", "config", "spend"]] = None,
query_type: Literal["find_unique", "find_all"] = "find_unique", query_type: Literal["find_unique", "find_all"] = "find_unique",
expires: Optional[datetime] = None, expires: Optional[datetime] = None,
reset_at: Optional[datetime] = None, reset_at: Optional[datetime] = None,
@ -411,6 +412,10 @@ class PrismaClient:
for r in response: for r in response:
if isinstance(r.expires, datetime): if isinstance(r.expires, datetime):
r.expires = r.expires.isoformat() r.expires = r.expires.isoformat()
elif query_type == "find_all":
response = await self.db.litellm_verificationtoken.find_many(
order={"spend": "desc"},
)
print_verbose(f"PrismaClient: response={response}") print_verbose(f"PrismaClient: response={response}")
if response is not None: if response is not None:
return response return response
@ -427,6 +432,23 @@ class PrismaClient:
} }
) )
return response return response
elif table_name == "spend":
verbose_proxy_logger.debug(
f"PrismaClient: get_data: table_name == 'spend'"
)
if request_id is not None:
response = await self.db.litellm_spendlogs.find_unique( # type: ignore
where={
"request_id": request_id,
}
)
return response
else:
response = await self.db.litellm_spendlogs.find_many( # type: ignore
order={"startTime": "desc"},
)
return response
except Exception as e: except Exception as e:
print_verbose(f"LiteLLM Prisma Client Exception: {e}") print_verbose(f"LiteLLM Prisma Client Exception: {e}")
import traceback import traceback
@ -549,21 +571,20 @@ class PrismaClient:
db_data = self.jsonify_object(data=data) db_data = self.jsonify_object(data=data)
if token is not None: if token is not None:
print_verbose(f"token: {token}") print_verbose(f"token: {token}")
if query_type == "update": # check if plain text or hash
# check if plain text or hash if token.startswith("sk-"):
if token.startswith("sk-"): token = self.hash_token(token=token)
token = self.hash_token(token=token) db_data["token"] = token
db_data["token"] = token response = await self.db.litellm_verificationtoken.update(
response = await self.db.litellm_verificationtoken.update( where={"token": token}, # type: ignore
where={"token": token}, # type: ignore data={**db_data}, # type: ignore
data={**db_data}, # type: ignore )
) verbose_proxy_logger.debug(
print_verbose( "\033[91m"
"\033[91m" + f"DB Token Table update succeeded {response}"
+ f"DB Token Table update succeeded {response}" + "\033[0m"
+ "\033[0m" )
) return {"token": token, "data": db_data}
return {"token": token, "data": db_data}
elif user_id is not None: elif user_id is not None:
""" """
If data['spend'] + data['user'], update the user table with spend info as well If data['spend'] + data['user'], update the user table with spend info as well
@ -885,10 +906,15 @@ def get_logging_payload(kwargs, response_obj, start_time, end_time):
usage = response_obj["usage"] usage = response_obj["usage"]
id = response_obj.get("id", str(uuid.uuid4())) id = response_obj.get("id", str(uuid.uuid4()))
api_key = metadata.get("user_api_key", "") api_key = metadata.get("user_api_key", "")
if api_key is not None and type(api_key) == str: if api_key is not None and isinstance(api_key, str) and api_key.startswith("sk-"):
# hash the api_key # hash the api_key
api_key = hash_token(api_key) api_key = hash_token(api_key)
if "headers" in metadata and "authorization" in metadata["headers"]:
metadata["headers"].pop(
"authorization"
) # do not store the original `sk-..` api key in the db
payload = { payload = {
"request_id": id, "request_id": id,
"call_type": call_type, "call_type": call_type,

View file

@ -1408,9 +1408,15 @@ def test_completion_sagemaker_stream():
) )
complete_streaming_response = "" complete_streaming_response = ""
first_chunk_id, chunk_id = None, None
for chunk in response: for i, chunk in enumerate(response):
print(chunk) print(chunk)
chunk_id = chunk.id
print(chunk_id)
if i == 0:
first_chunk_id = chunk_id
else:
assert chunk_id == first_chunk_id
complete_streaming_response += chunk.choices[0].delta.content or "" complete_streaming_response += chunk.choices[0].delta.content or ""
# Add any assertions here to check the response # Add any assertions here to check the response
# print(response) # print(response)

View file

@ -960,3 +960,29 @@ def test_router_anthropic_key_dynamic():
messages = [{"role": "user", "content": "Hey, how's it going?"}] messages = [{"role": "user", "content": "Hey, how's it going?"}]
router.completion(model="anthropic-claude", messages=messages) router.completion(model="anthropic-claude", messages=messages)
os.environ["ANTHROPIC_API_KEY"] = anthropic_api_key os.environ["ANTHROPIC_API_KEY"] = anthropic_api_key
def test_router_timeout():
model_list = [
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
"api_key": "os.environ/OPENAI_API_KEY",
},
}
]
router = Router(model_list=model_list)
messages = [{"role": "user", "content": "Hey, how's it going?"}]
start_time = time.time()
try:
res = router.completion(
model="gpt-3.5-turbo", messages=messages, timeout=0.0001
)
print(res)
pytest.fail("this should have timed out")
except litellm.exceptions.Timeout as e:
print("got timeout exception")
print(e)
print(vars(e))
pass

View file

@ -733,8 +733,15 @@ def test_completion_bedrock_claude_stream():
complete_response = "" complete_response = ""
has_finish_reason = False has_finish_reason = False
# Add any assertions here to check the response # Add any assertions here to check the response
first_chunk_id = None
for idx, chunk in enumerate(response): for idx, chunk in enumerate(response):
# print # print
if idx == 0:
first_chunk_id = chunk.id
else:
assert (
chunk.id == first_chunk_id
), f"chunk ids do not match: {chunk.id} != first chunk id{first_chunk_id}"
chunk, finished = streaming_format_tests(idx, chunk) chunk, finished = streaming_format_tests(idx, chunk)
has_finish_reason = finished has_finish_reason = finished
complete_response += chunk complete_response += chunk

View file

@ -1067,7 +1067,6 @@ class Logging:
## if model in model cost map - log the response cost ## if model in model cost map - log the response cost
## else set cost to None ## else set cost to None
verbose_logger.debug(f"Model={self.model}; result={result}") verbose_logger.debug(f"Model={self.model}; result={result}")
verbose_logger.debug(f"self.stream: {self.stream}")
if ( if (
result is not None result is not None
and ( and (
@ -1109,6 +1108,12 @@ class Logging:
self, result=None, start_time=None, end_time=None, cache_hit=None, **kwargs self, result=None, start_time=None, end_time=None, cache_hit=None, **kwargs
): ):
verbose_logger.debug(f"Logging Details LiteLLM-Success Call") verbose_logger.debug(f"Logging Details LiteLLM-Success Call")
start_time, end_time, result = self._success_handler_helper_fn(
start_time=start_time,
end_time=end_time,
result=result,
cache_hit=cache_hit,
)
# print(f"original response in success handler: {self.model_call_details['original_response']}") # print(f"original response in success handler: {self.model_call_details['original_response']}")
try: try:
verbose_logger.debug(f"success callbacks: {litellm.success_callback}") verbose_logger.debug(f"success callbacks: {litellm.success_callback}")
@ -1124,6 +1129,8 @@ class Logging:
complete_streaming_response = litellm.stream_chunk_builder( complete_streaming_response = litellm.stream_chunk_builder(
self.sync_streaming_chunks, self.sync_streaming_chunks,
messages=self.model_call_details.get("messages", None), messages=self.model_call_details.get("messages", None),
start_time=start_time,
end_time=end_time,
) )
except: except:
complete_streaming_response = None complete_streaming_response = None
@ -1137,13 +1144,19 @@ class Logging:
self.model_call_details[ self.model_call_details[
"complete_streaming_response" "complete_streaming_response"
] = complete_streaming_response ] = complete_streaming_response
try:
self.model_call_details["response_cost"] = litellm.completion_cost(
completion_response=complete_streaming_response,
)
verbose_logger.debug(
f"Model={self.model}; cost={self.model_call_details['response_cost']}"
)
except litellm.NotFoundError as e:
verbose_logger.debug(
f"Model={self.model} not found in completion cost map."
)
self.model_call_details["response_cost"] = None
start_time, end_time, result = self._success_handler_helper_fn(
start_time=start_time,
end_time=end_time,
result=result,
cache_hit=cache_hit,
)
for callback in litellm.success_callback: for callback in litellm.success_callback:
try: try:
if callback == "lite_debugger": if callback == "lite_debugger":
@ -1487,14 +1500,27 @@ class Logging:
end_time=end_time, end_time=end_time,
) )
if callable(callback): # custom logger functions if callable(callback): # custom logger functions
await customLogger.async_log_event( if self.stream:
kwargs=self.model_call_details, if "complete_streaming_response" in self.model_call_details:
response_obj=result, await customLogger.async_log_event(
start_time=start_time, kwargs=self.model_call_details,
end_time=end_time, response_obj=self.model_call_details[
print_verbose=print_verbose, "complete_streaming_response"
callback_func=callback, ],
) start_time=start_time,
end_time=end_time,
print_verbose=print_verbose,
callback_func=callback,
)
else:
await customLogger.async_log_event(
kwargs=self.model_call_details,
response_obj=result,
start_time=start_time,
end_time=end_time,
print_verbose=print_verbose,
callback_func=callback,
)
if callback == "dynamodb": if callback == "dynamodb":
global dynamoLogger global dynamoLogger
if dynamoLogger is None: if dynamoLogger is None:
@ -2915,17 +2941,25 @@ def cost_per_token(
) )
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
elif model_with_provider in model_cost_ref: elif model_with_provider in model_cost_ref:
print_verbose(f"Looking up model={model_with_provider} in model_cost_map") verbose_logger.debug(
f"Looking up model={model_with_provider} in model_cost_map"
)
verbose_logger.debug(
f"applying cost={model_cost_ref[model_with_provider]['input_cost_per_token']} for prompt_tokens={prompt_tokens}"
)
prompt_tokens_cost_usd_dollar = ( prompt_tokens_cost_usd_dollar = (
model_cost_ref[model_with_provider]["input_cost_per_token"] * prompt_tokens model_cost_ref[model_with_provider]["input_cost_per_token"] * prompt_tokens
) )
verbose_logger.debug(
f"applying cost={model_cost_ref[model_with_provider]['output_cost_per_token']} for completion_tokens={completion_tokens}"
)
completion_tokens_cost_usd_dollar = ( completion_tokens_cost_usd_dollar = (
model_cost_ref[model_with_provider]["output_cost_per_token"] model_cost_ref[model_with_provider]["output_cost_per_token"]
* completion_tokens * completion_tokens
) )
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
elif "ft:gpt-3.5-turbo" in model: elif "ft:gpt-3.5-turbo" in model:
print_verbose(f"Cost Tracking: {model} is an OpenAI FinteTuned LLM") verbose_logger.debug(f"Cost Tracking: {model} is an OpenAI FinteTuned LLM")
# fuzzy match ft:gpt-3.5-turbo:abcd-id-cool-litellm # fuzzy match ft:gpt-3.5-turbo:abcd-id-cool-litellm
prompt_tokens_cost_usd_dollar = ( prompt_tokens_cost_usd_dollar = (
model_cost_ref["ft:gpt-3.5-turbo"]["input_cost_per_token"] * prompt_tokens model_cost_ref["ft:gpt-3.5-turbo"]["input_cost_per_token"] * prompt_tokens
@ -2936,17 +2970,23 @@ def cost_per_token(
) )
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
elif model in litellm.azure_llms: elif model in litellm.azure_llms:
print_verbose(f"Cost Tracking: {model} is an Azure LLM") verbose_logger.debug(f"Cost Tracking: {model} is an Azure LLM")
model = litellm.azure_llms[model] model = litellm.azure_llms[model]
verbose_logger.debug(
f"applying cost={model_cost_ref[model]['input_cost_per_token']} for prompt_tokens={prompt_tokens}"
)
prompt_tokens_cost_usd_dollar = ( prompt_tokens_cost_usd_dollar = (
model_cost_ref[model]["input_cost_per_token"] * prompt_tokens model_cost_ref[model]["input_cost_per_token"] * prompt_tokens
) )
verbose_logger.debug(
f"applying cost={model_cost_ref[model]['output_cost_per_token']} for completion_tokens={completion_tokens}"
)
completion_tokens_cost_usd_dollar = ( completion_tokens_cost_usd_dollar = (
model_cost_ref[model]["output_cost_per_token"] * completion_tokens model_cost_ref[model]["output_cost_per_token"] * completion_tokens
) )
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
elif model in litellm.azure_embedding_models: elif model in litellm.azure_embedding_models:
print_verbose(f"Cost Tracking: {model} is an Azure Embedding Model") verbose_logger.debug(f"Cost Tracking: {model} is an Azure Embedding Model")
model = litellm.azure_embedding_models[model] model = litellm.azure_embedding_models[model]
prompt_tokens_cost_usd_dollar = ( prompt_tokens_cost_usd_dollar = (
model_cost_ref[model]["input_cost_per_token"] * prompt_tokens model_cost_ref[model]["input_cost_per_token"] * prompt_tokens
@ -7061,6 +7101,7 @@ class CustomStreamWrapper:
self._hidden_params = { self._hidden_params = {
"model_id": (_model_info.get("id", None)) "model_id": (_model_info.get("id", None))
} # returned as x-litellm-model-id response header in proxy } # returned as x-litellm-model-id response header in proxy
self.response_id = None
def __iter__(self): def __iter__(self):
return self return self
@ -7633,6 +7674,10 @@ class CustomStreamWrapper:
def chunk_creator(self, chunk): def chunk_creator(self, chunk):
model_response = ModelResponse(stream=True, model=self.model) model_response = ModelResponse(stream=True, model=self.model)
if self.response_id is not None:
model_response.id = self.response_id
else:
self.response_id = model_response.id
model_response._hidden_params["custom_llm_provider"] = self.custom_llm_provider model_response._hidden_params["custom_llm_provider"] = self.custom_llm_provider
model_response.choices = [StreamingChoices()] model_response.choices = [StreamingChoices()]
model_response.choices[0].finish_reason = None model_response.choices[0].finish_reason = None
@ -7752,10 +7797,8 @@ class CustomStreamWrapper:
] ]
self.sent_last_chunk = True self.sent_last_chunk = True
elif self.custom_llm_provider == "sagemaker": elif self.custom_llm_provider == "sagemaker":
print_verbose(f"ENTERS SAGEMAKER STREAMING") print_verbose(f"ENTERS SAGEMAKER STREAMING for chunk {chunk}")
new_chunk = next(self.completion_stream) completion_obj["content"] = chunk
completion_obj["content"] = new_chunk
elif self.custom_llm_provider == "petals": elif self.custom_llm_provider == "petals":
if len(self.completion_stream) == 0: if len(self.completion_stream) == 0:
if self.sent_last_chunk: if self.sent_last_chunk:
@ -7874,7 +7917,7 @@ class CustomStreamWrapper:
completion_obj["role"] = "assistant" completion_obj["role"] = "assistant"
self.sent_first_chunk = True self.sent_first_chunk = True
model_response.choices[0].delta = Delta(**completion_obj) model_response.choices[0].delta = Delta(**completion_obj)
print_verbose(f"model_response: {model_response}") print_verbose(f"returning model_response: {model_response}")
return model_response return model_response
else: else:
return return

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "litellm" name = "litellm"
version = "1.18.10" version = "1.18.11"
description = "Library to easily interface with LLM API providers" description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"] authors = ["BerriAI"]
license = "MIT" license = "MIT"
@ -63,7 +63,7 @@ requires = ["poetry-core", "wheel"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"
[tool.commitizen] [tool.commitizen]
version = "1.18.10" version = "1.18.11"
version_files = [ version_files = [
"pyproject.toml:^version" "pyproject.toml:^version"
] ]

View file

@ -2,15 +2,22 @@
## Tests /key endpoints. ## Tests /key endpoints.
import pytest import pytest
import asyncio import asyncio, time
import aiohttp import aiohttp
from openai import AsyncOpenAI
import sys, os
sys.path.insert(
0, os.path.abspath("../")
) # Adds the parent directory to the system path
import litellm
async def generate_key(session, i): async def generate_key(session, i):
url = "http://0.0.0.0:4000/key/generate" url = "http://0.0.0.0:4000/key/generate"
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
data = { data = {
"models": ["azure-models"], "models": ["azure-models", "gpt-4"],
"aliases": {"mistral-7b": "gpt-3.5-turbo"}, "aliases": {"mistral-7b": "gpt-3.5-turbo"},
"duration": None, "duration": None,
} }
@ -82,6 +89,35 @@ async def chat_completion(session, key, model="gpt-4"):
if status != 200: if status != 200:
raise Exception(f"Request did not return a 200 status code: {status}") raise Exception(f"Request did not return a 200 status code: {status}")
return await response.json()
async def chat_completion_streaming(session, key, model="gpt-4"):
client = AsyncOpenAI(api_key=key, base_url="http://0.0.0.0:4000")
messages = [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": f"Hello! {time.time()}"},
]
prompt_tokens = litellm.token_counter(model="gpt-35-turbo", messages=messages)
data = {
"model": model,
"messages": messages,
"stream": True,
}
response = await client.chat.completions.create(**data)
content = ""
async for chunk in response:
content += chunk.choices[0].delta.content or ""
print(f"content: {content}")
completion_tokens = litellm.token_counter(
model="gpt-35-turbo", text=content, count_response_tokens=True
)
return prompt_tokens, completion_tokens
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_key_update(): async def test_key_update():
@ -181,3 +217,49 @@ async def test_key_info():
random_key = key_gen["key"] random_key = key_gen["key"]
status = await get_key_info(session=session, get_key=key, call_key=random_key) status = await get_key_info(session=session, get_key=key, call_key=random_key)
assert status == 403 assert status == 403
@pytest.mark.asyncio
async def test_key_info_spend_values():
"""
- create key
- make completion call
- assert cost is expected value
"""
async with aiohttp.ClientSession() as session:
## Test Spend Update ##
# completion
# response = await chat_completion(session=session, key=key)
# prompt_cost, completion_cost = litellm.cost_per_token(
# model="azure/gpt-35-turbo",
# prompt_tokens=response["usage"]["prompt_tokens"],
# completion_tokens=response["usage"]["completion_tokens"],
# )
# response_cost = prompt_cost + completion_cost
# await asyncio.sleep(5) # allow db log to be updated
# key_info = await get_key_info(session=session, get_key=key, call_key=key)
# print(
# f"response_cost: {response_cost}; key_info spend: {key_info['info']['spend']}"
# )
# assert response_cost == key_info["info"]["spend"]
## streaming
key_gen = await generate_key(session=session, i=0)
new_key = key_gen["key"]
prompt_tokens, completion_tokens = await chat_completion_streaming(
session=session, key=new_key
)
print(f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}")
prompt_cost, completion_cost = litellm.cost_per_token(
model="azure/gpt-35-turbo",
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
response_cost = prompt_cost + completion_cost
await asyncio.sleep(5) # allow db log to be updated
key_info = await get_key_info(
session=session, get_key=new_key, call_key=new_key
)
print(
f"response_cost: {response_cost}; key_info spend: {key_info['info']['spend']}"
)
assert response_cost == key_info["info"]["spend"]

View file

@ -68,6 +68,7 @@ async def chat_completion(session, key):
if status != 200: if status != 200:
raise Exception(f"Request did not return a 200 status code: {status}") raise Exception(f"Request did not return a 200 status code: {status}")
return await response.json()
@pytest.mark.asyncio @pytest.mark.asyncio

View file

@ -6,6 +6,9 @@ from dotenv import load_dotenv
load_dotenv() load_dotenv()
import streamlit as st import streamlit as st
import base64, os, json, uuid, requests import base64, os, json, uuid, requests
import pandas as pd
import plotly.express as px
import click
# Replace your_base_url with the actual URL where the proxy auth app is hosted # Replace your_base_url with the actual URL where the proxy auth app is hosted
your_base_url = os.getenv("BASE_URL") # Example base URL your_base_url = os.getenv("BASE_URL") # Example base URL
@ -75,7 +78,7 @@ def add_new_model():
and st.session_state.get("proxy_key", None) is None and st.session_state.get("proxy_key", None) is None
): ):
st.warning( st.warning(
"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page." f"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page. Currently set Proxy Endpoint: {st.session_state.get('api_url', None)} and Proxy Key: {st.session_state.get('proxy_key', None)}"
) )
model_name = st.text_input( model_name = st.text_input(
@ -174,10 +177,70 @@ def list_models():
st.error(f"An error occurred while requesting models: {e}") st.error(f"An error occurred while requesting models: {e}")
else: else:
st.warning( st.warning(
"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page." f"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page. Currently set Proxy Endpoint: {st.session_state.get('api_url', None)} and Proxy Key: {st.session_state.get('proxy_key', None)}"
) )
def spend_per_key():
import streamlit as st
import requests
# Check if the necessary configuration is available
if (
st.session_state.get("api_url", None) is not None
and st.session_state.get("proxy_key", None) is not None
):
# Make the GET request
try:
complete_url = ""
if isinstance(st.session_state["api_url"], str) and st.session_state[
"api_url"
].endswith("/"):
complete_url = f"{st.session_state['api_url']}/spend/keys"
else:
complete_url = f"{st.session_state['api_url']}/spend/keys"
response = requests.get(
complete_url,
headers={"Authorization": f"Bearer {st.session_state['proxy_key']}"},
)
# Check if the request was successful
if response.status_code == 200:
spend_per_key = response.json()
# Create DataFrame
spend_df = pd.DataFrame(spend_per_key)
# Display the spend per key as a graph
st.header("Spend ($) per API Key:")
top_10_df = spend_df.nlargest(10, "spend")
fig = px.bar(
top_10_df,
x="token",
y="spend",
title="Top 10 Spend per Key",
height=550, # Adjust the height
width=1200, # Adjust the width)
hover_data=["token", "spend", "user_id", "team_id"],
)
st.plotly_chart(fig)
# Display the spend per key as a table
st.write("Spend per Key - Full Table:")
st.table(spend_df)
else:
st.error(f"Failed to get models. Status code: {response.status_code}")
except Exception as e:
st.error(f"An error occurred while requesting models: {e}")
else:
st.warning(
f"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page. Currently set Proxy Endpoint: {st.session_state.get('api_url', None)} and Proxy Key: {st.session_state.get('proxy_key', None)}"
)
def spend_per_user():
pass
def create_key(): def create_key():
import streamlit as st import streamlit as st
import json, requests, uuid import json, requests, uuid
@ -187,7 +250,7 @@ def create_key():
and st.session_state.get("proxy_key", None) is None and st.session_state.get("proxy_key", None) is None
): ):
st.warning( st.warning(
"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page." f"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page. Currently set Proxy Endpoint: {st.session_state.get('api_url', None)} and Proxy Key: {st.session_state.get('proxy_key', None)}"
) )
duration = st.text_input("Duration - Can be in (h,m,s)", placeholder="1h") duration = st.text_input("Duration - Can be in (h,m,s)", placeholder="1h")
@ -235,7 +298,7 @@ def update_config():
and st.session_state.get("proxy_key", None) is None and st.session_state.get("proxy_key", None) is None
): ):
st.warning( st.warning(
"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page." f"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page. Currently set Proxy Endpoint: {st.session_state.get('api_url', None)} and Proxy Key: {st.session_state.get('proxy_key', None)}"
) )
st.markdown("#### Alerting") st.markdown("#### Alerting")
@ -324,19 +387,25 @@ def update_config():
raise e raise e
def admin_page(is_admin="NOT_GIVEN"): def admin_page(is_admin="NOT_GIVEN", input_api_url=None, input_proxy_key=None):
# Display the form for the admin to set the proxy URL and allowed email subdomain # Display the form for the admin to set the proxy URL and allowed email subdomain
st.set_page_config(
layout="wide", # Use "wide" layout for more space
)
st.header("Admin Configuration") st.header("Admin Configuration")
st.session_state.setdefault("is_admin", is_admin) st.session_state.setdefault("is_admin", is_admin)
# Add a navigation sidebar # Add a navigation sidebar
st.sidebar.title("Navigation") st.sidebar.title("Navigation")
page = st.sidebar.radio( page = st.sidebar.radio(
"Go to", "Go to",
( (
"Connect to Proxy", "Connect to Proxy",
"View Spend Per Key",
"View Spend Per User",
"List Models",
"Update Config", "Update Config",
"Add Models", "Add Models",
"List Models",
"Create Key", "Create Key",
"End-User Auth", "End-User Auth",
), ),
@ -344,16 +413,23 @@ def admin_page(is_admin="NOT_GIVEN"):
# Display different pages based on navigation selection # Display different pages based on navigation selection
if page == "Connect to Proxy": if page == "Connect to Proxy":
# Use text inputs with intermediary variables # Use text inputs with intermediary variables
input_api_url = st.text_input( if input_api_url is None:
"Proxy Endpoint", input_api_url = st.text_input(
value=st.session_state.get("api_url", ""), "Proxy Endpoint",
placeholder="http://0.0.0.0:8000", value=st.session_state.get("api_url", ""),
) placeholder="http://0.0.0.0:8000",
input_proxy_key = st.text_input( )
"Proxy Key", else:
value=st.session_state.get("proxy_key", ""), st.session_state["api_url"] = input_api_url
placeholder="sk-...",
) if input_proxy_key is None:
input_proxy_key = st.text_input(
"Proxy Key",
value=st.session_state.get("proxy_key", ""),
placeholder="sk-...",
)
else:
st.session_state["proxy_key"] = input_proxy_key
# When the "Save" button is clicked, update the session state # When the "Save" button is clicked, update the session state
if st.button("Save"): if st.button("Save"):
st.session_state["api_url"] = input_api_url st.session_state["api_url"] = input_api_url
@ -369,6 +445,21 @@ def admin_page(is_admin="NOT_GIVEN"):
list_models() list_models()
elif page == "Create Key": elif page == "Create Key":
create_key() create_key()
elif page == "View Spend Per Key":
spend_per_key()
elif page == "View Spend Per User":
spend_per_user()
admin_page() # admin_page()
@click.command()
@click.option("--proxy_endpoint", type=str, help="Proxy Endpoint")
@click.option("--proxy_master_key", type=str, help="Proxy Master Key")
def main(proxy_endpoint, proxy_master_key):
admin_page(input_api_url=proxy_endpoint, input_proxy_key=proxy_master_key)
if __name__ == "__main__":
main()