forked from phoenix/litellm-mirror
fix(bedrock_httpx.py): working claude 3 function calling
This commit is contained in:
parent
a76a9b7d11
commit
a995a0b172
11 changed files with 321 additions and 108 deletions
|
@ -16,11 +16,11 @@ repos:
|
|||
name: Check if files match
|
||||
entry: python3 ci_cd/check_files_match.py
|
||||
language: system
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: mypy
|
||||
name: mypy
|
||||
entry: python3 -m mypy --ignore-missing-imports
|
||||
language: system
|
||||
types: [python]
|
||||
files: ^litellm/
|
||||
# - repo: local
|
||||
# hooks:
|
||||
# - id: mypy
|
||||
# name: mypy
|
||||
# entry: python3 -m mypy --ignore-missing-imports
|
||||
# language: system
|
||||
# types: [python]
|
||||
# files: ^litellm/
|
|
@ -47,6 +47,11 @@ import httpx # type: ignore
|
|||
from .bedrock import BedrockError, convert_messages_to_prompt, ModelResponseIterator
|
||||
from litellm.types.llms.bedrock import *
|
||||
import urllib.parse
|
||||
from litellm.types.llms.openai import (
|
||||
ChatCompletionResponseMessage,
|
||||
ChatCompletionToolCallChunk,
|
||||
ChatCompletionToolCallFunctionChunk,
|
||||
)
|
||||
|
||||
|
||||
class AmazonCohereChatConfig:
|
||||
|
@ -1004,12 +1009,12 @@ class BedrockLLM(BaseLLM):
|
|||
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||
timeout = httpx.Timeout(timeout)
|
||||
_params["timeout"] = timeout
|
||||
self.client = AsyncHTTPHandler(**_params) # type: ignore
|
||||
client = AsyncHTTPHandler(**_params) # type: ignore
|
||||
else:
|
||||
self.client = client # type: ignore
|
||||
client = client # type: ignore
|
||||
|
||||
try:
|
||||
response = await self.client.post(api_base, headers=headers, data=data) # type: ignore
|
||||
response = await client.post(api_base, headers=headers, data=data) # type: ignore
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
|
@ -1125,11 +1130,55 @@ class AmazonConverseConfig:
|
|||
"tool_choice",
|
||||
]
|
||||
|
||||
def map_tool_choice_values(
|
||||
self, model: str, tool_choice: Union[str, dict], drop_params: bool
|
||||
) -> Optional[ToolChoiceValuesBlock]:
|
||||
if not model.startswith("anthropic") and not model.startswith("mistral"):
|
||||
# only anthropic and mistral support tool choice config. otherwise (E.g. cohere) will fail the call - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html
|
||||
if drop_params == True or litellm.drop_params == True:
|
||||
return None
|
||||
else:
|
||||
raise litellm.utils.UnsupportedParamsError(
|
||||
message="Only Anthropic and Mistral on Bedrock support 'tool_choice'. To drop it from the call, set `litellm.drop_params = True.`",
|
||||
status_code=400,
|
||||
)
|
||||
if tool_choice == "none":
|
||||
if litellm.drop_params is True or drop_params is True:
|
||||
return None
|
||||
else:
|
||||
raise litellm.utils.UnsupportedParamsError(
|
||||
message="Bedrock doesn't support tool_choice={}. To drop it from the call, set `litellm.drop_params = True.".format(
|
||||
tool_choice
|
||||
),
|
||||
status_code=400,
|
||||
)
|
||||
elif tool_choice == "required":
|
||||
return ToolChoiceValuesBlock(any={})
|
||||
elif tool_choice == "auto":
|
||||
return ToolChoiceValuesBlock(auto={})
|
||||
elif isinstance(tool_choice, dict):
|
||||
# only supported for anthropic + mistral models - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html
|
||||
specific_tool = SpecificToolChoiceBlock(
|
||||
name=tool_choice.get("function", {}).get("name", "")
|
||||
)
|
||||
return ToolChoiceValuesBlock(tool=specific_tool)
|
||||
else:
|
||||
raise litellm.utils.UnsupportedParamsError(
|
||||
message="Bedrock doesn't support tool_choice={}. Supported tool_choice values=['auto', 'required', json object]. To drop it from the call, set `litellm.drop_params = True.".format(
|
||||
tool_choice
|
||||
),
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
def get_supported_image_types(self) -> List[str]:
|
||||
return ["png", "jpeg", "gif", "webp"]
|
||||
|
||||
def map_openai_params(
|
||||
self, non_default_params: dict, optional_params: dict
|
||||
self,
|
||||
model: str,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
for param, value in non_default_params.items():
|
||||
if param == "max_tokens":
|
||||
|
@ -1144,6 +1193,14 @@ class AmazonConverseConfig:
|
|||
optional_params["temperature"] = value
|
||||
if param == "top_p":
|
||||
optional_params["topP"] = value
|
||||
if param == "tools":
|
||||
optional_params["tools"] = value
|
||||
if param == "tool_choice":
|
||||
_tool_choice_value = self.map_tool_choice_values(
|
||||
model=model, tool_choice=value, drop_params=drop_params
|
||||
)
|
||||
if _tool_choice_value is not None:
|
||||
optional_params["tool_choice"] = _tool_choice_value
|
||||
return optional_params
|
||||
|
||||
|
||||
|
@ -1151,6 +1208,124 @@ class BedrockConverseLLM(BaseLLM):
|
|||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def process_response(
|
||||
self,
|
||||
model: str,
|
||||
response: Union[requests.Response, httpx.Response],
|
||||
model_response: ModelResponse,
|
||||
stream: bool,
|
||||
logging_obj: Logging,
|
||||
optional_params: dict,
|
||||
api_key: str,
|
||||
data: Union[dict, str],
|
||||
messages: List,
|
||||
print_verbose,
|
||||
encoding,
|
||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=response.text,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
print_verbose(f"raw model_response: {response.text}")
|
||||
|
||||
## RESPONSE OBJECT
|
||||
try:
|
||||
completion_response = ConverseResponseBlock(**response.json()) # type: ignore
|
||||
except Exception as e:
|
||||
raise BedrockError(
|
||||
message="Received={}, Error converting to valid response block={}. File an issue if litellm error - https://github.com/BerriAI/litellm/issues".format(
|
||||
response.text, str(e)
|
||||
),
|
||||
status_code=422,
|
||||
)
|
||||
|
||||
"""
|
||||
Bedrock Response Object has optional message block
|
||||
|
||||
completion_response["output"].get("message", None)
|
||||
|
||||
A message block looks like this (Example 1):
|
||||
"output": {
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"text": "Is there anything else you'd like to talk about? Perhaps I can help with some economic questions or provide some information about economic concepts?"
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
(Example 2):
|
||||
"output": {
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"toolUse": {
|
||||
"toolUseId": "tooluse_hbTgdi0CSLq_hM4P8csZJA",
|
||||
"name": "top_song",
|
||||
"input": {
|
||||
"sign": "WZPZ"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
"""
|
||||
message: Optional[MessageBlock] = completion_response["output"]["message"]
|
||||
chat_completion_message: ChatCompletionResponseMessage = {"role": "assistant"}
|
||||
content_str = ""
|
||||
tools: List[ChatCompletionToolCallChunk] = []
|
||||
if message is not None:
|
||||
for content in message["content"]:
|
||||
"""
|
||||
- Content is either a tool response or text
|
||||
"""
|
||||
if "text" in content:
|
||||
content_str += content["text"]
|
||||
if "toolUse" in content:
|
||||
_function_chunk = ChatCompletionToolCallFunctionChunk(
|
||||
name=content["toolUse"]["name"],
|
||||
arguments=json.dumps(content["toolUse"]["input"]),
|
||||
)
|
||||
_tool_response_chunk = ChatCompletionToolCallChunk(
|
||||
id=content["toolUse"]["toolUseId"],
|
||||
type="function",
|
||||
function=_function_chunk,
|
||||
)
|
||||
tools.append(_tool_response_chunk)
|
||||
chat_completion_message["content"] = content_str
|
||||
chat_completion_message["tool_calls"] = tools
|
||||
|
||||
## CALCULATING USAGE - bedrock returns usage in the headers
|
||||
input_tokens = completion_response["usage"]["inputTokens"]
|
||||
output_tokens = completion_response["usage"]["outputTokens"]
|
||||
total_tokens = completion_response["usage"]["totalTokens"]
|
||||
|
||||
model_response.choices = [
|
||||
litellm.Choices(
|
||||
finish_reason=map_finish_reason(completion_response["stopReason"]),
|
||||
index=0,
|
||||
message=litellm.Message(**chat_completion_message),
|
||||
)
|
||||
]
|
||||
model_response["created"] = int(time.time())
|
||||
model_response["model"] = model
|
||||
usage = Usage(
|
||||
prompt_tokens=input_tokens,
|
||||
completion_tokens=output_tokens,
|
||||
total_tokens=total_tokens,
|
||||
)
|
||||
setattr(model_response, "usage", usage)
|
||||
|
||||
return model_response
|
||||
|
||||
def encode_model_id(self, model_id: str) -> str:
|
||||
"""
|
||||
Double encode the model ID to ensure it matches the expected double-encoded format.
|
||||
|
@ -1387,11 +1562,14 @@ class BedrockConverseLLM(BaseLLM):
|
|||
additional_request_keys = []
|
||||
additional_request_params = {}
|
||||
supported_converse_params = AmazonConverseConfig().get_config().keys()
|
||||
|
||||
supported_tool_call_params = ["tools", "tool_choice"]
|
||||
## TRANSFORMATION ##
|
||||
# send all model-specific params in 'additional_request_params'
|
||||
for k, v in inference_params.items():
|
||||
if k not in supported_converse_params:
|
||||
if (
|
||||
k not in supported_converse_params
|
||||
and k not in supported_tool_call_params
|
||||
):
|
||||
additional_request_params[k] = v
|
||||
additional_request_keys.append(k)
|
||||
for key in additional_request_keys:
|
||||
|
@ -1401,23 +1579,27 @@ class BedrockConverseLLM(BaseLLM):
|
|||
messages=messages
|
||||
)
|
||||
bedrock_tools: List[ToolBlock] = _bedrock_tools_pt(
|
||||
inference_params.get("tools", [])
|
||||
inference_params.pop("tools", [])
|
||||
)
|
||||
bedrock_tool_config: Optional[ToolConfigBlock] = None
|
||||
if len(bedrock_tools) > 0:
|
||||
tool_choice_values: ToolChoiceValuesBlock = inference_params.pop(
|
||||
"tool_choice", None
|
||||
)
|
||||
bedrock_tool_config = ToolConfigBlock(
|
||||
tools=bedrock_tools,
|
||||
toolChoice=inference_params.get("tool_choice", None),
|
||||
)
|
||||
if tool_choice_values is not None:
|
||||
bedrock_tool_config["toolChoice"] = tool_choice_values
|
||||
|
||||
data: RequestObject = {
|
||||
_data: RequestObject = {
|
||||
"messages": bedrock_messages,
|
||||
"additionalModelRequestFields": additional_request_params,
|
||||
"system": system_content_blocks,
|
||||
}
|
||||
if bedrock_tool_config is not None:
|
||||
data["toolConfig"] = bedrock_tool_config
|
||||
|
||||
_data["toolConfig"] = bedrock_tool_config
|
||||
data = json.dumps(_data)
|
||||
## COMPLETION CALL
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
@ -1441,8 +1623,18 @@ class BedrockConverseLLM(BaseLLM):
|
|||
)
|
||||
|
||||
### ROUTING (ASYNC, STREAMING, SYNC)
|
||||
### COMPLETION
|
||||
if client is None or isinstance(client, AsyncHTTPHandler):
|
||||
_params = {}
|
||||
if timeout is not None:
|
||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||
timeout = httpx.Timeout(timeout)
|
||||
_params["timeout"] = timeout
|
||||
client = HTTPHandler(**_params) # type: ignore
|
||||
else:
|
||||
client = client
|
||||
try:
|
||||
response = self.client.post(url=prepped.url, headers=prepped.headers, data=data) # type: ignore
|
||||
response = client.post(url=prepped.url, headers=prepped.headers, data=data) # type: ignore
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
|
@ -1450,6 +1642,20 @@ class BedrockConverseLLM(BaseLLM):
|
|||
except httpx.TimeoutException as e:
|
||||
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||
|
||||
return self.process_response(
|
||||
model=model,
|
||||
response=response,
|
||||
model_response=model_response,
|
||||
stream=stream,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
api_key="",
|
||||
data=data,
|
||||
messages=messages,
|
||||
print_verbose=print_verbose,
|
||||
encoding=encoding,
|
||||
)
|
||||
|
||||
|
||||
def get_response_stream_shape():
|
||||
from botocore.model import ServiceModel
|
||||
|
|
|
@ -156,12 +156,13 @@ class HTTPHandler:
|
|||
self,
|
||||
url: str,
|
||||
data: Optional[Union[dict, str]] = None,
|
||||
json: Optional[Union[dict, str]] = None,
|
||||
params: Optional[dict] = None,
|
||||
headers: Optional[dict] = None,
|
||||
stream: bool = False,
|
||||
):
|
||||
req = self.client.build_request(
|
||||
"POST", url, data=data, params=params, headers=headers # type: ignore
|
||||
"POST", url, data=data, json=json, params=params, headers=headers # type: ignore
|
||||
)
|
||||
response = self.client.send(req, stream=stream)
|
||||
return response
|
||||
|
|
|
@ -1617,6 +1617,7 @@ from litellm.types.llms.bedrock import (
|
|||
ToolInputSchemaBlock as BedrockToolInputSchemaBlock,
|
||||
ToolSpecBlock as BedrockToolSpecBlock,
|
||||
ToolBlock as BedrockToolBlock,
|
||||
ToolChoiceValuesBlock as BedrockToolChoiceValuesBlock,
|
||||
)
|
||||
|
||||
|
||||
|
@ -1814,7 +1815,7 @@ def _convert_to_bedrock_tool_call_result(
|
|||
|
||||
tool_result_content_block = BedrockToolResultContentBlock(text=content)
|
||||
tool_result = BedrockToolResultBlock(
|
||||
content=tool_result_content_block,
|
||||
content=[tool_result_content_block],
|
||||
toolUseId=id,
|
||||
)
|
||||
content_block = BedrockContentBlock(toolResult=tool_result)
|
||||
|
|
|
@ -79,7 +79,7 @@ from .llms.anthropic import AnthropicChatCompletion
|
|||
from .llms.anthropic_text import AnthropicTextCompletion
|
||||
from .llms.huggingface_restapi import Huggingface
|
||||
from .llms.predibase import PredibaseChatCompletion
|
||||
from .llms.bedrock_httpx import BedrockLLM
|
||||
from .llms.bedrock_httpx import BedrockLLM, BedrockConverseLLM
|
||||
from .llms.vertex_httpx import VertexLLM
|
||||
from .llms.triton import TritonChatCompletion
|
||||
from .llms.prompt_templates.factory import (
|
||||
|
@ -121,7 +121,7 @@ azure_text_completions = AzureTextCompletion()
|
|||
huggingface = Huggingface()
|
||||
predibase_chat_completions = PredibaseChatCompletion()
|
||||
triton_chat_completions = TritonChatCompletion()
|
||||
bedrock_chat_completion = BedrockLLM()
|
||||
bedrock_chat_completion = BedrockConverseLLM()
|
||||
vertex_chat_completion = VertexLLM()
|
||||
####### COMPLETION ENDPOINTS ################
|
||||
|
||||
|
|
|
@ -302,10 +302,7 @@ def test_completion_claude_3():
|
|||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
# "anthropic/claude-3-opus-20240229",
|
||||
"cohere.command-r-plus-v1:0"
|
||||
],
|
||||
["anthropic/claude-3-opus-20240229", "anthropic.claude-3-sonnet-20240229-v1:0"],
|
||||
)
|
||||
def test_completion_claude_3_function_call(model):
|
||||
litellm.set_verbose = True
|
||||
|
@ -345,6 +342,7 @@ def test_completion_claude_3_function_call(model):
|
|||
"type": "function",
|
||||
"function": {"name": "get_current_weather"},
|
||||
},
|
||||
drop_params=True,
|
||||
)
|
||||
|
||||
# Add any assertions, here to check response args
|
||||
|
@ -375,6 +373,7 @@ def test_completion_claude_3_function_call(model):
|
|||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
drop_params=True,
|
||||
)
|
||||
print(second_response)
|
||||
except Exception as e:
|
||||
|
|
|
@ -15,6 +15,7 @@ from litellm.llms.prompt_templates.factory import (
|
|||
claude_2_1_pt,
|
||||
llama_2_chat_pt,
|
||||
prompt_factory,
|
||||
_bedrock_tools_pt,
|
||||
)
|
||||
|
||||
|
||||
|
@ -128,3 +129,29 @@ def test_anthropic_messages_pt():
|
|||
|
||||
|
||||
# codellama_prompt_format()
|
||||
def test_bedrock_tool_calling_pt():
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
converted_tools = _bedrock_tools_pt(tools=tools)
|
||||
|
||||
print(converted_tools)
|
||||
|
||||
assert False
|
||||
|
|
|
@ -31,7 +31,7 @@ class ToolResultContentBlock(TypedDict, total=False):
|
|||
|
||||
|
||||
class ToolResultBlock(TypedDict, total=False):
|
||||
content: Required[ToolResultContentBlock]
|
||||
content: Required[List[ToolResultContentBlock]]
|
||||
toolUseId: Required[str]
|
||||
status: Literal["success", "error"]
|
||||
|
||||
|
@ -54,6 +54,30 @@ class MessageBlock(TypedDict):
|
|||
role: Literal["user", "assistant"]
|
||||
|
||||
|
||||
class ConverseMetricsBlock(TypedDict):
|
||||
latencyMs: float # time in ms
|
||||
|
||||
|
||||
class ConverseResponseOutputBlock(TypedDict):
|
||||
message: Optional[MessageBlock]
|
||||
|
||||
|
||||
class ConverseTokenUsageBlock(TypedDict):
|
||||
inputTokens: int
|
||||
outputTokens: int
|
||||
totalTokens: int
|
||||
|
||||
|
||||
class ConverseResponseBlock(TypedDict):
|
||||
additionalModelResponseFields: dict
|
||||
metrics: ConverseMetricsBlock
|
||||
output: ConverseResponseOutputBlock
|
||||
stopReason: (
|
||||
str # end_turn | tool_use | max_tokens | stop_sequence | content_filtered
|
||||
)
|
||||
usage: ConverseTokenUsageBlock
|
||||
|
||||
|
||||
class ToolInputSchemaBlock(TypedDict):
|
||||
json: Optional[dict]
|
||||
|
||||
|
@ -72,9 +96,15 @@ class SpecificToolChoiceBlock(TypedDict):
|
|||
name: str
|
||||
|
||||
|
||||
class ToolChoiceValuesBlock(TypedDict, total=False):
|
||||
any: dict
|
||||
auto: dict
|
||||
tool: SpecificToolChoiceBlock
|
||||
|
||||
|
||||
class ToolConfigBlock(TypedDict, total=False):
|
||||
tools: Required[List[ToolBlock]]
|
||||
toolChoice: Union[str, SpecificToolChoiceBlock]
|
||||
toolChoice: Union[str, ToolChoiceValuesBlock]
|
||||
|
||||
|
||||
class RequestObject(TypedDict, total=False):
|
||||
|
|
|
@ -293,3 +293,20 @@ class ListBatchRequest(TypedDict, total=False):
|
|||
extra_headers: Optional[Dict[str, str]]
|
||||
extra_body: Optional[Dict[str, str]]
|
||||
timeout: Optional[float]
|
||||
|
||||
|
||||
class ChatCompletionToolCallFunctionChunk(TypedDict):
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
|
||||
class ChatCompletionToolCallChunk(TypedDict):
|
||||
id: str
|
||||
type: Literal["function"]
|
||||
function: ChatCompletionToolCallFunctionChunk
|
||||
|
||||
|
||||
class ChatCompletionResponseMessage(TypedDict, total=False):
|
||||
content: Optional[str]
|
||||
tool_calls: List[ChatCompletionToolCallChunk]
|
||||
role: Literal["assistant"]
|
||||
|
|
|
@ -5618,84 +5618,13 @@ def get_optional_params(
|
|||
supported_params = get_supported_openai_params(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
if "ai21" in model:
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
# params "maxTokens":200,"temperature":0,"topP":250,"stop_sequences":[],
|
||||
# https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=j2-ultra
|
||||
if max_tokens is not None:
|
||||
optional_params["maxTokens"] = max_tokens
|
||||
if temperature is not None:
|
||||
optional_params["temperature"] = temperature
|
||||
if top_p is not None:
|
||||
optional_params["topP"] = top_p
|
||||
if stream:
|
||||
optional_params["stream"] = stream
|
||||
elif "anthropic" in model:
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
# anthropic params on bedrock
|
||||
# \"max_tokens_to_sample\":300,\"temperature\":0.5,\"top_p\":1,\"stop_sequences\":[\"\\\\n\\\\nHuman:\"]}"
|
||||
if model.startswith("anthropic.claude-3"):
|
||||
optional_params = (
|
||||
litellm.AmazonAnthropicClaude3Config().map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
)
|
||||
else:
|
||||
optional_params = litellm.AmazonAnthropicConfig().map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
elif "amazon" in model: # amazon titan llms
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
# see https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-large
|
||||
if max_tokens is not None:
|
||||
optional_params["maxTokenCount"] = max_tokens
|
||||
if temperature is not None:
|
||||
optional_params["temperature"] = temperature
|
||||
if stop is not None:
|
||||
filtered_stop = _map_and_modify_arg(
|
||||
{"stop": stop}, provider="bedrock", model=model
|
||||
)
|
||||
optional_params["stopSequences"] = filtered_stop["stop"]
|
||||
if top_p is not None:
|
||||
optional_params["topP"] = top_p
|
||||
if stream:
|
||||
optional_params["stream"] = stream
|
||||
elif "meta" in model: # amazon / meta llms
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
# see https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-large
|
||||
if max_tokens is not None:
|
||||
optional_params["max_gen_len"] = max_tokens
|
||||
if temperature is not None:
|
||||
optional_params["temperature"] = temperature
|
||||
if top_p is not None:
|
||||
optional_params["top_p"] = top_p
|
||||
if stream:
|
||||
optional_params["stream"] = stream
|
||||
elif "cohere" in model: # cohere models on bedrock
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
# handle cohere params
|
||||
if stream:
|
||||
optional_params["stream"] = stream
|
||||
if temperature is not None:
|
||||
optional_params["temperature"] = temperature
|
||||
if max_tokens is not None:
|
||||
optional_params["max_tokens"] = max_tokens
|
||||
elif "mistral" in model:
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
# mistral params on bedrock
|
||||
# \"max_tokens\":400,\"temperature\":0.7,\"top_p\":0.7,\"stop\":[\"\\\\n\\\\nHuman:\"]}"
|
||||
if max_tokens is not None:
|
||||
optional_params["max_tokens"] = max_tokens
|
||||
if temperature is not None:
|
||||
optional_params["temperature"] = temperature
|
||||
if top_p is not None:
|
||||
optional_params["top_p"] = top_p
|
||||
if stop is not None:
|
||||
optional_params["stop"] = stop
|
||||
if stream is not None:
|
||||
optional_params["stream"] = stream
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
optional_params = litellm.AmazonConverseConfig().map_openai_params(
|
||||
model=model,
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
drop_params=drop_params,
|
||||
)
|
||||
elif custom_llm_provider == "aleph_alpha":
|
||||
supported_params = [
|
||||
"max_tokens",
|
||||
|
|
3
ruff.toml
Normal file
3
ruff.toml
Normal file
|
@ -0,0 +1,3 @@
|
|||
ignore = ["F405"]
|
||||
extend-select = ["E501"]
|
||||
line-length = 120
|
Loading…
Add table
Add a link
Reference in a new issue