_generic_api_call_with_fallbacks

This commit is contained in:
Ishaan Jaff 2025-03-12 15:26:37 -07:00
parent 52e845ec89
commit 631a97b64d

View file

@ -581,13 +581,7 @@ class Router:
self._initialize_alerting()
self.initialize_assistants_endpoint()
self.amoderation = self.factory_function(
litellm.amoderation, call_type="moderation"
)
self.aanthropic_messages = self.factory_function(
litellm.anthropic_messages, call_type="anthropic_messages"
)
self.initialize_router_endpoints()
def discard(self):
"""
@ -653,6 +647,18 @@ class Router:
self.aget_messages = self.factory_function(litellm.aget_messages)
self.arun_thread = self.factory_function(litellm.arun_thread)
def initialize_router_endpoints(self):
self.amoderation = self.factory_function(
litellm.amoderation, call_type="moderation"
)
self.aanthropic_messages = self.factory_function(
litellm.anthropic_messages, call_type="anthropic_messages"
)
self.aresponses = self.factory_function(
litellm.aresponses, call_type="aresponses"
)
self.responses = self.factory_function(litellm.responses, call_type="responses")
def routing_strategy_init(
self, routing_strategy: Union[RoutingStrategy, str], routing_strategy_args: dict
):
@ -2453,6 +2459,63 @@ class Router:
self.fail_calls[model] += 1
raise e
def _generic_api_call_with_fallbacks(
self, model: str, original_function: Callable, **kwargs
):
"""
Make a generic LLM API call through the router, this allows you to use retries/fallbacks with litellm router
Args:
model: The model to use
original_function: The handler function to call (e.g., litellm.completion)
**kwargs: Additional arguments to pass to the handler function
Returns:
The response from the handler function
"""
handler_name = original_function.__name__
try:
verbose_router_logger.debug(
f"Inside _generic_api_call() - handler: {handler_name}, model: {model}; kwargs: {kwargs}"
)
deployment = self.get_available_deployment(
model=model,
messages=kwargs.get("messages", None),
specific_deployment=kwargs.pop("specific_deployment", None),
)
self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
data = deployment["litellm_params"].copy()
model_name = data["model"]
model_client = self._get_client(
deployment=deployment, kwargs=kwargs, client_type="sync"
)
self.total_calls[model_name] += 1
# Perform pre-call checks for routing strategy
self.routing_strategy_pre_call_checks(deployment=deployment)
response = original_function(
**{
**data,
"caching": self.cache_responses,
"client": model_client,
**kwargs,
}
)
self.success_calls[model_name] += 1
verbose_router_logger.info(
f"{handler_name}(model={model_name})\033[32m 200 OK\033[0m"
)
return response
except Exception as e:
verbose_router_logger.info(
f"{handler_name}(model={model})\033[31m Exception {str(e)}\033[0m"
)
if model is not None:
self.fail_calls[model] += 1
raise e
def embedding(
self,
model: str,
@ -2974,7 +3037,7 @@ class Router:
self,
original_function: Callable,
call_type: Literal[
"assistants", "moderation", "anthropic_messages"
"assistants", "moderation", "anthropic_messages", "responses", "aresponses"
] = "assistants",
):
async def new_function(