mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
fix(bedrock_httpx.py): working claude 3 function calling
This commit is contained in:
parent
a76a9b7d11
commit
a995a0b172
11 changed files with 321 additions and 108 deletions
|
@ -16,11 +16,11 @@ repos:
|
||||||
name: Check if files match
|
name: Check if files match
|
||||||
entry: python3 ci_cd/check_files_match.py
|
entry: python3 ci_cd/check_files_match.py
|
||||||
language: system
|
language: system
|
||||||
- repo: local
|
# - repo: local
|
||||||
hooks:
|
# hooks:
|
||||||
- id: mypy
|
# - id: mypy
|
||||||
name: mypy
|
# name: mypy
|
||||||
entry: python3 -m mypy --ignore-missing-imports
|
# entry: python3 -m mypy --ignore-missing-imports
|
||||||
language: system
|
# language: system
|
||||||
types: [python]
|
# types: [python]
|
||||||
files: ^litellm/
|
# files: ^litellm/
|
|
@ -47,6 +47,11 @@ import httpx # type: ignore
|
||||||
from .bedrock import BedrockError, convert_messages_to_prompt, ModelResponseIterator
|
from .bedrock import BedrockError, convert_messages_to_prompt, ModelResponseIterator
|
||||||
from litellm.types.llms.bedrock import *
|
from litellm.types.llms.bedrock import *
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
|
from litellm.types.llms.openai import (
|
||||||
|
ChatCompletionResponseMessage,
|
||||||
|
ChatCompletionToolCallChunk,
|
||||||
|
ChatCompletionToolCallFunctionChunk,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AmazonCohereChatConfig:
|
class AmazonCohereChatConfig:
|
||||||
|
@ -1004,12 +1009,12 @@ class BedrockLLM(BaseLLM):
|
||||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||||
timeout = httpx.Timeout(timeout)
|
timeout = httpx.Timeout(timeout)
|
||||||
_params["timeout"] = timeout
|
_params["timeout"] = timeout
|
||||||
self.client = AsyncHTTPHandler(**_params) # type: ignore
|
client = AsyncHTTPHandler(**_params) # type: ignore
|
||||||
else:
|
else:
|
||||||
self.client = client # type: ignore
|
client = client # type: ignore
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await self.client.post(api_base, headers=headers, data=data) # type: ignore
|
response = await client.post(api_base, headers=headers, data=data) # type: ignore
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
except httpx.HTTPStatusError as err:
|
except httpx.HTTPStatusError as err:
|
||||||
error_code = err.response.status_code
|
error_code = err.response.status_code
|
||||||
|
@ -1125,11 +1130,55 @@ class AmazonConverseConfig:
|
||||||
"tool_choice",
|
"tool_choice",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def map_tool_choice_values(
|
||||||
|
self, model: str, tool_choice: Union[str, dict], drop_params: bool
|
||||||
|
) -> Optional[ToolChoiceValuesBlock]:
|
||||||
|
if not model.startswith("anthropic") and not model.startswith("mistral"):
|
||||||
|
# only anthropic and mistral support tool choice config. otherwise (E.g. cohere) will fail the call - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html
|
||||||
|
if drop_params == True or litellm.drop_params == True:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
raise litellm.utils.UnsupportedParamsError(
|
||||||
|
message="Only Anthropic and Mistral on Bedrock support 'tool_choice'. To drop it from the call, set `litellm.drop_params = True.`",
|
||||||
|
status_code=400,
|
||||||
|
)
|
||||||
|
if tool_choice == "none":
|
||||||
|
if litellm.drop_params is True or drop_params is True:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
raise litellm.utils.UnsupportedParamsError(
|
||||||
|
message="Bedrock doesn't support tool_choice={}. To drop it from the call, set `litellm.drop_params = True.".format(
|
||||||
|
tool_choice
|
||||||
|
),
|
||||||
|
status_code=400,
|
||||||
|
)
|
||||||
|
elif tool_choice == "required":
|
||||||
|
return ToolChoiceValuesBlock(any={})
|
||||||
|
elif tool_choice == "auto":
|
||||||
|
return ToolChoiceValuesBlock(auto={})
|
||||||
|
elif isinstance(tool_choice, dict):
|
||||||
|
# only supported for anthropic + mistral models - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html
|
||||||
|
specific_tool = SpecificToolChoiceBlock(
|
||||||
|
name=tool_choice.get("function", {}).get("name", "")
|
||||||
|
)
|
||||||
|
return ToolChoiceValuesBlock(tool=specific_tool)
|
||||||
|
else:
|
||||||
|
raise litellm.utils.UnsupportedParamsError(
|
||||||
|
message="Bedrock doesn't support tool_choice={}. Supported tool_choice values=['auto', 'required', json object]. To drop it from the call, set `litellm.drop_params = True.".format(
|
||||||
|
tool_choice
|
||||||
|
),
|
||||||
|
status_code=400,
|
||||||
|
)
|
||||||
|
|
||||||
def get_supported_image_types(self) -> List[str]:
|
def get_supported_image_types(self) -> List[str]:
|
||||||
return ["png", "jpeg", "gif", "webp"]
|
return ["png", "jpeg", "gif", "webp"]
|
||||||
|
|
||||||
def map_openai_params(
|
def map_openai_params(
|
||||||
self, non_default_params: dict, optional_params: dict
|
self,
|
||||||
|
model: str,
|
||||||
|
non_default_params: dict,
|
||||||
|
optional_params: dict,
|
||||||
|
drop_params: bool,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
for param, value in non_default_params.items():
|
for param, value in non_default_params.items():
|
||||||
if param == "max_tokens":
|
if param == "max_tokens":
|
||||||
|
@ -1144,6 +1193,14 @@ class AmazonConverseConfig:
|
||||||
optional_params["temperature"] = value
|
optional_params["temperature"] = value
|
||||||
if param == "top_p":
|
if param == "top_p":
|
||||||
optional_params["topP"] = value
|
optional_params["topP"] = value
|
||||||
|
if param == "tools":
|
||||||
|
optional_params["tools"] = value
|
||||||
|
if param == "tool_choice":
|
||||||
|
_tool_choice_value = self.map_tool_choice_values(
|
||||||
|
model=model, tool_choice=value, drop_params=drop_params
|
||||||
|
)
|
||||||
|
if _tool_choice_value is not None:
|
||||||
|
optional_params["tool_choice"] = _tool_choice_value
|
||||||
return optional_params
|
return optional_params
|
||||||
|
|
||||||
|
|
||||||
|
@ -1151,6 +1208,124 @@ class BedrockConverseLLM(BaseLLM):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
def process_response(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
response: Union[requests.Response, httpx.Response],
|
||||||
|
model_response: ModelResponse,
|
||||||
|
stream: bool,
|
||||||
|
logging_obj: Logging,
|
||||||
|
optional_params: dict,
|
||||||
|
api_key: str,
|
||||||
|
data: Union[dict, str],
|
||||||
|
messages: List,
|
||||||
|
print_verbose,
|
||||||
|
encoding,
|
||||||
|
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=messages,
|
||||||
|
api_key=api_key,
|
||||||
|
original_response=response.text,
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
)
|
||||||
|
print_verbose(f"raw model_response: {response.text}")
|
||||||
|
|
||||||
|
## RESPONSE OBJECT
|
||||||
|
try:
|
||||||
|
completion_response = ConverseResponseBlock(**response.json()) # type: ignore
|
||||||
|
except Exception as e:
|
||||||
|
raise BedrockError(
|
||||||
|
message="Received={}, Error converting to valid response block={}. File an issue if litellm error - https://github.com/BerriAI/litellm/issues".format(
|
||||||
|
response.text, str(e)
|
||||||
|
),
|
||||||
|
status_code=422,
|
||||||
|
)
|
||||||
|
|
||||||
|
"""
|
||||||
|
Bedrock Response Object has optional message block
|
||||||
|
|
||||||
|
completion_response["output"].get("message", None)
|
||||||
|
|
||||||
|
A message block looks like this (Example 1):
|
||||||
|
"output": {
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"text": "Is there anything else you'd like to talk about? Perhaps I can help with some economic questions or provide some information about economic concepts?"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
(Example 2):
|
||||||
|
"output": {
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"toolUse": {
|
||||||
|
"toolUseId": "tooluse_hbTgdi0CSLq_hM4P8csZJA",
|
||||||
|
"name": "top_song",
|
||||||
|
"input": {
|
||||||
|
"sign": "WZPZ"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
"""
|
||||||
|
message: Optional[MessageBlock] = completion_response["output"]["message"]
|
||||||
|
chat_completion_message: ChatCompletionResponseMessage = {"role": "assistant"}
|
||||||
|
content_str = ""
|
||||||
|
tools: List[ChatCompletionToolCallChunk] = []
|
||||||
|
if message is not None:
|
||||||
|
for content in message["content"]:
|
||||||
|
"""
|
||||||
|
- Content is either a tool response or text
|
||||||
|
"""
|
||||||
|
if "text" in content:
|
||||||
|
content_str += content["text"]
|
||||||
|
if "toolUse" in content:
|
||||||
|
_function_chunk = ChatCompletionToolCallFunctionChunk(
|
||||||
|
name=content["toolUse"]["name"],
|
||||||
|
arguments=json.dumps(content["toolUse"]["input"]),
|
||||||
|
)
|
||||||
|
_tool_response_chunk = ChatCompletionToolCallChunk(
|
||||||
|
id=content["toolUse"]["toolUseId"],
|
||||||
|
type="function",
|
||||||
|
function=_function_chunk,
|
||||||
|
)
|
||||||
|
tools.append(_tool_response_chunk)
|
||||||
|
chat_completion_message["content"] = content_str
|
||||||
|
chat_completion_message["tool_calls"] = tools
|
||||||
|
|
||||||
|
## CALCULATING USAGE - bedrock returns usage in the headers
|
||||||
|
input_tokens = completion_response["usage"]["inputTokens"]
|
||||||
|
output_tokens = completion_response["usage"]["outputTokens"]
|
||||||
|
total_tokens = completion_response["usage"]["totalTokens"]
|
||||||
|
|
||||||
|
model_response.choices = [
|
||||||
|
litellm.Choices(
|
||||||
|
finish_reason=map_finish_reason(completion_response["stopReason"]),
|
||||||
|
index=0,
|
||||||
|
message=litellm.Message(**chat_completion_message),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
model_response["created"] = int(time.time())
|
||||||
|
model_response["model"] = model
|
||||||
|
usage = Usage(
|
||||||
|
prompt_tokens=input_tokens,
|
||||||
|
completion_tokens=output_tokens,
|
||||||
|
total_tokens=total_tokens,
|
||||||
|
)
|
||||||
|
setattr(model_response, "usage", usage)
|
||||||
|
|
||||||
|
return model_response
|
||||||
|
|
||||||
def encode_model_id(self, model_id: str) -> str:
|
def encode_model_id(self, model_id: str) -> str:
|
||||||
"""
|
"""
|
||||||
Double encode the model ID to ensure it matches the expected double-encoded format.
|
Double encode the model ID to ensure it matches the expected double-encoded format.
|
||||||
|
@ -1387,11 +1562,14 @@ class BedrockConverseLLM(BaseLLM):
|
||||||
additional_request_keys = []
|
additional_request_keys = []
|
||||||
additional_request_params = {}
|
additional_request_params = {}
|
||||||
supported_converse_params = AmazonConverseConfig().get_config().keys()
|
supported_converse_params = AmazonConverseConfig().get_config().keys()
|
||||||
|
supported_tool_call_params = ["tools", "tool_choice"]
|
||||||
## TRANSFORMATION ##
|
## TRANSFORMATION ##
|
||||||
# send all model-specific params in 'additional_request_params'
|
# send all model-specific params in 'additional_request_params'
|
||||||
for k, v in inference_params.items():
|
for k, v in inference_params.items():
|
||||||
if k not in supported_converse_params:
|
if (
|
||||||
|
k not in supported_converse_params
|
||||||
|
and k not in supported_tool_call_params
|
||||||
|
):
|
||||||
additional_request_params[k] = v
|
additional_request_params[k] = v
|
||||||
additional_request_keys.append(k)
|
additional_request_keys.append(k)
|
||||||
for key in additional_request_keys:
|
for key in additional_request_keys:
|
||||||
|
@ -1401,23 +1579,27 @@ class BedrockConverseLLM(BaseLLM):
|
||||||
messages=messages
|
messages=messages
|
||||||
)
|
)
|
||||||
bedrock_tools: List[ToolBlock] = _bedrock_tools_pt(
|
bedrock_tools: List[ToolBlock] = _bedrock_tools_pt(
|
||||||
inference_params.get("tools", [])
|
inference_params.pop("tools", [])
|
||||||
)
|
)
|
||||||
bedrock_tool_config: Optional[ToolConfigBlock] = None
|
bedrock_tool_config: Optional[ToolConfigBlock] = None
|
||||||
if len(bedrock_tools) > 0:
|
if len(bedrock_tools) > 0:
|
||||||
|
tool_choice_values: ToolChoiceValuesBlock = inference_params.pop(
|
||||||
|
"tool_choice", None
|
||||||
|
)
|
||||||
bedrock_tool_config = ToolConfigBlock(
|
bedrock_tool_config = ToolConfigBlock(
|
||||||
tools=bedrock_tools,
|
tools=bedrock_tools,
|
||||||
toolChoice=inference_params.get("tool_choice", None),
|
|
||||||
)
|
)
|
||||||
|
if tool_choice_values is not None:
|
||||||
|
bedrock_tool_config["toolChoice"] = tool_choice_values
|
||||||
|
|
||||||
data: RequestObject = {
|
_data: RequestObject = {
|
||||||
"messages": bedrock_messages,
|
"messages": bedrock_messages,
|
||||||
"additionalModelRequestFields": additional_request_params,
|
"additionalModelRequestFields": additional_request_params,
|
||||||
"system": system_content_blocks,
|
"system": system_content_blocks,
|
||||||
}
|
}
|
||||||
if bedrock_tool_config is not None:
|
if bedrock_tool_config is not None:
|
||||||
data["toolConfig"] = bedrock_tool_config
|
_data["toolConfig"] = bedrock_tool_config
|
||||||
|
data = json.dumps(_data)
|
||||||
## COMPLETION CALL
|
## COMPLETION CALL
|
||||||
|
|
||||||
headers = {"Content-Type": "application/json"}
|
headers = {"Content-Type": "application/json"}
|
||||||
|
@ -1441,8 +1623,18 @@ class BedrockConverseLLM(BaseLLM):
|
||||||
)
|
)
|
||||||
|
|
||||||
### ROUTING (ASYNC, STREAMING, SYNC)
|
### ROUTING (ASYNC, STREAMING, SYNC)
|
||||||
|
### COMPLETION
|
||||||
|
if client is None or isinstance(client, AsyncHTTPHandler):
|
||||||
|
_params = {}
|
||||||
|
if timeout is not None:
|
||||||
|
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||||
|
timeout = httpx.Timeout(timeout)
|
||||||
|
_params["timeout"] = timeout
|
||||||
|
client = HTTPHandler(**_params) # type: ignore
|
||||||
|
else:
|
||||||
|
client = client
|
||||||
try:
|
try:
|
||||||
response = self.client.post(url=prepped.url, headers=prepped.headers, data=data) # type: ignore
|
response = client.post(url=prepped.url, headers=prepped.headers, data=data) # type: ignore
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
except httpx.HTTPStatusError as err:
|
except httpx.HTTPStatusError as err:
|
||||||
error_code = err.response.status_code
|
error_code = err.response.status_code
|
||||||
|
@ -1450,6 +1642,20 @@ class BedrockConverseLLM(BaseLLM):
|
||||||
except httpx.TimeoutException as e:
|
except httpx.TimeoutException as e:
|
||||||
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||||
|
|
||||||
|
return self.process_response(
|
||||||
|
model=model,
|
||||||
|
response=response,
|
||||||
|
model_response=model_response,
|
||||||
|
stream=stream,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
optional_params=optional_params,
|
||||||
|
api_key="",
|
||||||
|
data=data,
|
||||||
|
messages=messages,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
encoding=encoding,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_response_stream_shape():
|
def get_response_stream_shape():
|
||||||
from botocore.model import ServiceModel
|
from botocore.model import ServiceModel
|
||||||
|
|
|
@ -156,12 +156,13 @@ class HTTPHandler:
|
||||||
self,
|
self,
|
||||||
url: str,
|
url: str,
|
||||||
data: Optional[Union[dict, str]] = None,
|
data: Optional[Union[dict, str]] = None,
|
||||||
|
json: Optional[Union[dict, str]] = None,
|
||||||
params: Optional[dict] = None,
|
params: Optional[dict] = None,
|
||||||
headers: Optional[dict] = None,
|
headers: Optional[dict] = None,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
):
|
):
|
||||||
req = self.client.build_request(
|
req = self.client.build_request(
|
||||||
"POST", url, data=data, params=params, headers=headers # type: ignore
|
"POST", url, data=data, json=json, params=params, headers=headers # type: ignore
|
||||||
)
|
)
|
||||||
response = self.client.send(req, stream=stream)
|
response = self.client.send(req, stream=stream)
|
||||||
return response
|
return response
|
||||||
|
|
|
@ -1617,6 +1617,7 @@ from litellm.types.llms.bedrock import (
|
||||||
ToolInputSchemaBlock as BedrockToolInputSchemaBlock,
|
ToolInputSchemaBlock as BedrockToolInputSchemaBlock,
|
||||||
ToolSpecBlock as BedrockToolSpecBlock,
|
ToolSpecBlock as BedrockToolSpecBlock,
|
||||||
ToolBlock as BedrockToolBlock,
|
ToolBlock as BedrockToolBlock,
|
||||||
|
ToolChoiceValuesBlock as BedrockToolChoiceValuesBlock,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1814,7 +1815,7 @@ def _convert_to_bedrock_tool_call_result(
|
||||||
|
|
||||||
tool_result_content_block = BedrockToolResultContentBlock(text=content)
|
tool_result_content_block = BedrockToolResultContentBlock(text=content)
|
||||||
tool_result = BedrockToolResultBlock(
|
tool_result = BedrockToolResultBlock(
|
||||||
content=tool_result_content_block,
|
content=[tool_result_content_block],
|
||||||
toolUseId=id,
|
toolUseId=id,
|
||||||
)
|
)
|
||||||
content_block = BedrockContentBlock(toolResult=tool_result)
|
content_block = BedrockContentBlock(toolResult=tool_result)
|
||||||
|
|
|
@ -79,7 +79,7 @@ from .llms.anthropic import AnthropicChatCompletion
|
||||||
from .llms.anthropic_text import AnthropicTextCompletion
|
from .llms.anthropic_text import AnthropicTextCompletion
|
||||||
from .llms.huggingface_restapi import Huggingface
|
from .llms.huggingface_restapi import Huggingface
|
||||||
from .llms.predibase import PredibaseChatCompletion
|
from .llms.predibase import PredibaseChatCompletion
|
||||||
from .llms.bedrock_httpx import BedrockLLM
|
from .llms.bedrock_httpx import BedrockLLM, BedrockConverseLLM
|
||||||
from .llms.vertex_httpx import VertexLLM
|
from .llms.vertex_httpx import VertexLLM
|
||||||
from .llms.triton import TritonChatCompletion
|
from .llms.triton import TritonChatCompletion
|
||||||
from .llms.prompt_templates.factory import (
|
from .llms.prompt_templates.factory import (
|
||||||
|
@ -121,7 +121,7 @@ azure_text_completions = AzureTextCompletion()
|
||||||
huggingface = Huggingface()
|
huggingface = Huggingface()
|
||||||
predibase_chat_completions = PredibaseChatCompletion()
|
predibase_chat_completions = PredibaseChatCompletion()
|
||||||
triton_chat_completions = TritonChatCompletion()
|
triton_chat_completions = TritonChatCompletion()
|
||||||
bedrock_chat_completion = BedrockLLM()
|
bedrock_chat_completion = BedrockConverseLLM()
|
||||||
vertex_chat_completion = VertexLLM()
|
vertex_chat_completion = VertexLLM()
|
||||||
####### COMPLETION ENDPOINTS ################
|
####### COMPLETION ENDPOINTS ################
|
||||||
|
|
||||||
|
|
|
@ -302,10 +302,7 @@ def test_completion_claude_3():
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model",
|
"model",
|
||||||
[
|
["anthropic/claude-3-opus-20240229", "anthropic.claude-3-sonnet-20240229-v1:0"],
|
||||||
# "anthropic/claude-3-opus-20240229",
|
|
||||||
"cohere.command-r-plus-v1:0"
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
def test_completion_claude_3_function_call(model):
|
def test_completion_claude_3_function_call(model):
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
|
@ -345,6 +342,7 @@ def test_completion_claude_3_function_call(model):
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {"name": "get_current_weather"},
|
"function": {"name": "get_current_weather"},
|
||||||
},
|
},
|
||||||
|
drop_params=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add any assertions, here to check response args
|
# Add any assertions, here to check response args
|
||||||
|
@ -375,6 +373,7 @@ def test_completion_claude_3_function_call(model):
|
||||||
messages=messages,
|
messages=messages,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
tool_choice="auto",
|
tool_choice="auto",
|
||||||
|
drop_params=True,
|
||||||
)
|
)
|
||||||
print(second_response)
|
print(second_response)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
@ -15,6 +15,7 @@ from litellm.llms.prompt_templates.factory import (
|
||||||
claude_2_1_pt,
|
claude_2_1_pt,
|
||||||
llama_2_chat_pt,
|
llama_2_chat_pt,
|
||||||
prompt_factory,
|
prompt_factory,
|
||||||
|
_bedrock_tools_pt,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -128,3 +129,29 @@ def test_anthropic_messages_pt():
|
||||||
|
|
||||||
|
|
||||||
# codellama_prompt_format()
|
# codellama_prompt_format()
|
||||||
|
def test_bedrock_tool_calling_pt():
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"description": "Get the current weather in a given location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA",
|
||||||
|
},
|
||||||
|
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||||
|
},
|
||||||
|
"required": ["location"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
converted_tools = _bedrock_tools_pt(tools=tools)
|
||||||
|
|
||||||
|
print(converted_tools)
|
||||||
|
|
||||||
|
assert False
|
||||||
|
|
|
@ -31,7 +31,7 @@ class ToolResultContentBlock(TypedDict, total=False):
|
||||||
|
|
||||||
|
|
||||||
class ToolResultBlock(TypedDict, total=False):
|
class ToolResultBlock(TypedDict, total=False):
|
||||||
content: Required[ToolResultContentBlock]
|
content: Required[List[ToolResultContentBlock]]
|
||||||
toolUseId: Required[str]
|
toolUseId: Required[str]
|
||||||
status: Literal["success", "error"]
|
status: Literal["success", "error"]
|
||||||
|
|
||||||
|
@ -54,6 +54,30 @@ class MessageBlock(TypedDict):
|
||||||
role: Literal["user", "assistant"]
|
role: Literal["user", "assistant"]
|
||||||
|
|
||||||
|
|
||||||
|
class ConverseMetricsBlock(TypedDict):
|
||||||
|
latencyMs: float # time in ms
|
||||||
|
|
||||||
|
|
||||||
|
class ConverseResponseOutputBlock(TypedDict):
|
||||||
|
message: Optional[MessageBlock]
|
||||||
|
|
||||||
|
|
||||||
|
class ConverseTokenUsageBlock(TypedDict):
|
||||||
|
inputTokens: int
|
||||||
|
outputTokens: int
|
||||||
|
totalTokens: int
|
||||||
|
|
||||||
|
|
||||||
|
class ConverseResponseBlock(TypedDict):
|
||||||
|
additionalModelResponseFields: dict
|
||||||
|
metrics: ConverseMetricsBlock
|
||||||
|
output: ConverseResponseOutputBlock
|
||||||
|
stopReason: (
|
||||||
|
str # end_turn | tool_use | max_tokens | stop_sequence | content_filtered
|
||||||
|
)
|
||||||
|
usage: ConverseTokenUsageBlock
|
||||||
|
|
||||||
|
|
||||||
class ToolInputSchemaBlock(TypedDict):
|
class ToolInputSchemaBlock(TypedDict):
|
||||||
json: Optional[dict]
|
json: Optional[dict]
|
||||||
|
|
||||||
|
@ -72,9 +96,15 @@ class SpecificToolChoiceBlock(TypedDict):
|
||||||
name: str
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
class ToolChoiceValuesBlock(TypedDict, total=False):
|
||||||
|
any: dict
|
||||||
|
auto: dict
|
||||||
|
tool: SpecificToolChoiceBlock
|
||||||
|
|
||||||
|
|
||||||
class ToolConfigBlock(TypedDict, total=False):
|
class ToolConfigBlock(TypedDict, total=False):
|
||||||
tools: Required[List[ToolBlock]]
|
tools: Required[List[ToolBlock]]
|
||||||
toolChoice: Union[str, SpecificToolChoiceBlock]
|
toolChoice: Union[str, ToolChoiceValuesBlock]
|
||||||
|
|
||||||
|
|
||||||
class RequestObject(TypedDict, total=False):
|
class RequestObject(TypedDict, total=False):
|
||||||
|
|
|
@ -293,3 +293,20 @@ class ListBatchRequest(TypedDict, total=False):
|
||||||
extra_headers: Optional[Dict[str, str]]
|
extra_headers: Optional[Dict[str, str]]
|
||||||
extra_body: Optional[Dict[str, str]]
|
extra_body: Optional[Dict[str, str]]
|
||||||
timeout: Optional[float]
|
timeout: Optional[float]
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionToolCallFunctionChunk(TypedDict):
|
||||||
|
name: str
|
||||||
|
arguments: str
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionToolCallChunk(TypedDict):
|
||||||
|
id: str
|
||||||
|
type: Literal["function"]
|
||||||
|
function: ChatCompletionToolCallFunctionChunk
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionResponseMessage(TypedDict, total=False):
|
||||||
|
content: Optional[str]
|
||||||
|
tool_calls: List[ChatCompletionToolCallChunk]
|
||||||
|
role: Literal["assistant"]
|
||||||
|
|
|
@ -5618,84 +5618,13 @@ def get_optional_params(
|
||||||
supported_params = get_supported_openai_params(
|
supported_params = get_supported_openai_params(
|
||||||
model=model, custom_llm_provider=custom_llm_provider
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
)
|
)
|
||||||
if "ai21" in model:
|
|
||||||
_check_valid_arg(supported_params=supported_params)
|
_check_valid_arg(supported_params=supported_params)
|
||||||
# params "maxTokens":200,"temperature":0,"topP":250,"stop_sequences":[],
|
optional_params = litellm.AmazonConverseConfig().map_openai_params(
|
||||||
# https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=j2-ultra
|
model=model,
|
||||||
if max_tokens is not None:
|
|
||||||
optional_params["maxTokens"] = max_tokens
|
|
||||||
if temperature is not None:
|
|
||||||
optional_params["temperature"] = temperature
|
|
||||||
if top_p is not None:
|
|
||||||
optional_params["topP"] = top_p
|
|
||||||
if stream:
|
|
||||||
optional_params["stream"] = stream
|
|
||||||
elif "anthropic" in model:
|
|
||||||
_check_valid_arg(supported_params=supported_params)
|
|
||||||
# anthropic params on bedrock
|
|
||||||
# \"max_tokens_to_sample\":300,\"temperature\":0.5,\"top_p\":1,\"stop_sequences\":[\"\\\\n\\\\nHuman:\"]}"
|
|
||||||
if model.startswith("anthropic.claude-3"):
|
|
||||||
optional_params = (
|
|
||||||
litellm.AmazonAnthropicClaude3Config().map_openai_params(
|
|
||||||
non_default_params=non_default_params,
|
non_default_params=non_default_params,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
|
drop_params=drop_params,
|
||||||
)
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
optional_params = litellm.AmazonAnthropicConfig().map_openai_params(
|
|
||||||
non_default_params=non_default_params,
|
|
||||||
optional_params=optional_params,
|
|
||||||
)
|
|
||||||
elif "amazon" in model: # amazon titan llms
|
|
||||||
_check_valid_arg(supported_params=supported_params)
|
|
||||||
# see https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-large
|
|
||||||
if max_tokens is not None:
|
|
||||||
optional_params["maxTokenCount"] = max_tokens
|
|
||||||
if temperature is not None:
|
|
||||||
optional_params["temperature"] = temperature
|
|
||||||
if stop is not None:
|
|
||||||
filtered_stop = _map_and_modify_arg(
|
|
||||||
{"stop": stop}, provider="bedrock", model=model
|
|
||||||
)
|
|
||||||
optional_params["stopSequences"] = filtered_stop["stop"]
|
|
||||||
if top_p is not None:
|
|
||||||
optional_params["topP"] = top_p
|
|
||||||
if stream:
|
|
||||||
optional_params["stream"] = stream
|
|
||||||
elif "meta" in model: # amazon / meta llms
|
|
||||||
_check_valid_arg(supported_params=supported_params)
|
|
||||||
# see https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-large
|
|
||||||
if max_tokens is not None:
|
|
||||||
optional_params["max_gen_len"] = max_tokens
|
|
||||||
if temperature is not None:
|
|
||||||
optional_params["temperature"] = temperature
|
|
||||||
if top_p is not None:
|
|
||||||
optional_params["top_p"] = top_p
|
|
||||||
if stream:
|
|
||||||
optional_params["stream"] = stream
|
|
||||||
elif "cohere" in model: # cohere models on bedrock
|
|
||||||
_check_valid_arg(supported_params=supported_params)
|
|
||||||
# handle cohere params
|
|
||||||
if stream:
|
|
||||||
optional_params["stream"] = stream
|
|
||||||
if temperature is not None:
|
|
||||||
optional_params["temperature"] = temperature
|
|
||||||
if max_tokens is not None:
|
|
||||||
optional_params["max_tokens"] = max_tokens
|
|
||||||
elif "mistral" in model:
|
|
||||||
_check_valid_arg(supported_params=supported_params)
|
|
||||||
# mistral params on bedrock
|
|
||||||
# \"max_tokens\":400,\"temperature\":0.7,\"top_p\":0.7,\"stop\":[\"\\\\n\\\\nHuman:\"]}"
|
|
||||||
if max_tokens is not None:
|
|
||||||
optional_params["max_tokens"] = max_tokens
|
|
||||||
if temperature is not None:
|
|
||||||
optional_params["temperature"] = temperature
|
|
||||||
if top_p is not None:
|
|
||||||
optional_params["top_p"] = top_p
|
|
||||||
if stop is not None:
|
|
||||||
optional_params["stop"] = stop
|
|
||||||
if stream is not None:
|
|
||||||
optional_params["stream"] = stream
|
|
||||||
elif custom_llm_provider == "aleph_alpha":
|
elif custom_llm_provider == "aleph_alpha":
|
||||||
supported_params = [
|
supported_params = [
|
||||||
"max_tokens",
|
"max_tokens",
|
||||||
|
|
3
ruff.toml
Normal file
3
ruff.toml
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
ignore = ["F405"]
|
||||||
|
extend-select = ["E501"]
|
||||||
|
line-length = 120
|
Loading…
Add table
Add a link
Reference in a new issue