litellm-mirror/litellm/proxy/db/prisma_client.py
Krish Dholakia 0a03f2f11e LiteLLM Minor Fixes & Improvements (09/25/2024) (#5893)
* fix(langfuse.py): support new langfuse prompt_chat class init params

* fix(langfuse.py): handle new init values on prompt chat + prompt text templates

fixes error caused during langfuse logging

* docs(openai_compatible.md): clarify `openai/` handles correct routing for `/v1/completions` route

Fixes https://github.com/BerriAI/litellm/issues/5876

* fix(utils.py): handle unmapped gemini model optional param translation

Fixes https://github.com/BerriAI/litellm/issues/5888

* fix(o1_transformation.py): fix o-1 validation, to not raise error if temperature=1

Fixes https://github.com/BerriAI/litellm/issues/5884

* fix(prisma_client.py): refresh iam token

Fixes https://github.com/BerriAI/litellm/issues/5896

* fix: pass drop params where required

* fix(utils.py): pass drop_params correctly

* fix(types/vertex_ai.py): fix generation config

* test(test_max_completion_tokens.py): fix test

* fix(vertex_and_google_ai_studio_gemini.py): fix map openai params
2024-09-26 16:41:44 -07:00

106 lines
3.7 KiB
Python

import asyncio
import os
import urllib
import urllib.parse
from datetime import datetime, timedelta
from typing import Any, Callable, Optional
class PrismaWrapper:
def __init__(self, original_prisma: Any, iam_token_db_auth: bool):
self._original_prisma = original_prisma
self.iam_token_db_auth = iam_token_db_auth
def is_token_expired(self, token_url: Optional[str]) -> bool:
if token_url is None:
return True
# Decode the token URL to handle URL-encoded characters
decoded_url = urllib.parse.unquote(token_url)
# Parse the token URL
parsed_url = urllib.parse.urlparse(decoded_url)
# Parse the query parameters from the path component (if they exist there)
query_params = urllib.parse.parse_qs(parsed_url.query)
# Get expiration time from the query parameters
expires = query_params.get("X-Amz-Expires", [None])[0]
if expires is None:
raise ValueError("X-Amz-Expires parameter is missing or invalid.")
expires_int = int(expires)
# Get the token's creation time from the X-Amz-Date parameter
token_time_str = query_params.get("X-Amz-Date", [""])[0]
if not token_time_str:
raise ValueError("X-Amz-Date parameter is missing or invalid.")
# Ensure the token time string is parsed correctly
try:
token_time = datetime.strptime(token_time_str, "%Y%m%dT%H%M%SZ")
except ValueError as e:
raise ValueError(f"Invalid X-Amz-Date format: {e}")
# Calculate the expiration time
expiration_time = token_time + timedelta(seconds=expires_int)
# Current time in UTC
current_time = datetime.utcnow()
# Check if the token is expired
return current_time > expiration_time
def get_rds_iam_token(self) -> Optional[str]:
if self.iam_token_db_auth:
from litellm.proxy.auth.rds_iam_token import generate_iam_auth_token
db_host = os.getenv("DATABASE_HOST")
db_port = os.getenv("DATABASE_PORT")
db_user = os.getenv("DATABASE_USER")
db_name = os.getenv("DATABASE_NAME")
db_schema = os.getenv("DATABASE_SCHEMA")
token = generate_iam_auth_token(
db_host=db_host, db_port=db_port, db_user=db_user
)
# print(f"token: {token}")
_db_url = f"postgresql://{db_user}:{token}@{db_host}:{db_port}/{db_name}"
if db_schema:
_db_url += f"?schema={db_schema}"
os.environ["DATABASE_URL"] = _db_url
return _db_url
return None
async def recreate_prisma_client(
self, new_db_url: str, http_client: Optional[Any] = None
):
from prisma import Prisma # type: ignore
if http_client is not None:
self._original_prisma = Prisma(http=http_client)
else:
self._original_prisma = Prisma()
await self._original_prisma.connect()
def __getattr__(self, name: str):
original_attr = getattr(self._original_prisma, name)
if self.iam_token_db_auth:
db_url = os.getenv("DATABASE_URL")
if self.is_token_expired(db_url):
db_url = self.get_rds_iam_token()
loop = asyncio.get_event_loop()
if db_url:
if loop.is_running():
asyncio.run_coroutine_threadsafe(
self.recreate_prisma_client(db_url), loop
)
else:
asyncio.run(self.recreate_prisma_client(db_url))
else:
raise ValueError("Failed to get RDS IAM token")
return original_attr