feat(anthropic/chat/transformation.py): support passing user id to anthropic via openai 'user' param

This commit is contained in:
Krrish Dholakia 2024-11-14 12:07:23 +05:30
parent b6c9032454
commit 756d838dfa
4 changed files with 62 additions and 24 deletions

View file

@ -440,8 +440,8 @@ class AnthropicChatCompletion(BaseLLM):
logging_obj, logging_obj,
optional_params: dict, optional_params: dict,
timeout: Union[float, httpx.Timeout], timeout: Union[float, httpx.Timeout],
litellm_params: dict,
acompletion=None, acompletion=None,
litellm_params=None,
logger_fn=None, logger_fn=None,
headers={}, headers={},
client=None, client=None,
@ -464,6 +464,7 @@ class AnthropicChatCompletion(BaseLLM):
model=model, model=model,
messages=messages, messages=messages,
optional_params=optional_params, optional_params=optional_params,
litellm_params=litellm_params,
headers=headers, headers=headers,
_is_function_call=_is_function_call, _is_function_call=_is_function_call,
is_vertex_request=is_vertex_request, is_vertex_request=is_vertex_request,

View file

@ -91,6 +91,7 @@ class AnthropicConfig:
"extra_headers", "extra_headers",
"parallel_tool_calls", "parallel_tool_calls",
"response_format", "response_format",
"user",
] ]
def get_cache_control_headers(self) -> dict: def get_cache_control_headers(self) -> dict:
@ -246,6 +247,28 @@ class AnthropicConfig:
anthropic_tools.append(new_tool) anthropic_tools.append(new_tool)
return anthropic_tools return anthropic_tools
def _map_stop_sequences(
self, stop: Optional[Union[str, List[str]]]
) -> Optional[List[str]]:
new_stop: Optional[List[str]] = None
if isinstance(stop, str):
if (
stop == "\n"
) and litellm.drop_params is True: # anthropic doesn't allow whitespace characters as stop-sequences
return new_stop
new_stop = [stop]
elif isinstance(stop, list):
new_v = []
for v in stop:
if (
v == "\n"
) and litellm.drop_params is True: # anthropic doesn't allow whitespace characters as stop-sequences
continue
new_v.append(v)
if len(new_v) > 0:
new_stop = new_v
return new_stop
def map_openai_params( def map_openai_params(
self, self,
non_default_params: dict, non_default_params: dict,
@ -271,26 +294,8 @@ class AnthropicConfig:
optional_params["tool_choice"] = _tool_choice optional_params["tool_choice"] = _tool_choice
if param == "stream" and value is True: if param == "stream" and value is True:
optional_params["stream"] = value optional_params["stream"] = value
if param == "stop": if param == "stop" and (isinstance(value, str) or isinstance(value, list)):
if isinstance(value, str): optional_params["stop_sequences"] = self._map_stop_sequences(value)
if (
value == "\n"
) and litellm.drop_params is True: # anthropic doesn't allow whitespace characters as stop-sequences
continue
value = [value]
elif isinstance(value, list):
new_v = []
for v in value:
if (
v == "\n"
) and litellm.drop_params is True: # anthropic doesn't allow whitespace characters as stop-sequences
continue
new_v.append(v)
if len(new_v) > 0:
value = new_v
else:
continue
optional_params["stop_sequences"] = value
if param == "temperature": if param == "temperature":
optional_params["temperature"] = value optional_params["temperature"] = value
if param == "top_p": if param == "top_p":
@ -314,7 +319,8 @@ class AnthropicConfig:
optional_params["tools"] = [_tool] optional_params["tools"] = [_tool]
optional_params["tool_choice"] = _tool_choice optional_params["tool_choice"] = _tool_choice
optional_params["json_mode"] = True optional_params["json_mode"] = True
if param == "user":
optional_params["metadata"] = {"user_id": value}
## VALIDATE REQUEST ## VALIDATE REQUEST
""" """
Anthropic doesn't support tool calling without `tools=` param specified. Anthropic doesn't support tool calling without `tools=` param specified.
@ -465,6 +471,7 @@ class AnthropicConfig:
model: str, model: str,
messages: List[AllMessageValues], messages: List[AllMessageValues],
optional_params: dict, optional_params: dict,
litellm_params: dict,
headers: dict, headers: dict,
_is_function_call: bool, _is_function_call: bool,
is_vertex_request: bool, is_vertex_request: bool,
@ -502,6 +509,15 @@ class AnthropicConfig:
if "tools" in optional_params: if "tools" in optional_params:
_is_function_call = True _is_function_call = True
## Handle user_id in metadata
_litellm_metadata = litellm_params.get("metadata", None)
if (
_litellm_metadata
and isinstance(_litellm_metadata, dict)
and "user_id" in _litellm_metadata
):
optional_params["metadata"] = {"user_id": _litellm_metadata["user_id"]}
data = { data = {
"messages": anthropic_messages, "messages": anthropic_messages,
**optional_params, **optional_params,

View file

@ -13,8 +13,11 @@ sys.path.insert(
import litellm import litellm
from litellm.exceptions import BadRequestError from litellm.exceptions import BadRequestError
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.utils import CustomStreamWrapper from litellm.utils import (
CustomStreamWrapper,
get_supported_openai_params,
get_optional_params,
)
# test_example.py # test_example.py
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
@ -25,6 +28,11 @@ class BaseLLMChatTest(ABC):
Abstract base test class that enforces a common test across all test classes. Abstract base test class that enforces a common test across all test classes.
""" """
@abstractmethod
def get_default_model_name(self) -> str:
"""Must return the default model name"""
pass
@abstractmethod @abstractmethod
def get_base_completion_call_args(self) -> dict: def get_base_completion_call_args(self) -> dict:
"""Must return the base completion call args""" """Must return the base completion call args"""

View file

@ -921,3 +921,16 @@ def test_watsonx_text_top_k():
) )
print(optional_params) print(optional_params)
assert optional_params["top_k"] == 10 assert optional_params["top_k"] == 10
def test_forward_user_param():
from litellm.utils import get_supported_openai_params, get_optional_params
model = "claude-3-5-sonnet-20240620"
optional_params = get_optional_params(
model=model,
user="test_user",
custom_llm_provider="anthropic",
)
assert optional_params["metadata"]["user_id"] == "test_user"