diff --git a/.circleci/config.yml b/.circleci/config.yml index 9ec6c8db2..8685f4579 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -115,6 +115,25 @@ jobs: pip install "pytest==7.3.1" pip install "pytest-asyncio==0.21.1" pip install aiohttp + pip install openai + python -m pip install --upgrade pip + python -m pip install -r .circleci/requirements.txt + pip install "pytest==7.3.1" + pip install "pytest-asyncio==0.21.1" + pip install mypy + pip install "google-generativeai>=0.3.2" + pip install "google-cloud-aiplatform>=1.38.0" + pip install "boto3>=1.28.57" + pip install langchain + pip install "langfuse>=2.0.0" + pip install numpydoc + pip install prisma + pip install "httpx==0.24.1" + pip install "gunicorn==21.2.0" + pip install "anyio==3.7.1" + pip install "aiodynamo==23.10.1" + pip install "asyncio==3.4.3" + pip install "PyGithub==1.59.1" # Run pytest and generate JUnit XML report - run: name: Build Docker image diff --git a/litellm/proxy/admin_ui.py b/litellm/proxy/admin_ui.py index d50d8be90..c72cd88f0 100644 --- a/litellm/proxy/admin_ui.py +++ b/litellm/proxy/admin_ui.py @@ -98,7 +98,7 @@ def list_models(): st.error(f"An error occurred while requesting models: {e}") else: st.warning( - "Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page." + f"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page. Currently set Proxy Endpoint: {st.session_state.get('api_url', None)} and Proxy Key: {st.session_state.get('proxy_key', None)}" ) @@ -151,7 +151,7 @@ def create_key(): raise e else: st.warning( - "Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page." + f"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page. Currently set Proxy Endpoint: {st.session_state.get('api_url', None)} and Proxy Key: {st.session_state.get('proxy_key', None)}" ) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 9b21aa880..874731f1d 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -598,9 +598,9 @@ async def track_cost_callback( end_time=end_time, ) else: - if ( - kwargs["stream"] != True - or kwargs.get("complete_streaming_response", None) is not None + if kwargs["stream"] != True or ( + kwargs["stream"] == True + and kwargs.get("complete_streaming_response") in kwargs ): raise Exception( f"Model not in litellm model cost map. Add custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing" @@ -701,6 +701,7 @@ async def update_database( valid_token.spend = new_spend user_api_key_cache.set_cache(key=token, value=valid_token) + ### UPDATE SPEND LOGS ### async def _insert_spend_log_to_db(): # Helper to generate payload to log verbose_proxy_logger.debug("inserting spend log to db") @@ -1438,6 +1439,28 @@ async def async_data_generator(response, user_api_key_dict): yield f"data: {str(e)}\n\n" +def select_data_generator(response, user_api_key_dict): + try: + # since boto3 - sagemaker does not support async calls, we should use a sync data_generator + if ( + hasattr(response, "custom_llm_provider") + and response.custom_llm_provider == "sagemaker" + ): + return data_generator( + response=response, + ) + else: + # default to async_data_generator + return async_data_generator( + response=response, user_api_key_dict=user_api_key_dict + ) + except: + # worst case - use async_data_generator + return async_data_generator( + response=response, user_api_key_dict=user_api_key_dict + ) + + def get_litellm_model_info(model: dict = {}): model_info = model.get("model_info", {}) model_to_lookup = model.get("litellm_params", {}).get("model", None) @@ -1679,11 +1702,12 @@ async def completion( "stream" in data and data["stream"] == True ): # use generate_responses to stream responses custom_headers = {"x-litellm-model-id": model_id} + selected_data_generator = select_data_generator( + response=response, user_api_key_dict=user_api_key_dict + ) + return StreamingResponse( - async_data_generator( - user_api_key_dict=user_api_key_dict, - response=response, - ), + selected_data_generator, media_type="text/event-stream", headers=custom_headers, ) @@ -1841,11 +1865,12 @@ async def chat_completion( "stream" in data and data["stream"] == True ): # use generate_responses to stream responses custom_headers = {"x-litellm-model-id": model_id} + selected_data_generator = select_data_generator( + response=response, user_api_key_dict=user_api_key_dict + ) + return StreamingResponse( - async_data_generator( - user_api_key_dict=user_api_key_dict, - response=response, - ), + selected_data_generator, media_type="text/event-stream", headers=custom_headers, ) @@ -2305,6 +2330,94 @@ async def info_key_fn( ) +@router.get( + "/spend/keys", + tags=["Budget & Spend Tracking"], + dependencies=[Depends(user_api_key_auth)], +) +async def spend_key_fn(): + """ + View all keys created, ordered by spend + + Example Request: + ``` + curl -X GET "http://0.0.0.0:8000/spend/keys" \ +-H "Authorization: Bearer sk-1234" + ``` + """ + global prisma_client + try: + if prisma_client is None: + raise Exception( + f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys" + ) + + key_info = await prisma_client.get_data(table_name="key", query_type="find_all") + + return key_info + + except Exception as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={"error": str(e)}, + ) + + +@router.get( + "/spend/logs", + tags=["Budget & Spend Tracking"], + dependencies=[Depends(user_api_key_auth)], +) +async def view_spend_logs( + request_id: Optional[str] = fastapi.Query( + default=None, + description="request_id to get spend logs for specific request_id. If none passed then pass spend logs for all requests", + ), +): + """ + View all spend logs, if request_id is provided, only logs for that request_id will be returned + + Example Request for all logs + ``` + curl -X GET "http://0.0.0.0:8000/spend/logs" \ +-H "Authorization: Bearer sk-1234" + ``` + + Example Request for specific request_id + ``` + curl -X GET "http://0.0.0.0:8000/spend/logs?request_id=chatcmpl-6dcb2540-d3d7-4e49-bb27-291f863f112e" \ +-H "Authorization: Bearer sk-1234" + ``` + """ + global prisma_client + try: + if prisma_client is None: + raise Exception( + f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys" + ) + spend_logs = [] + if request_id is not None: + spend_log = await prisma_client.get_data( + table_name="spend", + query_type="find_unique", + request_id=request_id, + ) + return [spend_log] + else: + spend_logs = await prisma_client.get_data( + table_name="spend", query_type="find_all" + ) + return spend_logs + + return None + + except Exception as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={"error": str(e)}, + ) + + #### USER MANAGEMENT #### @router.post( "/user/new", diff --git a/litellm/proxy/tests/test_openai_js.js b/litellm/proxy/tests/test_openai_js.js index 7e74eeca3..c0f25cf05 100644 --- a/litellm/proxy/tests/test_openai_js.js +++ b/litellm/proxy/tests/test_openai_js.js @@ -4,7 +4,7 @@ const openai = require('openai'); process.env.DEBUG=false; async function runOpenAI() { const client = new openai.OpenAI({ - apiKey: 'sk-yPX56TDqBpr23W7ruFG3Yg', + apiKey: 'sk-JkKeNi6WpWDngBsghJ6B9g', baseURL: 'http://0.0.0.0:8000' }); diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 0f132d79b..1b3581427 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -361,7 +361,8 @@ class PrismaClient: self, token: Optional[str] = None, user_id: Optional[str] = None, - table_name: Optional[Literal["user", "key", "config"]] = None, + request_id: Optional[str] = None, + table_name: Optional[Literal["user", "key", "config", "spend"]] = None, query_type: Literal["find_unique", "find_all"] = "find_unique", expires: Optional[datetime] = None, reset_at: Optional[datetime] = None, @@ -411,6 +412,10 @@ class PrismaClient: for r in response: if isinstance(r.expires, datetime): r.expires = r.expires.isoformat() + elif query_type == "find_all": + response = await self.db.litellm_verificationtoken.find_many( + order={"spend": "desc"}, + ) print_verbose(f"PrismaClient: response={response}") if response is not None: return response @@ -427,6 +432,23 @@ class PrismaClient: } ) return response + elif table_name == "spend": + verbose_proxy_logger.debug( + f"PrismaClient: get_data: table_name == 'spend'" + ) + if request_id is not None: + response = await self.db.litellm_spendlogs.find_unique( # type: ignore + where={ + "request_id": request_id, + } + ) + return response + else: + response = await self.db.litellm_spendlogs.find_many( # type: ignore + order={"startTime": "desc"}, + ) + return response + except Exception as e: print_verbose(f"LiteLLM Prisma Client Exception: {e}") import traceback @@ -549,21 +571,20 @@ class PrismaClient: db_data = self.jsonify_object(data=data) if token is not None: print_verbose(f"token: {token}") - if query_type == "update": - # check if plain text or hash - if token.startswith("sk-"): - token = self.hash_token(token=token) - db_data["token"] = token - response = await self.db.litellm_verificationtoken.update( - where={"token": token}, # type: ignore - data={**db_data}, # type: ignore - ) - print_verbose( - "\033[91m" - + f"DB Token Table update succeeded {response}" - + "\033[0m" - ) - return {"token": token, "data": db_data} + # check if plain text or hash + if token.startswith("sk-"): + token = self.hash_token(token=token) + db_data["token"] = token + response = await self.db.litellm_verificationtoken.update( + where={"token": token}, # type: ignore + data={**db_data}, # type: ignore + ) + verbose_proxy_logger.debug( + "\033[91m" + + f"DB Token Table update succeeded {response}" + + "\033[0m" + ) + return {"token": token, "data": db_data} elif user_id is not None: """ If data['spend'] + data['user'], update the user table with spend info as well @@ -885,10 +906,15 @@ def get_logging_payload(kwargs, response_obj, start_time, end_time): usage = response_obj["usage"] id = response_obj.get("id", str(uuid.uuid4())) api_key = metadata.get("user_api_key", "") - if api_key is not None and type(api_key) == str: + if api_key is not None and isinstance(api_key, str) and api_key.startswith("sk-"): # hash the api_key api_key = hash_token(api_key) + if "headers" in metadata and "authorization" in metadata["headers"]: + metadata["headers"].pop( + "authorization" + ) # do not store the original `sk-..` api key in the db + payload = { "request_id": id, "call_type": call_type, diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 421580253..b2c69804c 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -1408,9 +1408,15 @@ def test_completion_sagemaker_stream(): ) complete_streaming_response = "" - - for chunk in response: + first_chunk_id, chunk_id = None, None + for i, chunk in enumerate(response): print(chunk) + chunk_id = chunk.id + print(chunk_id) + if i == 0: + first_chunk_id = chunk_id + else: + assert chunk_id == first_chunk_id complete_streaming_response += chunk.choices[0].delta.content or "" # Add any assertions here to check the response # print(response) diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index 6c5e8ee7d..d7ab4b880 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -960,3 +960,29 @@ def test_router_anthropic_key_dynamic(): messages = [{"role": "user", "content": "Hey, how's it going?"}] router.completion(model="anthropic-claude", messages=messages) os.environ["ANTHROPIC_API_KEY"] = anthropic_api_key + + +def test_router_timeout(): + model_list = [ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "gpt-3.5-turbo", + "api_key": "os.environ/OPENAI_API_KEY", + }, + } + ] + router = Router(model_list=model_list) + messages = [{"role": "user", "content": "Hey, how's it going?"}] + start_time = time.time() + try: + res = router.completion( + model="gpt-3.5-turbo", messages=messages, timeout=0.0001 + ) + print(res) + pytest.fail("this should have timed out") + except litellm.exceptions.Timeout as e: + print("got timeout exception") + print(e) + print(vars(e)) + pass diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index 959e63d59..14b1a7210 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -733,8 +733,15 @@ def test_completion_bedrock_claude_stream(): complete_response = "" has_finish_reason = False # Add any assertions here to check the response + first_chunk_id = None for idx, chunk in enumerate(response): # print + if idx == 0: + first_chunk_id = chunk.id + else: + assert ( + chunk.id == first_chunk_id + ), f"chunk ids do not match: {chunk.id} != first chunk id{first_chunk_id}" chunk, finished = streaming_format_tests(idx, chunk) has_finish_reason = finished complete_response += chunk diff --git a/litellm/utils.py b/litellm/utils.py index 7f8a447ad..03d38ff35 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1067,7 +1067,6 @@ class Logging: ## if model in model cost map - log the response cost ## else set cost to None verbose_logger.debug(f"Model={self.model}; result={result}") - verbose_logger.debug(f"self.stream: {self.stream}") if ( result is not None and ( @@ -1109,6 +1108,12 @@ class Logging: self, result=None, start_time=None, end_time=None, cache_hit=None, **kwargs ): verbose_logger.debug(f"Logging Details LiteLLM-Success Call") + start_time, end_time, result = self._success_handler_helper_fn( + start_time=start_time, + end_time=end_time, + result=result, + cache_hit=cache_hit, + ) # print(f"original response in success handler: {self.model_call_details['original_response']}") try: verbose_logger.debug(f"success callbacks: {litellm.success_callback}") @@ -1124,6 +1129,8 @@ class Logging: complete_streaming_response = litellm.stream_chunk_builder( self.sync_streaming_chunks, messages=self.model_call_details.get("messages", None), + start_time=start_time, + end_time=end_time, ) except: complete_streaming_response = None @@ -1137,13 +1144,19 @@ class Logging: self.model_call_details[ "complete_streaming_response" ] = complete_streaming_response + try: + self.model_call_details["response_cost"] = litellm.completion_cost( + completion_response=complete_streaming_response, + ) + verbose_logger.debug( + f"Model={self.model}; cost={self.model_call_details['response_cost']}" + ) + except litellm.NotFoundError as e: + verbose_logger.debug( + f"Model={self.model} not found in completion cost map." + ) + self.model_call_details["response_cost"] = None - start_time, end_time, result = self._success_handler_helper_fn( - start_time=start_time, - end_time=end_time, - result=result, - cache_hit=cache_hit, - ) for callback in litellm.success_callback: try: if callback == "lite_debugger": @@ -1487,14 +1500,27 @@ class Logging: end_time=end_time, ) if callable(callback): # custom logger functions - await customLogger.async_log_event( - kwargs=self.model_call_details, - response_obj=result, - start_time=start_time, - end_time=end_time, - print_verbose=print_verbose, - callback_func=callback, - ) + if self.stream: + if "complete_streaming_response" in self.model_call_details: + await customLogger.async_log_event( + kwargs=self.model_call_details, + response_obj=self.model_call_details[ + "complete_streaming_response" + ], + start_time=start_time, + end_time=end_time, + print_verbose=print_verbose, + callback_func=callback, + ) + else: + await customLogger.async_log_event( + kwargs=self.model_call_details, + response_obj=result, + start_time=start_time, + end_time=end_time, + print_verbose=print_verbose, + callback_func=callback, + ) if callback == "dynamodb": global dynamoLogger if dynamoLogger is None: @@ -2915,17 +2941,25 @@ def cost_per_token( ) return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar elif model_with_provider in model_cost_ref: - print_verbose(f"Looking up model={model_with_provider} in model_cost_map") + verbose_logger.debug( + f"Looking up model={model_with_provider} in model_cost_map" + ) + verbose_logger.debug( + f"applying cost={model_cost_ref[model_with_provider]['input_cost_per_token']} for prompt_tokens={prompt_tokens}" + ) prompt_tokens_cost_usd_dollar = ( model_cost_ref[model_with_provider]["input_cost_per_token"] * prompt_tokens ) + verbose_logger.debug( + f"applying cost={model_cost_ref[model_with_provider]['output_cost_per_token']} for completion_tokens={completion_tokens}" + ) completion_tokens_cost_usd_dollar = ( model_cost_ref[model_with_provider]["output_cost_per_token"] * completion_tokens ) return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar elif "ft:gpt-3.5-turbo" in model: - print_verbose(f"Cost Tracking: {model} is an OpenAI FinteTuned LLM") + verbose_logger.debug(f"Cost Tracking: {model} is an OpenAI FinteTuned LLM") # fuzzy match ft:gpt-3.5-turbo:abcd-id-cool-litellm prompt_tokens_cost_usd_dollar = ( model_cost_ref["ft:gpt-3.5-turbo"]["input_cost_per_token"] * prompt_tokens @@ -2936,17 +2970,23 @@ def cost_per_token( ) return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar elif model in litellm.azure_llms: - print_verbose(f"Cost Tracking: {model} is an Azure LLM") + verbose_logger.debug(f"Cost Tracking: {model} is an Azure LLM") model = litellm.azure_llms[model] + verbose_logger.debug( + f"applying cost={model_cost_ref[model]['input_cost_per_token']} for prompt_tokens={prompt_tokens}" + ) prompt_tokens_cost_usd_dollar = ( model_cost_ref[model]["input_cost_per_token"] * prompt_tokens ) + verbose_logger.debug( + f"applying cost={model_cost_ref[model]['output_cost_per_token']} for completion_tokens={completion_tokens}" + ) completion_tokens_cost_usd_dollar = ( model_cost_ref[model]["output_cost_per_token"] * completion_tokens ) return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar elif model in litellm.azure_embedding_models: - print_verbose(f"Cost Tracking: {model} is an Azure Embedding Model") + verbose_logger.debug(f"Cost Tracking: {model} is an Azure Embedding Model") model = litellm.azure_embedding_models[model] prompt_tokens_cost_usd_dollar = ( model_cost_ref[model]["input_cost_per_token"] * prompt_tokens @@ -7061,6 +7101,7 @@ class CustomStreamWrapper: self._hidden_params = { "model_id": (_model_info.get("id", None)) } # returned as x-litellm-model-id response header in proxy + self.response_id = None def __iter__(self): return self @@ -7633,6 +7674,10 @@ class CustomStreamWrapper: def chunk_creator(self, chunk): model_response = ModelResponse(stream=True, model=self.model) + if self.response_id is not None: + model_response.id = self.response_id + else: + self.response_id = model_response.id model_response._hidden_params["custom_llm_provider"] = self.custom_llm_provider model_response.choices = [StreamingChoices()] model_response.choices[0].finish_reason = None @@ -7752,10 +7797,8 @@ class CustomStreamWrapper: ] self.sent_last_chunk = True elif self.custom_llm_provider == "sagemaker": - print_verbose(f"ENTERS SAGEMAKER STREAMING") - new_chunk = next(self.completion_stream) - - completion_obj["content"] = new_chunk + print_verbose(f"ENTERS SAGEMAKER STREAMING for chunk {chunk}") + completion_obj["content"] = chunk elif self.custom_llm_provider == "petals": if len(self.completion_stream) == 0: if self.sent_last_chunk: @@ -7874,7 +7917,7 @@ class CustomStreamWrapper: completion_obj["role"] = "assistant" self.sent_first_chunk = True model_response.choices[0].delta = Delta(**completion_obj) - print_verbose(f"model_response: {model_response}") + print_verbose(f"returning model_response: {model_response}") return model_response else: return diff --git a/pyproject.toml b/pyproject.toml index 0a18f6af1..cd21db903 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "litellm" -version = "1.18.10" +version = "1.18.11" description = "Library to easily interface with LLM API providers" authors = ["BerriAI"] license = "MIT" @@ -63,7 +63,7 @@ requires = ["poetry-core", "wheel"] build-backend = "poetry.core.masonry.api" [tool.commitizen] -version = "1.18.10" +version = "1.18.11" version_files = [ "pyproject.toml:^version" ] diff --git a/tests/test_keys.py b/tests/test_keys.py index f209f4c5a..917c50823 100644 --- a/tests/test_keys.py +++ b/tests/test_keys.py @@ -2,15 +2,22 @@ ## Tests /key endpoints. import pytest -import asyncio +import asyncio, time import aiohttp +from openai import AsyncOpenAI +import sys, os + +sys.path.insert( + 0, os.path.abspath("../") +) # Adds the parent directory to the system path +import litellm async def generate_key(session, i): url = "http://0.0.0.0:4000/key/generate" headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} data = { - "models": ["azure-models"], + "models": ["azure-models", "gpt-4"], "aliases": {"mistral-7b": "gpt-3.5-turbo"}, "duration": None, } @@ -82,6 +89,35 @@ async def chat_completion(session, key, model="gpt-4"): if status != 200: raise Exception(f"Request did not return a 200 status code: {status}") + return await response.json() + + +async def chat_completion_streaming(session, key, model="gpt-4"): + client = AsyncOpenAI(api_key=key, base_url="http://0.0.0.0:4000") + messages = [ + {"role": "system", "content": "You are a helpful assistant"}, + {"role": "user", "content": f"Hello! {time.time()}"}, + ] + prompt_tokens = litellm.token_counter(model="gpt-35-turbo", messages=messages) + data = { + "model": model, + "messages": messages, + "stream": True, + } + response = await client.chat.completions.create(**data) + + content = "" + async for chunk in response: + content += chunk.choices[0].delta.content or "" + + print(f"content: {content}") + + completion_tokens = litellm.token_counter( + model="gpt-35-turbo", text=content, count_response_tokens=True + ) + + return prompt_tokens, completion_tokens + @pytest.mark.asyncio async def test_key_update(): @@ -181,3 +217,49 @@ async def test_key_info(): random_key = key_gen["key"] status = await get_key_info(session=session, get_key=key, call_key=random_key) assert status == 403 + + +@pytest.mark.asyncio +async def test_key_info_spend_values(): + """ + - create key + - make completion call + - assert cost is expected value + """ + async with aiohttp.ClientSession() as session: + ## Test Spend Update ## + # completion + # response = await chat_completion(session=session, key=key) + # prompt_cost, completion_cost = litellm.cost_per_token( + # model="azure/gpt-35-turbo", + # prompt_tokens=response["usage"]["prompt_tokens"], + # completion_tokens=response["usage"]["completion_tokens"], + # ) + # response_cost = prompt_cost + completion_cost + # await asyncio.sleep(5) # allow db log to be updated + # key_info = await get_key_info(session=session, get_key=key, call_key=key) + # print( + # f"response_cost: {response_cost}; key_info spend: {key_info['info']['spend']}" + # ) + # assert response_cost == key_info["info"]["spend"] + ## streaming + key_gen = await generate_key(session=session, i=0) + new_key = key_gen["key"] + prompt_tokens, completion_tokens = await chat_completion_streaming( + session=session, key=new_key + ) + print(f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}") + prompt_cost, completion_cost = litellm.cost_per_token( + model="azure/gpt-35-turbo", + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) + response_cost = prompt_cost + completion_cost + await asyncio.sleep(5) # allow db log to be updated + key_info = await get_key_info( + session=session, get_key=new_key, call_key=new_key + ) + print( + f"response_cost: {response_cost}; key_info spend: {key_info['info']['spend']}" + ) + assert response_cost == key_info["info"]["spend"] diff --git a/tests/test_openai_endpoints.py b/tests/test_openai_endpoints.py index 5a91bffa7..67d7c4db9 100644 --- a/tests/test_openai_endpoints.py +++ b/tests/test_openai_endpoints.py @@ -68,6 +68,7 @@ async def chat_completion(session, key): if status != 200: raise Exception(f"Request did not return a 200 status code: {status}") + return await response.json() @pytest.mark.asyncio diff --git a/ui/admin.py b/ui/admin.py index 2d823d85d..8b5c6b3ab 100644 --- a/ui/admin.py +++ b/ui/admin.py @@ -6,6 +6,9 @@ from dotenv import load_dotenv load_dotenv() import streamlit as st import base64, os, json, uuid, requests +import pandas as pd +import plotly.express as px +import click # Replace your_base_url with the actual URL where the proxy auth app is hosted your_base_url = os.getenv("BASE_URL") # Example base URL @@ -75,7 +78,7 @@ def add_new_model(): and st.session_state.get("proxy_key", None) is None ): st.warning( - "Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page." + f"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page. Currently set Proxy Endpoint: {st.session_state.get('api_url', None)} and Proxy Key: {st.session_state.get('proxy_key', None)}" ) model_name = st.text_input( @@ -174,10 +177,70 @@ def list_models(): st.error(f"An error occurred while requesting models: {e}") else: st.warning( - "Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page." + f"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page. Currently set Proxy Endpoint: {st.session_state.get('api_url', None)} and Proxy Key: {st.session_state.get('proxy_key', None)}" ) +def spend_per_key(): + import streamlit as st + import requests + + # Check if the necessary configuration is available + if ( + st.session_state.get("api_url", None) is not None + and st.session_state.get("proxy_key", None) is not None + ): + # Make the GET request + try: + complete_url = "" + if isinstance(st.session_state["api_url"], str) and st.session_state[ + "api_url" + ].endswith("/"): + complete_url = f"{st.session_state['api_url']}/spend/keys" + else: + complete_url = f"{st.session_state['api_url']}/spend/keys" + response = requests.get( + complete_url, + headers={"Authorization": f"Bearer {st.session_state['proxy_key']}"}, + ) + # Check if the request was successful + if response.status_code == 200: + spend_per_key = response.json() + # Create DataFrame + spend_df = pd.DataFrame(spend_per_key) + + # Display the spend per key as a graph + st.header("Spend ($) per API Key:") + top_10_df = spend_df.nlargest(10, "spend") + fig = px.bar( + top_10_df, + x="token", + y="spend", + title="Top 10 Spend per Key", + height=550, # Adjust the height + width=1200, # Adjust the width) + hover_data=["token", "spend", "user_id", "team_id"], + ) + st.plotly_chart(fig) + + # Display the spend per key as a table + st.write("Spend per Key - Full Table:") + st.table(spend_df) + + else: + st.error(f"Failed to get models. Status code: {response.status_code}") + except Exception as e: + st.error(f"An error occurred while requesting models: {e}") + else: + st.warning( + f"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page. Currently set Proxy Endpoint: {st.session_state.get('api_url', None)} and Proxy Key: {st.session_state.get('proxy_key', None)}" + ) + + +def spend_per_user(): + pass + + def create_key(): import streamlit as st import json, requests, uuid @@ -187,7 +250,7 @@ def create_key(): and st.session_state.get("proxy_key", None) is None ): st.warning( - "Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page." + f"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page. Currently set Proxy Endpoint: {st.session_state.get('api_url', None)} and Proxy Key: {st.session_state.get('proxy_key', None)}" ) duration = st.text_input("Duration - Can be in (h,m,s)", placeholder="1h") @@ -235,7 +298,7 @@ def update_config(): and st.session_state.get("proxy_key", None) is None ): st.warning( - "Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page." + f"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page. Currently set Proxy Endpoint: {st.session_state.get('api_url', None)} and Proxy Key: {st.session_state.get('proxy_key', None)}" ) st.markdown("#### Alerting") @@ -324,19 +387,25 @@ def update_config(): raise e -def admin_page(is_admin="NOT_GIVEN"): +def admin_page(is_admin="NOT_GIVEN", input_api_url=None, input_proxy_key=None): # Display the form for the admin to set the proxy URL and allowed email subdomain + st.set_page_config( + layout="wide", # Use "wide" layout for more space + ) st.header("Admin Configuration") st.session_state.setdefault("is_admin", is_admin) # Add a navigation sidebar st.sidebar.title("Navigation") + page = st.sidebar.radio( "Go to", ( "Connect to Proxy", + "View Spend Per Key", + "View Spend Per User", + "List Models", "Update Config", "Add Models", - "List Models", "Create Key", "End-User Auth", ), @@ -344,16 +413,23 @@ def admin_page(is_admin="NOT_GIVEN"): # Display different pages based on navigation selection if page == "Connect to Proxy": # Use text inputs with intermediary variables - input_api_url = st.text_input( - "Proxy Endpoint", - value=st.session_state.get("api_url", ""), - placeholder="http://0.0.0.0:8000", - ) - input_proxy_key = st.text_input( - "Proxy Key", - value=st.session_state.get("proxy_key", ""), - placeholder="sk-...", - ) + if input_api_url is None: + input_api_url = st.text_input( + "Proxy Endpoint", + value=st.session_state.get("api_url", ""), + placeholder="http://0.0.0.0:8000", + ) + else: + st.session_state["api_url"] = input_api_url + + if input_proxy_key is None: + input_proxy_key = st.text_input( + "Proxy Key", + value=st.session_state.get("proxy_key", ""), + placeholder="sk-...", + ) + else: + st.session_state["proxy_key"] = input_proxy_key # When the "Save" button is clicked, update the session state if st.button("Save"): st.session_state["api_url"] = input_api_url @@ -369,6 +445,21 @@ def admin_page(is_admin="NOT_GIVEN"): list_models() elif page == "Create Key": create_key() + elif page == "View Spend Per Key": + spend_per_key() + elif page == "View Spend Per User": + spend_per_user() -admin_page() +# admin_page() + + +@click.command() +@click.option("--proxy_endpoint", type=str, help="Proxy Endpoint") +@click.option("--proxy_master_key", type=str, help="Proxy Master Key") +def main(proxy_endpoint, proxy_master_key): + admin_page(input_api_url=proxy_endpoint, input_proxy_key=proxy_master_key) + + +if __name__ == "__main__": + main()