forked from phoenix/litellm-mirror
feat(anthropic/chat/transformation.py): support passing user id to anthropic via openai 'user' param
This commit is contained in:
parent
b6c9032454
commit
756d838dfa
4 changed files with 62 additions and 24 deletions
|
@ -440,8 +440,8 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
logging_obj,
|
||||
optional_params: dict,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
litellm_params: dict,
|
||||
acompletion=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
client=None,
|
||||
|
@ -464,6 +464,7 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
_is_function_call=_is_function_call,
|
||||
is_vertex_request=is_vertex_request,
|
||||
|
|
|
@ -91,6 +91,7 @@ class AnthropicConfig:
|
|||
"extra_headers",
|
||||
"parallel_tool_calls",
|
||||
"response_format",
|
||||
"user",
|
||||
]
|
||||
|
||||
def get_cache_control_headers(self) -> dict:
|
||||
|
@ -246,6 +247,28 @@ class AnthropicConfig:
|
|||
anthropic_tools.append(new_tool)
|
||||
return anthropic_tools
|
||||
|
||||
def _map_stop_sequences(
|
||||
self, stop: Optional[Union[str, List[str]]]
|
||||
) -> Optional[List[str]]:
|
||||
new_stop: Optional[List[str]] = None
|
||||
if isinstance(stop, str):
|
||||
if (
|
||||
stop == "\n"
|
||||
) and litellm.drop_params is True: # anthropic doesn't allow whitespace characters as stop-sequences
|
||||
return new_stop
|
||||
new_stop = [stop]
|
||||
elif isinstance(stop, list):
|
||||
new_v = []
|
||||
for v in stop:
|
||||
if (
|
||||
v == "\n"
|
||||
) and litellm.drop_params is True: # anthropic doesn't allow whitespace characters as stop-sequences
|
||||
continue
|
||||
new_v.append(v)
|
||||
if len(new_v) > 0:
|
||||
new_stop = new_v
|
||||
return new_stop
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
|
@ -271,26 +294,8 @@ class AnthropicConfig:
|
|||
optional_params["tool_choice"] = _tool_choice
|
||||
if param == "stream" and value is True:
|
||||
optional_params["stream"] = value
|
||||
if param == "stop":
|
||||
if isinstance(value, str):
|
||||
if (
|
||||
value == "\n"
|
||||
) and litellm.drop_params is True: # anthropic doesn't allow whitespace characters as stop-sequences
|
||||
continue
|
||||
value = [value]
|
||||
elif isinstance(value, list):
|
||||
new_v = []
|
||||
for v in value:
|
||||
if (
|
||||
v == "\n"
|
||||
) and litellm.drop_params is True: # anthropic doesn't allow whitespace characters as stop-sequences
|
||||
continue
|
||||
new_v.append(v)
|
||||
if len(new_v) > 0:
|
||||
value = new_v
|
||||
else:
|
||||
continue
|
||||
optional_params["stop_sequences"] = value
|
||||
if param == "stop" and (isinstance(value, str) or isinstance(value, list)):
|
||||
optional_params["stop_sequences"] = self._map_stop_sequences(value)
|
||||
if param == "temperature":
|
||||
optional_params["temperature"] = value
|
||||
if param == "top_p":
|
||||
|
@ -314,7 +319,8 @@ class AnthropicConfig:
|
|||
optional_params["tools"] = [_tool]
|
||||
optional_params["tool_choice"] = _tool_choice
|
||||
optional_params["json_mode"] = True
|
||||
|
||||
if param == "user":
|
||||
optional_params["metadata"] = {"user_id": value}
|
||||
## VALIDATE REQUEST
|
||||
"""
|
||||
Anthropic doesn't support tool calling without `tools=` param specified.
|
||||
|
@ -465,6 +471,7 @@ class AnthropicConfig:
|
|||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
_is_function_call: bool,
|
||||
is_vertex_request: bool,
|
||||
|
@ -502,6 +509,15 @@ class AnthropicConfig:
|
|||
if "tools" in optional_params:
|
||||
_is_function_call = True
|
||||
|
||||
## Handle user_id in metadata
|
||||
_litellm_metadata = litellm_params.get("metadata", None)
|
||||
if (
|
||||
_litellm_metadata
|
||||
and isinstance(_litellm_metadata, dict)
|
||||
and "user_id" in _litellm_metadata
|
||||
):
|
||||
optional_params["metadata"] = {"user_id": _litellm_metadata["user_id"]}
|
||||
|
||||
data = {
|
||||
"messages": anthropic_messages,
|
||||
**optional_params,
|
||||
|
|
|
@ -13,8 +13,11 @@ sys.path.insert(
|
|||
import litellm
|
||||
from litellm.exceptions import BadRequestError
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.utils import CustomStreamWrapper
|
||||
|
||||
from litellm.utils import (
|
||||
CustomStreamWrapper,
|
||||
get_supported_openai_params,
|
||||
get_optional_params,
|
||||
)
|
||||
|
||||
# test_example.py
|
||||
from abc import ABC, abstractmethod
|
||||
|
@ -25,6 +28,11 @@ class BaseLLMChatTest(ABC):
|
|||
Abstract base test class that enforces a common test across all test classes.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_default_model_name(self) -> str:
|
||||
"""Must return the default model name"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_base_completion_call_args(self) -> dict:
|
||||
"""Must return the base completion call args"""
|
||||
|
|
|
@ -921,3 +921,16 @@ def test_watsonx_text_top_k():
|
|||
)
|
||||
print(optional_params)
|
||||
assert optional_params["top_k"] == 10
|
||||
|
||||
|
||||
def test_forward_user_param():
|
||||
from litellm.utils import get_supported_openai_params, get_optional_params
|
||||
|
||||
model = "claude-3-5-sonnet-20240620"
|
||||
optional_params = get_optional_params(
|
||||
model=model,
|
||||
user="test_user",
|
||||
custom_llm_provider="anthropic",
|
||||
)
|
||||
|
||||
assert optional_params["metadata"]["user_id"] == "test_user"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue