fix(bedrock.py): working image calls to claude 3

This commit is contained in:
Krrish Dholakia 2024-03-04 18:12:47 -08:00
parent 818c29516d
commit caa17d484a
3 changed files with 114 additions and 22 deletions

View file

@ -1,7 +1,7 @@
import json, copy, types import json, copy, types
import os import os
from enum import Enum from enum import Enum
import time import time, uuid
from typing import Callable, Optional, Any, Union, List from typing import Callable, Optional, Any, Union, List
import litellm import litellm
from litellm.utils import ModelResponse, get_secret, Usage, ImageResponse from litellm.utils import ModelResponse, get_secret, Usage, ImageResponse
@ -118,12 +118,14 @@ class AmazonAnthropicClaude3Config:
} }
def get_supported_openai_params(self): def get_supported_openai_params(self):
return ["max_tokens"] return ["max_tokens", "tools", "tool_choice", "stream"]
def map_openai_params(self, non_default_params: dict, optional_params: dict): def map_openai_params(self, non_default_params: dict, optional_params: dict):
for param, value in non_default_params.items(): for param, value in non_default_params.items():
if param == "max_tokens": if param == "max_tokens":
optional_params["max_tokens"] = value optional_params["max_tokens"] = value
if param == "tools":
optional_params["tools"] = value
return optional_params return optional_params
@ -897,7 +899,37 @@ def completion(
elif provider == "anthropic": elif provider == "anthropic":
if model.startswith("anthropic.claude-3"): if model.startswith("anthropic.claude-3"):
outputText = response_body.get("content")[0].get("text", None) outputText = response_body.get("content")[0].get("text", None)
if "<invoke>" in outputText: # OUTPUT PARSE FUNCTION CALL
function_name = extract_between_tags("tool_name", outputText)[0]
function_arguments_str = extract_between_tags("invoke", outputText)[
0
].strip()
function_arguments_str = (
f"<invoke>{function_arguments_str}</invoke>"
)
function_arguments = parse_xml_params(function_arguments_str)
_message = litellm.Message(
tool_calls=[
{
"id": f"call_{uuid.uuid4()}",
"type": "function",
"function": {
"name": function_name,
"arguments": json.dumps(function_arguments),
},
}
],
content=None,
)
model_response.choices[0].message = _message # type: ignore
model_response["finish_reason"] = response_body["stop_reason"] model_response["finish_reason"] = response_body["stop_reason"]
_usage = litellm.Usage(
prompt_tokens=response_body["usage"]["input_tokens"],
completion_tokens=response_body["usage"]["output_tokens"],
total_tokens=response_body["usage"]["input_tokens"]
+ response_body["usage"]["output_tokens"],
)
model_response.usage = _usage
else: else:
outputText = response_body["completion"] outputText = response_body["completion"]
model_response["finish_reason"] = response_body["stop_reason"] model_response["finish_reason"] = response_body["stop_reason"]
@ -919,8 +951,17 @@ def completion(
) )
else: else:
try: try:
if len(outputText) > 0: if (
len(outputText) > 0
and hasattr(model_response.choices[0], "message")
and model_response.choices[0].message.tool_calls is None
):
model_response["choices"][0]["message"]["content"] = outputText model_response["choices"][0]["message"]["content"] = outputText
elif (
hasattr(model_response.choices[0], "message")
and model_response.choices[0].message.tool_calls is not None
):
pass
else: else:
raise Exception() raise Exception()
except: except:
@ -930,26 +971,28 @@ def completion(
) )
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
prompt_tokens = response_metadata.get( if getattr(model_response.usage, "total_tokens", None) is None:
"x-amzn-bedrock-input-token-count", len(encoding.encode(prompt)) prompt_tokens = response_metadata.get(
) "x-amzn-bedrock-input-token-count", len(encoding.encode(prompt))
completion_tokens = response_metadata.get( )
"x-amzn-bedrock-output-token-count", completion_tokens = response_metadata.get(
len( "x-amzn-bedrock-output-token-count",
encoding.encode( len(
model_response["choices"][0]["message"].get("content", "") encoding.encode(
) model_response["choices"][0]["message"].get("content", "")
), )
) ),
)
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
model_response["created"] = int(time.time()) model_response["created"] = int(time.time())
model_response["model"] = model model_response["model"] = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
model_response._hidden_params["region_name"] = client.meta.region_name model_response._hidden_params["region_name"] = client.meta.region_name
print_verbose(f"model_response._hidden_params: {model_response._hidden_params}") print_verbose(f"model_response._hidden_params: {model_response._hidden_params}")
return model_response return model_response

View file

@ -271,6 +271,50 @@ def test_bedrock_claude_3_tool_calling():
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
def encode_image(image_path):
import base64
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
@pytest.mark.skip(
reason="we already test claude-3, this is just another way to pass images"
)
def test_completion_claude_3_base64():
try:
litellm.set_verbose = True
litellm.num_retries = 3
image_path = "../proxy/cached_logo.jpg"
# Getting the base64 string
base64_image = encode_image(image_path)
resp = litellm.completion(
model="bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "Whats in this image?"},
{
"type": "image_url",
"image_url": {
"url": "data:image/jpeg;base64," + base64_image
},
},
],
}
],
)
prompt_tokens = resp.usage.prompt_tokens
raise Exception("it worked!")
except Exception as e:
if "500 Internal error encountered.'" in str(e):
pass
else:
pytest.fail(f"An exception occurred - {str(e)}")
def test_provisioned_throughput(): def test_provisioned_throughput():
try: try:
litellm.set_verbose = True litellm.set_verbose = True

View file

@ -245,10 +245,14 @@ class Message(OpenAIObject):
self.role = role self.role = role
if function_call is not None: if function_call is not None:
self.function_call = FunctionCall(**function_call) self.function_call = FunctionCall(**function_call)
else:
self.function_call = function_call
if tool_calls is not None: if tool_calls is not None:
self.tool_calls = [] self.tool_calls = []
for tool_call in tool_calls: for tool_call in tool_calls:
self.tool_calls.append(ChatCompletionMessageToolCall(**tool_call)) self.tool_calls.append(ChatCompletionMessageToolCall(**tool_call))
else:
self.tool_calls = tool_calls
if logprobs is not None: if logprobs is not None:
self._logprobs = logprobs self._logprobs = logprobs
@ -4111,6 +4115,7 @@ def get_optional_params(
and custom_llm_provider != "together_ai" and custom_llm_provider != "together_ai"
and custom_llm_provider != "mistral" and custom_llm_provider != "mistral"
and custom_llm_provider != "anthropic" and custom_llm_provider != "anthropic"
and custom_llm_provider != "bedrock"
): ):
if custom_llm_provider == "ollama" or custom_llm_provider == "ollama_chat": if custom_llm_provider == "ollama" or custom_llm_provider == "ollama_chat":
# ollama actually supports json output # ollama actually supports json output
@ -4521,13 +4526,13 @@ def get_optional_params(
# \"max_tokens_to_sample\":300,\"temperature\":0.5,\"top_p\":1,\"stop_sequences\":[\"\\\\n\\\\nHuman:\"]}" # \"max_tokens_to_sample\":300,\"temperature\":0.5,\"top_p\":1,\"stop_sequences\":[\"\\\\n\\\\nHuman:\"]}"
if model.startswith("anthropic.claude-3"): if model.startswith("anthropic.claude-3"):
optional_params = ( optional_params = (
litellm.AmazonAnthropicClaude3Config.map_openai_params( litellm.AmazonAnthropicClaude3Config().map_openai_params(
non_default_params=non_default_params, non_default_params=non_default_params,
optional_params=optional_params, optional_params=optional_params,
) )
) )
else: else:
optional_params = litellm.AmazonAnthropicConfig.map_openai_params( optional_params = litellm.AmazonAnthropicConfig().map_openai_params(
non_default_params=non_default_params, non_default_params=non_default_params,
optional_params=optional_params, optional_params=optional_params,
) )