Handle azure deepseek reasoning response (#8288) (#8366)

* Handle azure deepseek reasoning response (#8288)

* Handle deepseek reasoning response

* Add helper method + unit test

* Fix: Follow infinity api url format (#8346)

* Follow infinity api url format

* Update test_infinity.py

* fix(infinity/transformation.py): fix linting error

---------

Co-authored-by: vibhavbhat <vibhavb00@gmail.com>
Co-authored-by: Hao Shan <53949959+haoshan98@users.noreply.github.com>
This commit is contained in:
Krish Dholakia 2025-02-07 17:45:51 -08:00 committed by GitHub
parent f651d51f26
commit b5850b6b65
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 86 additions and 5 deletions

View file

@ -3,7 +3,8 @@ import json
import time
import traceback
import uuid
from typing import Dict, Iterable, List, Literal, Optional, Union
import re
from typing import Dict, Iterable, List, Literal, Optional, Union, Tuple
import litellm
from litellm._logging import verbose_logger
@ -220,6 +221,16 @@ def _handle_invalid_parallel_tool_calls(
# if there is a JSONDecodeError, return the original tool_calls
return tool_calls
def _parse_content_for_reasoning(message_text: Optional[str]) -> Tuple[Optional[str], Optional[str]]:
if not message_text:
return None, None
reasoning_match = re.match(r"<think>(.*?)</think>(.*)", message_text, re.DOTALL)
if reasoning_match:
return reasoning_match.group(1), reasoning_match.group(2)
return None, message_text
class LiteLLMResponseObjectHandler:
@ -432,8 +443,14 @@ def convert_to_model_response_object( # noqa: PLR0915
for field in choice["message"].keys():
if field not in message_keys:
provider_specific_fields[field] = choice["message"][field]
# Handle reasoning models that display `reasoning_content` within `content`
reasoning_content, content = _parse_content_for_reasoning(choice["message"].get("content", None))
if reasoning_content:
provider_specific_fields["reasoning_content"] = reasoning_content
message = Message(
content=choice["message"].get("content", None),
content=content,
role=choice["message"]["role"] or "assistant",
function_call=choice["message"].get("function_call", None),
tool_calls=tool_calls,

View file

@ -20,6 +20,15 @@ from .common_utils import InfinityError
class InfinityRerankConfig(CohereRerankConfig):
def get_complete_url(self, api_base: Optional[str], model: str) -> str:
if api_base is None:
raise ValueError("api_base is required for Infinity rerank")
# Remove trailing slashes and ensure clean base URL
api_base = api_base.rstrip("/")
if not api_base.endswith("/rerank"):
api_base = f"{api_base}/rerank"
return api_base
def validate_environment(
self,
headers: dict,

View file

@ -89,6 +89,7 @@ from litellm.litellm_core_utils.llm_response_utils.convert_dict_to_response impo
convert_to_model_response_object,
convert_to_streaming_response,
convert_to_streaming_response_async,
_parse_content_for_reasoning,
)
from litellm.litellm_core_utils.llm_response_utils.get_api_base import get_api_base
from litellm.litellm_core_utils.llm_response_utils.get_formatted_prompt import (

View file

@ -864,6 +864,18 @@ def test_convert_model_response_object():
== '{"type":"error","error":{"type":"invalid_request_error","message":"Output blocked by content filtering policy"}}'
)
@pytest.mark.parametrize(
"content, expected_reasoning, expected_content",
[
(None, None, None),
("<think>I am thinking here</think>The sky is a canvas of blue", "I am thinking here", "The sky is a canvas of blue"),
("I am a regular response", None, "I am a regular response"),
]
)
def test_parse_content_for_reasoning(content, expected_reasoning, expected_content):
assert(litellm.utils._parse_content_for_reasoning(content) == (expected_reasoning, expected_content))
@pytest.mark.parametrize(
"model, expected_bool",

View file

@ -13,7 +13,7 @@ import litellm.types.utils
from litellm.llms.anthropic.chat import ModelResponseIterator
import httpx
import json
from respx import MockRouter
from litellm.llms.custom_httpx.http_handler import HTTPHandler
load_dotenv()
import io
@ -184,3 +184,45 @@ def test_completion_azure_ai_command_r():
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_azure_deepseek_reasoning_content():
import json
client = HTTPHandler()
with patch.object(client, "post") as mock_post:
mock_response = MagicMock()
mock_response.text = json.dumps(
{
"choices": [
{
"finish_reason": "stop",
"index": 0,
"message": {
"content": "<think>I am thinking here</think>\n\nThe sky is a canvas of blue",
"role": "assistant",
}
}
],
}
)
mock_response.status_code = 200
# Add required response attributes
mock_response.headers = {"Content-Type": "application/json"}
mock_response.json = lambda: json.loads(mock_response.text)
mock_post.return_value = mock_response
response = litellm.completion(
model='azure_ai/deepseek-r1',
messages=[{"role": "user", "content": "Hello, world!"}],
api_base="https://litellm8397336933.services.ai.azure.com/models/chat/completions",
api_key="my-fake-api-key",
client=client
)
print(response)
assert(response.choices[0].message.reasoning_content == "I am thinking here")
assert(response.choices[0].message.content == "\n\nThe sky is a canvas of blue")

View file

@ -69,7 +69,7 @@ async def test_infinity_rerank():
_url = mock_post.call_args.kwargs["url"]
print("Arguments passed to API=", args_to_api)
print("url = ", _url)
assert _url == "https://api.infinity.ai/v1/rerank"
assert _url == "https://api.infinity.ai/rerank"
request_data = json.loads(args_to_api)
assert request_data["query"] == expected_payload["query"]
@ -133,7 +133,7 @@ async def test_infinity_rerank_with_env(monkeypatch):
_url = mock_post.call_args.kwargs["url"]
print("Arguments passed to API=", args_to_api)
print("url = ", _url)
assert _url == "https://env.infinity.ai/v1/rerank"
assert _url == "https://env.infinity.ai/rerank"
request_data = json.loads(args_to_api)
assert request_data["query"] == expected_payload["query"]