mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
[Feat] Add GET, DELETE Responses endpoints on LiteLLM Proxy (#10297)
* add GET responses endpoints on router * add GET responses endpoints on router * add GET responses endpoints on router * add DELETE responses endpoints on proxy * fixes for testing GET, DELETE endpoints * test_basic_responses api e2e
This commit is contained in:
parent
0a2c964db7
commit
5de101ab7b
8 changed files with 182 additions and 20 deletions
|
@ -108,7 +108,13 @@ class ProxyBaseLLMRequestProcessing:
|
||||||
user_api_key_dict: UserAPIKeyAuth,
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
proxy_logging_obj: ProxyLogging,
|
proxy_logging_obj: ProxyLogging,
|
||||||
proxy_config: ProxyConfig,
|
proxy_config: ProxyConfig,
|
||||||
route_type: Literal["acompletion", "aresponses", "_arealtime"],
|
route_type: Literal[
|
||||||
|
"acompletion",
|
||||||
|
"aresponses",
|
||||||
|
"_arealtime",
|
||||||
|
"aget_responses",
|
||||||
|
"adelete_responses",
|
||||||
|
],
|
||||||
version: Optional[str] = None,
|
version: Optional[str] = None,
|
||||||
user_model: Optional[str] = None,
|
user_model: Optional[str] = None,
|
||||||
user_temperature: Optional[float] = None,
|
user_temperature: Optional[float] = None,
|
||||||
|
@ -178,7 +184,13 @@ class ProxyBaseLLMRequestProcessing:
|
||||||
request: Request,
|
request: Request,
|
||||||
fastapi_response: Response,
|
fastapi_response: Response,
|
||||||
user_api_key_dict: UserAPIKeyAuth,
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
route_type: Literal["acompletion", "aresponses", "_arealtime"],
|
route_type: Literal[
|
||||||
|
"acompletion",
|
||||||
|
"aresponses",
|
||||||
|
"_arealtime",
|
||||||
|
"aget_responses",
|
||||||
|
"adelete_responses",
|
||||||
|
],
|
||||||
proxy_logging_obj: ProxyLogging,
|
proxy_logging_obj: ProxyLogging,
|
||||||
general_settings: dict,
|
general_settings: dict,
|
||||||
proxy_config: ProxyConfig,
|
proxy_config: ProxyConfig,
|
||||||
|
|
|
@ -1,16 +1,8 @@
|
||||||
model_list:
|
model_list:
|
||||||
- model_name: azure-computer-use-preview
|
- model_name: openai/*
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: azure/computer-use-preview
|
model: openai/*
|
||||||
api_key: mock-api-key
|
api_key: os.environ/OPENAI_API_KEY
|
||||||
api_version: mock-api-version
|
|
||||||
api_base: https://mock-endpoint.openai.azure.com
|
|
||||||
- model_name: azure-computer-use-preview
|
|
||||||
litellm_params:
|
|
||||||
model: azure/computer-use-preview-2
|
|
||||||
api_key: mock-api-key-2
|
|
||||||
api_version: mock-api-version-2
|
|
||||||
api_base: https://mock-endpoint-2.openai.azure.com
|
|
||||||
|
|
||||||
router_settings:
|
router_settings:
|
||||||
optional_pre_call_checks: ["responses_api_deployment_check"]
|
optional_pre_call_checks: ["responses_api_deployment_check"]
|
||||||
|
|
|
@ -106,8 +106,50 @@ async def get_response(
|
||||||
-H "Authorization: Bearer sk-1234"
|
-H "Authorization: Bearer sk-1234"
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
# TODO: Implement response retrieval logic
|
from litellm.proxy.proxy_server import (
|
||||||
pass
|
_read_request_body,
|
||||||
|
general_settings,
|
||||||
|
llm_router,
|
||||||
|
proxy_config,
|
||||||
|
proxy_logging_obj,
|
||||||
|
select_data_generator,
|
||||||
|
user_api_base,
|
||||||
|
user_max_tokens,
|
||||||
|
user_model,
|
||||||
|
user_request_timeout,
|
||||||
|
user_temperature,
|
||||||
|
version,
|
||||||
|
)
|
||||||
|
|
||||||
|
data = await _read_request_body(request=request)
|
||||||
|
data["response_id"] = response_id
|
||||||
|
processor = ProxyBaseLLMRequestProcessing(data=data)
|
||||||
|
try:
|
||||||
|
return await processor.base_process_llm_request(
|
||||||
|
request=request,
|
||||||
|
fastapi_response=fastapi_response,
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
route_type="aget_responses",
|
||||||
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
|
llm_router=llm_router,
|
||||||
|
general_settings=general_settings,
|
||||||
|
proxy_config=proxy_config,
|
||||||
|
select_data_generator=select_data_generator,
|
||||||
|
model=None,
|
||||||
|
user_model=user_model,
|
||||||
|
user_temperature=user_temperature,
|
||||||
|
user_request_timeout=user_request_timeout,
|
||||||
|
user_max_tokens=user_max_tokens,
|
||||||
|
user_api_base=user_api_base,
|
||||||
|
version=version,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise await processor._handle_llm_api_exception(
|
||||||
|
e=e,
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
|
version=version,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.delete(
|
@router.delete(
|
||||||
|
@ -136,8 +178,50 @@ async def delete_response(
|
||||||
-H "Authorization: Bearer sk-1234"
|
-H "Authorization: Bearer sk-1234"
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
# TODO: Implement response deletion logic
|
from litellm.proxy.proxy_server import (
|
||||||
pass
|
_read_request_body,
|
||||||
|
general_settings,
|
||||||
|
llm_router,
|
||||||
|
proxy_config,
|
||||||
|
proxy_logging_obj,
|
||||||
|
select_data_generator,
|
||||||
|
user_api_base,
|
||||||
|
user_max_tokens,
|
||||||
|
user_model,
|
||||||
|
user_request_timeout,
|
||||||
|
user_temperature,
|
||||||
|
version,
|
||||||
|
)
|
||||||
|
|
||||||
|
data = await _read_request_body(request=request)
|
||||||
|
data["response_id"] = response_id
|
||||||
|
processor = ProxyBaseLLMRequestProcessing(data=data)
|
||||||
|
try:
|
||||||
|
return await processor.base_process_llm_request(
|
||||||
|
request=request,
|
||||||
|
fastapi_response=fastapi_response,
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
route_type="adelete_responses",
|
||||||
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
|
llm_router=llm_router,
|
||||||
|
general_settings=general_settings,
|
||||||
|
proxy_config=proxy_config,
|
||||||
|
select_data_generator=select_data_generator,
|
||||||
|
model=None,
|
||||||
|
user_model=user_model,
|
||||||
|
user_temperature=user_temperature,
|
||||||
|
user_request_timeout=user_request_timeout,
|
||||||
|
user_max_tokens=user_max_tokens,
|
||||||
|
user_api_base=user_api_base,
|
||||||
|
version=version,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise await processor._handle_llm_api_exception(
|
||||||
|
e=e,
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
|
version=version,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
|
|
|
@ -47,6 +47,8 @@ async def route_request(
|
||||||
"amoderation",
|
"amoderation",
|
||||||
"arerank",
|
"arerank",
|
||||||
"aresponses",
|
"aresponses",
|
||||||
|
"aget_responses",
|
||||||
|
"adelete_responses",
|
||||||
"_arealtime", # private function for realtime API
|
"_arealtime", # private function for realtime API
|
||||||
],
|
],
|
||||||
):
|
):
|
||||||
|
|
|
@ -176,6 +176,16 @@ class ResponsesAPIRequestUtils:
|
||||||
response_id=response_id,
|
response_id=response_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_model_id_from_response_id(response_id: Optional[str]) -> Optional[str]:
|
||||||
|
"""Get the model_id from the response_id"""
|
||||||
|
if response_id is None:
|
||||||
|
return None
|
||||||
|
decoded_response_id = (
|
||||||
|
ResponsesAPIRequestUtils._decode_responses_api_response_id(response_id)
|
||||||
|
)
|
||||||
|
return decoded_response_id.get("model_id") or None
|
||||||
|
|
||||||
|
|
||||||
class ResponseAPILoggingUtils:
|
class ResponseAPILoggingUtils:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -739,6 +739,12 @@ class Router:
|
||||||
litellm.afile_content, call_type="afile_content"
|
litellm.afile_content, call_type="afile_content"
|
||||||
)
|
)
|
||||||
self.responses = self.factory_function(litellm.responses, call_type="responses")
|
self.responses = self.factory_function(litellm.responses, call_type="responses")
|
||||||
|
self.aget_responses = self.factory_function(
|
||||||
|
litellm.aget_responses, call_type="aget_responses"
|
||||||
|
)
|
||||||
|
self.adelete_responses = self.factory_function(
|
||||||
|
litellm.adelete_responses, call_type="adelete_responses"
|
||||||
|
)
|
||||||
|
|
||||||
def validate_fallbacks(self, fallback_param: Optional[List]):
|
def validate_fallbacks(self, fallback_param: Optional[List]):
|
||||||
"""
|
"""
|
||||||
|
@ -3081,6 +3087,8 @@ class Router:
|
||||||
"anthropic_messages",
|
"anthropic_messages",
|
||||||
"aresponses",
|
"aresponses",
|
||||||
"responses",
|
"responses",
|
||||||
|
"aget_responses",
|
||||||
|
"adelete_responses",
|
||||||
"afile_delete",
|
"afile_delete",
|
||||||
"afile_content",
|
"afile_content",
|
||||||
] = "assistants",
|
] = "assistants",
|
||||||
|
@ -3135,6 +3143,11 @@ class Router:
|
||||||
original_function=original_function,
|
original_function=original_function,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
elif call_type in ("aget_responses", "adelete_responses"):
|
||||||
|
return await self._init_responses_api_endpoints(
|
||||||
|
original_function=original_function,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
elif call_type in ("afile_delete", "afile_content"):
|
elif call_type in ("afile_delete", "afile_content"):
|
||||||
return await self._ageneric_api_call_with_fallbacks(
|
return await self._ageneric_api_call_with_fallbacks(
|
||||||
original_function=original_function,
|
original_function=original_function,
|
||||||
|
@ -3145,6 +3158,28 @@ class Router:
|
||||||
|
|
||||||
return async_wrapper
|
return async_wrapper
|
||||||
|
|
||||||
|
async def _init_responses_api_endpoints(
|
||||||
|
self,
|
||||||
|
original_function: Callable,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the Responses API endpoints on the router.
|
||||||
|
|
||||||
|
GET, DELETE Responses API Requests encode the model_id in the response_id, this function decodes the response_id and sets the model to the model_id.
|
||||||
|
"""
|
||||||
|
from litellm.responses.utils import ResponsesAPIRequestUtils
|
||||||
|
|
||||||
|
model_id = ResponsesAPIRequestUtils.get_model_id_from_response_id(
|
||||||
|
kwargs.get("response_id")
|
||||||
|
)
|
||||||
|
if model_id is not None:
|
||||||
|
kwargs["model"] = model_id
|
||||||
|
return await self._ageneric_api_call_with_fallbacks(
|
||||||
|
original_function=original_function,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
async def _pass_through_assistants_endpoint_factory(
|
async def _pass_through_assistants_endpoint_factory(
|
||||||
self,
|
self,
|
||||||
original_function: Callable,
|
original_function: Callable,
|
||||||
|
|
|
@ -73,15 +73,31 @@ def validate_stream_chunk(chunk):
|
||||||
def test_basic_response():
|
def test_basic_response():
|
||||||
client = get_test_client()
|
client = get_test_client()
|
||||||
response = client.responses.create(
|
response = client.responses.create(
|
||||||
model="gpt-4o", input="just respond with the word 'ping'"
|
model="gpt-4.0", input="just respond with the word 'ping'"
|
||||||
)
|
)
|
||||||
print("basic response=", response)
|
print("basic response=", response)
|
||||||
|
|
||||||
|
# get the response
|
||||||
|
response = client.responses.retrieve(response.id)
|
||||||
|
print("GET response=", response)
|
||||||
|
|
||||||
|
|
||||||
|
# delete the response
|
||||||
|
delete_response = client.responses.delete(response.id)
|
||||||
|
print("DELETE response=", delete_response)
|
||||||
|
|
||||||
|
# try getting the response again, we should not get it back
|
||||||
|
get_response = client.responses.retrieve(response.id)
|
||||||
|
print("GET response after delete=", get_response)
|
||||||
|
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
get_response = client.responses.retrieve(response.id)
|
||||||
|
|
||||||
|
|
||||||
def test_streaming_response():
|
def test_streaming_response():
|
||||||
client = get_test_client()
|
client = get_test_client()
|
||||||
stream = client.responses.create(
|
stream = client.responses.create(
|
||||||
model="gpt-4o", input="just respond with the word 'ping'", stream=True
|
model="gpt-4.0", input="just respond with the word 'ping'", stream=True
|
||||||
)
|
)
|
||||||
|
|
||||||
collected_chunks = []
|
collected_chunks = []
|
||||||
|
@ -104,5 +120,5 @@ def test_bad_request_bad_param_error():
|
||||||
with pytest.raises(BadRequestError):
|
with pytest.raises(BadRequestError):
|
||||||
# Trigger error with invalid model name
|
# Trigger error with invalid model name
|
||||||
client.responses.create(
|
client.responses.create(
|
||||||
model="gpt-4o", input="This should fail", temperature=2000
|
model="gpt-4.0", input="This should fail", temperature=2000
|
||||||
)
|
)
|
||||||
|
|
|
@ -1157,3 +1157,14 @@ def test_cached_get_model_group_info(model_list):
|
||||||
# Verify the cache info shows hits
|
# Verify the cache info shows hits
|
||||||
cache_info = router._cached_get_model_group_info.cache_info()
|
cache_info = router._cached_get_model_group_info.cache_info()
|
||||||
assert cache_info.hits > 0 # Should have at least one cache hit
|
assert cache_info.hits > 0 # Should have at least one cache hit
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_responses_api_endpoints(model_list):
|
||||||
|
"""Test if the '_init_responses_api_endpoints' function is working correctly"""
|
||||||
|
from typing import Callable
|
||||||
|
router = Router(model_list=model_list)
|
||||||
|
|
||||||
|
assert router.aget_responses is not None
|
||||||
|
assert isinstance(router.aget_responses, Callable)
|
||||||
|
assert router.adelete_responses is not None
|
||||||
|
assert isinstance(router.adelete_responses, Callable)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue