diff --git a/litellm/llms/anthropic.py b/litellm/llms/anthropic.py index 150ae0e07..829a8becd 100644 --- a/litellm/llms/anthropic.py +++ b/litellm/llms/anthropic.py @@ -20,7 +20,7 @@ class AnthropicError(Exception): self.status_code = status_code self.message = message self.request = httpx.Request( - method="POST", url="https://api.anthropic.com/v1/complete" + method="POST", url="https://api.anthropic.com/v1/messages" ) self.response = httpx.Response(status_code=status_code, request=self.request) super().__init__( @@ -35,9 +35,7 @@ class AnthropicConfig: to pass metadata to anthropic, it's {"user_id": "any-relevant-information"} """ - max_tokens_to_sample: Optional[ - int - ] = litellm.max_tokens # anthropic requires a default + max_tokens: Optional[int] = litellm.max_tokens # anthropic requires a default stop_sequences: Optional[list] = None temperature: Optional[int] = None top_p: Optional[int] = None @@ -46,7 +44,7 @@ class AnthropicConfig: def __init__( self, - max_tokens_to_sample: Optional[int] = 256, # anthropic requires a default + max_tokens: Optional[int] = 256, # anthropic requires a default stop_sequences: Optional[list] = None, temperature: Optional[int] = None, top_p: Optional[int] = None, @@ -124,6 +122,10 @@ def completion( model=model, messages=messages, custom_llm_provider="anthropic" ) + for message in messages: + if message["role"] == "system": + message["role"] = "assistant" + ## Load Config config = litellm.AnthropicConfig.get_config() for k, v in config.items(): @@ -134,7 +136,7 @@ def completion( data = { "model": model, - "prompt": prompt, + "messages": messages, **optional_params, } @@ -173,7 +175,7 @@ def completion( ## LOGGING logging_obj.post_call( - input=prompt, + input=messages, api_key=api_key, original_response=response.text, additional_args={"complete_input_dict": data}, @@ -192,19 +194,14 @@ def completion( status_code=response.status_code, ) else: - if len(completion_response["completion"]) > 0: - model_response["choices"][0]["message"][ - "content" - ] = completion_response["completion"] + text_content = completion_response["content"][0].get("text", None) + model_response.choices[0].message.content = text_content # type: ignore model_response.choices[0].finish_reason = completion_response["stop_reason"] ## CALCULATING USAGE - prompt_tokens = len( - encoding.encode(prompt) - ) ##[TODO] use the anthropic tokenizer here - completion_tokens = len( - encoding.encode(model_response["choices"][0]["message"].get("content", "")) - ) ##[TODO] use the anthropic tokenizer here + prompt_tokens = completion_response["usage"]["input_tokens"] + completion_tokens = completion_response["usage"]["output_tokens"] + total_tokens = prompt_tokens + completion_tokens model_response["created"] = int(time.time()) model_response["model"] = model diff --git a/litellm/main.py b/litellm/main.py index 67586603d..9fbd1b828 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1023,7 +1023,7 @@ def completion( api_base or litellm.api_base or get_secret("ANTHROPIC_API_BASE") - or "https://api.anthropic.com/v1/complete" + or "https://api.anthropic.com/v1/messages" ) custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict response = anthropic.completion( diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index f6be6e952..200b0ae58 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -84,6 +84,22 @@ def test_completion_claude(): # test_completion_claude() +def test_completion_claude_3(): + litellm.set_verbose = True + messages = [{"role": "user", "content": "Hello, world"}] + try: + # test without max tokens + response = completion( + model="anthropic/claude-3-opus-20240229", + messages=messages, + max_tokens=10, + ) + # Add any assertions, here to check response args + print(response) + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + def test_completion_mistral_api(): try: litellm.set_verbose = True diff --git a/litellm/utils.py b/litellm/utils.py index b3e197c1a..173f5e79e 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -4200,7 +4200,7 @@ def get_optional_params( if top_p is not None: optional_params["top_p"] = top_p if max_tokens is not None: - optional_params["max_tokens_to_sample"] = max_tokens + optional_params["max_tokens"] = max_tokens elif custom_llm_provider == "cohere": ## check if unsupported param passed in supported_params = [