forked from phoenix/litellm-mirror
fix(bedrock.py): working image calls to claude 3
This commit is contained in:
parent
818c29516d
commit
caa17d484a
3 changed files with 114 additions and 22 deletions
|
@ -1,7 +1,7 @@
|
|||
import json, copy, types
|
||||
import os
|
||||
from enum import Enum
|
||||
import time
|
||||
import time, uuid
|
||||
from typing import Callable, Optional, Any, Union, List
|
||||
import litellm
|
||||
from litellm.utils import ModelResponse, get_secret, Usage, ImageResponse
|
||||
|
@ -118,12 +118,14 @@ class AmazonAnthropicClaude3Config:
|
|||
}
|
||||
|
||||
def get_supported_openai_params(self):
|
||||
return ["max_tokens"]
|
||||
return ["max_tokens", "tools", "tool_choice", "stream"]
|
||||
|
||||
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
||||
for param, value in non_default_params.items():
|
||||
if param == "max_tokens":
|
||||
optional_params["max_tokens"] = value
|
||||
if param == "tools":
|
||||
optional_params["tools"] = value
|
||||
return optional_params
|
||||
|
||||
|
||||
|
@ -897,7 +899,37 @@ def completion(
|
|||
elif provider == "anthropic":
|
||||
if model.startswith("anthropic.claude-3"):
|
||||
outputText = response_body.get("content")[0].get("text", None)
|
||||
if "<invoke>" in outputText: # OUTPUT PARSE FUNCTION CALL
|
||||
function_name = extract_between_tags("tool_name", outputText)[0]
|
||||
function_arguments_str = extract_between_tags("invoke", outputText)[
|
||||
0
|
||||
].strip()
|
||||
function_arguments_str = (
|
||||
f"<invoke>{function_arguments_str}</invoke>"
|
||||
)
|
||||
function_arguments = parse_xml_params(function_arguments_str)
|
||||
_message = litellm.Message(
|
||||
tool_calls=[
|
||||
{
|
||||
"id": f"call_{uuid.uuid4()}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": function_name,
|
||||
"arguments": json.dumps(function_arguments),
|
||||
},
|
||||
}
|
||||
],
|
||||
content=None,
|
||||
)
|
||||
model_response.choices[0].message = _message # type: ignore
|
||||
model_response["finish_reason"] = response_body["stop_reason"]
|
||||
_usage = litellm.Usage(
|
||||
prompt_tokens=response_body["usage"]["input_tokens"],
|
||||
completion_tokens=response_body["usage"]["output_tokens"],
|
||||
total_tokens=response_body["usage"]["input_tokens"]
|
||||
+ response_body["usage"]["output_tokens"],
|
||||
)
|
||||
model_response.usage = _usage
|
||||
else:
|
||||
outputText = response_body["completion"]
|
||||
model_response["finish_reason"] = response_body["stop_reason"]
|
||||
|
@ -919,8 +951,17 @@ def completion(
|
|||
)
|
||||
else:
|
||||
try:
|
||||
if len(outputText) > 0:
|
||||
if (
|
||||
len(outputText) > 0
|
||||
and hasattr(model_response.choices[0], "message")
|
||||
and model_response.choices[0].message.tool_calls is None
|
||||
):
|
||||
model_response["choices"][0]["message"]["content"] = outputText
|
||||
elif (
|
||||
hasattr(model_response.choices[0], "message")
|
||||
and model_response.choices[0].message.tool_calls is not None
|
||||
):
|
||||
pass
|
||||
else:
|
||||
raise Exception()
|
||||
except:
|
||||
|
@ -930,6 +971,7 @@ def completion(
|
|||
)
|
||||
|
||||
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
|
||||
if getattr(model_response.usage, "total_tokens", None) is None:
|
||||
prompt_tokens = response_metadata.get(
|
||||
"x-amzn-bedrock-input-token-count", len(encoding.encode(prompt))
|
||||
)
|
||||
|
@ -941,15 +983,16 @@ def completion(
|
|||
)
|
||||
),
|
||||
)
|
||||
|
||||
model_response["created"] = int(time.time())
|
||||
model_response["model"] = model
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
model_response.usage = usage
|
||||
|
||||
model_response["created"] = int(time.time())
|
||||
model_response["model"] = model
|
||||
|
||||
model_response._hidden_params["region_name"] = client.meta.region_name
|
||||
print_verbose(f"model_response._hidden_params: {model_response._hidden_params}")
|
||||
return model_response
|
||||
|
|
|
@ -271,6 +271,50 @@ def test_bedrock_claude_3_tool_calling():
|
|||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
def encode_image(image_path):
|
||||
import base64
|
||||
|
||||
with open(image_path, "rb") as image_file:
|
||||
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="we already test claude-3, this is just another way to pass images"
|
||||
)
|
||||
def test_completion_claude_3_base64():
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
litellm.num_retries = 3
|
||||
image_path = "../proxy/cached_logo.jpg"
|
||||
# Getting the base64 string
|
||||
base64_image = encode_image(image_path)
|
||||
resp = litellm.completion(
|
||||
model="bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Whats in this image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "data:image/jpeg;base64," + base64_image
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
prompt_tokens = resp.usage.prompt_tokens
|
||||
raise Exception("it worked!")
|
||||
except Exception as e:
|
||||
if "500 Internal error encountered.'" in str(e):
|
||||
pass
|
||||
else:
|
||||
pytest.fail(f"An exception occurred - {str(e)}")
|
||||
|
||||
|
||||
def test_provisioned_throughput():
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
|
|
|
@ -245,10 +245,14 @@ class Message(OpenAIObject):
|
|||
self.role = role
|
||||
if function_call is not None:
|
||||
self.function_call = FunctionCall(**function_call)
|
||||
else:
|
||||
self.function_call = function_call
|
||||
if tool_calls is not None:
|
||||
self.tool_calls = []
|
||||
for tool_call in tool_calls:
|
||||
self.tool_calls.append(ChatCompletionMessageToolCall(**tool_call))
|
||||
else:
|
||||
self.tool_calls = tool_calls
|
||||
if logprobs is not None:
|
||||
self._logprobs = logprobs
|
||||
|
||||
|
@ -4111,6 +4115,7 @@ def get_optional_params(
|
|||
and custom_llm_provider != "together_ai"
|
||||
and custom_llm_provider != "mistral"
|
||||
and custom_llm_provider != "anthropic"
|
||||
and custom_llm_provider != "bedrock"
|
||||
):
|
||||
if custom_llm_provider == "ollama" or custom_llm_provider == "ollama_chat":
|
||||
# ollama actually supports json output
|
||||
|
@ -4521,13 +4526,13 @@ def get_optional_params(
|
|||
# \"max_tokens_to_sample\":300,\"temperature\":0.5,\"top_p\":1,\"stop_sequences\":[\"\\\\n\\\\nHuman:\"]}"
|
||||
if model.startswith("anthropic.claude-3"):
|
||||
optional_params = (
|
||||
litellm.AmazonAnthropicClaude3Config.map_openai_params(
|
||||
litellm.AmazonAnthropicClaude3Config().map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
)
|
||||
else:
|
||||
optional_params = litellm.AmazonAnthropicConfig.map_openai_params(
|
||||
optional_params = litellm.AmazonAnthropicConfig().map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue