(feat) - add claude 3

This commit is contained in:
ishaan-jaff 2024-03-04 07:13:08 -08:00
parent b283588eb8
commit 19eb9063fb
4 changed files with 32 additions and 19 deletions

View file

@ -20,7 +20,7 @@ class AnthropicError(Exception):
self.status_code = status_code self.status_code = status_code
self.message = message self.message = message
self.request = httpx.Request( self.request = httpx.Request(
method="POST", url="https://api.anthropic.com/v1/complete" method="POST", url="https://api.anthropic.com/v1/messages"
) )
self.response = httpx.Response(status_code=status_code, request=self.request) self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__( super().__init__(
@ -35,9 +35,7 @@ class AnthropicConfig:
to pass metadata to anthropic, it's {"user_id": "any-relevant-information"} to pass metadata to anthropic, it's {"user_id": "any-relevant-information"}
""" """
max_tokens_to_sample: Optional[ max_tokens: Optional[int] = litellm.max_tokens # anthropic requires a default
int
] = litellm.max_tokens # anthropic requires a default
stop_sequences: Optional[list] = None stop_sequences: Optional[list] = None
temperature: Optional[int] = None temperature: Optional[int] = None
top_p: Optional[int] = None top_p: Optional[int] = None
@ -46,7 +44,7 @@ class AnthropicConfig:
def __init__( def __init__(
self, self,
max_tokens_to_sample: Optional[int] = 256, # anthropic requires a default max_tokens: Optional[int] = 256, # anthropic requires a default
stop_sequences: Optional[list] = None, stop_sequences: Optional[list] = None,
temperature: Optional[int] = None, temperature: Optional[int] = None,
top_p: Optional[int] = None, top_p: Optional[int] = None,
@ -124,6 +122,10 @@ def completion(
model=model, messages=messages, custom_llm_provider="anthropic" model=model, messages=messages, custom_llm_provider="anthropic"
) )
for message in messages:
if message["role"] == "system":
message["role"] = "assistant"
## Load Config ## Load Config
config = litellm.AnthropicConfig.get_config() config = litellm.AnthropicConfig.get_config()
for k, v in config.items(): for k, v in config.items():
@ -134,7 +136,7 @@ def completion(
data = { data = {
"model": model, "model": model,
"prompt": prompt, "messages": messages,
**optional_params, **optional_params,
} }
@ -173,7 +175,7 @@ def completion(
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=prompt, input=messages,
api_key=api_key, api_key=api_key,
original_response=response.text, original_response=response.text,
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
@ -192,19 +194,14 @@ def completion(
status_code=response.status_code, status_code=response.status_code,
) )
else: else:
if len(completion_response["completion"]) > 0: text_content = completion_response["content"][0].get("text", None)
model_response["choices"][0]["message"][ model_response.choices[0].message.content = text_content # type: ignore
"content"
] = completion_response["completion"]
model_response.choices[0].finish_reason = completion_response["stop_reason"] model_response.choices[0].finish_reason = completion_response["stop_reason"]
## CALCULATING USAGE ## CALCULATING USAGE
prompt_tokens = len( prompt_tokens = completion_response["usage"]["input_tokens"]
encoding.encode(prompt) completion_tokens = completion_response["usage"]["output_tokens"]
) ##[TODO] use the anthropic tokenizer here total_tokens = prompt_tokens + completion_tokens
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
) ##[TODO] use the anthropic tokenizer here
model_response["created"] = int(time.time()) model_response["created"] = int(time.time())
model_response["model"] = model model_response["model"] = model

View file

@ -1023,7 +1023,7 @@ def completion(
api_base api_base
or litellm.api_base or litellm.api_base
or get_secret("ANTHROPIC_API_BASE") or get_secret("ANTHROPIC_API_BASE")
or "https://api.anthropic.com/v1/complete" or "https://api.anthropic.com/v1/messages"
) )
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
response = anthropic.completion( response = anthropic.completion(

View file

@ -84,6 +84,22 @@ def test_completion_claude():
# test_completion_claude() # test_completion_claude()
def test_completion_claude_3():
litellm.set_verbose = True
messages = [{"role": "user", "content": "Hello, world"}]
try:
# test without max tokens
response = completion(
model="anthropic/claude-3-opus-20240229",
messages=messages,
max_tokens=10,
)
# Add any assertions, here to check response args
print(response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_completion_mistral_api(): def test_completion_mistral_api():
try: try:
litellm.set_verbose = True litellm.set_verbose = True

View file

@ -4200,7 +4200,7 @@ def get_optional_params(
if top_p is not None: if top_p is not None:
optional_params["top_p"] = top_p optional_params["top_p"] = top_p
if max_tokens is not None: if max_tokens is not None:
optional_params["max_tokens_to_sample"] = max_tokens optional_params["max_tokens"] = max_tokens
elif custom_llm_provider == "cohere": elif custom_llm_provider == "cohere":
## check if unsupported param passed in ## check if unsupported param passed in
supported_params = [ supported_params = [