Merge branch 'litellm_dynamo_db_keys'

This commit is contained in:
Krrish Dholakia 2024-01-13 18:38:43 +05:30
commit 8a7a745549
12 changed files with 547 additions and 58 deletions

Binary file not shown.

BIN
dist/litellm-1.16.21.dev1.tar.gz vendored Normal file

Binary file not shown.

Binary file not shown.

BIN
dist/litellm-1.16.21.dev2.tar.gz vendored Normal file

Binary file not shown.

Binary file not shown.

BIN
dist/litellm-1.16.21.dev3.tar.gz vendored Normal file

Binary file not shown.

View file

@ -246,3 +246,35 @@ general_settings:
$ litellm --config /path/to/config.yaml $ litellm --config /path/to/config.yaml
``` ```
## [BETA] Dynamo DB
Only live in `v1.16.21.dev1`.
### Step 1. Save keys to env
```env
AWS_ACCESS_KEY_ID = "your-aws-access-key-id"
AWS_SECRET_ACCESS_KEY = "your-aws-secret-access-key"
```
### Step 2. Add details to config
```yaml
general_settings:
master_key: sk-1234
database_type: "dynamo_db"
database_args: { # 👈 all args - https://github.com/BerriAI/litellm/blob/befbcbb7ac8f59835ce47415c128decf37aac328/litellm/proxy/_types.py#L190
"billing_mode": "PAY_PER_REQUEST",
"region_name": "us-west-2"
}
```
### Step 3. Generate Key
```bash
curl --location 'http://0.0.0.0:8000/key/generate' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{"models": ["azure-models"], "aliases": {"mistral-7b": "gpt-3.5-turbo"}, "duration": null}'
```

View file

