mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
test_openai_responses_litellm_router
This commit is contained in:
parent
4e51321e24
commit
bfc928192d
3 changed files with 114 additions and 5 deletions
|
@ -1127,7 +1127,12 @@ class Router:
|
||||||
) # add new deployment to router
|
) # add new deployment to router
|
||||||
return deployment_pydantic_obj
|
return deployment_pydantic_obj
|
||||||
|
|
||||||
def _update_kwargs_with_deployment(self, deployment: dict, kwargs: dict) -> None:
|
def _update_kwargs_with_deployment(
|
||||||
|
self,
|
||||||
|
deployment: dict,
|
||||||
|
kwargs: dict,
|
||||||
|
function_name: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
2 jobs:
|
2 jobs:
|
||||||
- Adds selected deployment, model_info and api_base to kwargs["metadata"] (used for logging)
|
- Adds selected deployment, model_info and api_base to kwargs["metadata"] (used for logging)
|
||||||
|
@ -1144,7 +1149,10 @@ class Router:
|
||||||
deployment_model_name = deployment_pydantic_obj.litellm_params.model
|
deployment_model_name = deployment_pydantic_obj.litellm_params.model
|
||||||
deployment_api_base = deployment_pydantic_obj.litellm_params.api_base
|
deployment_api_base = deployment_pydantic_obj.litellm_params.api_base
|
||||||
|
|
||||||
kwargs.setdefault("metadata", {}).update(
|
metadata_variable_name = _get_router_metadata_variable_name(
|
||||||
|
function_name=function_name,
|
||||||
|
)
|
||||||
|
kwargs.setdefault(metadata_variable_name, {}).update(
|
||||||
{
|
{
|
||||||
"deployment": deployment_model_name,
|
"deployment": deployment_model_name,
|
||||||
"model_info": model_info,
|
"model_info": model_info,
|
||||||
|
@ -2402,7 +2410,9 @@ class Router:
|
||||||
messages=kwargs.get("messages", None),
|
messages=kwargs.get("messages", None),
|
||||||
specific_deployment=kwargs.pop("specific_deployment", None),
|
specific_deployment=kwargs.pop("specific_deployment", None),
|
||||||
)
|
)
|
||||||
self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
|
self._update_kwargs_with_deployment(
|
||||||
|
deployment=deployment, kwargs=kwargs, function_name="generic_api_call"
|
||||||
|
)
|
||||||
|
|
||||||
data = deployment["litellm_params"].copy()
|
data = deployment["litellm_params"].copy()
|
||||||
model_name = data["model"]
|
model_name = data["model"]
|
||||||
|
@ -2481,7 +2491,9 @@ class Router:
|
||||||
messages=kwargs.get("messages", None),
|
messages=kwargs.get("messages", None),
|
||||||
specific_deployment=kwargs.pop("specific_deployment", None),
|
specific_deployment=kwargs.pop("specific_deployment", None),
|
||||||
)
|
)
|
||||||
self._update_kwargs_with_deployment(deployment=deployment, kwargs=kwargs)
|
self._update_kwargs_with_deployment(
|
||||||
|
deployment=deployment, kwargs=kwargs, function_name="generic_api_call"
|
||||||
|
)
|
||||||
|
|
||||||
data = deployment["litellm_params"].copy()
|
data = deployment["litellm_params"].copy()
|
||||||
model_name = data["model"]
|
model_name = data["model"]
|
||||||
|
|
|
@ -56,7 +56,8 @@ def _get_router_metadata_variable_name(function_name) -> str:
|
||||||
|
|
||||||
For ALL other endpoints we call this "metadata
|
For ALL other endpoints we call this "metadata
|
||||||
"""
|
"""
|
||||||
if "batch" in function_name:
|
ROUTER_METHODS_USING_LITELLM_METADATA = set(["batch", "generic_api_call"])
|
||||||
|
if function_name in ROUTER_METHODS_USING_LITELLM_METADATA:
|
||||||
return "litellm_metadata"
|
return "litellm_metadata"
|
||||||
else:
|
else:
|
||||||
return "metadata"
|
return "metadata"
|
||||||
|
|
|
@ -503,3 +503,99 @@ async def test_openai_responses_api_streaming_validation(sync_mode):
|
||||||
assert not missing_events, f"Missing required event types: {missing_events}"
|
assert not missing_events, f"Missing required event types: {missing_events}"
|
||||||
|
|
||||||
print(f"Successfully validated all event types: {event_types_seen}")
|
print(f"Successfully validated all event types: {event_types_seen}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_openai_responses_litellm_router(sync_mode):
|
||||||
|
"""
|
||||||
|
Test the OpenAI responses API with LiteLLM Router in both sync and async modes
|
||||||
|
"""
|
||||||
|
litellm._turn_on_debug()
|
||||||
|
router = litellm.Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "gpt4o-special-alias",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-4o",
|
||||||
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Call the handler
|
||||||
|
if sync_mode:
|
||||||
|
response = router.responses(
|
||||||
|
model="gpt4o-special-alias",
|
||||||
|
input="Hello, can you tell me a short joke?",
|
||||||
|
max_output_tokens=100,
|
||||||
|
)
|
||||||
|
print("SYNC MODE RESPONSE=", response)
|
||||||
|
else:
|
||||||
|
response = await router.aresponses(
|
||||||
|
model="gpt4o-special-alias",
|
||||||
|
input="Hello, can you tell me a short joke?",
|
||||||
|
max_output_tokens=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Router {'sync' if sync_mode else 'async'} response=",
|
||||||
|
json.dumps(response, indent=4, default=str),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use the helper function to validate the response
|
||||||
|
validate_responses_api_response(response, final_chunk=True)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_openai_responses_litellm_router_streaming(sync_mode):
|
||||||
|
"""
|
||||||
|
Test the OpenAI responses API with streaming through LiteLLM Router
|
||||||
|
"""
|
||||||
|
litellm._turn_on_debug()
|
||||||
|
router = litellm.Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "gpt4o-special-alias",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-4o",
|
||||||
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
event_types_seen = set()
|
||||||
|
|
||||||
|
if sync_mode:
|
||||||
|
response = router.responses(
|
||||||
|
model="gpt4o-special-alias",
|
||||||
|
input="Tell me about artificial intelligence in 2 sentences.",
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
for event in response:
|
||||||
|
print(f"Validating event type: {event.type}")
|
||||||
|
validate_stream_event(event)
|
||||||
|
event_types_seen.add(event.type)
|
||||||
|
else:
|
||||||
|
response = await router.aresponses(
|
||||||
|
model="gpt4o-special-alias",
|
||||||
|
input="Tell me about artificial intelligence in 2 sentences.",
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
async for event in response:
|
||||||
|
print(f"Validating event type: {event.type}")
|
||||||
|
validate_stream_event(event)
|
||||||
|
event_types_seen.add(event.type)
|
||||||
|
|
||||||
|
# At minimum, we should see these core event types
|
||||||
|
required_events = {"response.created", "response.completed"}
|
||||||
|
|
||||||
|
missing_events = required_events - event_types_seen
|
||||||
|
assert not missing_events, f"Missing required event types: {missing_events}"
|
||||||
|
|
||||||
|
print(f"Successfully validated all event types: {event_types_seen}")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue