diff --git a/litellm/router_utils/pre_call_checks/responses_api_deployment_check.py b/litellm/router_utils/pre_call_checks/responses_api_deployment_check.py index 48754a8490..5227fbaa79 100644 --- a/litellm/router_utils/pre_call_checks/responses_api_deployment_check.py +++ b/litellm/router_utils/pre_call_checks/responses_api_deployment_check.py @@ -2,7 +2,7 @@ If previous_response_id is provided, route to the deployment that returned the previous response """ -from typing import List, Optional, cast +from typing import List, Optional from litellm import verbose_logger from litellm.caching.dual_cache import DualCache @@ -32,9 +32,15 @@ class ResponsesApiDeploymentCheck(CustomLogger): if previous_response_id is None: return healthy_deployments - # for deployment in healthy_deployments: - # if deployment["model_info"]["id"] == model_id: - # return [deployment] + model_id = await self.async_get_response_id_from_cache( + response_id=previous_response_id, + ) + if model_id is None: + return healthy_deployments + + for deployment in healthy_deployments: + if deployment["model_info"]["id"] == model_id: + return [deployment] return healthy_deployments @@ -60,6 +66,17 @@ class ResponsesApiDeploymentCheck(CustomLogger): response_id = getattr(response_obj, "id", None) model_id = standard_logging_object["model_id"] + if response_id is None or model_id is None: + verbose_logger.debug( + "litellm.router_utils.pre_call_checks.responses_api_deployment_check: skipping adding response_id to cache, RESPONSE ID OR MODEL ID IS NONE" + ) + return + + await self.async_add_response_id_to_cache( + response_id=response_id, + model_id=model_id, + ) + return async def async_add_response_id_to_cache( @@ -68,9 +85,20 @@ class ResponsesApiDeploymentCheck(CustomLogger): model_id: str, ): await self.cache.async_set_cache( - key=self.RESPONSES_API_RESPONSE_MODEL_ID_CACHE_KEY, + key=self.get_cache_key_for_response_id(response_id), value={ "response_id": response_id, "model_id": model_id, }, ) + + async def async_get_response_id_from_cache(self, response_id: str) -> Optional[str]: + cache_value = await self.cache.async_get_cache( + key=self.get_cache_key_for_response_id(response_id), + ) + if cache_value is None: + return None + return str(cache_value) + + def get_cache_key_for_response_id(self, response_id: str) -> str: + return f"{self.RESPONSES_API_RESPONSE_MODEL_ID_CACHE_KEY}:{response_id}"