@ -17,6 +17,13 @@ class LiteLLMBase(BaseModel):
# if using pydantic v1 # if using pydantic v1
return self.dict() return self.dict()
def fields_set(self):
try:
return self.model_fields_set # noqa
except:
# if using pydantic v1
return self.__fields_set__
######### Request Class Definition ###### ######### Request Class Definition ######
class ProxyChatCompletionRequest(LiteLLMBase): class ProxyChatCompletionRequest(LiteLLMBase):
@ -182,6 +189,16 @@ class KeyManagementSystem(enum.Enum):
LOCAL = "local" LOCAL = "local"
class DynamoDBArgs(LiteLLMBase):
billing_mode: Literal["PROVISIONED_THROUGHPUT", "PAY_PER_REQUEST"]
read_capacity_units: Optional[int] = None
write_capacity_units: Optional[int] = None
region_name: str
user_table_name: str = "LiteLLM_UserTable"
key_table_name: str = "LiteLLM_VerificationToken"
config_table_name: str = "LiteLLM_Config"
class ConfigGeneralSettings(LiteLLMBase): class ConfigGeneralSettings(LiteLLMBase):
""" """
Documents all the fields supported by `general_settings` in config.yaml Documents all the fields supported by `general_settings` in config.yaml
@ -206,6 +223,13 @@ class ConfigGeneralSettings(LiteLLMBase):
None, None,
description="connect to a postgres db - needed for generating temporary keys + tracking spend / key", description="connect to a postgres db - needed for generating temporary keys + tracking spend / key",
) )
database_type: Optional[Literal["dynamo_db"]] = Field(
None, description="to use dynamodb instead of postgres db"
)
database_args: Optional[DynamoDBArgs] = Field(
None,
description="custom args for instantiating dynamodb client - e.g. billing provision",
)
otel: Optional[bool] = Field( otel: Optional[bool] = Field(
None, None,
description="[BETA] OpenTelemetry support - this might change, use with caution.", description="[BETA] OpenTelemetry support - this might change, use with caution.",
@ -258,3 +282,33 @@ class ConfigYAML(LiteLLMBase):
class Config: class Config:
protected_namespaces = () protected_namespaces = ()
class LiteLLM_VerificationToken(LiteLLMBase):
token: str
spend: float = 0.0
expires: Union[str, None]
models: List[str]
aliases: Dict[str, str] = {}
config: Dict[str, str] = {}
user_id: Union[str, None]
max_parallel_requests: Union[int, None]
metadata: Dict[str, str] = {}
class LiteLLM_Config(LiteLLMBase):
param_name: str
param_value: Dict
class LiteLLM_UserTable(LiteLLMBase):
user_id: str
max_budget: Optional[float]
spend: float = 0.0
user_email: Optional[str]
@root_validator(pre=True)
def set_model_info(cls, values):
if values.get("spend") is None:
values.update({"spend": 0.0})
return values

View file

@ -0,0 +1,53 @@
from typing import Any, Literal, List
class CustomDB:
"""
Implements a base class that we expect any custom db implementation (e.g. DynamoDB) to follow
"""
def __init__(self) -> None:
pass
def get_data(self, key: str, table_name: Literal["user", "key", "config"]):
"""
Check if key valid
"""
pass
def insert_data(self, value: Any, table_name: Literal["user", "key", "config"]):
"""
For new key / user logic
"""
pass
def update_data(
self, key: str, value: Any, table_name: Literal["user", "key", "config"]
):
"""
For cost tracking logic
"""
pass
def delete_data(
self, keys: List[str], table_name: Literal["user", "key", "config"]
):
"""
For /key/delete endpoint s
"""
def connect(
self,
):
"""
For connecting to db and creating / updating any tables
"""
pass
def disconnect(
self,
):
"""
For closing connection on server shutdown
"""
pass

View file

@ -0,0 +1,204 @@
import json
from aiodynamo.client import Client
from aiodynamo.credentials import Credentials, StaticCredentials
from aiodynamo.http.httpx import HTTPX
from aiodynamo.models import Throughput, KeySchema, KeySpec, KeyType, PayPerRequest
from yarl import URL
from litellm.proxy.db.base_client import CustomDB
from litellm.proxy._types import (
DynamoDBArgs,
LiteLLM_VerificationToken,
LiteLLM_Config,
LiteLLM_UserTable,
)
from litellm import get_secret
from typing import Any, List, Literal, Optional, Union
from aiodynamo.expressions import UpdateExpression, F, Value
from aiodynamo.models import ReturnValues
from aiodynamo.http.aiohttp import AIOHTTP
from aiohttp import ClientSession
from datetime import datetime
class DynamoDBWrapper(CustomDB):
credentials: Credentials
def __init__(self, database_arguments: DynamoDBArgs):
self.throughput_type = None
if database_arguments.billing_mode == "PAY_PER_REQUEST":
self.throughput_type = PayPerRequest()
elif database_arguments.billing_mode == "PROVISIONED_THROUGHPUT":
if (
database_arguments.read_capacity_units is not None
and isinstance(database_arguments.read_capacity_units, int)
and database_arguments.write_capacity_units is not None
and isinstance(database_arguments.write_capacity_units, int)
):
self.throughput_type = Throughput(read=database_arguments.read_capacity_units, write=database_arguments.write_capacity_units) # type: ignore
else:
raise Exception(
f"Invalid args passed in. Need to set both read_capacity_units and write_capacity_units. Args passed in - {database_arguments}"
)
self.database_arguments = database_arguments
self.region_name = database_arguments.region_name
async def connect(self):
"""
Connect to DB, and creating / updating any tables
"""
async with ClientSession() as session:
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name)
## User
try:
error_occurred = False
table = client.table(self.database_arguments.user_table_name)
if not await table.exists():
await table.create(
self.throughput_type,
KeySchema(hash_key=KeySpec("user_id", KeyType.string)),
)
except Exception as e:
error_occurred = True
if error_occurred == True:
raise Exception(
f"Failed to create table - {self.database_arguments.user_table_name}.\nPlease create a new table called {self.database_arguments.user_table_name}\nAND set `hash_key` as 'user_id'"
)
## Token
try:
error_occurred = False
table = client.table(self.database_arguments.key_table_name)
if not await table.exists():
await table.create(
self.throughput_type,
KeySchema(hash_key=KeySpec("token", KeyType.string)),
)
except Exception as e:
error_occurred = True
if error_occurred == True:
raise Exception(
f"Failed to create table - {self.database_arguments.key_table_name}.\nPlease create a new table called {self.database_arguments.key_table_name}\nAND set `hash_key` as 'token'"
)
## Config
try:
error_occurred = False
table = client.table(self.database_arguments.config_table_name)
if not await table.exists():
await table.create(
self.throughput_type,
KeySchema(hash_key=KeySpec("param_name", KeyType.string)),
)
except Exception as e:
error_occurred = True
if error_occurred == True:
raise Exception(
f"Failed to create table - {self.database_arguments.config_table_name}.\nPlease create a new table called {self.database_arguments.config_table_name}\nAND set `hash_key` as 'param_name'"
)
async def insert_data(
self, value: Any, table_name: Literal["user", "key", "config"]
):
async with ClientSession() as session:
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name)
table = None
if table_name == "user":
table = client.table(self.database_arguments.user_table_name)
elif table_name == "key":
table = client.table(self.database_arguments.key_table_name)
elif table_name == "config":
table = client.table(self.database_arguments.config_table_name)
for k, v in value.items():
if isinstance(v, datetime):
value[k] = v.isoformat()
await table.put_item(item=value)
async def get_data(self, key: str, table_name: Literal["user", "key", "config"]):
async with ClientSession() as session:
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name)
table = None
key_name = None
if table_name == "user":
table = client.table(self.database_arguments.user_table_name)
key_name = "user_id"
elif table_name == "key":
table = client.table(self.database_arguments.key_table_name)
key_name = "token"
elif table_name == "config":
table = client.table(self.database_arguments.config_table_name)
key_name = "param_name"
response = await table.get_item({key_name: key})
new_response: Any = None
if table_name == "user":
new_response = LiteLLM_UserTable(**response)
elif table_name == "key":
new_response = {}
for k, v in response.items(): # handle json string
if (
(k == "aliases" or k == "config" or k == "metadata")
and v is not None
and isinstance(v, str)
):
new_response[k] = json.loads(v)
else:
new_response[k] = v
new_response = LiteLLM_VerificationToken(**new_response)
elif table_name == "config":
new_response = LiteLLM_Config(**response)
return new_response
async def update_data(
self, key: str, value: dict, table_name: Literal["user", "key", "config"]
):
async with ClientSession() as session:
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name)
table = None
key_name = None
try:
if table_name == "user":
table = client.table(self.database_arguments.user_table_name)
key_name = "user_id"
elif table_name == "key":
table = client.table(self.database_arguments.key_table_name)
key_name = "token"
elif table_name == "config":
table = client.table(self.database_arguments.config_table_name)
key_name = "param_name"
else:
raise Exception(
f"Invalid table name. Needs to be one of - {self.database_arguments.user_table_name}, {self.database_arguments.key_table_name}, {self.database_arguments.config_table_name}"
)
except Exception as e:
raise Exception(f"Error connecting to table - {str(e)}")
# Initialize an empty UpdateExpression
actions: List = []
for k, v in value.items():
# Convert datetime object to ISO8601 string
if isinstance(v, datetime):
v = v.isoformat()
# Accumulate updates
actions.append((F(k), Value(value=v)))
update_expression = UpdateExpression(set_updates=actions)
# Perform the update in DynamoDB
result = await table.update_item(
key={key_name: key},
update_expression=update_expression,
return_values=ReturnValues.none,
)
return result
async def delete_data(
self, keys: List[str], table_name: Literal["user", "key", "config"]
):
"""
Not Implemented yet.
"""
return super().delete_data(keys, table_name)

View file

@ -67,6 +67,7 @@ def generate_feedback_box():
import litellm import litellm
from litellm.proxy.utils import ( from litellm.proxy.utils import (
PrismaClient, PrismaClient,
DBClient,
get_instance_fn, get_instance_fn,
ProxyLogging, ProxyLogging,
_cache_user_row, _cache_user_row,
@ -141,6 +142,7 @@ worker_config = None
master_key = None master_key = None
otel_logging = False otel_logging = False
prisma_client: Optional[PrismaClient] = None prisma_client: Optional[PrismaClient] = None
custom_db_client: Optional[DBClient] = None
user_api_key_cache = DualCache() user_api_key_cache = DualCache()
user_custom_auth = None user_custom_auth = None
use_background_health_checks = None use_background_health_checks = None
@ -184,12 +186,12 @@ def _get_pydantic_json_dict(pydantic_obj: BaseModel) -> dict:
async def user_api_key_auth( async def user_api_key_auth(
request: Request, api_key: str = fastapi.Security(api_key_header) request: Request, api_key: str = fastapi.Security(api_key_header)
) -> UserAPIKeyAuth: ) -> UserAPIKeyAuth:
global master_key, prisma_client, llm_model_list, user_custom_auth global master_key, prisma_client, llm_model_list, user_custom_auth, custom_db_client
try: try:
if isinstance(api_key, str): if isinstance(api_key, str):
api_key = _get_bearer_token(api_key=api_key) api_key = _get_bearer_token(api_key=api_key)
### USER-DEFINED AUTH FUNCTION ### ### USER-DEFINED AUTH FUNCTION ###
if user_custom_auth: if user_custom_auth is not None:
response = await user_custom_auth(request=request, api_key=api_key) response = await user_custom_auth(request=request, api_key=api_key)
return UserAPIKeyAuth.model_validate(response) return UserAPIKeyAuth.model_validate(response)
### LITELLM-DEFINED AUTH FUNCTION ### ### LITELLM-DEFINED AUTH FUNCTION ###
@ -231,7 +233,7 @@ async def user_api_key_auth(
) )
if ( if (
prisma_client is None prisma_client is None and custom_db_client is None
): # if both master key + user key submitted, and user key != master key, and no db connected, raise an error ): # if both master key + user key submitted, and user key != master key, and no db connected, raise an error
raise Exception("No connected db.") raise Exception("No connected db.")
@ -241,15 +243,22 @@ async def user_api_key_auth(
if valid_token is None: if valid_token is None:
## check db ## check db
verbose_proxy_logger.debug(f"api key: {api_key}") verbose_proxy_logger.debug(f"api key: {api_key}")
valid_token = await prisma_client.get_data( if prisma_client is not None:
token=api_key, valid_token = await prisma_client.get_data(
) token=api_key,
expires = datetime.utcnow().replace(tzinfo=timezone.utc) )
expires = datetime.utcnow().replace(tzinfo=timezone.utc)
elif custom_db_client is not None:
valid_token = await custom_db_client.get_data(
key=api_key, table_name="key"
)
# Token exists, now check expiration. # Token exists, now check expiration.
if valid_token.expires is not None and expires is not None: if valid_token.expires is not None:
if valid_token.expires >= expires: expiry_time = datetime.fromisoformat(valid_token.expires)
if expiry_time >= datetime.utcnow():
# Token exists and is not expired. # Token exists and is not expired.
return valid_token return response
else: else:
# Token exists but is expired. # Token exists but is expired.
raise HTTPException( raise HTTPException(
@ -291,13 +300,22 @@ async def user_api_key_auth(
This makes the user row data accessible to pre-api call hooks. This makes the user row data accessible to pre-api call hooks.
""" """
asyncio.create_task( if prisma_client is not None:
_cache_user_row( asyncio.create_task(
user_id=valid_token.user_id, _cache_user_row(
cache=user_api_key_cache, user_id=valid_token.user_id,
db=prisma_client, cache=user_api_key_cache,
db=prisma_client,
)
)
elif custom_db_client is not None:
asyncio.create_task(
_cache_user_row(
user_id=valid_token.user_id,
cache=user_api_key_cache,
db=custom_db_client,
)
) )
)
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict) return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
else: else:
raise Exception(f"Invalid token") raise Exception(f"Invalid token")
@ -371,8 +389,8 @@ def load_from_azure_key_vault(use_azure_key_vault: bool = False):
def cost_tracking(): def cost_tracking():
global prisma_client global prisma_client, custom_db_client
if prisma_client is not None: if prisma_client is not None or custom_db_client is not None:
if isinstance(litellm.success_callback, list): if isinstance(litellm.success_callback, list):
verbose_proxy_logger.debug("setting litellm success callback to track cost") verbose_proxy_logger.debug("setting litellm success callback to track cost")
if (track_cost_callback) not in litellm.success_callback: # type: ignore if (track_cost_callback) not in litellm.success_callback: # type: ignore
@ -385,7 +403,7 @@ async def track_cost_callback(
start_time=None, start_time=None,
end_time=None, # start/end time for completion end_time=None, # start/end time for completion
): ):
global prisma_client global prisma_client, custom_db_client
try: try:
# check if it has collected an entire stream response # check if it has collected an entire stream response
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
@ -404,10 +422,10 @@ async def track_cost_callback(
user_id = kwargs["litellm_params"]["metadata"].get( user_id = kwargs["litellm_params"]["metadata"].get(
"user_api_key_user_id", None "user_api_key_user_id", None
) )
if user_api_key and prisma_client: if user_api_key and (
await update_prisma_database( prisma_client is not None or custom_db_client is not None
token=user_api_key, response_cost=response_cost ):
) await update_database(token=user_api_key, response_cost=response_cost)
elif kwargs["stream"] == False: # for non streaming responses elif kwargs["stream"] == False: # for non streaming responses
response_cost = litellm.completion_cost( response_cost = litellm.completion_cost(
completion_response=completion_response completion_response=completion_response
@ -418,15 +436,17 @@ async def track_cost_callback(
user_id = kwargs["litellm_params"]["metadata"].get( user_id = kwargs["litellm_params"]["metadata"].get(
"user_api_key_user_id", None "user_api_key_user_id", None
) )
if user_api_key and prisma_client: if user_api_key and (
await update_prisma_database( prisma_client is not None or custom_db_client is not None
):
await update_database(
token=user_api_key, response_cost=response_cost, user_id=user_id token=user_api_key, response_cost=response_cost, user_id=user_id
) )
except Exception as e: except Exception as e:
verbose_proxy_logger.debug(f"error in tracking cost callback - {str(e)}") verbose_proxy_logger.debug(f"error in tracking cost callback - {str(e)}")
async def update_prisma_database(token, response_cost, user_id=None): async def update_database(token, response_cost, user_id=None):
try: try:
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
f"Enters prisma db call, token: {token}; user_id: {user_id}" f"Enters prisma db call, token: {token}; user_id: {user_id}"
@ -436,7 +456,12 @@ async def update_prisma_database(token, response_cost, user_id=None):
async def _update_user_db(): async def _update_user_db():
if user_id is None: if user_id is None:
return return
existing_spend_obj = await prisma_client.get_data(user_id=user_id) if prisma_client is not None:
existing_spend_obj = await prisma_client.get_data(user_id=user_id)
elif custom_db_client is not None:
existing_spend_obj = await custom_db_client.get_data(
key=user_id, table_name="user"
)
if existing_spend_obj is None: if existing_spend_obj is None:
existing_spend = 0 existing_spend = 0
else: else:
@ -447,23 +472,49 @@ async def update_prisma_database(token, response_cost, user_id=None):
verbose_proxy_logger.debug(f"new cost: {new_spend}") verbose_proxy_logger.debug(f"new cost: {new_spend}")
# Update the cost column for the given user id # Update the cost column for the given user id
await prisma_client.update_data(user_id=user_id, data={"spend": new_spend}) if prisma_client is not None:
await prisma_client.update_data(
user_id=user_id, data={"spend": new_spend}
)
elif custom_db_client is not None:
await custom_db_client.update_data(
key=user_id, value={"spend": new_spend}, table_name="user"
)
### UPDATE KEY SPEND ### ### UPDATE KEY SPEND ###
async def _update_key_db(): async def _update_key_db():
# Fetch the existing cost for the given token if prisma_client is not None:
existing_spend_obj = await prisma_client.get_data(token=token) # Fetch the existing cost for the given token
verbose_proxy_logger.debug(f"existing spend: {existing_spend_obj}") existing_spend_obj = await prisma_client.get_data(token=token)
if existing_spend_obj is None: verbose_proxy_logger.debug(f"existing spend: {existing_spend_obj}")
existing_spend = 0 if existing_spend_obj is None:
else: existing_spend = 0
existing_spend = existing_spend_obj.spend else:
# Calculate the new cost by adding the existing cost and response_cost existing_spend = existing_spend_obj.spend
new_spend = existing_spend + response_cost # Calculate the new cost by adding the existing cost and response_cost
new_spend = existing_spend + response_cost
verbose_proxy_logger.debug(f"new cost: {new_spend}") verbose_proxy_logger.debug(f"new cost: {new_spend}")
# Update the cost column for the given token # Update the cost column for the given token
await prisma_client.update_data(token=token, data={"spend": new_spend}) await prisma_client.update_data(token=token, data={"spend": new_spend})
elif custom_db_client is not None:
# Fetch the existing cost for the given token
existing_spend_obj = await custom_db_client.get_data(
key=token, table_name="key"
)
verbose_proxy_logger.debug(f"existing spend: {existing_spend_obj}")
if existing_spend_obj is None:
existing_spend = 0
else:
existing_spend = existing_spend_obj.spend
# Calculate the new cost by adding the existing cost and response_cost
new_spend = existing_spend + response_cost
verbose_proxy_logger.debug(f"new cost: {new_spend}")
# Update the cost column for the given token
await custom_db_client.update_data(
key=token, value={"spend": new_spend}, table_name="key"
)
tasks = [] tasks = []
tasks.append(_update_user_db()) tasks.append(_update_user_db())
@ -622,7 +673,7 @@ class ProxyConfig:
""" """
Load config values into proxy global state Load config values into proxy global state
""" """
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval, use_queue global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval, use_queue, custom_db_client
# Load existing config # Load existing config
config = await self.get_config(config_file_path=config_file_path) config = await self.get_config(config_file_path=config_file_path)
@ -787,8 +838,6 @@ class ProxyConfig:
verbose_proxy_logger.debug(f"GOING INTO LITELLM.GET_SECRET!") verbose_proxy_logger.debug(f"GOING INTO LITELLM.GET_SECRET!")
database_url = litellm.get_secret(database_url) database_url = litellm.get_secret(database_url)
verbose_proxy_logger.debug(f"RETRIEVED DB URL: {database_url}") verbose_proxy_logger.debug(f"RETRIEVED DB URL: {database_url}")
## COST TRACKING ##
cost_tracking()
### MASTER KEY ### ### MASTER KEY ###
master_key = general_settings.get( master_key = general_settings.get(
"master_key", litellm.get_secret("LITELLM_MASTER_KEY", None) "master_key", litellm.get_secret("LITELLM_MASTER_KEY", None)
@ -796,11 +845,23 @@ class ProxyConfig:
if master_key and master_key.startswith("os.environ/"): if master_key and master_key.startswith("os.environ/"):
master_key = litellm.get_secret(master_key) master_key = litellm.get_secret(master_key)
### CUSTOM API KEY AUTH ### ### CUSTOM API KEY AUTH ###
## pass filepath
custom_auth = general_settings.get("custom_auth", None) custom_auth = general_settings.get("custom_auth", None)
if custom_auth: if custom_auth is not None:
user_custom_auth = get_instance_fn( user_custom_auth = get_instance_fn(
value=custom_auth, config_file_path=config_file_path value=custom_auth, config_file_path=config_file_path
) )
## dynamodb
database_type = general_settings.get("database_type", None)
if database_type is not None and (
database_type == "dynamo_db" or database_type == "dynamodb"
):
database_args = general_settings.get("database_args", None)
custom_db_client = DBClient(
custom_db_args=database_args, custom_db_type=database_type
)
## COST TRACKING ##
cost_tracking()
### BACKGROUND HEALTH CHECKS ### ### BACKGROUND HEALTH CHECKS ###
# Enable background health checks # Enable background health checks
use_background_health_checks = general_settings.get( use_background_health_checks = general_settings.get(
@ -867,9 +928,9 @@ async def generate_key_helper_fn(
max_parallel_requests: Optional[int] = None, max_parallel_requests: Optional[int] = None,
metadata: Optional[dict] = {}, metadata: Optional[dict] = {},
): ):
global prisma_client global prisma_client, custom_db_client
if prisma_client is None: if prisma_client is None and custom_db_client is None:
raise Exception( raise Exception(
f"Connect Proxy to database to generate keys - https://docs.litellm.ai/docs/proxy/virtual_keys " f"Connect Proxy to database to generate keys - https://docs.litellm.ai/docs/proxy/virtual_keys "
) )
@ -908,7 +969,13 @@ async def generate_key_helper_fn(
user_id = user_id or str(uuid.uuid4()) user_id = user_id or str(uuid.uuid4())
try: try:
# Create a new verification token (you may want to enhance this logic based on your needs) # Create a new verification token (you may want to enhance this logic based on your needs)
verification_token_data = { user_data = {
"max_budget": max_budget,
"user_email": user_email,
"user_id": user_id,
"spend": spend,
}
key_data = {
"token": token, "token": token,
"expires": expires, "expires": expires,
"models": models, "models": models,
@ -918,19 +985,23 @@ async def generate_key_helper_fn(
"user_id": user_id, "user_id": user_id,
"max_parallel_requests": max_parallel_requests, "max_parallel_requests": max_parallel_requests,
"metadata": metadata_json, "metadata": metadata_json,
"max_budget": max_budget,
"user_email": user_email,
} }
verbose_proxy_logger.debug("PrismaClient: Before Insert Data") if prisma_client is not None:
new_verification_token = await prisma_client.insert_data( verification_token_data = dict(key_data)
data=verification_token_data verification_token_data.update(user_data)
) verbose_proxy_logger.debug("PrismaClient: Before Insert Data")
await prisma_client.insert_data(data=verification_token_data)
elif custom_db_client is not None:
## CREATE USER (If necessary)
await custom_db_client.insert_data(value=user_data, table_name="user")
## CREATE KEY
await custom_db_client.insert_data(value=key_data, table_name="key")
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
return { return {
"token": token, "token": token,
"expires": new_verification_token.expires, "expires": expires,
"user_id": user_id, "user_id": user_id,
"max_budget": max_budget, "max_budget": max_budget,
} }
@ -1174,12 +1245,21 @@ async def startup_event():
if prisma_client is not None: if prisma_client is not None:
await prisma_client.connect() await prisma_client.connect()
if custom_db_client is not None:
await custom_db_client.connect()
if prisma_client is not None and master_key is not None: if prisma_client is not None and master_key is not None:
# add master key to db # add master key to db
await generate_key_helper_fn( await generate_key_helper_fn(
duration=None, models=[], aliases={}, config={}, spend=0, token=master_key duration=None, models=[], aliases={}, config={}, spend=0, token=master_key
) )
if custom_db_client is not None and master_key is not None:
# add master key to db
await generate_key_helper_fn(
duration=None, models=[], aliases={}, config={}, spend=0, token=master_key
)
#### API ENDPOINTS #### #### API ENDPOINTS ####
@router.get( @router.get(

View file

@ -1,11 +1,13 @@
from typing import Optional, List, Any, Literal from typing import Optional, List, Any, Literal, Union
import os, subprocess, hashlib, importlib, asyncio, copy, json, aiohttp, httpx import os, subprocess, hashlib, importlib, asyncio, copy, json, aiohttp, httpx
import litellm, backoff import litellm, backoff
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth, DynamoDBArgs
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler
from litellm.proxy.hooks.max_budget_limiter import MaxBudgetLimiter from litellm.proxy.hooks.max_budget_limiter import MaxBudgetLimiter
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy.db.base_client import CustomDB
from litellm.proxy.db.dynamo_db import DynamoDBWrapper
from fastapi import HTTPException, status from fastapi import HTTPException, status
import smtplib import smtplib
from email.mime.text import MIMEText from email.mime.text import MIMEText
@ -578,6 +580,65 @@ class PrismaClient:
raise e raise e
class DBClient:
"""
Routes requests for CustomAuth
[TODO] route b/w customauth and prisma
"""
def __init__(
self, custom_db_type: Literal["dynamo_db"], custom_db_args: dict
) -> None:
if custom_db_type == "dynamo_db":
self.db = DynamoDBWrapper(database_arguments=DynamoDBArgs(**custom_db_args))
async def get_data(self, key: str, table_name: Literal["user", "key", "config"]):
"""
Check if key valid
"""
return await self.db.get_data(key=key, table_name=table_name)
async def insert_data(
self, value: Any, table_name: Literal["user", "key", "config"]
):
"""
For new key / user logic
"""
return await self.db.insert_data(value=value, table_name=table_name)
async def update_data(
self, key: str, value: Any, table_name: Literal["user", "key", "config"]
):
"""
For cost tracking logic
key - hash_key value \n
value - dict with updated values
"""
return await self.db.update_data(key=key, value=value, table_name=table_name)
async def delete_data(
self, keys: List[str], table_name: Literal["user", "key", "config"]
):
"""
For /key/delete endpoints
"""
return await self.db.delete_data(keys=keys, table_name=table_name)
async def connect(self):
"""
For connecting to db and creating / updating any tables
"""
return await self.db.connect()
async def disconnect(self):
"""
For closing connection on server shutdown
"""
return await self.db.disconnect()
### CUSTOM FILE ### ### CUSTOM FILE ###
def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any: def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
try: try:
@ -618,7 +679,9 @@ def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
### HELPER FUNCTIONS ### ### HELPER FUNCTIONS ###
async def _cache_user_row(user_id: str, cache: DualCache, db: PrismaClient): async def _cache_user_row(
user_id: str, cache: DualCache, db: Union[PrismaClient, DBClient]
):
""" """
Check if a user_id exists in cache, Check if a user_id exists in cache,
if not retrieve it. if not retrieve it.
@ -627,7 +690,10 @@ async def _cache_user_row(user_id: str, cache: DualCache, db: PrismaClient):
cache_key = f"{user_id}_user_api_key_user_id" cache_key = f"{user_id}_user_api_key_user_id"
response = cache.get_cache(key=cache_key) response = cache.get_cache(key=cache_key)
if response is None: # Cache miss if response is None: # Cache miss
user_row = await db.get_data(user_id=user_id) if isinstance(db, PrismaClient):
user_row = await db.get_data(user_id=user_id)
elif isinstance(db, DBClient):
user_row = await db.get_data(key=user_id, table_name="user")
if user_row is not None: if user_row is not None:
print_verbose(f"User Row: {user_row}, type = {type(user_row)}") print_verbose(f"User Row: {user_row}, type = {type(user_row)}")
if hasattr(user_row, "model_dump_json") and callable( if hasattr(user_row, "model_dump_json") and callable(