diff --git a/docs/my-website/docs/providers/openai_compatible.md b/docs/my-website/docs/providers/openai_compatible.md index f86544c28..09dcd7e4c 100644 --- a/docs/my-website/docs/providers/openai_compatible.md +++ b/docs/my-website/docs/providers/openai_compatible.md @@ -1,3 +1,6 @@ +import Tabs from '@theme/Tabs'; +import TabItem from '@theme/TabItem'; + # OpenAI-Compatible Endpoints To call models hosted behind an openai proxy, make 2 changes: @@ -39,4 +42,74 @@ response = litellm.embedding( input=["good morning from litellm"] ) print(response) -``` \ No newline at end of file +``` + + + +## Usage with LiteLLM Proxy Server + +Here's how to call an OpenAI-Compatible Endpoint with the LiteLLM Proxy Server + +1. Modify the config.yaml + + ```yaml + model_list: + - model_name: my-model + litellm_params: + model: openai/ # add openai/ prefix to route as OpenAI provider + api_base: # add api base for OpenAI compatible provider + api_key: api-key # api key to send your model + ``` + +2. Start the proxy + + ```bash + $ litellm --config /path/to/config.yaml + ``` + +3. Send Request to LiteLLM Proxy Server + + + + + + ```python + import openai + client = openai.OpenAI( + api_key="sk-1234", # pass litellm proxy key, if you're using virtual keys + base_url="http://0.0.0.0:4000" # litellm-proxy-base url + ) + + response = client.chat.completions.create( + model="my-model", + messages = [ + { + "role": "user", + "content": "what llm are you" + } + ], + ) + + print(response) + ``` + + + + + ```shell + curl --location 'http://0.0.0.0:4000/chat/completions' \ + --header 'Authorization: Bearer sk-1234' \ + --header 'Content-Type: application/json' \ + --data '{ + "model": "my-model", + "messages": [ + { + "role": "user", + "content": "what llm are you" + } + ], + }' + ``` + + + diff --git a/docs/my-website/docs/providers/vertex.md b/docs/my-website/docs/providers/vertex.md index fa18525ee..05720db13 100644 --- a/docs/my-website/docs/providers/vertex.md +++ b/docs/my-website/docs/providers/vertex.md @@ -23,58 +23,105 @@ litellm.vertex_location = "us-central1" # proj location response = litellm.completion(model="gemini-pro", messages=[{"role": "user", "content": "write code for saying hi from LiteLLM"}]) ``` -## OpenAI Proxy Usage +## Usage with LiteLLM Proxy Server Here's how to use Vertex AI with the LiteLLM Proxy Server 1. Modify the config.yaml - + - + -Use this when you need to set a different location for each vertex model + Use this when you need to set a different location for each vertex model -```yaml -model_list: - - model_name: gemini-vision - litellm_params: - model: vertex_ai/gemini-1.0-pro-vision-001 - vertex_project: "project-id" - vertex_location: "us-central1" - - model_name: gemini-vision - litellm_params: - model: vertex_ai/gemini-1.0-pro-vision-001 - vertex_project: "project-id2" - vertex_location: "us-east" -``` + ```yaml + model_list: + - model_name: gemini-vision + litellm_params: + model: vertex_ai/gemini-1.0-pro-vision-001 + vertex_project: "project-id" + vertex_location: "us-central1" + - model_name: gemini-vision + litellm_params: + model: vertex_ai/gemini-1.0-pro-vision-001 + vertex_project: "project-id2" + vertex_location: "us-east" + ``` - + - + -Use this when you have one vertex location for all models + Use this when you have one vertex location for all models -```yaml -litellm_settings: - vertex_project: "hardy-device-38811" # Your Project ID - vertex_location: "us-central1" # proj location + ```yaml + litellm_settings: + vertex_project: "hardy-device-38811" # Your Project ID + vertex_location: "us-central1" # proj location -model_list: - -model_name: team1-gemini-pro - litellm_params: - model: gemini-pro -``` + model_list: + -model_name: team1-gemini-pro + litellm_params: + model: gemini-pro + ``` - + - + 2. Start the proxy -```bash -$ litellm --config /path/to/config.yaml -``` + ```bash + $ litellm --config /path/to/config.yaml + ``` + +3. Send Request to LiteLLM Proxy Server + + + + + + ```python + import openai + client = openai.OpenAI( + api_key="sk-1234", # pass litellm proxy key, if you're using virtual keys + base_url="http://0.0.0.0:4000" # litellm-proxy-base url + ) + + response = client.chat.completions.create( + model="team1-gemini-pro", + messages = [ + { + "role": "user", + "content": "what llm are you" + } + ], + ) + + print(response) + ``` + + + + + ```shell + curl --location 'http://0.0.0.0:4000/chat/completions' \ + --header 'Authorization: Bearer sk-1234' \ + --header 'Content-Type: application/json' \ + --data '{ + "model": "team1-gemini-pro", + "messages": [ + { + "role": "user", + "content": "what llm are you" + } + ], + }' + ``` + + + ## Set Vertex Project & Vertex Location All calls using Vertex AI require the following parameters: diff --git a/docs/my-website/docs/proxy/deploy.md b/docs/my-website/docs/proxy/deploy.md index 54f151a58..d25035760 100644 --- a/docs/my-website/docs/proxy/deploy.md +++ b/docs/my-website/docs/proxy/deploy.md @@ -11,23 +11,56 @@ You can find the Dockerfile to build litellm proxy [here](https://github.com/Ber -See the latest available ghcr docker image here: -https://github.com/berriai/litellm/pkgs/container/litellm +**Step 1. Create a file called `litellm_config.yaml`** -Your litellm config.yaml should be called `litellm_config.yaml` in the directory you run this command. -The `-v` command will mount that file + Example `litellm_config.yaml` (the `os.environ/` prefix means litellm will read `AZURE_API_BASE` from the env) + ```yaml + model_list: + - model_name: azure-gpt-3.5 + litellm_params: + model: azure/ + api_base: os.environ/AZURE_API_BASE + api_key: os.environ/AZURE_API_KEY + api_version: "2023-07-01-preview" + ``` -`AZURE_API_KEY` and `AZURE_API_BASE` are not required to start, just examples on how to pass .env vars +**Step 2. Run litellm docker image** -```shell -docker run \ - -v $(pwd)/litellm_config.yaml:/app/config.yaml \ - -e AZURE_API_KEY=d6*********** \ - -e AZURE_API_BASE=https://openai-***********/ \ - -p 4000:4000 \ - ghcr.io/berriai/litellm:main-latest \ - --config /app/config.yaml --detailed_debug -``` + See the latest available ghcr docker image here: + https://github.com/berriai/litellm/pkgs/container/litellm + + Your litellm config.yaml should be called `litellm_config.yaml` in the directory you run this command. + The `-v` command will mount that file + + Pass `AZURE_API_KEY` and `AZURE_API_BASE` since we set them in step 1 + + ```shell + docker run \ + -v $(pwd)/litellm_config.yaml:/app/config.yaml \ + -e AZURE_API_KEY=d6*********** \ + -e AZURE_API_BASE=https://openai-***********/ \ + -p 4000:4000 \ + ghcr.io/berriai/litellm:main-latest \ + --config /app/config.yaml --detailed_debug + ``` + +**Step 3. Send a Test Request** + + Pass `model=azure-gpt-3.5` this was set on step 1 + + ```shell + curl --location 'http://0.0.0.0:4000/chat/completions' \ + --header 'Content-Type: application/json' \ + --data '{ + "model": "azure-gpt-3.5", + "messages": [ + { + "role": "user", + "content": "what llm are you" + } + ] + }' + ``` diff --git a/docs/my-website/docs/proxy/quick_start.md b/docs/my-website/docs/proxy/quick_start.md index d44970348..8c7d1c066 100644 --- a/docs/my-website/docs/proxy/quick_start.md +++ b/docs/my-website/docs/proxy/quick_start.md @@ -363,74 +363,6 @@ print(query_result[:5]) - GET `/models` - available models on server - POST `/key/generate` - generate a key to access the proxy -## Quick Start Docker Image: Github Container Registry - -### Pull the litellm ghcr docker image -See the latest available ghcr docker image here: -https://github.com/berriai/litellm/pkgs/container/litellm - -```shell -docker pull ghcr.io/berriai/litellm:main-latest -``` - -### Run the Docker Image -```shell -docker run ghcr.io/berriai/litellm:main-latest -``` - -#### Run the Docker Image with LiteLLM CLI args - -See all supported CLI args [here](https://docs.litellm.ai/docs/proxy/cli): - -Here's how you can run the docker image and pass your config to `litellm` -```shell -docker run ghcr.io/berriai/litellm:main-latest --config your_config.yaml -``` - -Here's how you can run the docker image and start litellm on port 8002 with `num_workers=8` -```shell -docker run ghcr.io/berriai/litellm:main-latest --port 8002 --num_workers 8 -``` - -#### Run the Docker Image using docker compose - -**Step 1** - -- (Recommended) Use the example file `docker-compose.example.yml` given in the project root. e.g. https://github.com/BerriAI/litellm/blob/main/docker-compose.example.yml - -- Rename the file `docker-compose.example.yml` to `docker-compose.yml`. - -Here's an example `docker-compose.yml` file -```yaml -version: "3.9" -services: - litellm: - image: ghcr.io/berriai/litellm:main - ports: - - "4000:4000" # Map the container port to the host, change the host port if necessary - volumes: - - ./litellm-config.yaml:/app/config.yaml # Mount the local configuration file - # You can change the port or number of workers as per your requirements or pass any new supported CLI augument. Make sure the port passed here matches with the container port defined above in `ports` value - command: [ "--config", "/app/config.yaml", "--port", "4000", "--num_workers", "8" ] - -# ...rest of your docker-compose config if any -``` - -**Step 2** - -Create a `litellm-config.yaml` file with your LiteLLM config relative to your `docker-compose.yml` file. - -Check the config doc [here](https://docs.litellm.ai/docs/proxy/configs) - -**Step 3** - -Run the command `docker-compose up` or `docker compose up` as per your docker installation. - -> Use `-d` flag to run the container in detached mode (background) e.g. `docker compose up -d` - - -Your LiteLLM container should be running now on the defined port e.g. `4000`. - ## Using with OpenAI compatible projects Set `base_url` to the LiteLLM Proxy server diff --git a/litellm/caching.py b/litellm/caching.py index 67fb1ec53..5a9008342 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -841,6 +841,17 @@ class DualCache(BaseCache): except Exception as e: traceback.print_exc() + async def async_set_cache(self, key, value, local_only: bool = False, **kwargs): + try: + if self.in_memory_cache is not None: + await self.in_memory_cache.async_set_cache(key, value, **kwargs) + + if self.redis_cache is not None and local_only == False: + await self.redis_cache.async_set_cache(key, value, **kwargs) + except Exception as e: + print_verbose(f"LiteLLM Cache: Excepton async add_cache: {str(e)}") + traceback.print_exc() + def flush_cache(self): if self.in_memory_cache is not None: self.in_memory_cache.flush_cache() diff --git a/litellm/llms/custom_httpx/bedrock_async.py b/litellm/llms/custom_httpx/bedrock_async.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 4d8ad200a..b5c50b143 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -632,6 +632,8 @@ class LiteLLM_UserTable(LiteLLMBase): model_spend: Optional[Dict] = {} user_email: Optional[str] models: list = [] + tpm_limit: Optional[int] = None + rpm_limit: Optional[int] = None @root_validator(pre=True) def set_model_info(cls, values): @@ -650,6 +652,7 @@ class LiteLLM_EndUserTable(LiteLLMBase): blocked: bool alias: Optional[str] = None spend: float = 0.0 + litellm_budget_table: Optional[LiteLLM_BudgetTable] = None @root_validator(pre=True) def set_model_info(cls, values): diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py new file mode 100644 index 000000000..cd326cc6d --- /dev/null +++ b/litellm/proxy/auth/auth_checks.py @@ -0,0 +1,84 @@ +# What is this? +## Common auth checks between jwt + key based auth +""" +Got Valid Token from Cache, DB +Run checks for: + +1. If user can call model +2. If user is in budget +3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget +""" +from litellm.proxy._types import LiteLLM_UserTable, LiteLLM_EndUserTable +from typing import Optional +from litellm.proxy.utils import PrismaClient +from litellm.caching import DualCache + + +def common_checks( + request_body: dict, + user_object: LiteLLM_UserTable, + end_user_object: Optional[LiteLLM_EndUserTable], +) -> bool: + _model = request_body.get("model", None) + # 1. If user can call model + if ( + _model is not None + and len(user_object.models) > 0 + and _model not in user_object.models + ): + raise Exception( + f"User={user_object.user_id} not allowed to call model={_model}. Allowed user models = {user_object.models}" + ) + # 2. If user is in budget + if ( + user_object.max_budget is not None + and user_object.spend > user_object.max_budget + ): + raise Exception( + f"User={user_object.user_id} over budget. Spend={user_object.spend}, Budget={user_object.max_budget}" + ) + # 3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget + if end_user_object is not None and end_user_object.litellm_budget_table is not None: + end_user_budget = end_user_object.litellm_budget_table.max_budget + if end_user_budget is not None and end_user_object.spend > end_user_budget: + raise Exception( + f"End User={end_user_object.user_id} over budget. Spend={end_user_object.spend}, Budget={end_user_budget}" + ) + return True + + +async def get_end_user_object( + end_user_id: Optional[str], + prisma_client: Optional[PrismaClient], + user_api_key_cache: DualCache, +) -> Optional[LiteLLM_EndUserTable]: + """ + Returns end user object, if in db. + + Do a isolated check for end user in table vs. doing a combined key + team + user + end-user check, as key might come in frequently for different end-users. Larger call will slowdown query time. This way we get to cache the constant (key/team/user info) and only update based on the changing value (end-user). + """ + if prisma_client is None: + raise Exception("No db connected") + + if end_user_id is None: + return None + + # check if in cache + cached_user_obj = user_api_key_cache.async_get_cache(key=end_user_id) + if cached_user_obj is not None: + if isinstance(cached_user_obj, dict): + return LiteLLM_EndUserTable(**cached_user_obj) + elif isinstance(cached_user_obj, LiteLLM_EndUserTable): + return cached_user_obj + # else, check db + try: + response = await prisma_client.db.litellm_endusertable.find_unique( + where={"user_id": end_user_id} + ) + + if response is None: + raise Exception + + return LiteLLM_EndUserTable(**response.dict()) + except Exception as e: # if end-user not in db + return None diff --git a/litellm/proxy/auth/handle_jwt.py b/litellm/proxy/auth/handle_jwt.py index 4342f3365..ad69543d5 100644 --- a/litellm/proxy/auth/handle_jwt.py +++ b/litellm/proxy/auth/handle_jwt.py @@ -8,23 +8,27 @@ JWT token must have 'litellm_proxy_admin' in scope. import httpx import jwt - -print(jwt.__version__) # noqa from jwt.algorithms import RSAAlgorithm import json import os -from litellm.proxy._types import LiteLLMProxyRoles +from litellm.caching import DualCache +from litellm.proxy._types import LiteLLMProxyRoles, LiteLLM_UserTable +from litellm.proxy.utils import PrismaClient from typing import Optional class HTTPHandler: - def __init__(self): - self.client = httpx.AsyncClient() + def __init__(self, concurrent_limit=1000): + # Create a client with a connection pool + self.client = httpx.AsyncClient( + limits=httpx.Limits( + max_connections=concurrent_limit, + max_keepalive_connections=concurrent_limit, + ) + ) - async def __aenter__(self): - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): + async def close(self): + # Close the client when you're done with it await self.client.aclose() async def get( @@ -47,10 +51,27 @@ class HTTPHandler: class JWTHandler: + """ + - treat the sub id passed in as the user id + - return an error if id making request doesn't exist in proxy user table + - track spend against the user id + - if role="litellm_proxy_user" -> allow making calls + info. Can not edit budgets + """ - def __init__(self) -> None: + prisma_client: Optional[PrismaClient] + user_api_key_cache: DualCache + + def __init__( + self, + ) -> None: self.http_handler = HTTPHandler() + def update_environment( + self, prisma_client: Optional[PrismaClient], user_api_key_cache: DualCache + ) -> None: + self.prisma_client = prisma_client + self.user_api_key_cache = user_api_key_cache + def is_jwt(self, token: str): parts = token.split(".") return len(parts) == 3 @@ -67,6 +88,46 @@ class JWTHandler: user_id = default_value return user_id + def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str]: + try: + team_id = token["azp"] + except KeyError: + team_id = default_value + return team_id + + async def get_user_object(self, user_id: str) -> LiteLLM_UserTable: + """ + - Check if user id in proxy User Table + - if valid, return LiteLLM_UserTable object with defined limits + - if not, then raise an error + """ + if self.prisma_client is None: + raise Exception( + "No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys" + ) + + # check if in cache + cached_user_obj = self.user_api_key_cache.async_get_cache(key=user_id) + if cached_user_obj is not None: + if isinstance(cached_user_obj, dict): + return LiteLLM_UserTable(**cached_user_obj) + elif isinstance(cached_user_obj, LiteLLM_UserTable): + return cached_user_obj + # else, check db + try: + response = await self.prisma_client.db.litellm_usertable.find_unique( + where={"user_id": user_id} + ) + + if response is None: + raise Exception + + return LiteLLM_UserTable(**response.dict()) + except Exception as e: + raise Exception( + f"User doesn't exist in db. User={user_id}. Create user via `/user/new` call." + ) + def get_scopes(self, token: dict) -> list: try: # Assuming the scopes are stored in 'scope' claim and are space-separated @@ -78,8 +139,10 @@ class JWTHandler: async def auth_jwt(self, token: str) -> dict: keys_url = os.getenv("JWT_PUBLIC_KEY_URL") - async with self.http_handler as http: - response = await http.get(keys_url) + if keys_url is None: + raise Exception("Missing JWT Public Key URL from environment.") + + response = await self.http_handler.get(keys_url) keys = response.json()["keys"] @@ -113,3 +176,6 @@ class JWTHandler: raise Exception(f"Validation fails: {str(e)}") raise Exception("Invalid JWT Submitted") + + async def close(self): + await self.http_handler.close() diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index fd0bb6cd9..0d6539358 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -110,6 +110,7 @@ from litellm.proxy.auth.handle_jwt import JWTHandler from litellm.proxy.hooks.prompt_injection_detection import ( _OPTIONAL_PromptInjectionDetection, ) +from litellm.proxy.auth.auth_checks import common_checks, get_end_user_object try: from litellm._version import version @@ -364,20 +365,54 @@ async def user_api_key_auth( user_id = jwt_handler.get_user_id( token=valid_token, default_value=litellm_proxy_admin_name ) + + end_user_object = None + # get the request body + request_data = await _read_request_body(request=request) + # get user obj from cache/db -> run for admin too. Ensures, admin client id in db. + user_object = await jwt_handler.get_user_object(user_id=user_id) + if ( + request_data.get("user", None) + and request_data["user"] != user_object.user_id + ): + # get the end-user object + end_user_object = await get_end_user_object( + end_user_id=request_data["user"], + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + ) + # save the end-user object to cache + await user_api_key_cache.async_set_cache( + key=request_data["user"], value=end_user_object + ) + + # run through common checks + _ = common_checks( + request_body=request_data, + user_object=user_object, + end_user_object=end_user_object, + ) + # save user object in cache + await user_api_key_cache.async_set_cache( + key=user_object.user_id, value=user_object + ) # if admin return if is_admin: - _user_api_key_obj = UserAPIKeyAuth( + return UserAPIKeyAuth( api_key=api_key, user_role="proxy_admin", user_id=user_id, ) - user_api_key_cache.set_cache( - key=hash_token(api_key), value=_user_api_key_obj + else: + # return UserAPIKeyAuth object + return UserAPIKeyAuth( + api_key=None, + user_id=user_object.user_id, + tpm_limit=user_object.tpm_limit, + rpm_limit=user_object.rpm_limit, + models=user_object.models, + user_role="app_owner", ) - - return _user_api_key_obj - else: - raise Exception("Invalid key error!") #### ELSE #### if master_key is None: if isinstance(api_key, str): @@ -442,7 +477,7 @@ async def user_api_key_auth( user_role="proxy_admin", user_id=litellm_proxy_admin_name, ) - user_api_key_cache.set_cache( + await user_api_key_cache.async_set_cache( key=hash_token(master_key), value=_user_api_key_obj ) @@ -607,7 +642,7 @@ async def user_api_key_auth( query_type="find_all", ) for _id in user_id_information: - user_api_key_cache.set_cache( + await user_api_key_cache.async_set_cache( key=_id["user_id"], value=_id, ttl=600 ) if custom_db_client is not None: @@ -795,7 +830,9 @@ async def user_api_key_auth( api_key = valid_token.token # Add hashed token to cache - user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=600) + await user_api_key_cache.async_set_cache( + key=api_key, value=valid_token, ttl=600 + ) valid_token_dict = _get_pydantic_json_dict(valid_token) valid_token_dict.pop("token", None) """ @@ -1077,7 +1114,10 @@ async def _PROXY_track_cost_callback( ) await update_cache( - token=user_api_key, user_id=user_id, response_cost=response_cost + token=user_api_key, + user_id=user_id, + end_user_id=end_user_id, + response_cost=response_cost, ) else: raise Exception("User API key missing from custom callback.") @@ -1352,9 +1392,10 @@ async def update_database( async def update_cache( - token, - user_id, - response_cost, + token: Optional[str], + user_id: Optional[str], + end_user_id: Optional[str], + response_cost: Optional[float], ): """ Use this to update the cache with new user spend. @@ -1369,12 +1410,17 @@ async def update_cache( hashed_token = hash_token(token=token) else: hashed_token = token + verbose_proxy_logger.debug(f"_update_key_cache: hashed_token={hashed_token}") existing_spend_obj = await user_api_key_cache.async_get_cache(key=hashed_token) verbose_proxy_logger.debug( - f"_update_key_db: existing spend: {existing_spend_obj}" + f"_update_key_cache: existing_spend_obj={existing_spend_obj}" + ) + verbose_proxy_logger.debug( + f"_update_key_cache: existing spend: {existing_spend_obj}" ) if existing_spend_obj is None: existing_spend = 0 + existing_spend_obj = LiteLLM_VerificationTokenView() else: existing_spend = existing_spend_obj.spend # Calculate the new cost by adding the existing cost and response_cost @@ -1430,18 +1476,7 @@ async def update_cache( async def _update_user_cache(): ## UPDATE CACHE FOR USER ID + GLOBAL PROXY - end_user_id = None - if isinstance(token, str) and token.startswith("sk-"): - hashed_token = hash_token(token=token) - else: - hashed_token = token - existing_token_obj = await user_api_key_cache.async_get_cache(key=hashed_token) - if existing_token_obj is None: - return - if existing_token_obj.user_id != user_id: # an end-user id was passed in - end_user_id = user_id - user_ids = [existing_token_obj.user_id, litellm_proxy_budget_name, end_user_id] - + user_ids = [user_id, litellm_proxy_budget_name, end_user_id] try: for _id in user_ids: # Fetch the existing cost for the given user @@ -1487,9 +1522,59 @@ async def update_cache( f"An error occurred updating user cache: {str(e)}\n\n{traceback.format_exc()}" ) - asyncio.create_task(_update_key_cache()) + async def _update_end_user_cache(): + ## UPDATE CACHE FOR USER ID + GLOBAL PROXY + _id = end_user_id + try: + # Fetch the existing cost for the given user + existing_spend_obj = await user_api_key_cache.async_get_cache(key=_id) + if existing_spend_obj is None: + # if user does not exist in LiteLLM_UserTable, create a new user + existing_spend = 0 + max_user_budget = None + if litellm.max_user_budget is not None: + max_user_budget = litellm.max_user_budget + existing_spend_obj = LiteLLM_EndUserTable( + user_id=_id, + spend=0, + blocked=False, + litellm_budget_table=LiteLLM_BudgetTable( + max_budget=max_user_budget + ), + ) + verbose_proxy_logger.debug( + f"_update_end_user_db: existing spend: {existing_spend_obj}; response_cost: {response_cost}" + ) + if existing_spend_obj is None: + existing_spend = 0 + else: + if isinstance(existing_spend_obj, dict): + existing_spend = existing_spend_obj["spend"] + else: + existing_spend = existing_spend_obj.spend + # Calculate the new cost by adding the existing cost and response_cost + new_spend = existing_spend + response_cost + + # Update the cost column for the given user + if isinstance(existing_spend_obj, dict): + existing_spend_obj["spend"] = new_spend + user_api_key_cache.set_cache(key=_id, value=existing_spend_obj) + else: + existing_spend_obj.spend = new_spend + user_api_key_cache.set_cache(key=_id, value=existing_spend_obj.json()) + except Exception as e: + verbose_proxy_logger.debug( + f"An error occurred updating end user cache: {str(e)}\n\n{traceback.format_exc()}" + ) + + if token is not None: + asyncio.create_task(_update_key_cache()) + asyncio.create_task(_update_user_cache()) + if end_user_id is not None: + asyncio.create_task(_update_end_user_cache()) + def run_ollama_serve(): try: @@ -1881,7 +1966,7 @@ class ProxyConfig: elif key == "success_callback": litellm.success_callback = [] - # intialize success callbacks + # initialize success callbacks for callback in value: # user passed custom_callbacks.async_on_succes_logger. They need us to import a function if "." in callback: @@ -1906,7 +1991,7 @@ class ProxyConfig: elif key == "failure_callback": litellm.failure_callback = [] - # intialize success callbacks + # initialize success callbacks for callback in value: # user passed custom_callbacks.async_on_succes_logger. They need us to import a function if "." in callback: @@ -2604,6 +2689,11 @@ async def startup_event(): proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made + ## JWT AUTH ## + jwt_handler.update_environment( + prisma_client=prisma_client, user_api_key_cache=user_api_key_cache + ) + if use_background_health_checks: asyncio.create_task( _run_background_health_check() @@ -7771,6 +7861,8 @@ async def shutdown_event(): if litellm.cache is not None: await litellm.cache.disconnect() + await jwt_handler.close() + ## RESET CUSTOM VARIABLES ## cleanup_router_config_variables()