Compare commits

..

No commits in common. "main" and "v1.67.3.dev1" have entirely different histories.

21 changed files with 43 additions and 405 deletions

View file

@ -37,7 +37,6 @@ class SagemakerConfig(BaseConfig):
""" """
max_new_tokens: Optional[int] = None max_new_tokens: Optional[int] = None
max_completion_tokens: Optional[int] = None
top_p: Optional[float] = None top_p: Optional[float] = None
temperature: Optional[float] = None temperature: Optional[float] = None
return_full_text: Optional[bool] = None return_full_text: Optional[bool] = None
@ -45,7 +44,6 @@ class SagemakerConfig(BaseConfig):
def __init__( def __init__(
self, self,
max_new_tokens: Optional[int] = None, max_new_tokens: Optional[int] = None,
max_completion_tokens: Optional[int] = None,
top_p: Optional[float] = None, top_p: Optional[float] = None,
temperature: Optional[float] = None, temperature: Optional[float] = None,
return_full_text: Optional[bool] = None, return_full_text: Optional[bool] = None,
@ -67,7 +65,7 @@ class SagemakerConfig(BaseConfig):
) )
def get_supported_openai_params(self, model: str) -> List: def get_supported_openai_params(self, model: str) -> List:
return ["stream", "temperature", "max_tokens", "max_completion_tokens", "top_p", "stop", "n"] return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
def map_openai_params( def map_openai_params(
self, self,
@ -104,8 +102,6 @@ class SagemakerConfig(BaseConfig):
if value == 0: if value == 0:
value = 1 value = 1
optional_params["max_new_tokens"] = value optional_params["max_new_tokens"] = value
if param == "max_completion_tokens":
optional_params["max_new_tokens"] = value
non_default_params.pop("aws_sagemaker_allow_zero_temp", None) non_default_params.pop("aws_sagemaker_allow_zero_temp", None)
return optional_params return optional_params

View file

@ -7058,17 +7058,6 @@
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models", "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models",
"supports_tool_choice": true "supports_tool_choice": true
}, },
"command-a-03-2025": {
"max_tokens": 8000,
"max_input_tokens": 256000,
"max_output_tokens": 8000,
"input_cost_per_token": 0.0000025,
"output_cost_per_token": 0.00001,
"litellm_provider": "cohere_chat",
"mode": "chat",
"supports_function_calling": true,
"supports_tool_choice": true
},
"command-r": { "command-r": {
"max_tokens": 4096, "max_tokens": 4096,
"max_input_tokens": 128000, "max_input_tokens": 128000,

View file

@ -108,13 +108,7 @@ class ProxyBaseLLMRequestProcessing:
user_api_key_dict: UserAPIKeyAuth, user_api_key_dict: UserAPIKeyAuth,
proxy_logging_obj: ProxyLogging, proxy_logging_obj: ProxyLogging,
proxy_config: ProxyConfig, proxy_config: ProxyConfig,
route_type: Literal[ route_type: Literal["acompletion", "aresponses", "_arealtime"],
"acompletion",
"aresponses",
"_arealtime",
"aget_responses",
"adelete_responses",
],
version: Optional[str] = None, version: Optional[str] = None,
user_model: Optional[str] = None, user_model: Optional[str] = None,
user_temperature: Optional[float] = None, user_temperature: Optional[float] = None,
@ -184,13 +178,7 @@ class ProxyBaseLLMRequestProcessing:
request: Request, request: Request,
fastapi_response: Response, fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth, user_api_key_dict: UserAPIKeyAuth,
route_type: Literal[ route_type: Literal["acompletion", "aresponses", "_arealtime"],
"acompletion",
"aresponses",
"_arealtime",
"aget_responses",
"adelete_responses",
],
proxy_logging_obj: ProxyLogging, proxy_logging_obj: ProxyLogging,
general_settings: dict, general_settings: dict,
proxy_config: ProxyConfig, proxy_config: ProxyConfig,

View file

@ -1,5 +1,5 @@
import json import json
from typing import Any, Dict, List, Optional from typing import Dict, List, Optional
import orjson import orjson
from fastapi import Request, UploadFile, status from fastapi import Request, UploadFile, status
@ -147,11 +147,11 @@ def check_file_size_under_limit(
if llm_router is not None and request_data["model"] in router_model_names: if llm_router is not None and request_data["model"] in router_model_names:
try: try:
deployment: Optional[Deployment] = ( deployment: Optional[
llm_router.get_deployment_by_model_group_name( Deployment
] = llm_router.get_deployment_by_model_group_name(
model_group_name=request_data["model"] model_group_name=request_data["model"]
) )
)
if ( if (
deployment deployment
and deployment.litellm_params is not None and deployment.litellm_params is not None
@ -185,23 +185,3 @@ def check_file_size_under_limit(
) )
return True return True
async def get_form_data(request: Request) -> Dict[str, Any]:
"""
Read form data from request
Handles when OpenAI SDKs pass form keys as `timestamp_granularities[]="word"` instead of `timestamp_granularities=["word", "sentence"]`
"""
form = await request.form()
form_data = dict(form)
parsed_form_data: dict[str, Any] = {}
for key, value in form_data.items():
# OpenAI SDKs pass form keys as `timestamp_granularities[]="word"` instead of `timestamp_granularities=["word", "sentence"]`
if key.endswith("[]"):
clean_key = key[:-2]
parsed_form_data.setdefault(clean_key, []).append(value)
else:
parsed_form_data[key] = value
return parsed_form_data

View file

@ -1,8 +1,16 @@
model_list: model_list:
- model_name: openai/* - model_name: azure-computer-use-preview
litellm_params: litellm_params:
model: openai/* model: azure/computer-use-preview
api_key: os.environ/OPENAI_API_KEY api_key: mock-api-key
api_version: mock-api-version
api_base: https://mock-endpoint.openai.azure.com
- model_name: azure-computer-use-preview
litellm_params:
model: azure/computer-use-preview-2
api_key: mock-api-key-2
api_version: mock-api-version-2
api_base: https://mock-endpoint-2.openai.azure.com
router_settings: router_settings:
optional_pre_call_checks: ["responses_api_deployment_check"] optional_pre_call_checks: ["responses_api_deployment_check"]

View file

@ -179,7 +179,6 @@ from litellm.proxy.common_utils.html_forms.ui_login import html_form
from litellm.proxy.common_utils.http_parsing_utils import ( from litellm.proxy.common_utils.http_parsing_utils import (
_read_request_body, _read_request_body,
check_file_size_under_limit, check_file_size_under_limit,
get_form_data,
) )
from litellm.proxy.common_utils.load_config_utils import ( from litellm.proxy.common_utils.load_config_utils import (
get_config_file_contents_from_gcs, get_config_file_contents_from_gcs,
@ -805,9 +804,9 @@ model_max_budget_limiter = _PROXY_VirtualKeyModelMaxBudgetLimiter(
dual_cache=user_api_key_cache dual_cache=user_api_key_cache
) )
litellm.logging_callback_manager.add_litellm_callback(model_max_budget_limiter) litellm.logging_callback_manager.add_litellm_callback(model_max_budget_limiter)
redis_usage_cache: Optional[RedisCache] = ( redis_usage_cache: Optional[
None # redis cache used for tracking spend, tpm/rpm limits RedisCache
) ] = None # redis cache used for tracking spend, tpm/rpm limits
user_custom_auth = None user_custom_auth = None
user_custom_key_generate = None user_custom_key_generate = None
user_custom_sso = None user_custom_sso = None
@ -1133,9 +1132,9 @@ async def update_cache( # noqa: PLR0915
_id = "team_id:{}".format(team_id) _id = "team_id:{}".format(team_id)
try: try:
# Fetch the existing cost for the given user # Fetch the existing cost for the given user
existing_spend_obj: Optional[LiteLLM_TeamTable] = ( existing_spend_obj: Optional[
await user_api_key_cache.async_get_cache(key=_id) LiteLLM_TeamTable
) ] = await user_api_key_cache.async_get_cache(key=_id)
if existing_spend_obj is None: if existing_spend_obj is None:
# do nothing if team not in api key cache # do nothing if team not in api key cache
return return
@ -2807,9 +2806,9 @@ async def initialize( # noqa: PLR0915
user_api_base = api_base user_api_base = api_base
dynamic_config[user_model]["api_base"] = api_base dynamic_config[user_model]["api_base"] = api_base
if api_version: if api_version:
os.environ["AZURE_API_VERSION"] = ( os.environ[
api_version # set this for azure - litellm can read this from the env "AZURE_API_VERSION"
) ] = api_version # set this for azure - litellm can read this from the env
if max_tokens: # model-specific param if max_tokens: # model-specific param
dynamic_config[user_model]["max_tokens"] = max_tokens dynamic_config[user_model]["max_tokens"] = max_tokens
if temperature: # model-specific param if temperature: # model-specific param
@ -4121,7 +4120,7 @@ async def audio_transcriptions(
data: Dict = {} data: Dict = {}
try: try:
# Use orjson to parse JSON data, orjson speeds up requests significantly # Use orjson to parse JSON data, orjson speeds up requests significantly
form_data = await get_form_data(request) form_data = await request.form()
data = {key: value for key, value in form_data.items() if key != "file"} data = {key: value for key, value in form_data.items() if key != "file"}
# Include original request and headers in the data # Include original request and headers in the data
@ -7759,9 +7758,9 @@ async def get_config_list(
hasattr(sub_field_info, "description") hasattr(sub_field_info, "description")
and sub_field_info.description is not None and sub_field_info.description is not None
): ):
nested_fields[idx].field_description = ( nested_fields[
sub_field_info.description idx
) ].field_description = sub_field_info.description
idx += 1 idx += 1
_stored_in_db = None _stored_in_db = None

View file

@ -106,50 +106,8 @@ async def get_response(
-H "Authorization: Bearer sk-1234" -H "Authorization: Bearer sk-1234"
``` ```
""" """
from litellm.proxy.proxy_server import ( # TODO: Implement response retrieval logic
_read_request_body, pass
general_settings,
llm_router,
proxy_config,
proxy_logging_obj,
select_data_generator,
user_api_base,
user_max_tokens,
user_model,
user_request_timeout,
user_temperature,
version,
)
data = await _read_request_body(request=request)
data["response_id"] = response_id
processor = ProxyBaseLLMRequestProcessing(data=data)
try:
return await processor.base_process_llm_request(
request=request,
fastapi_response=fastapi_response,
user_api_key_dict=user_api_key_dict,
route_type="aget_responses",
proxy_logging_obj=proxy_logging_obj,
llm_router=llm_router,
general_settings=general_settings,
proxy_config=proxy_config,
select_data_generator=select_data_generator,
model=None,
user_model=user_model,
user_temperature=user_temperature,
user_request_timeout=user_request_timeout,
user_max_tokens=user_max_tokens,
user_api_base=user_api_base,
version=version,
)
except Exception as e:
raise await processor._handle_llm_api_exception(
e=e,
user_api_key_dict=user_api_key_dict,
proxy_logging_obj=proxy_logging_obj,
version=version,
)
@router.delete( @router.delete(
@ -178,50 +136,8 @@ async def delete_response(
-H "Authorization: Bearer sk-1234" -H "Authorization: Bearer sk-1234"
``` ```
""" """
from litellm.proxy.proxy_server import ( # TODO: Implement response deletion logic
_read_request_body, pass
general_settings,
llm_router,
proxy_config,
proxy_logging_obj,
select_data_generator,
user_api_base,
user_max_tokens,
user_model,
user_request_timeout,
user_temperature,
version,
)
data = await _read_request_body(request=request)
data["response_id"] = response_id
processor = ProxyBaseLLMRequestProcessing(data=data)
try:
return await processor.base_process_llm_request(
request=request,
fastapi_response=fastapi_response,
user_api_key_dict=user_api_key_dict,
route_type="adelete_responses",
proxy_logging_obj=proxy_logging_obj,
llm_router=llm_router,
general_settings=general_settings,
proxy_config=proxy_config,
select_data_generator=select_data_generator,
model=None,
user_model=user_model,
user_temperature=user_temperature,
user_request_timeout=user_request_timeout,
user_max_tokens=user_max_tokens,
user_api_base=user_api_base,
version=version,
)
except Exception as e:
raise await processor._handle_llm_api_exception(
e=e,
user_api_key_dict=user_api_key_dict,
proxy_logging_obj=proxy_logging_obj,
version=version,
)
@router.get( @router.get(

View file

@ -47,8 +47,6 @@ async def route_request(
"amoderation", "amoderation",
"arerank", "arerank",
"aresponses", "aresponses",
"aget_responses",
"adelete_responses",
"_arealtime", # private function for realtime API "_arealtime", # private function for realtime API
], ],
): ):

View file

@ -176,16 +176,6 @@ class ResponsesAPIRequestUtils:
response_id=response_id, response_id=response_id,
) )
@staticmethod
def get_model_id_from_response_id(response_id: Optional[str]) -> Optional[str]:
"""Get the model_id from the response_id"""
if response_id is None:
return None
decoded_response_id = (
ResponsesAPIRequestUtils._decode_responses_api_response_id(response_id)
)
return decoded_response_id.get("model_id") or None
class ResponseAPILoggingUtils: class ResponseAPILoggingUtils:
@staticmethod @staticmethod

View file

@ -739,12 +739,6 @@ class Router:
litellm.afile_content, call_type="afile_content" litellm.afile_content, call_type="afile_content"
) )
self.responses = self.factory_function(litellm.responses, call_type="responses") self.responses = self.factory_function(litellm.responses, call_type="responses")
self.aget_responses = self.factory_function(
litellm.aget_responses, call_type="aget_responses"
)
self.adelete_responses = self.factory_function(
litellm.adelete_responses, call_type="adelete_responses"
)
def validate_fallbacks(self, fallback_param: Optional[List]): def validate_fallbacks(self, fallback_param: Optional[List]):
""" """
@ -3087,8 +3081,6 @@ class Router:
"anthropic_messages", "anthropic_messages",
"aresponses", "aresponses",
"responses", "responses",
"aget_responses",
"adelete_responses",
"afile_delete", "afile_delete",
"afile_content", "afile_content",
] = "assistants", ] = "assistants",
@ -3143,11 +3135,6 @@ class Router:
original_function=original_function, original_function=original_function,
**kwargs, **kwargs,
) )
elif call_type in ("aget_responses", "adelete_responses"):
return await self._init_responses_api_endpoints(
original_function=original_function,
**kwargs,
)
elif call_type in ("afile_delete", "afile_content"): elif call_type in ("afile_delete", "afile_content"):
return await self._ageneric_api_call_with_fallbacks( return await self._ageneric_api_call_with_fallbacks(
original_function=original_function, original_function=original_function,
@ -3158,28 +3145,6 @@ class Router:
return async_wrapper return async_wrapper
async def _init_responses_api_endpoints(
self,
original_function: Callable,
**kwargs,
):
"""
Initialize the Responses API endpoints on the router.
GET, DELETE Responses API Requests encode the model_id in the response_id, this function decodes the response_id and sets the model to the model_id.
"""
from litellm.responses.utils import ResponsesAPIRequestUtils
model_id = ResponsesAPIRequestUtils.get_model_id_from_response_id(
kwargs.get("response_id")
)
if model_id is not None:
kwargs["model"] = model_id
return await self._ageneric_api_call_with_fallbacks(
original_function=original_function,
**kwargs,
)
async def _pass_through_assistants_endpoint_factory( async def _pass_through_assistants_endpoint_factory(
self, self,
original_function: Callable, original_function: Callable,

View file

@ -7058,17 +7058,6 @@
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models", "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models",
"supports_tool_choice": true "supports_tool_choice": true
}, },
"command-a-03-2025": {
"max_tokens": 8000,
"max_input_tokens": 256000,
"max_output_tokens": 8000,
"input_cost_per_token": 0.0000025,
"output_cost_per_token": 0.00001,
"litellm_provider": "cohere_chat",
"mode": "chat",
"supports_function_calling": true,
"supports_tool_choice": true
},
"command-r": { "command-r": {
"max_tokens": 4096, "max_tokens": 4096,
"max_input_tokens": 128000, "max_input_tokens": 128000,

View file

@ -8,7 +8,7 @@ import pytest
sys.path.insert(0, os.path.abspath("../../../../..")) sys.path.insert(0, os.path.abspath("../../../../.."))
from litellm.llms.sagemaker.common_utils import AWSEventStreamDecoder from litellm.llms.sagemaker.common_utils import AWSEventStreamDecoder
from litellm.llms.sagemaker.completion.transformation import SagemakerConfig
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_aiter_bytes_unicode_decode_error(): async def test_aiter_bytes_unicode_decode_error():
@ -95,37 +95,3 @@ async def test_aiter_bytes_valid_chunk_followed_by_unicode_error():
# Verify we got our valid chunk despite the subsequent error # Verify we got our valid chunk despite the subsequent error
assert len(chunks) == 1 assert len(chunks) == 1
assert chunks[0]["text"] == "hello" # Verify the content of the valid chunk assert chunks[0]["text"] == "hello" # Verify the content of the valid chunk
class TestSagemakerTransform:
def setup_method(self):
self.config = SagemakerConfig()
self.model = "test"
self.logging_obj = MagicMock()
def test_map_mistral_params(self):
"""Test that parameters are correctly mapped"""
test_params = {"temperature": 0.7, "max_tokens": 200, "max_completion_tokens": 256}
result = self.config.map_openai_params(
non_default_params=test_params,
optional_params={},
model=self.model,
drop_params=False,
)
# The function should properly map max_completion_tokens to max_tokens and override max_tokens
assert result == {"temperature": 0.7, "max_new_tokens": 256}
def test_mistral_max_tokens_backward_compat(self):
"""Test that parameters are correctly mapped"""
test_params = {"temperature": 0.7, "max_tokens": 200,}
result = self.config.map_openai_params(
non_default_params=test_params,
optional_params={},
model=self.model,
drop_params=False,
)
# The function should properly map max_tokens if max_completion_tokens is not provided
assert result == {"temperature": 0.7, "max_new_tokens": 200}

View file

@ -18,7 +18,6 @@ from litellm.proxy.common_utils.http_parsing_utils import (
_read_request_body, _read_request_body,
_safe_get_request_parsed_body, _safe_get_request_parsed_body,
_safe_set_request_parsed_body, _safe_set_request_parsed_body,
get_form_data,
) )
@ -148,53 +147,3 @@ async def test_circular_reference_handling():
assert ( assert (
"proxy_server_request" not in result2 "proxy_server_request" not in result2
) # This will pass, showing the cache pollution ) # This will pass, showing the cache pollution
@pytest.mark.asyncio
async def test_get_form_data():
"""
Test that get_form_data correctly handles form data with array notation.
Tests audio transcription parameters as a specific example.
"""
# Create a mock request with transcription form data
mock_request = MagicMock()
# Create mock form data with array notation for timestamp_granularities
mock_form_data = {
"file": "file_object", # In a real request this would be an UploadFile
"model": "gpt-4o-transcribe",
"include[]": "logprobs", # Array notation
"language": "en",
"prompt": "Transcribe this audio file",
"response_format": "json",
"stream": "false",
"temperature": "0.2",
"timestamp_granularities[]": "word", # First array item
"timestamp_granularities[]": "segment", # Second array item (would overwrite in dict, but handled by the function)
}
# Mock the form method to return the test data
mock_request.form = AsyncMock(return_value=mock_form_data)
# Call the function being tested
result = await get_form_data(mock_request)
# Verify regular form fields are preserved
assert result["file"] == "file_object"
assert result["model"] == "gpt-4o-transcribe"
assert result["language"] == "en"
assert result["prompt"] == "Transcribe this audio file"
assert result["response_format"] == "json"
assert result["stream"] == "false"
assert result["temperature"] == "0.2"
# Verify array fields are correctly parsed
assert "include" in result
assert isinstance(result["include"], list)
assert "logprobs" in result["include"]
assert "timestamp_granularities" in result
assert isinstance(result["timestamp_granularities"], list)
# Note: In a real MultiDict, both values would be present
# But in our mock dictionary the second value overwrites the first
assert "segment" in result["timestamp_granularities"]

View file

@ -947,8 +947,6 @@ class BaseLLMChatTest(ABC):
second_response.choices[0].message.content is not None second_response.choices[0].message.content is not None
or second_response.choices[0].message.tool_calls is not None or second_response.choices[0].message.tool_calls is not None
) )
except litellm.ServiceUnavailableError:
pytest.skip("Model is overloaded")
except litellm.InternalServerError: except litellm.InternalServerError:
pytest.skip("Model is overloaded") pytest.skip("Model is overloaded")
except litellm.RateLimitError: except litellm.RateLimitError:

View file

@ -77,22 +77,6 @@ def test_basic_response():
) )
print("basic response=", response) print("basic response=", response)
# get the response
response = client.responses.retrieve(response.id)
print("GET response=", response)
# delete the response
delete_response = client.responses.delete(response.id)
print("DELETE response=", delete_response)
# try getting the response again, we should not get it back
get_response = client.responses.retrieve(response.id)
print("GET response after delete=", get_response)
with pytest.raises(Exception):
get_response = client.responses.retrieve(response.id)
def test_streaming_response(): def test_streaming_response():
client = get_test_client() client = get_test_client()

View file

@ -6,7 +6,7 @@ from typing import Optional
from dotenv import load_dotenv from dotenv import load_dotenv
from fastapi import Request from fastapi import Request
from datetime import datetime from datetime import datetime
from unittest.mock import AsyncMock, patch, MagicMock from unittest.mock import AsyncMock, patch
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
@ -553,67 +553,9 @@ def test_initialize_router_endpoints():
assert hasattr(router, "aanthropic_messages") assert hasattr(router, "aanthropic_messages")
assert hasattr(router, "aresponses") assert hasattr(router, "aresponses")
assert hasattr(router, "responses") assert hasattr(router, "responses")
assert hasattr(router, "aget_responses")
assert hasattr(router, "adelete_responses")
# Verify the endpoints are callable # Verify the endpoints are callable
assert callable(router.amoderation) assert callable(router.amoderation)
assert callable(router.aanthropic_messages) assert callable(router.aanthropic_messages)
assert callable(router.aresponses) assert callable(router.aresponses)
assert callable(router.responses) assert callable(router.responses)
assert callable(router.aget_responses)
assert callable(router.adelete_responses)
@pytest.mark.asyncio
async def test_init_responses_api_endpoints():
"""
A simpler test for _init_responses_api_endpoints that focuses on the basic functionality
"""
from litellm.responses.utils import ResponsesAPIRequestUtils
# Create a router with a basic model
router = Router(
model_list=[
{
"model_name": "test-model",
"litellm_params": {
"model": "openai/test-model",
"api_key": "fake-api-key",
},
}
]
)
# Just mock the _ageneric_api_call_with_fallbacks method
router._ageneric_api_call_with_fallbacks = AsyncMock()
# Add a mock implementation of _get_model_id_from_response_id to the Router instance
ResponsesAPIRequestUtils.get_model_id_from_response_id = MagicMock(return_value=None)
# Call without a response_id (no model extraction should happen)
await router._init_responses_api_endpoints(
original_function=AsyncMock(),
thread_id="thread_xyz"
)
# Verify _ageneric_api_call_with_fallbacks was called but model wasn't changed
first_call_kwargs = router._ageneric_api_call_with_fallbacks.call_args.kwargs
assert "model" not in first_call_kwargs
assert first_call_kwargs["thread_id"] == "thread_xyz"
# Reset the mock
router._ageneric_api_call_with_fallbacks.reset_mock()
# Change the return value for the second call
ResponsesAPIRequestUtils.get_model_id_from_response_id.return_value = "claude-3-sonnet"
# Call with a response_id
await router._init_responses_api_endpoints(
original_function=AsyncMock(),
response_id="resp_claude_123"
)
# Verify model was updated in the kwargs
second_call_kwargs = router._ageneric_api_call_with_fallbacks.call_args.kwargs
assert second_call_kwargs["model"] == "claude-3-sonnet"
assert second_call_kwargs["response_id"] == "resp_claude_123"

View file

@ -1157,14 +1157,3 @@ def test_cached_get_model_group_info(model_list):
# Verify the cache info shows hits # Verify the cache info shows hits
cache_info = router._cached_get_model_group_info.cache_info() cache_info = router._cached_get_model_group_info.cache_info()
assert cache_info.hits > 0 # Should have at least one cache hit assert cache_info.hits > 0 # Should have at least one cache hit
def test_init_responses_api_endpoints(model_list):
"""Test if the '_init_responses_api_endpoints' function is working correctly"""
from typing import Callable
router = Router(model_list=model_list)
assert router.aget_responses is not None
assert isinstance(router.aget_responses, Callable)
assert router.adelete_responses is not None
assert isinstance(router.adelete_responses, Callable)

View file

@ -151,7 +151,7 @@ export default function CreateKeyPage() {
if (redirectToLogin) { if (redirectToLogin) {
window.location.href = (proxyBaseUrl || "") + "/sso/key/generate" window.location.href = (proxyBaseUrl || "") + "/sso/key/generate"
} }
}, [redirectToLogin]) }, [token, authLoading])
useEffect(() => { useEffect(() => {
if (!token) { if (!token) {
@ -223,7 +223,7 @@ export default function CreateKeyPage() {
} }
}, [accessToken, userID, userRole]); }, [accessToken, userID, userRole]);
if (authLoading || redirectToLogin) { if (authLoading) {
return <LoadingScreen /> return <LoadingScreen />
} }

View file

@ -450,8 +450,6 @@ export function AllKeysTable({
columns={columns.filter(col => col.id !== 'expander') as any} columns={columns.filter(col => col.id !== 'expander') as any}
data={filteredKeys as any} data={filteredKeys as any}
isLoading={isLoading} isLoading={isLoading}
loadingMessage="🚅 Loading keys..."
noDataMessage="No keys found"
getRowCanExpand={() => false} getRowCanExpand={() => false}
renderSubComponent={() => <></>} renderSubComponent={() => <></>}
/> />

View file

@ -26,8 +26,6 @@ function DataTableWrapper({
isLoading={isLoading} isLoading={isLoading}
renderSubComponent={renderSubComponent} renderSubComponent={renderSubComponent}
getRowCanExpand={getRowCanExpand} getRowCanExpand={getRowCanExpand}
loadingMessage="🚅 Loading tools..."
noDataMessage="No tools found"
/> />
); );
} }

View file

@ -26,8 +26,6 @@ interface DataTableProps<TData, TValue> {
expandedRequestId?: string | null; expandedRequestId?: string | null;
onRowExpand?: (requestId: string | null) => void; onRowExpand?: (requestId: string | null) => void;
setSelectedKeyIdInfoView?: (keyId: string | null) => void; setSelectedKeyIdInfoView?: (keyId: string | null) => void;
loadingMessage?: string;
noDataMessage?: string;
} }
export function DataTable<TData extends { request_id: string }, TValue>({ export function DataTable<TData extends { request_id: string }, TValue>({
@ -38,8 +36,6 @@ export function DataTable<TData extends { request_id: string }, TValue>({
isLoading = false, isLoading = false,
expandedRequestId, expandedRequestId,
onRowExpand, onRowExpand,
loadingMessage = "🚅 Loading logs...",
noDataMessage = "No logs found",
}: DataTableProps<TData, TValue>) { }: DataTableProps<TData, TValue>) {
const table = useReactTable({ const table = useReactTable({
data, data,
@ -118,7 +114,7 @@ export function DataTable<TData extends { request_id: string }, TValue>({
<TableRow> <TableRow>
<TableCell colSpan={columns.length} className="h-8 text-center"> <TableCell colSpan={columns.length} className="h-8 text-center">
<div className="text-center text-gray-500"> <div className="text-center text-gray-500">
<p>{loadingMessage}</p> <p>🚅 Loading logs...</p>
</div> </div>
</TableCell> </TableCell>
</TableRow> </TableRow>
@ -151,7 +147,7 @@ export function DataTable<TData extends { request_id: string }, TValue>({
: <TableRow> : <TableRow>
<TableCell colSpan={columns.length} className="h-8 text-center"> <TableCell colSpan={columns.length} className="h-8 text-center">
<div className="text-center text-gray-500"> <div className="text-center text-gray-500">
<p>{noDataMessage}</p> <p>No logs found</p>
</div> </div>
</TableCell> </TableCell>
</TableRow> </TableRow>