mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
This commit is contained in:
parent
bdad9961e3
commit
d640bc0a00
6 changed files with 135 additions and 14 deletions
|
@ -6,12 +6,13 @@ Calls done in OpenAI/openai.py as OpenRouter is openai-compatible.
|
|||
Docs: https://openrouter.ai/docs/parameters
|
||||
"""
|
||||
|
||||
from typing import Any, AsyncIterator, Iterator, Optional, Union
|
||||
from typing import Any, AsyncIterator, Iterator, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.llms.openrouter import OpenRouterErrorMessage
|
||||
from litellm.types.utils import ModelResponse, ModelResponseStream
|
||||
|
||||
|
@ -47,6 +48,27 @@ class OpenrouterConfig(OpenAIGPTConfig):
|
|||
] = extra_body # openai client supports `extra_body` param
|
||||
return mapped_openai_params
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Transform the overall request to be sent to the API.
|
||||
|
||||
Returns:
|
||||
dict: The transformed request. Sent as the body of the API call.
|
||||
"""
|
||||
extra_body = optional_params.pop("extra_body", {})
|
||||
response = super().transform_request(
|
||||
model, messages, optional_params, litellm_params, headers
|
||||
)
|
||||
response.update(extra_body)
|
||||
return response
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
|
|
|
@ -452,7 +452,7 @@ async def acompletion(
|
|||
fallbacks = fallbacks or litellm.model_fallbacks
|
||||
if fallbacks is not None:
|
||||
response = await async_completion_with_fallbacks(
|
||||
**completion_kwargs, kwargs={"fallbacks": fallbacks}
|
||||
**completion_kwargs, kwargs={"fallbacks": fallbacks, **kwargs}
|
||||
)
|
||||
if response is None:
|
||||
raise Exception(
|
||||
|
|
12
poetry.lock
generated
12
poetry.lock
generated
|
@ -3105,17 +3105,17 @@ requests = "2.31.0"
|
|||
|
||||
[[package]]
|
||||
name = "respx"
|
||||
version = "0.20.2"
|
||||
version = "0.22.0"
|
||||
description = "A utility for mocking out the Python HTTPX and HTTP Core libraries."
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "respx-0.20.2-py2.py3-none-any.whl", hash = "sha256:ab8e1cf6da28a5b2dd883ea617f8130f77f676736e6e9e4a25817ad116a172c9"},
|
||||
{file = "respx-0.20.2.tar.gz", hash = "sha256:07cf4108b1c88b82010f67d3c831dae33a375c7b436e54d87737c7f9f99be643"},
|
||||
{file = "respx-0.22.0-py2.py3-none-any.whl", hash = "sha256:631128d4c9aba15e56903fb5f66fb1eff412ce28dd387ca3a81339e52dbd3ad0"},
|
||||
{file = "respx-0.22.0.tar.gz", hash = "sha256:3c8924caa2a50bd71aefc07aa812f2466ff489f1848c96e954a5362d17095d91"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
httpx = ">=0.21.0"
|
||||
httpx = ">=0.25.0"
|
||||
|
||||
[[package]]
|
||||
name = "rpds-py"
|
||||
|
@ -4056,4 +4056,4 @@ proxy = ["PyJWT", "apscheduler", "backoff", "boto3", "cryptography", "fastapi",
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.8.1,<4.0, !=3.9.7"
|
||||
content-hash = "524b2f8276ba057f8dc8a79dd460c1a243ef4aece7c08a8bf344e029e07b8841"
|
||||
content-hash = "27c2090e5190d8b37948419dd8dd6234dd0ab7ea81a222aa81601596382472fc"
|
||||
|
|
|
@ -101,7 +101,7 @@ mypy = "^1.0"
|
|||
pytest = "^7.4.3"
|
||||
pytest-mock = "^3.12.0"
|
||||
pytest-asyncio = "^0.21.1"
|
||||
respx = "^0.20.2"
|
||||
respx = "^0.22.0"
|
||||
ruff = "^0.1.0"
|
||||
types-requests = "*"
|
||||
types-setuptools = "*"
|
||||
|
|
|
@ -1,11 +1,9 @@
|
|||
import json
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
@ -13,6 +11,7 @@ sys.path.insert(
|
|||
from litellm.llms.openrouter.chat.transformation import (
|
||||
OpenRouterChatCompletionStreamingHandler,
|
||||
OpenRouterException,
|
||||
OpenrouterConfig,
|
||||
)
|
||||
|
||||
|
||||
|
@ -79,3 +78,20 @@ class TestOpenRouterChatCompletionStreamingHandler:
|
|||
|
||||
assert "KeyError" in str(exc_info.value)
|
||||
assert exc_info.value.status_code == 400
|
||||
|
||||
|
||||
def test_openrouter_extra_body_transformation():
|
||||
|
||||
transformed_request = OpenrouterConfig().transform_request(
|
||||
model="openrouter/deepseek/deepseek-chat",
|
||||
messages=[{"role": "user", "content": "Hello, world!"}],
|
||||
optional_params={"extra_body": {"provider": {"order": ["DeepSeek"]}}},
|
||||
litellm_params={},
|
||||
headers={},
|
||||
)
|
||||
|
||||
# https://github.com/BerriAI/litellm/issues/8425, validate its not contained in extra_body still
|
||||
assert transformed_request["provider"]["order"] == ["DeepSeek"]
|
||||
assert transformed_request["messages"] == [
|
||||
{"role": "user", "content": "Hello, world!"}
|
||||
]
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import respx
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
sys.path.insert(
|
||||
|
@ -259,3 +261,84 @@ def test_bedrock_latency_optimized_inference():
|
|||
mock_post.assert_called_once()
|
||||
json_data = json.loads(mock_post.call_args.kwargs["data"])
|
||||
assert json_data["performanceConfig"]["latency"] == "optimized"
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def set_openrouter_api_key():
|
||||
original_api_key = os.environ.get("OPENROUTER_API_KEY")
|
||||
os.environ["OPENROUTER_API_KEY"] = "fake-key-for-testing"
|
||||
yield
|
||||
if original_api_key is not None:
|
||||
os.environ["OPENROUTER_API_KEY"] = original_api_key
|
||||
else:
|
||||
del os.environ["OPENROUTER_API_KEY"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extra_body_with_fallback(respx_mock: respx.MockRouter, set_openrouter_api_key):
|
||||
"""
|
||||
test regression for https://github.com/BerriAI/litellm/issues/8425.
|
||||
|
||||
This was perhaps a wider issue with the acompletion function not passing kwargs such as extra_body correctly when fallbacks are specified.
|
||||
"""
|
||||
# Set up test parameters
|
||||
model = "openrouter/deepseek/deepseek-chat"
|
||||
messages = [{"role": "user", "content": "Hello, world!"}]
|
||||
extra_body = {
|
||||
"provider": {
|
||||
"order": ["DeepSeek"],
|
||||
"allow_fallbacks": False,
|
||||
"require_parameters": True
|
||||
}
|
||||
}
|
||||
fallbacks = [
|
||||
{
|
||||
"model": "openrouter/google/gemini-flash-1.5-8b"
|
||||
}
|
||||
]
|
||||
|
||||
respx_mock.post("https://openrouter.ai/api/v1/chat/completions").respond(
|
||||
json={
|
||||
"id": "chatcmpl-123",
|
||||
"object": "chat.completion",
|
||||
"created": 1677652288,
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Hello from mocked response!",
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 9, "completion_tokens": 12, "total_tokens": 21},
|
||||
}
|
||||
)
|
||||
|
||||
response = await litellm.acompletion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
extra_body=extra_body,
|
||||
fallbacks=fallbacks,
|
||||
api_key="fake-openrouter-api-key",
|
||||
)
|
||||
|
||||
# Get the request from the mock
|
||||
request: httpx.Request = respx_mock.calls[0].request
|
||||
request_body = request.read()
|
||||
request_body = json.loads(request_body)
|
||||
|
||||
# Verify basic parameters
|
||||
assert request_body["model"] == "deepseek/deepseek-chat"
|
||||
assert request_body["messages"] == messages
|
||||
|
||||
# Verify the extra_body parameters remain under the provider key
|
||||
assert request_body["provider"]["order"] == ["DeepSeek"]
|
||||
assert request_body["provider"]["allow_fallbacks"] is False
|
||||
assert request_body["provider"]["require_parameters"] is True
|
||||
|
||||
# Verify the response
|
||||
assert response is not None
|
||||
assert response.choices[0].message.content == "Hello from mocked response!"
|
||||
|
Loading…
Add table
Add a link
Reference in a new issue