From 76795dba39a60088891306198f06796d4319c36e Mon Sep 17 00:00:00 2001 From: Krish Dholakia Date: Tue, 21 Jan 2025 23:13:15 -0800 Subject: [PATCH] Deepseek r1 support + watsonx qa improvements (#7907) * fix(types/utils.py): support returning 'reasoning_content' for deepseek models Fixes https://github.com/BerriAI/litellm/issues/7877#issuecomment-2603813218 * fix(convert_dict_to_response.py): return deepseek response in provider_specific_field allows for separating openai vs. non-openai params in model response * fix(utils.py): support 'provider_specific_field' in delta chunk as well allows deepseek reasoning content chunk to be returned to user from stream as well Fixes https://github.com/BerriAI/litellm/issues/7877#issuecomment-2603813218 * fix(watsonx/chat/handler.py): fix passing space id to watsonx on chat route * fix(watsonx/): fix watsonx_text/ route with space id * fix(watsonx/): qa item - also adds better unit testing for watsonx embedding calls * fix(utils.py): rename to '..fields' * fix: fix linting errors * fix(utils.py): fix typing - don't show provider-specific field if none or empty - prevents default respons e from being non-oai compatible * fix: cleanup unused imports * docs(deepseek.md): add docs for deepseek reasoning model --- docs/my-website/docs/providers/deepseek.md | 72 ++++++++++ .../convert_dict_to_response.py | 7 +- litellm/llms/watsonx/chat/handler.py | 14 +- litellm/llms/watsonx/chat/transformation.py | 18 +-- litellm/llms/watsonx/common_utils.py | 14 ++ .../llms/watsonx/completion/transformation.py | 16 +-- litellm/llms/watsonx/embed/transformation.py | 19 +-- litellm/proxy/_new_secret_config.yaml | 2 +- litellm/types/utils.py | 21 ++- tests/llm_translation/test_watsonx.py | 130 ++++++++++++++++-- tests/local_testing/test_completion.py | 13 ++ tests/local_testing/test_streaming.py | 18 +++ 12 files changed, 281 insertions(+), 63 deletions(-) diff --git a/docs/my-website/docs/providers/deepseek.md b/docs/my-website/docs/providers/deepseek.md index dfe51e6c2e..9f48e87123 100644 --- a/docs/my-website/docs/providers/deepseek.md +++ b/docs/my-website/docs/providers/deepseek.md @@ -1,3 +1,6 @@ +import Tabs from '@theme/Tabs'; +import TabItem from '@theme/TabItem'; + # Deepseek https://deepseek.com/ @@ -52,3 +55,72 @@ We support ALL Deepseek models, just set `deepseek/` as a prefix when sending co | deepseek-coder | `completion(model="deepseek/deepseek-coder", messages)` | +## Reasoning Models +| Model Name | Function Call | +|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| deepseek-reasoner | `completion(model="deepseek/deepseek-reasoner", messages)` | + + + + + + +```python +from litellm import completion +import os + +os.environ['DEEPSEEK_API_KEY'] = "" +resp = completion( + model="deepseek/deepseek-reasoner", + messages=[{"role": "user", "content": "Tell me a joke."}], +) + +print( + resp.choices[0].message.provider_specific_fields["reasoning_content"] +) +``` + + + + +1. Setup config.yaml + +```yaml +model_list: + - model_name: deepseek-reasoner + litellm_params: + model: deepseek/deepseek-reasoner + api_key: os.environ/DEEPSEEK_API_KEY +``` + +2. Run proxy + +```bash +python litellm/proxy/main.py +``` + +3. Test it! + +```bash +curl -L -X POST 'http://0.0.0.0:4000/v1/chat/completions' \ +-H 'Content-Type: application/json' \ +-H 'Authorization: Bearer sk-1234' \ +-d '{ + "model": "deepseek-reasoner", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Hi, how are you ?" + } + ] + } + ] +}' +``` + + + + \ No newline at end of file diff --git a/litellm/litellm_core_utils/llm_response_utils/convert_dict_to_response.py b/litellm/litellm_core_utils/llm_response_utils/convert_dict_to_response.py index 93926a81f4..28d546796d 100644 --- a/litellm/litellm_core_utils/llm_response_utils/convert_dict_to_response.py +++ b/litellm/litellm_core_utils/llm_response_utils/convert_dict_to_response.py @@ -337,7 +337,6 @@ def convert_to_model_response_object( # noqa: PLR0915 ] = None, # used for supporting 'json_schema' on older models ): received_args = locals() - additional_headers = get_response_headers(_response_headers) if hidden_params is None: @@ -411,12 +410,18 @@ def convert_to_model_response_object( # noqa: PLR0915 message = litellm.Message(content=json_mode_content_str) finish_reason = "stop" if message is None: + provider_specific_fields = {} + message_keys = Message.model_fields.keys() + for field in choice["message"].keys(): + if field not in message_keys: + provider_specific_fields[field] = choice["message"][field] message = Message( content=choice["message"].get("content", None), role=choice["message"]["role"] or "assistant", function_call=choice["message"].get("function_call", None), tool_calls=tool_calls, audio=choice["message"].get("audio", None), + provider_specific_fields=provider_specific_fields, ) finish_reason = choice.get("finish_reason", None) if finish_reason is None: diff --git a/litellm/llms/watsonx/chat/handler.py b/litellm/llms/watsonx/chat/handler.py index 4f2d36d7a8..fd195214db 100644 --- a/litellm/llms/watsonx/chat/handler.py +++ b/litellm/llms/watsonx/chat/handler.py @@ -51,6 +51,13 @@ class WatsonXChatHandler(OpenAILikeChatHandler): api_key=api_key, ) + ## UPDATE PAYLOAD (optional params) + watsonx_auth_payload = watsonx_chat_transformation._prepare_payload( + model=model, + api_params=api_params, + ) + optional_params.update(watsonx_auth_payload) + ## GET API URL api_base = watsonx_chat_transformation.get_complete_url( api_base=api_base, @@ -59,13 +66,6 @@ class WatsonXChatHandler(OpenAILikeChatHandler): stream=optional_params.get("stream", False), ) - ## UPDATE PAYLOAD (optional params) - watsonx_auth_payload = watsonx_chat_transformation._prepare_payload( - model=model, - api_params=api_params, - ) - optional_params.update(watsonx_auth_payload) - return super().completion( model=model, messages=messages, diff --git a/litellm/llms/watsonx/chat/transformation.py b/litellm/llms/watsonx/chat/transformation.py index 5d0c432c56..208da82ef5 100644 --- a/litellm/llms/watsonx/chat/transformation.py +++ b/litellm/llms/watsonx/chat/transformation.py @@ -7,11 +7,11 @@ Docs: https://cloud.ibm.com/apidocs/watsonx-ai#text-chat from typing import List, Optional, Tuple, Union from litellm.secret_managers.main import get_secret_str -from litellm.types.llms.watsonx import WatsonXAIEndpoint, WatsonXAPIParams +from litellm.types.llms.watsonx import WatsonXAIEndpoint from ....utils import _remove_additional_properties, _remove_strict_from_schema from ...openai.chat.gpt_transformation import OpenAIGPTConfig -from ..common_utils import IBMWatsonXMixin, WatsonXAIError +from ..common_utils import IBMWatsonXMixin class IBMWatsonXChatConfig(IBMWatsonXMixin, OpenAIGPTConfig): @@ -87,12 +87,6 @@ class IBMWatsonXChatConfig(IBMWatsonXMixin, OpenAIGPTConfig): ) -> str: url = self._get_base_url(api_base=api_base) if model.startswith("deployment/"): - # deployment models are passed in as 'deployment/' - if optional_params.get("space_id") is None: - raise WatsonXAIError( - status_code=401, - message="Error: space_id is required for models called using the 'deployment/' endpoint. Pass in the space_id as a parameter or set it in the WX_SPACE_ID environment variable.", - ) deployment_id = "/".join(model.split("/")[1:]) endpoint = ( WatsonXAIEndpoint.DEPLOYMENT_CHAT_STREAM.value @@ -113,11 +107,3 @@ class IBMWatsonXChatConfig(IBMWatsonXMixin, OpenAIGPTConfig): url=url, api_version=optional_params.pop("api_version", None) ) return url - - def _prepare_payload(self, model: str, api_params: WatsonXAPIParams) -> dict: - payload: dict = {} - if model.startswith("deployment/"): - return payload - payload["model_id"] = model - payload["project_id"] = api_params["project_id"] - return payload diff --git a/litellm/llms/watsonx/common_utils.py b/litellm/llms/watsonx/common_utils.py index 62f141c474..4916cd1c75 100644 --- a/litellm/llms/watsonx/common_utils.py +++ b/litellm/llms/watsonx/common_utils.py @@ -275,3 +275,17 @@ class IBMWatsonXMixin: return WatsonXCredentials( api_key=api_key, api_base=api_base, token=cast(Optional[str], token) ) + + def _prepare_payload(self, model: str, api_params: WatsonXAPIParams) -> dict: + payload: dict = {} + if model.startswith("deployment/"): + if api_params["space_id"] is None: + raise WatsonXAIError( + status_code=401, + message="Error: space_id is required for models called using the 'deployment/' endpoint. Pass in the space_id as a parameter or set it in the WX_SPACE_ID environment variable.", + ) + payload["space_id"] = api_params["space_id"] + return payload + payload["model_id"] = model + payload["project_id"] = api_params["project_id"] + return payload diff --git a/litellm/llms/watsonx/completion/transformation.py b/litellm/llms/watsonx/completion/transformation.py index e214a945d2..7e6a8a525d 100644 --- a/litellm/llms/watsonx/completion/transformation.py +++ b/litellm/llms/watsonx/completion/transformation.py @@ -246,17 +246,20 @@ class IBMWatsonXAIConfig(IBMWatsonXMixin, BaseConfig): extra_body_params = optional_params.pop("extra_body", {}) optional_params.update(extra_body_params) watsonx_api_params = _get_api_params(params=optional_params) + + watsonx_auth_payload = self._prepare_payload( + model=model, + api_params=watsonx_api_params, + ) + # init the payload to the text generation call payload = { "input": prompt, "moderations": optional_params.pop("moderations", {}), "parameters": optional_params, + **watsonx_auth_payload, } - if not model.startswith("deployment/"): - payload["model_id"] = model - payload["project_id"] = watsonx_api_params["project_id"] - return payload def transform_response( @@ -320,11 +323,6 @@ class IBMWatsonXAIConfig(IBMWatsonXMixin, BaseConfig): url = self._get_base_url(api_base=api_base) if model.startswith("deployment/"): # deployment models are passed in as 'deployment/' - if optional_params.get("space_id") is None: - raise WatsonXAIError( - status_code=401, - message="Error: space_id is required for models called using the 'deployment/' endpoint. Pass in the space_id as a parameter or set it in the WX_SPACE_ID environment variable.", - ) deployment_id = "/".join(model.split("/")[1:]) endpoint = ( WatsonXAIEndpoint.DEPLOYMENT_TEXT_GENERATION_STREAM.value diff --git a/litellm/llms/watsonx/embed/transformation.py b/litellm/llms/watsonx/embed/transformation.py index a41a0d6f2a..69c1f8fffa 100644 --- a/litellm/llms/watsonx/embed/transformation.py +++ b/litellm/llms/watsonx/embed/transformation.py @@ -14,7 +14,7 @@ from litellm.types.llms.openai import AllEmbeddingInputValues from litellm.types.llms.watsonx import WatsonXAIEndpoint from litellm.types.utils import EmbeddingResponse, Usage -from ..common_utils import IBMWatsonXMixin, WatsonXAIError, _get_api_params +from ..common_utils import IBMWatsonXMixin, _get_api_params class IBMWatsonXEmbeddingConfig(IBMWatsonXMixin, BaseEmbeddingConfig): @@ -38,14 +38,15 @@ class IBMWatsonXEmbeddingConfig(IBMWatsonXMixin, BaseEmbeddingConfig): headers: dict, ) -> dict: watsonx_api_params = _get_api_params(params=optional_params) - project_id = watsonx_api_params["project_id"] - if not project_id: - raise ValueError("project_id is required") + watsonx_auth_payload = self._prepare_payload( + model=model, + api_params=watsonx_api_params, + ) + return { "inputs": input, - "model_id": model, - "project_id": project_id, "parameters": optional_params, + **watsonx_auth_payload, } def get_complete_url( @@ -58,12 +59,6 @@ class IBMWatsonXEmbeddingConfig(IBMWatsonXMixin, BaseEmbeddingConfig): url = self._get_base_url(api_base=api_base) endpoint = WatsonXAIEndpoint.EMBEDDINGS.value if model.startswith("deployment/"): - # deployment models are passed in as 'deployment/' - if optional_params.get("space_id") is None: - raise WatsonXAIError( - status_code=401, - message="Error: space_id is required for models called using the 'deployment/' endpoint. Pass in the space_id as a parameter or set it in the WX_SPACE_ID environment variable.", - ) deployment_id = "/".join(model.split("/")[1:]) endpoint = endpoint.format(deployment_id=deployment_id) url = url.rstrip("/") + endpoint diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index aaab76842e..1722a4d796 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -6,7 +6,7 @@ model_list: api_base: https://exampleopenaiendpoint-production.up.railway.app - model_name: openai-o1 litellm_params: - model: bedrock/anthropic.claude-3-sonnet-20240229-v1:0 + model: openai/random_sleep api_base: http://0.0.0.0:8090 timeout: 2 num_retries: 0 diff --git a/litellm/types/utils.py b/litellm/types/utils.py index c657143cd5..d60b263052 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -18,7 +18,7 @@ from openai.types.moderation import ( CategoryScores, ) from openai.types.moderation_create_response import Moderation, ModerationCreateResponse -from pydantic import BaseModel, ConfigDict, PrivateAttr +from pydantic import BaseModel, ConfigDict, Field, PrivateAttr from typing_extensions import Callable, Dict, Required, TypedDict, override from ..litellm_core_utils.core_helpers import map_finish_reason @@ -438,6 +438,9 @@ class Message(OpenAIObject): tool_calls: Optional[List[ChatCompletionMessageToolCall]] function_call: Optional[FunctionCall] audio: Optional[ChatCompletionAudioResponse] = None + provider_specific_fields: Optional[Dict[str, Any]] = Field( + default=None, exclude=True + ) def __init__( self, @@ -446,6 +449,7 @@ class Message(OpenAIObject): function_call=None, tool_calls: Optional[list] = None, audio: Optional[ChatCompletionAudioResponse] = None, + provider_specific_fields: Optional[Dict[str, Any]] = None, **params, ): init_values: Dict[str, Any] = { @@ -481,6 +485,9 @@ class Message(OpenAIObject): # OpenAI compatible APIs like mistral API will raise an error if audio is passed in del self.audio + if provider_specific_fields: # set if provider_specific_fields is not empty + self.provider_specific_fields = provider_specific_fields + def get(self, key, default=None): # Custom .get() method to access attributes with a default value if the attribute doesn't exist return getattr(self, key, default) @@ -502,6 +509,10 @@ class Message(OpenAIObject): class Delta(OpenAIObject): + provider_specific_fields: Optional[Dict[str, Any]] = Field( + default=None, exclude=True + ) + def __init__( self, content=None, @@ -511,14 +522,20 @@ class Delta(OpenAIObject): audio: Optional[ChatCompletionAudioResponse] = None, **params, ): + provider_specific_fields: Dict[str, Any] = {} + if "reasoning_content" in params: + provider_specific_fields["reasoning_content"] = params["reasoning_content"] + del params["reasoning_content"] super(Delta, self).__init__(**params) self.content = content self.role = role - + self.provider_specific_fields = provider_specific_fields # Set default values and correct types self.function_call: Optional[Union[FunctionCall, Any]] = None self.tool_calls: Optional[List[Union[ChatCompletionDeltaToolCall, Any]]] = None self.audio: Optional[ChatCompletionAudioResponse] = None + if provider_specific_fields: # set if provider_specific_fields is not empty + self.provider_specific_fields = provider_specific_fields if function_call is not None and isinstance(function_call, dict): self.function_call = FunctionCall(**function_call) diff --git a/tests/llm_translation/test_watsonx.py b/tests/llm_translation/test_watsonx.py index 2246a3fd87..d788541dad 100644 --- a/tests/llm_translation/test_watsonx.py +++ b/tests/llm_translation/test_watsonx.py @@ -8,11 +8,12 @@ sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path import litellm -from litellm import completion +from litellm import completion, embedding from litellm.llms.watsonx.common_utils import IBMWatsonXMixin from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler from unittest.mock import patch, MagicMock, AsyncMock, Mock import pytest +from typing import Optional @pytest.fixture @@ -21,6 +22,7 @@ def watsonx_chat_completion_call(): model="watsonx/my-test-model", messages=None, api_key="test_api_key", + space_id: Optional[str] = None, headers=None, client=None, patch_token_call=True, @@ -41,24 +43,90 @@ def watsonx_chat_completion_call(): with patch.object(client, "post") as mock_post, patch.object( litellm.module_level_client, "post", return_value=mock_response ) as mock_get: - completion( - model=model, - messages=messages, - api_key=api_key, - headers=headers or {}, - client=client, - ) + try: + completion( + model=model, + messages=messages, + api_key=api_key, + headers=headers or {}, + client=client, + space_id=space_id, + ) + except Exception as e: + print(e) return mock_post, mock_get else: with patch.object(client, "post") as mock_post: - completion( - model=model, - messages=messages, - api_key=api_key, - headers=headers or {}, - client=client, - ) + try: + completion( + model=model, + messages=messages, + api_key=api_key, + headers=headers or {}, + client=client, + space_id=space_id, + ) + except Exception as e: + print(e) + return mock_post, None + + return _call + + +@pytest.fixture +def watsonx_embedding_call(): + def _call( + model="watsonx/my-test-model", + input=None, + api_key="test_api_key", + space_id: Optional[str] = None, + headers=None, + client=None, + patch_token_call=True, + ): + if input is None: + input = ["Hello, how are you?"] + if client is None: + client = HTTPHandler() + + if patch_token_call: + mock_response = Mock() + mock_response.json.return_value = { + "access_token": "mock_access_token", + "expires_in": 3600, + } + mock_response.raise_for_status = Mock() # No-op to simulate no exception + + with patch.object(client, "post") as mock_post, patch.object( + litellm.module_level_client, "post", return_value=mock_response + ) as mock_get: + try: + embedding( + model=model, + input=input, + api_key=api_key, + headers=headers or {}, + client=client, + space_id=space_id, + ) + except Exception as e: + print(e) + + return mock_post, mock_get + else: + with patch.object(client, "post") as mock_post: + try: + embedding( + model=model, + input=input, + api_key=api_key, + headers=headers or {}, + client=client, + space_id=space_id, + ) + except Exception as e: + print(e) return mock_post, None return _call @@ -118,3 +186,35 @@ def test_watsonx_chat_completions_endpoint(watsonx_chat_completion_call): assert mock_post.call_count == 1 assert "deployment" not in mock_post.call_args.kwargs["url"] + + +@pytest.mark.parametrize( + "model", + [ + "watsonx/deployment/", + "watsonx_text/deployment/", + ], +) +def test_watsonx_deployment_space_id(monkeypatch, watsonx_chat_completion_call, model): + my_fake_space_id = "xxx-xxx-xxx-xxx-xxx" + monkeypatch.setenv("WATSONX_SPACE_ID", my_fake_space_id) + + mock_post, _ = watsonx_chat_completion_call( + model=model, + messages=[{"content": "Hello, how are you?", "role": "user"}], + ) + + assert mock_post.call_count == 1 + json_data = json.loads(mock_post.call_args.kwargs["data"]) + assert my_fake_space_id == json_data["space_id"] + + +def test_watsonx_deployment_space_id_embedding(monkeypatch, watsonx_embedding_call): + my_fake_space_id = "xxx-xxx-xxx-xxx-xxx" + monkeypatch.setenv("WATSONX_SPACE_ID", my_fake_space_id) + + mock_post, _ = watsonx_embedding_call(model="watsonx/deployment/my-test-model") + + assert mock_post.call_count == 1 + json_data = json.loads(mock_post.call_args.kwargs["data"]) + assert my_fake_space_id == json_data["space_id"] diff --git a/tests/local_testing/test_completion.py b/tests/local_testing/test_completion.py index 02d91cfb65..466369ef6e 100644 --- a/tests/local_testing/test_completion.py +++ b/tests/local_testing/test_completion.py @@ -4537,3 +4537,16 @@ def test_humanloop_completion(monkeypatch): prompt_variables={"person": "John"}, messages=[{"role": "user", "content": "Tell me a joke."}], ) + + +def test_deepseek_reasoning_content_completion(): + litellm.set_verbose = True + resp = litellm.completion( + model="deepseek/deepseek-reasoner", + messages=[{"role": "user", "content": "Tell me a joke."}], + ) + + assert ( + resp.choices[0].message.provider_specific_fields["reasoning_content"] + is not None + ) diff --git a/tests/local_testing/test_streaming.py b/tests/local_testing/test_streaming.py index 793106368d..06e2b9156d 100644 --- a/tests/local_testing/test_streaming.py +++ b/tests/local_testing/test_streaming.py @@ -4063,3 +4063,21 @@ def test_mock_response_iterator_tool_use(): response_chunk = completion_stream._chunk_parser(chunk_data=response) assert response_chunk["tool_use"] is not None + + +def test_deepseek_reasoning_content_completion(): + litellm.set_verbose = True + resp = litellm.completion( + model="deepseek/deepseek-reasoner", + messages=[{"role": "user", "content": "Tell me a joke."}], + stream=True, + ) + + reasoning_content_exists = False + for chunk in resp: + print(f"chunk: {chunk}") + if chunk.choices[0].delta.content is not None: + if "reasoning_content" in chunk.choices[0].delta.provider_specific_fields: + reasoning_content_exists = True + break + assert reasoning_content_exists