Deepseek r1 support + watsonx qa improvements (#7907)

* fix(types/utils.py): support returning 'reasoning_content' for deepseek models

Fixes https://github.com/BerriAI/litellm/issues/7877#issuecomment-2603813218

* fix(convert_dict_to_response.py): return deepseek response in provider_specific_field

allows for separating openai vs. non-openai params in model response

* fix(utils.py): support 'provider_specific_field' in delta chunk as well

allows deepseek reasoning content chunk to be returned to user from stream as well

Fixes https://github.com/BerriAI/litellm/issues/7877#issuecomment-2603813218

* fix(watsonx/chat/handler.py): fix passing space id to watsonx on chat route

* fix(watsonx/): fix watsonx_text/ route with space id

* fix(watsonx/): qa item - also adds better unit testing for watsonx embedding calls

* fix(utils.py): rename to '..fields'

* fix: fix linting errors

* fix(utils.py): fix typing - don't show provider-specific field if none or empty - prevents default respons
e from being non-oai compatible

* fix: cleanup unused imports

* docs(deepseek.md): add docs for deepseek reasoning model
This commit is contained in:
Krish Dholakia 2025-01-21 23:13:15 -08:00 committed by GitHub
parent 26a79a533d
commit 76795dba39
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 281 additions and 63 deletions

View file

@ -1,3 +1,6 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Deepseek # Deepseek
https://deepseek.com/ https://deepseek.com/
@ -52,3 +55,72 @@ We support ALL Deepseek models, just set `deepseek/` as a prefix when sending co
| deepseek-coder | `completion(model="deepseek/deepseek-coder", messages)` | | deepseek-coder | `completion(model="deepseek/deepseek-coder", messages)` |
## Reasoning Models
| Model Name | Function Call |
|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| deepseek-reasoner | `completion(model="deepseek/deepseek-reasoner", messages)` |
<Tabs>
<TabItem value="sdk" label="SDK">
```python
from litellm import completion
import os
os.environ['DEEPSEEK_API_KEY'] = ""
resp = completion(
model="deepseek/deepseek-reasoner",
messages=[{"role": "user", "content": "Tell me a joke."}],
)
print(
resp.choices[0].message.provider_specific_fields["reasoning_content"]
)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Setup config.yaml
```yaml
model_list:
- model_name: deepseek-reasoner
litellm_params:
model: deepseek/deepseek-reasoner
api_key: os.environ/DEEPSEEK_API_KEY
```
2. Run proxy
```bash
python litellm/proxy/main.py
```
3. Test it!
```bash
curl -L -X POST 'http://0.0.0.0:4000/v1/chat/completions' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer sk-1234' \
-d '{
"model": "deepseek-reasoner",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "Hi, how are you ?"
}
]
}
]
}'
```
</TabItem>
</Tabs>

View file

@ -337,7 +337,6 @@ def convert_to_model_response_object( # noqa: PLR0915
] = None, # used for supporting 'json_schema' on older models ] = None, # used for supporting 'json_schema' on older models
): ):
received_args = locals() received_args = locals()
additional_headers = get_response_headers(_response_headers) additional_headers = get_response_headers(_response_headers)
if hidden_params is None: if hidden_params is None:
@ -411,12 +410,18 @@ def convert_to_model_response_object( # noqa: PLR0915
message = litellm.Message(content=json_mode_content_str) message = litellm.Message(content=json_mode_content_str)
finish_reason = "stop" finish_reason = "stop"
if message is None: if message is None:
provider_specific_fields = {}
message_keys = Message.model_fields.keys()
for field in choice["message"].keys():
if field not in message_keys:
provider_specific_fields[field] = choice["message"][field]
message = Message( message = Message(
content=choice["message"].get("content", None), content=choice["message"].get("content", None),
role=choice["message"]["role"] or "assistant", role=choice["message"]["role"] or "assistant",
function_call=choice["message"].get("function_call", None), function_call=choice["message"].get("function_call", None),
tool_calls=tool_calls, tool_calls=tool_calls,
audio=choice["message"].get("audio", None), audio=choice["message"].get("audio", None),
provider_specific_fields=provider_specific_fields,
) )
finish_reason = choice.get("finish_reason", None) finish_reason = choice.get("finish_reason", None)
if finish_reason is None: if finish_reason is None:

View file

@ -51,6 +51,13 @@ class WatsonXChatHandler(OpenAILikeChatHandler):
api_key=api_key, api_key=api_key,
) )
## UPDATE PAYLOAD (optional params)
watsonx_auth_payload = watsonx_chat_transformation._prepare_payload(
model=model,
api_params=api_params,
)
optional_params.update(watsonx_auth_payload)
## GET API URL ## GET API URL
api_base = watsonx_chat_transformation.get_complete_url( api_base = watsonx_chat_transformation.get_complete_url(
api_base=api_base, api_base=api_base,
@ -59,13 +66,6 @@ class WatsonXChatHandler(OpenAILikeChatHandler):
stream=optional_params.get("stream", False), stream=optional_params.get("stream", False),
) )
## UPDATE PAYLOAD (optional params)
watsonx_auth_payload = watsonx_chat_transformation._prepare_payload(
model=model,
api_params=api_params,
)
optional_params.update(watsonx_auth_payload)
return super().completion( return super().completion(
model=model, model=model,
messages=messages, messages=messages,

View file

@ -7,11 +7,11 @@ Docs: https://cloud.ibm.com/apidocs/watsonx-ai#text-chat
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
from litellm.secret_managers.main import get_secret_str from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.watsonx import WatsonXAIEndpoint, WatsonXAPIParams from litellm.types.llms.watsonx import WatsonXAIEndpoint
from ....utils import _remove_additional_properties, _remove_strict_from_schema from ....utils import _remove_additional_properties, _remove_strict_from_schema
from ...openai.chat.gpt_transformation import OpenAIGPTConfig from ...openai.chat.gpt_transformation import OpenAIGPTConfig
from ..common_utils import IBMWatsonXMixin, WatsonXAIError from ..common_utils import IBMWatsonXMixin
class IBMWatsonXChatConfig(IBMWatsonXMixin, OpenAIGPTConfig): class IBMWatsonXChatConfig(IBMWatsonXMixin, OpenAIGPTConfig):
@ -87,12 +87,6 @@ class IBMWatsonXChatConfig(IBMWatsonXMixin, OpenAIGPTConfig):
) -> str: ) -> str:
url = self._get_base_url(api_base=api_base) url = self._get_base_url(api_base=api_base)
if model.startswith("deployment/"): if model.startswith("deployment/"):
# deployment models are passed in as 'deployment/<deployment_id>'
if optional_params.get("space_id") is None:
raise WatsonXAIError(
status_code=401,
message="Error: space_id is required for models called using the 'deployment/' endpoint. Pass in the space_id as a parameter or set it in the WX_SPACE_ID environment variable.",
)
deployment_id = "/".join(model.split("/")[1:]) deployment_id = "/".join(model.split("/")[1:])
endpoint = ( endpoint = (
WatsonXAIEndpoint.DEPLOYMENT_CHAT_STREAM.value WatsonXAIEndpoint.DEPLOYMENT_CHAT_STREAM.value
@ -113,11 +107,3 @@ class IBMWatsonXChatConfig(IBMWatsonXMixin, OpenAIGPTConfig):
url=url, api_version=optional_params.pop("api_version", None) url=url, api_version=optional_params.pop("api_version", None)
) )
return url return url
def _prepare_payload(self, model: str, api_params: WatsonXAPIParams) -> dict:
payload: dict = {}
if model.startswith("deployment/"):
return payload
payload["model_id"] = model
payload["project_id"] = api_params["project_id"]
return payload

View file

@ -275,3 +275,17 @@ class IBMWatsonXMixin:
return WatsonXCredentials( return WatsonXCredentials(
api_key=api_key, api_base=api_base, token=cast(Optional[str], token) api_key=api_key, api_base=api_base, token=cast(Optional[str], token)
) )
def _prepare_payload(self, model: str, api_params: WatsonXAPIParams) -> dict:
payload: dict = {}
if model.startswith("deployment/"):
if api_params["space_id"] is None:
raise WatsonXAIError(
status_code=401,
message="Error: space_id is required for models called using the 'deployment/' endpoint. Pass in the space_id as a parameter or set it in the WX_SPACE_ID environment variable.",
)
payload["space_id"] = api_params["space_id"]
return payload
payload["model_id"] = model
payload["project_id"] = api_params["project_id"]
return payload

View file

@ -246,17 +246,20 @@ class IBMWatsonXAIConfig(IBMWatsonXMixin, BaseConfig):
extra_body_params = optional_params.pop("extra_body", {}) extra_body_params = optional_params.pop("extra_body", {})
optional_params.update(extra_body_params) optional_params.update(extra_body_params)
watsonx_api_params = _get_api_params(params=optional_params) watsonx_api_params = _get_api_params(params=optional_params)
watsonx_auth_payload = self._prepare_payload(
model=model,
api_params=watsonx_api_params,
)
# init the payload to the text generation call # init the payload to the text generation call
payload = { payload = {
"input": prompt, "input": prompt,
"moderations": optional_params.pop("moderations", {}), "moderations": optional_params.pop("moderations", {}),
"parameters": optional_params, "parameters": optional_params,
**watsonx_auth_payload,
} }
if not model.startswith("deployment/"):
payload["model_id"] = model
payload["project_id"] = watsonx_api_params["project_id"]
return payload return payload
def transform_response( def transform_response(
@ -320,11 +323,6 @@ class IBMWatsonXAIConfig(IBMWatsonXMixin, BaseConfig):
url = self._get_base_url(api_base=api_base) url = self._get_base_url(api_base=api_base)
if model.startswith("deployment/"): if model.startswith("deployment/"):
# deployment models are passed in as 'deployment/<deployment_id>' # deployment models are passed in as 'deployment/<deployment_id>'
if optional_params.get("space_id") is None:
raise WatsonXAIError(
status_code=401,
message="Error: space_id is required for models called using the 'deployment/' endpoint. Pass in the space_id as a parameter or set it in the WX_SPACE_ID environment variable.",
)
deployment_id = "/".join(model.split("/")[1:]) deployment_id = "/".join(model.split("/")[1:])
endpoint = ( endpoint = (
WatsonXAIEndpoint.DEPLOYMENT_TEXT_GENERATION_STREAM.value WatsonXAIEndpoint.DEPLOYMENT_TEXT_GENERATION_STREAM.value

View file

@ -14,7 +14,7 @@ from litellm.types.llms.openai import AllEmbeddingInputValues
from litellm.types.llms.watsonx import WatsonXAIEndpoint from litellm.types.llms.watsonx import WatsonXAIEndpoint
from litellm.types.utils import EmbeddingResponse, Usage from litellm.types.utils import EmbeddingResponse, Usage
from ..common_utils import IBMWatsonXMixin, WatsonXAIError, _get_api_params from ..common_utils import IBMWatsonXMixin, _get_api_params
class IBMWatsonXEmbeddingConfig(IBMWatsonXMixin, BaseEmbeddingConfig): class IBMWatsonXEmbeddingConfig(IBMWatsonXMixin, BaseEmbeddingConfig):
@ -38,14 +38,15 @@ class IBMWatsonXEmbeddingConfig(IBMWatsonXMixin, BaseEmbeddingConfig):
headers: dict, headers: dict,
) -> dict: ) -> dict:
watsonx_api_params = _get_api_params(params=optional_params) watsonx_api_params = _get_api_params(params=optional_params)
project_id = watsonx_api_params["project_id"] watsonx_auth_payload = self._prepare_payload(
if not project_id: model=model,
raise ValueError("project_id is required") api_params=watsonx_api_params,
)
return { return {
"inputs": input, "inputs": input,
"model_id": model,
"project_id": project_id,
"parameters": optional_params, "parameters": optional_params,
**watsonx_auth_payload,
} }
def get_complete_url( def get_complete_url(
@ -58,12 +59,6 @@ class IBMWatsonXEmbeddingConfig(IBMWatsonXMixin, BaseEmbeddingConfig):
url = self._get_base_url(api_base=api_base) url = self._get_base_url(api_base=api_base)
endpoint = WatsonXAIEndpoint.EMBEDDINGS.value endpoint = WatsonXAIEndpoint.EMBEDDINGS.value
if model.startswith("deployment/"): if model.startswith("deployment/"):
# deployment models are passed in as 'deployment/<deployment_id>'
if optional_params.get("space_id") is None:
raise WatsonXAIError(
status_code=401,
message="Error: space_id is required for models called using the 'deployment/' endpoint. Pass in the space_id as a parameter or set it in the WX_SPACE_ID environment variable.",
)
deployment_id = "/".join(model.split("/")[1:]) deployment_id = "/".join(model.split("/")[1:])
endpoint = endpoint.format(deployment_id=deployment_id) endpoint = endpoint.format(deployment_id=deployment_id)
url = url.rstrip("/") + endpoint url = url.rstrip("/") + endpoint

View file

@ -6,7 +6,7 @@ model_list:
api_base: https://exampleopenaiendpoint-production.up.railway.app api_base: https://exampleopenaiendpoint-production.up.railway.app
- model_name: openai-o1 - model_name: openai-o1
litellm_params: litellm_params:
model: bedrock/anthropic.claude-3-sonnet-20240229-v1:0 model: openai/random_sleep
api_base: http://0.0.0.0:8090 api_base: http://0.0.0.0:8090
timeout: 2 timeout: 2
num_retries: 0 num_retries: 0

View file

@ -18,7 +18,7 @@ from openai.types.moderation import (
CategoryScores, CategoryScores,
) )
from openai.types.moderation_create_response import Moderation, ModerationCreateResponse from openai.types.moderation_create_response import Moderation, ModerationCreateResponse
from pydantic import BaseModel, ConfigDict, PrivateAttr from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
from typing_extensions import Callable, Dict, Required, TypedDict, override from typing_extensions import Callable, Dict, Required, TypedDict, override
from ..litellm_core_utils.core_helpers import map_finish_reason from ..litellm_core_utils.core_helpers import map_finish_reason
@ -438,6 +438,9 @@ class Message(OpenAIObject):
tool_calls: Optional[List[ChatCompletionMessageToolCall]] tool_calls: Optional[List[ChatCompletionMessageToolCall]]
function_call: Optional[FunctionCall] function_call: Optional[FunctionCall]
audio: Optional[ChatCompletionAudioResponse] = None audio: Optional[ChatCompletionAudioResponse] = None
provider_specific_fields: Optional[Dict[str, Any]] = Field(
default=None, exclude=True
)
def __init__( def __init__(
self, self,
@ -446,6 +449,7 @@ class Message(OpenAIObject):
function_call=None, function_call=None,
tool_calls: Optional[list] = None, tool_calls: Optional[list] = None,
audio: Optional[ChatCompletionAudioResponse] = None, audio: Optional[ChatCompletionAudioResponse] = None,
provider_specific_fields: Optional[Dict[str, Any]] = None,
**params, **params,
): ):
init_values: Dict[str, Any] = { init_values: Dict[str, Any] = {
@ -481,6 +485,9 @@ class Message(OpenAIObject):
# OpenAI compatible APIs like mistral API will raise an error if audio is passed in # OpenAI compatible APIs like mistral API will raise an error if audio is passed in
del self.audio del self.audio
if provider_specific_fields: # set if provider_specific_fields is not empty
self.provider_specific_fields = provider_specific_fields
def get(self, key, default=None): def get(self, key, default=None):
# Custom .get() method to access attributes with a default value if the attribute doesn't exist # Custom .get() method to access attributes with a default value if the attribute doesn't exist
return getattr(self, key, default) return getattr(self, key, default)
@ -502,6 +509,10 @@ class Message(OpenAIObject):
class Delta(OpenAIObject): class Delta(OpenAIObject):
provider_specific_fields: Optional[Dict[str, Any]] = Field(
default=None, exclude=True
)
def __init__( def __init__(
self, self,
content=None, content=None,
@ -511,14 +522,20 @@ class Delta(OpenAIObject):
audio: Optional[ChatCompletionAudioResponse] = None, audio: Optional[ChatCompletionAudioResponse] = None,
**params, **params,
): ):
provider_specific_fields: Dict[str, Any] = {}
if "reasoning_content" in params:
provider_specific_fields["reasoning_content"] = params["reasoning_content"]
del params["reasoning_content"]
super(Delta, self).__init__(**params) super(Delta, self).__init__(**params)
self.content = content self.content = content
self.role = role self.role = role
self.provider_specific_fields = provider_specific_fields
# Set default values and correct types # Set default values and correct types
self.function_call: Optional[Union[FunctionCall, Any]] = None self.function_call: Optional[Union[FunctionCall, Any]] = None
self.tool_calls: Optional[List[Union[ChatCompletionDeltaToolCall, Any]]] = None self.tool_calls: Optional[List[Union[ChatCompletionDeltaToolCall, Any]]] = None
self.audio: Optional[ChatCompletionAudioResponse] = None self.audio: Optional[ChatCompletionAudioResponse] = None
if provider_specific_fields: # set if provider_specific_fields is not empty
self.provider_specific_fields = provider_specific_fields
if function_call is not None and isinstance(function_call, dict): if function_call is not None and isinstance(function_call, dict):
self.function_call = FunctionCall(**function_call) self.function_call = FunctionCall(**function_call)

View file

@ -8,11 +8,12 @@ sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
import litellm import litellm
from litellm import completion from litellm import completion, embedding
from litellm.llms.watsonx.common_utils import IBMWatsonXMixin from litellm.llms.watsonx.common_utils import IBMWatsonXMixin
from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler
from unittest.mock import patch, MagicMock, AsyncMock, Mock from unittest.mock import patch, MagicMock, AsyncMock, Mock
import pytest import pytest
from typing import Optional
@pytest.fixture @pytest.fixture
@ -21,6 +22,7 @@ def watsonx_chat_completion_call():
model="watsonx/my-test-model", model="watsonx/my-test-model",
messages=None, messages=None,
api_key="test_api_key", api_key="test_api_key",
space_id: Optional[str] = None,
headers=None, headers=None,
client=None, client=None,
patch_token_call=True, patch_token_call=True,
@ -41,24 +43,90 @@ def watsonx_chat_completion_call():
with patch.object(client, "post") as mock_post, patch.object( with patch.object(client, "post") as mock_post, patch.object(
litellm.module_level_client, "post", return_value=mock_response litellm.module_level_client, "post", return_value=mock_response
) as mock_get: ) as mock_get:
completion( try:
model=model, completion(
messages=messages, model=model,
api_key=api_key, messages=messages,
headers=headers or {}, api_key=api_key,
client=client, headers=headers or {},
) client=client,
space_id=space_id,
)
except Exception as e:
print(e)
return mock_post, mock_get return mock_post, mock_get
else: else:
with patch.object(client, "post") as mock_post: with patch.object(client, "post") as mock_post:
completion( try:
model=model, completion(
messages=messages, model=model,
api_key=api_key, messages=messages,
headers=headers or {}, api_key=api_key,
client=client, headers=headers or {},
) client=client,
space_id=space_id,
)
except Exception as e:
print(e)
return mock_post, None
return _call
@pytest.fixture
def watsonx_embedding_call():
def _call(
model="watsonx/my-test-model",
input=None,
api_key="test_api_key",
space_id: Optional[str] = None,
headers=None,
client=None,
patch_token_call=True,
):
if input is None:
input = ["Hello, how are you?"]
if client is None:
client = HTTPHandler()
if patch_token_call:
mock_response = Mock()
mock_response.json.return_value = {
"access_token": "mock_access_token",
"expires_in": 3600,
}
mock_response.raise_for_status = Mock() # No-op to simulate no exception
with patch.object(client, "post") as mock_post, patch.object(
litellm.module_level_client, "post", return_value=mock_response
) as mock_get:
try:
embedding(
model=model,
input=input,
api_key=api_key,
headers=headers or {},
client=client,
space_id=space_id,
)
except Exception as e:
print(e)
return mock_post, mock_get
else:
with patch.object(client, "post") as mock_post:
try:
embedding(
model=model,
input=input,
api_key=api_key,
headers=headers or {},
client=client,
space_id=space_id,
)
except Exception as e:
print(e)
return mock_post, None return mock_post, None
return _call return _call
@ -118,3 +186,35 @@ def test_watsonx_chat_completions_endpoint(watsonx_chat_completion_call):
assert mock_post.call_count == 1 assert mock_post.call_count == 1
assert "deployment" not in mock_post.call_args.kwargs["url"] assert "deployment" not in mock_post.call_args.kwargs["url"]
@pytest.mark.parametrize(
"model",
[
"watsonx/deployment/<xxxx.xxx.xxx.xxxx>",
"watsonx_text/deployment/<xxxx.xxx.xxx.xxxx>",
],
)
def test_watsonx_deployment_space_id(monkeypatch, watsonx_chat_completion_call, model):
my_fake_space_id = "xxx-xxx-xxx-xxx-xxx"
monkeypatch.setenv("WATSONX_SPACE_ID", my_fake_space_id)
mock_post, _ = watsonx_chat_completion_call(
model=model,
messages=[{"content": "Hello, how are you?", "role": "user"}],
)
assert mock_post.call_count == 1
json_data = json.loads(mock_post.call_args.kwargs["data"])
assert my_fake_space_id == json_data["space_id"]
def test_watsonx_deployment_space_id_embedding(monkeypatch, watsonx_embedding_call):
my_fake_space_id = "xxx-xxx-xxx-xxx-xxx"
monkeypatch.setenv("WATSONX_SPACE_ID", my_fake_space_id)
mock_post, _ = watsonx_embedding_call(model="watsonx/deployment/my-test-model")
assert mock_post.call_count == 1
json_data = json.loads(mock_post.call_args.kwargs["data"])
assert my_fake_space_id == json_data["space_id"]

View file

@ -4537,3 +4537,16 @@ def test_humanloop_completion(monkeypatch):
prompt_variables={"person": "John"}, prompt_variables={"person": "John"},
messages=[{"role": "user", "content": "Tell me a joke."}], messages=[{"role": "user", "content": "Tell me a joke."}],
) )
def test_deepseek_reasoning_content_completion():
litellm.set_verbose = True
resp = litellm.completion(
model="deepseek/deepseek-reasoner",
messages=[{"role": "user", "content": "Tell me a joke."}],
)
assert (
resp.choices[0].message.provider_specific_fields["reasoning_content"]
is not None
)

View file

@ -4063,3 +4063,21 @@ def test_mock_response_iterator_tool_use():
response_chunk = completion_stream._chunk_parser(chunk_data=response) response_chunk = completion_stream._chunk_parser(chunk_data=response)
assert response_chunk["tool_use"] is not None assert response_chunk["tool_use"] is not None
def test_deepseek_reasoning_content_completion():
litellm.set_verbose = True
resp = litellm.completion(
model="deepseek/deepseek-reasoner",
messages=[{"role": "user", "content": "Tell me a joke."}],
stream=True,
)
reasoning_content_exists = False
for chunk in resp:
print(f"chunk: {chunk}")
if chunk.choices[0].delta.content is not None:
if "reasoning_content" in chunk.choices[0].delta.provider_specific_fields:
reasoning_content_exists = True
break
assert reasoning_content_exists