mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
Deepseek r1 support + watsonx qa improvements (#7907)
* fix(types/utils.py): support returning 'reasoning_content' for deepseek models Fixes https://github.com/BerriAI/litellm/issues/7877#issuecomment-2603813218 * fix(convert_dict_to_response.py): return deepseek response in provider_specific_field allows for separating openai vs. non-openai params in model response * fix(utils.py): support 'provider_specific_field' in delta chunk as well allows deepseek reasoning content chunk to be returned to user from stream as well Fixes https://github.com/BerriAI/litellm/issues/7877#issuecomment-2603813218 * fix(watsonx/chat/handler.py): fix passing space id to watsonx on chat route * fix(watsonx/): fix watsonx_text/ route with space id * fix(watsonx/): qa item - also adds better unit testing for watsonx embedding calls * fix(utils.py): rename to '..fields' * fix: fix linting errors * fix(utils.py): fix typing - don't show provider-specific field if none or empty - prevents default respons e from being non-oai compatible * fix: cleanup unused imports * docs(deepseek.md): add docs for deepseek reasoning model
This commit is contained in:
parent
26a79a533d
commit
76795dba39
12 changed files with 281 additions and 63 deletions
|
@ -1,3 +1,6 @@
|
||||||
|
import Tabs from '@theme/Tabs';
|
||||||
|
import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
# Deepseek
|
# Deepseek
|
||||||
https://deepseek.com/
|
https://deepseek.com/
|
||||||
|
|
||||||
|
@ -52,3 +55,72 @@ We support ALL Deepseek models, just set `deepseek/` as a prefix when sending co
|
||||||
| deepseek-coder | `completion(model="deepseek/deepseek-coder", messages)` |
|
| deepseek-coder | `completion(model="deepseek/deepseek-coder", messages)` |
|
||||||
|
|
||||||
|
|
||||||
|
## Reasoning Models
|
||||||
|
| Model Name | Function Call |
|
||||||
|
|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||||
|
| deepseek-reasoner | `completion(model="deepseek/deepseek-reasoner", messages)` |
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
|
```python
|
||||||
|
from litellm import completion
|
||||||
|
import os
|
||||||
|
|
||||||
|
os.environ['DEEPSEEK_API_KEY'] = ""
|
||||||
|
resp = completion(
|
||||||
|
model="deepseek/deepseek-reasoner",
|
||||||
|
messages=[{"role": "user", "content": "Tell me a joke."}],
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
resp.choices[0].message.provider_specific_fields["reasoning_content"]
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="proxy" label="PROXY">
|
||||||
|
|
||||||
|
1. Setup config.yaml
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: deepseek-reasoner
|
||||||
|
litellm_params:
|
||||||
|
model: deepseek/deepseek-reasoner
|
||||||
|
api_key: os.environ/DEEPSEEK_API_KEY
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Run proxy
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python litellm/proxy/main.py
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Test it!
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -L -X POST 'http://0.0.0.0:4000/v1/chat/completions' \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-H 'Authorization: Bearer sk-1234' \
|
||||||
|
-d '{
|
||||||
|
"model": "deepseek-reasoner",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "Hi, how are you ?"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
</Tabs>
|
|
@ -337,7 +337,6 @@ def convert_to_model_response_object( # noqa: PLR0915
|
||||||
] = None, # used for supporting 'json_schema' on older models
|
] = None, # used for supporting 'json_schema' on older models
|
||||||
):
|
):
|
||||||
received_args = locals()
|
received_args = locals()
|
||||||
|
|
||||||
additional_headers = get_response_headers(_response_headers)
|
additional_headers = get_response_headers(_response_headers)
|
||||||
|
|
||||||
if hidden_params is None:
|
if hidden_params is None:
|
||||||
|
@ -411,12 +410,18 @@ def convert_to_model_response_object( # noqa: PLR0915
|
||||||
message = litellm.Message(content=json_mode_content_str)
|
message = litellm.Message(content=json_mode_content_str)
|
||||||
finish_reason = "stop"
|
finish_reason = "stop"
|
||||||
if message is None:
|
if message is None:
|
||||||
|
provider_specific_fields = {}
|
||||||
|
message_keys = Message.model_fields.keys()
|
||||||
|
for field in choice["message"].keys():
|
||||||
|
if field not in message_keys:
|
||||||
|
provider_specific_fields[field] = choice["message"][field]
|
||||||
message = Message(
|
message = Message(
|
||||||
content=choice["message"].get("content", None),
|
content=choice["message"].get("content", None),
|
||||||
role=choice["message"]["role"] or "assistant",
|
role=choice["message"]["role"] or "assistant",
|
||||||
function_call=choice["message"].get("function_call", None),
|
function_call=choice["message"].get("function_call", None),
|
||||||
tool_calls=tool_calls,
|
tool_calls=tool_calls,
|
||||||
audio=choice["message"].get("audio", None),
|
audio=choice["message"].get("audio", None),
|
||||||
|
provider_specific_fields=provider_specific_fields,
|
||||||
)
|
)
|
||||||
finish_reason = choice.get("finish_reason", None)
|
finish_reason = choice.get("finish_reason", None)
|
||||||
if finish_reason is None:
|
if finish_reason is None:
|
||||||
|
|
|
@ -51,6 +51,13 @@ class WatsonXChatHandler(OpenAILikeChatHandler):
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
## UPDATE PAYLOAD (optional params)
|
||||||
|
watsonx_auth_payload = watsonx_chat_transformation._prepare_payload(
|
||||||
|
model=model,
|
||||||
|
api_params=api_params,
|
||||||
|
)
|
||||||
|
optional_params.update(watsonx_auth_payload)
|
||||||
|
|
||||||
## GET API URL
|
## GET API URL
|
||||||
api_base = watsonx_chat_transformation.get_complete_url(
|
api_base = watsonx_chat_transformation.get_complete_url(
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
@ -59,13 +66,6 @@ class WatsonXChatHandler(OpenAILikeChatHandler):
|
||||||
stream=optional_params.get("stream", False),
|
stream=optional_params.get("stream", False),
|
||||||
)
|
)
|
||||||
|
|
||||||
## UPDATE PAYLOAD (optional params)
|
|
||||||
watsonx_auth_payload = watsonx_chat_transformation._prepare_payload(
|
|
||||||
model=model,
|
|
||||||
api_params=api_params,
|
|
||||||
)
|
|
||||||
optional_params.update(watsonx_auth_payload)
|
|
||||||
|
|
||||||
return super().completion(
|
return super().completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
|
|
@ -7,11 +7,11 @@ Docs: https://cloud.ibm.com/apidocs/watsonx-ai#text-chat
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
from litellm.secret_managers.main import get_secret_str
|
from litellm.secret_managers.main import get_secret_str
|
||||||
from litellm.types.llms.watsonx import WatsonXAIEndpoint, WatsonXAPIParams
|
from litellm.types.llms.watsonx import WatsonXAIEndpoint
|
||||||
|
|
||||||
from ....utils import _remove_additional_properties, _remove_strict_from_schema
|
from ....utils import _remove_additional_properties, _remove_strict_from_schema
|
||||||
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
|
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
|
||||||
from ..common_utils import IBMWatsonXMixin, WatsonXAIError
|
from ..common_utils import IBMWatsonXMixin
|
||||||
|
|
||||||
|
|
||||||
class IBMWatsonXChatConfig(IBMWatsonXMixin, OpenAIGPTConfig):
|
class IBMWatsonXChatConfig(IBMWatsonXMixin, OpenAIGPTConfig):
|
||||||
|
@ -87,12 +87,6 @@ class IBMWatsonXChatConfig(IBMWatsonXMixin, OpenAIGPTConfig):
|
||||||
) -> str:
|
) -> str:
|
||||||
url = self._get_base_url(api_base=api_base)
|
url = self._get_base_url(api_base=api_base)
|
||||||
if model.startswith("deployment/"):
|
if model.startswith("deployment/"):
|
||||||
# deployment models are passed in as 'deployment/<deployment_id>'
|
|
||||||
if optional_params.get("space_id") is None:
|
|
||||||
raise WatsonXAIError(
|
|
||||||
status_code=401,
|
|
||||||
message="Error: space_id is required for models called using the 'deployment/' endpoint. Pass in the space_id as a parameter or set it in the WX_SPACE_ID environment variable.",
|
|
||||||
)
|
|
||||||
deployment_id = "/".join(model.split("/")[1:])
|
deployment_id = "/".join(model.split("/")[1:])
|
||||||
endpoint = (
|
endpoint = (
|
||||||
WatsonXAIEndpoint.DEPLOYMENT_CHAT_STREAM.value
|
WatsonXAIEndpoint.DEPLOYMENT_CHAT_STREAM.value
|
||||||
|
@ -113,11 +107,3 @@ class IBMWatsonXChatConfig(IBMWatsonXMixin, OpenAIGPTConfig):
|
||||||
url=url, api_version=optional_params.pop("api_version", None)
|
url=url, api_version=optional_params.pop("api_version", None)
|
||||||
)
|
)
|
||||||
return url
|
return url
|
||||||
|
|
||||||
def _prepare_payload(self, model: str, api_params: WatsonXAPIParams) -> dict:
|
|
||||||
payload: dict = {}
|
|
||||||
if model.startswith("deployment/"):
|
|
||||||
return payload
|
|
||||||
payload["model_id"] = model
|
|
||||||
payload["project_id"] = api_params["project_id"]
|
|
||||||
return payload
|
|
||||||
|
|
|
@ -275,3 +275,17 @@ class IBMWatsonXMixin:
|
||||||
return WatsonXCredentials(
|
return WatsonXCredentials(
|
||||||
api_key=api_key, api_base=api_base, token=cast(Optional[str], token)
|
api_key=api_key, api_base=api_base, token=cast(Optional[str], token)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _prepare_payload(self, model: str, api_params: WatsonXAPIParams) -> dict:
|
||||||
|
payload: dict = {}
|
||||||
|
if model.startswith("deployment/"):
|
||||||
|
if api_params["space_id"] is None:
|
||||||
|
raise WatsonXAIError(
|
||||||
|
status_code=401,
|
||||||
|
message="Error: space_id is required for models called using the 'deployment/' endpoint. Pass in the space_id as a parameter or set it in the WX_SPACE_ID environment variable.",
|
||||||
|
)
|
||||||
|
payload["space_id"] = api_params["space_id"]
|
||||||
|
return payload
|
||||||
|
payload["model_id"] = model
|
||||||
|
payload["project_id"] = api_params["project_id"]
|
||||||
|
return payload
|
||||||
|
|
|
@ -246,17 +246,20 @@ class IBMWatsonXAIConfig(IBMWatsonXMixin, BaseConfig):
|
||||||
extra_body_params = optional_params.pop("extra_body", {})
|
extra_body_params = optional_params.pop("extra_body", {})
|
||||||
optional_params.update(extra_body_params)
|
optional_params.update(extra_body_params)
|
||||||
watsonx_api_params = _get_api_params(params=optional_params)
|
watsonx_api_params = _get_api_params(params=optional_params)
|
||||||
|
|
||||||
|
watsonx_auth_payload = self._prepare_payload(
|
||||||
|
model=model,
|
||||||
|
api_params=watsonx_api_params,
|
||||||
|
)
|
||||||
|
|
||||||
# init the payload to the text generation call
|
# init the payload to the text generation call
|
||||||
payload = {
|
payload = {
|
||||||
"input": prompt,
|
"input": prompt,
|
||||||
"moderations": optional_params.pop("moderations", {}),
|
"moderations": optional_params.pop("moderations", {}),
|
||||||
"parameters": optional_params,
|
"parameters": optional_params,
|
||||||
|
**watsonx_auth_payload,
|
||||||
}
|
}
|
||||||
|
|
||||||
if not model.startswith("deployment/"):
|
|
||||||
payload["model_id"] = model
|
|
||||||
payload["project_id"] = watsonx_api_params["project_id"]
|
|
||||||
|
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
def transform_response(
|
def transform_response(
|
||||||
|
@ -320,11 +323,6 @@ class IBMWatsonXAIConfig(IBMWatsonXMixin, BaseConfig):
|
||||||
url = self._get_base_url(api_base=api_base)
|
url = self._get_base_url(api_base=api_base)
|
||||||
if model.startswith("deployment/"):
|
if model.startswith("deployment/"):
|
||||||
# deployment models are passed in as 'deployment/<deployment_id>'
|
# deployment models are passed in as 'deployment/<deployment_id>'
|
||||||
if optional_params.get("space_id") is None:
|
|
||||||
raise WatsonXAIError(
|
|
||||||
status_code=401,
|
|
||||||
message="Error: space_id is required for models called using the 'deployment/' endpoint. Pass in the space_id as a parameter or set it in the WX_SPACE_ID environment variable.",
|
|
||||||
)
|
|
||||||
deployment_id = "/".join(model.split("/")[1:])
|
deployment_id = "/".join(model.split("/")[1:])
|
||||||
endpoint = (
|
endpoint = (
|
||||||
WatsonXAIEndpoint.DEPLOYMENT_TEXT_GENERATION_STREAM.value
|
WatsonXAIEndpoint.DEPLOYMENT_TEXT_GENERATION_STREAM.value
|
||||||
|
|
|
@ -14,7 +14,7 @@ from litellm.types.llms.openai import AllEmbeddingInputValues
|
||||||
from litellm.types.llms.watsonx import WatsonXAIEndpoint
|
from litellm.types.llms.watsonx import WatsonXAIEndpoint
|
||||||
from litellm.types.utils import EmbeddingResponse, Usage
|
from litellm.types.utils import EmbeddingResponse, Usage
|
||||||
|
|
||||||
from ..common_utils import IBMWatsonXMixin, WatsonXAIError, _get_api_params
|
from ..common_utils import IBMWatsonXMixin, _get_api_params
|
||||||
|
|
||||||
|
|
||||||
class IBMWatsonXEmbeddingConfig(IBMWatsonXMixin, BaseEmbeddingConfig):
|
class IBMWatsonXEmbeddingConfig(IBMWatsonXMixin, BaseEmbeddingConfig):
|
||||||
|
@ -38,14 +38,15 @@ class IBMWatsonXEmbeddingConfig(IBMWatsonXMixin, BaseEmbeddingConfig):
|
||||||
headers: dict,
|
headers: dict,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
watsonx_api_params = _get_api_params(params=optional_params)
|
watsonx_api_params = _get_api_params(params=optional_params)
|
||||||
project_id = watsonx_api_params["project_id"]
|
watsonx_auth_payload = self._prepare_payload(
|
||||||
if not project_id:
|
model=model,
|
||||||
raise ValueError("project_id is required")
|
api_params=watsonx_api_params,
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"inputs": input,
|
"inputs": input,
|
||||||
"model_id": model,
|
|
||||||
"project_id": project_id,
|
|
||||||
"parameters": optional_params,
|
"parameters": optional_params,
|
||||||
|
**watsonx_auth_payload,
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_complete_url(
|
def get_complete_url(
|
||||||
|
@ -58,12 +59,6 @@ class IBMWatsonXEmbeddingConfig(IBMWatsonXMixin, BaseEmbeddingConfig):
|
||||||
url = self._get_base_url(api_base=api_base)
|
url = self._get_base_url(api_base=api_base)
|
||||||
endpoint = WatsonXAIEndpoint.EMBEDDINGS.value
|
endpoint = WatsonXAIEndpoint.EMBEDDINGS.value
|
||||||
if model.startswith("deployment/"):
|
if model.startswith("deployment/"):
|
||||||
# deployment models are passed in as 'deployment/<deployment_id>'
|
|
||||||
if optional_params.get("space_id") is None:
|
|
||||||
raise WatsonXAIError(
|
|
||||||
status_code=401,
|
|
||||||
message="Error: space_id is required for models called using the 'deployment/' endpoint. Pass in the space_id as a parameter or set it in the WX_SPACE_ID environment variable.",
|
|
||||||
)
|
|
||||||
deployment_id = "/".join(model.split("/")[1:])
|
deployment_id = "/".join(model.split("/")[1:])
|
||||||
endpoint = endpoint.format(deployment_id=deployment_id)
|
endpoint = endpoint.format(deployment_id=deployment_id)
|
||||||
url = url.rstrip("/") + endpoint
|
url = url.rstrip("/") + endpoint
|
||||||
|
|
|
@ -6,7 +6,7 @@ model_list:
|
||||||
api_base: https://exampleopenaiendpoint-production.up.railway.app
|
api_base: https://exampleopenaiendpoint-production.up.railway.app
|
||||||
- model_name: openai-o1
|
- model_name: openai-o1
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: bedrock/anthropic.claude-3-sonnet-20240229-v1:0
|
model: openai/random_sleep
|
||||||
api_base: http://0.0.0.0:8090
|
api_base: http://0.0.0.0:8090
|
||||||
timeout: 2
|
timeout: 2
|
||||||
num_retries: 0
|
num_retries: 0
|
||||||
|
|
|
@ -18,7 +18,7 @@ from openai.types.moderation import (
|
||||||
CategoryScores,
|
CategoryScores,
|
||||||
)
|
)
|
||||||
from openai.types.moderation_create_response import Moderation, ModerationCreateResponse
|
from openai.types.moderation_create_response import Moderation, ModerationCreateResponse
|
||||||
from pydantic import BaseModel, ConfigDict, PrivateAttr
|
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
|
||||||
from typing_extensions import Callable, Dict, Required, TypedDict, override
|
from typing_extensions import Callable, Dict, Required, TypedDict, override
|
||||||
|
|
||||||
from ..litellm_core_utils.core_helpers import map_finish_reason
|
from ..litellm_core_utils.core_helpers import map_finish_reason
|
||||||
|
@ -438,6 +438,9 @@ class Message(OpenAIObject):
|
||||||
tool_calls: Optional[List[ChatCompletionMessageToolCall]]
|
tool_calls: Optional[List[ChatCompletionMessageToolCall]]
|
||||||
function_call: Optional[FunctionCall]
|
function_call: Optional[FunctionCall]
|
||||||
audio: Optional[ChatCompletionAudioResponse] = None
|
audio: Optional[ChatCompletionAudioResponse] = None
|
||||||
|
provider_specific_fields: Optional[Dict[str, Any]] = Field(
|
||||||
|
default=None, exclude=True
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -446,6 +449,7 @@ class Message(OpenAIObject):
|
||||||
function_call=None,
|
function_call=None,
|
||||||
tool_calls: Optional[list] = None,
|
tool_calls: Optional[list] = None,
|
||||||
audio: Optional[ChatCompletionAudioResponse] = None,
|
audio: Optional[ChatCompletionAudioResponse] = None,
|
||||||
|
provider_specific_fields: Optional[Dict[str, Any]] = None,
|
||||||
**params,
|
**params,
|
||||||
):
|
):
|
||||||
init_values: Dict[str, Any] = {
|
init_values: Dict[str, Any] = {
|
||||||
|
@ -481,6 +485,9 @@ class Message(OpenAIObject):
|
||||||
# OpenAI compatible APIs like mistral API will raise an error if audio is passed in
|
# OpenAI compatible APIs like mistral API will raise an error if audio is passed in
|
||||||
del self.audio
|
del self.audio
|
||||||
|
|
||||||
|
if provider_specific_fields: # set if provider_specific_fields is not empty
|
||||||
|
self.provider_specific_fields = provider_specific_fields
|
||||||
|
|
||||||
def get(self, key, default=None):
|
def get(self, key, default=None):
|
||||||
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
|
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
|
||||||
return getattr(self, key, default)
|
return getattr(self, key, default)
|
||||||
|
@ -502,6 +509,10 @@ class Message(OpenAIObject):
|
||||||
|
|
||||||
|
|
||||||
class Delta(OpenAIObject):
|
class Delta(OpenAIObject):
|
||||||
|
provider_specific_fields: Optional[Dict[str, Any]] = Field(
|
||||||
|
default=None, exclude=True
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
content=None,
|
content=None,
|
||||||
|
@ -511,14 +522,20 @@ class Delta(OpenAIObject):
|
||||||
audio: Optional[ChatCompletionAudioResponse] = None,
|
audio: Optional[ChatCompletionAudioResponse] = None,
|
||||||
**params,
|
**params,
|
||||||
):
|
):
|
||||||
|
provider_specific_fields: Dict[str, Any] = {}
|
||||||
|
if "reasoning_content" in params:
|
||||||
|
provider_specific_fields["reasoning_content"] = params["reasoning_content"]
|
||||||
|
del params["reasoning_content"]
|
||||||
super(Delta, self).__init__(**params)
|
super(Delta, self).__init__(**params)
|
||||||
self.content = content
|
self.content = content
|
||||||
self.role = role
|
self.role = role
|
||||||
|
self.provider_specific_fields = provider_specific_fields
|
||||||
# Set default values and correct types
|
# Set default values and correct types
|
||||||
self.function_call: Optional[Union[FunctionCall, Any]] = None
|
self.function_call: Optional[Union[FunctionCall, Any]] = None
|
||||||
self.tool_calls: Optional[List[Union[ChatCompletionDeltaToolCall, Any]]] = None
|
self.tool_calls: Optional[List[Union[ChatCompletionDeltaToolCall, Any]]] = None
|
||||||
self.audio: Optional[ChatCompletionAudioResponse] = None
|
self.audio: Optional[ChatCompletionAudioResponse] = None
|
||||||
|
if provider_specific_fields: # set if provider_specific_fields is not empty
|
||||||
|
self.provider_specific_fields = provider_specific_fields
|
||||||
|
|
||||||
if function_call is not None and isinstance(function_call, dict):
|
if function_call is not None and isinstance(function_call, dict):
|
||||||
self.function_call = FunctionCall(**function_call)
|
self.function_call = FunctionCall(**function_call)
|
||||||
|
|
|
@ -8,11 +8,12 @@ sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import completion
|
from litellm import completion, embedding
|
||||||
from litellm.llms.watsonx.common_utils import IBMWatsonXMixin
|
from litellm.llms.watsonx.common_utils import IBMWatsonXMixin
|
||||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler
|
||||||
from unittest.mock import patch, MagicMock, AsyncMock, Mock
|
from unittest.mock import patch, MagicMock, AsyncMock, Mock
|
||||||
import pytest
|
import pytest
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@ -21,6 +22,7 @@ def watsonx_chat_completion_call():
|
||||||
model="watsonx/my-test-model",
|
model="watsonx/my-test-model",
|
||||||
messages=None,
|
messages=None,
|
||||||
api_key="test_api_key",
|
api_key="test_api_key",
|
||||||
|
space_id: Optional[str] = None,
|
||||||
headers=None,
|
headers=None,
|
||||||
client=None,
|
client=None,
|
||||||
patch_token_call=True,
|
patch_token_call=True,
|
||||||
|
@ -41,24 +43,90 @@ def watsonx_chat_completion_call():
|
||||||
with patch.object(client, "post") as mock_post, patch.object(
|
with patch.object(client, "post") as mock_post, patch.object(
|
||||||
litellm.module_level_client, "post", return_value=mock_response
|
litellm.module_level_client, "post", return_value=mock_response
|
||||||
) as mock_get:
|
) as mock_get:
|
||||||
completion(
|
try:
|
||||||
model=model,
|
completion(
|
||||||
messages=messages,
|
model=model,
|
||||||
api_key=api_key,
|
messages=messages,
|
||||||
headers=headers or {},
|
api_key=api_key,
|
||||||
client=client,
|
headers=headers or {},
|
||||||
)
|
client=client,
|
||||||
|
space_id=space_id,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
return mock_post, mock_get
|
return mock_post, mock_get
|
||||||
else:
|
else:
|
||||||
with patch.object(client, "post") as mock_post:
|
with patch.object(client, "post") as mock_post:
|
||||||
completion(
|
try:
|
||||||
model=model,
|
completion(
|
||||||
messages=messages,
|
model=model,
|
||||||
api_key=api_key,
|
messages=messages,
|
||||||
headers=headers or {},
|
api_key=api_key,
|
||||||
client=client,
|
headers=headers or {},
|
||||||
)
|
client=client,
|
||||||
|
space_id=space_id,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
return mock_post, None
|
||||||
|
|
||||||
|
return _call
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def watsonx_embedding_call():
|
||||||
|
def _call(
|
||||||
|
model="watsonx/my-test-model",
|
||||||
|
input=None,
|
||||||
|
api_key="test_api_key",
|
||||||
|
space_id: Optional[str] = None,
|
||||||
|
headers=None,
|
||||||
|
client=None,
|
||||||
|
patch_token_call=True,
|
||||||
|
):
|
||||||
|
if input is None:
|
||||||
|
input = ["Hello, how are you?"]
|
||||||
|
if client is None:
|
||||||
|
client = HTTPHandler()
|
||||||
|
|
||||||
|
if patch_token_call:
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"access_token": "mock_access_token",
|
||||||
|
"expires_in": 3600,
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status = Mock() # No-op to simulate no exception
|
||||||
|
|
||||||
|
with patch.object(client, "post") as mock_post, patch.object(
|
||||||
|
litellm.module_level_client, "post", return_value=mock_response
|
||||||
|
) as mock_get:
|
||||||
|
try:
|
||||||
|
embedding(
|
||||||
|
model=model,
|
||||||
|
input=input,
|
||||||
|
api_key=api_key,
|
||||||
|
headers=headers or {},
|
||||||
|
client=client,
|
||||||
|
space_id=space_id,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
return mock_post, mock_get
|
||||||
|
else:
|
||||||
|
with patch.object(client, "post") as mock_post:
|
||||||
|
try:
|
||||||
|
embedding(
|
||||||
|
model=model,
|
||||||
|
input=input,
|
||||||
|
api_key=api_key,
|
||||||
|
headers=headers or {},
|
||||||
|
client=client,
|
||||||
|
space_id=space_id,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
return mock_post, None
|
return mock_post, None
|
||||||
|
|
||||||
return _call
|
return _call
|
||||||
|
@ -118,3 +186,35 @@ def test_watsonx_chat_completions_endpoint(watsonx_chat_completion_call):
|
||||||
|
|
||||||
assert mock_post.call_count == 1
|
assert mock_post.call_count == 1
|
||||||
assert "deployment" not in mock_post.call_args.kwargs["url"]
|
assert "deployment" not in mock_post.call_args.kwargs["url"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model",
|
||||||
|
[
|
||||||
|
"watsonx/deployment/<xxxx.xxx.xxx.xxxx>",
|
||||||
|
"watsonx_text/deployment/<xxxx.xxx.xxx.xxxx>",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_watsonx_deployment_space_id(monkeypatch, watsonx_chat_completion_call, model):
|
||||||
|
my_fake_space_id = "xxx-xxx-xxx-xxx-xxx"
|
||||||
|
monkeypatch.setenv("WATSONX_SPACE_ID", my_fake_space_id)
|
||||||
|
|
||||||
|
mock_post, _ = watsonx_chat_completion_call(
|
||||||
|
model=model,
|
||||||
|
messages=[{"content": "Hello, how are you?", "role": "user"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mock_post.call_count == 1
|
||||||
|
json_data = json.loads(mock_post.call_args.kwargs["data"])
|
||||||
|
assert my_fake_space_id == json_data["space_id"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_watsonx_deployment_space_id_embedding(monkeypatch, watsonx_embedding_call):
|
||||||
|
my_fake_space_id = "xxx-xxx-xxx-xxx-xxx"
|
||||||
|
monkeypatch.setenv("WATSONX_SPACE_ID", my_fake_space_id)
|
||||||
|
|
||||||
|
mock_post, _ = watsonx_embedding_call(model="watsonx/deployment/my-test-model")
|
||||||
|
|
||||||
|
assert mock_post.call_count == 1
|
||||||
|
json_data = json.loads(mock_post.call_args.kwargs["data"])
|
||||||
|
assert my_fake_space_id == json_data["space_id"]
|
||||||
|
|
|
@ -4537,3 +4537,16 @@ def test_humanloop_completion(monkeypatch):
|
||||||
prompt_variables={"person": "John"},
|
prompt_variables={"person": "John"},
|
||||||
messages=[{"role": "user", "content": "Tell me a joke."}],
|
messages=[{"role": "user", "content": "Tell me a joke."}],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_deepseek_reasoning_content_completion():
|
||||||
|
litellm.set_verbose = True
|
||||||
|
resp = litellm.completion(
|
||||||
|
model="deepseek/deepseek-reasoner",
|
||||||
|
messages=[{"role": "user", "content": "Tell me a joke."}],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
resp.choices[0].message.provider_specific_fields["reasoning_content"]
|
||||||
|
is not None
|
||||||
|
)
|
||||||
|
|
|
@ -4063,3 +4063,21 @@ def test_mock_response_iterator_tool_use():
|
||||||
response_chunk = completion_stream._chunk_parser(chunk_data=response)
|
response_chunk = completion_stream._chunk_parser(chunk_data=response)
|
||||||
|
|
||||||
assert response_chunk["tool_use"] is not None
|
assert response_chunk["tool_use"] is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_deepseek_reasoning_content_completion():
|
||||||
|
litellm.set_verbose = True
|
||||||
|
resp = litellm.completion(
|
||||||
|
model="deepseek/deepseek-reasoner",
|
||||||
|
messages=[{"role": "user", "content": "Tell me a joke."}],
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
reasoning_content_exists = False
|
||||||
|
for chunk in resp:
|
||||||
|
print(f"chunk: {chunk}")
|
||||||
|
if chunk.choices[0].delta.content is not None:
|
||||||
|
if "reasoning_content" in chunk.choices[0].delta.provider_specific_fields:
|
||||||
|
reasoning_content_exists = True
|
||||||
|
break
|
||||||
|
assert reasoning_content_exists
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue