mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
Merge branch 'litellm_dynamo_db_keys'
This commit is contained in:
commit
8a7a745549
12 changed files with 547 additions and 58 deletions
BIN
dist/litellm-1.16.21.dev1-py3-none-any.whl
vendored
Normal file
BIN
dist/litellm-1.16.21.dev1-py3-none-any.whl
vendored
Normal file
Binary file not shown.
BIN
dist/litellm-1.16.21.dev1.tar.gz
vendored
Normal file
BIN
dist/litellm-1.16.21.dev1.tar.gz
vendored
Normal file
Binary file not shown.
BIN
dist/litellm-1.16.21.dev2-py3-none-any.whl
vendored
Normal file
BIN
dist/litellm-1.16.21.dev2-py3-none-any.whl
vendored
Normal file
Binary file not shown.
BIN
dist/litellm-1.16.21.dev2.tar.gz
vendored
Normal file
BIN
dist/litellm-1.16.21.dev2.tar.gz
vendored
Normal file
Binary file not shown.
BIN
dist/litellm-1.16.21.dev3-py3-none-any.whl
vendored
Normal file
BIN
dist/litellm-1.16.21.dev3-py3-none-any.whl
vendored
Normal file
Binary file not shown.
BIN
dist/litellm-1.16.21.dev3.tar.gz
vendored
Normal file
BIN
dist/litellm-1.16.21.dev3.tar.gz
vendored
Normal file
Binary file not shown.
|
@ -246,3 +246,35 @@ general_settings:
|
||||||
$ litellm --config /path/to/config.yaml
|
$ litellm --config /path/to/config.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## [BETA] Dynamo DB
|
||||||
|
|
||||||
|
Only live in `v1.16.21.dev1`.
|
||||||
|
|
||||||
|
### Step 1. Save keys to env
|
||||||
|
|
||||||
|
```env
|
||||||
|
AWS_ACCESS_KEY_ID = "your-aws-access-key-id"
|
||||||
|
AWS_SECRET_ACCESS_KEY = "your-aws-secret-access-key"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 2. Add details to config
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
general_settings:
|
||||||
|
master_key: sk-1234
|
||||||
|
database_type: "dynamo_db"
|
||||||
|
database_args: { # 👈 all args - https://github.com/BerriAI/litellm/blob/befbcbb7ac8f59835ce47415c128decf37aac328/litellm/proxy/_types.py#L190
|
||||||
|
"billing_mode": "PAY_PER_REQUEST",
|
||||||
|
"region_name": "us-west-2"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 3. Generate Key
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl --location 'http://0.0.0.0:8000/key/generate' \
|
||||||
|
--header 'Authorization: Bearer sk-1234' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data '{"models": ["azure-models"], "aliases": {"mistral-7b": "gpt-3.5-turbo"}, "duration": null}'
|
||||||
|
```
|
|
@ -17,6 +17,13 @@ class LiteLLMBase(BaseModel):
|
||||||
# if using pydantic v1
|
# if using pydantic v1
|
||||||
return self.dict()
|
return self.dict()
|
||||||
|
|
||||||
|
def fields_set(self):
|
||||||
|
try:
|
||||||
|
return self.model_fields_set # noqa
|
||||||
|
except:
|
||||||
|
# if using pydantic v1
|
||||||
|
return self.__fields_set__
|
||||||
|
|
||||||
|
|
||||||
######### Request Class Definition ######
|
######### Request Class Definition ######
|
||||||
class ProxyChatCompletionRequest(LiteLLMBase):
|
class ProxyChatCompletionRequest(LiteLLMBase):
|
||||||
|
@ -182,6 +189,16 @@ class KeyManagementSystem(enum.Enum):
|
||||||
LOCAL = "local"
|
LOCAL = "local"
|
||||||
|
|
||||||
|
|
||||||
|
class DynamoDBArgs(LiteLLMBase):
|
||||||
|
billing_mode: Literal["PROVISIONED_THROUGHPUT", "PAY_PER_REQUEST"]
|
||||||
|
read_capacity_units: Optional[int] = None
|
||||||
|
write_capacity_units: Optional[int] = None
|
||||||
|
region_name: str
|
||||||
|
user_table_name: str = "LiteLLM_UserTable"
|
||||||
|
key_table_name: str = "LiteLLM_VerificationToken"
|
||||||
|
config_table_name: str = "LiteLLM_Config"
|
||||||
|
|
||||||
|
|
||||||
class ConfigGeneralSettings(LiteLLMBase):
|
class ConfigGeneralSettings(LiteLLMBase):
|
||||||
"""
|
"""
|
||||||
Documents all the fields supported by `general_settings` in config.yaml
|
Documents all the fields supported by `general_settings` in config.yaml
|
||||||
|
@ -206,6 +223,13 @@ class ConfigGeneralSettings(LiteLLMBase):
|
||||||
None,
|
None,
|
||||||
description="connect to a postgres db - needed for generating temporary keys + tracking spend / key",
|
description="connect to a postgres db - needed for generating temporary keys + tracking spend / key",
|
||||||
)
|
)
|
||||||
|
database_type: Optional[Literal["dynamo_db"]] = Field(
|
||||||
|
None, description="to use dynamodb instead of postgres db"
|
||||||
|
)
|
||||||
|
database_args: Optional[DynamoDBArgs] = Field(
|
||||||
|
None,
|
||||||
|
description="custom args for instantiating dynamodb client - e.g. billing provision",
|
||||||
|
)
|
||||||
otel: Optional[bool] = Field(
|
otel: Optional[bool] = Field(
|
||||||
None,
|
None,
|
||||||
description="[BETA] OpenTelemetry support - this might change, use with caution.",
|
description="[BETA] OpenTelemetry support - this might change, use with caution.",
|
||||||
|
@ -258,3 +282,33 @@ class ConfigYAML(LiteLLMBase):
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
protected_namespaces = ()
|
protected_namespaces = ()
|
||||||
|
|
||||||
|
|
||||||
|
class LiteLLM_VerificationToken(LiteLLMBase):
|
||||||
|
token: str
|
||||||
|
spend: float = 0.0
|
||||||
|
expires: Union[str, None]
|
||||||
|
models: List[str]
|
||||||
|
aliases: Dict[str, str] = {}
|
||||||
|
config: Dict[str, str] = {}
|
||||||
|
user_id: Union[str, None]
|
||||||
|
max_parallel_requests: Union[int, None]
|
||||||
|
metadata: Dict[str, str] = {}
|
||||||
|
|
||||||
|
|
||||||
|
class LiteLLM_Config(LiteLLMBase):
|
||||||
|
param_name: str
|
||||||
|
param_value: Dict
|
||||||
|
|
||||||
|
|
||||||
|
class LiteLLM_UserTable(LiteLLMBase):
|
||||||
|
user_id: str
|
||||||
|
max_budget: Optional[float]
|
||||||
|
spend: float = 0.0
|
||||||
|
user_email: Optional[str]
|
||||||
|
|
||||||
|
@root_validator(pre=True)
|
||||||
|
def set_model_info(cls, values):
|
||||||
|
if values.get("spend") is None:
|
||||||
|
values.update({"spend": 0.0})
|
||||||
|
return values
|
||||||
|
|
53
litellm/proxy/db/base_client.py
Normal file
53
litellm/proxy/db/base_client.py
Normal file
|
@ -0,0 +1,53 @@
|
||||||
|
from typing import Any, Literal, List
|
||||||
|
|
||||||
|
|
||||||
|
class CustomDB:
|
||||||
|
"""
|
||||||
|
Implements a base class that we expect any custom db implementation (e.g. DynamoDB) to follow
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_data(self, key: str, table_name: Literal["user", "key", "config"]):
|
||||||
|
"""
|
||||||
|
Check if key valid
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def insert_data(self, value: Any, table_name: Literal["user", "key", "config"]):
|
||||||
|
"""
|
||||||
|
For new key / user logic
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def update_data(
|
||||||
|
self, key: str, value: Any, table_name: Literal["user", "key", "config"]
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
For cost tracking logic
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def delete_data(
|
||||||
|
self, keys: List[str], table_name: Literal["user", "key", "config"]
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
For /key/delete endpoint s
|
||||||
|
"""
|
||||||
|
|
||||||
|
def connect(
|
||||||
|
self,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
For connecting to db and creating / updating any tables
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def disconnect(
|
||||||
|
self,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
For closing connection on server shutdown
|
||||||
|
"""
|
||||||
|
pass
|
204
litellm/proxy/db/dynamo_db.py
Normal file
204
litellm/proxy/db/dynamo_db.py
Normal file
|
@ -0,0 +1,204 @@
|
||||||
|
import json
|
||||||
|
from aiodynamo.client import Client
|
||||||
|
from aiodynamo.credentials import Credentials, StaticCredentials
|
||||||
|
from aiodynamo.http.httpx import HTTPX
|
||||||
|
from aiodynamo.models import Throughput, KeySchema, KeySpec, KeyType, PayPerRequest
|
||||||
|
from yarl import URL
|
||||||
|
from litellm.proxy.db.base_client import CustomDB
|
||||||
|
from litellm.proxy._types import (
|
||||||
|
DynamoDBArgs,
|
||||||
|
LiteLLM_VerificationToken,
|
||||||
|
LiteLLM_Config,
|
||||||
|
LiteLLM_UserTable,
|
||||||
|
)
|
||||||
|
from litellm import get_secret
|
||||||
|
from typing import Any, List, Literal, Optional, Union
|
||||||
|
from aiodynamo.expressions import UpdateExpression, F, Value
|
||||||
|
from aiodynamo.models import ReturnValues
|
||||||
|
from aiodynamo.http.aiohttp import AIOHTTP
|
||||||
|
from aiohttp import ClientSession
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
|
class DynamoDBWrapper(CustomDB):
|
||||||
|
credentials: Credentials
|
||||||
|
|
||||||
|
def __init__(self, database_arguments: DynamoDBArgs):
|
||||||
|
self.throughput_type = None
|
||||||
|
if database_arguments.billing_mode == "PAY_PER_REQUEST":
|
||||||
|
self.throughput_type = PayPerRequest()
|
||||||
|
elif database_arguments.billing_mode == "PROVISIONED_THROUGHPUT":
|
||||||
|
if (
|
||||||
|
database_arguments.read_capacity_units is not None
|
||||||
|
and isinstance(database_arguments.read_capacity_units, int)
|
||||||
|
and database_arguments.write_capacity_units is not None
|
||||||
|
and isinstance(database_arguments.write_capacity_units, int)
|
||||||
|
):
|
||||||
|
self.throughput_type = Throughput(read=database_arguments.read_capacity_units, write=database_arguments.write_capacity_units) # type: ignore
|
||||||
|
else:
|
||||||
|
raise Exception(
|
||||||
|
f"Invalid args passed in. Need to set both read_capacity_units and write_capacity_units. Args passed in - {database_arguments}"
|
||||||
|
)
|
||||||
|
self.database_arguments = database_arguments
|
||||||
|
self.region_name = database_arguments.region_name
|
||||||
|
|
||||||
|
async def connect(self):
|
||||||
|
"""
|
||||||
|
Connect to DB, and creating / updating any tables
|
||||||
|
"""
|
||||||
|
async with ClientSession() as session:
|
||||||
|
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name)
|
||||||
|
## User
|
||||||
|
try:
|
||||||
|
error_occurred = False
|
||||||
|
table = client.table(self.database_arguments.user_table_name)
|
||||||
|
if not await table.exists():
|
||||||
|
await table.create(
|
||||||
|
self.throughput_type,
|
||||||
|
KeySchema(hash_key=KeySpec("user_id", KeyType.string)),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
error_occurred = True
|
||||||
|
if error_occurred == True:
|
||||||
|
raise Exception(
|
||||||
|
f"Failed to create table - {self.database_arguments.user_table_name}.\nPlease create a new table called {self.database_arguments.user_table_name}\nAND set `hash_key` as 'user_id'"
|
||||||
|
)
|
||||||
|
## Token
|
||||||
|
try:
|
||||||
|
error_occurred = False
|
||||||
|
table = client.table(self.database_arguments.key_table_name)
|
||||||
|
if not await table.exists():
|
||||||
|
await table.create(
|
||||||
|
self.throughput_type,
|
||||||
|
KeySchema(hash_key=KeySpec("token", KeyType.string)),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
error_occurred = True
|
||||||
|
if error_occurred == True:
|
||||||
|
raise Exception(
|
||||||
|
f"Failed to create table - {self.database_arguments.key_table_name}.\nPlease create a new table called {self.database_arguments.key_table_name}\nAND set `hash_key` as 'token'"
|
||||||
|
)
|
||||||
|
## Config
|
||||||
|
try:
|
||||||
|
error_occurred = False
|
||||||
|
table = client.table(self.database_arguments.config_table_name)
|
||||||
|
if not await table.exists():
|
||||||
|
await table.create(
|
||||||
|
self.throughput_type,
|
||||||
|
KeySchema(hash_key=KeySpec("param_name", KeyType.string)),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
error_occurred = True
|
||||||
|
if error_occurred == True:
|
||||||
|
raise Exception(
|
||||||
|
f"Failed to create table - {self.database_arguments.config_table_name}.\nPlease create a new table called {self.database_arguments.config_table_name}\nAND set `hash_key` as 'param_name'"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def insert_data(
|
||||||
|
self, value: Any, table_name: Literal["user", "key", "config"]
|
||||||
|
):
|
||||||
|
async with ClientSession() as session:
|
||||||
|
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name)
|
||||||
|
table = None
|
||||||
|
if table_name == "user":
|
||||||
|
table = client.table(self.database_arguments.user_table_name)
|
||||||
|
elif table_name == "key":
|
||||||
|
table = client.table(self.database_arguments.key_table_name)
|
||||||
|
elif table_name == "config":
|
||||||
|
table = client.table(self.database_arguments.config_table_name)
|
||||||
|
|
||||||
|
for k, v in value.items():
|
||||||
|
if isinstance(v, datetime):
|
||||||
|
value[k] = v.isoformat()
|
||||||
|
|
||||||
|
await table.put_item(item=value)
|
||||||
|
|
||||||
|
async def get_data(self, key: str, table_name: Literal["user", "key", "config"]):
|
||||||
|
async with ClientSession() as session:
|
||||||
|
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name)
|
||||||
|
table = None
|
||||||
|
key_name = None
|
||||||
|
if table_name == "user":
|
||||||
|
table = client.table(self.database_arguments.user_table_name)
|
||||||
|
key_name = "user_id"
|
||||||
|
elif table_name == "key":
|
||||||
|
table = client.table(self.database_arguments.key_table_name)
|
||||||
|
key_name = "token"
|
||||||
|
elif table_name == "config":
|
||||||
|
table = client.table(self.database_arguments.config_table_name)
|
||||||
|
key_name = "param_name"
|
||||||
|
|
||||||
|
response = await table.get_item({key_name: key})
|
||||||
|
|
||||||
|
new_response: Any = None
|
||||||
|
if table_name == "user":
|
||||||
|
new_response = LiteLLM_UserTable(**response)
|
||||||
|
elif table_name == "key":
|
||||||
|
new_response = {}
|
||||||
|
for k, v in response.items(): # handle json string
|
||||||
|
if (
|
||||||
|
(k == "aliases" or k == "config" or k == "metadata")
|
||||||
|
and v is not None
|
||||||
|
and isinstance(v, str)
|
||||||
|
):
|
||||||
|
new_response[k] = json.loads(v)
|
||||||
|
else:
|
||||||
|
new_response[k] = v
|
||||||
|
new_response = LiteLLM_VerificationToken(**new_response)
|
||||||
|
elif table_name == "config":
|
||||||
|
new_response = LiteLLM_Config(**response)
|
||||||
|
return new_response
|
||||||
|
|
||||||
|
async def update_data(
|
||||||
|
self, key: str, value: dict, table_name: Literal["user", "key", "config"]
|
||||||
|
):
|
||||||
|
async with ClientSession() as session:
|
||||||
|
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name)
|
||||||
|
table = None
|
||||||
|
key_name = None
|
||||||
|
try:
|
||||||
|
if table_name == "user":
|
||||||
|
table = client.table(self.database_arguments.user_table_name)
|
||||||
|
key_name = "user_id"
|
||||||
|
|
||||||
|
elif table_name == "key":
|
||||||
|
table = client.table(self.database_arguments.key_table_name)
|
||||||
|
key_name = "token"
|
||||||
|
|
||||||
|
elif table_name == "config":
|
||||||
|
table = client.table(self.database_arguments.config_table_name)
|
||||||
|
key_name = "param_name"
|
||||||
|
else:
|
||||||
|
raise Exception(
|
||||||
|
f"Invalid table name. Needs to be one of - {self.database_arguments.user_table_name}, {self.database_arguments.key_table_name}, {self.database_arguments.config_table_name}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"Error connecting to table - {str(e)}")
|
||||||
|
|
||||||
|
# Initialize an empty UpdateExpression
|
||||||
|
|
||||||
|
actions: List = []
|
||||||
|
for k, v in value.items():
|
||||||
|
# Convert datetime object to ISO8601 string
|
||||||
|
if isinstance(v, datetime):
|
||||||
|
v = v.isoformat()
|
||||||
|
|
||||||
|
# Accumulate updates
|
||||||
|
actions.append((F(k), Value(value=v)))
|
||||||
|
|
||||||
|
update_expression = UpdateExpression(set_updates=actions)
|
||||||
|
# Perform the update in DynamoDB
|
||||||
|
result = await table.update_item(
|
||||||
|
key={key_name: key},
|
||||||
|
update_expression=update_expression,
|
||||||
|
return_values=ReturnValues.none,
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def delete_data(
|
||||||
|
self, keys: List[str], table_name: Literal["user", "key", "config"]
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Not Implemented yet.
|
||||||
|
"""
|
||||||
|
return super().delete_data(keys, table_name)
|
|
@ -67,6 +67,7 @@ def generate_feedback_box():
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.proxy.utils import (
|
from litellm.proxy.utils import (
|
||||||
PrismaClient,
|
PrismaClient,
|
||||||
|
DBClient,
|
||||||
get_instance_fn,
|
get_instance_fn,
|
||||||
ProxyLogging,
|
ProxyLogging,
|
||||||
_cache_user_row,
|
_cache_user_row,
|
||||||
|
@ -141,6 +142,7 @@ worker_config = None
|
||||||
master_key = None
|
master_key = None
|
||||||
otel_logging = False
|
otel_logging = False
|
||||||
prisma_client: Optional[PrismaClient] = None
|
prisma_client: Optional[PrismaClient] = None
|
||||||
|
custom_db_client: Optional[DBClient] = None
|
||||||
user_api_key_cache = DualCache()
|
user_api_key_cache = DualCache()
|
||||||
user_custom_auth = None
|
user_custom_auth = None
|
||||||
use_background_health_checks = None
|
use_background_health_checks = None
|
||||||
|
@ -184,12 +186,12 @@ def _get_pydantic_json_dict(pydantic_obj: BaseModel) -> dict:
|
||||||
async def user_api_key_auth(
|
async def user_api_key_auth(
|
||||||
request: Request, api_key: str = fastapi.Security(api_key_header)
|
request: Request, api_key: str = fastapi.Security(api_key_header)
|
||||||
) -> UserAPIKeyAuth:
|
) -> UserAPIKeyAuth:
|
||||||
global master_key, prisma_client, llm_model_list, user_custom_auth
|
global master_key, prisma_client, llm_model_list, user_custom_auth, custom_db_client
|
||||||
try:
|
try:
|
||||||
if isinstance(api_key, str):
|
if isinstance(api_key, str):
|
||||||
api_key = _get_bearer_token(api_key=api_key)
|
api_key = _get_bearer_token(api_key=api_key)
|
||||||
### USER-DEFINED AUTH FUNCTION ###
|
### USER-DEFINED AUTH FUNCTION ###
|
||||||
if user_custom_auth:
|
if user_custom_auth is not None:
|
||||||
response = await user_custom_auth(request=request, api_key=api_key)
|
response = await user_custom_auth(request=request, api_key=api_key)
|
||||||
return UserAPIKeyAuth.model_validate(response)
|
return UserAPIKeyAuth.model_validate(response)
|
||||||
### LITELLM-DEFINED AUTH FUNCTION ###
|
### LITELLM-DEFINED AUTH FUNCTION ###
|
||||||
|
@ -231,7 +233,7 @@ async def user_api_key_auth(
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
prisma_client is None
|
prisma_client is None and custom_db_client is None
|
||||||
): # if both master key + user key submitted, and user key != master key, and no db connected, raise an error
|
): # if both master key + user key submitted, and user key != master key, and no db connected, raise an error
|
||||||
raise Exception("No connected db.")
|
raise Exception("No connected db.")
|
||||||
|
|
||||||
|
@ -241,15 +243,22 @@ async def user_api_key_auth(
|
||||||
if valid_token is None:
|
if valid_token is None:
|
||||||
## check db
|
## check db
|
||||||
verbose_proxy_logger.debug(f"api key: {api_key}")
|
verbose_proxy_logger.debug(f"api key: {api_key}")
|
||||||
valid_token = await prisma_client.get_data(
|
if prisma_client is not None:
|
||||||
token=api_key,
|
valid_token = await prisma_client.get_data(
|
||||||
)
|
token=api_key,
|
||||||
expires = datetime.utcnow().replace(tzinfo=timezone.utc)
|
)
|
||||||
|
|
||||||
|
expires = datetime.utcnow().replace(tzinfo=timezone.utc)
|
||||||
|
elif custom_db_client is not None:
|
||||||
|
valid_token = await custom_db_client.get_data(
|
||||||
|
key=api_key, table_name="key"
|
||||||
|
)
|
||||||
# Token exists, now check expiration.
|
# Token exists, now check expiration.
|
||||||
if valid_token.expires is not None and expires is not None:
|
if valid_token.expires is not None:
|
||||||
if valid_token.expires >= expires:
|
expiry_time = datetime.fromisoformat(valid_token.expires)
|
||||||
|
if expiry_time >= datetime.utcnow():
|
||||||
# Token exists and is not expired.
|
# Token exists and is not expired.
|
||||||
return valid_token
|
return response
|
||||||
else:
|
else:
|
||||||
# Token exists but is expired.
|
# Token exists but is expired.
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
@ -291,13 +300,22 @@ async def user_api_key_auth(
|
||||||
|
|
||||||
This makes the user row data accessible to pre-api call hooks.
|
This makes the user row data accessible to pre-api call hooks.
|
||||||
"""
|
"""
|
||||||
asyncio.create_task(
|
if prisma_client is not None:
|
||||||
_cache_user_row(
|
asyncio.create_task(
|
||||||
user_id=valid_token.user_id,
|
_cache_user_row(
|
||||||
cache=user_api_key_cache,
|
user_id=valid_token.user_id,
|
||||||
db=prisma_client,
|
cache=user_api_key_cache,
|
||||||
|
db=prisma_client,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif custom_db_client is not None:
|
||||||
|
asyncio.create_task(
|
||||||
|
_cache_user_row(
|
||||||
|
user_id=valid_token.user_id,
|
||||||
|
cache=user_api_key_cache,
|
||||||
|
db=custom_db_client,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
|
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Invalid token")
|
raise Exception(f"Invalid token")
|
||||||
|
@ -371,8 +389,8 @@ def load_from_azure_key_vault(use_azure_key_vault: bool = False):
|
||||||
|
|
||||||
|
|
||||||
def cost_tracking():
|
def cost_tracking():
|
||||||
global prisma_client
|
global prisma_client, custom_db_client
|
||||||
if prisma_client is not None:
|
if prisma_client is not None or custom_db_client is not None:
|
||||||
if isinstance(litellm.success_callback, list):
|
if isinstance(litellm.success_callback, list):
|
||||||
verbose_proxy_logger.debug("setting litellm success callback to track cost")
|
verbose_proxy_logger.debug("setting litellm success callback to track cost")
|
||||||
if (track_cost_callback) not in litellm.success_callback: # type: ignore
|
if (track_cost_callback) not in litellm.success_callback: # type: ignore
|
||||||
|
@ -385,7 +403,7 @@ async def track_cost_callback(
|
||||||
start_time=None,
|
start_time=None,
|
||||||
end_time=None, # start/end time for completion
|
end_time=None, # start/end time for completion
|
||||||
):
|
):
|
||||||
global prisma_client
|
global prisma_client, custom_db_client
|
||||||
try:
|
try:
|
||||||
# check if it has collected an entire stream response
|
# check if it has collected an entire stream response
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
|
@ -404,10 +422,10 @@ async def track_cost_callback(
|
||||||
user_id = kwargs["litellm_params"]["metadata"].get(
|
user_id = kwargs["litellm_params"]["metadata"].get(
|
||||||
"user_api_key_user_id", None
|
"user_api_key_user_id", None
|
||||||
)
|
)
|
||||||
if user_api_key and prisma_client:
|
if user_api_key and (
|
||||||
await update_prisma_database(
|
prisma_client is not None or custom_db_client is not None
|
||||||
token=user_api_key, response_cost=response_cost
|
):
|
||||||
)
|
await update_database(token=user_api_key, response_cost=response_cost)
|
||||||
elif kwargs["stream"] == False: # for non streaming responses
|
elif kwargs["stream"] == False: # for non streaming responses
|
||||||
response_cost = litellm.completion_cost(
|
response_cost = litellm.completion_cost(
|
||||||
completion_response=completion_response
|
completion_response=completion_response
|
||||||
|
@ -418,15 +436,17 @@ async def track_cost_callback(
|
||||||
user_id = kwargs["litellm_params"]["metadata"].get(
|
user_id = kwargs["litellm_params"]["metadata"].get(
|
||||||
"user_api_key_user_id", None
|
"user_api_key_user_id", None
|
||||||
)
|
)
|
||||||
if user_api_key and prisma_client:
|
if user_api_key and (
|
||||||
await update_prisma_database(
|
prisma_client is not None or custom_db_client is not None
|
||||||
|
):
|
||||||
|
await update_database(
|
||||||
token=user_api_key, response_cost=response_cost, user_id=user_id
|
token=user_api_key, response_cost=response_cost, user_id=user_id
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_proxy_logger.debug(f"error in tracking cost callback - {str(e)}")
|
verbose_proxy_logger.debug(f"error in tracking cost callback - {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
async def update_prisma_database(token, response_cost, user_id=None):
|
async def update_database(token, response_cost, user_id=None):
|
||||||
try:
|
try:
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
f"Enters prisma db call, token: {token}; user_id: {user_id}"
|
f"Enters prisma db call, token: {token}; user_id: {user_id}"
|
||||||
|
@ -436,7 +456,12 @@ async def update_prisma_database(token, response_cost, user_id=None):
|
||||||
async def _update_user_db():
|
async def _update_user_db():
|
||||||
if user_id is None:
|
if user_id is None:
|
||||||
return
|
return
|
||||||
existing_spend_obj = await prisma_client.get_data(user_id=user_id)
|
if prisma_client is not None:
|
||||||
|
existing_spend_obj = await prisma_client.get_data(user_id=user_id)
|
||||||
|
elif custom_db_client is not None:
|
||||||
|
existing_spend_obj = await custom_db_client.get_data(
|
||||||
|
key=user_id, table_name="user"
|
||||||
|
)
|
||||||
if existing_spend_obj is None:
|
if existing_spend_obj is None:
|
||||||
existing_spend = 0
|
existing_spend = 0
|
||||||
else:
|
else:
|
||||||
|
@ -447,23 +472,49 @@ async def update_prisma_database(token, response_cost, user_id=None):
|
||||||
|
|
||||||
verbose_proxy_logger.debug(f"new cost: {new_spend}")
|
verbose_proxy_logger.debug(f"new cost: {new_spend}")
|
||||||
# Update the cost column for the given user id
|
# Update the cost column for the given user id
|
||||||
await prisma_client.update_data(user_id=user_id, data={"spend": new_spend})
|
if prisma_client is not None:
|
||||||
|
await prisma_client.update_data(
|
||||||
|
user_id=user_id, data={"spend": new_spend}
|
||||||
|
)
|
||||||
|
elif custom_db_client is not None:
|
||||||
|
await custom_db_client.update_data(
|
||||||
|
key=user_id, value={"spend": new_spend}, table_name="user"
|
||||||
|
)
|
||||||
|
|
||||||
### UPDATE KEY SPEND ###
|
### UPDATE KEY SPEND ###
|
||||||
async def _update_key_db():
|
async def _update_key_db():
|
||||||
# Fetch the existing cost for the given token
|
if prisma_client is not None:
|
||||||
existing_spend_obj = await prisma_client.get_data(token=token)
|
# Fetch the existing cost for the given token
|
||||||
verbose_proxy_logger.debug(f"existing spend: {existing_spend_obj}")
|
existing_spend_obj = await prisma_client.get_data(token=token)
|
||||||
if existing_spend_obj is None:
|
verbose_proxy_logger.debug(f"existing spend: {existing_spend_obj}")
|
||||||
existing_spend = 0
|
if existing_spend_obj is None:
|
||||||
else:
|
existing_spend = 0
|
||||||
existing_spend = existing_spend_obj.spend
|
else:
|
||||||
# Calculate the new cost by adding the existing cost and response_cost
|
existing_spend = existing_spend_obj.spend
|
||||||
new_spend = existing_spend + response_cost
|
# Calculate the new cost by adding the existing cost and response_cost
|
||||||
|
new_spend = existing_spend + response_cost
|
||||||
|
|
||||||
verbose_proxy_logger.debug(f"new cost: {new_spend}")
|
verbose_proxy_logger.debug(f"new cost: {new_spend}")
|
||||||
# Update the cost column for the given token
|
# Update the cost column for the given token
|
||||||
await prisma_client.update_data(token=token, data={"spend": new_spend})
|
await prisma_client.update_data(token=token, data={"spend": new_spend})
|
||||||
|
elif custom_db_client is not None:
|
||||||
|
# Fetch the existing cost for the given token
|
||||||
|
existing_spend_obj = await custom_db_client.get_data(
|
||||||
|
key=token, table_name="key"
|
||||||
|
)
|
||||||
|
verbose_proxy_logger.debug(f"existing spend: {existing_spend_obj}")
|
||||||
|
if existing_spend_obj is None:
|
||||||
|
existing_spend = 0
|
||||||
|
else:
|
||||||
|
existing_spend = existing_spend_obj.spend
|
||||||
|
# Calculate the new cost by adding the existing cost and response_cost
|
||||||
|
new_spend = existing_spend + response_cost
|
||||||
|
|
||||||
|
verbose_proxy_logger.debug(f"new cost: {new_spend}")
|
||||||
|
# Update the cost column for the given token
|
||||||
|
await custom_db_client.update_data(
|
||||||
|
key=token, value={"spend": new_spend}, table_name="key"
|
||||||
|
)
|
||||||
|
|
||||||
tasks = []
|
tasks = []
|
||||||
tasks.append(_update_user_db())
|
tasks.append(_update_user_db())
|
||||||
|
@ -622,7 +673,7 @@ class ProxyConfig:
|
||||||
"""
|
"""
|
||||||
Load config values into proxy global state
|
Load config values into proxy global state
|
||||||
"""
|
"""
|
||||||
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval, use_queue
|
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval, use_queue, custom_db_client
|
||||||
|
|
||||||
# Load existing config
|
# Load existing config
|
||||||
config = await self.get_config(config_file_path=config_file_path)
|
config = await self.get_config(config_file_path=config_file_path)
|
||||||
|
@ -787,8 +838,6 @@ class ProxyConfig:
|
||||||
verbose_proxy_logger.debug(f"GOING INTO LITELLM.GET_SECRET!")
|
verbose_proxy_logger.debug(f"GOING INTO LITELLM.GET_SECRET!")
|
||||||
database_url = litellm.get_secret(database_url)
|
database_url = litellm.get_secret(database_url)
|
||||||
verbose_proxy_logger.debug(f"RETRIEVED DB URL: {database_url}")
|
verbose_proxy_logger.debug(f"RETRIEVED DB URL: {database_url}")
|
||||||
## COST TRACKING ##
|
|
||||||
cost_tracking()
|
|
||||||
### MASTER KEY ###
|
### MASTER KEY ###
|
||||||
master_key = general_settings.get(
|
master_key = general_settings.get(
|
||||||
"master_key", litellm.get_secret("LITELLM_MASTER_KEY", None)
|
"master_key", litellm.get_secret("LITELLM_MASTER_KEY", None)
|
||||||
|
@ -796,11 +845,23 @@ class ProxyConfig:
|
||||||
if master_key and master_key.startswith("os.environ/"):
|
if master_key and master_key.startswith("os.environ/"):
|
||||||
master_key = litellm.get_secret(master_key)
|
master_key = litellm.get_secret(master_key)
|
||||||
### CUSTOM API KEY AUTH ###
|
### CUSTOM API KEY AUTH ###
|
||||||
|
## pass filepath
|
||||||
custom_auth = general_settings.get("custom_auth", None)
|
custom_auth = general_settings.get("custom_auth", None)
|
||||||
if custom_auth:
|
if custom_auth is not None:
|
||||||
user_custom_auth = get_instance_fn(
|
user_custom_auth = get_instance_fn(
|
||||||
value=custom_auth, config_file_path=config_file_path
|
value=custom_auth, config_file_path=config_file_path
|
||||||
)
|
)
|
||||||
|
## dynamodb
|
||||||
|
database_type = general_settings.get("database_type", None)
|
||||||
|
if database_type is not None and (
|
||||||
|
database_type == "dynamo_db" or database_type == "dynamodb"
|
||||||
|
):
|
||||||
|
database_args = general_settings.get("database_args", None)
|
||||||
|
custom_db_client = DBClient(
|
||||||
|
custom_db_args=database_args, custom_db_type=database_type
|
||||||
|
)
|
||||||
|
## COST TRACKING ##
|
||||||
|
cost_tracking()
|
||||||
### BACKGROUND HEALTH CHECKS ###
|
### BACKGROUND HEALTH CHECKS ###
|
||||||
# Enable background health checks
|
# Enable background health checks
|
||||||
use_background_health_checks = general_settings.get(
|
use_background_health_checks = general_settings.get(
|
||||||
|
@ -867,9 +928,9 @@ async def generate_key_helper_fn(
|
||||||
max_parallel_requests: Optional[int] = None,
|
max_parallel_requests: Optional[int] = None,
|
||||||
metadata: Optional[dict] = {},
|
metadata: Optional[dict] = {},
|
||||||
):
|
):
|
||||||
global prisma_client
|
global prisma_client, custom_db_client
|
||||||
|
|
||||||
if prisma_client is None:
|
if prisma_client is None and custom_db_client is None:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Connect Proxy to database to generate keys - https://docs.litellm.ai/docs/proxy/virtual_keys "
|
f"Connect Proxy to database to generate keys - https://docs.litellm.ai/docs/proxy/virtual_keys "
|
||||||
)
|
)
|
||||||
|
@ -908,7 +969,13 @@ async def generate_key_helper_fn(
|
||||||
user_id = user_id or str(uuid.uuid4())
|
user_id = user_id or str(uuid.uuid4())
|
||||||
try:
|
try:
|
||||||
# Create a new verification token (you may want to enhance this logic based on your needs)
|
# Create a new verification token (you may want to enhance this logic based on your needs)
|
||||||
verification_token_data = {
|
user_data = {
|
||||||
|
"max_budget": max_budget,
|
||||||
|
"user_email": user_email,
|
||||||
|
"user_id": user_id,
|
||||||
|
"spend": spend,
|
||||||
|
}
|
||||||
|
key_data = {
|
||||||
"token": token,
|
"token": token,
|
||||||
"expires": expires,
|
"expires": expires,
|
||||||
"models": models,
|
"models": models,
|
||||||
|
@ -918,19 +985,23 @@ async def generate_key_helper_fn(
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"max_parallel_requests": max_parallel_requests,
|
"max_parallel_requests": max_parallel_requests,
|
||||||
"metadata": metadata_json,
|
"metadata": metadata_json,
|
||||||
"max_budget": max_budget,
|
|
||||||
"user_email": user_email,
|
|
||||||
}
|
}
|
||||||
verbose_proxy_logger.debug("PrismaClient: Before Insert Data")
|
if prisma_client is not None:
|
||||||
new_verification_token = await prisma_client.insert_data(
|
verification_token_data = dict(key_data)
|
||||||
data=verification_token_data
|
verification_token_data.update(user_data)
|
||||||
)
|
verbose_proxy_logger.debug("PrismaClient: Before Insert Data")
|
||||||
|
await prisma_client.insert_data(data=verification_token_data)
|
||||||
|
elif custom_db_client is not None:
|
||||||
|
## CREATE USER (If necessary)
|
||||||
|
await custom_db_client.insert_data(value=user_data, table_name="user")
|
||||||
|
## CREATE KEY
|
||||||
|
await custom_db_client.insert_data(value=key_data, table_name="key")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||||
return {
|
return {
|
||||||
"token": token,
|
"token": token,
|
||||||
"expires": new_verification_token.expires,
|
"expires": expires,
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"max_budget": max_budget,
|
"max_budget": max_budget,
|
||||||
}
|
}
|
||||||
|
@ -1174,12 +1245,21 @@ async def startup_event():
|
||||||
if prisma_client is not None:
|
if prisma_client is not None:
|
||||||
await prisma_client.connect()
|
await prisma_client.connect()
|
||||||
|
|
||||||
|
if custom_db_client is not None:
|
||||||
|
await custom_db_client.connect()
|
||||||
|
|
||||||
if prisma_client is not None and master_key is not None:
|
if prisma_client is not None and master_key is not None:
|
||||||
# add master key to db
|
# add master key to db
|
||||||
await generate_key_helper_fn(
|
await generate_key_helper_fn(
|
||||||
duration=None, models=[], aliases={}, config={}, spend=0, token=master_key
|
duration=None, models=[], aliases={}, config={}, spend=0, token=master_key
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if custom_db_client is not None and master_key is not None:
|
||||||
|
# add master key to db
|
||||||
|
await generate_key_helper_fn(
|
||||||
|
duration=None, models=[], aliases={}, config={}, spend=0, token=master_key
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
#### API ENDPOINTS ####
|
#### API ENDPOINTS ####
|
||||||
@router.get(
|
@router.get(
|
||||||
|
|
|
@ -1,11 +1,13 @@
|
||||||
from typing import Optional, List, Any, Literal
|
from typing import Optional, List, Any, Literal, Union
|
||||||
import os, subprocess, hashlib, importlib, asyncio, copy, json, aiohttp, httpx
|
import os, subprocess, hashlib, importlib, asyncio, copy, json, aiohttp, httpx
|
||||||
import litellm, backoff
|
import litellm, backoff
|
||||||
from litellm.proxy._types import UserAPIKeyAuth
|
from litellm.proxy._types import UserAPIKeyAuth, DynamoDBArgs
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler
|
from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler
|
||||||
from litellm.proxy.hooks.max_budget_limiter import MaxBudgetLimiter
|
from litellm.proxy.hooks.max_budget_limiter import MaxBudgetLimiter
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
from litellm.proxy.db.base_client import CustomDB
|
||||||
|
from litellm.proxy.db.dynamo_db import DynamoDBWrapper
|
||||||
from fastapi import HTTPException, status
|
from fastapi import HTTPException, status
|
||||||
import smtplib
|
import smtplib
|
||||||
from email.mime.text import MIMEText
|
from email.mime.text import MIMEText
|
||||||
|
@ -578,6 +580,65 @@ class PrismaClient:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
class DBClient:
|
||||||
|
"""
|
||||||
|
Routes requests for CustomAuth
|
||||||
|
|
||||||
|
[TODO] route b/w customauth and prisma
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, custom_db_type: Literal["dynamo_db"], custom_db_args: dict
|
||||||
|
) -> None:
|
||||||
|
if custom_db_type == "dynamo_db":
|
||||||
|
self.db = DynamoDBWrapper(database_arguments=DynamoDBArgs(**custom_db_args))
|
||||||
|
|
||||||
|
async def get_data(self, key: str, table_name: Literal["user", "key", "config"]):
|
||||||
|
"""
|
||||||
|
Check if key valid
|
||||||
|
"""
|
||||||
|
return await self.db.get_data(key=key, table_name=table_name)
|
||||||
|
|
||||||
|
async def insert_data(
|
||||||
|
self, value: Any, table_name: Literal["user", "key", "config"]
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
For new key / user logic
|
||||||
|
"""
|
||||||
|
return await self.db.insert_data(value=value, table_name=table_name)
|
||||||
|
|
||||||
|
async def update_data(
|
||||||
|
self, key: str, value: Any, table_name: Literal["user", "key", "config"]
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
For cost tracking logic
|
||||||
|
|
||||||
|
key - hash_key value \n
|
||||||
|
value - dict with updated values
|
||||||
|
"""
|
||||||
|
return await self.db.update_data(key=key, value=value, table_name=table_name)
|
||||||
|
|
||||||
|
async def delete_data(
|
||||||
|
self, keys: List[str], table_name: Literal["user", "key", "config"]
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
For /key/delete endpoints
|
||||||
|
"""
|
||||||
|
return await self.db.delete_data(keys=keys, table_name=table_name)
|
||||||
|
|
||||||
|
async def connect(self):
|
||||||
|
"""
|
||||||
|
For connecting to db and creating / updating any tables
|
||||||
|
"""
|
||||||
|
return await self.db.connect()
|
||||||
|
|
||||||
|
async def disconnect(self):
|
||||||
|
"""
|
||||||
|
For closing connection on server shutdown
|
||||||
|
"""
|
||||||
|
return await self.db.disconnect()
|
||||||
|
|
||||||
|
|
||||||
### CUSTOM FILE ###
|
### CUSTOM FILE ###
|
||||||
def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
|
def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
|
||||||
try:
|
try:
|
||||||
|
@ -618,7 +679,9 @@ def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
|
||||||
|
|
||||||
|
|
||||||
### HELPER FUNCTIONS ###
|
### HELPER FUNCTIONS ###
|
||||||
async def _cache_user_row(user_id: str, cache: DualCache, db: PrismaClient):
|
async def _cache_user_row(
|
||||||
|
user_id: str, cache: DualCache, db: Union[PrismaClient, DBClient]
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Check if a user_id exists in cache,
|
Check if a user_id exists in cache,
|
||||||
if not retrieve it.
|
if not retrieve it.
|
||||||
|
@ -627,7 +690,10 @@ async def _cache_user_row(user_id: str, cache: DualCache, db: PrismaClient):
|
||||||
cache_key = f"{user_id}_user_api_key_user_id"
|
cache_key = f"{user_id}_user_api_key_user_id"
|
||||||
response = cache.get_cache(key=cache_key)
|
response = cache.get_cache(key=cache_key)
|
||||||
if response is None: # Cache miss
|
if response is None: # Cache miss
|
||||||
user_row = await db.get_data(user_id=user_id)
|
if isinstance(db, PrismaClient):
|
||||||
|
user_row = await db.get_data(user_id=user_id)
|
||||||
|
elif isinstance(db, DBClient):
|
||||||
|
user_row = await db.get_data(key=user_id, table_name="user")
|
||||||
if user_row is not None:
|
if user_row is not None:
|
||||||
print_verbose(f"User Row: {user_row}, type = {type(user_row)}")
|
print_verbose(f"User Row: {user_row}, type = {type(user_row)}")
|
||||||
if hasattr(user_row, "model_dump_json") and callable(
|
if hasattr(user_row, "model_dump_json") and callable(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue