Merge branch 'BerriAI:main' into main

This commit is contained in:
Sunny Wan 2025-03-04 18:00:58 -05:00 committed by GitHub
commit cf5a0dee6f
25 changed files with 449 additions and 67 deletions

View file

@ -46,7 +46,7 @@ For security inquiries, please contact us at support@berri.ai
|-------------------|-------------------------------------------------------------------------------------------------|
| SOC 2 Type I | Certified. Report available upon request on Enterprise plan. |
| SOC 2 Type II | In progress. Certificate available by April 15th, 2025 |
| ISO27001 | In progress. Certificate available by February 7th, 2025 |
| ISO 27001 | Certified. Report available upon request on Enterprise |
## Supported Data Regions for LiteLLM Cloud
@ -137,7 +137,7 @@ Point of contact email address for general security-related questions: krrish@be
Has the Vendor been audited / certified?
- SOC 2 Type I. Certified. Report available upon request on Enterprise plan.
- SOC 2 Type II. In progress. Certificate available by April 15th, 2025.
- ISO27001. In progress. Certificate available by February 7th, 2025.
- ISO 27001. Certified. Report available upon request on Enterprise plan.
Has an information security management system been implemented?
- Yes - [CodeQL](https://codeql.github.com/) and a comprehensive ISMS covering multiple security domains.

View file

@ -40,6 +40,7 @@ in_memory_dynamic_logger_cache = DynamicLoggingCache()
def langfuse_client_init(
langfuse_public_key=None,
langfuse_secret=None,
langfuse_secret_key=None,
langfuse_host=None,
flush_interval=1,
) -> LangfuseClass:
@ -67,7 +68,10 @@ def langfuse_client_init(
)
# Instance variables
secret_key = langfuse_secret or os.getenv("LANGFUSE_SECRET_KEY")
secret_key = (
langfuse_secret or langfuse_secret_key or os.getenv("LANGFUSE_SECRET_KEY")
)
public_key = langfuse_public_key or os.getenv("LANGFUSE_PUBLIC_KEY")
langfuse_host = langfuse_host or os.getenv(
"LANGFUSE_HOST", "https://cloud.langfuse.com"
@ -190,6 +194,7 @@ class LangfusePromptManagement(LangFuseLogger, PromptManagementBase, CustomLogge
langfuse_client = langfuse_client_init(
langfuse_public_key=dynamic_callback_params.get("langfuse_public_key"),
langfuse_secret=dynamic_callback_params.get("langfuse_secret"),
langfuse_secret_key=dynamic_callback_params.get("langfuse_secret_key"),
langfuse_host=dynamic_callback_params.get("langfuse_host"),
)
langfuse_prompt_client = self._get_prompt_from_id(
@ -206,6 +211,7 @@ class LangfusePromptManagement(LangFuseLogger, PromptManagementBase, CustomLogge
langfuse_client = langfuse_client_init(
langfuse_public_key=dynamic_callback_params.get("langfuse_public_key"),
langfuse_secret=dynamic_callback_params.get("langfuse_secret"),
langfuse_secret_key=dynamic_callback_params.get("langfuse_secret_key"),
langfuse_host=dynamic_callback_params.get("langfuse_host"),
)
langfuse_prompt_client = self._get_prompt_from_id(

View file

@ -1,34 +1,9 @@
model_list:
- model_name: claude-3.7
- model_name: my-langfuse-model
litellm_params:
model: langfuse/openai-model
api_key: os.environ/OPENAI_API_KEY
- model_name: openai-model
litellm_params:
model: openai/gpt-3.5-turbo
api_key: os.environ/OPENAI_API_KEY
api_base: http://0.0.0.0:8090
- model_name: deepseek-r1
litellm_params:
model: bedrock/deepseek_r1/arn:aws:bedrock:us-west-2:888602223428:imported-model/bnnr6463ejgf
- model_name: deepseek-r1-api
litellm_params:
model: deepseek/deepseek-reasoner
- model_name: cohere.embed-english-v3
litellm_params:
model: bedrock/cohere.embed-english-v3
api_key: os.environ/COHERE_API_KEY
- model_name: bedrock-claude-3-7
litellm_params:
model: bedrock/invoke/us.anthropic.claude-3-7-sonnet-20250219-v1:0
- model_name: bedrock-claude-3-5-sonnet
litellm_params:
model: bedrock/invoke/us.anthropic.claude-3-5-sonnet-20240620-v1:0
- model_name: bedrock-nova
litellm_params:
model: bedrock/us.amazon.nova-pro-v1:0
- model_name: gpt-4o
litellm_params:
model: openai/gpt-4o
litellm_settings:
cache: true
cache_params: # set cache params for redis
type: redis
namespace: "litellm.caching"
api_key: os.environ/OPENAI_API_KEY

View file

@ -1003,6 +1003,7 @@ class AddTeamCallback(LiteLLMPydanticObjectBase):
class TeamCallbackMetadata(LiteLLMPydanticObjectBase):
success_callback: Optional[List[str]] = []
failure_callback: Optional[List[str]] = []
callbacks: Optional[List[str]] = []
# for now - only supported for langfuse
callback_vars: Optional[Dict[str, str]] = {}
@ -1015,6 +1016,9 @@ class TeamCallbackMetadata(LiteLLMPydanticObjectBase):
failure_callback = values.get("failure_callback", [])
if failure_callback is None:
values.pop("failure_callback", None)
callbacks = values.get("callbacks", [])
if callbacks is None:
values.pop("callbacks", None)
callback_vars = values.get("callback_vars", {})
if callback_vars is None:
@ -1023,6 +1027,7 @@ class TeamCallbackMetadata(LiteLLMPydanticObjectBase):
return {
"success_callback": [],
"failure_callback": [],
"callbacks": [],
"callback_vars": {},
}
valid_keys = set(StandardCallbackDynamicParams.__annotations__.keys())

View file

@ -102,11 +102,15 @@ def convert_key_logging_metadata_to_callback(
if data.callback_name not in team_callback_settings_obj.failure_callback:
team_callback_settings_obj.failure_callback.append(data.callback_name)
elif data.callback_type == "success_and_failure":
elif (
not data.callback_type or data.callback_type == "success_and_failure"
): # assume 'success_and_failure' = litellm.callbacks
if team_callback_settings_obj.success_callback is None:
team_callback_settings_obj.success_callback = []
if team_callback_settings_obj.failure_callback is None:
team_callback_settings_obj.failure_callback = []
if team_callback_settings_obj.callbacks is None:
team_callback_settings_obj.callbacks = []
if data.callback_name not in team_callback_settings_obj.success_callback:
team_callback_settings_obj.success_callback.append(data.callback_name)
@ -114,6 +118,9 @@ def convert_key_logging_metadata_to_callback(
if data.callback_name not in team_callback_settings_obj.failure_callback:
team_callback_settings_obj.failure_callback.append(data.callback_name)
if data.callback_name not in team_callback_settings_obj.callbacks:
team_callback_settings_obj.callbacks.append(data.callback_name)
for var, value in data.callback_vars.items():
if team_callback_settings_obj.callback_vars is None:
team_callback_settings_obj.callback_vars = {}

View file

@ -11,7 +11,6 @@ import uuid
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.responses import RedirectResponse
import litellm
from litellm._logging import verbose_proxy_logger

View file

@ -53,7 +53,7 @@ async def route_request(
"""
router_model_names = llm_router.model_names if llm_router is not None else []
if "api_key" in data or "api_base" in data:
return getattr(litellm, f"{route_type}")(**data)
return getattr(llm_router, f"{route_type}")(**data)
elif "user_config" in data:
router_config = data.pop("user_config")

View file

@ -67,6 +67,10 @@ from litellm.router_utils.batch_utils import (
replace_model_in_jsonl,
)
from litellm.router_utils.client_initalization_utils import InitalizeOpenAISDKClient
from litellm.router_utils.clientside_credential_handler import (
get_dynamic_litellm_params,
is_clientside_credential,
)
from litellm.router_utils.cooldown_cache import CooldownCache
from litellm.router_utils.cooldown_handlers import (
DEFAULT_COOLDOWN_TIME_SECONDS,
@ -1067,20 +1071,61 @@ class Router:
elif k == "metadata":
kwargs[k].update(v)
def _handle_clientside_credential(
self, deployment: dict, kwargs: dict
) -> Deployment:
"""
Handle clientside credential
"""
model_info = deployment.get("model_info", {}).copy()
litellm_params = deployment["litellm_params"].copy()
dynamic_litellm_params = get_dynamic_litellm_params(
litellm_params=litellm_params, request_kwargs=kwargs
)
metadata = kwargs.get("metadata", {})
model_group = cast(str, metadata.get("model_group"))
_model_id = self._generate_model_id(
model_group=model_group, litellm_params=dynamic_litellm_params
)
original_model_id = model_info.get("id")
model_info["id"] = _model_id
model_info["original_model_id"] = original_model_id
deployment_pydantic_obj = Deployment(
model_name=model_group,
litellm_params=LiteLLM_Params(**dynamic_litellm_params),
model_info=model_info,
)
self.upsert_deployment(
deployment=deployment_pydantic_obj
) # add new deployment to router
return deployment_pydantic_obj
def _update_kwargs_with_deployment(self, deployment: dict, kwargs: dict) -> None:
"""
2 jobs:
- Adds selected deployment, model_info and api_base to kwargs["metadata"] (used for logging)
- Adds default litellm params to kwargs, if set.
"""
model_info = deployment.get("model_info", {}).copy()
deployment_model_name = deployment["litellm_params"]["model"]
deployment_api_base = deployment["litellm_params"].get("api_base")
if is_clientside_credential(request_kwargs=kwargs):
deployment_pydantic_obj = self._handle_clientside_credential(
deployment=deployment, kwargs=kwargs
)
model_info = deployment_pydantic_obj.model_info.model_dump()
deployment_model_name = deployment_pydantic_obj.litellm_params.model
deployment_api_base = deployment_pydantic_obj.litellm_params.api_base
kwargs.setdefault("metadata", {}).update(
{
"deployment": deployment["litellm_params"]["model"],
"model_info": deployment.get("model_info", {}),
"api_base": deployment.get("litellm_params", {}).get("api_base"),
"deployment": deployment_model_name,
"model_info": model_info,
"api_base": deployment_api_base,
}
)
kwargs["model_info"] = deployment.get("model_info", {})
kwargs["model_info"] = model_info
kwargs["timeout"] = self._get_timeout(
kwargs=kwargs, data=deployment["litellm_params"]
)
@ -1705,6 +1750,7 @@ class Router:
model=model,
messages=[{"role": "user", "content": "prompt"}],
specific_deployment=kwargs.pop("specific_deployment", None),
request_kwargs=kwargs,
)
self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
@ -1818,6 +1864,7 @@ class Router:
model=model,
messages=[{"role": "user", "content": "prompt"}],
specific_deployment=kwargs.pop("specific_deployment", None),
request_kwargs=kwargs,
)
self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
@ -1916,6 +1963,7 @@ class Router:
model=model,
messages=[{"role": "user", "content": "prompt"}],
specific_deployment=kwargs.pop("specific_deployment", None),
request_kwargs=kwargs,
)
self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
data = deployment["litellm_params"].copy()
@ -1991,6 +2039,7 @@ class Router:
deployment = await self.async_get_available_deployment(
model=model,
specific_deployment=kwargs.pop("specific_deployment", None),
request_kwargs=kwargs,
)
self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
data = deployment["litellm_params"].copy()
@ -2035,6 +2084,7 @@ class Router:
model=model,
messages=messages,
specific_deployment=kwargs.pop("specific_deployment", None),
request_kwargs=kwargs,
)
data = deployment["litellm_params"].copy()
@ -2140,6 +2190,7 @@ class Router:
model=model,
messages=[{"role": "user", "content": prompt}],
specific_deployment=kwargs.pop("specific_deployment", None),
request_kwargs=kwargs,
)
self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
@ -2238,6 +2289,7 @@ class Router:
model=model,
messages=[{"role": "user", "content": "default text"}],
specific_deployment=kwargs.pop("specific_deployment", None),
request_kwargs=kwargs,
)
self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
@ -2407,6 +2459,7 @@ class Router:
model=model,
input=input,
specific_deployment=kwargs.pop("specific_deployment", None),
request_kwargs=kwargs,
)
self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
data = deployment["litellm_params"].copy()
@ -2504,6 +2557,7 @@ class Router:
model=model,
messages=[{"role": "user", "content": "files-api-fake-text"}],
specific_deployment=kwargs.pop("specific_deployment", None),
request_kwargs=kwargs,
)
self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
@ -2609,6 +2663,7 @@ class Router:
model=model,
messages=[{"role": "user", "content": "files-api-fake-text"}],
specific_deployment=kwargs.pop("specific_deployment", None),
request_kwargs=kwargs,
)
metadata_variable_name = _get_router_metadata_variable_name(
function_name="_acreate_batch"
@ -2805,7 +2860,8 @@ class Router:
):
if kwargs.get("model") and self.get_model_list(model_name=kwargs["model"]):
deployment = await self.async_get_available_deployment(
model=kwargs["model"]
model=kwargs["model"],
request_kwargs=kwargs,
)
kwargs["model"] = deployment["litellm_params"]["model"]
return await original_function(**kwargs)
@ -3601,6 +3657,7 @@ class Router:
- True if the deployment should be put in cooldown
- False if the deployment should not be put in cooldown
"""
verbose_router_logger.debug("Router: Entering 'deployment_callback_on_failure'")
try:
exception = kwargs.get("exception", None)
exception_status = getattr(exception, "status_code", "")
@ -3642,6 +3699,9 @@ class Router:
return result
else:
verbose_router_logger.debug(
"Router: Exiting 'deployment_callback_on_failure' without cooldown. No model_info found."
)
return False
except Exception as e:
@ -5541,10 +5601,10 @@ class Router:
async def async_get_available_deployment(
self,
model: str,
request_kwargs: Dict,
messages: Optional[List[Dict[str, str]]] = None,
input: Optional[Union[str, List]] = None,
specific_deployment: Optional[bool] = False,
request_kwargs: Optional[Dict] = None,
):
"""
Async implementation of 'get_available_deployments'.

View file

@ -0,0 +1,37 @@
"""
Utils for handling clientside credentials
Supported clientside credentials:
- api_key
- api_base
- base_url
If given, generate a unique model_id for the deployment.
Ensures cooldowns are applied correctly.
"""
clientside_credential_keys = ["api_key", "api_base", "base_url"]
def is_clientside_credential(request_kwargs: dict) -> bool:
"""
Check if the credential is a clientside credential.
"""
return any(key in request_kwargs for key in clientside_credential_keys)
def get_dynamic_litellm_params(litellm_params: dict, request_kwargs: dict) -> dict:
"""
Generate a unique model_id for the deployment.
Returns
- litellm_params: dict
for generating a unique model_id.
"""
# update litellm_params with clientside credentials
for key in clientside_credential_keys:
if key in request_kwargs:
litellm_params[key] = request_kwargs[key]
return litellm_params

View file

@ -112,12 +112,19 @@ def _should_run_cooldown_logic(
deployment is None
or litellm_router_instance.get_model_group(id=deployment) is None
):
verbose_router_logger.debug(
"Should Not Run Cooldown Logic: deployment id is none or model group can't be found."
)
return False
if litellm_router_instance.disable_cooldowns:
verbose_router_logger.debug(
"Should Not Run Cooldown Logic: disable_cooldowns is True"
)
return False
if deployment is None:
verbose_router_logger.debug("Should Not Run Cooldown Logic: deployment is None")
return False
if not _is_cooldown_required(
@ -126,9 +133,15 @@ def _should_run_cooldown_logic(
exception_status=exception_status,
exception_str=str(original_exception),
):
verbose_router_logger.debug(
"Should Not Run Cooldown Logic: _is_cooldown_required returned False"
)
return False
if deployment in litellm_router_instance.provider_default_deployment_ids:
verbose_router_logger.debug(
"Should Not Run Cooldown Logic: deployment is in provider_default_deployment_ids"
)
return False
return True
@ -244,6 +257,8 @@ def _set_cooldown_deployments(
- True if the deployment should be put in cooldown
- False if the deployment should not be put in cooldown
"""
verbose_router_logger.debug("checks 'should_run_cooldown_logic'")
if (
_should_run_cooldown_logic(
litellm_router_instance, deployment, exception_status, original_exception
@ -251,6 +266,7 @@ def _set_cooldown_deployments(
is False
or deployment is None
):
verbose_router_logger.debug("should_run_cooldown_logic returned False")
return False
exception_status_int = cast_exception_status_to_int(exception_status)

View file

@ -451,6 +451,15 @@ def get_applied_guardrails(kwargs: Dict[str, Any]) -> List[str]:
return applied_guardrails
def get_dynamic_callbacks(
dynamic_callbacks: Optional[List[Union[str, Callable, CustomLogger]]]
) -> List:
returned_callbacks = litellm.callbacks.copy()
if dynamic_callbacks:
returned_callbacks.extend(dynamic_callbacks) # type: ignore
return returned_callbacks
def function_setup( # noqa: PLR0915
original_function: str, rules_obj, start_time, *args, **kwargs
): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.
@ -475,12 +484,18 @@ def function_setup( # noqa: PLR0915
## LOGGING SETUP
function_id: Optional[str] = kwargs["id"] if "id" in kwargs else None
if len(litellm.callbacks) > 0:
for callback in litellm.callbacks:
## DYNAMIC CALLBACKS ##
dynamic_callbacks: Optional[List[Union[str, Callable, CustomLogger]]] = (
kwargs.pop("callbacks", None)
)
all_callbacks = get_dynamic_callbacks(dynamic_callbacks=dynamic_callbacks)
if len(all_callbacks) > 0:
for callback in all_callbacks:
# check if callback is a string - e.g. "lago", "openmeter"
if isinstance(callback, str):
callback = litellm.litellm_core_utils.litellm_logging._init_custom_logger_compatible_class( # type: ignore
callback, internal_usage_cache=None, llm_router=None
callback, internal_usage_cache=None, llm_router=None # type: ignore
)
if callback is None or any(
isinstance(cb, type(callback))

View file

@ -96,7 +96,7 @@ requires = ["poetry-core", "wheel"]
build-backend = "poetry.core.masonry.api"
[tool.commitizen]
version = "1.62.2"
version = "1.62.3"
version_files = [
"pyproject.toml:^version"
]

View file

@ -0,0 +1,46 @@
import json
import os
import sys
import pytest
from fastapi.testclient import TestClient
sys.path.insert(
0, os.path.abspath("../../..")
) # Adds the parent directory to the system path
from unittest.mock import MagicMock
from litellm.proxy.route_llm_request import route_request
@pytest.mark.parametrize(
"route_type",
[
"atext_completion",
"acompletion",
"aembedding",
"aimage_generation",
"aspeech",
"atranscription",
"amoderation",
"arerank",
],
)
@pytest.mark.asyncio
async def test_route_request_dynamic_credentials(route_type):
data = {
"model": "openai/gpt-4o-mini-2024-07-18",
"api_key": "my-bad-key",
"api_base": "https://api.openai.com/v1 ",
}
llm_router = MagicMock()
# Ensure that the dynamic method exists on the llm_router mock.
getattr(llm_router, route_type).return_value = "fake_response"
response = await route_request(data, llm_router, None, route_type)
# Optionally verify the response if needed:
assert response == "fake_response"
# Now assert that the dynamic method was called once with the expected kwargs.
getattr(llm_router, route_type).assert_called_once_with(**data)

View file

@ -119,7 +119,7 @@ async def test_router_get_available_deployments(async_test):
if async_test is True:
await router.cache.async_set_cache(key=cache_key, value=request_count_dict)
deployment = await router.async_get_available_deployment(
model=model_group, messages=None
model=model_group, messages=None, request_kwargs={}
)
else:
router.cache.set_cache(key=cache_key, value=request_count_dict)

View file

@ -2777,3 +2777,46 @@ def test_router_get_model_list_from_model_alias():
model_name="gpt-3.5-turbo"
)
assert len(model_alias_list) == 0
def test_router_dynamic_credentials():
"""
Assert model id for dynamic api key 1 != model id for dynamic api key 2
"""
original_model_id = "123"
original_api_key = "my-bad-key"
router = Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "openai/gpt-3.5-turbo",
"api_key": original_api_key,
"mock_response": "fake_response",
},
"model_info": {"id": original_model_id},
}
]
)
deployment = router.get_deployment(model_id=original_model_id)
assert deployment is not None
assert deployment.litellm_params.api_key == original_api_key
response = router.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "hi"}],
api_key="my-bad-key-2",
)
response_2 = router.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "hi"}],
api_key="my-bad-key-3",
)
assert response_2._hidden_params["model_id"] != response._hidden_params["model_id"]
deployment = router.get_deployment(model_id=original_model_id)
assert deployment is not None
assert deployment.litellm_params.api_key == original_api_key

View file

@ -692,3 +692,50 @@ def test_router_fallbacks_with_cooldowns_and_model_id():
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "hi"}],
)
@pytest.mark.asyncio()
async def test_router_fallbacks_with_cooldowns_and_dynamic_credentials():
"""
Ensure cooldown on credential 1 does not affect credential 2
"""
from litellm.router_utils.cooldown_handlers import _async_get_cooldown_deployments
litellm._turn_on_debug()
router = Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {"model": "gpt-3.5-turbo", "rpm": 1},
"model_info": {
"id": "123",
},
}
]
)
## trigger ratelimit
try:
await router.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "hi"}],
api_key="my-bad-key-1",
mock_response="litellm.RateLimitError",
)
pytest.fail("Expected RateLimitError")
except litellm.RateLimitError:
pass
await asyncio.sleep(1)
cooldown_list = await _async_get_cooldown_deployments(
litellm_router_instance=router, parent_otel_span=None
)
print("cooldown_list: ", cooldown_list)
assert len(cooldown_list) == 1
await router.acompletion(
model="gpt-3.5-turbo",
api_key=os.getenv("OPENAI_API_KEY"),
messages=[{"role": "user", "content": "hi"}],
)

View file

@ -569,7 +569,7 @@ async def test_weighted_selection_router_async(rpm_list, tpm_list):
# call get_available_deployment 1k times, it should pick azure/chatgpt-v-2 about 90% of the time
for _ in range(1000):
selected_model = await router.async_get_available_deployment(
"gpt-3.5-turbo"
"gpt-3.5-turbo", request_kwargs={}
)
selected_model_id = selected_model["litellm_params"]["model"]
selected_model_name = selected_model_id

View file

@ -26,11 +26,6 @@ import litellm
from litellm import Router
from litellm._logging import verbose_logger
verbose_logger.setLevel(logging.DEBUG)
load_dotenv()
@pytest.mark.asyncio()
async def test_router_free_paid_tier():
@ -93,6 +88,69 @@ async def test_router_free_paid_tier():
assert response_extra_info["model_id"] == "very-expensive-model"
@pytest.mark.asyncio()
async def test_router_free_paid_tier_embeddings():
"""
Pass list of orgs in 1 model definition,
expect a unique deployment for each to be created
"""
router = litellm.Router(
model_list=[
{
"model_name": "gpt-4",
"litellm_params": {
"model": "gpt-4o",
"api_base": "https://exampleopenaiendpoint-production.up.railway.app/",
"tags": ["free"],
"mock_response": ["1", "2", "3"],
},
"model_info": {"id": "very-cheap-model"},
},
{
"model_name": "gpt-4",
"litellm_params": {
"model": "gpt-4o-mini",
"api_base": "https://exampleopenaiendpoint-production.up.railway.app/",
"tags": ["paid"],
"mock_response": ["1", "2", "3"],
},
"model_info": {"id": "very-expensive-model"},
},
],
enable_tag_filtering=True,
)
for _ in range(1):
# this should pick model with id == very-cheap-model
response = await router.aembedding(
model="gpt-4",
input="Tell me a joke.",
metadata={"tags": ["free"]},
)
print("Response: ", response)
response_extra_info = response._hidden_params
print("response_extra_info: ", response_extra_info)
assert response_extra_info["model_id"] == "very-cheap-model"
for _ in range(5):
# this should pick model with id == very-cheap-model
response = await router.aembedding(
model="gpt-4",
input="Tell me a joke.",
metadata={"tags": ["paid"]},
)
print("Response: ", response)
response_extra_info = response._hidden_params
print("response_extra_info: ", response_extra_info)
assert response_extra_info["model_id"] == "very-expensive-model"
@pytest.mark.asyncio()
async def test_default_tagged_deployments():
"""

View file

@ -396,3 +396,25 @@ def test_router_redis_cache():
router._update_redis_cache(cache=redis_cache)
assert router.cache.redis_cache == redis_cache
def test_router_handle_clientside_credential():
deployment = {
"model_name": "gemini/*",
"litellm_params": {"model": "gemini/*"},
"model_info": {
"id": "1",
},
}
router = Router(model_list=[deployment])
new_deployment = router._handle_clientside_credential(
deployment=deployment,
kwargs={
"api_key": "123",
"metadata": {"model_group": "gemini/gemini-1.5-flash"},
},
)
assert new_deployment.litellm_params.api_key == "123"
assert len(router.get_model_list()) == 2

View file

@ -377,6 +377,7 @@ async def test_multiple_potential_deployments(sync_mode):
deployment = await router.async_get_available_deployment(
model="azure-model",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
request_kwargs={},
)
## get id ##

View file

@ -20,14 +20,7 @@ import {
} from "@/components/networking";
import { jwtDecode } from "jwt-decode";
import { Form, Button as Button2, message } from "antd";
function getCookie(name: string) {
console.log("COOKIES", document.cookie)
const cookieValue = document.cookie
.split('; ')
.find(row => row.startsWith(name + '='));
return cookieValue ? cookieValue.split('=')[1] : null;
}
import { getCookie } from "@/utils/cookieUtils";
export default function Onboarding() {
const [form] = Form.useForm();

View file

@ -211,6 +211,7 @@ export default function CreateKeyPage() {
userID={userID}
userRole={userRole}
premiumUser={premiumUser}
userEmail={userEmail}
setProxySettings={setProxySettings}
proxySettings={proxySettings}
/>

View file

@ -8,8 +8,10 @@ import {
UserOutlined,
LogoutOutlined
} from '@ant-design/icons';
import { clearTokenCookies } from "@/utils/cookieUtils";
interface NavbarProps {
userID: string | null;
userEmail: string | null;
userRole: string | null;
premiumUser: boolean;
setProxySettings: React.Dispatch<React.SetStateAction<any>>;
@ -18,6 +20,7 @@ interface NavbarProps {
const Navbar: React.FC<NavbarProps> = ({
userID,
userEmail,
userRole,
premiumUser,
proxySettings,
@ -27,7 +30,7 @@ const Navbar: React.FC<NavbarProps> = ({
let logoutUrl = proxySettings?.PROXY_LOGOUT_URL || "";
const handleLogout = () => {
document.cookie = "token=; expires=Thu, 01 Jan 1970 00:00:00 UTC; path=/;";
clearTokenCookies();
window.location.href = logoutUrl;
};
@ -37,6 +40,7 @@ const Navbar: React.FC<NavbarProps> = ({
label: (
<div className="py-1">
<p className="text-sm text-gray-600">Role: {userRole}</p>
<p className="text-sm text-gray-600">Email: {userEmail || "Unknown"}</p>
<p className="text-sm text-gray-600"><UserOutlined /> {userID}</p>
<p className="text-sm text-gray-600">Premium User: {String(premiumUser)}</p>
</div>

View file

@ -21,6 +21,7 @@ import { useSearchParams, useRouter } from "next/navigation";
import { Team } from "./key_team_helpers/key_list";
import { jwtDecode } from "jwt-decode";
import { Typography } from "antd";
import { clearTokenCookies } from "@/utils/cookieUtils";
const isLocal = process.env.NODE_ENV === "development";
if (isLocal != true) {
console.log = function() {};
@ -295,14 +296,15 @@ const UserDashboard: React.FC<UserDashboardProps> = ({
if (userID == null || token == null) {
// user is not logged in as yet
console.log("All cookies before redirect:", document.cookie);
// Clear token cookies using the utility function
clearTokenCookies();
const url = proxyBaseUrl
? `${proxyBaseUrl}/sso/key/generate`
: `/sso/key/generate`;
// clear cookie called "token" since user will be logging in again
document.cookie = "token=; expires=Thu, 01 Jan 1970 00:00:00 UTC; path=/;";
console.log("Full URL:", url);
window.location.href = url;
@ -326,6 +328,7 @@ const UserDashboard: React.FC<UserDashboardProps> = ({
}
console.log("inside user dashboard, selected team", selectedTeam);
console.log("All cookies after redirect:", document.cookie);
return (
<div className="w-full mx-4 h-[75vh]">
<Grid numItems={1} className="gap-2 p-8 w-full mt-2">

View file

@ -0,0 +1,44 @@
/**
* Utility functions for managing cookies
*/
/**
* Clears the token cookie from both root and /ui paths
*/
export function clearTokenCookies() {
// Get the current domain
const domain = window.location.hostname;
// Clear with various combinations of path and SameSite
const paths = ['/', '/ui'];
const sameSiteValues = ['Lax', 'Strict', 'None'];
paths.forEach(path => {
// Basic clearing
document.cookie = `token=; expires=Thu, 01 Jan 1970 00:00:00 UTC; path=${path};`;
// With domain
document.cookie = `token=; expires=Thu, 01 Jan 1970 00:00:00 UTC; path=${path}; domain=${domain};`;
// Try different SameSite values
sameSiteValues.forEach(sameSite => {
const secureFlag = sameSite === 'None' ? ' Secure;' : '';
document.cookie = `token=; expires=Thu, 01 Jan 1970 00:00:00 UTC; path=${path}; SameSite=${sameSite};${secureFlag}`;
document.cookie = `token=; expires=Thu, 01 Jan 1970 00:00:00 UTC; path=${path}; domain=${domain}; SameSite=${sameSite};${secureFlag}`;
});
});
console.log("After clearing cookies:", document.cookie);
}
/**
* Gets a cookie value by name
* @param name The name of the cookie to retrieve
* @returns The cookie value or null if not found
*/
export function getCookie(name: string) {
const cookieValue = document.cookie
.split('; ')
.find(row => row.startsWith(name + '='));
return cookieValue ? cookieValue.split('=')[1] : null;
}