feat(endpoints.py): support adding credentials by model id

Allows user to reuse existing model credentials
This commit is contained in:
Krrish Dholakia 2025-03-14 12:32:32 -07:00
parent 913dc5b73b
commit b75cd3b887
5 changed files with 99 additions and 47 deletions

View file

@ -2401,8 +2401,8 @@ def _get_masked_values(
sensitive_object: dict, sensitive_object: dict,
ignore_sensitive_values: bool = False, ignore_sensitive_values: bool = False,
mask_all_values: bool = False, mask_all_values: bool = False,
unmasked_length: int = 44, unmasked_length: int = 4,
number_of_asterisks: Optional[int] = None, number_of_asterisks: Optional[int] = 4,
) -> dict: ) -> dict:
""" """
Internal debugging helper function Internal debugging helper function

View file

@ -2,7 +2,7 @@
CRUD endpoints for storing reusable credentials. CRUD endpoints for storing reusable credentials.
""" """
from typing import Optional from typing import Optional, Union
from fastapi import APIRouter, Depends, HTTPException, Request, Response from fastapi import APIRouter, Depends, HTTPException, Request, Response
@ -14,8 +14,7 @@ from litellm.proxy._types import CommonProxyErrors, UserAPIKeyAuth
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.common_utils.encrypt_decrypt_utils import encrypt_value_helper from litellm.proxy.common_utils.encrypt_decrypt_utils import encrypt_value_helper
from litellm.proxy.utils import handle_exception_on_proxy, jsonify_object from litellm.proxy.utils import handle_exception_on_proxy, jsonify_object
from litellm.types.router import CredentialLiteLLMParams from litellm.types.utils import CreateCredentialItem, CredentialItem
from litellm.types.utils import CredentialItem
router = APIRouter() router = APIRouter()
@ -39,7 +38,7 @@ class CredentialHelperUtils:
async def create_credential( async def create_credential(
request: Request, request: Request,
fastapi_response: Response, fastapi_response: Response,
credential: CredentialItem, credential: CreateCredentialItem,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
): ):
""" """
@ -47,7 +46,7 @@ async def create_credential(
Stores credential in DB. Stores credential in DB.
Reloads credentials in memory. Reloads credentials in memory.
""" """
from litellm.proxy.proxy_server import prisma_client from litellm.proxy.proxy_server import llm_router, prisma_client
try: try:
if prisma_client is None: if prisma_client is None:
@ -55,9 +54,35 @@ async def create_credential(
status_code=500, status_code=500,
detail={"error": CommonProxyErrors.db_not_connected_error.value}, detail={"error": CommonProxyErrors.db_not_connected_error.value},
) )
if credential.model_id:
if llm_router is None:
raise HTTPException(
status_code=500,
detail="LLM router not found. Please ensure you have a valid router instance.",
)
# get model from router
model = llm_router.get_deployment(credential.model_id)
if model is None:
raise HTTPException(status_code=404, detail="Model not found")
credential_values = llm_router.get_deployment_credentials(
credential.model_id
)
if credential_values is None:
raise HTTPException(status_code=404, detail="Model not found")
credential.credential_values = credential_values
if credential.credential_values is None:
raise HTTPException(
status_code=400,
detail="Credential values are required. Unable to infer credential values from model ID.",
)
processed_credential = CredentialItem(
credential_name=credential.credential_name,
credential_values=credential.credential_values,
credential_info=credential.credential_info,
)
encrypted_credential = CredentialHelperUtils.encrypt_credential_values( encrypted_credential = CredentialHelperUtils.encrypt_credential_values(
credential processed_credential
) )
credentials_dict = encrypted_credential.model_dump() credentials_dict = encrypted_credential.model_dump()
credentials_dict_jsonified = jsonify_object(credentials_dict) credentials_dict_jsonified = jsonify_object(credentials_dict)
@ -70,7 +95,7 @@ async def create_credential(
) )
## ADD TO LITELLM ## ## ADD TO LITELLM ##
CredentialAccessor.upsert_credentials([credential]) CredentialAccessor.upsert_credentials([processed_credential])
return {"success": True, "message": "Credential created successfully"} return {"success": True, "message": "Credential created successfully"}
except Exception as e: except Exception as e:
@ -95,6 +120,7 @@ async def get_credentials(
masked_credentials = [ masked_credentials = [
{ {
"credential_name": credential.credential_name, "credential_name": credential.credential_name,
"credential_values": _get_masked_values(credential.credential_values),
"credential_info": credential.credential_info, "credential_info": credential.credential_info,
} }
for credential in litellm.credential_list for credential in litellm.credential_list
@ -132,23 +158,20 @@ async def get_credential(
if model_id: if model_id:
if llm_router is None: if llm_router is None:
raise HTTPException(status_code=500, detail="LLM router not found") raise HTTPException(status_code=500, detail="LLM router not found")
# get model from router
model = llm_router.get_deployment(model_id) model = llm_router.get_deployment(model_id)
if model is None: if model is None:
raise HTTPException( raise HTTPException(status_code=404, detail="Model not found")
status_code=404, detail="Model not found. Got model ID: " + model_id credential_values = llm_router.get_deployment_credentials(model_id)
) if credential_values is None:
# get credential object from model raise HTTPException(status_code=404, detail="Model not found")
credential_values = _get_masked_values( masked_credential_values = _get_masked_values(
CredentialLiteLLMParams(**model.litellm_params.model_dump()).model_dump( credential_values,
exclude_none=True
),
unmasked_length=4, unmasked_length=4,
number_of_asterisks=4, number_of_asterisks=4,
) )
credential = CredentialItem( credential = CredentialItem(
credential_name="{}-credential-{}".format(model.model_name, model_id), credential_name="{}-credential-{}".format(model.model_name, model_id),
credential_values=credential_values, credential_values=masked_credential_values,
credential_info={}, credential_info={},
) )
# return credential object # return credential object
@ -159,7 +182,9 @@ async def get_credential(
masked_credential = CredentialItem( masked_credential = CredentialItem(
credential_name=credential.credential_name, credential_name=credential.credential_name,
credential_values=_get_masked_values( credential_values=_get_masked_values(
credential.credential_values credential.credential_values,
unmasked_length=4,
number_of_asterisks=4,
), ),
credential_info=credential.credential_info, credential_info=credential.credential_info,
) )

View file

@ -111,6 +111,7 @@ from litellm.types.router import (
AlertingConfig, AlertingConfig,
AllowedFailsPolicy, AllowedFailsPolicy,
AssistantsTypedDict, AssistantsTypedDict,
CredentialLiteLLMParams,
CustomRoutingStrategyBase, CustomRoutingStrategyBase,
Deployment, Deployment,
DeploymentTypedDict, DeploymentTypedDict,
@ -636,29 +637,6 @@ class Router:
if self.cache.redis_cache is None: if self.cache.redis_cache is None:
self.cache.redis_cache = cache self.cache.redis_cache = cache
def initialize_assistants_endpoint(self):
## INITIALIZE PASS THROUGH ASSISTANTS ENDPOINT ##
self.acreate_assistants = self.factory_function(litellm.acreate_assistants)
self.adelete_assistant = self.factory_function(litellm.adelete_assistant)
self.aget_assistants = self.factory_function(litellm.aget_assistants)
self.acreate_thread = self.factory_function(litellm.acreate_thread)
self.aget_thread = self.factory_function(litellm.aget_thread)
self.a_add_message = self.factory_function(litellm.a_add_message)
self.aget_messages = self.factory_function(litellm.aget_messages)
self.arun_thread = self.factory_function(litellm.arun_thread)
def initialize_router_endpoints(self):
self.amoderation = self.factory_function(
litellm.amoderation, call_type="moderation"
)
self.aanthropic_messages = self.factory_function(
litellm.anthropic_messages, call_type="anthropic_messages"
)
self.aresponses = self.factory_function(
litellm.aresponses, call_type="aresponses"
)
self.responses = self.factory_function(litellm.responses, call_type="responses")
def routing_strategy_init( def routing_strategy_init(
self, routing_strategy: Union[RoutingStrategy, str], routing_strategy_args: dict self, routing_strategy: Union[RoutingStrategy, str], routing_strategy_args: dict
): ):
@ -724,6 +702,29 @@ class Router:
else: else:
pass pass
def initialize_assistants_endpoint(self):
## INITIALIZE PASS THROUGH ASSISTANTS ENDPOINT ##
self.acreate_assistants = self.factory_function(litellm.acreate_assistants)
self.adelete_assistant = self.factory_function(litellm.adelete_assistant)
self.aget_assistants = self.factory_function(litellm.aget_assistants)
self.acreate_thread = self.factory_function(litellm.acreate_thread)
self.aget_thread = self.factory_function(litellm.aget_thread)
self.a_add_message = self.factory_function(litellm.a_add_message)
self.aget_messages = self.factory_function(litellm.aget_messages)
self.arun_thread = self.factory_function(litellm.arun_thread)
def initialize_router_endpoints(self):
self.amoderation = self.factory_function(
litellm.amoderation, call_type="moderation"
)
self.aanthropic_messages = self.factory_function(
litellm.anthropic_messages, call_type="anthropic_messages"
)
self.aresponses = self.factory_function(
litellm.aresponses, call_type="aresponses"
)
self.responses = self.factory_function(litellm.responses, call_type="responses")
def validate_fallbacks(self, fallback_param: Optional[List]): def validate_fallbacks(self, fallback_param: Optional[List]):
""" """
Validate the fallbacks parameter. Validate the fallbacks parameter.
@ -4625,6 +4626,17 @@ class Router:
raise Exception("Model invalid format - {}".format(type(model))) raise Exception("Model invalid format - {}".format(type(model)))
return None return None
def get_deployment_credentials(self, model_id: str) -> Optional[dict]:
"""
Returns -> dict of credentials for a given model id
"""
deployment = self.get_deployment(model_id=model_id)
if deployment is None:
return None
return CredentialLiteLLMParams(
**deployment.litellm_params.model_dump(exclude_none=True)
).model_dump(exclude_none=True)
def get_deployment_by_model_group_name( def get_deployment_by_model_group_name(
self, model_group_name: str self, model_group_name: str
) -> Optional[Deployment]: ) -> Optional[Deployment]:

View file

@ -1,6 +1,6 @@
from typing import Iterable, List, Optional, Union from typing import Iterable, List, Optional, Union
from pydantic import BaseModel, ConfigDict, validator from pydantic import BaseModel, ConfigDict
from typing_extensions import Literal, Required, TypedDict from typing_extensions import Literal, Required, TypedDict

View file

@ -18,7 +18,7 @@ from openai.types.moderation import (
CategoryScores, CategoryScores,
) )
from openai.types.moderation_create_response import Moderation, ModerationCreateResponse from openai.types.moderation_create_response import Moderation, ModerationCreateResponse
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
from typing_extensions import Callable, Dict, Required, TypedDict, override from typing_extensions import Callable, Dict, Required, TypedDict, override
import litellm import litellm
@ -2053,7 +2053,22 @@ class RawRequestTypedDict(TypedDict, total=False):
error: Optional[str] error: Optional[str]
class CredentialItem(BaseModel): class CredentialBase(BaseModel):
credential_name: str credential_name: str
credential_values: dict
credential_info: dict credential_info: dict
class CredentialItem(CredentialBase):
credential_values: dict
class CreateCredentialItem(CredentialBase):
credential_values: Optional[dict] = None
model_id: Optional[str] = None
@model_validator(mode="before")
@classmethod
def check_credential_params(cls, values):
if not values.get("credential_values") and not values.get("model_id"):
raise ValueError("Either credential_values or model_id must be set")
return values