fix(anthropic.py): handle whitespace characters for anthropic calls

This commit is contained in:
Krrish Dholakia 2024-05-03 17:31:34 -07:00
parent 0b9fa53e3e
commit 097714e02f
3 changed files with 26 additions and 34 deletions

View file

@ -101,13 +101,13 @@ class AnthropicConfig:
optional_params["max_tokens"] = value optional_params["max_tokens"] = value
if param == "tools": if param == "tools":
optional_params["tools"] = value optional_params["tools"] = value
if param == "stream": if param == "stream" and value == True:
optional_params["stream"] = value optional_params["stream"] = value
if param == "stop": if param == "stop":
if isinstance(value, str): if isinstance(value, str):
if ( if (
value == "\n" value == "\n"
): # anthropic doesn't allow whitespace characters as stop-sequences ) and litellm.drop_params == True: # anthropic doesn't allow whitespace characters as stop-sequences
continue continue
value = [value] value = [value]
elif isinstance(value, list): elif isinstance(value, list):
@ -115,10 +115,13 @@ class AnthropicConfig:
for v in value: for v in value:
if ( if (
v == "\n" v == "\n"
): # anthropic doesn't allow whitespace characters as stop-sequences ) and litellm.drop_params == True: # anthropic doesn't allow whitespace characters as stop-sequences
continue continue
new_v.append(v) new_v.append(v)
value = new_v if len(new_v) > 0:
value = new_v
else:
continue
optional_params["stop_sequences"] = value optional_params["stop_sequences"] = value
if param == "temperature": if param == "temperature":
optional_params["temperature"] = value optional_params["temperature"] = value

View file

@ -5,13 +5,27 @@ import pytest
sys.path.insert(0, os.path.abspath("../..")) sys.path.insert(0, os.path.abspath("../.."))
import litellm import litellm
from litellm.utils import get_optional_params_embeddings from litellm.utils import get_optional_params_embeddings, get_optional_params
## get_optional_params_embeddings ## get_optional_params_embeddings
### Models: OpenAI, Azure, Bedrock ### Models: OpenAI, Azure, Bedrock
### Scenarios: w/ optional params + litellm.drop_params = True ### Scenarios: w/ optional params + litellm.drop_params = True
@pytest.mark.parametrize(
"stop_sequence, expected_count", [("\n", 0), (["\n"], 0), (["finish_reason"], 1)]
)
def test_anthropic_optional_params(stop_sequence, expected_count):
"""
Test if whitespace character optional param is dropped by anthropic
"""
litellm.drop_params = True
optional_params = get_optional_params(
model="claude-3", custom_llm_provider="anthropic", stop=stop_sequence
)
assert len(optional_params) == expected_count
def test_bedrock_optional_params_embeddings(): def test_bedrock_optional_params_embeddings():
litellm.drop_params = True litellm.drop_params = True
optional_params = get_optional_params_embeddings( optional_params = get_optional_params_embeddings(

View file

@ -5006,26 +5006,9 @@ def get_optional_params(
model=model, custom_llm_provider=custom_llm_provider model=model, custom_llm_provider=custom_llm_provider
) )
_check_valid_arg(supported_params=supported_params) _check_valid_arg(supported_params=supported_params)
# handle anthropic params optional_params = litellm.AnthropicConfig().map_openai_params(
if stream: non_default_params=non_default_params, optional_params=optional_params
optional_params["stream"] = stream )
if stop is not None:
if type(stop) == str:
stop = [stop] # openai can accept str/list for stop
optional_params["stop_sequences"] = stop
if temperature is not None:
optional_params["temperature"] = temperature
if top_p is not None:
optional_params["top_p"] = top_p
if max_tokens is not None:
if (model == "claude-2") or (model == "claude-instant-1"):
# these models use antropic_text.py which only accepts max_tokens_to_sample
optional_params["max_tokens_to_sample"] = max_tokens
else:
optional_params["max_tokens"] = max_tokens
optional_params["max_tokens"] = max_tokens
if tools is not None:
optional_params["tools"] = tools
elif custom_llm_provider == "cohere": elif custom_llm_provider == "cohere":
## check if unsupported param passed in ## check if unsupported param passed in
supported_params = get_supported_openai_params( supported_params = get_supported_openai_params(
@ -5929,15 +5912,7 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
elif custom_llm_provider == "ollama_chat": elif custom_llm_provider == "ollama_chat":
return litellm.OllamaChatConfig().get_supported_openai_params() return litellm.OllamaChatConfig().get_supported_openai_params()
elif custom_llm_provider == "anthropic": elif custom_llm_provider == "anthropic":
return [ return litellm.AnthropicConfig().get_supported_openai_params()
"stream",
"stop",
"temperature",
"top_p",
"max_tokens",
"tools",
"tool_choice",
]
elif custom_llm_provider == "groq": elif custom_llm_provider == "groq":
return [ return [
"temperature", "temperature",