diff --git a/litellm/proxy/common_request_processing.py b/litellm/proxy/common_request_processing.py index 60050fbeb2..2ea3c18ea8 100644 --- a/litellm/proxy/common_request_processing.py +++ b/litellm/proxy/common_request_processing.py @@ -108,7 +108,13 @@ class ProxyBaseLLMRequestProcessing: user_api_key_dict: UserAPIKeyAuth, proxy_logging_obj: ProxyLogging, proxy_config: ProxyConfig, - route_type: Literal["acompletion", "aresponses", "_arealtime"], + route_type: Literal[ + "acompletion", + "aresponses", + "_arealtime", + "aget_responses", + "adelete_responses", + ], version: Optional[str] = None, user_model: Optional[str] = None, user_temperature: Optional[float] = None, @@ -178,7 +184,13 @@ class ProxyBaseLLMRequestProcessing: request: Request, fastapi_response: Response, user_api_key_dict: UserAPIKeyAuth, - route_type: Literal["acompletion", "aresponses", "_arealtime"], + route_type: Literal[ + "acompletion", + "aresponses", + "_arealtime", + "aget_responses", + "adelete_responses", + ], proxy_logging_obj: ProxyLogging, general_settings: dict, proxy_config: ProxyConfig, diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 63d0a3ffe2..75c49211e7 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -1,16 +1,8 @@ model_list: - - model_name: azure-computer-use-preview + - model_name: openai/* litellm_params: - model: azure/computer-use-preview - api_key: mock-api-key - api_version: mock-api-version - api_base: https://mock-endpoint.openai.azure.com - - model_name: azure-computer-use-preview - litellm_params: - model: azure/computer-use-preview-2 - api_key: mock-api-key-2 - api_version: mock-api-version-2 - api_base: https://mock-endpoint-2.openai.azure.com + model: openai/* + api_key: os.environ/OPENAI_API_KEY router_settings: optional_pre_call_checks: ["responses_api_deployment_check"] diff --git a/litellm/proxy/response_api_endpoints/endpoints.py b/litellm/proxy/response_api_endpoints/endpoints.py index f9ddf306a7..bef5b4b807 100644 --- a/litellm/proxy/response_api_endpoints/endpoints.py +++ b/litellm/proxy/response_api_endpoints/endpoints.py @@ -106,8 +106,50 @@ async def get_response( -H "Authorization: Bearer sk-1234" ``` """ - # TODO: Implement response retrieval logic - pass + from litellm.proxy.proxy_server import ( + _read_request_body, + general_settings, + llm_router, + proxy_config, + proxy_logging_obj, + select_data_generator, + user_api_base, + user_max_tokens, + user_model, + user_request_timeout, + user_temperature, + version, + ) + + data = await _read_request_body(request=request) + data["response_id"] = response_id + processor = ProxyBaseLLMRequestProcessing(data=data) + try: + return await processor.base_process_llm_request( + request=request, + fastapi_response=fastapi_response, + user_api_key_dict=user_api_key_dict, + route_type="aget_responses", + proxy_logging_obj=proxy_logging_obj, + llm_router=llm_router, + general_settings=general_settings, + proxy_config=proxy_config, + select_data_generator=select_data_generator, + model=None, + user_model=user_model, + user_temperature=user_temperature, + user_request_timeout=user_request_timeout, + user_max_tokens=user_max_tokens, + user_api_base=user_api_base, + version=version, + ) + except Exception as e: + raise await processor._handle_llm_api_exception( + e=e, + user_api_key_dict=user_api_key_dict, + proxy_logging_obj=proxy_logging_obj, + version=version, + ) @router.delete( @@ -136,8 +178,50 @@ async def delete_response( -H "Authorization: Bearer sk-1234" ``` """ - # TODO: Implement response deletion logic - pass + from litellm.proxy.proxy_server import ( + _read_request_body, + general_settings, + llm_router, + proxy_config, + proxy_logging_obj, + select_data_generator, + user_api_base, + user_max_tokens, + user_model, + user_request_timeout, + user_temperature, + version, + ) + + data = await _read_request_body(request=request) + data["response_id"] = response_id + processor = ProxyBaseLLMRequestProcessing(data=data) + try: + return await processor.base_process_llm_request( + request=request, + fastapi_response=fastapi_response, + user_api_key_dict=user_api_key_dict, + route_type="adelete_responses", + proxy_logging_obj=proxy_logging_obj, + llm_router=llm_router, + general_settings=general_settings, + proxy_config=proxy_config, + select_data_generator=select_data_generator, + model=None, + user_model=user_model, + user_temperature=user_temperature, + user_request_timeout=user_request_timeout, + user_max_tokens=user_max_tokens, + user_api_base=user_api_base, + version=version, + ) + except Exception as e: + raise await processor._handle_llm_api_exception( + e=e, + user_api_key_dict=user_api_key_dict, + proxy_logging_obj=proxy_logging_obj, + version=version, + ) @router.get( diff --git a/litellm/proxy/route_llm_request.py b/litellm/proxy/route_llm_request.py index ac9332b219..ef5edaff99 100644 --- a/litellm/proxy/route_llm_request.py +++ b/litellm/proxy/route_llm_request.py @@ -47,6 +47,8 @@ async def route_request( "amoderation", "arerank", "aresponses", + "aget_responses", + "adelete_responses", "_arealtime", # private function for realtime API ], ): diff --git a/litellm/responses/utils.py b/litellm/responses/utils.py index 9fa455de71..cf51c40b2a 100644 --- a/litellm/responses/utils.py +++ b/litellm/responses/utils.py @@ -176,6 +176,16 @@ class ResponsesAPIRequestUtils: response_id=response_id, ) + @staticmethod + def get_model_id_from_response_id(response_id: Optional[str]) -> Optional[str]: + """Get the model_id from the response_id""" + if response_id is None: + return None + decoded_response_id = ( + ResponsesAPIRequestUtils._decode_responses_api_response_id(response_id) + ) + return decoded_response_id.get("model_id") or None + class ResponseAPILoggingUtils: @staticmethod diff --git a/litellm/router.py b/litellm/router.py index dba886b856..e7f98fab48 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -739,6 +739,12 @@ class Router: litellm.afile_content, call_type="afile_content" ) self.responses = self.factory_function(litellm.responses, call_type="responses") + self.aget_responses = self.factory_function( + litellm.aget_responses, call_type="aget_responses" + ) + self.adelete_responses = self.factory_function( + litellm.adelete_responses, call_type="adelete_responses" + ) def validate_fallbacks(self, fallback_param: Optional[List]): """ @@ -3081,6 +3087,8 @@ class Router: "anthropic_messages", "aresponses", "responses", + "aget_responses", + "adelete_responses", "afile_delete", "afile_content", ] = "assistants", @@ -3135,6 +3143,11 @@ class Router: original_function=original_function, **kwargs, ) + elif call_type in ("aget_responses", "adelete_responses"): + return await self._init_responses_api_endpoints( + original_function=original_function, + **kwargs, + ) elif call_type in ("afile_delete", "afile_content"): return await self._ageneric_api_call_with_fallbacks( original_function=original_function, @@ -3145,6 +3158,28 @@ class Router: return async_wrapper + async def _init_responses_api_endpoints( + self, + original_function: Callable, + **kwargs, + ): + """ + Initialize the Responses API endpoints on the router. + + GET, DELETE Responses API Requests encode the model_id in the response_id, this function decodes the response_id and sets the model to the model_id. + """ + from litellm.responses.utils import ResponsesAPIRequestUtils + + model_id = ResponsesAPIRequestUtils.get_model_id_from_response_id( + kwargs.get("response_id") + ) + if model_id is not None: + kwargs["model"] = model_id + return await self._ageneric_api_call_with_fallbacks( + original_function=original_function, + **kwargs, + ) + async def _pass_through_assistants_endpoint_factory( self, original_function: Callable, diff --git a/tests/openai_endpoints_tests/test_e2e_openai_responses_api.py b/tests/openai_endpoints_tests/test_e2e_openai_responses_api.py index 1dde8ebae6..f87444141e 100644 --- a/tests/openai_endpoints_tests/test_e2e_openai_responses_api.py +++ b/tests/openai_endpoints_tests/test_e2e_openai_responses_api.py @@ -73,15 +73,31 @@ def validate_stream_chunk(chunk): def test_basic_response(): client = get_test_client() response = client.responses.create( - model="gpt-4o", input="just respond with the word 'ping'" + model="gpt-4.0", input="just respond with the word 'ping'" ) print("basic response=", response) + # get the response + response = client.responses.retrieve(response.id) + print("GET response=", response) + + + # delete the response + delete_response = client.responses.delete(response.id) + print("DELETE response=", delete_response) + + # try getting the response again, we should not get it back + get_response = client.responses.retrieve(response.id) + print("GET response after delete=", get_response) + + with pytest.raises(Exception): + get_response = client.responses.retrieve(response.id) + def test_streaming_response(): client = get_test_client() stream = client.responses.create( - model="gpt-4o", input="just respond with the word 'ping'", stream=True + model="gpt-4.0", input="just respond with the word 'ping'", stream=True ) collected_chunks = [] @@ -104,5 +120,5 @@ def test_bad_request_bad_param_error(): with pytest.raises(BadRequestError): # Trigger error with invalid model name client.responses.create( - model="gpt-4o", input="This should fail", temperature=2000 + model="gpt-4.0", input="This should fail", temperature=2000 ) diff --git a/tests/router_unit_tests/test_router_helper_utils.py b/tests/router_unit_tests/test_router_helper_utils.py index 782f0d8fbb..b17f0c0a5e 100644 --- a/tests/router_unit_tests/test_router_helper_utils.py +++ b/tests/router_unit_tests/test_router_helper_utils.py @@ -1157,3 +1157,14 @@ def test_cached_get_model_group_info(model_list): # Verify the cache info shows hits cache_info = router._cached_get_model_group_info.cache_info() assert cache_info.hits > 0 # Should have at least one cache hit + + +def test_init_responses_api_endpoints(model_list): + """Test if the '_init_responses_api_endpoints' function is working correctly""" + from typing import Callable + router = Router(model_list=model_list) + + assert router.aget_responses is not None + assert isinstance(router.aget_responses, Callable) + assert router.adelete_responses is not None + assert isinstance(router.adelete_responses, Callable)