mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
feat(endpoints.py): support adding credentials by model id
Allows user to reuse existing model credentials
This commit is contained in:
parent
913dc5b73b
commit
b75cd3b887
5 changed files with 99 additions and 47 deletions
|
@ -2401,8 +2401,8 @@ def _get_masked_values(
|
||||||
sensitive_object: dict,
|
sensitive_object: dict,
|
||||||
ignore_sensitive_values: bool = False,
|
ignore_sensitive_values: bool = False,
|
||||||
mask_all_values: bool = False,
|
mask_all_values: bool = False,
|
||||||
unmasked_length: int = 44,
|
unmasked_length: int = 4,
|
||||||
number_of_asterisks: Optional[int] = None,
|
number_of_asterisks: Optional[int] = 4,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
Internal debugging helper function
|
Internal debugging helper function
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
CRUD endpoints for storing reusable credentials.
|
CRUD endpoints for storing reusable credentials.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional, Union
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
||||||
|
|
||||||
|
@ -14,8 +14,7 @@ from litellm.proxy._types import CommonProxyErrors, UserAPIKeyAuth
|
||||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||||
from litellm.proxy.common_utils.encrypt_decrypt_utils import encrypt_value_helper
|
from litellm.proxy.common_utils.encrypt_decrypt_utils import encrypt_value_helper
|
||||||
from litellm.proxy.utils import handle_exception_on_proxy, jsonify_object
|
from litellm.proxy.utils import handle_exception_on_proxy, jsonify_object
|
||||||
from litellm.types.router import CredentialLiteLLMParams
|
from litellm.types.utils import CreateCredentialItem, CredentialItem
|
||||||
from litellm.types.utils import CredentialItem
|
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
@ -39,7 +38,7 @@ class CredentialHelperUtils:
|
||||||
async def create_credential(
|
async def create_credential(
|
||||||
request: Request,
|
request: Request,
|
||||||
fastapi_response: Response,
|
fastapi_response: Response,
|
||||||
credential: CredentialItem,
|
credential: CreateCredentialItem,
|
||||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -47,7 +46,7 @@ async def create_credential(
|
||||||
Stores credential in DB.
|
Stores credential in DB.
|
||||||
Reloads credentials in memory.
|
Reloads credentials in memory.
|
||||||
"""
|
"""
|
||||||
from litellm.proxy.proxy_server import prisma_client
|
from litellm.proxy.proxy_server import llm_router, prisma_client
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if prisma_client is None:
|
if prisma_client is None:
|
||||||
|
@ -55,9 +54,35 @@ async def create_credential(
|
||||||
status_code=500,
|
status_code=500,
|
||||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||||
)
|
)
|
||||||
|
if credential.model_id:
|
||||||
|
if llm_router is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail="LLM router not found. Please ensure you have a valid router instance.",
|
||||||
|
)
|
||||||
|
# get model from router
|
||||||
|
model = llm_router.get_deployment(credential.model_id)
|
||||||
|
if model is None:
|
||||||
|
raise HTTPException(status_code=404, detail="Model not found")
|
||||||
|
credential_values = llm_router.get_deployment_credentials(
|
||||||
|
credential.model_id
|
||||||
|
)
|
||||||
|
if credential_values is None:
|
||||||
|
raise HTTPException(status_code=404, detail="Model not found")
|
||||||
|
credential.credential_values = credential_values
|
||||||
|
|
||||||
|
if credential.credential_values is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="Credential values are required. Unable to infer credential values from model ID.",
|
||||||
|
)
|
||||||
|
processed_credential = CredentialItem(
|
||||||
|
credential_name=credential.credential_name,
|
||||||
|
credential_values=credential.credential_values,
|
||||||
|
credential_info=credential.credential_info,
|
||||||
|
)
|
||||||
encrypted_credential = CredentialHelperUtils.encrypt_credential_values(
|
encrypted_credential = CredentialHelperUtils.encrypt_credential_values(
|
||||||
credential
|
processed_credential
|
||||||
)
|
)
|
||||||
credentials_dict = encrypted_credential.model_dump()
|
credentials_dict = encrypted_credential.model_dump()
|
||||||
credentials_dict_jsonified = jsonify_object(credentials_dict)
|
credentials_dict_jsonified = jsonify_object(credentials_dict)
|
||||||
|
@ -70,7 +95,7 @@ async def create_credential(
|
||||||
)
|
)
|
||||||
|
|
||||||
## ADD TO LITELLM ##
|
## ADD TO LITELLM ##
|
||||||
CredentialAccessor.upsert_credentials([credential])
|
CredentialAccessor.upsert_credentials([processed_credential])
|
||||||
|
|
||||||
return {"success": True, "message": "Credential created successfully"}
|
return {"success": True, "message": "Credential created successfully"}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -95,6 +120,7 @@ async def get_credentials(
|
||||||
masked_credentials = [
|
masked_credentials = [
|
||||||
{
|
{
|
||||||
"credential_name": credential.credential_name,
|
"credential_name": credential.credential_name,
|
||||||
|
"credential_values": _get_masked_values(credential.credential_values),
|
||||||
"credential_info": credential.credential_info,
|
"credential_info": credential.credential_info,
|
||||||
}
|
}
|
||||||
for credential in litellm.credential_list
|
for credential in litellm.credential_list
|
||||||
|
@ -132,23 +158,20 @@ async def get_credential(
|
||||||
if model_id:
|
if model_id:
|
||||||
if llm_router is None:
|
if llm_router is None:
|
||||||
raise HTTPException(status_code=500, detail="LLM router not found")
|
raise HTTPException(status_code=500, detail="LLM router not found")
|
||||||
# get model from router
|
|
||||||
model = llm_router.get_deployment(model_id)
|
model = llm_router.get_deployment(model_id)
|
||||||
if model is None:
|
if model is None:
|
||||||
raise HTTPException(
|
raise HTTPException(status_code=404, detail="Model not found")
|
||||||
status_code=404, detail="Model not found. Got model ID: " + model_id
|
credential_values = llm_router.get_deployment_credentials(model_id)
|
||||||
)
|
if credential_values is None:
|
||||||
# get credential object from model
|
raise HTTPException(status_code=404, detail="Model not found")
|
||||||
credential_values = _get_masked_values(
|
masked_credential_values = _get_masked_values(
|
||||||
CredentialLiteLLMParams(**model.litellm_params.model_dump()).model_dump(
|
credential_values,
|
||||||
exclude_none=True
|
|
||||||
),
|
|
||||||
unmasked_length=4,
|
unmasked_length=4,
|
||||||
number_of_asterisks=4,
|
number_of_asterisks=4,
|
||||||
)
|
)
|
||||||
credential = CredentialItem(
|
credential = CredentialItem(
|
||||||
credential_name="{}-credential-{}".format(model.model_name, model_id),
|
credential_name="{}-credential-{}".format(model.model_name, model_id),
|
||||||
credential_values=credential_values,
|
credential_values=masked_credential_values,
|
||||||
credential_info={},
|
credential_info={},
|
||||||
)
|
)
|
||||||
# return credential object
|
# return credential object
|
||||||
|
@ -159,7 +182,9 @@ async def get_credential(
|
||||||
masked_credential = CredentialItem(
|
masked_credential = CredentialItem(
|
||||||
credential_name=credential.credential_name,
|
credential_name=credential.credential_name,
|
||||||
credential_values=_get_masked_values(
|
credential_values=_get_masked_values(
|
||||||
credential.credential_values
|
credential.credential_values,
|
||||||
|
unmasked_length=4,
|
||||||
|
number_of_asterisks=4,
|
||||||
),
|
),
|
||||||
credential_info=credential.credential_info,
|
credential_info=credential.credential_info,
|
||||||
)
|
)
|
||||||
|
|
|
@ -111,6 +111,7 @@ from litellm.types.router import (
|
||||||
AlertingConfig,
|
AlertingConfig,
|
||||||
AllowedFailsPolicy,
|
AllowedFailsPolicy,
|
||||||
AssistantsTypedDict,
|
AssistantsTypedDict,
|
||||||
|
CredentialLiteLLMParams,
|
||||||
CustomRoutingStrategyBase,
|
CustomRoutingStrategyBase,
|
||||||
Deployment,
|
Deployment,
|
||||||
DeploymentTypedDict,
|
DeploymentTypedDict,
|
||||||
|
@ -636,29 +637,6 @@ class Router:
|
||||||
if self.cache.redis_cache is None:
|
if self.cache.redis_cache is None:
|
||||||
self.cache.redis_cache = cache
|
self.cache.redis_cache = cache
|
||||||
|
|
||||||
def initialize_assistants_endpoint(self):
|
|
||||||
## INITIALIZE PASS THROUGH ASSISTANTS ENDPOINT ##
|
|
||||||
self.acreate_assistants = self.factory_function(litellm.acreate_assistants)
|
|
||||||
self.adelete_assistant = self.factory_function(litellm.adelete_assistant)
|
|
||||||
self.aget_assistants = self.factory_function(litellm.aget_assistants)
|
|
||||||
self.acreate_thread = self.factory_function(litellm.acreate_thread)
|
|
||||||
self.aget_thread = self.factory_function(litellm.aget_thread)
|
|
||||||
self.a_add_message = self.factory_function(litellm.a_add_message)
|
|
||||||
self.aget_messages = self.factory_function(litellm.aget_messages)
|
|
||||||
self.arun_thread = self.factory_function(litellm.arun_thread)
|
|
||||||
|
|
||||||
def initialize_router_endpoints(self):
|
|
||||||
self.amoderation = self.factory_function(
|
|
||||||
litellm.amoderation, call_type="moderation"
|
|
||||||
)
|
|
||||||
self.aanthropic_messages = self.factory_function(
|
|
||||||
litellm.anthropic_messages, call_type="anthropic_messages"
|
|
||||||
)
|
|
||||||
self.aresponses = self.factory_function(
|
|
||||||
litellm.aresponses, call_type="aresponses"
|
|
||||||
)
|
|
||||||
self.responses = self.factory_function(litellm.responses, call_type="responses")
|
|
||||||
|
|
||||||
def routing_strategy_init(
|
def routing_strategy_init(
|
||||||
self, routing_strategy: Union[RoutingStrategy, str], routing_strategy_args: dict
|
self, routing_strategy: Union[RoutingStrategy, str], routing_strategy_args: dict
|
||||||
):
|
):
|
||||||
|
@ -724,6 +702,29 @@ class Router:
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def initialize_assistants_endpoint(self):
|
||||||
|
## INITIALIZE PASS THROUGH ASSISTANTS ENDPOINT ##
|
||||||
|
self.acreate_assistants = self.factory_function(litellm.acreate_assistants)
|
||||||
|
self.adelete_assistant = self.factory_function(litellm.adelete_assistant)
|
||||||
|
self.aget_assistants = self.factory_function(litellm.aget_assistants)
|
||||||
|
self.acreate_thread = self.factory_function(litellm.acreate_thread)
|
||||||
|
self.aget_thread = self.factory_function(litellm.aget_thread)
|
||||||
|
self.a_add_message = self.factory_function(litellm.a_add_message)
|
||||||
|
self.aget_messages = self.factory_function(litellm.aget_messages)
|
||||||
|
self.arun_thread = self.factory_function(litellm.arun_thread)
|
||||||
|
|
||||||
|
def initialize_router_endpoints(self):
|
||||||
|
self.amoderation = self.factory_function(
|
||||||
|
litellm.amoderation, call_type="moderation"
|
||||||
|
)
|
||||||
|
self.aanthropic_messages = self.factory_function(
|
||||||
|
litellm.anthropic_messages, call_type="anthropic_messages"
|
||||||
|
)
|
||||||
|
self.aresponses = self.factory_function(
|
||||||
|
litellm.aresponses, call_type="aresponses"
|
||||||
|
)
|
||||||
|
self.responses = self.factory_function(litellm.responses, call_type="responses")
|
||||||
|
|
||||||
def validate_fallbacks(self, fallback_param: Optional[List]):
|
def validate_fallbacks(self, fallback_param: Optional[List]):
|
||||||
"""
|
"""
|
||||||
Validate the fallbacks parameter.
|
Validate the fallbacks parameter.
|
||||||
|
@ -4625,6 +4626,17 @@ class Router:
|
||||||
raise Exception("Model invalid format - {}".format(type(model)))
|
raise Exception("Model invalid format - {}".format(type(model)))
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def get_deployment_credentials(self, model_id: str) -> Optional[dict]:
|
||||||
|
"""
|
||||||
|
Returns -> dict of credentials for a given model id
|
||||||
|
"""
|
||||||
|
deployment = self.get_deployment(model_id=model_id)
|
||||||
|
if deployment is None:
|
||||||
|
return None
|
||||||
|
return CredentialLiteLLMParams(
|
||||||
|
**deployment.litellm_params.model_dump(exclude_none=True)
|
||||||
|
).model_dump(exclude_none=True)
|
||||||
|
|
||||||
def get_deployment_by_model_group_name(
|
def get_deployment_by_model_group_name(
|
||||||
self, model_group_name: str
|
self, model_group_name: str
|
||||||
) -> Optional[Deployment]:
|
) -> Optional[Deployment]:
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from typing import Iterable, List, Optional, Union
|
from typing import Iterable, List, Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, validator
|
from pydantic import BaseModel, ConfigDict
|
||||||
from typing_extensions import Literal, Required, TypedDict
|
from typing_extensions import Literal, Required, TypedDict
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,7 @@ from openai.types.moderation import (
|
||||||
CategoryScores,
|
CategoryScores,
|
||||||
)
|
)
|
||||||
from openai.types.moderation_create_response import Moderation, ModerationCreateResponse
|
from openai.types.moderation_create_response import Moderation, ModerationCreateResponse
|
||||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
|
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
|
||||||
from typing_extensions import Callable, Dict, Required, TypedDict, override
|
from typing_extensions import Callable, Dict, Required, TypedDict, override
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
@ -2053,7 +2053,22 @@ class RawRequestTypedDict(TypedDict, total=False):
|
||||||
error: Optional[str]
|
error: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
class CredentialItem(BaseModel):
|
class CredentialBase(BaseModel):
|
||||||
credential_name: str
|
credential_name: str
|
||||||
credential_values: dict
|
|
||||||
credential_info: dict
|
credential_info: dict
|
||||||
|
|
||||||
|
|
||||||
|
class CredentialItem(CredentialBase):
|
||||||
|
credential_values: dict
|
||||||
|
|
||||||
|
|
||||||
|
class CreateCredentialItem(CredentialBase):
|
||||||
|
credential_values: Optional[dict] = None
|
||||||
|
model_id: Optional[str] = None
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_credential_params(cls, values):
|
||||||
|
if not values.get("credential_values") and not values.get("model_id"):
|
||||||
|
raise ValueError("Either credential_values or model_id must be set")
|
||||||
|
return values
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue