diff --git a/litellm/llms/openrouter/chat/transformation.py b/litellm/llms/openrouter/chat/transformation.py index 452921f551..77f402a131 100644 --- a/litellm/llms/openrouter/chat/transformation.py +++ b/litellm/llms/openrouter/chat/transformation.py @@ -1,17 +1,18 @@ """ -Support for OpenAI's `/v1/chat/completions` endpoint. +Support for OpenAI's `/v1/chat/completions` endpoint. Calls done in OpenAI/openai.py as OpenRouter is openai-compatible. Docs: https://openrouter.ai/docs/parameters """ -from typing import Any, AsyncIterator, Iterator, Optional, Union +from typing import Any, AsyncIterator, Iterator, List, Optional, Union import httpx from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator from litellm.llms.base_llm.chat.transformation import BaseLLMException +from litellm.types.llms.openai import AllMessageValues from litellm.types.llms.openrouter import OpenRouterErrorMessage from litellm.types.utils import ModelResponse, ModelResponseStream @@ -47,6 +48,27 @@ class OpenrouterConfig(OpenAIGPTConfig): ] = extra_body # openai client supports `extra_body` param return mapped_openai_params + def transform_request( + self, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + headers: dict, + ) -> dict: + """ + Transform the overall request to be sent to the API. + + Returns: + dict: The transformed request. Sent as the body of the API call. + """ + extra_body = optional_params.pop("extra_body", {}) + response = super().transform_request( + model, messages, optional_params, litellm_params, headers + ) + response.update(extra_body) + return response + def get_error_class( self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers] ) -> BaseLLMException: diff --git a/litellm/main.py b/litellm/main.py index 11aa7a78d4..dcc277343e 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -452,7 +452,7 @@ async def acompletion( fallbacks = fallbacks or litellm.model_fallbacks if fallbacks is not None: response = await async_completion_with_fallbacks( - **completion_kwargs, kwargs={"fallbacks": fallbacks} + **completion_kwargs, kwargs={"fallbacks": fallbacks, **kwargs} ) if response is None: raise Exception( diff --git a/poetry.lock b/poetry.lock index b6200d3180..7983887ecd 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3105,17 +3105,17 @@ requests = "2.31.0" [[package]] name = "respx" -version = "0.20.2" +version = "0.22.0" description = "A utility for mocking out the Python HTTPX and HTTP Core libraries." optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "respx-0.20.2-py2.py3-none-any.whl", hash = "sha256:ab8e1cf6da28a5b2dd883ea617f8130f77f676736e6e9e4a25817ad116a172c9"}, - {file = "respx-0.20.2.tar.gz", hash = "sha256:07cf4108b1c88b82010f67d3c831dae33a375c7b436e54d87737c7f9f99be643"}, + {file = "respx-0.22.0-py2.py3-none-any.whl", hash = "sha256:631128d4c9aba15e56903fb5f66fb1eff412ce28dd387ca3a81339e52dbd3ad0"}, + {file = "respx-0.22.0.tar.gz", hash = "sha256:3c8924caa2a50bd71aefc07aa812f2466ff489f1848c96e954a5362d17095d91"}, ] [package.dependencies] -httpx = ">=0.21.0" +httpx = ">=0.25.0" [[package]] name = "rpds-py" @@ -4056,4 +4056,4 @@ proxy = ["PyJWT", "apscheduler", "backoff", "boto3", "cryptography", "fastapi", [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0, !=3.9.7" -content-hash = "524b2f8276ba057f8dc8a79dd460c1a243ef4aece7c08a8bf344e029e07b8841" +content-hash = "27c2090e5190d8b37948419dd8dd6234dd0ab7ea81a222aa81601596382472fc" diff --git a/pyproject.toml b/pyproject.toml index 37870631d2..ac14a9af51 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -101,7 +101,7 @@ mypy = "^1.0" pytest = "^7.4.3" pytest-mock = "^3.12.0" pytest-asyncio = "^0.21.1" -respx = "^0.20.2" +respx = "^0.22.0" ruff = "^0.1.0" types-requests = "*" types-setuptools = "*" diff --git a/tests/litellm/llms/openrouter/chat/test_openrouter_chat_transformation.py b/tests/litellm/llms/openrouter/chat/test_openrouter_chat_transformation.py index c5d7a2c278..de0b284f0a 100644 --- a/tests/litellm/llms/openrouter/chat/test_openrouter_chat_transformation.py +++ b/tests/litellm/llms/openrouter/chat/test_openrouter_chat_transformation.py @@ -1,11 +1,9 @@ -import json import os import sys -from unittest.mock import AsyncMock, MagicMock, patch -import httpx import pytest + sys.path.insert( 0, os.path.abspath("../../../../..") ) # Adds the parent directory to the system path @@ -13,6 +11,7 @@ sys.path.insert( from litellm.llms.openrouter.chat.transformation import ( OpenRouterChatCompletionStreamingHandler, OpenRouterException, + OpenrouterConfig, ) @@ -79,3 +78,20 @@ class TestOpenRouterChatCompletionStreamingHandler: assert "KeyError" in str(exc_info.value) assert exc_info.value.status_code == 400 + + +def test_openrouter_extra_body_transformation(): + + transformed_request = OpenrouterConfig().transform_request( + model="openrouter/deepseek/deepseek-chat", + messages=[{"role": "user", "content": "Hello, world!"}], + optional_params={"extra_body": {"provider": {"order": ["DeepSeek"]}}}, + litellm_params={}, + headers={}, + ) + + # https://github.com/BerriAI/litellm/issues/8425, validate its not contained in extra_body still + assert transformed_request["provider"]["order"] == ["DeepSeek"] + assert transformed_request["messages"] == [ + {"role": "user", "content": "Hello, world!"} + ] diff --git a/tests/litellm/test_main.py b/tests/litellm/test_main.py index 57161e7dd7..b3e085df6c 100644 --- a/tests/litellm/test_main.py +++ b/tests/litellm/test_main.py @@ -1,8 +1,10 @@ import json import os import sys - +import httpx import pytest +import respx + from fastapi.testclient import TestClient sys.path.insert( @@ -259,3 +261,84 @@ def test_bedrock_latency_optimized_inference(): mock_post.assert_called_once() json_data = json.loads(mock_post.call_args.kwargs["data"]) assert json_data["performanceConfig"]["latency"] == "optimized" + +@pytest.fixture(autouse=True) +def set_openrouter_api_key(): + original_api_key = os.environ.get("OPENROUTER_API_KEY") + os.environ["OPENROUTER_API_KEY"] = "fake-key-for-testing" + yield + if original_api_key is not None: + os.environ["OPENROUTER_API_KEY"] = original_api_key + else: + del os.environ["OPENROUTER_API_KEY"] + + +@pytest.mark.asyncio +async def test_extra_body_with_fallback(respx_mock: respx.MockRouter, set_openrouter_api_key): + """ + test regression for https://github.com/BerriAI/litellm/issues/8425. + + This was perhaps a wider issue with the acompletion function not passing kwargs such as extra_body correctly when fallbacks are specified. + """ + # Set up test parameters + model = "openrouter/deepseek/deepseek-chat" + messages = [{"role": "user", "content": "Hello, world!"}] + extra_body = { + "provider": { + "order": ["DeepSeek"], + "allow_fallbacks": False, + "require_parameters": True + } + } + fallbacks = [ + { + "model": "openrouter/google/gemini-flash-1.5-8b" + } + ] + + respx_mock.post("https://openrouter.ai/api/v1/chat/completions").respond( + json={ + "id": "chatcmpl-123", + "object": "chat.completion", + "created": 1677652288, + "model": model, + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Hello from mocked response!", + }, + "finish_reason": "stop", + } + ], + "usage": {"prompt_tokens": 9, "completion_tokens": 12, "total_tokens": 21}, + } + ) + + response = await litellm.acompletion( + model=model, + messages=messages, + extra_body=extra_body, + fallbacks=fallbacks, + api_key="fake-openrouter-api-key", + ) + + # Get the request from the mock + request: httpx.Request = respx_mock.calls[0].request + request_body = request.read() + request_body = json.loads(request_body) + + # Verify basic parameters + assert request_body["model"] == "deepseek/deepseek-chat" + assert request_body["messages"] == messages + + # Verify the extra_body parameters remain under the provider key + assert request_body["provider"]["order"] == ["DeepSeek"] + assert request_body["provider"]["allow_fallbacks"] is False + assert request_body["provider"]["require_parameters"] is True + + # Verify the response + assert response is not None + assert response.choices[0].message.content == "Hello from mocked response!" + \ No newline at end of file