fix(bedrock.py): working image calls to claude 3

This commit is contained in:
Krrish Dholakia 2024-03-04 18:12:47 -08:00
parent 818c29516d
commit caa17d484a
3 changed files with 114 additions and 22 deletions

View file

@ -1,7 +1,7 @@
import json, copy, types
import os
from enum import Enum
import time
import time, uuid
from typing import Callable, Optional, Any, Union, List
import litellm
from litellm.utils import ModelResponse, get_secret, Usage, ImageResponse
@ -118,12 +118,14 @@ class AmazonAnthropicClaude3Config:
}
def get_supported_openai_params(self):
return ["max_tokens"]
return ["max_tokens", "tools", "tool_choice", "stream"]
def map_openai_params(self, non_default_params: dict, optional_params: dict):
for param, value in non_default_params.items():
if param == "max_tokens":
optional_params["max_tokens"] = value
if param == "tools":
optional_params["tools"] = value
return optional_params
@ -897,7 +899,37 @@ def completion(
elif provider == "anthropic":
if model.startswith("anthropic.claude-3"):
outputText = response_body.get("content")[0].get("text", None)
if "<invoke>" in outputText: # OUTPUT PARSE FUNCTION CALL
function_name = extract_between_tags("tool_name", outputText)[0]
function_arguments_str = extract_between_tags("invoke", outputText)[
0
].strip()
function_arguments_str = (
f"<invoke>{function_arguments_str}</invoke>"
)
function_arguments = parse_xml_params(function_arguments_str)
_message = litellm.Message(
tool_calls=[
{
"id": f"call_{uuid.uuid4()}",
"type": "function",
"function": {
"name": function_name,
"arguments": json.dumps(function_arguments),
},
}
],
content=None,
)
model_response.choices[0].message = _message # type: ignore
model_response["finish_reason"] = response_body["stop_reason"]
_usage = litellm.Usage(
prompt_tokens=response_body["usage"]["input_tokens"],
completion_tokens=response_body["usage"]["output_tokens"],
total_tokens=response_body["usage"]["input_tokens"]
+ response_body["usage"]["output_tokens"],
)
model_response.usage = _usage
else:
outputText = response_body["completion"]
model_response["finish_reason"] = response_body["stop_reason"]
@ -919,8 +951,17 @@ def completion(
)
else:
try:
if len(outputText) > 0:
if (
len(outputText) > 0
and hasattr(model_response.choices[0], "message")
and model_response.choices[0].message.tool_calls is None
):
model_response["choices"][0]["message"]["content"] = outputText
elif (
hasattr(model_response.choices[0], "message")
and model_response.choices[0].message.tool_calls is not None
):
pass
else:
raise Exception()
except:
@ -930,6 +971,7 @@ def completion(
)
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
if getattr(model_response.usage, "total_tokens", None) is None:
prompt_tokens = response_metadata.get(
"x-amzn-bedrock-input-token-count", len(encoding.encode(prompt))
)
@ -941,15 +983,16 @@ def completion(
)
),
)
model_response["created"] = int(time.time())
model_response["model"] = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
model_response["created"] = int(time.time())
model_response["model"] = model
model_response._hidden_params["region_name"] = client.meta.region_name
print_verbose(f"model_response._hidden_params: {model_response._hidden_params}")
return model_response

View file

@ -271,6 +271,50 @@ def test_bedrock_claude_3_tool_calling():
pytest.fail(f"Error occurred: {e}")
def encode_image(image_path):
import base64
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
@pytest.mark.skip(
reason="we already test claude-3, this is just another way to pass images"
)
def test_completion_claude_3_base64():
try:
litellm.set_verbose = True
litellm.num_retries = 3
image_path = "../proxy/cached_logo.jpg"
# Getting the base64 string
base64_image = encode_image(image_path)
resp = litellm.completion(
model="bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "Whats in this image?"},
{
"type": "image_url",
"image_url": {
"url": "data:image/jpeg;base64," + base64_image
},
},
],
}
],
)
prompt_tokens = resp.usage.prompt_tokens
raise Exception("it worked!")
except Exception as e:
if "500 Internal error encountered.'" in str(e):
pass
else:
pytest.fail(f"An exception occurred - {str(e)}")
def test_provisioned_throughput():
try:
litellm.set_verbose = True

View file

@ -245,10 +245,14 @@ class Message(OpenAIObject):
self.role = role
if function_call is not None:
self.function_call = FunctionCall(**function_call)
else:
self.function_call = function_call
if tool_calls is not None:
self.tool_calls = []
for tool_call in tool_calls:
self.tool_calls.append(ChatCompletionMessageToolCall(**tool_call))
else:
self.tool_calls = tool_calls
if logprobs is not None:
self._logprobs = logprobs
@ -4111,6 +4115,7 @@ def get_optional_params(
and custom_llm_provider != "together_ai"
and custom_llm_provider != "mistral"
and custom_llm_provider != "anthropic"
and custom_llm_provider != "bedrock"
):
if custom_llm_provider == "ollama" or custom_llm_provider == "ollama_chat":
# ollama actually supports json output
@ -4521,13 +4526,13 @@ def get_optional_params(
# \"max_tokens_to_sample\":300,\"temperature\":0.5,\"top_p\":1,\"stop_sequences\":[\"\\\\n\\\\nHuman:\"]}"
if model.startswith("anthropic.claude-3"):
optional_params = (
litellm.AmazonAnthropicClaude3Config.map_openai_params(
litellm.AmazonAnthropicClaude3Config().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
)
)
else:
optional_params = litellm.AmazonAnthropicConfig.map_openai_params(
optional_params = litellm.AmazonAnthropicConfig().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
)