Merge branch 'litellm_dynamo_db_keys'

This commit is contained in:
Krrish Dholakia 2024-01-13 18:38:43 +05:30
commit 8a7a745549
12 changed files with 547 additions and 58 deletions

Binary file not shown.

BIN
dist/litellm-1.16.21.dev1.tar.gz vendored Normal file

Binary file not shown.

Binary file not shown.

BIN
dist/litellm-1.16.21.dev2.tar.gz vendored Normal file

Binary file not shown.

Binary file not shown.

BIN
dist/litellm-1.16.21.dev3.tar.gz vendored Normal file

Binary file not shown.

View file

@ -246,3 +246,35 @@ general_settings:
$ litellm --config /path/to/config.yaml
```
## [BETA] Dynamo DB
Only live in `v1.16.21.dev1`.
### Step 1. Save keys to env
```env
AWS_ACCESS_KEY_ID = "your-aws-access-key-id"
AWS_SECRET_ACCESS_KEY = "your-aws-secret-access-key"
```
### Step 2. Add details to config
```yaml
general_settings:
master_key: sk-1234
database_type: "dynamo_db"
database_args: { # 👈 all args - https://github.com/BerriAI/litellm/blob/befbcbb7ac8f59835ce47415c128decf37aac328/litellm/proxy/_types.py#L190
"billing_mode": "PAY_PER_REQUEST",
"region_name": "us-west-2"
}
```
### Step 3. Generate Key
```bash
curl --location 'http://0.0.0.0:8000/key/generate' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{"models": ["azure-models"], "aliases": {"mistral-7b": "gpt-3.5-turbo"}, "duration": null}'
```

View file

@ -17,6 +17,13 @@ class LiteLLMBase(BaseModel):
# if using pydantic v1
return self.dict()
def fields_set(self):
try:
return self.model_fields_set # noqa
except:
# if using pydantic v1
return self.__fields_set__
######### Request Class Definition ######
class ProxyChatCompletionRequest(LiteLLMBase):
@ -182,6 +189,16 @@ class KeyManagementSystem(enum.Enum):
LOCAL = "local"
class DynamoDBArgs(LiteLLMBase):
billing_mode: Literal["PROVISIONED_THROUGHPUT", "PAY_PER_REQUEST"]
read_capacity_units: Optional[int] = None
write_capacity_units: Optional[int] = None
region_name: str
user_table_name: str = "LiteLLM_UserTable"
key_table_name: str = "LiteLLM_VerificationToken"
config_table_name: str = "LiteLLM_Config"
class ConfigGeneralSettings(LiteLLMBase):
"""
Documents all the fields supported by `general_settings` in config.yaml
@ -206,6 +223,13 @@ class ConfigGeneralSettings(LiteLLMBase):
None,
description="connect to a postgres db - needed for generating temporary keys + tracking spend / key",
)
database_type: Optional[Literal["dynamo_db"]] = Field(
None, description="to use dynamodb instead of postgres db"
)
database_args: Optional[DynamoDBArgs] = Field(
None,
description="custom args for instantiating dynamodb client - e.g. billing provision",
)
otel: Optional[bool] = Field(
None,
description="[BETA] OpenTelemetry support - this might change, use with caution.",
@ -258,3 +282,33 @@ class ConfigYAML(LiteLLMBase):
class Config:
protected_namespaces = ()
class LiteLLM_VerificationToken(LiteLLMBase):
token: str
spend: float = 0.0
expires: Union[str, None]
models: List[str]
aliases: Dict[str, str] = {}
config: Dict[str, str] = {}
user_id: Union[str, None]
max_parallel_requests: Union[int, None]
metadata: Dict[str, str] = {}
class LiteLLM_Config(LiteLLMBase):
param_name: str
param_value: Dict
class LiteLLM_UserTable(LiteLLMBase):
user_id: str
max_budget: Optional[float]
spend: float = 0.0
user_email: Optional[str]
@root_validator(pre=True)
def set_model_info(cls, values):
if values.get("spend") is None:
values.update({"spend": 0.0})
return values

View file

@ -0,0 +1,53 @@
from typing import Any, Literal, List
class CustomDB:
"""
Implements a base class that we expect any custom db implementation (e.g. DynamoDB) to follow
"""
def __init__(self) -> None:
pass
def get_data(self, key: str, table_name: Literal["user", "key", "config"]):
"""
Check if key valid
"""
pass
def insert_data(self, value: Any, table_name: Literal["user", "key", "config"]):
"""
For new key / user logic
"""
pass
def update_data(
self, key: str, value: Any, table_name: Literal["user", "key", "config"]
):
"""
For cost tracking logic
"""
pass
def delete_data(
self, keys: List[str], table_name: Literal["user", "key", "config"]
):
"""
For /key/delete endpoint s
"""
def connect(
self,
):
"""
For connecting to db and creating / updating any tables
"""
pass
def disconnect(
self,
):
"""
For closing connection on server shutdown
"""
pass

View file

@ -0,0 +1,204 @@
import json
from aiodynamo.client import Client
from aiodynamo.credentials import Credentials, StaticCredentials
from aiodynamo.http.httpx import HTTPX
from aiodynamo.models import Throughput, KeySchema, KeySpec, KeyType, PayPerRequest
from yarl import URL
from litellm.proxy.db.base_client import CustomDB
from litellm.proxy._types import (
DynamoDBArgs,
LiteLLM_VerificationToken,
LiteLLM_Config,
LiteLLM_UserTable,
)
from litellm import get_secret
from typing import Any, List, Literal, Optional, Union
from aiodynamo.expressions import UpdateExpression, F, Value
from aiodynamo.models import ReturnValues
from aiodynamo.http.aiohttp import AIOHTTP
from aiohttp import ClientSession
from datetime import datetime
class DynamoDBWrapper(CustomDB):
credentials: Credentials
def __init__(self, database_arguments: DynamoDBArgs):
self.throughput_type = None
if database_arguments.billing_mode == "PAY_PER_REQUEST":
self.throughput_type = PayPerRequest()
elif database_arguments.billing_mode == "PROVISIONED_THROUGHPUT":
if (
database_arguments.read_capacity_units is not None
and isinstance(database_arguments.read_capacity_units, int)
and database_arguments.write_capacity_units is not None
and isinstance(database_arguments.write_capacity_units, int)
):
self.throughput_type = Throughput(read=database_arguments.read_capacity_units, write=database_arguments.write_capacity_units) # type: ignore
else:
raise Exception(
f"Invalid args passed in. Need to set both read_capacity_units and write_capacity_units. Args passed in - {database_arguments}"
)
self.database_arguments = database_arguments
self.region_name = database_arguments.region_name
async def connect(self):
"""
Connect to DB, and creating / updating any tables
"""
async with ClientSession() as session:
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name)
## User
try:
error_occurred = False
table = client.table(self.database_arguments.user_table_name)
if not await table.exists():
await table.create(
self.throughput_type,
KeySchema(hash_key=KeySpec("user_id", KeyType.string)),
)
except Exception as e:
error_occurred = True
if error_occurred == True:
raise Exception(
f"Failed to create table - {self.database_arguments.user_table_name}.\nPlease create a new table called {self.database_arguments.user_table_name}\nAND set `hash_key` as 'user_id'"
)
## Token
try:
error_occurred = False
table = client.table(self.database_arguments.key_table_name)
if not await table.exists():
await table.create(
self.throughput_type,
KeySchema(hash_key=KeySpec("token", KeyType.string)),
)
except Exception as e:
error_occurred = True
if error_occurred == True:
raise Exception(
f"Failed to create table - {self.database_arguments.key_table_name}.\nPlease create a new table called {self.database_arguments.key_table_name}\nAND set `hash_key` as 'token'"
)
## Config
try:
error_occurred = False
table = client.table(self.database_arguments.config_table_name)
if not await table.exists():
await table.create(
self.throughput_type,
KeySchema(hash_key=KeySpec("param_name", KeyType.string)),
)
except Exception as e:
error_occurred = True
if error_occurred == True:
raise Exception(
f"Failed to create table - {self.database_arguments.config_table_name}.\nPlease create a new table called {self.database_arguments.config_table_name}\nAND set `hash_key` as 'param_name'"
)
async def insert_data(
self, value: Any, table_name: Literal["user", "key", "config"]
):
async with ClientSession() as session:
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name)
table = None
if table_name == "user":
table = client.table(self.database_arguments.user_table_name)
elif table_name == "key":
table = client.table(self.database_arguments.key_table_name)
elif table_name == "config":
table = client.table(self.database_arguments.config_table_name)
for k, v in value.items():
if isinstance(v, datetime):
value[k] = v.isoformat()
await table.put_item(item=value)
async def get_data(self, key: str, table_name: Literal["user", "key", "config"]):
async with ClientSession() as session:
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name)
table = None
key_name = None
if table_name == "user":
table = client.table(self.database_arguments.user_table_name)
key_name = "user_id"
elif table_name == "key":
table = client.table(self.database_arguments.key_table_name)
key_name = "token"
elif table_name == "config":
table = client.table(self.database_arguments.config_table_name)
key_name = "param_name"
response = await table.get_item({key_name: key})
new_response: Any = None
if table_name == "user":
new_response = LiteLLM_UserTable(**response)
elif table_name == "key":
new_response = {}
for k, v in response.items(): # handle json string
if (
(k == "aliases" or k == "config" or k == "metadata")
and v is not None
and isinstance(v, str)
):
new_response[k] = json.loads(v)
else:
new_response[k] = v
new_response = LiteLLM_VerificationToken(**new_response)
elif table_name == "config":
new_response = LiteLLM_Config(**response)
return new_response
async def update_data(
self, key: str, value: dict, table_name: Literal["user", "key", "config"]
):
async with ClientSession() as session:
client = Client(AIOHTTP(session), Credentials.auto(), self.region_name)
table = None
key_name = None
try:
if table_name == "user":
table = client.table(self.database_arguments.user_table_name)
key_name = "user_id"
elif table_name == "key":
table = client.table(self.database_arguments.key_table_name)
key_name = "token"
elif table_name == "config":
table = client.table(self.database_arguments.config_table_name)
key_name = "param_name"
else:
raise Exception(
f"Invalid table name. Needs to be one of - {self.database_arguments.user_table_name}, {self.database_arguments.key_table_name}, {self.database_arguments.config_table_name}"
)
except Exception as e:
raise Exception(f"Error connecting to table - {str(e)}")
# Initialize an empty UpdateExpression
actions: List = []
for k, v in value.items():
# Convert datetime object to ISO8601 string
if isinstance(v, datetime):
v = v.isoformat()
# Accumulate updates
actions.append((F(k), Value(value=v)))
update_expression = UpdateExpression(set_updates=actions)
# Perform the update in DynamoDB
result = await table.update_item(
key={key_name: key},
update_expression=update_expression,
return_values=ReturnValues.none,
)
return result
async def delete_data(
self, keys: List[str], table_name: Literal["user", "key", "config"]
):
"""
Not Implemented yet.
"""
return super().delete_data(keys, table_name)

View file

@ -67,6 +67,7 @@ def generate_feedback_box():
import litellm
from litellm.proxy.utils import (
PrismaClient,
DBClient,
get_instance_fn,
ProxyLogging,
_cache_user_row,
@ -141,6 +142,7 @@ worker_config = None
master_key = None
otel_logging = False
prisma_client: Optional[PrismaClient] = None
custom_db_client: Optional[DBClient] = None
user_api_key_cache = DualCache()
user_custom_auth = None
use_background_health_checks = None
@ -184,12 +186,12 @@ def _get_pydantic_json_dict(pydantic_obj: BaseModel) -> dict:
async def user_api_key_auth(
request: Request, api_key: str = fastapi.Security(api_key_header)
) -> UserAPIKeyAuth:
global master_key, prisma_client, llm_model_list, user_custom_auth
global master_key, prisma_client, llm_model_list, user_custom_auth, custom_db_client
try:
if isinstance(api_key, str):
api_key = _get_bearer_token(api_key=api_key)
### USER-DEFINED AUTH FUNCTION ###
if user_custom_auth:
if user_custom_auth is not None:
response = await user_custom_auth(request=request, api_key=api_key)
return UserAPIKeyAuth.model_validate(response)
### LITELLM-DEFINED AUTH FUNCTION ###
@ -231,7 +233,7 @@ async def user_api_key_auth(
)
if (
prisma_client is None
prisma_client is None and custom_db_client is None
): # if both master key + user key submitted, and user key != master key, and no db connected, raise an error
raise Exception("No connected db.")
@ -241,15 +243,22 @@ async def user_api_key_auth(
if valid_token is None:
## check db
verbose_proxy_logger.debug(f"api key: {api_key}")
if prisma_client is not None:
valid_token = await prisma_client.get_data(
token=api_key,
)
expires = datetime.utcnow().replace(tzinfo=timezone.utc)
elif custom_db_client is not None:
valid_token = await custom_db_client.get_data(
key=api_key, table_name="key"
)
# Token exists, now check expiration.
if valid_token.expires is not None and expires is not None:
if valid_token.expires >= expires:
if valid_token.expires is not None:
expiry_time = datetime.fromisoformat(valid_token.expires)
if expiry_time >= datetime.utcnow():
# Token exists and is not expired.
return valid_token
return response
else:
# Token exists but is expired.
raise HTTPException(
@ -291,6 +300,7 @@ async def user_api_key_auth(
This makes the user row data accessible to pre-api call hooks.
"""
if prisma_client is not None:
asyncio.create_task(
_cache_user_row(
user_id=valid_token.user_id,
@ -298,6 +308,14 @@ async def user_api_key_auth(
db=prisma_client,
)
)
elif custom_db_client is not None:
asyncio.create_task(
_cache_user_row(
user_id=valid_token.user_id,
cache=user_api_key_cache,
db=custom_db_client,
)
)
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
else:
raise Exception(f"Invalid token")
@ -371,8 +389,8 @@ def load_from_azure_key_vault(use_azure_key_vault: bool = False):
def cost_tracking():
global prisma_client
if prisma_client is not None:
global prisma_client, custom_db_client
if prisma_client is not None or custom_db_client is not None:
if isinstance(litellm.success_callback, list):
verbose_proxy_logger.debug("setting litellm success callback to track cost")
if (track_cost_callback) not in litellm.success_callback: # type: ignore
@ -385,7 +403,7 @@ async def track_cost_callback(
start_time=None,
end_time=None, # start/end time for completion
):
global prisma_client
global prisma_client, custom_db_client
try:
# check if it has collected an entire stream response
verbose_proxy_logger.debug(
@ -404,10 +422,10 @@ async def track_cost_callback(
user_id = kwargs["litellm_params"]["metadata"].get(
"user_api_key_user_id", None
)
if user_api_key and prisma_client:
await update_prisma_database(
token=user_api_key, response_cost=response_cost
)
if user_api_key and (
prisma_client is not None or custom_db_client is not None
):
await update_database(token=user_api_key, response_cost=response_cost)
elif kwargs["stream"] == False: # for non streaming responses
response_cost = litellm.completion_cost(
completion_response=completion_response
@ -418,15 +436,17 @@ async def track_cost_callback(
user_id = kwargs["litellm_params"]["metadata"].get(
"user_api_key_user_id", None
)
if user_api_key and prisma_client:
await update_prisma_database(
if user_api_key and (
prisma_client is not None or custom_db_client is not None
):
await update_database(
token=user_api_key, response_cost=response_cost, user_id=user_id
)
except Exception as e:
verbose_proxy_logger.debug(f"error in tracking cost callback - {str(e)}")
async def update_prisma_database(token, response_cost, user_id=None):
async def update_database(token, response_cost, user_id=None):
try:
verbose_proxy_logger.debug(
f"Enters prisma db call, token: {token}; user_id: {user_id}"
@ -436,7 +456,12 @@ async def update_prisma_database(token, response_cost, user_id=None):
async def _update_user_db():
if user_id is None:
return
if prisma_client is not None:
existing_spend_obj = await prisma_client.get_data(user_id=user_id)
elif custom_db_client is not None:
existing_spend_obj = await custom_db_client.get_data(
key=user_id, table_name="user"
)
if existing_spend_obj is None:
existing_spend = 0
else:
@ -447,10 +472,18 @@ async def update_prisma_database(token, response_cost, user_id=None):
verbose_proxy_logger.debug(f"new cost: {new_spend}")
# Update the cost column for the given user id
await prisma_client.update_data(user_id=user_id, data={"spend": new_spend})
if prisma_client is not None:
await prisma_client.update_data(
user_id=user_id, data={"spend": new_spend}
)
elif custom_db_client is not None:
await custom_db_client.update_data(
key=user_id, value={"spend": new_spend}, table_name="user"
)
### UPDATE KEY SPEND ###
async def _update_key_db():
if prisma_client is not None:
# Fetch the existing cost for the given token
existing_spend_obj = await prisma_client.get_data(token=token)
verbose_proxy_logger.debug(f"existing spend: {existing_spend_obj}")
@ -464,6 +497,24 @@ async def update_prisma_database(token, response_cost, user_id=None):
verbose_proxy_logger.debug(f"new cost: {new_spend}")
# Update the cost column for the given token
await prisma_client.update_data(token=token, data={"spend": new_spend})
elif custom_db_client is not None:
# Fetch the existing cost for the given token
existing_spend_obj = await custom_db_client.get_data(
key=token, table_name="key"
)
verbose_proxy_logger.debug(f"existing spend: {existing_spend_obj}")
if existing_spend_obj is None:
existing_spend = 0
else:
existing_spend = existing_spend_obj.spend
# Calculate the new cost by adding the existing cost and response_cost
new_spend = existing_spend + response_cost
verbose_proxy_logger.debug(f"new cost: {new_spend}")
# Update the cost column for the given token
await custom_db_client.update_data(
key=token, value={"spend": new_spend}, table_name="key"
)
tasks = []
tasks.append(_update_user_db())
@ -622,7 +673,7 @@ class ProxyConfig:
"""
Load config values into proxy global state
"""
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval, use_queue
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval, use_queue, custom_db_client
# Load existing config
config = await self.get_config(config_file_path=config_file_path)
@ -787,8 +838,6 @@ class ProxyConfig:
verbose_proxy_logger.debug(f"GOING INTO LITELLM.GET_SECRET!")
database_url = litellm.get_secret(database_url)
verbose_proxy_logger.debug(f"RETRIEVED DB URL: {database_url}")
## COST TRACKING ##
cost_tracking()
### MASTER KEY ###
master_key = general_settings.get(
"master_key", litellm.get_secret("LITELLM_MASTER_KEY", None)
@ -796,11 +845,23 @@ class ProxyConfig:
if master_key and master_key.startswith("os.environ/"):
master_key = litellm.get_secret(master_key)
### CUSTOM API KEY AUTH ###
## pass filepath
custom_auth = general_settings.get("custom_auth", None)
if custom_auth:
if custom_auth is not None:
user_custom_auth = get_instance_fn(
value=custom_auth, config_file_path=config_file_path
)
## dynamodb
database_type = general_settings.get("database_type", None)
if database_type is not None and (
database_type == "dynamo_db" or database_type == "dynamodb"
):
database_args = general_settings.get("database_args", None)
custom_db_client = DBClient(
custom_db_args=database_args, custom_db_type=database_type
)
## COST TRACKING ##
cost_tracking()
### BACKGROUND HEALTH CHECKS ###
# Enable background health checks
use_background_health_checks = general_settings.get(
@ -867,9 +928,9 @@ async def generate_key_helper_fn(
max_parallel_requests: Optional[int] = None,
metadata: Optional[dict] = {},
):
global prisma_client
global prisma_client, custom_db_client
if prisma_client is None:
if prisma_client is None and custom_db_client is None:
raise Exception(
f"Connect Proxy to database to generate keys - https://docs.litellm.ai/docs/proxy/virtual_keys "
)
@ -908,7 +969,13 @@ async def generate_key_helper_fn(
user_id = user_id or str(uuid.uuid4())
try:
# Create a new verification token (you may want to enhance this logic based on your needs)
verification_token_data = {
user_data = {
"max_budget": max_budget,
"user_email": user_email,
"user_id": user_id,
"spend": spend,
}
key_data = {
"token": token,
"expires": expires,
"models": models,
@ -918,19 +985,23 @@ async def generate_key_helper_fn(
"user_id": user_id,
"max_parallel_requests": max_parallel_requests,
"metadata": metadata_json,
"max_budget": max_budget,
"user_email": user_email,
}
if prisma_client is not None:
verification_token_data = dict(key_data)
verification_token_data.update(user_data)
verbose_proxy_logger.debug("PrismaClient: Before Insert Data")
new_verification_token = await prisma_client.insert_data(
data=verification_token_data
)
await prisma_client.insert_data(data=verification_token_data)
elif custom_db_client is not None:
## CREATE USER (If necessary)
await custom_db_client.insert_data(value=user_data, table_name="user")
## CREATE KEY
await custom_db_client.insert_data(value=key_data, table_name="key")
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
return {
"token": token,
"expires": new_verification_token.expires,
"expires": expires,
"user_id": user_id,
"max_budget": max_budget,
}
@ -1174,12 +1245,21 @@ async def startup_event():
if prisma_client is not None:
await prisma_client.connect()
if custom_db_client is not None:
await custom_db_client.connect()
if prisma_client is not None and master_key is not None:
# add master key to db
await generate_key_helper_fn(
duration=None, models=[], aliases={}, config={}, spend=0, token=master_key
)
if custom_db_client is not None and master_key is not None:
# add master key to db
await generate_key_helper_fn(
duration=None, models=[], aliases={}, config={}, spend=0, token=master_key
)
#### API ENDPOINTS ####
@router.get(

View file

@ -1,11 +1,13 @@
from typing import Optional, List, Any, Literal
from typing import Optional, List, Any, Literal, Union
import os, subprocess, hashlib, importlib, asyncio, copy, json, aiohttp, httpx
import litellm, backoff
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy._types import UserAPIKeyAuth, DynamoDBArgs
from litellm.caching import DualCache
from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler
from litellm.proxy.hooks.max_budget_limiter import MaxBudgetLimiter
from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy.db.base_client import CustomDB
from litellm.proxy.db.dynamo_db import DynamoDBWrapper
from fastapi import HTTPException, status
import smtplib
from email.mime.text import MIMEText
@ -578,6 +580,65 @@ class PrismaClient:
raise e
class DBClient:
"""
Routes requests for CustomAuth
[TODO] route b/w customauth and prisma
"""
def __init__(
self, custom_db_type: Literal["dynamo_db"], custom_db_args: dict
) -> None:
if custom_db_type == "dynamo_db":
self.db = DynamoDBWrapper(database_arguments=DynamoDBArgs(**custom_db_args))
async def get_data(self, key: str, table_name: Literal["user", "key", "config"]):
"""
Check if key valid
"""
return await self.db.get_data(key=key, table_name=table_name)
async def insert_data(
self, value: Any, table_name: Literal["user", "key", "config"]
):
"""
For new key / user logic
"""
return await self.db.insert_data(value=value, table_name=table_name)
async def update_data(
self, key: str, value: Any, table_name: Literal["user", "key", "config"]
):
"""
For cost tracking logic
key - hash_key value \n
value - dict with updated values
"""
return await self.db.update_data(key=key, value=value, table_name=table_name)
async def delete_data(
self, keys: List[str], table_name: Literal["user", "key", "config"]
):
"""
For /key/delete endpoints
"""
return await self.db.delete_data(keys=keys, table_name=table_name)
async def connect(self):
"""
For connecting to db and creating / updating any tables
"""
return await self.db.connect()
async def disconnect(self):
"""
For closing connection on server shutdown
"""
return await self.db.disconnect()
### CUSTOM FILE ###
def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
try:
@ -618,7 +679,9 @@ def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
### HELPER FUNCTIONS ###
async def _cache_user_row(user_id: str, cache: DualCache, db: PrismaClient):
async def _cache_user_row(
user_id: str, cache: DualCache, db: Union[PrismaClient, DBClient]
):
"""
Check if a user_id exists in cache,
if not retrieve it.
@ -627,7 +690,10 @@ async def _cache_user_row(user_id: str, cache: DualCache, db: PrismaClient):
cache_key = f"{user_id}_user_api_key_user_id"
response = cache.get_cache(key=cache_key)
if response is None: # Cache miss
if isinstance(db, PrismaClient):
user_row = await db.get_data(user_id=user_id)
elif isinstance(db, DBClient):
user_row = await db.get_data(key=user_id, table_name="user")
if user_row is not None:
print_verbose(f"User Row: {user_row}, type = {type(user_row)}")
if hasattr(user_row, "model_dump_json") and callable(