mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
This commit is contained in:
parent
bdad9961e3
commit
d640bc0a00
6 changed files with 135 additions and 14 deletions
|
@ -1,17 +1,18 @@
|
||||||
"""
|
"""
|
||||||
Support for OpenAI's `/v1/chat/completions` endpoint.
|
Support for OpenAI's `/v1/chat/completions` endpoint.
|
||||||
|
|
||||||
Calls done in OpenAI/openai.py as OpenRouter is openai-compatible.
|
Calls done in OpenAI/openai.py as OpenRouter is openai-compatible.
|
||||||
|
|
||||||
Docs: https://openrouter.ai/docs/parameters
|
Docs: https://openrouter.ai/docs/parameters
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, AsyncIterator, Iterator, Optional, Union
|
from typing import Any, AsyncIterator, Iterator, List, Optional, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
|
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
|
||||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||||
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
from litellm.types.llms.openrouter import OpenRouterErrorMessage
|
from litellm.types.llms.openrouter import OpenRouterErrorMessage
|
||||||
from litellm.types.utils import ModelResponse, ModelResponseStream
|
from litellm.types.utils import ModelResponse, ModelResponseStream
|
||||||
|
|
||||||
|
@ -47,6 +48,27 @@ class OpenrouterConfig(OpenAIGPTConfig):
|
||||||
] = extra_body # openai client supports `extra_body` param
|
] = extra_body # openai client supports `extra_body` param
|
||||||
return mapped_openai_params
|
return mapped_openai_params
|
||||||
|
|
||||||
|
def transform_request(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
|
headers: dict,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Transform the overall request to be sent to the API.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: The transformed request. Sent as the body of the API call.
|
||||||
|
"""
|
||||||
|
extra_body = optional_params.pop("extra_body", {})
|
||||||
|
response = super().transform_request(
|
||||||
|
model, messages, optional_params, litellm_params, headers
|
||||||
|
)
|
||||||
|
response.update(extra_body)
|
||||||
|
return response
|
||||||
|
|
||||||
def get_error_class(
|
def get_error_class(
|
||||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||||
) -> BaseLLMException:
|
) -> BaseLLMException:
|
||||||
|
|
|
@ -452,7 +452,7 @@ async def acompletion(
|
||||||
fallbacks = fallbacks or litellm.model_fallbacks
|
fallbacks = fallbacks or litellm.model_fallbacks
|
||||||
if fallbacks is not None:
|
if fallbacks is not None:
|
||||||
response = await async_completion_with_fallbacks(
|
response = await async_completion_with_fallbacks(
|
||||||
**completion_kwargs, kwargs={"fallbacks": fallbacks}
|
**completion_kwargs, kwargs={"fallbacks": fallbacks, **kwargs}
|
||||||
)
|
)
|
||||||
if response is None:
|
if response is None:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
|
|
12
poetry.lock
generated
12
poetry.lock
generated
|
@ -3105,17 +3105,17 @@ requests = "2.31.0"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "respx"
|
name = "respx"
|
||||||
version = "0.20.2"
|
version = "0.22.0"
|
||||||
description = "A utility for mocking out the Python HTTPX and HTTP Core libraries."
|
description = "A utility for mocking out the Python HTTPX and HTTP Core libraries."
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.7"
|
python-versions = ">=3.8"
|
||||||
files = [
|
files = [
|
||||||
{file = "respx-0.20.2-py2.py3-none-any.whl", hash = "sha256:ab8e1cf6da28a5b2dd883ea617f8130f77f676736e6e9e4a25817ad116a172c9"},
|
{file = "respx-0.22.0-py2.py3-none-any.whl", hash = "sha256:631128d4c9aba15e56903fb5f66fb1eff412ce28dd387ca3a81339e52dbd3ad0"},
|
||||||
{file = "respx-0.20.2.tar.gz", hash = "sha256:07cf4108b1c88b82010f67d3c831dae33a375c7b436e54d87737c7f9f99be643"},
|
{file = "respx-0.22.0.tar.gz", hash = "sha256:3c8924caa2a50bd71aefc07aa812f2466ff489f1848c96e954a5362d17095d91"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
httpx = ">=0.21.0"
|
httpx = ">=0.25.0"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rpds-py"
|
name = "rpds-py"
|
||||||
|
@ -4056,4 +4056,4 @@ proxy = ["PyJWT", "apscheduler", "backoff", "boto3", "cryptography", "fastapi",
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = ">=3.8.1,<4.0, !=3.9.7"
|
python-versions = ">=3.8.1,<4.0, !=3.9.7"
|
||||||
content-hash = "524b2f8276ba057f8dc8a79dd460c1a243ef4aece7c08a8bf344e029e07b8841"
|
content-hash = "27c2090e5190d8b37948419dd8dd6234dd0ab7ea81a222aa81601596382472fc"
|
||||||
|
|
|
@ -101,7 +101,7 @@ mypy = "^1.0"
|
||||||
pytest = "^7.4.3"
|
pytest = "^7.4.3"
|
||||||
pytest-mock = "^3.12.0"
|
pytest-mock = "^3.12.0"
|
||||||
pytest-asyncio = "^0.21.1"
|
pytest-asyncio = "^0.21.1"
|
||||||
respx = "^0.20.2"
|
respx = "^0.22.0"
|
||||||
ruff = "^0.1.0"
|
ruff = "^0.1.0"
|
||||||
types-requests = "*"
|
types-requests = "*"
|
||||||
types-setuptools = "*"
|
types-setuptools = "*"
|
||||||
|
|
|
@ -1,11 +1,9 @@
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../../../../..")
|
0, os.path.abspath("../../../../..")
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
|
@ -13,6 +11,7 @@ sys.path.insert(
|
||||||
from litellm.llms.openrouter.chat.transformation import (
|
from litellm.llms.openrouter.chat.transformation import (
|
||||||
OpenRouterChatCompletionStreamingHandler,
|
OpenRouterChatCompletionStreamingHandler,
|
||||||
OpenRouterException,
|
OpenRouterException,
|
||||||
|
OpenrouterConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -79,3 +78,20 @@ class TestOpenRouterChatCompletionStreamingHandler:
|
||||||
|
|
||||||
assert "KeyError" in str(exc_info.value)
|
assert "KeyError" in str(exc_info.value)
|
||||||
assert exc_info.value.status_code == 400
|
assert exc_info.value.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
|
def test_openrouter_extra_body_transformation():
|
||||||
|
|
||||||
|
transformed_request = OpenrouterConfig().transform_request(
|
||||||
|
model="openrouter/deepseek/deepseek-chat",
|
||||||
|
messages=[{"role": "user", "content": "Hello, world!"}],
|
||||||
|
optional_params={"extra_body": {"provider": {"order": ["DeepSeek"]}}},
|
||||||
|
litellm_params={},
|
||||||
|
headers={},
|
||||||
|
)
|
||||||
|
|
||||||
|
# https://github.com/BerriAI/litellm/issues/8425, validate its not contained in extra_body still
|
||||||
|
assert transformed_request["provider"]["order"] == ["DeepSeek"]
|
||||||
|
assert transformed_request["messages"] == [
|
||||||
|
{"role": "user", "content": "Hello, world!"}
|
||||||
|
]
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import httpx
|
||||||
import pytest
|
import pytest
|
||||||
|
import respx
|
||||||
|
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
|
@ -259,3 +261,84 @@ def test_bedrock_latency_optimized_inference():
|
||||||
mock_post.assert_called_once()
|
mock_post.assert_called_once()
|
||||||
json_data = json.loads(mock_post.call_args.kwargs["data"])
|
json_data = json.loads(mock_post.call_args.kwargs["data"])
|
||||||
assert json_data["performanceConfig"]["latency"] == "optimized"
|
assert json_data["performanceConfig"]["latency"] == "optimized"
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def set_openrouter_api_key():
|
||||||
|
original_api_key = os.environ.get("OPENROUTER_API_KEY")
|
||||||
|
os.environ["OPENROUTER_API_KEY"] = "fake-key-for-testing"
|
||||||
|
yield
|
||||||
|
if original_api_key is not None:
|
||||||
|
os.environ["OPENROUTER_API_KEY"] = original_api_key
|
||||||
|
else:
|
||||||
|
del os.environ["OPENROUTER_API_KEY"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_extra_body_with_fallback(respx_mock: respx.MockRouter, set_openrouter_api_key):
|
||||||
|
"""
|
||||||
|
test regression for https://github.com/BerriAI/litellm/issues/8425.
|
||||||
|
|
||||||
|
This was perhaps a wider issue with the acompletion function not passing kwargs such as extra_body correctly when fallbacks are specified.
|
||||||
|
"""
|
||||||
|
# Set up test parameters
|
||||||
|
model = "openrouter/deepseek/deepseek-chat"
|
||||||
|
messages = [{"role": "user", "content": "Hello, world!"}]
|
||||||
|
extra_body = {
|
||||||
|
"provider": {
|
||||||
|
"order": ["DeepSeek"],
|
||||||
|
"allow_fallbacks": False,
|
||||||
|
"require_parameters": True
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fallbacks = [
|
||||||
|
{
|
||||||
|
"model": "openrouter/google/gemini-flash-1.5-8b"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
respx_mock.post("https://openrouter.ai/api/v1/chat/completions").respond(
|
||||||
|
json={
|
||||||
|
"id": "chatcmpl-123",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": 1677652288,
|
||||||
|
"model": model,
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Hello from mocked response!",
|
||||||
|
},
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {"prompt_tokens": 9, "completion_tokens": 12, "total_tokens": 21},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
extra_body=extra_body,
|
||||||
|
fallbacks=fallbacks,
|
||||||
|
api_key="fake-openrouter-api-key",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get the request from the mock
|
||||||
|
request: httpx.Request = respx_mock.calls[0].request
|
||||||
|
request_body = request.read()
|
||||||
|
request_body = json.loads(request_body)
|
||||||
|
|
||||||
|
# Verify basic parameters
|
||||||
|
assert request_body["model"] == "deepseek/deepseek-chat"
|
||||||
|
assert request_body["messages"] == messages
|
||||||
|
|
||||||
|
# Verify the extra_body parameters remain under the provider key
|
||||||
|
assert request_body["provider"]["order"] == ["DeepSeek"]
|
||||||
|
assert request_body["provider"]["allow_fallbacks"] is False
|
||||||
|
assert request_body["provider"]["require_parameters"] is True
|
||||||
|
|
||||||
|
# Verify the response
|
||||||
|
assert response is not None
|
||||||
|
assert response.choices[0].message.content == "Hello from mocked response!"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue