factory_function

This commit is contained in:
Ishaan Jaff 2025-03-12 15:27:34 -07:00
parent 631a97b64d
commit 4e51321e24

View file

@ -3037,14 +3037,42 @@ class Router:
self,
original_function: Callable,
call_type: Literal[
"assistants", "moderation", "anthropic_messages", "responses", "aresponses"
"assistants",
"moderation",
"anthropic_messages",
"aresponses",
"responses",
] = "assistants",
):
async def new_function(
"""
Creates appropriate wrapper functions for different API call types.
Returns:
- A synchronous function for synchronous call types
- An asynchronous function for asynchronous call types
"""
# Handle synchronous call types
if call_type == "responses":
def sync_wrapper(
custom_llm_provider: Optional[
Literal["openai", "azure", "anthropic"]
] = None,
client: Optional[Any] = None,
**kwargs,
):
return self._generic_api_call_with_fallbacks(
original_function=original_function, **kwargs
)
return sync_wrapper
# Handle asynchronous call types
async def async_wrapper(
custom_llm_provider: Optional[
Literal["openai", "azure", "anthropic"]
] = None,
client: Optional["AsyncOpenAI"] = None,
client: Optional[Any] = None,
**kwargs,
):
if call_type == "assistants":
@ -3055,18 +3083,16 @@ class Router:
**kwargs,
)
elif call_type == "moderation":
return await self._pass_through_moderation_endpoint_factory( # type: ignore
original_function=original_function,
**kwargs,
return await self._pass_through_moderation_endpoint_factory(
original_function=original_function, **kwargs
)
elif call_type == "anthropic_messages":
return await self._ageneric_api_call_with_fallbacks( # type: ignore
elif call_type in ("anthropic_messages", "aresponses"):
return await self._ageneric_api_call_with_fallbacks(
original_function=original_function,
**kwargs,
)
return new_function
return async_wrapper
async def _pass_through_assistants_endpoint_factory(
self,