From 1183e5f2e5c4c71482034f1c20af5e9c671d40fb Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Mon, 4 Mar 2024 11:16:34 -0800 Subject: [PATCH] (feat) maintain anthropic text completion --- docs/my-website/docs/providers/anthropic.md | 1 + litellm/main.py | 70 ++++++++++++++------- litellm/tests/test_completion.py | 5 +- litellm/utils.py | 8 ++- 4 files changed, 59 insertions(+), 25 deletions(-) diff --git a/docs/my-website/docs/providers/anthropic.md b/docs/my-website/docs/providers/anthropic.md index 198a6a03d..aff3415d3 100644 --- a/docs/my-website/docs/providers/anthropic.md +++ b/docs/my-website/docs/providers/anthropic.md @@ -56,6 +56,7 @@ for chunk in response: | claude-2.1 | `completion('claude-2.1', messages)` | `os.environ['ANTHROPIC_API_KEY']` | | claude-2 | `completion('claude-2', messages)` | `os.environ['ANTHROPIC_API_KEY']` | | claude-instant-1.2 | `completion('claude-instant-1.2', messages)` | `os.environ['ANTHROPIC_API_KEY']` | +| claude-instant-1 | `completion('claude-instant-1', messages)` | `os.environ['ANTHROPIC_API_KEY']` | ## Advanced diff --git a/litellm/main.py b/litellm/main.py index b7707b722..60effd96f 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -39,6 +39,7 @@ from litellm.utils import ( ) from .llms import ( anthropic, + anthropic_text, together_ai, ai21, sagemaker, @@ -1018,28 +1019,55 @@ def completion( or litellm.api_key or os.environ.get("ANTHROPIC_API_KEY") ) - api_base = ( - api_base - or litellm.api_base - or get_secret("ANTHROPIC_API_BASE") - or "https://api.anthropic.com/v1/messages" - ) custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict - response = anthropic.completion( - model=model, - messages=messages, - api_base=api_base, - custom_prompt_dict=litellm.custom_prompt_dict, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, # for calculating input/output tokens - api_key=api_key, - logging_obj=logging, - headers=headers, - ) + + if (model == "claude-2") or (model == "claude-instant-1"): + # call anthropic /completion, only use this route for claude-2, claude-instant-1 + api_base = ( + api_base + or litellm.api_base + or get_secret("ANTHROPIC_API_BASE") + or "https://api.anthropic.com/v1/complete" + ) + response = anthropic_text.completion( + model=model, + messages=messages, + api_base=api_base, + custom_prompt_dict=litellm.custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, # for calculating input/output tokens + api_key=api_key, + logging_obj=logging, + headers=headers, + ) + else: + # call /messages + # default route for all anthropic models + api_base = ( + api_base + or litellm.api_base + or get_secret("ANTHROPIC_API_BASE") + or "https://api.anthropic.com/v1/messages" + ) + response = anthropic.completion( + model=model, + messages=messages, + api_base=api_base, + custom_prompt_dict=litellm.custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, # for calculating input/output tokens + api_key=api_key, + logging_obj=logging, + headers=headers, + ) if "stream" in optional_params and optional_params["stream"] == True: # don't try to access stream object, response = CustomStreamWrapper( diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index a9d41be8d..13a08689c 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -56,7 +56,7 @@ def test_completion_custom_provider_model_name(): def test_completion_claude(): litellm.set_verbose = True litellm.cache = None - litellm.AnthropicConfig(max_tokens=200, metadata={"user_id": "1224"}) + litellm.AnthropicTextConfig(max_tokens_to_sample=200, metadata={"user_id": "1224"}) messages = [ { "role": "system", @@ -67,9 +67,10 @@ def test_completion_claude(): try: # test without max tokens response = completion( - model="claude-instant-1.2", + model="claude-instant-1", messages=messages, request_timeout=10, + max_tokens=10, ) # Add any assertions, here to check response args print(response) diff --git a/litellm/utils.py b/litellm/utils.py index 233fd6bae..4b9b0c8a4 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -4200,7 +4200,11 @@ def get_optional_params( if top_p is not None: optional_params["top_p"] = top_p if max_tokens is not None: - optional_params["max_tokens"] = max_tokens + if (model == "claude-2") or (model == "claude-instant-1"): + # these models use antropic_text.py which only accepts max_tokens_to_sample + optional_params["max_tokens_to_sample"] = max_tokens + else: + optional_params["max_tokens"] = max_tokens elif custom_llm_provider == "cohere": ## check if unsupported param passed in supported_params = [ @@ -9704,4 +9708,4 @@ def _get_base_model_from_metadata(model_call_details=None): base_model = model_info.get("base_model", None) if base_model is not None: return base_model - return None \ No newline at end of file + return None