test_openai_responses_litellm_router

This commit is contained in:
Ishaan Jaff 2025-03-12 16:13:48 -07:00
parent 89d30d39f6
commit d808fa3c23
3 changed files with 114 additions and 5 deletions

View file

@ -1127,7 +1127,12 @@ class Router:
) # add new deployment to router
return deployment_pydantic_obj
def _update_kwargs_with_deployment(self, deployment: dict, kwargs: dict) -> None:
def _update_kwargs_with_deployment(
self,
deployment: dict,
kwargs: dict,
function_name: Optional[str] = None,
) -> None:
"""
2 jobs:
- Adds selected deployment, model_info and api_base to kwargs["metadata"] (used for logging)
@ -1144,7 +1149,10 @@ class Router:
deployment_model_name = deployment_pydantic_obj.litellm_params.model
deployment_api_base = deployment_pydantic_obj.litellm_params.api_base
kwargs.setdefault("metadata", {}).update(
metadata_variable_name = _get_router_metadata_variable_name(
function_name=function_name,
)
kwargs.setdefault(metadata_variable_name, {}).update(
{
"deployment": deployment_model_name,
"model_info": model_info,
@ -2402,7 +2410,9 @@ class Router:
messages=kwargs.get("messages", None),
specific_deployment=kwargs.pop("specific_deployment", None),
)
self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
self._update_kwargs_with_deployment(
deployment=deployment, kwargs=kwargs, function_name="generic_api_call"
)
data = deployment["litellm_params"].copy()
model_name = data["model"]
@ -2481,7 +2491,9 @@ class Router:
messages=kwargs.get("messages", None),
specific_deployment=kwargs.pop("specific_deployment", None),
)
self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
self._update_kwargs_with_deployment(
deployment=deployment, kwargs=kwargs, function_name="generic_api_call"
)
data = deployment["litellm_params"].copy()
model_name = data["model"]

View file

@ -56,7 +56,8 @@ def _get_router_metadata_variable_name(function_name) -> str:
For ALL other endpoints we call this "metadata
"""
if "batch" in function_name:
ROUTER_METHODS_USING_LITELLM_METADATA = set(["batch", "generic_api_call"])
if function_name in ROUTER_METHODS_USING_LITELLM_METADATA:
return "litellm_metadata"
else:
return "metadata"

View file

@ -503,3 +503,99 @@ async def test_openai_responses_api_streaming_validation(sync_mode):
assert not missing_events, f"Missing required event types: {missing_events}"
print(f"Successfully validated all event types: {event_types_seen}")
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_openai_responses_litellm_router(sync_mode):
"""
Test the OpenAI responses API with LiteLLM Router in both sync and async modes
"""
litellm._turn_on_debug()
router = litellm.Router(
model_list=[
{
"model_name": "gpt4o-special-alias",
"litellm_params": {
"model": "gpt-4o",
"api_key": os.getenv("OPENAI_API_KEY"),
},
}
]
)
# Call the handler
if sync_mode:
response = router.responses(
model="gpt4o-special-alias",
input="Hello, can you tell me a short joke?",
max_output_tokens=100,
)
print("SYNC MODE RESPONSE=", response)
else:
response = await router.aresponses(
model="gpt4o-special-alias",
input="Hello, can you tell me a short joke?",
max_output_tokens=100,
)
print(
f"Router {'sync' if sync_mode else 'async'} response=",
json.dumps(response, indent=4, default=str),
)
# Use the helper function to validate the response
validate_responses_api_response(response, final_chunk=True)
return response
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_openai_responses_litellm_router_streaming(sync_mode):
"""
Test the OpenAI responses API with streaming through LiteLLM Router
"""
litellm._turn_on_debug()
router = litellm.Router(
model_list=[
{
"model_name": "gpt4o-special-alias",
"litellm_params": {
"model": "gpt-4o",
"api_key": os.getenv("OPENAI_API_KEY"),
},
}
]
)
event_types_seen = set()
if sync_mode:
response = router.responses(
model="gpt4o-special-alias",
input="Tell me about artificial intelligence in 2 sentences.",
stream=True,
)
for event in response:
print(f"Validating event type: {event.type}")
validate_stream_event(event)
event_types_seen.add(event.type)
else:
response = await router.aresponses(
model="gpt4o-special-alias",
input="Tell me about artificial intelligence in 2 sentences.",
stream=True,
)
async for event in response:
print(f"Validating event type: {event.type}")
validate_stream_event(event)
event_types_seen.add(event.type)
# At minimum, we should see these core event types
required_events = {"response.created", "response.completed"}
missing_events = required_events - event_types_seen
assert not missing_events, f"Missing required event types: {missing_events}"
print(f"Successfully validated all event types: {event_types_seen}")