mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Merge branch 'BerriAI:main' into main
This commit is contained in:
commit
cf5a0dee6f
25 changed files with 449 additions and 67 deletions
|
@ -46,7 +46,7 @@ For security inquiries, please contact us at support@berri.ai
|
|||
|-------------------|-------------------------------------------------------------------------------------------------|
|
||||
| SOC 2 Type I | Certified. Report available upon request on Enterprise plan. |
|
||||
| SOC 2 Type II | In progress. Certificate available by April 15th, 2025 |
|
||||
| ISO27001 | In progress. Certificate available by February 7th, 2025 |
|
||||
| ISO 27001 | Certified. Report available upon request on Enterprise |
|
||||
|
||||
|
||||
## Supported Data Regions for LiteLLM Cloud
|
||||
|
@ -137,7 +137,7 @@ Point of contact email address for general security-related questions: krrish@be
|
|||
Has the Vendor been audited / certified?
|
||||
- SOC 2 Type I. Certified. Report available upon request on Enterprise plan.
|
||||
- SOC 2 Type II. In progress. Certificate available by April 15th, 2025.
|
||||
- ISO27001. In progress. Certificate available by February 7th, 2025.
|
||||
- ISO 27001. Certified. Report available upon request on Enterprise plan.
|
||||
|
||||
Has an information security management system been implemented?
|
||||
- Yes - [CodeQL](https://codeql.github.com/) and a comprehensive ISMS covering multiple security domains.
|
||||
|
|
|
@ -40,6 +40,7 @@ in_memory_dynamic_logger_cache = DynamicLoggingCache()
|
|||
def langfuse_client_init(
|
||||
langfuse_public_key=None,
|
||||
langfuse_secret=None,
|
||||
langfuse_secret_key=None,
|
||||
langfuse_host=None,
|
||||
flush_interval=1,
|
||||
) -> LangfuseClass:
|
||||
|
@ -67,7 +68,10 @@ def langfuse_client_init(
|
|||
)
|
||||
|
||||
# Instance variables
|
||||
secret_key = langfuse_secret or os.getenv("LANGFUSE_SECRET_KEY")
|
||||
|
||||
secret_key = (
|
||||
langfuse_secret or langfuse_secret_key or os.getenv("LANGFUSE_SECRET_KEY")
|
||||
)
|
||||
public_key = langfuse_public_key or os.getenv("LANGFUSE_PUBLIC_KEY")
|
||||
langfuse_host = langfuse_host or os.getenv(
|
||||
"LANGFUSE_HOST", "https://cloud.langfuse.com"
|
||||
|
@ -190,6 +194,7 @@ class LangfusePromptManagement(LangFuseLogger, PromptManagementBase, CustomLogge
|
|||
langfuse_client = langfuse_client_init(
|
||||
langfuse_public_key=dynamic_callback_params.get("langfuse_public_key"),
|
||||
langfuse_secret=dynamic_callback_params.get("langfuse_secret"),
|
||||
langfuse_secret_key=dynamic_callback_params.get("langfuse_secret_key"),
|
||||
langfuse_host=dynamic_callback_params.get("langfuse_host"),
|
||||
)
|
||||
langfuse_prompt_client = self._get_prompt_from_id(
|
||||
|
@ -206,6 +211,7 @@ class LangfusePromptManagement(LangFuseLogger, PromptManagementBase, CustomLogge
|
|||
langfuse_client = langfuse_client_init(
|
||||
langfuse_public_key=dynamic_callback_params.get("langfuse_public_key"),
|
||||
langfuse_secret=dynamic_callback_params.get("langfuse_secret"),
|
||||
langfuse_secret_key=dynamic_callback_params.get("langfuse_secret_key"),
|
||||
langfuse_host=dynamic_callback_params.get("langfuse_host"),
|
||||
)
|
||||
langfuse_prompt_client = self._get_prompt_from_id(
|
||||
|
|
|
@ -1,34 +1,9 @@
|
|||
model_list:
|
||||
- model_name: claude-3.7
|
||||
- model_name: my-langfuse-model
|
||||
litellm_params:
|
||||
model: langfuse/openai-model
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
- model_name: openai-model
|
||||
litellm_params:
|
||||
model: openai/gpt-3.5-turbo
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
api_base: http://0.0.0.0:8090
|
||||
- model_name: deepseek-r1
|
||||
litellm_params:
|
||||
model: bedrock/deepseek_r1/arn:aws:bedrock:us-west-2:888602223428:imported-model/bnnr6463ejgf
|
||||
- model_name: deepseek-r1-api
|
||||
litellm_params:
|
||||
model: deepseek/deepseek-reasoner
|
||||
- model_name: cohere.embed-english-v3
|
||||
litellm_params:
|
||||
model: bedrock/cohere.embed-english-v3
|
||||
api_key: os.environ/COHERE_API_KEY
|
||||
- model_name: bedrock-claude-3-7
|
||||
litellm_params:
|
||||
model: bedrock/invoke/us.anthropic.claude-3-7-sonnet-20250219-v1:0
|
||||
- model_name: bedrock-claude-3-5-sonnet
|
||||
litellm_params:
|
||||
model: bedrock/invoke/us.anthropic.claude-3-5-sonnet-20240620-v1:0
|
||||
- model_name: bedrock-nova
|
||||
litellm_params:
|
||||
model: bedrock/us.amazon.nova-pro-v1:0
|
||||
- model_name: gpt-4o
|
||||
litellm_params:
|
||||
model: openai/gpt-4o
|
||||
|
||||
litellm_settings:
|
||||
cache: true
|
||||
cache_params: # set cache params for redis
|
||||
type: redis
|
||||
namespace: "litellm.caching"
|
||||
api_key: os.environ/OPENAI_API_KEY
|
|
@ -1003,6 +1003,7 @@ class AddTeamCallback(LiteLLMPydanticObjectBase):
|
|||
class TeamCallbackMetadata(LiteLLMPydanticObjectBase):
|
||||
success_callback: Optional[List[str]] = []
|
||||
failure_callback: Optional[List[str]] = []
|
||||
callbacks: Optional[List[str]] = []
|
||||
# for now - only supported for langfuse
|
||||
callback_vars: Optional[Dict[str, str]] = {}
|
||||
|
||||
|
@ -1015,6 +1016,9 @@ class TeamCallbackMetadata(LiteLLMPydanticObjectBase):
|
|||
failure_callback = values.get("failure_callback", [])
|
||||
if failure_callback is None:
|
||||
values.pop("failure_callback", None)
|
||||
callbacks = values.get("callbacks", [])
|
||||
if callbacks is None:
|
||||
values.pop("callbacks", None)
|
||||
|
||||
callback_vars = values.get("callback_vars", {})
|
||||
if callback_vars is None:
|
||||
|
@ -1023,6 +1027,7 @@ class TeamCallbackMetadata(LiteLLMPydanticObjectBase):
|
|||
return {
|
||||
"success_callback": [],
|
||||
"failure_callback": [],
|
||||
"callbacks": [],
|
||||
"callback_vars": {},
|
||||
}
|
||||
valid_keys = set(StandardCallbackDynamicParams.__annotations__.keys())
|
||||
|
|
|
@ -102,11 +102,15 @@ def convert_key_logging_metadata_to_callback(
|
|||
|
||||
if data.callback_name not in team_callback_settings_obj.failure_callback:
|
||||
team_callback_settings_obj.failure_callback.append(data.callback_name)
|
||||
elif data.callback_type == "success_and_failure":
|
||||
elif (
|
||||
not data.callback_type or data.callback_type == "success_and_failure"
|
||||
): # assume 'success_and_failure' = litellm.callbacks
|
||||
if team_callback_settings_obj.success_callback is None:
|
||||
team_callback_settings_obj.success_callback = []
|
||||
if team_callback_settings_obj.failure_callback is None:
|
||||
team_callback_settings_obj.failure_callback = []
|
||||
if team_callback_settings_obj.callbacks is None:
|
||||
team_callback_settings_obj.callbacks = []
|
||||
|
||||
if data.callback_name not in team_callback_settings_obj.success_callback:
|
||||
team_callback_settings_obj.success_callback.append(data.callback_name)
|
||||
|
@ -114,6 +118,9 @@ def convert_key_logging_metadata_to_callback(
|
|||
if data.callback_name not in team_callback_settings_obj.failure_callback:
|
||||
team_callback_settings_obj.failure_callback.append(data.callback_name)
|
||||
|
||||
if data.callback_name not in team_callback_settings_obj.callbacks:
|
||||
team_callback_settings_obj.callbacks.append(data.callback_name)
|
||||
|
||||
for var, value in data.callback_vars.items():
|
||||
if team_callback_settings_obj.callback_vars is None:
|
||||
team_callback_settings_obj.callback_vars = {}
|
||||
|
|
|
@ -11,7 +11,6 @@ import uuid
|
|||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from fastapi.responses import RedirectResponse
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
|
|
|
@ -53,7 +53,7 @@ async def route_request(
|
|||
"""
|
||||
router_model_names = llm_router.model_names if llm_router is not None else []
|
||||
if "api_key" in data or "api_base" in data:
|
||||
return getattr(litellm, f"{route_type}")(**data)
|
||||
return getattr(llm_router, f"{route_type}")(**data)
|
||||
|
||||
elif "user_config" in data:
|
||||
router_config = data.pop("user_config")
|
||||
|
|
|
@ -67,6 +67,10 @@ from litellm.router_utils.batch_utils import (
|
|||
replace_model_in_jsonl,
|
||||
)
|
||||
from litellm.router_utils.client_initalization_utils import InitalizeOpenAISDKClient
|
||||
from litellm.router_utils.clientside_credential_handler import (
|
||||
get_dynamic_litellm_params,
|
||||
is_clientside_credential,
|
||||
)
|
||||
from litellm.router_utils.cooldown_cache import CooldownCache
|
||||
from litellm.router_utils.cooldown_handlers import (
|
||||
DEFAULT_COOLDOWN_TIME_SECONDS,
|
||||
|
@ -1067,20 +1071,61 @@ class Router:
|
|||
elif k == "metadata":
|
||||
kwargs[k].update(v)
|
||||
|
||||
def _handle_clientside_credential(
|
||||
self, deployment: dict, kwargs: dict
|
||||
) -> Deployment:
|
||||
"""
|
||||
Handle clientside credential
|
||||
"""
|
||||
model_info = deployment.get("model_info", {}).copy()
|
||||
litellm_params = deployment["litellm_params"].copy()
|
||||
dynamic_litellm_params = get_dynamic_litellm_params(
|
||||
litellm_params=litellm_params, request_kwargs=kwargs
|
||||
)
|
||||
metadata = kwargs.get("metadata", {})
|
||||
model_group = cast(str, metadata.get("model_group"))
|
||||
_model_id = self._generate_model_id(
|
||||
model_group=model_group, litellm_params=dynamic_litellm_params
|
||||
)
|
||||
original_model_id = model_info.get("id")
|
||||
model_info["id"] = _model_id
|
||||
model_info["original_model_id"] = original_model_id
|
||||
deployment_pydantic_obj = Deployment(
|
||||
model_name=model_group,
|
||||
litellm_params=LiteLLM_Params(**dynamic_litellm_params),
|
||||
model_info=model_info,
|
||||
)
|
||||
self.upsert_deployment(
|
||||
deployment=deployment_pydantic_obj
|
||||
) # add new deployment to router
|
||||
return deployment_pydantic_obj
|
||||
|
||||
def _update_kwargs_with_deployment(self, deployment: dict, kwargs: dict) -> None:
|
||||
"""
|
||||
2 jobs:
|
||||
- Adds selected deployment, model_info and api_base to kwargs["metadata"] (used for logging)
|
||||
- Adds default litellm params to kwargs, if set.
|
||||
"""
|
||||
model_info = deployment.get("model_info", {}).copy()
|
||||
deployment_model_name = deployment["litellm_params"]["model"]
|
||||
deployment_api_base = deployment["litellm_params"].get("api_base")
|
||||
if is_clientside_credential(request_kwargs=kwargs):
|
||||
deployment_pydantic_obj = self._handle_clientside_credential(
|
||||
deployment=deployment, kwargs=kwargs
|
||||
)
|
||||
model_info = deployment_pydantic_obj.model_info.model_dump()
|
||||
deployment_model_name = deployment_pydantic_obj.litellm_params.model
|
||||
deployment_api_base = deployment_pydantic_obj.litellm_params.api_base
|
||||
|
||||
kwargs.setdefault("metadata", {}).update(
|
||||
{
|
||||
"deployment": deployment["litellm_params"]["model"],
|
||||
"model_info": deployment.get("model_info", {}),
|
||||
"api_base": deployment.get("litellm_params", {}).get("api_base"),
|
||||
"deployment": deployment_model_name,
|
||||
"model_info": model_info,
|
||||
"api_base": deployment_api_base,
|
||||
}
|
||||
)
|
||||
kwargs["model_info"] = deployment.get("model_info", {})
|
||||
kwargs["model_info"] = model_info
|
||||
|
||||
kwargs["timeout"] = self._get_timeout(
|
||||
kwargs=kwargs, data=deployment["litellm_params"]
|
||||
)
|
||||
|
@ -1705,6 +1750,7 @@ class Router:
|
|||
model=model,
|
||||
messages=[{"role": "user", "content": "prompt"}],
|
||||
specific_deployment=kwargs.pop("specific_deployment", None),
|
||||
request_kwargs=kwargs,
|
||||
)
|
||||
self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
|
||||
|
||||
|
@ -1818,6 +1864,7 @@ class Router:
|
|||
model=model,
|
||||
messages=[{"role": "user", "content": "prompt"}],
|
||||
specific_deployment=kwargs.pop("specific_deployment", None),
|
||||
request_kwargs=kwargs,
|
||||
)
|
||||
|
||||
self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
|
||||
|
@ -1916,6 +1963,7 @@ class Router:
|
|||
model=model,
|
||||
messages=[{"role": "user", "content": "prompt"}],
|
||||
specific_deployment=kwargs.pop("specific_deployment", None),
|
||||
request_kwargs=kwargs,
|
||||
)
|
||||
self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
|
||||
data = deployment["litellm_params"].copy()
|
||||
|
@ -1991,6 +2039,7 @@ class Router:
|
|||
deployment = await self.async_get_available_deployment(
|
||||
model=model,
|
||||
specific_deployment=kwargs.pop("specific_deployment", None),
|
||||
request_kwargs=kwargs,
|
||||
)
|
||||
self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
|
||||
data = deployment["litellm_params"].copy()
|
||||
|
@ -2035,6 +2084,7 @@ class Router:
|
|||
model=model,
|
||||
messages=messages,
|
||||
specific_deployment=kwargs.pop("specific_deployment", None),
|
||||
request_kwargs=kwargs,
|
||||
)
|
||||
|
||||
data = deployment["litellm_params"].copy()
|
||||
|
@ -2140,6 +2190,7 @@ class Router:
|
|||
model=model,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
specific_deployment=kwargs.pop("specific_deployment", None),
|
||||
request_kwargs=kwargs,
|
||||
)
|
||||
self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
|
||||
|
||||
|
@ -2238,6 +2289,7 @@ class Router:
|
|||
model=model,
|
||||
messages=[{"role": "user", "content": "default text"}],
|
||||
specific_deployment=kwargs.pop("specific_deployment", None),
|
||||
request_kwargs=kwargs,
|
||||
)
|
||||
self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
|
||||
|
||||
|
@ -2407,6 +2459,7 @@ class Router:
|
|||
model=model,
|
||||
input=input,
|
||||
specific_deployment=kwargs.pop("specific_deployment", None),
|
||||
request_kwargs=kwargs,
|
||||
)
|
||||
self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
|
||||
data = deployment["litellm_params"].copy()
|
||||
|
@ -2504,6 +2557,7 @@ class Router:
|
|||
model=model,
|
||||
messages=[{"role": "user", "content": "files-api-fake-text"}],
|
||||
specific_deployment=kwargs.pop("specific_deployment", None),
|
||||
request_kwargs=kwargs,
|
||||
)
|
||||
self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
|
||||
|
||||
|
@ -2609,6 +2663,7 @@ class Router:
|
|||
model=model,
|
||||
messages=[{"role": "user", "content": "files-api-fake-text"}],
|
||||
specific_deployment=kwargs.pop("specific_deployment", None),
|
||||
request_kwargs=kwargs,
|
||||
)
|
||||
metadata_variable_name = _get_router_metadata_variable_name(
|
||||
function_name="_acreate_batch"
|
||||
|
@ -2805,7 +2860,8 @@ class Router:
|
|||
):
|
||||
if kwargs.get("model") and self.get_model_list(model_name=kwargs["model"]):
|
||||
deployment = await self.async_get_available_deployment(
|
||||
model=kwargs["model"]
|
||||
model=kwargs["model"],
|
||||
request_kwargs=kwargs,
|
||||
)
|
||||
kwargs["model"] = deployment["litellm_params"]["model"]
|
||||
return await original_function(**kwargs)
|
||||
|
@ -3601,6 +3657,7 @@ class Router:
|
|||
- True if the deployment should be put in cooldown
|
||||
- False if the deployment should not be put in cooldown
|
||||
"""
|
||||
verbose_router_logger.debug("Router: Entering 'deployment_callback_on_failure'")
|
||||
try:
|
||||
exception = kwargs.get("exception", None)
|
||||
exception_status = getattr(exception, "status_code", "")
|
||||
|
@ -3642,6 +3699,9 @@ class Router:
|
|||
|
||||
return result
|
||||
else:
|
||||
verbose_router_logger.debug(
|
||||
"Router: Exiting 'deployment_callback_on_failure' without cooldown. No model_info found."
|
||||
)
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
|
@ -5541,10 +5601,10 @@ class Router:
|
|||
async def async_get_available_deployment(
|
||||
self,
|
||||
model: str,
|
||||
request_kwargs: Dict,
|
||||
messages: Optional[List[Dict[str, str]]] = None,
|
||||
input: Optional[Union[str, List]] = None,
|
||||
specific_deployment: Optional[bool] = False,
|
||||
request_kwargs: Optional[Dict] = None,
|
||||
):
|
||||
"""
|
||||
Async implementation of 'get_available_deployments'.
|
||||
|
|
37
litellm/router_utils/clientside_credential_handler.py
Normal file
37
litellm/router_utils/clientside_credential_handler.py
Normal file
|
@ -0,0 +1,37 @@
|
|||
"""
|
||||
Utils for handling clientside credentials
|
||||
|
||||
Supported clientside credentials:
|
||||
- api_key
|
||||
- api_base
|
||||
- base_url
|
||||
|
||||
If given, generate a unique model_id for the deployment.
|
||||
|
||||
Ensures cooldowns are applied correctly.
|
||||
"""
|
||||
|
||||
clientside_credential_keys = ["api_key", "api_base", "base_url"]
|
||||
|
||||
|
||||
def is_clientside_credential(request_kwargs: dict) -> bool:
|
||||
"""
|
||||
Check if the credential is a clientside credential.
|
||||
"""
|
||||
return any(key in request_kwargs for key in clientside_credential_keys)
|
||||
|
||||
|
||||
def get_dynamic_litellm_params(litellm_params: dict, request_kwargs: dict) -> dict:
|
||||
"""
|
||||
Generate a unique model_id for the deployment.
|
||||
|
||||
Returns
|
||||
- litellm_params: dict
|
||||
|
||||
for generating a unique model_id.
|
||||
"""
|
||||
# update litellm_params with clientside credentials
|
||||
for key in clientside_credential_keys:
|
||||
if key in request_kwargs:
|
||||
litellm_params[key] = request_kwargs[key]
|
||||
return litellm_params
|
|
@ -112,12 +112,19 @@ def _should_run_cooldown_logic(
|
|||
deployment is None
|
||||
or litellm_router_instance.get_model_group(id=deployment) is None
|
||||
):
|
||||
verbose_router_logger.debug(
|
||||
"Should Not Run Cooldown Logic: deployment id is none or model group can't be found."
|
||||
)
|
||||
return False
|
||||
|
||||
if litellm_router_instance.disable_cooldowns:
|
||||
verbose_router_logger.debug(
|
||||
"Should Not Run Cooldown Logic: disable_cooldowns is True"
|
||||
)
|
||||
return False
|
||||
|
||||
if deployment is None:
|
||||
verbose_router_logger.debug("Should Not Run Cooldown Logic: deployment is None")
|
||||
return False
|
||||
|
||||
if not _is_cooldown_required(
|
||||
|
@ -126,9 +133,15 @@ def _should_run_cooldown_logic(
|
|||
exception_status=exception_status,
|
||||
exception_str=str(original_exception),
|
||||
):
|
||||
verbose_router_logger.debug(
|
||||
"Should Not Run Cooldown Logic: _is_cooldown_required returned False"
|
||||
)
|
||||
return False
|
||||
|
||||
if deployment in litellm_router_instance.provider_default_deployment_ids:
|
||||
verbose_router_logger.debug(
|
||||
"Should Not Run Cooldown Logic: deployment is in provider_default_deployment_ids"
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
@ -244,6 +257,8 @@ def _set_cooldown_deployments(
|
|||
- True if the deployment should be put in cooldown
|
||||
- False if the deployment should not be put in cooldown
|
||||
"""
|
||||
verbose_router_logger.debug("checks 'should_run_cooldown_logic'")
|
||||
|
||||
if (
|
||||
_should_run_cooldown_logic(
|
||||
litellm_router_instance, deployment, exception_status, original_exception
|
||||
|
@ -251,6 +266,7 @@ def _set_cooldown_deployments(
|
|||
is False
|
||||
or deployment is None
|
||||
):
|
||||
verbose_router_logger.debug("should_run_cooldown_logic returned False")
|
||||
return False
|
||||
|
||||
exception_status_int = cast_exception_status_to_int(exception_status)
|
||||
|
|
|
@ -451,6 +451,15 @@ def get_applied_guardrails(kwargs: Dict[str, Any]) -> List[str]:
|
|||
return applied_guardrails
|
||||
|
||||
|
||||
def get_dynamic_callbacks(
|
||||
dynamic_callbacks: Optional[List[Union[str, Callable, CustomLogger]]]
|
||||
) -> List:
|
||||
returned_callbacks = litellm.callbacks.copy()
|
||||
if dynamic_callbacks:
|
||||
returned_callbacks.extend(dynamic_callbacks) # type: ignore
|
||||
return returned_callbacks
|
||||
|
||||
|
||||
def function_setup( # noqa: PLR0915
|
||||
original_function: str, rules_obj, start_time, *args, **kwargs
|
||||
): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.
|
||||
|
@ -475,12 +484,18 @@ def function_setup( # noqa: PLR0915
|
|||
## LOGGING SETUP
|
||||
function_id: Optional[str] = kwargs["id"] if "id" in kwargs else None
|
||||
|
||||
if len(litellm.callbacks) > 0:
|
||||
for callback in litellm.callbacks:
|
||||
## DYNAMIC CALLBACKS ##
|
||||
dynamic_callbacks: Optional[List[Union[str, Callable, CustomLogger]]] = (
|
||||
kwargs.pop("callbacks", None)
|
||||
)
|
||||
all_callbacks = get_dynamic_callbacks(dynamic_callbacks=dynamic_callbacks)
|
||||
|
||||
if len(all_callbacks) > 0:
|
||||
for callback in all_callbacks:
|
||||
# check if callback is a string - e.g. "lago", "openmeter"
|
||||
if isinstance(callback, str):
|
||||
callback = litellm.litellm_core_utils.litellm_logging._init_custom_logger_compatible_class( # type: ignore
|
||||
callback, internal_usage_cache=None, llm_router=None
|
||||
callback, internal_usage_cache=None, llm_router=None # type: ignore
|
||||
)
|
||||
if callback is None or any(
|
||||
isinstance(cb, type(callback))
|
||||
|
|
|
@ -96,7 +96,7 @@ requires = ["poetry-core", "wheel"]
|
|||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.commitizen]
|
||||
version = "1.62.2"
|
||||
version = "1.62.3"
|
||||
version_files = [
|
||||
"pyproject.toml:^version"
|
||||
]
|
||||
|
|
46
tests/litellm/proxy/test_route_llm_request.py
Normal file
46
tests/litellm/proxy/test_route_llm_request.py
Normal file
|
@ -0,0 +1,46 @@
|
|||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from litellm.proxy.route_llm_request import route_request
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"route_type",
|
||||
[
|
||||
"atext_completion",
|
||||
"acompletion",
|
||||
"aembedding",
|
||||
"aimage_generation",
|
||||
"aspeech",
|
||||
"atranscription",
|
||||
"amoderation",
|
||||
"arerank",
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_route_request_dynamic_credentials(route_type):
|
||||
data = {
|
||||
"model": "openai/gpt-4o-mini-2024-07-18",
|
||||
"api_key": "my-bad-key",
|
||||
"api_base": "https://api.openai.com/v1 ",
|
||||
}
|
||||
llm_router = MagicMock()
|
||||
# Ensure that the dynamic method exists on the llm_router mock.
|
||||
getattr(llm_router, route_type).return_value = "fake_response"
|
||||
|
||||
response = await route_request(data, llm_router, None, route_type)
|
||||
# Optionally verify the response if needed:
|
||||
assert response == "fake_response"
|
||||
# Now assert that the dynamic method was called once with the expected kwargs.
|
||||
getattr(llm_router, route_type).assert_called_once_with(**data)
|
|
@ -119,7 +119,7 @@ async def test_router_get_available_deployments(async_test):
|
|||
if async_test is True:
|
||||
await router.cache.async_set_cache(key=cache_key, value=request_count_dict)
|
||||
deployment = await router.async_get_available_deployment(
|
||||
model=model_group, messages=None
|
||||
model=model_group, messages=None, request_kwargs={}
|
||||
)
|
||||
else:
|
||||
router.cache.set_cache(key=cache_key, value=request_count_dict)
|
||||
|
|
|
@ -2777,3 +2777,46 @@ def test_router_get_model_list_from_model_alias():
|
|||
model_name="gpt-3.5-turbo"
|
||||
)
|
||||
assert len(model_alias_list) == 0
|
||||
|
||||
|
||||
def test_router_dynamic_credentials():
|
||||
"""
|
||||
Assert model id for dynamic api key 1 != model id for dynamic api key 2
|
||||
"""
|
||||
original_model_id = "123"
|
||||
original_api_key = "my-bad-key"
|
||||
router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "openai/gpt-3.5-turbo",
|
||||
"api_key": original_api_key,
|
||||
"mock_response": "fake_response",
|
||||
},
|
||||
"model_info": {"id": original_model_id},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
deployment = router.get_deployment(model_id=original_model_id)
|
||||
assert deployment is not None
|
||||
assert deployment.litellm_params.api_key == original_api_key
|
||||
|
||||
response = router.completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
api_key="my-bad-key-2",
|
||||
)
|
||||
|
||||
response_2 = router.completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
api_key="my-bad-key-3",
|
||||
)
|
||||
|
||||
assert response_2._hidden_params["model_id"] != response._hidden_params["model_id"]
|
||||
|
||||
deployment = router.get_deployment(model_id=original_model_id)
|
||||
assert deployment is not None
|
||||
assert deployment.litellm_params.api_key == original_api_key
|
||||
|
|
|
@ -692,3 +692,50 @@ def test_router_fallbacks_with_cooldowns_and_model_id():
|
|||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_router_fallbacks_with_cooldowns_and_dynamic_credentials():
|
||||
"""
|
||||
Ensure cooldown on credential 1 does not affect credential 2
|
||||
"""
|
||||
from litellm.router_utils.cooldown_handlers import _async_get_cooldown_deployments
|
||||
|
||||
litellm._turn_on_debug()
|
||||
router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {"model": "gpt-3.5-turbo", "rpm": 1},
|
||||
"model_info": {
|
||||
"id": "123",
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
## trigger ratelimit
|
||||
try:
|
||||
await router.acompletion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
api_key="my-bad-key-1",
|
||||
mock_response="litellm.RateLimitError",
|
||||
)
|
||||
pytest.fail("Expected RateLimitError")
|
||||
except litellm.RateLimitError:
|
||||
pass
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
cooldown_list = await _async_get_cooldown_deployments(
|
||||
litellm_router_instance=router, parent_otel_span=None
|
||||
)
|
||||
print("cooldown_list: ", cooldown_list)
|
||||
assert len(cooldown_list) == 1
|
||||
|
||||
await router.acompletion(
|
||||
model="gpt-3.5-turbo",
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
)
|
||||
|
|
|
@ -569,7 +569,7 @@ async def test_weighted_selection_router_async(rpm_list, tpm_list):
|
|||
# call get_available_deployment 1k times, it should pick azure/chatgpt-v-2 about 90% of the time
|
||||
for _ in range(1000):
|
||||
selected_model = await router.async_get_available_deployment(
|
||||
"gpt-3.5-turbo"
|
||||
"gpt-3.5-turbo", request_kwargs={}
|
||||
)
|
||||
selected_model_id = selected_model["litellm_params"]["model"]
|
||||
selected_model_name = selected_model_id
|
||||
|
|
|
@ -26,11 +26,6 @@ import litellm
|
|||
from litellm import Router
|
||||
from litellm._logging import verbose_logger
|
||||
|
||||
verbose_logger.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_router_free_paid_tier():
|
||||
|
@ -93,6 +88,69 @@ async def test_router_free_paid_tier():
|
|||
assert response_extra_info["model_id"] == "very-expensive-model"
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_router_free_paid_tier_embeddings():
|
||||
"""
|
||||
Pass list of orgs in 1 model definition,
|
||||
expect a unique deployment for each to be created
|
||||
"""
|
||||
router = litellm.Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "gpt-4",
|
||||
"litellm_params": {
|
||||
"model": "gpt-4o",
|
||||
"api_base": "https://exampleopenaiendpoint-production.up.railway.app/",
|
||||
"tags": ["free"],
|
||||
"mock_response": ["1", "2", "3"],
|
||||
},
|
||||
"model_info": {"id": "very-cheap-model"},
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-4",
|
||||
"litellm_params": {
|
||||
"model": "gpt-4o-mini",
|
||||
"api_base": "https://exampleopenaiendpoint-production.up.railway.app/",
|
||||
"tags": ["paid"],
|
||||
"mock_response": ["1", "2", "3"],
|
||||
},
|
||||
"model_info": {"id": "very-expensive-model"},
|
||||
},
|
||||
],
|
||||
enable_tag_filtering=True,
|
||||
)
|
||||
|
||||
for _ in range(1):
|
||||
# this should pick model with id == very-cheap-model
|
||||
response = await router.aembedding(
|
||||
model="gpt-4",
|
||||
input="Tell me a joke.",
|
||||
metadata={"tags": ["free"]},
|
||||
)
|
||||
|
||||
print("Response: ", response)
|
||||
|
||||
response_extra_info = response._hidden_params
|
||||
print("response_extra_info: ", response_extra_info)
|
||||
|
||||
assert response_extra_info["model_id"] == "very-cheap-model"
|
||||
|
||||
for _ in range(5):
|
||||
# this should pick model with id == very-cheap-model
|
||||
response = await router.aembedding(
|
||||
model="gpt-4",
|
||||
input="Tell me a joke.",
|
||||
metadata={"tags": ["paid"]},
|
||||
)
|
||||
|
||||
print("Response: ", response)
|
||||
|
||||
response_extra_info = response._hidden_params
|
||||
print("response_extra_info: ", response_extra_info)
|
||||
|
||||
assert response_extra_info["model_id"] == "very-expensive-model"
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_default_tagged_deployments():
|
||||
"""
|
||||
|
|
|
@ -396,3 +396,25 @@ def test_router_redis_cache():
|
|||
router._update_redis_cache(cache=redis_cache)
|
||||
|
||||
assert router.cache.redis_cache == redis_cache
|
||||
|
||||
|
||||
def test_router_handle_clientside_credential():
|
||||
deployment = {
|
||||
"model_name": "gemini/*",
|
||||
"litellm_params": {"model": "gemini/*"},
|
||||
"model_info": {
|
||||
"id": "1",
|
||||
},
|
||||
}
|
||||
router = Router(model_list=[deployment])
|
||||
|
||||
new_deployment = router._handle_clientside_credential(
|
||||
deployment=deployment,
|
||||
kwargs={
|
||||
"api_key": "123",
|
||||
"metadata": {"model_group": "gemini/gemini-1.5-flash"},
|
||||
},
|
||||
)
|
||||
|
||||
assert new_deployment.litellm_params.api_key == "123"
|
||||
assert len(router.get_model_list()) == 2
|
||||
|
|
|
@ -377,6 +377,7 @@ async def test_multiple_potential_deployments(sync_mode):
|
|||
deployment = await router.async_get_available_deployment(
|
||||
model="azure-model",
|
||||
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||
request_kwargs={},
|
||||
)
|
||||
|
||||
## get id ##
|
||||
|
|
|
@ -20,14 +20,7 @@ import {
|
|||
} from "@/components/networking";
|
||||
import { jwtDecode } from "jwt-decode";
|
||||
import { Form, Button as Button2, message } from "antd";
|
||||
|
||||
function getCookie(name: string) {
|
||||
console.log("COOKIES", document.cookie)
|
||||
const cookieValue = document.cookie
|
||||
.split('; ')
|
||||
.find(row => row.startsWith(name + '='));
|
||||
return cookieValue ? cookieValue.split('=')[1] : null;
|
||||
}
|
||||
import { getCookie } from "@/utils/cookieUtils";
|
||||
|
||||
export default function Onboarding() {
|
||||
const [form] = Form.useForm();
|
||||
|
|
|
@ -211,6 +211,7 @@ export default function CreateKeyPage() {
|
|||
userID={userID}
|
||||
userRole={userRole}
|
||||
premiumUser={premiumUser}
|
||||
userEmail={userEmail}
|
||||
setProxySettings={setProxySettings}
|
||||
proxySettings={proxySettings}
|
||||
/>
|
||||
|
|
|
@ -8,8 +8,10 @@ import {
|
|||
UserOutlined,
|
||||
LogoutOutlined
|
||||
} from '@ant-design/icons';
|
||||
import { clearTokenCookies } from "@/utils/cookieUtils";
|
||||
interface NavbarProps {
|
||||
userID: string | null;
|
||||
userEmail: string | null;
|
||||
userRole: string | null;
|
||||
premiumUser: boolean;
|
||||
setProxySettings: React.Dispatch<React.SetStateAction<any>>;
|
||||
|
@ -18,6 +20,7 @@ interface NavbarProps {
|
|||
|
||||
const Navbar: React.FC<NavbarProps> = ({
|
||||
userID,
|
||||
userEmail,
|
||||
userRole,
|
||||
premiumUser,
|
||||
proxySettings,
|
||||
|
@ -27,7 +30,7 @@ const Navbar: React.FC<NavbarProps> = ({
|
|||
let logoutUrl = proxySettings?.PROXY_LOGOUT_URL || "";
|
||||
|
||||
const handleLogout = () => {
|
||||
document.cookie = "token=; expires=Thu, 01 Jan 1970 00:00:00 UTC; path=/;";
|
||||
clearTokenCookies();
|
||||
window.location.href = logoutUrl;
|
||||
};
|
||||
|
||||
|
@ -37,6 +40,7 @@ const Navbar: React.FC<NavbarProps> = ({
|
|||
label: (
|
||||
<div className="py-1">
|
||||
<p className="text-sm text-gray-600">Role: {userRole}</p>
|
||||
<p className="text-sm text-gray-600">Email: {userEmail || "Unknown"}</p>
|
||||
<p className="text-sm text-gray-600"><UserOutlined /> {userID}</p>
|
||||
<p className="text-sm text-gray-600">Premium User: {String(premiumUser)}</p>
|
||||
</div>
|
||||
|
|
|
@ -21,6 +21,7 @@ import { useSearchParams, useRouter } from "next/navigation";
|
|||
import { Team } from "./key_team_helpers/key_list";
|
||||
import { jwtDecode } from "jwt-decode";
|
||||
import { Typography } from "antd";
|
||||
import { clearTokenCookies } from "@/utils/cookieUtils";
|
||||
const isLocal = process.env.NODE_ENV === "development";
|
||||
if (isLocal != true) {
|
||||
console.log = function() {};
|
||||
|
@ -295,14 +296,15 @@ const UserDashboard: React.FC<UserDashboardProps> = ({
|
|||
|
||||
if (userID == null || token == null) {
|
||||
// user is not logged in as yet
|
||||
console.log("All cookies before redirect:", document.cookie);
|
||||
|
||||
// Clear token cookies using the utility function
|
||||
clearTokenCookies();
|
||||
|
||||
const url = proxyBaseUrl
|
||||
? `${proxyBaseUrl}/sso/key/generate`
|
||||
: `/sso/key/generate`;
|
||||
|
||||
|
||||
// clear cookie called "token" since user will be logging in again
|
||||
document.cookie = "token=; expires=Thu, 01 Jan 1970 00:00:00 UTC; path=/;";
|
||||
|
||||
console.log("Full URL:", url);
|
||||
window.location.href = url;
|
||||
|
||||
|
@ -326,6 +328,7 @@ const UserDashboard: React.FC<UserDashboardProps> = ({
|
|||
}
|
||||
|
||||
console.log("inside user dashboard, selected team", selectedTeam);
|
||||
console.log("All cookies after redirect:", document.cookie);
|
||||
return (
|
||||
<div className="w-full mx-4 h-[75vh]">
|
||||
<Grid numItems={1} className="gap-2 p-8 w-full mt-2">
|
||||
|
|
44
ui/litellm-dashboard/src/utils/cookieUtils.ts
Normal file
44
ui/litellm-dashboard/src/utils/cookieUtils.ts
Normal file
|
@ -0,0 +1,44 @@
|
|||
/**
|
||||
* Utility functions for managing cookies
|
||||
*/
|
||||
|
||||
/**
|
||||
* Clears the token cookie from both root and /ui paths
|
||||
*/
|
||||
export function clearTokenCookies() {
|
||||
// Get the current domain
|
||||
const domain = window.location.hostname;
|
||||
|
||||
// Clear with various combinations of path and SameSite
|
||||
const paths = ['/', '/ui'];
|
||||
const sameSiteValues = ['Lax', 'Strict', 'None'];
|
||||
|
||||
paths.forEach(path => {
|
||||
// Basic clearing
|
||||
document.cookie = `token=; expires=Thu, 01 Jan 1970 00:00:00 UTC; path=${path};`;
|
||||
|
||||
// With domain
|
||||
document.cookie = `token=; expires=Thu, 01 Jan 1970 00:00:00 UTC; path=${path}; domain=${domain};`;
|
||||
|
||||
// Try different SameSite values
|
||||
sameSiteValues.forEach(sameSite => {
|
||||
const secureFlag = sameSite === 'None' ? ' Secure;' : '';
|
||||
document.cookie = `token=; expires=Thu, 01 Jan 1970 00:00:00 UTC; path=${path}; SameSite=${sameSite};${secureFlag}`;
|
||||
document.cookie = `token=; expires=Thu, 01 Jan 1970 00:00:00 UTC; path=${path}; domain=${domain}; SameSite=${sameSite};${secureFlag}`;
|
||||
});
|
||||
});
|
||||
|
||||
console.log("After clearing cookies:", document.cookie);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets a cookie value by name
|
||||
* @param name The name of the cookie to retrieve
|
||||
* @returns The cookie value or null if not found
|
||||
*/
|
||||
export function getCookie(name: string) {
|
||||
const cookieValue = document.cookie
|
||||
.split('; ')
|
||||
.find(row => row.startsWith(name + '='));
|
||||
return cookieValue ? cookieValue.split('=')[1] : null;
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue