factory_function

This commit is contained in:
Ishaan Jaff 2025-03-12 15:27:34 -07:00
parent 631a97b64d
commit 4e51321e24

View file

@ -3037,14 +3037,42 @@ class Router:
self, self,
original_function: Callable, original_function: Callable,
call_type: Literal[ call_type: Literal[
"assistants", "moderation", "anthropic_messages", "responses", "aresponses" "assistants",
"moderation",
"anthropic_messages",
"aresponses",
"responses",
] = "assistants", ] = "assistants",
): ):
async def new_function( """
Creates appropriate wrapper functions for different API call types.
Returns:
- A synchronous function for synchronous call types
- An asynchronous function for asynchronous call types
"""
# Handle synchronous call types
if call_type == "responses":
def sync_wrapper(
custom_llm_provider: Optional[
Literal["openai", "azure", "anthropic"]
] = None,
client: Optional[Any] = None,
**kwargs,
):
return self._generic_api_call_with_fallbacks(
original_function=original_function, **kwargs
)
return sync_wrapper
# Handle asynchronous call types
async def async_wrapper(
custom_llm_provider: Optional[ custom_llm_provider: Optional[
Literal["openai", "azure", "anthropic"] Literal["openai", "azure", "anthropic"]
] = None, ] = None,
client: Optional["AsyncOpenAI"] = None, client: Optional[Any] = None,
**kwargs, **kwargs,
): ):
if call_type == "assistants": if call_type == "assistants":
@ -3055,18 +3083,16 @@ class Router:
**kwargs, **kwargs,
) )
elif call_type == "moderation": elif call_type == "moderation":
return await self._pass_through_moderation_endpoint_factory(
return await self._pass_through_moderation_endpoint_factory( # type: ignore original_function=original_function, **kwargs
original_function=original_function,
**kwargs,
) )
elif call_type == "anthropic_messages": elif call_type in ("anthropic_messages", "aresponses"):
return await self._ageneric_api_call_with_fallbacks( # type: ignore return await self._ageneric_api_call_with_fallbacks(
original_function=original_function, original_function=original_function,
**kwargs, **kwargs,
) )
return new_function return async_wrapper
async def _pass_through_assistants_endpoint_factory( async def _pass_through_assistants_endpoint_factory(
self, self,