Compare commits

...

6 commits

Author SHA1 Message Date
Ishaan Jaff
164017119d
[Bug Fix] Timestamp Granularities are not properly passed to whisper in Azure (#10299)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 23s
Helm unit test / unit-test (push) Successful in 29s
* test fix form data parsing

* test fix form data parsing

* fix types
2025-04-24 18:57:11 -07:00
Ishaan Jaff
5de101ab7b
[Feat] Add GET, DELETE Responses endpoints on LiteLLM Proxy (#10297)
* add GET responses endpoints on router

* add GET responses endpoints on router

* add GET responses endpoints on router

* add DELETE responses endpoints on proxy

* fixes for testing GET, DELETE endpoints

* test_basic_responses api e2e
2025-04-24 17:34:26 -07:00
Ryan Chase
0a2c964db7
adding support for cohere command-a-03-2025 (#10295) 2025-04-24 17:07:29 -07:00
Marc Abramowitz
56d00c43f7
Keys page: Use keys rather than logs terminology (#10253) 2025-04-24 14:25:59 -07:00
Christian Owusu
b82af5b826
Fix UI Flicker in Dashboard (#10261)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 17s
Helm unit test / unit-test (push) Successful in 24s
2025-04-23 23:27:44 -07:00
Krrish Dholakia
2adb2fc6a5 test: handle service unavailable error 2025-04-23 22:10:46 -07:00
17 changed files with 297 additions and 42 deletions

View file

@ -108,7 +108,13 @@ class ProxyBaseLLMRequestProcessing:
user_api_key_dict: UserAPIKeyAuth,
proxy_logging_obj: ProxyLogging,
proxy_config: ProxyConfig,
route_type: Literal["acompletion", "aresponses", "_arealtime"],
route_type: Literal[
"acompletion",
"aresponses",
"_arealtime",
"aget_responses",
"adelete_responses",
],
version: Optional[str] = None,
user_model: Optional[str] = None,
user_temperature: Optional[float] = None,
@ -178,7 +184,13 @@ class ProxyBaseLLMRequestProcessing:
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth,
route_type: Literal["acompletion", "aresponses", "_arealtime"],
route_type: Literal[
"acompletion",
"aresponses",
"_arealtime",
"aget_responses",
"adelete_responses",
],
proxy_logging_obj: ProxyLogging,
general_settings: dict,
proxy_config: ProxyConfig,

View file

@ -1,5 +1,5 @@
import json
from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional
import orjson
from fastapi import Request, UploadFile, status
@ -147,11 +147,11 @@ def check_file_size_under_limit(
if llm_router is not None and request_data["model"] in router_model_names:
try:
deployment: Optional[
Deployment
] = llm_router.get_deployment_by_model_group_name(
deployment: Optional[Deployment] = (
llm_router.get_deployment_by_model_group_name(
model_group_name=request_data["model"]
)
)
if (
deployment
and deployment.litellm_params is not None
@ -185,3 +185,23 @@ def check_file_size_under_limit(
)
return True
async def get_form_data(request: Request) -> Dict[str, Any]:
"""
Read form data from request
Handles when OpenAI SDKs pass form keys as `timestamp_granularities[]="word"` instead of `timestamp_granularities=["word", "sentence"]`
"""
form = await request.form()
form_data = dict(form)
parsed_form_data: dict[str, Any] = {}
for key, value in form_data.items():
# OpenAI SDKs pass form keys as `timestamp_granularities[]="word"` instead of `timestamp_granularities=["word", "sentence"]`
if key.endswith("[]"):
clean_key = key[:-2]
parsed_form_data.setdefault(clean_key, []).append(value)
else:
parsed_form_data[key] = value
return parsed_form_data

View file

@ -1,16 +1,8 @@
model_list:
- model_name: azure-computer-use-preview
- model_name: openai/*
litellm_params:
model: azure/computer-use-preview
api_key: mock-api-key
api_version: mock-api-version
api_base: https://mock-endpoint.openai.azure.com
- model_name: azure-computer-use-preview
litellm_params:
model: azure/computer-use-preview-2
api_key: mock-api-key-2
api_version: mock-api-version-2
api_base: https://mock-endpoint-2.openai.azure.com
model: openai/*
api_key: os.environ/OPENAI_API_KEY
router_settings:
optional_pre_call_checks: ["responses_api_deployment_check"]

View file

@ -179,6 +179,7 @@ from litellm.proxy.common_utils.html_forms.ui_login import html_form
from litellm.proxy.common_utils.http_parsing_utils import (
_read_request_body,
check_file_size_under_limit,
get_form_data,
)
from litellm.proxy.common_utils.load_config_utils import (
get_config_file_contents_from_gcs,
@ -804,9 +805,9 @@ model_max_budget_limiter = _PROXY_VirtualKeyModelMaxBudgetLimiter(
dual_cache=user_api_key_cache
)
litellm.logging_callback_manager.add_litellm_callback(model_max_budget_limiter)
redis_usage_cache: Optional[
RedisCache
] = None # redis cache used for tracking spend, tpm/rpm limits
redis_usage_cache: Optional[RedisCache] = (
None # redis cache used for tracking spend, tpm/rpm limits
)
user_custom_auth = None
user_custom_key_generate = None
user_custom_sso = None
@ -1132,9 +1133,9 @@ async def update_cache( # noqa: PLR0915
_id = "team_id:{}".format(team_id)
try:
# Fetch the existing cost for the given user
existing_spend_obj: Optional[
LiteLLM_TeamTable
] = await user_api_key_cache.async_get_cache(key=_id)
existing_spend_obj: Optional[LiteLLM_TeamTable] = (
await user_api_key_cache.async_get_cache(key=_id)
)
if existing_spend_obj is None:
# do nothing if team not in api key cache
return
@ -2806,9 +2807,9 @@ async def initialize( # noqa: PLR0915
user_api_base = api_base
dynamic_config[user_model]["api_base"] = api_base
if api_version:
os.environ[
"AZURE_API_VERSION"
] = api_version # set this for azure - litellm can read this from the env
os.environ["AZURE_API_VERSION"] = (
api_version # set this for azure - litellm can read this from the env
)
if max_tokens: # model-specific param
dynamic_config[user_model]["max_tokens"] = max_tokens
if temperature: # model-specific param
@ -4120,7 +4121,7 @@ async def audio_transcriptions(
data: Dict = {}
try:
# Use orjson to parse JSON data, orjson speeds up requests significantly
form_data = await request.form()
form_data = await get_form_data(request)
data = {key: value for key, value in form_data.items() if key != "file"}
# Include original request and headers in the data
@ -7758,9 +7759,9 @@ async def get_config_list(
hasattr(sub_field_info, "description")
and sub_field_info.description is not None
):
nested_fields[
idx
].field_description = sub_field_info.description
nested_fields[idx].field_description = (
sub_field_info.description
)
idx += 1
_stored_in_db = None

View file

@ -106,8 +106,50 @@ async def get_response(
-H "Authorization: Bearer sk-1234"
```
"""
# TODO: Implement response retrieval logic
pass
from litellm.proxy.proxy_server import (
_read_request_body,
general_settings,
llm_router,
proxy_config,
proxy_logging_obj,
select_data_generator,
user_api_base,
user_max_tokens,
user_model,
user_request_timeout,
user_temperature,
version,
)
data = await _read_request_body(request=request)
data["response_id"] = response_id
processor = ProxyBaseLLMRequestProcessing(data=data)
try:
return await processor.base_process_llm_request(
request=request,
fastapi_response=fastapi_response,
user_api_key_dict=user_api_key_dict,
route_type="aget_responses",
proxy_logging_obj=proxy_logging_obj,
llm_router=llm_router,
general_settings=general_settings,
proxy_config=proxy_config,
select_data_generator=select_data_generator,
model=None,
user_model=user_model,
user_temperature=user_temperature,
user_request_timeout=user_request_timeout,
user_max_tokens=user_max_tokens,
user_api_base=user_api_base,
version=version,
)
except Exception as e:
raise await processor._handle_llm_api_exception(
e=e,
user_api_key_dict=user_api_key_dict,
proxy_logging_obj=proxy_logging_obj,
version=version,
)
@router.delete(
@ -136,8 +178,50 @@ async def delete_response(
-H "Authorization: Bearer sk-1234"
```
"""
# TODO: Implement response deletion logic
pass
from litellm.proxy.proxy_server import (
_read_request_body,
general_settings,
llm_router,
proxy_config,
proxy_logging_obj,
select_data_generator,
user_api_base,
user_max_tokens,
user_model,
user_request_timeout,
user_temperature,
version,
)
data = await _read_request_body(request=request)
data["response_id"] = response_id
processor = ProxyBaseLLMRequestProcessing(data=data)
try:
return await processor.base_process_llm_request(
request=request,
fastapi_response=fastapi_response,
user_api_key_dict=user_api_key_dict,
route_type="adelete_responses",
proxy_logging_obj=proxy_logging_obj,
llm_router=llm_router,
general_settings=general_settings,
proxy_config=proxy_config,
select_data_generator=select_data_generator,
model=None,
user_model=user_model,
user_temperature=user_temperature,
user_request_timeout=user_request_timeout,
user_max_tokens=user_max_tokens,
user_api_base=user_api_base,
version=version,
)
except Exception as e:
raise await processor._handle_llm_api_exception(
e=e,
user_api_key_dict=user_api_key_dict,
proxy_logging_obj=proxy_logging_obj,
version=version,
)
@router.get(

View file

@ -47,6 +47,8 @@ async def route_request(
"amoderation",
"arerank",
"aresponses",
"aget_responses",
"adelete_responses",
"_arealtime", # private function for realtime API
],
):

View file

@ -176,6 +176,16 @@ class ResponsesAPIRequestUtils:
response_id=response_id,
)
@staticmethod
def get_model_id_from_response_id(response_id: Optional[str]) -> Optional[str]:
"""Get the model_id from the response_id"""
if response_id is None:
return None
decoded_response_id = (
ResponsesAPIRequestUtils._decode_responses_api_response_id(response_id)
)
return decoded_response_id.get("model_id") or None
class ResponseAPILoggingUtils:
@staticmethod

View file

@ -739,6 +739,12 @@ class Router:
litellm.afile_content, call_type="afile_content"
)
self.responses = self.factory_function(litellm.responses, call_type="responses")
self.aget_responses = self.factory_function(
litellm.aget_responses, call_type="aget_responses"
)
self.adelete_responses = self.factory_function(
litellm.adelete_responses, call_type="adelete_responses"
)
def validate_fallbacks(self, fallback_param: Optional[List]):
"""
@ -3081,6 +3087,8 @@ class Router:
"anthropic_messages",
"aresponses",
"responses",
"aget_responses",
"adelete_responses",
"afile_delete",
"afile_content",
] = "assistants",
@ -3135,6 +3143,11 @@ class Router:
original_function=original_function,
**kwargs,
)
elif call_type in ("aget_responses", "adelete_responses"):
return await self._init_responses_api_endpoints(
original_function=original_function,
**kwargs,
)
elif call_type in ("afile_delete", "afile_content"):
return await self._ageneric_api_call_with_fallbacks(
original_function=original_function,
@ -3145,6 +3158,28 @@ class Router:
return async_wrapper
async def _init_responses_api_endpoints(
self,
original_function: Callable,
**kwargs,
):
"""
Initialize the Responses API endpoints on the router.
GET, DELETE Responses API Requests encode the model_id in the response_id, this function decodes the response_id and sets the model to the model_id.
"""
from litellm.responses.utils import ResponsesAPIRequestUtils
model_id = ResponsesAPIRequestUtils.get_model_id_from_response_id(
kwargs.get("response_id")
)
if model_id is not None:
kwargs["model"] = model_id
return await self._ageneric_api_call_with_fallbacks(
original_function=original_function,
**kwargs,
)
async def _pass_through_assistants_endpoint_factory(
self,
original_function: Callable,

View file

@ -7058,6 +7058,17 @@
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models",
"supports_tool_choice": true
},
"command-a-03-2025": {
"max_tokens": 8000,
"max_input_tokens": 256000,
"max_output_tokens": 8000,
"input_cost_per_token": 0.0000025,
"output_cost_per_token": 0.00001,
"litellm_provider": "cohere_chat",
"mode": "chat",
"supports_function_calling": true,
"supports_tool_choice": true
},
"command-r": {
"max_tokens": 4096,
"max_input_tokens": 128000,

View file

@ -18,6 +18,7 @@ from litellm.proxy.common_utils.http_parsing_utils import (
_read_request_body,
_safe_get_request_parsed_body,
_safe_set_request_parsed_body,
get_form_data,
)
@ -147,3 +148,53 @@ async def test_circular_reference_handling():
assert (
"proxy_server_request" not in result2
) # This will pass, showing the cache pollution
@pytest.mark.asyncio
async def test_get_form_data():
"""
Test that get_form_data correctly handles form data with array notation.
Tests audio transcription parameters as a specific example.
"""
# Create a mock request with transcription form data
mock_request = MagicMock()
# Create mock form data with array notation for timestamp_granularities
mock_form_data = {
"file": "file_object", # In a real request this would be an UploadFile
"model": "gpt-4o-transcribe",
"include[]": "logprobs", # Array notation
"language": "en",
"prompt": "Transcribe this audio file",
"response_format": "json",
"stream": "false",
"temperature": "0.2",
"timestamp_granularities[]": "word", # First array item
"timestamp_granularities[]": "segment", # Second array item (would overwrite in dict, but handled by the function)
}
# Mock the form method to return the test data
mock_request.form = AsyncMock(return_value=mock_form_data)
# Call the function being tested
result = await get_form_data(mock_request)
# Verify regular form fields are preserved
assert result["file"] == "file_object"
assert result["model"] == "gpt-4o-transcribe"
assert result["language"] == "en"
assert result["prompt"] == "Transcribe this audio file"
assert result["response_format"] == "json"
assert result["stream"] == "false"
assert result["temperature"] == "0.2"
# Verify array fields are correctly parsed
assert "include" in result
assert isinstance(result["include"], list)
assert "logprobs" in result["include"]
assert "timestamp_granularities" in result
assert isinstance(result["timestamp_granularities"], list)
# Note: In a real MultiDict, both values would be present
# But in our mock dictionary the second value overwrites the first
assert "segment" in result["timestamp_granularities"]

View file

@ -947,6 +947,8 @@ class BaseLLMChatTest(ABC):
second_response.choices[0].message.content is not None
or second_response.choices[0].message.tool_calls is not None
)
except litellm.ServiceUnavailableError:
pytest.skip("Model is overloaded")
except litellm.InternalServerError:
pytest.skip("Model is overloaded")
except litellm.RateLimitError:

View file

@ -73,15 +73,31 @@ def validate_stream_chunk(chunk):
def test_basic_response():
client = get_test_client()
response = client.responses.create(
model="gpt-4o", input="just respond with the word 'ping'"
model="gpt-4.0", input="just respond with the word 'ping'"
)
print("basic response=", response)
# get the response
response = client.responses.retrieve(response.id)
print("GET response=", response)
# delete the response
delete_response = client.responses.delete(response.id)
print("DELETE response=", delete_response)
# try getting the response again, we should not get it back
get_response = client.responses.retrieve(response.id)
print("GET response after delete=", get_response)
with pytest.raises(Exception):
get_response = client.responses.retrieve(response.id)
def test_streaming_response():
client = get_test_client()
stream = client.responses.create(
model="gpt-4o", input="just respond with the word 'ping'", stream=True
model="gpt-4.0", input="just respond with the word 'ping'", stream=True
)
collected_chunks = []
@ -104,5 +120,5 @@ def test_bad_request_bad_param_error():
with pytest.raises(BadRequestError):
# Trigger error with invalid model name
client.responses.create(
model="gpt-4o", input="This should fail", temperature=2000
model="gpt-4.0", input="This should fail", temperature=2000
)

View file

@ -1157,3 +1157,14 @@ def test_cached_get_model_group_info(model_list):
# Verify the cache info shows hits
cache_info = router._cached_get_model_group_info.cache_info()
assert cache_info.hits > 0 # Should have at least one cache hit
def test_init_responses_api_endpoints(model_list):
"""Test if the '_init_responses_api_endpoints' function is working correctly"""
from typing import Callable
router = Router(model_list=model_list)
assert router.aget_responses is not None
assert isinstance(router.aget_responses, Callable)
assert router.adelete_responses is not None
assert isinstance(router.adelete_responses, Callable)

View file

@ -151,7 +151,7 @@ export default function CreateKeyPage() {
if (redirectToLogin) {
window.location.href = (proxyBaseUrl || "") + "/sso/key/generate"
}
}, [token, authLoading])
}, [redirectToLogin])
useEffect(() => {
if (!token) {
@ -223,7 +223,7 @@ export default function CreateKeyPage() {
}
}, [accessToken, userID, userRole]);
if (authLoading) {
if (authLoading || redirectToLogin) {
return <LoadingScreen />
}

View file

@ -450,6 +450,8 @@ export function AllKeysTable({
columns={columns.filter(col => col.id !== 'expander') as any}
data={filteredKeys as any}
isLoading={isLoading}
loadingMessage="🚅 Loading keys..."
noDataMessage="No keys found"
getRowCanExpand={() => false}
renderSubComponent={() => <></>}
/>

View file

@ -26,6 +26,8 @@ function DataTableWrapper({
isLoading={isLoading}
renderSubComponent={renderSubComponent}
getRowCanExpand={getRowCanExpand}
loadingMessage="🚅 Loading tools..."
noDataMessage="No tools found"
/>
);
}

View file

@ -26,6 +26,8 @@ interface DataTableProps<TData, TValue> {
expandedRequestId?: string | null;
onRowExpand?: (requestId: string | null) => void;
setSelectedKeyIdInfoView?: (keyId: string | null) => void;
loadingMessage?: string;
noDataMessage?: string;
}
export function DataTable<TData extends { request_id: string }, TValue>({
@ -36,6 +38,8 @@ export function DataTable<TData extends { request_id: string }, TValue>({
isLoading = false,
expandedRequestId,
onRowExpand,
loadingMessage = "🚅 Loading logs...",
noDataMessage = "No logs found",
}: DataTableProps<TData, TValue>) {
const table = useReactTable({
data,
@ -114,7 +118,7 @@ export function DataTable<TData extends { request_id: string }, TValue>({
<TableRow>
<TableCell colSpan={columns.length} className="h-8 text-center">
<div className="text-center text-gray-500">
<p>🚅 Loading logs...</p>
<p>{loadingMessage}</p>
</div>
</TableCell>
</TableRow>
@ -147,7 +151,7 @@ export function DataTable<TData extends { request_id: string }, TValue>({
: <TableRow>
<TableCell colSpan={columns.length} className="h-8 text-center">
<div className="text-center text-gray-500">
<p>No logs found</p>
<p>{noDataMessage}</p>
</div>
</TableCell>
</TableRow>