fix #8425, passthrough kwargs during acompletion, and unwrap extra_body for openrouter (#9747)

This commit is contained in:
Adrian Lyjak 2025-04-04 01:19:40 -04:00 committed by GitHub
parent bdad9961e3
commit d640bc0a00
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 135 additions and 14 deletions

View file

@ -1,17 +1,18 @@
""" """
Support for OpenAI's `/v1/chat/completions` endpoint. Support for OpenAI's `/v1/chat/completions` endpoint.
Calls done in OpenAI/openai.py as OpenRouter is openai-compatible. Calls done in OpenAI/openai.py as OpenRouter is openai-compatible.
Docs: https://openrouter.ai/docs/parameters Docs: https://openrouter.ai/docs/parameters
""" """
from typing import Any, AsyncIterator, Iterator, Optional, Union from typing import Any, AsyncIterator, Iterator, List, Optional, Union
import httpx import httpx
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
from litellm.llms.base_llm.chat.transformation import BaseLLMException from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.types.llms.openai import AllMessageValues
from litellm.types.llms.openrouter import OpenRouterErrorMessage from litellm.types.llms.openrouter import OpenRouterErrorMessage
from litellm.types.utils import ModelResponse, ModelResponseStream from litellm.types.utils import ModelResponse, ModelResponseStream
@ -47,6 +48,27 @@ class OpenrouterConfig(OpenAIGPTConfig):
] = extra_body # openai client supports `extra_body` param ] = extra_body # openai client supports `extra_body` param
return mapped_openai_params return mapped_openai_params
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
"""
Transform the overall request to be sent to the API.
Returns:
dict: The transformed request. Sent as the body of the API call.
"""
extra_body = optional_params.pop("extra_body", {})
response = super().transform_request(
model, messages, optional_params, litellm_params, headers
)
response.update(extra_body)
return response
def get_error_class( def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers] self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException: ) -> BaseLLMException:

View file

@ -452,7 +452,7 @@ async def acompletion(
fallbacks = fallbacks or litellm.model_fallbacks fallbacks = fallbacks or litellm.model_fallbacks
if fallbacks is not None: if fallbacks is not None:
response = await async_completion_with_fallbacks( response = await async_completion_with_fallbacks(
**completion_kwargs, kwargs={"fallbacks": fallbacks} **completion_kwargs, kwargs={"fallbacks": fallbacks, **kwargs}
) )
if response is None: if response is None:
raise Exception( raise Exception(

12
poetry.lock generated
View file

@ -3105,17 +3105,17 @@ requests = "2.31.0"
[[package]] [[package]]
name = "respx" name = "respx"
version = "0.20.2" version = "0.22.0"
description = "A utility for mocking out the Python HTTPX and HTTP Core libraries." description = "A utility for mocking out the Python HTTPX and HTTP Core libraries."
optional = false optional = false
python-versions = ">=3.7" python-versions = ">=3.8"
files = [ files = [
{file = "respx-0.20.2-py2.py3-none-any.whl", hash = "sha256:ab8e1cf6da28a5b2dd883ea617f8130f77f676736e6e9e4a25817ad116a172c9"}, {file = "respx-0.22.0-py2.py3-none-any.whl", hash = "sha256:631128d4c9aba15e56903fb5f66fb1eff412ce28dd387ca3a81339e52dbd3ad0"},
{file = "respx-0.20.2.tar.gz", hash = "sha256:07cf4108b1c88b82010f67d3c831dae33a375c7b436e54d87737c7f9f99be643"}, {file = "respx-0.22.0.tar.gz", hash = "sha256:3c8924caa2a50bd71aefc07aa812f2466ff489f1848c96e954a5362d17095d91"},
] ]
[package.dependencies] [package.dependencies]
httpx = ">=0.21.0" httpx = ">=0.25.0"
[[package]] [[package]]
name = "rpds-py" name = "rpds-py"
@ -4056,4 +4056,4 @@ proxy = ["PyJWT", "apscheduler", "backoff", "boto3", "cryptography", "fastapi",
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.8.1,<4.0, !=3.9.7" python-versions = ">=3.8.1,<4.0, !=3.9.7"
content-hash = "524b2f8276ba057f8dc8a79dd460c1a243ef4aece7c08a8bf344e029e07b8841" content-hash = "27c2090e5190d8b37948419dd8dd6234dd0ab7ea81a222aa81601596382472fc"

View file

@ -101,7 +101,7 @@ mypy = "^1.0"
pytest = "^7.4.3" pytest = "^7.4.3"
pytest-mock = "^3.12.0" pytest-mock = "^3.12.0"
pytest-asyncio = "^0.21.1" pytest-asyncio = "^0.21.1"
respx = "^0.20.2" respx = "^0.22.0"
ruff = "^0.1.0" ruff = "^0.1.0"
types-requests = "*" types-requests = "*"
types-setuptools = "*" types-setuptools = "*"

View file

@ -1,11 +1,9 @@
import json
import os import os
import sys import sys
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import pytest import pytest
sys.path.insert( sys.path.insert(
0, os.path.abspath("../../../../..") 0, os.path.abspath("../../../../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
@ -13,6 +11,7 @@ sys.path.insert(
from litellm.llms.openrouter.chat.transformation import ( from litellm.llms.openrouter.chat.transformation import (
OpenRouterChatCompletionStreamingHandler, OpenRouterChatCompletionStreamingHandler,
OpenRouterException, OpenRouterException,
OpenrouterConfig,
) )
@ -79,3 +78,20 @@ class TestOpenRouterChatCompletionStreamingHandler:
assert "KeyError" in str(exc_info.value) assert "KeyError" in str(exc_info.value)
assert exc_info.value.status_code == 400 assert exc_info.value.status_code == 400
def test_openrouter_extra_body_transformation():
transformed_request = OpenrouterConfig().transform_request(
model="openrouter/deepseek/deepseek-chat",
messages=[{"role": "user", "content": "Hello, world!"}],
optional_params={"extra_body": {"provider": {"order": ["DeepSeek"]}}},
litellm_params={},
headers={},
)
# https://github.com/BerriAI/litellm/issues/8425, validate its not contained in extra_body still
assert transformed_request["provider"]["order"] == ["DeepSeek"]
assert transformed_request["messages"] == [
{"role": "user", "content": "Hello, world!"}
]

View file

@ -1,8 +1,10 @@
import json import json
import os import os
import sys import sys
import httpx
import pytest import pytest
import respx
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
sys.path.insert( sys.path.insert(
@ -259,3 +261,84 @@ def test_bedrock_latency_optimized_inference():
mock_post.assert_called_once() mock_post.assert_called_once()
json_data = json.loads(mock_post.call_args.kwargs["data"]) json_data = json.loads(mock_post.call_args.kwargs["data"])
assert json_data["performanceConfig"]["latency"] == "optimized" assert json_data["performanceConfig"]["latency"] == "optimized"
@pytest.fixture(autouse=True)
def set_openrouter_api_key():
original_api_key = os.environ.get("OPENROUTER_API_KEY")
os.environ["OPENROUTER_API_KEY"] = "fake-key-for-testing"
yield
if original_api_key is not None:
os.environ["OPENROUTER_API_KEY"] = original_api_key
else:
del os.environ["OPENROUTER_API_KEY"]
@pytest.mark.asyncio
async def test_extra_body_with_fallback(respx_mock: respx.MockRouter, set_openrouter_api_key):
"""
test regression for https://github.com/BerriAI/litellm/issues/8425.
This was perhaps a wider issue with the acompletion function not passing kwargs such as extra_body correctly when fallbacks are specified.
"""
# Set up test parameters
model = "openrouter/deepseek/deepseek-chat"
messages = [{"role": "user", "content": "Hello, world!"}]
extra_body = {
"provider": {
"order": ["DeepSeek"],
"allow_fallbacks": False,
"require_parameters": True
}
}
fallbacks = [
{
"model": "openrouter/google/gemini-flash-1.5-8b"
}
]
respx_mock.post("https://openrouter.ai/api/v1/chat/completions").respond(
json={
"id": "chatcmpl-123",
"object": "chat.completion",
"created": 1677652288,
"model": model,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "Hello from mocked response!",
},
"finish_reason": "stop",
}
],
"usage": {"prompt_tokens": 9, "completion_tokens": 12, "total_tokens": 21},
}
)
response = await litellm.acompletion(
model=model,
messages=messages,
extra_body=extra_body,
fallbacks=fallbacks,
api_key="fake-openrouter-api-key",
)
# Get the request from the mock
request: httpx.Request = respx_mock.calls[0].request
request_body = request.read()
request_body = json.loads(request_body)
# Verify basic parameters
assert request_body["model"] == "deepseek/deepseek-chat"
assert request_body["messages"] == messages
# Verify the extra_body parameters remain under the provider key
assert request_body["provider"]["order"] == ["DeepSeek"]
assert request_body["provider"]["allow_fallbacks"] is False
assert request_body["provider"]["require_parameters"] is True
# Verify the response
assert response is not None
assert response.choices[0].message.content == "Hello from mocked response!"