feat(anthropic/chat/transformation.py): support passing user id to anthropic via openai 'user' param

This commit is contained in:
Krrish Dholakia 2024-11-14 12:07:23 +05:30
parent b6c9032454
commit 756d838dfa
4 changed files with 62 additions and 24 deletions

View file

@ -440,8 +440,8 @@ class AnthropicChatCompletion(BaseLLM):
logging_obj,
optional_params: dict,
timeout: Union[float, httpx.Timeout],
litellm_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers={},
client=None,
@ -464,6 +464,7 @@ class AnthropicChatCompletion(BaseLLM):
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers,
_is_function_call=_is_function_call,
is_vertex_request=is_vertex_request,

View file

@ -91,6 +91,7 @@ class AnthropicConfig:
"extra_headers",
"parallel_tool_calls",
"response_format",
"user",
]
def get_cache_control_headers(self) -> dict:
@ -246,6 +247,28 @@ class AnthropicConfig:
anthropic_tools.append(new_tool)
return anthropic_tools
def _map_stop_sequences(
self, stop: Optional[Union[str, List[str]]]
) -> Optional[List[str]]:
new_stop: Optional[List[str]] = None
if isinstance(stop, str):
if (
stop == "\n"
) and litellm.drop_params is True: # anthropic doesn't allow whitespace characters as stop-sequences
return new_stop
new_stop = [stop]
elif isinstance(stop, list):
new_v = []
for v in stop:
if (
v == "\n"
) and litellm.drop_params is True: # anthropic doesn't allow whitespace characters as stop-sequences
continue
new_v.append(v)
if len(new_v) > 0:
new_stop = new_v
return new_stop
def map_openai_params(
self,
non_default_params: dict,
@ -271,26 +294,8 @@ class AnthropicConfig:
optional_params["tool_choice"] = _tool_choice
if param == "stream" and value is True:
optional_params["stream"] = value
if param == "stop":
if isinstance(value, str):
if (
value == "\n"
) and litellm.drop_params is True: # anthropic doesn't allow whitespace characters as stop-sequences
continue
value = [value]
elif isinstance(value, list):
new_v = []
for v in value:
if (
v == "\n"
) and litellm.drop_params is True: # anthropic doesn't allow whitespace characters as stop-sequences
continue
new_v.append(v)
if len(new_v) > 0:
value = new_v
else:
continue
optional_params["stop_sequences"] = value
if param == "stop" and (isinstance(value, str) or isinstance(value, list)):
optional_params["stop_sequences"] = self._map_stop_sequences(value)
if param == "temperature":
optional_params["temperature"] = value
if param == "top_p":
@ -314,7 +319,8 @@ class AnthropicConfig:
optional_params["tools"] = [_tool]
optional_params["tool_choice"] = _tool_choice
optional_params["json_mode"] = True
if param == "user":
optional_params["metadata"] = {"user_id": value}
## VALIDATE REQUEST
"""
Anthropic doesn't support tool calling without `tools=` param specified.
@ -465,6 +471,7 @@ class AnthropicConfig:
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
_is_function_call: bool,
is_vertex_request: bool,
@ -502,6 +509,15 @@ class AnthropicConfig:
if "tools" in optional_params:
_is_function_call = True
## Handle user_id in metadata
_litellm_metadata = litellm_params.get("metadata", None)
if (
_litellm_metadata
and isinstance(_litellm_metadata, dict)
and "user_id" in _litellm_metadata
):
optional_params["metadata"] = {"user_id": _litellm_metadata["user_id"]}
data = {
"messages": anthropic_messages,
**optional_params,

View file

@ -13,8 +13,11 @@ sys.path.insert(
import litellm
from litellm.exceptions import BadRequestError
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.utils import CustomStreamWrapper
from litellm.utils import (
CustomStreamWrapper,
get_supported_openai_params,
get_optional_params,
)
# test_example.py
from abc import ABC, abstractmethod
@ -25,6 +28,11 @@ class BaseLLMChatTest(ABC):
Abstract base test class that enforces a common test across all test classes.
"""
@abstractmethod
def get_default_model_name(self) -> str:
"""Must return the default model name"""
pass
@abstractmethod
def get_base_completion_call_args(self) -> dict:
"""Must return the base completion call args"""

View file

@ -921,3 +921,16 @@ def test_watsonx_text_top_k():
)
print(optional_params)
assert optional_params["top_k"] == 10
def test_forward_user_param():
from litellm.utils import get_supported_openai_params, get_optional_params
model = "claude-3-5-sonnet-20240620"
optional_params = get_optional_params(
model=model,
user="test_user",
custom_llm_provider="anthropic",
)
assert optional_params["metadata"]["user_id"] == "test_user"