mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
* Handle azure deepseek reasoning response (#8288) * Handle deepseek reasoning response * Add helper method + unit test * Fix: Follow infinity api url format (#8346) * Follow infinity api url format * Update test_infinity.py * fix(infinity/transformation.py): fix linting error --------- Co-authored-by: vibhavbhat <vibhavb00@gmail.com> Co-authored-by: Hao Shan <53949959+haoshan98@users.noreply.github.com>
This commit is contained in:
parent
f651d51f26
commit
b5850b6b65
6 changed files with 86 additions and 5 deletions
|
@ -3,7 +3,8 @@ import json
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Dict, Iterable, List, Literal, Optional, Union
|
import re
|
||||||
|
from typing import Dict, Iterable, List, Literal, Optional, Union, Tuple
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm._logging import verbose_logger
|
from litellm._logging import verbose_logger
|
||||||
|
@ -220,6 +221,16 @@ def _handle_invalid_parallel_tool_calls(
|
||||||
# if there is a JSONDecodeError, return the original tool_calls
|
# if there is a JSONDecodeError, return the original tool_calls
|
||||||
return tool_calls
|
return tool_calls
|
||||||
|
|
||||||
|
def _parse_content_for_reasoning(message_text: Optional[str]) -> Tuple[Optional[str], Optional[str]]:
|
||||||
|
if not message_text:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
reasoning_match = re.match(r"<think>(.*?)</think>(.*)", message_text, re.DOTALL)
|
||||||
|
|
||||||
|
if reasoning_match:
|
||||||
|
return reasoning_match.group(1), reasoning_match.group(2)
|
||||||
|
|
||||||
|
return None, message_text
|
||||||
|
|
||||||
class LiteLLMResponseObjectHandler:
|
class LiteLLMResponseObjectHandler:
|
||||||
|
|
||||||
|
@ -432,8 +443,14 @@ def convert_to_model_response_object( # noqa: PLR0915
|
||||||
for field in choice["message"].keys():
|
for field in choice["message"].keys():
|
||||||
if field not in message_keys:
|
if field not in message_keys:
|
||||||
provider_specific_fields[field] = choice["message"][field]
|
provider_specific_fields[field] = choice["message"][field]
|
||||||
|
|
||||||
|
# Handle reasoning models that display `reasoning_content` within `content`
|
||||||
|
reasoning_content, content = _parse_content_for_reasoning(choice["message"].get("content", None))
|
||||||
|
if reasoning_content:
|
||||||
|
provider_specific_fields["reasoning_content"] = reasoning_content
|
||||||
|
|
||||||
message = Message(
|
message = Message(
|
||||||
content=choice["message"].get("content", None),
|
content=content,
|
||||||
role=choice["message"]["role"] or "assistant",
|
role=choice["message"]["role"] or "assistant",
|
||||||
function_call=choice["message"].get("function_call", None),
|
function_call=choice["message"].get("function_call", None),
|
||||||
tool_calls=tool_calls,
|
tool_calls=tool_calls,
|
||||||
|
|
|
@ -20,6 +20,15 @@ from .common_utils import InfinityError
|
||||||
|
|
||||||
|
|
||||||
class InfinityRerankConfig(CohereRerankConfig):
|
class InfinityRerankConfig(CohereRerankConfig):
|
||||||
|
def get_complete_url(self, api_base: Optional[str], model: str) -> str:
|
||||||
|
if api_base is None:
|
||||||
|
raise ValueError("api_base is required for Infinity rerank")
|
||||||
|
# Remove trailing slashes and ensure clean base URL
|
||||||
|
api_base = api_base.rstrip("/")
|
||||||
|
if not api_base.endswith("/rerank"):
|
||||||
|
api_base = f"{api_base}/rerank"
|
||||||
|
return api_base
|
||||||
|
|
||||||
def validate_environment(
|
def validate_environment(
|
||||||
self,
|
self,
|
||||||
headers: dict,
|
headers: dict,
|
||||||
|
|
|
@ -89,6 +89,7 @@ from litellm.litellm_core_utils.llm_response_utils.convert_dict_to_response impo
|
||||||
convert_to_model_response_object,
|
convert_to_model_response_object,
|
||||||
convert_to_streaming_response,
|
convert_to_streaming_response,
|
||||||
convert_to_streaming_response_async,
|
convert_to_streaming_response_async,
|
||||||
|
_parse_content_for_reasoning,
|
||||||
)
|
)
|
||||||
from litellm.litellm_core_utils.llm_response_utils.get_api_base import get_api_base
|
from litellm.litellm_core_utils.llm_response_utils.get_api_base import get_api_base
|
||||||
from litellm.litellm_core_utils.llm_response_utils.get_formatted_prompt import (
|
from litellm.litellm_core_utils.llm_response_utils.get_formatted_prompt import (
|
||||||
|
|
|
@ -864,6 +864,18 @@ def test_convert_model_response_object():
|
||||||
== '{"type":"error","error":{"type":"invalid_request_error","message":"Output blocked by content filtering policy"}}'
|
== '{"type":"error","error":{"type":"invalid_request_error","message":"Output blocked by content filtering policy"}}'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"content, expected_reasoning, expected_content",
|
||||||
|
[
|
||||||
|
(None, None, None),
|
||||||
|
("<think>I am thinking here</think>The sky is a canvas of blue", "I am thinking here", "The sky is a canvas of blue"),
|
||||||
|
("I am a regular response", None, "I am a regular response"),
|
||||||
|
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def test_parse_content_for_reasoning(content, expected_reasoning, expected_content):
|
||||||
|
assert(litellm.utils._parse_content_for_reasoning(content) == (expected_reasoning, expected_content))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model, expected_bool",
|
"model, expected_bool",
|
||||||
|
|
|
@ -13,7 +13,7 @@ import litellm.types.utils
|
||||||
from litellm.llms.anthropic.chat import ModelResponseIterator
|
from litellm.llms.anthropic.chat import ModelResponseIterator
|
||||||
import httpx
|
import httpx
|
||||||
import json
|
import json
|
||||||
from respx import MockRouter
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
import io
|
import io
|
||||||
|
@ -184,3 +184,45 @@ def test_completion_azure_ai_command_r():
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
def test_azure_deepseek_reasoning_content():
|
||||||
|
import json
|
||||||
|
|
||||||
|
client = HTTPHandler()
|
||||||
|
|
||||||
|
with patch.object(client, "post") as mock_post:
|
||||||
|
mock_response = MagicMock()
|
||||||
|
|
||||||
|
mock_response.text = json.dumps(
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "stop",
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"content": "<think>I am thinking here</think>\n\nThe sky is a canvas of blue",
|
||||||
|
"role": "assistant",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_response.status_code = 200
|
||||||
|
# Add required response attributes
|
||||||
|
mock_response.headers = {"Content-Type": "application/json"}
|
||||||
|
mock_response.json = lambda: json.loads(mock_response.text)
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
|
|
||||||
|
response = litellm.completion(
|
||||||
|
model='azure_ai/deepseek-r1',
|
||||||
|
messages=[{"role": "user", "content": "Hello, world!"}],
|
||||||
|
api_base="https://litellm8397336933.services.ai.azure.com/models/chat/completions",
|
||||||
|
api_key="my-fake-api-key",
|
||||||
|
client=client
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
assert(response.choices[0].message.reasoning_content == "I am thinking here")
|
||||||
|
assert(response.choices[0].message.content == "\n\nThe sky is a canvas of blue")
|
||||||
|
|
|
@ -69,7 +69,7 @@ async def test_infinity_rerank():
|
||||||
_url = mock_post.call_args.kwargs["url"]
|
_url = mock_post.call_args.kwargs["url"]
|
||||||
print("Arguments passed to API=", args_to_api)
|
print("Arguments passed to API=", args_to_api)
|
||||||
print("url = ", _url)
|
print("url = ", _url)
|
||||||
assert _url == "https://api.infinity.ai/v1/rerank"
|
assert _url == "https://api.infinity.ai/rerank"
|
||||||
|
|
||||||
request_data = json.loads(args_to_api)
|
request_data = json.loads(args_to_api)
|
||||||
assert request_data["query"] == expected_payload["query"]
|
assert request_data["query"] == expected_payload["query"]
|
||||||
|
@ -133,7 +133,7 @@ async def test_infinity_rerank_with_env(monkeypatch):
|
||||||
_url = mock_post.call_args.kwargs["url"]
|
_url = mock_post.call_args.kwargs["url"]
|
||||||
print("Arguments passed to API=", args_to_api)
|
print("Arguments passed to API=", args_to_api)
|
||||||
print("url = ", _url)
|
print("url = ", _url)
|
||||||
assert _url == "https://env.infinity.ai/v1/rerank"
|
assert _url == "https://env.infinity.ai/rerank"
|
||||||
|
|
||||||
request_data = json.loads(args_to_api)
|
request_data = json.loads(args_to_api)
|
||||||
assert request_data["query"] == expected_payload["query"]
|
assert request_data["query"] == expected_payload["query"]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue