mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
[Feat] Add GET, DELETE Responses endpoints on LiteLLM Proxy (#10297)
* add GET responses endpoints on router * add GET responses endpoints on router * add GET responses endpoints on router * add DELETE responses endpoints on proxy * fixes for testing GET, DELETE endpoints * test_basic_responses api e2e
This commit is contained in:
parent
0a2c964db7
commit
5de101ab7b
8 changed files with 182 additions and 20 deletions
|
@ -108,7 +108,13 @@ class ProxyBaseLLMRequestProcessing:
|
|||
user_api_key_dict: UserAPIKeyAuth,
|
||||
proxy_logging_obj: ProxyLogging,
|
||||
proxy_config: ProxyConfig,
|
||||
route_type: Literal["acompletion", "aresponses", "_arealtime"],
|
||||
route_type: Literal[
|
||||
"acompletion",
|
||||
"aresponses",
|
||||
"_arealtime",
|
||||
"aget_responses",
|
||||
"adelete_responses",
|
||||
],
|
||||
version: Optional[str] = None,
|
||||
user_model: Optional[str] = None,
|
||||
user_temperature: Optional[float] = None,
|
||||
|
@ -178,7 +184,13 @@ class ProxyBaseLLMRequestProcessing:
|
|||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
route_type: Literal["acompletion", "aresponses", "_arealtime"],
|
||||
route_type: Literal[
|
||||
"acompletion",
|
||||
"aresponses",
|
||||
"_arealtime",
|
||||
"aget_responses",
|
||||
"adelete_responses",
|
||||
],
|
||||
proxy_logging_obj: ProxyLogging,
|
||||
general_settings: dict,
|
||||
proxy_config: ProxyConfig,
|
||||
|
|
|
@ -1,16 +1,8 @@
|
|||
model_list:
|
||||
- model_name: azure-computer-use-preview
|
||||
- model_name: openai/*
|
||||
litellm_params:
|
||||
model: azure/computer-use-preview
|
||||
api_key: mock-api-key
|
||||
api_version: mock-api-version
|
||||
api_base: https://mock-endpoint.openai.azure.com
|
||||
- model_name: azure-computer-use-preview
|
||||
litellm_params:
|
||||
model: azure/computer-use-preview-2
|
||||
api_key: mock-api-key-2
|
||||
api_version: mock-api-version-2
|
||||
api_base: https://mock-endpoint-2.openai.azure.com
|
||||
model: openai/*
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
|
||||
router_settings:
|
||||
optional_pre_call_checks: ["responses_api_deployment_check"]
|
||||
|
|
|
@ -106,8 +106,50 @@ async def get_response(
|
|||
-H "Authorization: Bearer sk-1234"
|
||||
```
|
||||
"""
|
||||
# TODO: Implement response retrieval logic
|
||||
pass
|
||||
from litellm.proxy.proxy_server import (
|
||||
_read_request_body,
|
||||
general_settings,
|
||||
llm_router,
|
||||
proxy_config,
|
||||
proxy_logging_obj,
|
||||
select_data_generator,
|
||||
user_api_base,
|
||||
user_max_tokens,
|
||||
user_model,
|
||||
user_request_timeout,
|
||||
user_temperature,
|
||||
version,
|
||||
)
|
||||
|
||||
data = await _read_request_body(request=request)
|
||||
data["response_id"] = response_id
|
||||
processor = ProxyBaseLLMRequestProcessing(data=data)
|
||||
try:
|
||||
return await processor.base_process_llm_request(
|
||||
request=request,
|
||||
fastapi_response=fastapi_response,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
route_type="aget_responses",
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
llm_router=llm_router,
|
||||
general_settings=general_settings,
|
||||
proxy_config=proxy_config,
|
||||
select_data_generator=select_data_generator,
|
||||
model=None,
|
||||
user_model=user_model,
|
||||
user_temperature=user_temperature,
|
||||
user_request_timeout=user_request_timeout,
|
||||
user_max_tokens=user_max_tokens,
|
||||
user_api_base=user_api_base,
|
||||
version=version,
|
||||
)
|
||||
except Exception as e:
|
||||
raise await processor._handle_llm_api_exception(
|
||||
e=e,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
version=version,
|
||||
)
|
||||
|
||||
|
||||
@router.delete(
|
||||
|
@ -136,8 +178,50 @@ async def delete_response(
|
|||
-H "Authorization: Bearer sk-1234"
|
||||
```
|
||||
"""
|
||||
# TODO: Implement response deletion logic
|
||||
pass
|
||||
from litellm.proxy.proxy_server import (
|
||||
_read_request_body,
|
||||
general_settings,
|
||||
llm_router,
|
||||
proxy_config,
|
||||
proxy_logging_obj,
|
||||
select_data_generator,
|
||||
user_api_base,
|
||||
user_max_tokens,
|
||||
user_model,
|
||||
user_request_timeout,
|
||||
user_temperature,
|
||||
version,
|
||||
)
|
||||
|
||||
data = await _read_request_body(request=request)
|
||||
data["response_id"] = response_id
|
||||
processor = ProxyBaseLLMRequestProcessing(data=data)
|
||||
try:
|
||||
return await processor.base_process_llm_request(
|
||||
request=request,
|
||||
fastapi_response=fastapi_response,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
route_type="adelete_responses",
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
llm_router=llm_router,
|
||||
general_settings=general_settings,
|
||||
proxy_config=proxy_config,
|
||||
select_data_generator=select_data_generator,
|
||||
model=None,
|
||||
user_model=user_model,
|
||||
user_temperature=user_temperature,
|
||||
user_request_timeout=user_request_timeout,
|
||||
user_max_tokens=user_max_tokens,
|
||||
user_api_base=user_api_base,
|
||||
version=version,
|
||||
)
|
||||
except Exception as e:
|
||||
raise await processor._handle_llm_api_exception(
|
||||
e=e,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
version=version,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
|
|
|
@ -47,6 +47,8 @@ async def route_request(
|
|||
"amoderation",
|
||||
"arerank",
|
||||
"aresponses",
|
||||
"aget_responses",
|
||||
"adelete_responses",
|
||||
"_arealtime", # private function for realtime API
|
||||
],
|
||||
):
|
||||
|
|
|
@ -176,6 +176,16 @@ class ResponsesAPIRequestUtils:
|
|||
response_id=response_id,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_model_id_from_response_id(response_id: Optional[str]) -> Optional[str]:
|
||||
"""Get the model_id from the response_id"""
|
||||
if response_id is None:
|
||||
return None
|
||||
decoded_response_id = (
|
||||
ResponsesAPIRequestUtils._decode_responses_api_response_id(response_id)
|
||||
)
|
||||
return decoded_response_id.get("model_id") or None
|
||||
|
||||
|
||||
class ResponseAPILoggingUtils:
|
||||
@staticmethod
|
||||
|
|
|
@ -739,6 +739,12 @@ class Router:
|
|||
litellm.afile_content, call_type="afile_content"
|
||||
)
|
||||
self.responses = self.factory_function(litellm.responses, call_type="responses")
|
||||
self.aget_responses = self.factory_function(
|
||||
litellm.aget_responses, call_type="aget_responses"
|
||||
)
|
||||
self.adelete_responses = self.factory_function(
|
||||
litellm.adelete_responses, call_type="adelete_responses"
|
||||
)
|
||||
|
||||
def validate_fallbacks(self, fallback_param: Optional[List]):
|
||||
"""
|
||||
|
@ -3081,6 +3087,8 @@ class Router:
|
|||
"anthropic_messages",
|
||||
"aresponses",
|
||||
"responses",
|
||||
"aget_responses",
|
||||
"adelete_responses",
|
||||
"afile_delete",
|
||||
"afile_content",
|
||||
] = "assistants",
|
||||
|
@ -3135,6 +3143,11 @@ class Router:
|
|||
original_function=original_function,
|
||||
**kwargs,
|
||||
)
|
||||
elif call_type in ("aget_responses", "adelete_responses"):
|
||||
return await self._init_responses_api_endpoints(
|
||||
original_function=original_function,
|
||||
**kwargs,
|
||||
)
|
||||
elif call_type in ("afile_delete", "afile_content"):
|
||||
return await self._ageneric_api_call_with_fallbacks(
|
||||
original_function=original_function,
|
||||
|
@ -3145,6 +3158,28 @@ class Router:
|
|||
|
||||
return async_wrapper
|
||||
|
||||
async def _init_responses_api_endpoints(
|
||||
self,
|
||||
original_function: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initialize the Responses API endpoints on the router.
|
||||
|
||||
GET, DELETE Responses API Requests encode the model_id in the response_id, this function decodes the response_id and sets the model to the model_id.
|
||||
"""
|
||||
from litellm.responses.utils import ResponsesAPIRequestUtils
|
||||
|
||||
model_id = ResponsesAPIRequestUtils.get_model_id_from_response_id(
|
||||
kwargs.get("response_id")
|
||||
)
|
||||
if model_id is not None:
|
||||
kwargs["model"] = model_id
|
||||
return await self._ageneric_api_call_with_fallbacks(
|
||||
original_function=original_function,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def _pass_through_assistants_endpoint_factory(
|
||||
self,
|
||||
original_function: Callable,
|
||||
|
|
|
@ -73,15 +73,31 @@ def validate_stream_chunk(chunk):
|
|||
def test_basic_response():
|
||||
client = get_test_client()
|
||||
response = client.responses.create(
|
||||
model="gpt-4o", input="just respond with the word 'ping'"
|
||||
model="gpt-4.0", input="just respond with the word 'ping'"
|
||||
)
|
||||
print("basic response=", response)
|
||||
|
||||
# get the response
|
||||
response = client.responses.retrieve(response.id)
|
||||
print("GET response=", response)
|
||||
|
||||
|
||||
# delete the response
|
||||
delete_response = client.responses.delete(response.id)
|
||||
print("DELETE response=", delete_response)
|
||||
|
||||
# try getting the response again, we should not get it back
|
||||
get_response = client.responses.retrieve(response.id)
|
||||
print("GET response after delete=", get_response)
|
||||
|
||||
with pytest.raises(Exception):
|
||||
get_response = client.responses.retrieve(response.id)
|
||||
|
||||
|
||||
def test_streaming_response():
|
||||
client = get_test_client()
|
||||
stream = client.responses.create(
|
||||
model="gpt-4o", input="just respond with the word 'ping'", stream=True
|
||||
model="gpt-4.0", input="just respond with the word 'ping'", stream=True
|
||||
)
|
||||
|
||||
collected_chunks = []
|
||||
|
@ -104,5 +120,5 @@ def test_bad_request_bad_param_error():
|
|||
with pytest.raises(BadRequestError):
|
||||
# Trigger error with invalid model name
|
||||
client.responses.create(
|
||||
model="gpt-4o", input="This should fail", temperature=2000
|
||||
model="gpt-4.0", input="This should fail", temperature=2000
|
||||
)
|
||||
|
|
|
@ -1157,3 +1157,14 @@ def test_cached_get_model_group_info(model_list):
|
|||
# Verify the cache info shows hits
|
||||
cache_info = router._cached_get_model_group_info.cache_info()
|
||||
assert cache_info.hits > 0 # Should have at least one cache hit
|
||||
|
||||
|
||||
def test_init_responses_api_endpoints(model_list):
|
||||
"""Test if the '_init_responses_api_endpoints' function is working correctly"""
|
||||
from typing import Callable
|
||||
router = Router(model_list=model_list)
|
||||
|
||||
assert router.aget_responses is not None
|
||||
assert isinstance(router.aget_responses, Callable)
|
||||
assert router.adelete_responses is not None
|
||||
assert isinstance(router.adelete_responses, Callable)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue