From 756d838dfa9c888a7edc771166b97161f9c5824b Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 14 Nov 2024 12:07:23 +0530 Subject: [PATCH] feat(anthropic/chat/transformation.py): support passing user id to anthropic via openai 'user' param --- litellm/llms/anthropic/chat/handler.py | 3 +- litellm/llms/anthropic/chat/transformation.py | 58 ++++++++++++------- tests/llm_translation/base_llm_unit_tests.py | 12 +++- tests/llm_translation/test_optional_params.py | 13 +++++ 4 files changed, 62 insertions(+), 24 deletions(-) diff --git a/litellm/llms/anthropic/chat/handler.py b/litellm/llms/anthropic/chat/handler.py index 2d119a28f..12194533c 100644 --- a/litellm/llms/anthropic/chat/handler.py +++ b/litellm/llms/anthropic/chat/handler.py @@ -440,8 +440,8 @@ class AnthropicChatCompletion(BaseLLM): logging_obj, optional_params: dict, timeout: Union[float, httpx.Timeout], + litellm_params: dict, acompletion=None, - litellm_params=None, logger_fn=None, headers={}, client=None, @@ -464,6 +464,7 @@ class AnthropicChatCompletion(BaseLLM): model=model, messages=messages, optional_params=optional_params, + litellm_params=litellm_params, headers=headers, _is_function_call=_is_function_call, is_vertex_request=is_vertex_request, diff --git a/litellm/llms/anthropic/chat/transformation.py b/litellm/llms/anthropic/chat/transformation.py index e222d8721..e12fbd572 100644 --- a/litellm/llms/anthropic/chat/transformation.py +++ b/litellm/llms/anthropic/chat/transformation.py @@ -91,6 +91,7 @@ class AnthropicConfig: "extra_headers", "parallel_tool_calls", "response_format", + "user", ] def get_cache_control_headers(self) -> dict: @@ -246,6 +247,28 @@ class AnthropicConfig: anthropic_tools.append(new_tool) return anthropic_tools + def _map_stop_sequences( + self, stop: Optional[Union[str, List[str]]] + ) -> Optional[List[str]]: + new_stop: Optional[List[str]] = None + if isinstance(stop, str): + if ( + stop == "\n" + ) and litellm.drop_params is True: # anthropic doesn't allow whitespace characters as stop-sequences + return new_stop + new_stop = [stop] + elif isinstance(stop, list): + new_v = [] + for v in stop: + if ( + v == "\n" + ) and litellm.drop_params is True: # anthropic doesn't allow whitespace characters as stop-sequences + continue + new_v.append(v) + if len(new_v) > 0: + new_stop = new_v + return new_stop + def map_openai_params( self, non_default_params: dict, @@ -271,26 +294,8 @@ class AnthropicConfig: optional_params["tool_choice"] = _tool_choice if param == "stream" and value is True: optional_params["stream"] = value - if param == "stop": - if isinstance(value, str): - if ( - value == "\n" - ) and litellm.drop_params is True: # anthropic doesn't allow whitespace characters as stop-sequences - continue - value = [value] - elif isinstance(value, list): - new_v = [] - for v in value: - if ( - v == "\n" - ) and litellm.drop_params is True: # anthropic doesn't allow whitespace characters as stop-sequences - continue - new_v.append(v) - if len(new_v) > 0: - value = new_v - else: - continue - optional_params["stop_sequences"] = value + if param == "stop" and (isinstance(value, str) or isinstance(value, list)): + optional_params["stop_sequences"] = self._map_stop_sequences(value) if param == "temperature": optional_params["temperature"] = value if param == "top_p": @@ -314,7 +319,8 @@ class AnthropicConfig: optional_params["tools"] = [_tool] optional_params["tool_choice"] = _tool_choice optional_params["json_mode"] = True - + if param == "user": + optional_params["metadata"] = {"user_id": value} ## VALIDATE REQUEST """ Anthropic doesn't support tool calling without `tools=` param specified. @@ -465,6 +471,7 @@ class AnthropicConfig: model: str, messages: List[AllMessageValues], optional_params: dict, + litellm_params: dict, headers: dict, _is_function_call: bool, is_vertex_request: bool, @@ -502,6 +509,15 @@ class AnthropicConfig: if "tools" in optional_params: _is_function_call = True + ## Handle user_id in metadata + _litellm_metadata = litellm_params.get("metadata", None) + if ( + _litellm_metadata + and isinstance(_litellm_metadata, dict) + and "user_id" in _litellm_metadata + ): + optional_params["metadata"] = {"user_id": _litellm_metadata["user_id"]} + data = { "messages": anthropic_messages, **optional_params, diff --git a/tests/llm_translation/base_llm_unit_tests.py b/tests/llm_translation/base_llm_unit_tests.py index acb764ba1..2546344d0 100644 --- a/tests/llm_translation/base_llm_unit_tests.py +++ b/tests/llm_translation/base_llm_unit_tests.py @@ -13,8 +13,11 @@ sys.path.insert( import litellm from litellm.exceptions import BadRequestError from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler -from litellm.utils import CustomStreamWrapper - +from litellm.utils import ( + CustomStreamWrapper, + get_supported_openai_params, + get_optional_params, +) # test_example.py from abc import ABC, abstractmethod @@ -25,6 +28,11 @@ class BaseLLMChatTest(ABC): Abstract base test class that enforces a common test across all test classes. """ + @abstractmethod + def get_default_model_name(self) -> str: + """Must return the default model name""" + pass + @abstractmethod def get_base_completion_call_args(self) -> dict: """Must return the base completion call args""" diff --git a/tests/llm_translation/test_optional_params.py b/tests/llm_translation/test_optional_params.py index 7283e9a39..bea066865 100644 --- a/tests/llm_translation/test_optional_params.py +++ b/tests/llm_translation/test_optional_params.py @@ -921,3 +921,16 @@ def test_watsonx_text_top_k(): ) print(optional_params) assert optional_params["top_k"] == 10 + + +def test_forward_user_param(): + from litellm.utils import get_supported_openai_params, get_optional_params + + model = "claude-3-5-sonnet-20240620" + optional_params = get_optional_params( + model=model, + user="test_user", + custom_llm_provider="anthropic", + ) + + assert optional_params["metadata"]["user_id"] == "test_user"