mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
* Handle azure deepseek reasoning response (#8288) * Handle deepseek reasoning response * Add helper method + unit test * Fix: Follow infinity api url format (#8346) * Follow infinity api url format * Update test_infinity.py * fix(infinity/transformation.py): fix linting error --------- Co-authored-by: vibhavbhat <vibhavb00@gmail.com> Co-authored-by: Hao Shan <53949959+haoshan98@users.noreply.github.com>
This commit is contained in:
parent
f651d51f26
commit
b5850b6b65
6 changed files with 86 additions and 5 deletions
|
@ -3,7 +3,8 @@ import json
|
|||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
from typing import Dict, Iterable, List, Literal, Optional, Union
|
||||
import re
|
||||
from typing import Dict, Iterable, List, Literal, Optional, Union, Tuple
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
|
@ -220,6 +221,16 @@ def _handle_invalid_parallel_tool_calls(
|
|||
# if there is a JSONDecodeError, return the original tool_calls
|
||||
return tool_calls
|
||||
|
||||
def _parse_content_for_reasoning(message_text: Optional[str]) -> Tuple[Optional[str], Optional[str]]:
|
||||
if not message_text:
|
||||
return None, None
|
||||
|
||||
reasoning_match = re.match(r"<think>(.*?)</think>(.*)", message_text, re.DOTALL)
|
||||
|
||||
if reasoning_match:
|
||||
return reasoning_match.group(1), reasoning_match.group(2)
|
||||
|
||||
return None, message_text
|
||||
|
||||
class LiteLLMResponseObjectHandler:
|
||||
|
||||
|
@ -432,8 +443,14 @@ def convert_to_model_response_object( # noqa: PLR0915
|
|||
for field in choice["message"].keys():
|
||||
if field not in message_keys:
|
||||
provider_specific_fields[field] = choice["message"][field]
|
||||
|
||||
# Handle reasoning models that display `reasoning_content` within `content`
|
||||
reasoning_content, content = _parse_content_for_reasoning(choice["message"].get("content", None))
|
||||
if reasoning_content:
|
||||
provider_specific_fields["reasoning_content"] = reasoning_content
|
||||
|
||||
message = Message(
|
||||
content=choice["message"].get("content", None),
|
||||
content=content,
|
||||
role=choice["message"]["role"] or "assistant",
|
||||
function_call=choice["message"].get("function_call", None),
|
||||
tool_calls=tool_calls,
|
||||
|
|
|
@ -20,6 +20,15 @@ from .common_utils import InfinityError
|
|||
|
||||
|
||||
class InfinityRerankConfig(CohereRerankConfig):
|
||||
def get_complete_url(self, api_base: Optional[str], model: str) -> str:
|
||||
if api_base is None:
|
||||
raise ValueError("api_base is required for Infinity rerank")
|
||||
# Remove trailing slashes and ensure clean base URL
|
||||
api_base = api_base.rstrip("/")
|
||||
if not api_base.endswith("/rerank"):
|
||||
api_base = f"{api_base}/rerank"
|
||||
return api_base
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
|
|
|
@ -89,6 +89,7 @@ from litellm.litellm_core_utils.llm_response_utils.convert_dict_to_response impo
|
|||
convert_to_model_response_object,
|
||||
convert_to_streaming_response,
|
||||
convert_to_streaming_response_async,
|
||||
_parse_content_for_reasoning,
|
||||
)
|
||||
from litellm.litellm_core_utils.llm_response_utils.get_api_base import get_api_base
|
||||
from litellm.litellm_core_utils.llm_response_utils.get_formatted_prompt import (
|
||||
|
|
|
@ -864,6 +864,18 @@ def test_convert_model_response_object():
|
|||
== '{"type":"error","error":{"type":"invalid_request_error","message":"Output blocked by content filtering policy"}}'
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"content, expected_reasoning, expected_content",
|
||||
[
|
||||
(None, None, None),
|
||||
("<think>I am thinking here</think>The sky is a canvas of blue", "I am thinking here", "The sky is a canvas of blue"),
|
||||
("I am a regular response", None, "I am a regular response"),
|
||||
|
||||
]
|
||||
)
|
||||
def test_parse_content_for_reasoning(content, expected_reasoning, expected_content):
|
||||
assert(litellm.utils._parse_content_for_reasoning(content) == (expected_reasoning, expected_content))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model, expected_bool",
|
||||
|
|
|
@ -13,7 +13,7 @@ import litellm.types.utils
|
|||
from litellm.llms.anthropic.chat import ModelResponseIterator
|
||||
import httpx
|
||||
import json
|
||||
from respx import MockRouter
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
|
||||
load_dotenv()
|
||||
import io
|
||||
|
@ -184,3 +184,45 @@ def test_completion_azure_ai_command_r():
|
|||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
def test_azure_deepseek_reasoning_content():
|
||||
import json
|
||||
|
||||
client = HTTPHandler()
|
||||
|
||||
with patch.object(client, "post") as mock_post:
|
||||
mock_response = MagicMock()
|
||||
|
||||
mock_response.text = json.dumps(
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"message": {
|
||||
"content": "<think>I am thinking here</think>\n\nThe sky is a canvas of blue",
|
||||
"role": "assistant",
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
mock_response.status_code = 200
|
||||
# Add required response attributes
|
||||
mock_response.headers = {"Content-Type": "application/json"}
|
||||
mock_response.json = lambda: json.loads(mock_response.text)
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
|
||||
response = litellm.completion(
|
||||
model='azure_ai/deepseek-r1',
|
||||
messages=[{"role": "user", "content": "Hello, world!"}],
|
||||
api_base="https://litellm8397336933.services.ai.azure.com/models/chat/completions",
|
||||
api_key="my-fake-api-key",
|
||||
client=client
|
||||
)
|
||||
|
||||
print(response)
|
||||
assert(response.choices[0].message.reasoning_content == "I am thinking here")
|
||||
assert(response.choices[0].message.content == "\n\nThe sky is a canvas of blue")
|
||||
|
|
|
@ -69,7 +69,7 @@ async def test_infinity_rerank():
|
|||
_url = mock_post.call_args.kwargs["url"]
|
||||
print("Arguments passed to API=", args_to_api)
|
||||
print("url = ", _url)
|
||||
assert _url == "https://api.infinity.ai/v1/rerank"
|
||||
assert _url == "https://api.infinity.ai/rerank"
|
||||
|
||||
request_data = json.loads(args_to_api)
|
||||
assert request_data["query"] == expected_payload["query"]
|
||||
|
@ -133,7 +133,7 @@ async def test_infinity_rerank_with_env(monkeypatch):
|
|||
_url = mock_post.call_args.kwargs["url"]
|
||||
print("Arguments passed to API=", args_to_api)
|
||||
print("url = ", _url)
|
||||
assert _url == "https://env.infinity.ai/v1/rerank"
|
||||
assert _url == "https://env.infinity.ai/rerank"
|
||||
|
||||
request_data = json.loads(args_to_api)
|
||||
assert request_data["query"] == expected_payload["query"]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue