Litellm dev 01 13 2025 p2 (#7758)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 12s

* fix(factory.py): fix bedrock document url check

Make check more generic - if starts with 'text' or 'application' assume it's a document and let it go through

 Fixes https://github.com/BerriAI/litellm/issues/7746

* feat(key_management_endpoints.py): support writing new key alias to aws secret manager - on key rotation

adds rotation endpoint to aws key management hook - allows for rotated litellm virtual keys with new key alias to be written to it

* feat(key_management_event_hooks.py): support rotating keys and updating secret manager

* refactor(base_secret_manager.py): support rotate secret at the base level

since it's just an abstraction function, it's easy to implement at the base manager level

* style: cleanup unused imports
This commit is contained in:
Krish Dholakia 2025-01-14 17:04:01 -08:00 committed by GitHub
parent 7b27cfb0ae
commit 35919d9fec
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 209 additions and 54 deletions

View file

@ -2185,12 +2185,7 @@ def get_image_details(image_url) -> Tuple[str, str]:
# Convert the image content to base64 bytes
base64_bytes = base64.b64encode(response.content).decode("utf-8")
# Get mime-type
mime_type = content_type.split("/")[
1
] # Extract mime-type from content-type header
return base64_bytes, mime_type
return base64_bytes, content_type
except Exception as e:
raise e
@ -2216,50 +2211,37 @@ def _process_bedrock_converse_image_block(
mime_type = "image/jpeg"
image_format = "jpeg"
_blob = BedrockSourceBlock(bytes=img_without_base_64)
supported_image_formats = (
litellm.AmazonConverseConfig().get_supported_image_types()
)
supported_document_types = (
litellm.AmazonConverseConfig().get_supported_document_types()
)
if image_format in supported_image_formats:
return BedrockContentBlock(image=BedrockImageBlock(source=_blob, format=image_format)) # type: ignore
elif image_format in supported_document_types:
return BedrockContentBlock(document=BedrockDocumentBlock(source=_blob, format=image_format, name="DocumentPDFmessages_{}".format(str(uuid.uuid4())))) # type: ignore
else:
# Handle the case when the image format is not supported
raise ValueError(
"Unsupported image format: {}. Supported formats: {}".format(
image_format, supported_image_formats
)
)
elif "https:/" in image_url:
# Case 2: Images with direct links
image_bytes, image_format = get_image_details(image_url)
image_bytes, mime_type = get_image_details(image_url)
image_format = mime_type.split("/")[1]
_blob = BedrockSourceBlock(bytes=image_bytes)
supported_image_formats = (
litellm.AmazonConverseConfig().get_supported_image_types()
)
supported_document_types = (
litellm.AmazonConverseConfig().get_supported_document_types()
)
if image_format in supported_image_formats:
return BedrockContentBlock(image=BedrockImageBlock(source=_blob, format=image_format)) # type: ignore
elif image_format in supported_document_types:
return BedrockContentBlock(document=BedrockDocumentBlock(source=_blob, format=image_format, name="DocumentPDFmessages_{}".format(str(uuid.uuid4())))) # type: ignore
else:
# Handle the case when the image format is not supported
raise ValueError(
"Unsupported image format: {}. Supported formats: {}".format(
image_format, supported_image_formats
)
)
else:
raise ValueError(
"Unsupported image type. Expected either image url or base64 encoded string - \
e.g. 'data:image/jpeg;base64,<base64-encoded-string>'"
)
supported_image_formats = litellm.AmazonConverseConfig().get_supported_image_types()
document_types = ["application", "text"]
is_document = any(
mime_type.startswith(document_type) for document_type in document_types
)
if image_format in supported_image_formats:
return BedrockContentBlock(image=BedrockImageBlock(source=_blob, format=image_format)) # type: ignore
elif is_document:
return BedrockContentBlock(document=BedrockDocumentBlock(source=_blob, format=image_format, name="DocumentPDFmessages_{}".format(str(uuid.uuid4())))) # type: ignore
else:
# Handle the case when the image format is not supported
raise ValueError(
"Unsupported image format: {}. Supported formats: {}".format(
image_format, supported_image_formats
)
)
def _convert_to_bedrock_tool_call_invoke(
tool_calls: list,

View file

@ -4,4 +4,4 @@ model_list:
model: "azure/gpt-4o"
api_key: os.environ/AZURE_API_KEY
api_base: os.environ/AZURE_API_BASE

View file

@ -657,6 +657,7 @@ class GenerateKeyResponse(KeyRequestBase):
user_id: Optional[str] = None
token_id: Optional[str] = None
litellm_budget_table: Optional[Any] = None
token: Optional[str] = None
@model_validator(mode="before")
@classmethod

View file

@ -10,12 +10,14 @@ import litellm
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import (
GenerateKeyRequest,
GenerateKeyResponse,
KeyRequest,
LiteLLM_AuditLogs,
LiteLLM_VerificationToken,
LitellmTableNames,
ProxyErrorTypes,
ProxyException,
RegenerateKeyRequest,
UpdateKeyRequest,
UserAPIKeyAuth,
WebhookEvent,
@ -30,7 +32,7 @@ class KeyManagementEventHooks:
@staticmethod
async def async_key_generated_hook(
data: GenerateKeyRequest,
response: dict,
response: GenerateKeyResponse,
user_api_key_dict: UserAPIKeyAuth,
litellm_changed_by: Optional[str] = None,
):
@ -48,11 +50,13 @@ class KeyManagementEventHooks:
from litellm.proxy.proxy_server import litellm_proxy_admin_name
if data.send_invite_email is True:
await KeyManagementEventHooks._send_key_created_email(response)
await KeyManagementEventHooks._send_key_created_email(
response.model_dump(exclude_none=True)
)
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
if litellm.store_audit_logs is True:
_updated_values = json.dumps(response, default=str)
_updated_values = response.model_dump_json(exclude_none=True)
asyncio.create_task(
create_audit_log_for_update(
request_data=LiteLLM_AuditLogs(
@ -63,7 +67,7 @@ class KeyManagementEventHooks:
or litellm_proxy_admin_name,
changed_by_api_key=user_api_key_dict.api_key,
table_name=LitellmTableNames.KEY_TABLE_NAME,
object_id=response.get("token_id", ""),
object_id=response.token_id or "",
action="created",
updated_values=_updated_values,
before_value=None,
@ -72,8 +76,8 @@ class KeyManagementEventHooks:
)
# store the generated key in the secret manager
await KeyManagementEventHooks._store_virtual_key_in_secret_manager(
secret_name=data.key_alias or f"virtual-key-{uuid.uuid4()}",
secret_token=response.get("token", ""),
secret_name=data.key_alias or f"virtual-key-{response.token_id}",
secret_token=response.key,
)
@staticmethod
@ -119,7 +123,25 @@ class KeyManagementEventHooks:
)
)
)
pass
@staticmethod
async def async_key_rotated_hook(
data: Optional[RegenerateKeyRequest],
existing_key_row: Any,
response: GenerateKeyResponse,
user_api_key_dict: UserAPIKeyAuth,
litellm_changed_by: Optional[str] = None,
):
# store the generated key in the secret manager
if data is not None and response.token_id is not None:
initial_secret_name = (
existing_key_row.key_alias or f"virtual-key-{existing_key_row.token}"
)
await KeyManagementEventHooks._rotate_virtual_key_in_secret_manager(
current_secret_name=initial_secret_name,
new_secret_name=data.key_alias or f"virtual-key-{response.token_id}",
new_secret_value=response.key,
)
@staticmethod
async def async_key_deleted_hook(
@ -207,6 +229,35 @@ class KeyManagementEventHooks:
secret_value=secret_token,
)
@staticmethod
async def _rotate_virtual_key_in_secret_manager(
current_secret_name: str, new_secret_name: str, new_secret_value: str
):
"""
Update a virtual key in the secret manager
Args:
secret_name: Name of the virtual key
secret_token: Value of the virtual key (example: sk-1234)
"""
if litellm._key_management_settings is not None:
if litellm._key_management_settings.store_virtual_keys is True:
from litellm.secret_managers.base_secret_manager import (
BaseSecretManager,
)
# store the key in the secret manager
if isinstance(litellm.secret_manager_client, BaseSecretManager):
await litellm.secret_manager_client.async_rotate_secret(
current_secret_name=KeyManagementEventHooks._get_secret_name(
current_secret_name
),
new_secret_name=KeyManagementEventHooks._get_secret_name(
new_secret_name
),
new_secret_value=new_secret_value,
)
@staticmethod
def _get_secret_name(secret_name: str) -> str:
if litellm._key_management_settings.prefix_for_stored_virtual_keys.endswith(

View file

@ -523,6 +523,8 @@ async def generate_key_fn( # noqa: PLR0915
data.soft_budget
) # include the user-input soft budget in the response
response = GenerateKeyResponse(**response)
asyncio.create_task(
KeyManagementEventHooks.async_key_generated_hook(
data=data,
@ -532,7 +534,7 @@ async def generate_key_fn( # noqa: PLR0915
)
)
return GenerateKeyResponse(**response)
return response
except Exception as e:
verbose_proxy_logger.exception(
"litellm.proxy.proxy_server.generate_key_fn(): Exception occured - {}".format(
@ -1517,7 +1519,7 @@ async def regenerate_key_fn(
updated_token_dict = dict(updated_token)
updated_token_dict["key"] = new_token
updated_token_dict.pop("token")
updated_token_dict["token_id"] = updated_token_dict.pop("token")
### 3. remove existing key entry from cache
######################################################################
@ -1535,9 +1537,21 @@ async def regenerate_key_fn(
proxy_logging_obj=proxy_logging_obj,
)
return GenerateKeyResponse(
response = GenerateKeyResponse(
**updated_token_dict,
)
asyncio.create_task(
KeyManagementEventHooks.async_key_rotated_hook(
data=data,
existing_key_row=_key_in_db,
response=response,
user_api_key_dict=user_api_key_dict,
litellm_changed_by=litellm_changed_by,
)
)
return response
except Exception as e:
raise handle_exception_on_proxy(e)

View file

@ -3,6 +3,8 @@ from typing import Any, Dict, Optional, Union
import httpx
from litellm import verbose_logger
class BaseSecretManager(ABC):
"""
@ -93,3 +95,82 @@ class BaseSecretManager(ABC):
dict: Response from the secret manager containing deletion details
"""
pass
async def async_rotate_secret(
self,
current_secret_name: str,
new_secret_name: str,
new_secret_value: str,
optional_params: Optional[dict] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
) -> dict:
"""
Async function to rotate a secret by creating a new one and deleting the old one.
This allows for both value and name changes during rotation.
Args:
current_secret_name: Current name of the secret
new_secret_name: New name for the secret
new_secret_value: New value for the secret
optional_params: Additional AWS parameters
timeout: Request timeout
Returns:
dict: Response containing the new secret details
Raises:
ValueError: If the secret doesn't exist or if there's an HTTP error
"""
try:
# First verify the old secret exists
old_secret = await self.async_read_secret(
secret_name=current_secret_name,
optional_params=optional_params,
timeout=timeout,
)
if old_secret is None:
raise ValueError(f"Current secret {current_secret_name} not found")
# Create new secret with new name and value
create_response = await self.async_write_secret(
secret_name=new_secret_name,
secret_value=new_secret_value,
description=f"Rotated from {current_secret_name}",
optional_params=optional_params,
timeout=timeout,
)
# Verify new secret was created successfully
new_secret = await self.async_read_secret(
secret_name=new_secret_name,
optional_params=optional_params,
timeout=timeout,
)
if new_secret is None:
raise ValueError(f"Failed to verify new secret {new_secret_name}")
# If everything is successful, delete the old secret
await self.async_delete_secret(
secret_name=current_secret_name,
recovery_window_in_days=7, # Keep for recovery if needed
optional_params=optional_params,
timeout=timeout,
)
return create_response
except httpx.HTTPStatusError as err:
verbose_logger.exception(
"Error rotating secret in AWS Secrets Manager: %s",
str(err.response.text),
)
raise ValueError(f"HTTP error occurred: {err.response.text}")
except httpx.TimeoutException:
raise ValueError("Timeout error occurred")
except Exception as e:
verbose_logger.exception(
"Error rotating secret in AWS Secrets Manager: %s", str(e)
)
raise

View file

@ -222,6 +222,16 @@ class HashicorpSecretManager(BaseSecretManager):
verbose_logger.exception(f"Error writing secret to Hashicorp Vault: {e}")
return {"status": "error", "message": str(e)}
async def async_rotate_secret(
self,
current_secret_name: str,
new_secret_name: str,
new_secret_value: str,
optional_params: Dict | None = None,
timeout: float | httpx.Timeout | None = None,
) -> Dict:
raise NotImplementedError("Hashicorp does not support secret rotation")
async def async_delete_secret(
self,
secret_name: str,

View file

@ -2379,3 +2379,15 @@ class TestBedrockEmbedding(BaseLLMEmbeddingTest):
transformed_request[
"inputImage"
] == "iVBORw0KGgoAAAANSUhEUgAAAGQAAABkBAMAAACCzIhnAAAAG1BMVEURAAD///+ln5/h39/Dv79qX18uHx+If39MPz9oMSdmAAAACXBIWXMAAA7EAAAOxAGVKw4bAAABB0lEQVRYhe2SzWrEIBCAh2A0jxEs4j6GLDS9hqWmV5Flt0cJS+lRwv742DXpEjY1kOZW6HwHFZnPmVEBEARBEARB/jd0KYA/bcUYbPrRLh6amXHJ/K+ypMoyUaGthILzw0l+xI0jsO7ZcmCcm4ILd+QuVYgpHOmDmz6jBeJImdcUCmeBqQpuqRIbVmQsLCrAalrGpfoEqEogqbLTWuXCPCo+Ki1XGqgQ+jVVuhB8bOaHkvmYuzm/b0KYLWwoK58oFqi6XfxQ4Uz7d6WeKpna6ytUs5e8betMcqAv5YPC5EZB2Lm9FIn0/VP6R58+/GEY1X1egVoZ/3bt/EqF6malgSAIgiDIH+QL41409QMY0LMAAAAASUVORK5CYII="
def test_process_bedrock_converse_image_block():
from litellm.litellm_core_utils.prompt_templates.factory import (
_process_bedrock_converse_image_block,
)
block = _process_bedrock_converse_image_block(
image_url="data:text/plain;base64,base64file"
)
assert block["document"] is not None

View file

@ -327,7 +327,9 @@ def test_bedrock_parallel_tool_calling_pt(provider):
"""
Make sure parallel tool call blocks are merged correctly - https://github.com/BerriAI/litellm/issues/5277
"""
from litellm.litellm_core_utils.prompt_templates.factory import _bedrock_converse_messages_pt
from litellm.litellm_core_utils.prompt_templates.factory import (
_bedrock_converse_messages_pt,
)
from litellm.types.utils import ChatCompletionMessageToolCall, Function, Message
messages = [
@ -682,7 +684,9 @@ def test_alternating_roles_e2e():
def test_just_system_message():
from litellm.litellm_core_utils.prompt_templates.factory import _bedrock_converse_messages_pt
from litellm.litellm_core_utils.prompt_templates.factory import (
_bedrock_converse_messages_pt,
)
with pytest.raises(litellm.BadRequestError) as e:
_bedrock_converse_messages_pt(