(feat) maintain anthropic text completion

This commit is contained in:
ishaan-jaff 2024-03-04 11:16:34 -08:00
parent 9094be7fbd
commit 1183e5f2e5
4 changed files with 59 additions and 25 deletions

View file

@ -56,6 +56,7 @@ for chunk in response:
| claude-2.1 | `completion('claude-2.1', messages)` | `os.environ['ANTHROPIC_API_KEY']` | | claude-2.1 | `completion('claude-2.1', messages)` | `os.environ['ANTHROPIC_API_KEY']` |
| claude-2 | `completion('claude-2', messages)` | `os.environ['ANTHROPIC_API_KEY']` | | claude-2 | `completion('claude-2', messages)` | `os.environ['ANTHROPIC_API_KEY']` |
| claude-instant-1.2 | `completion('claude-instant-1.2', messages)` | `os.environ['ANTHROPIC_API_KEY']` | | claude-instant-1.2 | `completion('claude-instant-1.2', messages)` | `os.environ['ANTHROPIC_API_KEY']` |
| claude-instant-1 | `completion('claude-instant-1', messages)` | `os.environ['ANTHROPIC_API_KEY']` |
## Advanced ## Advanced

View file

@ -39,6 +39,7 @@ from litellm.utils import (
) )
from .llms import ( from .llms import (
anthropic, anthropic,
anthropic_text,
together_ai, together_ai,
ai21, ai21,
sagemaker, sagemaker,
@ -1018,13 +1019,40 @@ def completion(
or litellm.api_key or litellm.api_key
or os.environ.get("ANTHROPIC_API_KEY") or os.environ.get("ANTHROPIC_API_KEY")
) )
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
if (model == "claude-2") or (model == "claude-instant-1"):
# call anthropic /completion, only use this route for claude-2, claude-instant-1
api_base = (
api_base
or litellm.api_base
or get_secret("ANTHROPIC_API_BASE")
or "https://api.anthropic.com/v1/complete"
)
response = anthropic_text.completion(
model=model,
messages=messages,
api_base=api_base,
custom_prompt_dict=litellm.custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding, # for calculating input/output tokens
api_key=api_key,
logging_obj=logging,
headers=headers,
)
else:
# call /messages
# default route for all anthropic models
api_base = ( api_base = (
api_base api_base
or litellm.api_base or litellm.api_base
or get_secret("ANTHROPIC_API_BASE") or get_secret("ANTHROPIC_API_BASE")
or "https://api.anthropic.com/v1/messages" or "https://api.anthropic.com/v1/messages"
) )
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
response = anthropic.completion( response = anthropic.completion(
model=model, model=model,
messages=messages, messages=messages,

View file

@ -56,7 +56,7 @@ def test_completion_custom_provider_model_name():
def test_completion_claude(): def test_completion_claude():
litellm.set_verbose = True litellm.set_verbose = True
litellm.cache = None litellm.cache = None
litellm.AnthropicConfig(max_tokens=200, metadata={"user_id": "1224"}) litellm.AnthropicTextConfig(max_tokens_to_sample=200, metadata={"user_id": "1224"})
messages = [ messages = [
{ {
"role": "system", "role": "system",
@ -67,9 +67,10 @@ def test_completion_claude():
try: try:
# test without max tokens # test without max tokens
response = completion( response = completion(
model="claude-instant-1.2", model="claude-instant-1",
messages=messages, messages=messages,
request_timeout=10, request_timeout=10,
max_tokens=10,
) )
# Add any assertions, here to check response args # Add any assertions, here to check response args
print(response) print(response)

View file

@ -4200,6 +4200,10 @@ def get_optional_params(
if top_p is not None: if top_p is not None:
optional_params["top_p"] = top_p optional_params["top_p"] = top_p
if max_tokens is not None: if max_tokens is not None:
if (model == "claude-2") or (model == "claude-instant-1"):
# these models use antropic_text.py which only accepts max_tokens_to_sample
optional_params["max_tokens_to_sample"] = max_tokens
else:
optional_params["max_tokens"] = max_tokens optional_params["max_tokens"] = max_tokens
elif custom_llm_provider == "cohere": elif custom_llm_provider == "cohere":
## check if unsupported param passed in ## check if unsupported param passed in