forked from phoenix/litellm-mirror
fix(bedrock.py): working image calls to claude 3
This commit is contained in:
parent
818c29516d
commit
caa17d484a
3 changed files with 114 additions and 22 deletions
|
@ -1,7 +1,7 @@
|
||||||
import json, copy, types
|
import json, copy, types
|
||||||
import os
|
import os
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import time
|
import time, uuid
|
||||||
from typing import Callable, Optional, Any, Union, List
|
from typing import Callable, Optional, Any, Union, List
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.utils import ModelResponse, get_secret, Usage, ImageResponse
|
from litellm.utils import ModelResponse, get_secret, Usage, ImageResponse
|
||||||
|
@ -118,12 +118,14 @@ class AmazonAnthropicClaude3Config:
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_supported_openai_params(self):
|
def get_supported_openai_params(self):
|
||||||
return ["max_tokens"]
|
return ["max_tokens", "tools", "tool_choice", "stream"]
|
||||||
|
|
||||||
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
||||||
for param, value in non_default_params.items():
|
for param, value in non_default_params.items():
|
||||||
if param == "max_tokens":
|
if param == "max_tokens":
|
||||||
optional_params["max_tokens"] = value
|
optional_params["max_tokens"] = value
|
||||||
|
if param == "tools":
|
||||||
|
optional_params["tools"] = value
|
||||||
return optional_params
|
return optional_params
|
||||||
|
|
||||||
|
|
||||||
|
@ -897,7 +899,37 @@ def completion(
|
||||||
elif provider == "anthropic":
|
elif provider == "anthropic":
|
||||||
if model.startswith("anthropic.claude-3"):
|
if model.startswith("anthropic.claude-3"):
|
||||||
outputText = response_body.get("content")[0].get("text", None)
|
outputText = response_body.get("content")[0].get("text", None)
|
||||||
|
if "<invoke>" in outputText: # OUTPUT PARSE FUNCTION CALL
|
||||||
|
function_name = extract_between_tags("tool_name", outputText)[0]
|
||||||
|
function_arguments_str = extract_between_tags("invoke", outputText)[
|
||||||
|
0
|
||||||
|
].strip()
|
||||||
|
function_arguments_str = (
|
||||||
|
f"<invoke>{function_arguments_str}</invoke>"
|
||||||
|
)
|
||||||
|
function_arguments = parse_xml_params(function_arguments_str)
|
||||||
|
_message = litellm.Message(
|
||||||
|
tool_calls=[
|
||||||
|
{
|
||||||
|
"id": f"call_{uuid.uuid4()}",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": function_name,
|
||||||
|
"arguments": json.dumps(function_arguments),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
content=None,
|
||||||
|
)
|
||||||
|
model_response.choices[0].message = _message # type: ignore
|
||||||
model_response["finish_reason"] = response_body["stop_reason"]
|
model_response["finish_reason"] = response_body["stop_reason"]
|
||||||
|
_usage = litellm.Usage(
|
||||||
|
prompt_tokens=response_body["usage"]["input_tokens"],
|
||||||
|
completion_tokens=response_body["usage"]["output_tokens"],
|
||||||
|
total_tokens=response_body["usage"]["input_tokens"]
|
||||||
|
+ response_body["usage"]["output_tokens"],
|
||||||
|
)
|
||||||
|
model_response.usage = _usage
|
||||||
else:
|
else:
|
||||||
outputText = response_body["completion"]
|
outputText = response_body["completion"]
|
||||||
model_response["finish_reason"] = response_body["stop_reason"]
|
model_response["finish_reason"] = response_body["stop_reason"]
|
||||||
|
@ -919,8 +951,17 @@ def completion(
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
if len(outputText) > 0:
|
if (
|
||||||
|
len(outputText) > 0
|
||||||
|
and hasattr(model_response.choices[0], "message")
|
||||||
|
and model_response.choices[0].message.tool_calls is None
|
||||||
|
):
|
||||||
model_response["choices"][0]["message"]["content"] = outputText
|
model_response["choices"][0]["message"]["content"] = outputText
|
||||||
|
elif (
|
||||||
|
hasattr(model_response.choices[0], "message")
|
||||||
|
and model_response.choices[0].message.tool_calls is not None
|
||||||
|
):
|
||||||
|
pass
|
||||||
else:
|
else:
|
||||||
raise Exception()
|
raise Exception()
|
||||||
except:
|
except:
|
||||||
|
@ -930,26 +971,28 @@ def completion(
|
||||||
)
|
)
|
||||||
|
|
||||||
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
|
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
|
||||||
prompt_tokens = response_metadata.get(
|
if getattr(model_response.usage, "total_tokens", None) is None:
|
||||||
"x-amzn-bedrock-input-token-count", len(encoding.encode(prompt))
|
prompt_tokens = response_metadata.get(
|
||||||
)
|
"x-amzn-bedrock-input-token-count", len(encoding.encode(prompt))
|
||||||
completion_tokens = response_metadata.get(
|
)
|
||||||
"x-amzn-bedrock-output-token-count",
|
completion_tokens = response_metadata.get(
|
||||||
len(
|
"x-amzn-bedrock-output-token-count",
|
||||||
encoding.encode(
|
len(
|
||||||
model_response["choices"][0]["message"].get("content", "")
|
encoding.encode(
|
||||||
)
|
model_response["choices"][0]["message"].get("content", "")
|
||||||
),
|
)
|
||||||
)
|
),
|
||||||
|
)
|
||||||
|
usage = Usage(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=prompt_tokens + completion_tokens,
|
||||||
|
)
|
||||||
|
model_response.usage = usage
|
||||||
|
|
||||||
model_response["created"] = int(time.time())
|
model_response["created"] = int(time.time())
|
||||||
model_response["model"] = model
|
model_response["model"] = model
|
||||||
usage = Usage(
|
|
||||||
prompt_tokens=prompt_tokens,
|
|
||||||
completion_tokens=completion_tokens,
|
|
||||||
total_tokens=prompt_tokens + completion_tokens,
|
|
||||||
)
|
|
||||||
model_response.usage = usage
|
|
||||||
model_response._hidden_params["region_name"] = client.meta.region_name
|
model_response._hidden_params["region_name"] = client.meta.region_name
|
||||||
print_verbose(f"model_response._hidden_params: {model_response._hidden_params}")
|
print_verbose(f"model_response._hidden_params: {model_response._hidden_params}")
|
||||||
return model_response
|
return model_response
|
||||||
|
|
|
@ -271,6 +271,50 @@ def test_bedrock_claude_3_tool_calling():
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def encode_image(image_path):
|
||||||
|
import base64
|
||||||
|
|
||||||
|
with open(image_path, "rb") as image_file:
|
||||||
|
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(
|
||||||
|
reason="we already test claude-3, this is just another way to pass images"
|
||||||
|
)
|
||||||
|
def test_completion_claude_3_base64():
|
||||||
|
try:
|
||||||
|
litellm.set_verbose = True
|
||||||
|
litellm.num_retries = 3
|
||||||
|
image_path = "../proxy/cached_logo.jpg"
|
||||||
|
# Getting the base64 string
|
||||||
|
base64_image = encode_image(image_path)
|
||||||
|
resp = litellm.completion(
|
||||||
|
model="bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "Whats in this image?"},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": "data:image/jpeg;base64," + base64_image
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_tokens = resp.usage.prompt_tokens
|
||||||
|
raise Exception("it worked!")
|
||||||
|
except Exception as e:
|
||||||
|
if "500 Internal error encountered.'" in str(e):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
pytest.fail(f"An exception occurred - {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
def test_provisioned_throughput():
|
def test_provisioned_throughput():
|
||||||
try:
|
try:
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
|
|
|
@ -245,10 +245,14 @@ class Message(OpenAIObject):
|
||||||
self.role = role
|
self.role = role
|
||||||
if function_call is not None:
|
if function_call is not None:
|
||||||
self.function_call = FunctionCall(**function_call)
|
self.function_call = FunctionCall(**function_call)
|
||||||
|
else:
|
||||||
|
self.function_call = function_call
|
||||||
if tool_calls is not None:
|
if tool_calls is not None:
|
||||||
self.tool_calls = []
|
self.tool_calls = []
|
||||||
for tool_call in tool_calls:
|
for tool_call in tool_calls:
|
||||||
self.tool_calls.append(ChatCompletionMessageToolCall(**tool_call))
|
self.tool_calls.append(ChatCompletionMessageToolCall(**tool_call))
|
||||||
|
else:
|
||||||
|
self.tool_calls = tool_calls
|
||||||
if logprobs is not None:
|
if logprobs is not None:
|
||||||
self._logprobs = logprobs
|
self._logprobs = logprobs
|
||||||
|
|
||||||
|
@ -4111,6 +4115,7 @@ def get_optional_params(
|
||||||
and custom_llm_provider != "together_ai"
|
and custom_llm_provider != "together_ai"
|
||||||
and custom_llm_provider != "mistral"
|
and custom_llm_provider != "mistral"
|
||||||
and custom_llm_provider != "anthropic"
|
and custom_llm_provider != "anthropic"
|
||||||
|
and custom_llm_provider != "bedrock"
|
||||||
):
|
):
|
||||||
if custom_llm_provider == "ollama" or custom_llm_provider == "ollama_chat":
|
if custom_llm_provider == "ollama" or custom_llm_provider == "ollama_chat":
|
||||||
# ollama actually supports json output
|
# ollama actually supports json output
|
||||||
|
@ -4521,13 +4526,13 @@ def get_optional_params(
|
||||||
# \"max_tokens_to_sample\":300,\"temperature\":0.5,\"top_p\":1,\"stop_sequences\":[\"\\\\n\\\\nHuman:\"]}"
|
# \"max_tokens_to_sample\":300,\"temperature\":0.5,\"top_p\":1,\"stop_sequences\":[\"\\\\n\\\\nHuman:\"]}"
|
||||||
if model.startswith("anthropic.claude-3"):
|
if model.startswith("anthropic.claude-3"):
|
||||||
optional_params = (
|
optional_params = (
|
||||||
litellm.AmazonAnthropicClaude3Config.map_openai_params(
|
litellm.AmazonAnthropicClaude3Config().map_openai_params(
|
||||||
non_default_params=non_default_params,
|
non_default_params=non_default_params,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
optional_params = litellm.AmazonAnthropicConfig.map_openai_params(
|
optional_params = litellm.AmazonAnthropicConfig().map_openai_params(
|
||||||
non_default_params=non_default_params,
|
non_default_params=non_default_params,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue