diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index f0b5afa67f..a369b7f3e3 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -2401,8 +2401,8 @@ def _get_masked_values( sensitive_object: dict, ignore_sensitive_values: bool = False, mask_all_values: bool = False, - unmasked_length: int = 44, - number_of_asterisks: Optional[int] = None, + unmasked_length: int = 4, + number_of_asterisks: Optional[int] = 4, ) -> dict: """ Internal debugging helper function diff --git a/litellm/proxy/credential_endpoints/endpoints.py b/litellm/proxy/credential_endpoints/endpoints.py index 466766fc2d..5a82f44e80 100644 --- a/litellm/proxy/credential_endpoints/endpoints.py +++ b/litellm/proxy/credential_endpoints/endpoints.py @@ -2,7 +2,7 @@ CRUD endpoints for storing reusable credentials. """ -from typing import Optional +from typing import Optional, Union from fastapi import APIRouter, Depends, HTTPException, Request, Response @@ -14,8 +14,7 @@ from litellm.proxy._types import CommonProxyErrors, UserAPIKeyAuth from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.common_utils.encrypt_decrypt_utils import encrypt_value_helper from litellm.proxy.utils import handle_exception_on_proxy, jsonify_object -from litellm.types.router import CredentialLiteLLMParams -from litellm.types.utils import CredentialItem +from litellm.types.utils import CreateCredentialItem, CredentialItem router = APIRouter() @@ -39,7 +38,7 @@ class CredentialHelperUtils: async def create_credential( request: Request, fastapi_response: Response, - credential: CredentialItem, + credential: CreateCredentialItem, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), ): """ @@ -47,7 +46,7 @@ async def create_credential( Stores credential in DB. Reloads credentials in memory. """ - from litellm.proxy.proxy_server import prisma_client + from litellm.proxy.proxy_server import llm_router, prisma_client try: if prisma_client is None: @@ -55,9 +54,35 @@ async def create_credential( status_code=500, detail={"error": CommonProxyErrors.db_not_connected_error.value}, ) + if credential.model_id: + if llm_router is None: + raise HTTPException( + status_code=500, + detail="LLM router not found. Please ensure you have a valid router instance.", + ) + # get model from router + model = llm_router.get_deployment(credential.model_id) + if model is None: + raise HTTPException(status_code=404, detail="Model not found") + credential_values = llm_router.get_deployment_credentials( + credential.model_id + ) + if credential_values is None: + raise HTTPException(status_code=404, detail="Model not found") + credential.credential_values = credential_values + if credential.credential_values is None: + raise HTTPException( + status_code=400, + detail="Credential values are required. Unable to infer credential values from model ID.", + ) + processed_credential = CredentialItem( + credential_name=credential.credential_name, + credential_values=credential.credential_values, + credential_info=credential.credential_info, + ) encrypted_credential = CredentialHelperUtils.encrypt_credential_values( - credential + processed_credential ) credentials_dict = encrypted_credential.model_dump() credentials_dict_jsonified = jsonify_object(credentials_dict) @@ -70,7 +95,7 @@ async def create_credential( ) ## ADD TO LITELLM ## - CredentialAccessor.upsert_credentials([credential]) + CredentialAccessor.upsert_credentials([processed_credential]) return {"success": True, "message": "Credential created successfully"} except Exception as e: @@ -95,6 +120,7 @@ async def get_credentials( masked_credentials = [ { "credential_name": credential.credential_name, + "credential_values": _get_masked_values(credential.credential_values), "credential_info": credential.credential_info, } for credential in litellm.credential_list @@ -132,23 +158,20 @@ async def get_credential( if model_id: if llm_router is None: raise HTTPException(status_code=500, detail="LLM router not found") - # get model from router model = llm_router.get_deployment(model_id) if model is None: - raise HTTPException( - status_code=404, detail="Model not found. Got model ID: " + model_id - ) - # get credential object from model - credential_values = _get_masked_values( - CredentialLiteLLMParams(**model.litellm_params.model_dump()).model_dump( - exclude_none=True - ), + raise HTTPException(status_code=404, detail="Model not found") + credential_values = llm_router.get_deployment_credentials(model_id) + if credential_values is None: + raise HTTPException(status_code=404, detail="Model not found") + masked_credential_values = _get_masked_values( + credential_values, unmasked_length=4, number_of_asterisks=4, ) credential = CredentialItem( credential_name="{}-credential-{}".format(model.model_name, model_id), - credential_values=credential_values, + credential_values=masked_credential_values, credential_info={}, ) # return credential object @@ -159,7 +182,9 @@ async def get_credential( masked_credential = CredentialItem( credential_name=credential.credential_name, credential_values=_get_masked_values( - credential.credential_values + credential.credential_values, + unmasked_length=4, + number_of_asterisks=4, ), credential_info=credential.credential_info, ) diff --git a/litellm/router.py b/litellm/router.py index f7f361354b..a395c851dd 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -111,6 +111,7 @@ from litellm.types.router import ( AlertingConfig, AllowedFailsPolicy, AssistantsTypedDict, + CredentialLiteLLMParams, CustomRoutingStrategyBase, Deployment, DeploymentTypedDict, @@ -636,29 +637,6 @@ class Router: if self.cache.redis_cache is None: self.cache.redis_cache = cache - def initialize_assistants_endpoint(self): - ## INITIALIZE PASS THROUGH ASSISTANTS ENDPOINT ## - self.acreate_assistants = self.factory_function(litellm.acreate_assistants) - self.adelete_assistant = self.factory_function(litellm.adelete_assistant) - self.aget_assistants = self.factory_function(litellm.aget_assistants) - self.acreate_thread = self.factory_function(litellm.acreate_thread) - self.aget_thread = self.factory_function(litellm.aget_thread) - self.a_add_message = self.factory_function(litellm.a_add_message) - self.aget_messages = self.factory_function(litellm.aget_messages) - self.arun_thread = self.factory_function(litellm.arun_thread) - - def initialize_router_endpoints(self): - self.amoderation = self.factory_function( - litellm.amoderation, call_type="moderation" - ) - self.aanthropic_messages = self.factory_function( - litellm.anthropic_messages, call_type="anthropic_messages" - ) - self.aresponses = self.factory_function( - litellm.aresponses, call_type="aresponses" - ) - self.responses = self.factory_function(litellm.responses, call_type="responses") - def routing_strategy_init( self, routing_strategy: Union[RoutingStrategy, str], routing_strategy_args: dict ): @@ -724,6 +702,29 @@ class Router: else: pass + def initialize_assistants_endpoint(self): + ## INITIALIZE PASS THROUGH ASSISTANTS ENDPOINT ## + self.acreate_assistants = self.factory_function(litellm.acreate_assistants) + self.adelete_assistant = self.factory_function(litellm.adelete_assistant) + self.aget_assistants = self.factory_function(litellm.aget_assistants) + self.acreate_thread = self.factory_function(litellm.acreate_thread) + self.aget_thread = self.factory_function(litellm.aget_thread) + self.a_add_message = self.factory_function(litellm.a_add_message) + self.aget_messages = self.factory_function(litellm.aget_messages) + self.arun_thread = self.factory_function(litellm.arun_thread) + + def initialize_router_endpoints(self): + self.amoderation = self.factory_function( + litellm.amoderation, call_type="moderation" + ) + self.aanthropic_messages = self.factory_function( + litellm.anthropic_messages, call_type="anthropic_messages" + ) + self.aresponses = self.factory_function( + litellm.aresponses, call_type="aresponses" + ) + self.responses = self.factory_function(litellm.responses, call_type="responses") + def validate_fallbacks(self, fallback_param: Optional[List]): """ Validate the fallbacks parameter. @@ -4625,6 +4626,17 @@ class Router: raise Exception("Model invalid format - {}".format(type(model))) return None + def get_deployment_credentials(self, model_id: str) -> Optional[dict]: + """ + Returns -> dict of credentials for a given model id + """ + deployment = self.get_deployment(model_id=model_id) + if deployment is None: + return None + return CredentialLiteLLMParams( + **deployment.litellm_params.model_dump(exclude_none=True) + ).model_dump(exclude_none=True) + def get_deployment_by_model_group_name( self, model_group_name: str ) -> Optional[Deployment]: diff --git a/litellm/types/completion.py b/litellm/types/completion.py index 7b5ed4e502..b06bb733c4 100644 --- a/litellm/types/completion.py +++ b/litellm/types/completion.py @@ -1,6 +1,6 @@ from typing import Iterable, List, Optional, Union -from pydantic import BaseModel, ConfigDict, validator +from pydantic import BaseModel, ConfigDict from typing_extensions import Literal, Required, TypedDict diff --git a/litellm/types/utils.py b/litellm/types/utils.py index 9608c099a3..a2d41d8fb9 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -18,7 +18,7 @@ from openai.types.moderation import ( CategoryScores, ) from openai.types.moderation_create_response import Moderation, ModerationCreateResponse -from pydantic import BaseModel, ConfigDict, Field, PrivateAttr +from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator from typing_extensions import Callable, Dict, Required, TypedDict, override import litellm @@ -2053,7 +2053,22 @@ class RawRequestTypedDict(TypedDict, total=False): error: Optional[str] -class CredentialItem(BaseModel): +class CredentialBase(BaseModel): credential_name: str - credential_values: dict credential_info: dict + + +class CredentialItem(CredentialBase): + credential_values: dict + + +class CreateCredentialItem(CredentialBase): + credential_values: Optional[dict] = None + model_id: Optional[str] = None + + @model_validator(mode="before") + @classmethod + def check_credential_params(cls, values): + if not values.get("credential_values") and not values.get("model_id"): + raise ValueError("Either credential_values or model_id must be set") + return values