forked from phoenix/litellm-mirror
Merge pull request #4033 from BerriAI/litellm_bedrock_converse_api
feat(bedrock_httpx.py): add support for bedrock converse api
This commit is contained in:
commit
6b703f0ebf
15 changed files with 1496 additions and 4374 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -59,3 +59,4 @@ myenv/*
|
|||
litellm/proxy/_experimental/out/404/index.html
|
||||
litellm/proxy/_experimental/out/model_hub/index.html
|
||||
litellm/proxy/_experimental/out/onboarding/index.html
|
||||
litellm/tests/log.txt
|
||||
|
|
|
@ -5,7 +5,7 @@ warnings.filterwarnings("ignore", message=".*conflict with protected namespace.*
|
|||
### INIT VARIABLES ###
|
||||
import threading, requests, os
|
||||
from typing import Callable, List, Optional, Dict, Union, Any, Literal
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.caching import Cache
|
||||
from litellm._logging import (
|
||||
set_verbose,
|
||||
|
@ -233,6 +233,7 @@ max_end_user_budget: Optional[float] = None
|
|||
#### RELIABILITY ####
|
||||
request_timeout: float = 6000
|
||||
module_level_aclient = AsyncHTTPHandler(timeout=request_timeout)
|
||||
module_level_client = HTTPHandler(timeout=request_timeout)
|
||||
num_retries: Optional[int] = None # per model endpoint
|
||||
default_fallbacks: Optional[List] = None
|
||||
fallbacks: Optional[List] = None
|
||||
|
@ -766,7 +767,7 @@ from .llms.sagemaker import SagemakerConfig
|
|||
from .llms.ollama import OllamaConfig
|
||||
from .llms.ollama_chat import OllamaChatConfig
|
||||
from .llms.maritalk import MaritTalkConfig
|
||||
from .llms.bedrock_httpx import AmazonCohereChatConfig
|
||||
from .llms.bedrock_httpx import AmazonCohereChatConfig, AmazonConverseConfig
|
||||
from .llms.bedrock import (
|
||||
AmazonTitanConfig,
|
||||
AmazonAI21Config,
|
||||
|
|
|
@ -38,6 +38,8 @@ from .prompt_templates.factory import (
|
|||
extract_between_tags,
|
||||
parse_xml_params,
|
||||
contains_tag,
|
||||
_bedrock_converse_messages_pt,
|
||||
_bedrock_tools_pt,
|
||||
)
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from .base import BaseLLM
|
||||
|
@ -45,6 +47,11 @@ import httpx # type: ignore
|
|||
from .bedrock import BedrockError, convert_messages_to_prompt, ModelResponseIterator
|
||||
from litellm.types.llms.bedrock import *
|
||||
import urllib.parse
|
||||
from litellm.types.llms.openai import (
|
||||
ChatCompletionResponseMessage,
|
||||
ChatCompletionToolCallChunk,
|
||||
ChatCompletionToolCallFunctionChunk,
|
||||
)
|
||||
|
||||
|
||||
class AmazonCohereChatConfig:
|
||||
|
@ -118,6 +125,8 @@ class AmazonCohereChatConfig:
|
|||
"presence_penalty",
|
||||
"seed",
|
||||
"stop",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
|
@ -176,6 +185,37 @@ async def make_call(
|
|||
return completion_stream
|
||||
|
||||
|
||||
def make_sync_call(
|
||||
client: Optional[HTTPHandler],
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
data: str,
|
||||
model: str,
|
||||
messages: list,
|
||||
logging_obj,
|
||||
):
|
||||
if client is None:
|
||||
client = HTTPHandler() # Create a new client if none provided
|
||||
|
||||
response = client.post(api_base, headers=headers, data=data, stream=True)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise BedrockError(status_code=response.status_code, message=response.read())
|
||||
|
||||
decoder = AWSEventStreamDecoder(model=model)
|
||||
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024))
|
||||
|
||||
# LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
original_response=completion_stream, # Pass the completion stream for logging
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
|
||||
return completion_stream
|
||||
|
||||
|
||||
class BedrockLLM(BaseLLM):
|
||||
"""
|
||||
Example call
|
||||
|
@ -1000,12 +1040,12 @@ class BedrockLLM(BaseLLM):
|
|||
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||
timeout = httpx.Timeout(timeout)
|
||||
_params["timeout"] = timeout
|
||||
self.client = AsyncHTTPHandler(**_params) # type: ignore
|
||||
client = AsyncHTTPHandler(**_params) # type: ignore
|
||||
else:
|
||||
self.client = client # type: ignore
|
||||
client = client # type: ignore
|
||||
|
||||
try:
|
||||
response = await self.client.post(api_base, headers=headers, data=data) # type: ignore
|
||||
response = await client.post(api_base, headers=headers, data=data) # type: ignore
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
|
@ -1069,6 +1109,745 @@ class BedrockLLM(BaseLLM):
|
|||
return super().embedding(*args, **kwargs)
|
||||
|
||||
|
||||
class AmazonConverseConfig:
|
||||
"""
|
||||
Reference - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html
|
||||
#2 - https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#conversation-inference-supported-models-features
|
||||
"""
|
||||
|
||||
maxTokens: Optional[int]
|
||||
stopSequences: Optional[List[str]]
|
||||
temperature: Optional[int]
|
||||
topP: Optional[int]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
maxTokens: Optional[int] = None,
|
||||
stopSequences: Optional[List[str]] = None,
|
||||
temperature: Optional[int] = None,
|
||||
topP: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
supported_params = [
|
||||
"max_tokens",
|
||||
"stream",
|
||||
"stream_options",
|
||||
"stop",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"extra_headers",
|
||||
]
|
||||
|
||||
if (
|
||||
model.startswith("anthropic")
|
||||
or model.startswith("mistral")
|
||||
or model.startswith("cohere")
|
||||
):
|
||||
supported_params.append("tools")
|
||||
|
||||
if model.startswith("anthropic") or model.startswith("mistral"):
|
||||
# only anthropic and mistral support tool choice config. otherwise (E.g. cohere) will fail the call - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html
|
||||
supported_params.append("tool_choice")
|
||||
|
||||
return supported_params
|
||||
|
||||
def map_tool_choice_values(
|
||||
self, model: str, tool_choice: Union[str, dict], drop_params: bool
|
||||
) -> Optional[ToolChoiceValuesBlock]:
|
||||
if tool_choice == "none":
|
||||
if litellm.drop_params is True or drop_params is True:
|
||||
return None
|
||||
else:
|
||||
raise litellm.utils.UnsupportedParamsError(
|
||||
message="Bedrock doesn't support tool_choice={}. To drop it from the call, set `litellm.drop_params = True.".format(
|
||||
tool_choice
|
||||
),
|
||||
status_code=400,
|
||||
)
|
||||
elif tool_choice == "required":
|
||||
return ToolChoiceValuesBlock(any={})
|
||||
elif tool_choice == "auto":
|
||||
return ToolChoiceValuesBlock(auto={})
|
||||
elif isinstance(tool_choice, dict):
|
||||
# only supported for anthropic + mistral models - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html
|
||||
specific_tool = SpecificToolChoiceBlock(
|
||||
name=tool_choice.get("function", {}).get("name", "")
|
||||
)
|
||||
return ToolChoiceValuesBlock(tool=specific_tool)
|
||||
else:
|
||||
raise litellm.utils.UnsupportedParamsError(
|
||||
message="Bedrock doesn't support tool_choice={}. Supported tool_choice values=['auto', 'required', json object]. To drop it from the call, set `litellm.drop_params = True.".format(
|
||||
tool_choice
|
||||
),
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
def get_supported_image_types(self) -> List[str]:
|
||||
return ["png", "jpeg", "gif", "webp"]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
model: str,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
for param, value in non_default_params.items():
|
||||
if param == "max_tokens":
|
||||
optional_params["maxTokens"] = value
|
||||
if param == "stream":
|
||||
optional_params["stream"] = value
|
||||
if param == "stop":
|
||||
if isinstance(value, str):
|
||||
value = [value]
|
||||
optional_params["stop_sequences"] = value
|
||||
if param == "temperature":
|
||||
optional_params["temperature"] = value
|
||||
if param == "top_p":
|
||||
optional_params["topP"] = value
|
||||
if param == "tools":
|
||||
optional_params["tools"] = value
|
||||
if param == "tool_choice":
|
||||
_tool_choice_value = self.map_tool_choice_values(
|
||||
model=model, tool_choice=value, drop_params=drop_params # type: ignore
|
||||
)
|
||||
if _tool_choice_value is not None:
|
||||
optional_params["tool_choice"] = _tool_choice_value
|
||||
return optional_params
|
||||
|
||||
|
||||
class BedrockConverseLLM(BaseLLM):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def process_response(
|
||||
self,
|
||||
model: str,
|
||||
response: Union[requests.Response, httpx.Response],
|
||||
model_response: ModelResponse,
|
||||
stream: bool,
|
||||
logging_obj: Logging,
|
||||
optional_params: dict,
|
||||
api_key: str,
|
||||
data: Union[dict, str],
|
||||
messages: List,
|
||||
print_verbose,
|
||||
encoding,
|
||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=response.text,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
print_verbose(f"raw model_response: {response.text}")
|
||||
|
||||
## RESPONSE OBJECT
|
||||
try:
|
||||
completion_response = ConverseResponseBlock(**response.json()) # type: ignore
|
||||
except Exception as e:
|
||||
raise BedrockError(
|
||||
message="Received={}, Error converting to valid response block={}. File an issue if litellm error - https://github.com/BerriAI/litellm/issues".format(
|
||||
response.text, str(e)
|
||||
),
|
||||
status_code=422,
|
||||
)
|
||||
|
||||
"""
|
||||
Bedrock Response Object has optional message block
|
||||
|
||||
completion_response["output"].get("message", None)
|
||||
|
||||
A message block looks like this (Example 1):
|
||||
"output": {
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"text": "Is there anything else you'd like to talk about? Perhaps I can help with some economic questions or provide some information about economic concepts?"
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
(Example 2):
|
||||
"output": {
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"toolUse": {
|
||||
"toolUseId": "tooluse_hbTgdi0CSLq_hM4P8csZJA",
|
||||
"name": "top_song",
|
||||
"input": {
|
||||
"sign": "WZPZ"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
"""
|
||||
message: Optional[MessageBlock] = completion_response["output"]["message"]
|
||||
chat_completion_message: ChatCompletionResponseMessage = {"role": "assistant"}
|
||||
content_str = ""
|
||||
tools: List[ChatCompletionToolCallChunk] = []
|
||||
if message is not None:
|
||||
for content in message["content"]:
|
||||
"""
|
||||
- Content is either a tool response or text
|
||||
"""
|
||||
if "text" in content:
|
||||
content_str += content["text"]
|
||||
if "toolUse" in content:
|
||||
_function_chunk = ChatCompletionToolCallFunctionChunk(
|
||||
name=content["toolUse"]["name"],
|
||||
arguments=json.dumps(content["toolUse"]["input"]),
|
||||
)
|
||||
_tool_response_chunk = ChatCompletionToolCallChunk(
|
||||
id=content["toolUse"]["toolUseId"],
|
||||
type="function",
|
||||
function=_function_chunk,
|
||||
)
|
||||
tools.append(_tool_response_chunk)
|
||||
chat_completion_message["content"] = content_str
|
||||
chat_completion_message["tool_calls"] = tools
|
||||
|
||||
## CALCULATING USAGE - bedrock returns usage in the headers
|
||||
input_tokens = completion_response["usage"]["inputTokens"]
|
||||
output_tokens = completion_response["usage"]["outputTokens"]
|
||||
total_tokens = completion_response["usage"]["totalTokens"]
|
||||
|
||||
model_response.choices = [
|
||||
litellm.Choices(
|
||||
finish_reason=map_finish_reason(completion_response["stopReason"]),
|
||||
index=0,
|
||||
message=litellm.Message(**chat_completion_message),
|
||||
)
|
||||
]
|
||||
model_response["created"] = int(time.time())
|
||||
model_response["model"] = model
|
||||
usage = Usage(
|
||||
prompt_tokens=input_tokens,
|
||||
completion_tokens=output_tokens,
|
||||
total_tokens=total_tokens,
|
||||
)
|
||||
setattr(model_response, "usage", usage)
|
||||
|
||||
return model_response
|
||||
|
||||
def encode_model_id(self, model_id: str) -> str:
|
||||
"""
|
||||
Double encode the model ID to ensure it matches the expected double-encoded format.
|
||||
Args:
|
||||
model_id (str): The model ID to encode.
|
||||
Returns:
|
||||
str: The double-encoded model ID.
|
||||
"""
|
||||
return urllib.parse.quote(model_id, safe="")
|
||||
|
||||
def get_credentials(
|
||||
self,
|
||||
aws_access_key_id: Optional[str] = None,
|
||||
aws_secret_access_key: Optional[str] = None,
|
||||
aws_region_name: Optional[str] = None,
|
||||
aws_session_name: Optional[str] = None,
|
||||
aws_profile_name: Optional[str] = None,
|
||||
aws_role_name: Optional[str] = None,
|
||||
aws_web_identity_token: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Return a boto3.Credentials object
|
||||
"""
|
||||
import boto3
|
||||
|
||||
## CHECK IS 'os.environ/' passed in
|
||||
params_to_check: List[Optional[str]] = [
|
||||
aws_access_key_id,
|
||||
aws_secret_access_key,
|
||||
aws_region_name,
|
||||
aws_session_name,
|
||||
aws_profile_name,
|
||||
aws_role_name,
|
||||
aws_web_identity_token,
|
||||
]
|
||||
|
||||
# Iterate over parameters and update if needed
|
||||
for i, param in enumerate(params_to_check):
|
||||
if param and param.startswith("os.environ/"):
|
||||
_v = get_secret(param)
|
||||
if _v is not None and isinstance(_v, str):
|
||||
params_to_check[i] = _v
|
||||
# Assign updated values back to parameters
|
||||
(
|
||||
aws_access_key_id,
|
||||
aws_secret_access_key,
|
||||
aws_region_name,
|
||||
aws_session_name,
|
||||
aws_profile_name,
|
||||
aws_role_name,
|
||||
aws_web_identity_token,
|
||||
) = params_to_check
|
||||
|
||||
### CHECK STS ###
|
||||
if (
|
||||
aws_web_identity_token is not None
|
||||
and aws_role_name is not None
|
||||
and aws_session_name is not None
|
||||
):
|
||||
oidc_token = get_secret(aws_web_identity_token)
|
||||
|
||||
if oidc_token is None:
|
||||
raise BedrockError(
|
||||
message="OIDC token could not be retrieved from secret manager.",
|
||||
status_code=401,
|
||||
)
|
||||
|
||||
sts_client = boto3.client("sts")
|
||||
|
||||
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
|
||||
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
|
||||
sts_response = sts_client.assume_role_with_web_identity(
|
||||
RoleArn=aws_role_name,
|
||||
RoleSessionName=aws_session_name,
|
||||
WebIdentityToken=oidc_token,
|
||||
DurationSeconds=3600,
|
||||
)
|
||||
|
||||
session = boto3.Session(
|
||||
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
|
||||
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
|
||||
aws_session_token=sts_response["Credentials"]["SessionToken"],
|
||||
region_name=aws_region_name,
|
||||
)
|
||||
|
||||
return session.get_credentials()
|
||||
elif aws_role_name is not None and aws_session_name is not None:
|
||||
sts_client = boto3.client(
|
||||
"sts",
|
||||
aws_access_key_id=aws_access_key_id, # [OPTIONAL]
|
||||
aws_secret_access_key=aws_secret_access_key, # [OPTIONAL]
|
||||
)
|
||||
|
||||
sts_response = sts_client.assume_role(
|
||||
RoleArn=aws_role_name, RoleSessionName=aws_session_name
|
||||
)
|
||||
|
||||
# Extract the credentials from the response and convert to Session Credentials
|
||||
sts_credentials = sts_response["Credentials"]
|
||||
from botocore.credentials import Credentials
|
||||
|
||||
credentials = Credentials(
|
||||
access_key=sts_credentials["AccessKeyId"],
|
||||
secret_key=sts_credentials["SecretAccessKey"],
|
||||
token=sts_credentials["SessionToken"],
|
||||
)
|
||||
return credentials
|
||||
elif aws_profile_name is not None: ### CHECK SESSION ###
|
||||
# uses auth values from AWS profile usually stored in ~/.aws/credentials
|
||||
client = boto3.Session(profile_name=aws_profile_name)
|
||||
|
||||
return client.get_credentials()
|
||||
else:
|
||||
session = boto3.Session(
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
region_name=aws_region_name,
|
||||
)
|
||||
|
||||
return session.get_credentials()
|
||||
|
||||
async def async_streaming(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
data: str,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
encoding,
|
||||
logging_obj,
|
||||
stream,
|
||||
optional_params: dict,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
) -> CustomStreamWrapper:
|
||||
streaming_response = CustomStreamWrapper(
|
||||
completion_stream=None,
|
||||
make_call=partial(
|
||||
make_call,
|
||||
client=client,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
data=data,
|
||||
model=model,
|
||||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
),
|
||||
model=model,
|
||||
custom_llm_provider="bedrock",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return streaming_response
|
||||
|
||||
async def async_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
data: str,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
encoding,
|
||||
logging_obj,
|
||||
stream,
|
||||
optional_params: dict,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
if client is None:
|
||||
_params = {}
|
||||
if timeout is not None:
|
||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||
timeout = httpx.Timeout(timeout)
|
||||
_params["timeout"] = timeout
|
||||
client = AsyncHTTPHandler(**_params) # type: ignore
|
||||
else:
|
||||
client = client # type: ignore
|
||||
|
||||
try:
|
||||
response = await client.post(api_base, headers=headers, data=data) # type: ignore
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
raise BedrockError(status_code=error_code, message=err.response.text)
|
||||
except httpx.TimeoutException as e:
|
||||
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||
|
||||
return self.process_response(
|
||||
model=model,
|
||||
response=response,
|
||||
model_response=model_response,
|
||||
stream=stream if isinstance(stream, bool) else False,
|
||||
logging_obj=logging_obj,
|
||||
api_key="",
|
||||
data=data,
|
||||
messages=messages,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
encoding=encoding,
|
||||
)
|
||||
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
acompletion: bool,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
extra_headers: Optional[dict] = None,
|
||||
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
|
||||
):
|
||||
try:
|
||||
import boto3
|
||||
|
||||
from botocore.auth import SigV4Auth
|
||||
from botocore.awsrequest import AWSRequest
|
||||
from botocore.credentials import Credentials
|
||||
except ImportError:
|
||||
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||
|
||||
## SETUP ##
|
||||
stream = optional_params.pop("stream", None)
|
||||
modelId = optional_params.pop("model_id", None)
|
||||
if modelId is not None:
|
||||
modelId = self.encode_model_id(model_id=modelId)
|
||||
else:
|
||||
modelId = model
|
||||
|
||||
provider = model.split(".")[0]
|
||||
|
||||
## CREDENTIALS ##
|
||||
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
||||
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
||||
aws_region_name = optional_params.pop("aws_region_name", None)
|
||||
aws_role_name = optional_params.pop("aws_role_name", None)
|
||||
aws_session_name = optional_params.pop("aws_session_name", None)
|
||||
aws_profile_name = optional_params.pop("aws_profile_name", None)
|
||||
aws_bedrock_runtime_endpoint = optional_params.pop(
|
||||
"aws_bedrock_runtime_endpoint", None
|
||||
) # https://bedrock-runtime.{region_name}.amazonaws.com
|
||||
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
|
||||
|
||||
### SET REGION NAME ###
|
||||
if aws_region_name is None:
|
||||
# check env #
|
||||
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
|
||||
|
||||
if litellm_aws_region_name is not None and isinstance(
|
||||
litellm_aws_region_name, str
|
||||
):
|
||||
aws_region_name = litellm_aws_region_name
|
||||
|
||||
standard_aws_region_name = get_secret("AWS_REGION", None)
|
||||
if standard_aws_region_name is not None and isinstance(
|
||||
standard_aws_region_name, str
|
||||
):
|
||||
aws_region_name = standard_aws_region_name
|
||||
|
||||
if aws_region_name is None:
|
||||
aws_region_name = "us-west-2"
|
||||
|
||||
credentials: Credentials = self.get_credentials(
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_region_name=aws_region_name,
|
||||
aws_session_name=aws_session_name,
|
||||
aws_profile_name=aws_profile_name,
|
||||
aws_role_name=aws_role_name,
|
||||
aws_web_identity_token=aws_web_identity_token,
|
||||
)
|
||||
|
||||
### SET RUNTIME ENDPOINT ###
|
||||
endpoint_url = ""
|
||||
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
|
||||
if aws_bedrock_runtime_endpoint is not None and isinstance(
|
||||
aws_bedrock_runtime_endpoint, str
|
||||
):
|
||||
endpoint_url = aws_bedrock_runtime_endpoint
|
||||
elif env_aws_bedrock_runtime_endpoint and isinstance(
|
||||
env_aws_bedrock_runtime_endpoint, str
|
||||
):
|
||||
endpoint_url = env_aws_bedrock_runtime_endpoint
|
||||
else:
|
||||
endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
|
||||
|
||||
if (stream is not None and stream is True) and provider != "ai21":
|
||||
endpoint_url = f"{endpoint_url}/model/{modelId}/converse-stream"
|
||||
else:
|
||||
endpoint_url = f"{endpoint_url}/model/{modelId}/converse"
|
||||
|
||||
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
|
||||
|
||||
# Separate system prompt from rest of message
|
||||
system_prompt_indices = []
|
||||
system_content_blocks: List[SystemContentBlock] = []
|
||||
for idx, message in enumerate(messages):
|
||||
if message["role"] == "system":
|
||||
_system_content_block = SystemContentBlock(text=message["content"])
|
||||
system_content_blocks.append(_system_content_block)
|
||||
system_prompt_indices.append(idx)
|
||||
if len(system_prompt_indices) > 0:
|
||||
for idx in reversed(system_prompt_indices):
|
||||
messages.pop(idx)
|
||||
|
||||
inference_params = copy.deepcopy(optional_params)
|
||||
additional_request_keys = []
|
||||
additional_request_params = {}
|
||||
supported_converse_params = AmazonConverseConfig.__annotations__.keys()
|
||||
supported_tool_call_params = ["tools", "tool_choice"]
|
||||
## TRANSFORMATION ##
|
||||
# send all model-specific params in 'additional_request_params'
|
||||
for k, v in inference_params.items():
|
||||
if (
|
||||
k not in supported_converse_params
|
||||
and k not in supported_tool_call_params
|
||||
):
|
||||
additional_request_params[k] = v
|
||||
additional_request_keys.append(k)
|
||||
for key in additional_request_keys:
|
||||
inference_params.pop(key, None)
|
||||
|
||||
bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt(
|
||||
messages=messages
|
||||
)
|
||||
bedrock_tools: List[ToolBlock] = _bedrock_tools_pt(
|
||||
inference_params.pop("tools", [])
|
||||
)
|
||||
bedrock_tool_config: Optional[ToolConfigBlock] = None
|
||||
if len(bedrock_tools) > 0:
|
||||
tool_choice_values: ToolChoiceValuesBlock = inference_params.pop(
|
||||
"tool_choice", None
|
||||
)
|
||||
bedrock_tool_config = ToolConfigBlock(
|
||||
tools=bedrock_tools,
|
||||
)
|
||||
if tool_choice_values is not None:
|
||||
bedrock_tool_config["toolChoice"] = tool_choice_values
|
||||
|
||||
_data: RequestObject = {
|
||||
"messages": bedrock_messages,
|
||||
"additionalModelRequestFields": additional_request_params,
|
||||
"system": system_content_blocks,
|
||||
"inferenceConfig": InferenceConfig(**inference_params),
|
||||
}
|
||||
if bedrock_tool_config is not None:
|
||||
_data["toolConfig"] = bedrock_tool_config
|
||||
data = json.dumps(_data)
|
||||
## COMPLETION CALL
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if extra_headers is not None:
|
||||
headers = {"Content-Type": "application/json", **extra_headers}
|
||||
request = AWSRequest(
|
||||
method="POST", url=endpoint_url, data=data, headers=headers
|
||||
)
|
||||
sigv4.add_auth(request)
|
||||
prepped = request.prepare()
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": prepped.url,
|
||||
"headers": prepped.headers,
|
||||
},
|
||||
)
|
||||
|
||||
### ROUTING (ASYNC, STREAMING, SYNC)
|
||||
if acompletion:
|
||||
if isinstance(client, HTTPHandler):
|
||||
client = None
|
||||
if stream is True and provider != "ai21":
|
||||
return self.async_streaming(
|
||||
model=model,
|
||||
messages=messages,
|
||||
data=data,
|
||||
api_base=prepped.url,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
encoding=encoding,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
stream=True,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=prepped.headers,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
) # type: ignore
|
||||
### ASYNC COMPLETION
|
||||
return self.async_completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
data=data,
|
||||
api_base=prepped.url,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
encoding=encoding,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
stream=stream, # type: ignore
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=prepped.headers,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
) # type: ignore
|
||||
|
||||
if (stream is not None and stream is True) and provider != "ai21":
|
||||
|
||||
streaming_response = CustomStreamWrapper(
|
||||
completion_stream=None,
|
||||
make_call=partial(
|
||||
make_sync_call,
|
||||
client=None,
|
||||
api_base=prepped.url,
|
||||
headers=prepped.headers, # type: ignore
|
||||
data=data,
|
||||
model=model,
|
||||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
),
|
||||
model=model,
|
||||
custom_llm_provider="bedrock",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
original_response=streaming_response,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
return streaming_response
|
||||
### COMPLETION
|
||||
|
||||
if client is None or isinstance(client, AsyncHTTPHandler):
|
||||
_params = {}
|
||||
if timeout is not None:
|
||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||
timeout = httpx.Timeout(timeout)
|
||||
_params["timeout"] = timeout
|
||||
client = HTTPHandler(**_params) # type: ignore
|
||||
else:
|
||||
client = client
|
||||
try:
|
||||
response = client.post(url=prepped.url, headers=prepped.headers, data=data) # type: ignore
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
raise BedrockError(status_code=error_code, message=response.text)
|
||||
except httpx.TimeoutException:
|
||||
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||
|
||||
return self.process_response(
|
||||
model=model,
|
||||
response=response,
|
||||
model_response=model_response,
|
||||
stream=stream,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
api_key="",
|
||||
data=data,
|
||||
messages=messages,
|
||||
print_verbose=print_verbose,
|
||||
encoding=encoding,
|
||||
)
|
||||
|
||||
|
||||
def get_response_stream_shape():
|
||||
from botocore.model import ServiceModel
|
||||
from botocore.loaders import Loader
|
||||
|
@ -1086,6 +1865,31 @@ class AWSEventStreamDecoder:
|
|||
self.model = model
|
||||
self.parser = EventStreamJSONParser()
|
||||
|
||||
def converse_chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk:
|
||||
text = ""
|
||||
tool_str = ""
|
||||
is_finished = False
|
||||
finish_reason = ""
|
||||
usage: Optional[ConverseTokenUsageBlock] = None
|
||||
if "delta" in chunk_data:
|
||||
delta_obj = ContentBlockDeltaEvent(**chunk_data["delta"])
|
||||
if "text" in delta_obj:
|
||||
text = delta_obj["text"]
|
||||
elif "toolUse" in delta_obj:
|
||||
tool_str = delta_obj["toolUse"]["input"]
|
||||
elif "stopReason" in chunk_data:
|
||||
finish_reason = map_finish_reason(chunk_data.get("stopReason", "stop"))
|
||||
elif "usage" in chunk_data:
|
||||
usage = ConverseTokenUsageBlock(**chunk_data["usage"]) # type: ignore
|
||||
response = GenericStreamingChunk(
|
||||
text=text,
|
||||
tool_str=tool_str,
|
||||
is_finished=is_finished,
|
||||
finish_reason=finish_reason,
|
||||
usage=usage,
|
||||
)
|
||||
return response
|
||||
|
||||
def _chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk:
|
||||
text = ""
|
||||
is_finished = False
|
||||
|
@ -1098,19 +1902,8 @@ class AWSEventStreamDecoder:
|
|||
is_finished = True
|
||||
finish_reason = "stop"
|
||||
######## bedrock.anthropic mappings ###############
|
||||
elif "completion" in chunk_data: # not claude-3
|
||||
text = chunk_data["completion"] # bedrock.anthropic
|
||||
stop_reason = chunk_data.get("stop_reason", None)
|
||||
if stop_reason != None:
|
||||
is_finished = True
|
||||
finish_reason = stop_reason
|
||||
elif "delta" in chunk_data:
|
||||
if chunk_data["delta"].get("text", None) is not None:
|
||||
text = chunk_data["delta"]["text"]
|
||||
stop_reason = chunk_data["delta"].get("stop_reason", None)
|
||||
if stop_reason != None:
|
||||
is_finished = True
|
||||
finish_reason = stop_reason
|
||||
return self.converse_chunk_parser(chunk_data=chunk_data)
|
||||
######## bedrock.mistral mappings ###############
|
||||
elif "outputs" in chunk_data:
|
||||
if (
|
||||
|
@ -1137,11 +1930,11 @@ class AWSEventStreamDecoder:
|
|||
is_finished = True
|
||||
finish_reason = chunk_data["completionReason"]
|
||||
return GenericStreamingChunk(
|
||||
**{
|
||||
"text": text,
|
||||
"is_finished": is_finished,
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
text=text,
|
||||
is_finished=is_finished,
|
||||
finish_reason=finish_reason,
|
||||
tool_str="",
|
||||
usage=None,
|
||||
)
|
||||
|
||||
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GenericStreamingChunk]:
|
||||
|
@ -1178,9 +1971,14 @@ class AWSEventStreamDecoder:
|
|||
parsed_response = self.parser.parse(response_dict, get_response_stream_shape())
|
||||
if response_dict["status_code"] != 200:
|
||||
raise ValueError(f"Bad response code, expected 200: {response_dict}")
|
||||
if "chunk" in parsed_response:
|
||||
chunk = parsed_response.get("chunk")
|
||||
if not chunk:
|
||||
return None
|
||||
return chunk.get("bytes").decode() # type: ignore[no-any-return]
|
||||
else:
|
||||
chunk = response_dict.get("body")
|
||||
if not chunk:
|
||||
return None
|
||||
|
||||
chunk = parsed_response.get("chunk")
|
||||
if not chunk:
|
||||
return None
|
||||
|
||||
return chunk.get("bytes").decode() # type: ignore[no-any-return]
|
||||
return chunk.decode() # type: ignore[no-any-return]
|
||||
|
|
|
@ -156,12 +156,13 @@ class HTTPHandler:
|
|||
self,
|
||||
url: str,
|
||||
data: Optional[Union[dict, str]] = None,
|
||||
json: Optional[Union[dict, str]] = None,
|
||||
params: Optional[dict] = None,
|
||||
headers: Optional[dict] = None,
|
||||
stream: bool = False,
|
||||
):
|
||||
req = self.client.build_request(
|
||||
"POST", url, data=data, params=params, headers=headers # type: ignore
|
||||
"POST", url, data=data, json=json, params=params, headers=headers # type: ignore
|
||||
)
|
||||
response = self.client.send(req, stream=stream)
|
||||
return response
|
||||
|
|
|
@ -3,14 +3,7 @@ import requests, traceback
|
|||
import json, re, xml.etree.ElementTree as ET
|
||||
from jinja2 import Template, exceptions, meta, BaseLoader
|
||||
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
||||
from typing import (
|
||||
Any,
|
||||
List,
|
||||
Mapping,
|
||||
MutableMapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
)
|
||||
from typing import Any, List, Mapping, MutableMapping, Optional, Sequence, Tuple
|
||||
import litellm
|
||||
import litellm.types
|
||||
from litellm.types.completion import (
|
||||
|
@ -24,7 +17,7 @@ from litellm.types.completion import (
|
|||
import litellm.types.llms
|
||||
from litellm.types.llms.anthropic import *
|
||||
import uuid
|
||||
|
||||
from litellm.types.llms.bedrock import MessageBlock as BedrockMessageBlock
|
||||
import litellm.types.llms.vertex_ai
|
||||
|
||||
|
||||
|
@ -1460,9 +1453,7 @@ def _load_image_from_url(image_url):
|
|||
try:
|
||||
from PIL import Image
|
||||
except:
|
||||
raise Exception(
|
||||
"gemini image conversion failed please run `pip install Pillow`"
|
||||
)
|
||||
raise Exception("image conversion failed please run `pip install Pillow`")
|
||||
from io import BytesIO
|
||||
|
||||
try:
|
||||
|
@ -1613,6 +1604,380 @@ def azure_text_pt(messages: list):
|
|||
return prompt
|
||||
|
||||
|
||||
###### AMAZON BEDROCK #######
|
||||
|
||||
from litellm.types.llms.bedrock import (
|
||||
ToolResultContentBlock as BedrockToolResultContentBlock,
|
||||
ToolResultBlock as BedrockToolResultBlock,
|
||||
ToolConfigBlock as BedrockToolConfigBlock,
|
||||
ToolUseBlock as BedrockToolUseBlock,
|
||||
ImageSourceBlock as BedrockImageSourceBlock,
|
||||
ImageBlock as BedrockImageBlock,
|
||||
ContentBlock as BedrockContentBlock,
|
||||
ToolInputSchemaBlock as BedrockToolInputSchemaBlock,
|
||||
ToolSpecBlock as BedrockToolSpecBlock,
|
||||
ToolBlock as BedrockToolBlock,
|
||||
ToolChoiceValuesBlock as BedrockToolChoiceValuesBlock,
|
||||
)
|
||||
|
||||
|
||||
def get_image_details(image_url) -> Tuple[str, str]:
|
||||
try:
|
||||
import base64
|
||||
|
||||
# Send a GET request to the image URL
|
||||
response = requests.get(image_url)
|
||||
response.raise_for_status() # Raise an exception for HTTP errors
|
||||
|
||||
# Check the response's content type to ensure it is an image
|
||||
content_type = response.headers.get("content-type")
|
||||
if not content_type or "image" not in content_type:
|
||||
raise ValueError(
|
||||
f"URL does not point to a valid image (content-type: {content_type})"
|
||||
)
|
||||
|
||||
# Convert the image content to base64 bytes
|
||||
base64_bytes = base64.b64encode(response.content).decode("utf-8")
|
||||
|
||||
# Get mime-type
|
||||
mime_type = content_type.split("/")[
|
||||
1
|
||||
] # Extract mime-type from content-type header
|
||||
|
||||
return base64_bytes, mime_type
|
||||
|
||||
except requests.RequestException as e:
|
||||
raise Exception(f"Request failed: {e}")
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def _process_bedrock_converse_image_block(image_url: str) -> BedrockImageBlock:
|
||||
if "base64" in image_url:
|
||||
# Case 1: Images with base64 encoding
|
||||
import base64, re
|
||||
|
||||
# base 64 is passed as data:image/jpeg;base64,<base-64-encoded-image>
|
||||
image_metadata, img_without_base_64 = image_url.split(",")
|
||||
|
||||
# read mime_type from img_without_base_64=data:image/jpeg;base64
|
||||
# Extract MIME type using regular expression
|
||||
mime_type_match = re.match(r"data:(.*?);base64", image_metadata)
|
||||
if mime_type_match:
|
||||
mime_type = mime_type_match.group(1)
|
||||
image_format = mime_type.split("/")[1]
|
||||
else:
|
||||
mime_type = "image/jpeg"
|
||||
image_format = "jpeg"
|
||||
_blob = BedrockImageSourceBlock(bytes=img_without_base_64)
|
||||
supported_image_formats = (
|
||||
litellm.AmazonConverseConfig().get_supported_image_types()
|
||||
)
|
||||
if image_format in supported_image_formats:
|
||||
return BedrockImageBlock(source=_blob, format=image_format) # type: ignore
|
||||
else:
|
||||
# Handle the case when the image format is not supported
|
||||
raise ValueError(
|
||||
"Unsupported image format: {}. Supported formats: {}".format(
|
||||
image_format, supported_image_formats
|
||||
)
|
||||
)
|
||||
elif "https:/" in image_url:
|
||||
# Case 2: Images with direct links
|
||||
image_bytes, image_format = get_image_details(image_url)
|
||||
_blob = BedrockImageSourceBlock(bytes=image_bytes)
|
||||
supported_image_formats = (
|
||||
litellm.AmazonConverseConfig().get_supported_image_types()
|
||||
)
|
||||
if image_format in supported_image_formats:
|
||||
return BedrockImageBlock(source=_blob, format=image_format) # type: ignore
|
||||
else:
|
||||
# Handle the case when the image format is not supported
|
||||
raise ValueError(
|
||||
"Unsupported image format: {}. Supported formats: {}".format(
|
||||
image_format, supported_image_formats
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported image type. Expected either image url or base64 encoded string - \
|
||||
e.g. 'data:image/jpeg;base64,<base64-encoded-string>'"
|
||||
)
|
||||
|
||||
|
||||
def _convert_to_bedrock_tool_call_invoke(
|
||||
tool_calls: list,
|
||||
) -> List[BedrockContentBlock]:
|
||||
"""
|
||||
OpenAI tool invokes:
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_abc123",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"arguments": "{\n\"location\": \"Boston, MA\"\n}"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
"""
|
||||
"""
|
||||
Bedrock tool invokes:
|
||||
[
|
||||
{
|
||||
"role": "assistant",
|
||||
"toolUse": {
|
||||
"input": {"location": "Boston, MA", ..},
|
||||
"name": "get_current_weather",
|
||||
"toolUseId": "call_abc123"
|
||||
}
|
||||
}
|
||||
]
|
||||
"""
|
||||
"""
|
||||
- json.loads argument
|
||||
- extract name
|
||||
- extract id
|
||||
"""
|
||||
|
||||
try:
|
||||
_parts_list: List[BedrockContentBlock] = []
|
||||
for tool in tool_calls:
|
||||
if "function" in tool:
|
||||
id = tool["id"]
|
||||
name = tool["function"].get("name", "")
|
||||
arguments = tool["function"].get("arguments", "")
|
||||
arguments_dict = json.loads(arguments)
|
||||
bedrock_tool = BedrockToolUseBlock(
|
||||
input=arguments_dict, name=name, toolUseId=id
|
||||
)
|
||||
bedrock_content_block = BedrockContentBlock(toolUse=bedrock_tool)
|
||||
_parts_list.append(bedrock_content_block)
|
||||
return _parts_list
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
"Unable to convert openai tool calls={} to bedrock tool calls. Received error={}".format(
|
||||
tool_calls, str(e)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _convert_to_bedrock_tool_call_result(
|
||||
message: dict,
|
||||
) -> BedrockMessageBlock:
|
||||
"""
|
||||
OpenAI message with a tool result looks like:
|
||||
{
|
||||
"tool_call_id": "tool_1",
|
||||
"role": "tool",
|
||||
"name": "get_current_weather",
|
||||
"content": "function result goes here",
|
||||
},
|
||||
|
||||
OpenAI message with a function call result looks like:
|
||||
{
|
||||
"role": "function",
|
||||
"name": "get_current_weather",
|
||||
"content": "function result goes here",
|
||||
}
|
||||
"""
|
||||
"""
|
||||
Bedrock result looks like this:
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"toolResult": {
|
||||
"toolUseId": "tooluse_kZJMlvQmRJ6eAyJE5GIl7Q",
|
||||
"content": [
|
||||
{
|
||||
"json": {
|
||||
"song": "Elemental Hotel",
|
||||
"artist": "8 Storey Hike"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
"""
|
||||
"""
|
||||
-
|
||||
"""
|
||||
content = message.get("content", "")
|
||||
name = message.get("name", "")
|
||||
id = message.get("tool_call_id", str(uuid.uuid4()))
|
||||
|
||||
tool_result_content_block = BedrockToolResultContentBlock(text=content)
|
||||
tool_result = BedrockToolResultBlock(
|
||||
content=[tool_result_content_block],
|
||||
toolUseId=id,
|
||||
)
|
||||
content_block = BedrockContentBlock(toolResult=tool_result)
|
||||
|
||||
return BedrockMessageBlock(role="user", content=[content_block])
|
||||
|
||||
|
||||
def _bedrock_converse_messages_pt(messages: List) -> List[BedrockMessageBlock]:
|
||||
"""
|
||||
Converts given messages from OpenAI format to Bedrock format
|
||||
|
||||
- Roles must alternate b/w 'user' and 'model' (same as anthropic -> merge consecutive roles)
|
||||
- Please ensure that function response turn comes immediately after a function call turn
|
||||
"""
|
||||
|
||||
contents: List[BedrockMessageBlock] = []
|
||||
msg_i = 0
|
||||
while msg_i < len(messages):
|
||||
user_content: List[BedrockContentBlock] = []
|
||||
init_msg_i = msg_i
|
||||
## MERGE CONSECUTIVE USER CONTENT ##
|
||||
while msg_i < len(messages) and messages[msg_i]["role"] == "user":
|
||||
if isinstance(messages[msg_i]["content"], list):
|
||||
_parts: List[BedrockContentBlock] = []
|
||||
for element in messages[msg_i]["content"]:
|
||||
if isinstance(element, dict):
|
||||
if element["type"] == "text":
|
||||
_part = BedrockContentBlock(text=element["text"])
|
||||
_parts.append(_part)
|
||||
elif element["type"] == "image_url":
|
||||
image_url = element["image_url"]["url"]
|
||||
_part = _process_bedrock_converse_image_block( # type: ignore
|
||||
image_url=image_url
|
||||
)
|
||||
_parts.append(BedrockContentBlock(image=_part)) # type: ignore
|
||||
user_content.extend(_parts)
|
||||
else:
|
||||
_part = BedrockContentBlock(text=messages[msg_i]["content"])
|
||||
user_content.append(_part)
|
||||
|
||||
msg_i += 1
|
||||
|
||||
if user_content:
|
||||
contents.append(BedrockMessageBlock(role="user", content=user_content))
|
||||
assistant_content: List[BedrockContentBlock] = []
|
||||
## MERGE CONSECUTIVE ASSISTANT CONTENT ##
|
||||
while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
|
||||
if isinstance(messages[msg_i]["content"], list):
|
||||
assistants_parts: List[BedrockContentBlock] = []
|
||||
for element in messages[msg_i]["content"]:
|
||||
if isinstance(element, dict):
|
||||
if element["type"] == "text":
|
||||
assistants_part = BedrockContentBlock(text=element["text"])
|
||||
assistants_parts.append(assistants_part)
|
||||
elif element["type"] == "image_url":
|
||||
image_url = element["image_url"]["url"]
|
||||
assistants_part = _process_bedrock_converse_image_block( # type: ignore
|
||||
image_url=image_url
|
||||
)
|
||||
assistants_parts.append(
|
||||
BedrockContentBlock(image=assistants_part) # type: ignore
|
||||
)
|
||||
assistant_content.extend(assistants_parts)
|
||||
elif messages[msg_i].get(
|
||||
"tool_calls", []
|
||||
): # support assistant tool invoke convertion
|
||||
assistant_content.extend(
|
||||
_convert_to_bedrock_tool_call_invoke(messages[msg_i]["tool_calls"])
|
||||
)
|
||||
else:
|
||||
assistant_text = (
|
||||
messages[msg_i].get("content") or ""
|
||||
) # either string or none
|
||||
if assistant_text:
|
||||
assistant_content.append(BedrockContentBlock(text=assistant_text))
|
||||
|
||||
msg_i += 1
|
||||
|
||||
if assistant_content:
|
||||
contents.append(
|
||||
BedrockMessageBlock(role="assistant", content=assistant_content)
|
||||
)
|
||||
|
||||
## APPEND TOOL CALL MESSAGES ##
|
||||
if msg_i < len(messages) and messages[msg_i]["role"] == "tool":
|
||||
tool_call_result = _convert_to_bedrock_tool_call_result(messages[msg_i])
|
||||
contents.append(tool_call_result)
|
||||
msg_i += 1
|
||||
if msg_i == init_msg_i: # prevent infinite loops
|
||||
raise Exception(
|
||||
"Invalid Message passed in - {}. File an issue https://github.com/BerriAI/litellm/issues".format(
|
||||
messages[msg_i]
|
||||
)
|
||||
)
|
||||
|
||||
return contents
|
||||
|
||||
|
||||
def _bedrock_tools_pt(tools: List) -> List[BedrockToolBlock]:
|
||||
"""
|
||||
OpenAI tools looks like:
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
}
|
||||
}
|
||||
]
|
||||
"""
|
||||
"""
|
||||
Bedrock toolConfig looks like:
|
||||
"tools": [
|
||||
{
|
||||
"toolSpec": {
|
||||
"name": "top_song",
|
||||
"description": "Get the most popular song played on a radio station.",
|
||||
"inputSchema": {
|
||||
"json": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"sign": {
|
||||
"type": "string",
|
||||
"description": "The call sign for the radio station for which you want the most popular song. Example calls signs are WZPZ, and WKRP."
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"sign"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
"""
|
||||
tool_block_list: List[BedrockToolBlock] = []
|
||||
for tool in tools:
|
||||
parameters = tool.get("function", {}).get("parameters", None)
|
||||
name = tool.get("function", {}).get("name", "")
|
||||
description = tool.get("function", {}).get("description", "")
|
||||
tool_input_schema = BedrockToolInputSchemaBlock(json=parameters)
|
||||
tool_spec = BedrockToolSpecBlock(
|
||||
inputSchema=tool_input_schema, name=name, description=description
|
||||
)
|
||||
tool_block = BedrockToolBlock(toolSpec=tool_spec)
|
||||
tool_block_list.append(tool_block)
|
||||
|
||||
return tool_block_list
|
||||
|
||||
|
||||
# Function call template
|
||||
def function_call_prompt(messages: list, functions: list):
|
||||
function_prompt = """Produce JSON OUTPUT ONLY! Adhere to this format {"name": "function_name", "arguments":{"argument_name": "argument_value"}} The following functions are available to you:"""
|
||||
|
|
|
@ -79,7 +79,7 @@ from .llms.anthropic import AnthropicChatCompletion
|
|||
from .llms.anthropic_text import AnthropicTextCompletion
|
||||
from .llms.huggingface_restapi import Huggingface
|
||||
from .llms.predibase import PredibaseChatCompletion
|
||||
from .llms.bedrock_httpx import BedrockLLM
|
||||
from .llms.bedrock_httpx import BedrockLLM, BedrockConverseLLM
|
||||
from .llms.vertex_httpx import VertexLLM
|
||||
from .llms.triton import TritonChatCompletion
|
||||
from .llms.prompt_templates.factory import (
|
||||
|
@ -122,6 +122,7 @@ huggingface = Huggingface()
|
|||
predibase_chat_completions = PredibaseChatCompletion()
|
||||
triton_chat_completions = TritonChatCompletion()
|
||||
bedrock_chat_completion = BedrockLLM()
|
||||
bedrock_converse_chat_completion = BedrockConverseLLM()
|
||||
vertex_chat_completion = VertexLLM()
|
||||
####### COMPLETION ENDPOINTS ################
|
||||
|
||||
|
@ -2103,22 +2104,40 @@ def completion(
|
|||
logging_obj=logging,
|
||||
)
|
||||
else:
|
||||
response = bedrock_chat_completion.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
logging_obj=logging,
|
||||
extra_headers=extra_headers,
|
||||
timeout=timeout,
|
||||
acompletion=acompletion,
|
||||
client=client,
|
||||
)
|
||||
if model.startswith("anthropic"):
|
||||
response = bedrock_converse_chat_completion.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
logging_obj=logging,
|
||||
extra_headers=extra_headers,
|
||||
timeout=timeout,
|
||||
acompletion=acompletion,
|
||||
client=client,
|
||||
)
|
||||
else:
|
||||
response = bedrock_chat_completion.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
logging_obj=logging,
|
||||
extra_headers=extra_headers,
|
||||
timeout=timeout,
|
||||
acompletion=acompletion,
|
||||
client=client,
|
||||
)
|
||||
if optional_params.get("stream", False):
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -243,6 +243,7 @@ def test_completion_bedrock_claude_sts_oidc_auth():
|
|||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("CIRCLE_OIDC_TOKEN_V2") is None,
|
||||
reason="Cannot run without being in CircleCI Runner",
|
||||
|
@ -277,7 +278,15 @@ def test_completion_bedrock_httpx_command_r_sts_oidc_auth():
|
|||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
def test_bedrock_claude_3():
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"image_url",
|
||||
[
|
||||
"",
|
||||
"https://avatars.githubusercontent.com/u/29436595?v=",
|
||||
],
|
||||
)
|
||||
def test_bedrock_claude_3(image_url):
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
data = {
|
||||
|
@ -294,7 +303,7 @@ def test_bedrock_claude_3():
|
|||
{
|
||||
"image_url": {
|
||||
"detail": "high",
|
||||
"url": "",
|
||||
"url": image_url,
|
||||
},
|
||||
"type": "image_url",
|
||||
},
|
||||
|
@ -313,7 +322,6 @@ def test_bedrock_claude_3():
|
|||
# Add any assertions here to check the response
|
||||
assert len(response.choices) > 0
|
||||
assert len(response.choices[0].message.content) > 0
|
||||
|
||||
except RateLimitError:
|
||||
pass
|
||||
except Exception as e:
|
||||
|
@ -552,7 +560,7 @@ def test_bedrock_ptu():
|
|||
assert "url" in mock_client_post.call_args.kwargs
|
||||
assert (
|
||||
mock_client_post.call_args.kwargs["url"]
|
||||
== "https://bedrock-runtime.us-west-2.amazonaws.com/model/arn%3Aaws%3Abedrock%3Aus-west-2%3A888602223428%3Aprovisioned-model%2F8fxff74qyhs3/invoke"
|
||||
== "https://bedrock-runtime.us-west-2.amazonaws.com/model/arn%3Aaws%3Abedrock%3Aus-west-2%3A888602223428%3Aprovisioned-model%2F8fxff74qyhs3/converse"
|
||||
)
|
||||
mock_client_post.assert_called_once()
|
||||
|
||||
|
|
|
@ -300,7 +300,11 @@ def test_completion_claude_3():
|
|||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
def test_completion_claude_3_function_call():
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
["anthropic/claude-3-opus-20240229", "anthropic.claude-3-sonnet-20240229-v1:0"],
|
||||
)
|
||||
def test_completion_claude_3_function_call(model):
|
||||
litellm.set_verbose = True
|
||||
tools = [
|
||||
{
|
||||
|
@ -331,13 +335,14 @@ def test_completion_claude_3_function_call():
|
|||
try:
|
||||
# test without max tokens
|
||||
response = completion(
|
||||
model="anthropic/claude-3-opus-20240229",
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice={
|
||||
"type": "function",
|
||||
"function": {"name": "get_current_weather"},
|
||||
},
|
||||
drop_params=True,
|
||||
)
|
||||
|
||||
# Add any assertions, here to check response args
|
||||
|
@ -364,10 +369,11 @@ def test_completion_claude_3_function_call():
|
|||
)
|
||||
# In the second response, Claude should deduce answer from tool results
|
||||
second_response = completion(
|
||||
model="anthropic/claude-3-opus-20240229",
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
drop_params=True,
|
||||
)
|
||||
print(second_response)
|
||||
except Exception as e:
|
||||
|
|
|
@ -15,6 +15,7 @@ from litellm.llms.prompt_templates.factory import (
|
|||
claude_2_1_pt,
|
||||
llama_2_chat_pt,
|
||||
prompt_factory,
|
||||
_bedrock_tools_pt,
|
||||
)
|
||||
|
||||
|
||||
|
@ -128,3 +129,27 @@ def test_anthropic_messages_pt():
|
|||
|
||||
|
||||
# codellama_prompt_format()
|
||||
def test_bedrock_tool_calling_pt():
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
converted_tools = _bedrock_tools_pt(tools=tools)
|
||||
|
||||
print(converted_tools)
|
||||
|
|
|
@ -1284,18 +1284,18 @@ async def test_completion_replicate_llama3_streaming(sync_mode):
|
|||
# pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.parametrize("sync_mode", [True]) # False
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
# "bedrock/cohere.command-r-plus-v1:0",
|
||||
# "anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
# "anthropic.claude-instant-v1",
|
||||
# "bedrock/ai21.j2-mid",
|
||||
# "mistral.mistral-7b-instruct-v0:2",
|
||||
# "bedrock/amazon.titan-tg1-large",
|
||||
# "meta.llama3-8b-instruct-v1:0",
|
||||
"cohere.command-text-v14"
|
||||
"bedrock/cohere.command-r-plus-v1:0",
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
"anthropic.claude-instant-v1",
|
||||
"bedrock/ai21.j2-mid",
|
||||
"mistral.mistral-7b-instruct-v0:2",
|
||||
"bedrock/amazon.titan-tg1-large",
|
||||
"meta.llama3-8b-instruct-v1:0",
|
||||
"cohere.command-text-v14",
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import TypedDict, Any, Union, Optional
|
||||
from typing import TypedDict, Any, Union, Optional, Literal, List
|
||||
import json
|
||||
from typing_extensions import (
|
||||
Self,
|
||||
|
@ -11,10 +11,137 @@ from typing_extensions import (
|
|||
)
|
||||
|
||||
|
||||
class SystemContentBlock(TypedDict):
|
||||
text: str
|
||||
|
||||
|
||||
class ImageSourceBlock(TypedDict):
|
||||
bytes: Optional[str] # base 64 encoded string
|
||||
|
||||
|
||||
class ImageBlock(TypedDict):
|
||||
format: Literal["png", "jpeg", "gif", "webp"]
|
||||
source: ImageSourceBlock
|
||||
|
||||
|
||||
class ToolResultContentBlock(TypedDict, total=False):
|
||||
image: ImageBlock
|
||||
json: dict
|
||||
text: str
|
||||
|
||||
|
||||
class ToolResultBlock(TypedDict, total=False):
|
||||
content: Required[List[ToolResultContentBlock]]
|
||||
toolUseId: Required[str]
|
||||
status: Literal["success", "error"]
|
||||
|
||||
|
||||
class ToolUseBlock(TypedDict):
|
||||
input: dict
|
||||
name: str
|
||||
toolUseId: str
|
||||
|
||||
|
||||
class ContentBlock(TypedDict, total=False):
|
||||
text: str
|
||||
image: ImageBlock
|
||||
toolResult: ToolResultBlock
|
||||
toolUse: ToolUseBlock
|
||||
|
||||
|
||||
class MessageBlock(TypedDict):
|
||||
content: List[ContentBlock]
|
||||
role: Literal["user", "assistant"]
|
||||
|
||||
|
||||
class ConverseMetricsBlock(TypedDict):
|
||||
latencyMs: float # time in ms
|
||||
|
||||
|
||||
class ConverseResponseOutputBlock(TypedDict):
|
||||
message: Optional[MessageBlock]
|
||||
|
||||
|
||||
class ConverseTokenUsageBlock(TypedDict):
|
||||
inputTokens: int
|
||||
outputTokens: int
|
||||
totalTokens: int
|
||||
|
||||
|
||||
class ConverseResponseBlock(TypedDict):
|
||||
additionalModelResponseFields: dict
|
||||
metrics: ConverseMetricsBlock
|
||||
output: ConverseResponseOutputBlock
|
||||
stopReason: (
|
||||
str # end_turn | tool_use | max_tokens | stop_sequence | content_filtered
|
||||
)
|
||||
usage: ConverseTokenUsageBlock
|
||||
|
||||
|
||||
class ToolInputSchemaBlock(TypedDict):
|
||||
json: Optional[dict]
|
||||
|
||||
|
||||
class ToolSpecBlock(TypedDict, total=False):
|
||||
inputSchema: Required[ToolInputSchemaBlock]
|
||||
name: Required[str]
|
||||
description: str
|
||||
|
||||
|
||||
class ToolBlock(TypedDict):
|
||||
toolSpec: Optional[ToolSpecBlock]
|
||||
|
||||
|
||||
class SpecificToolChoiceBlock(TypedDict):
|
||||
name: str
|
||||
|
||||
|
||||
class ToolChoiceValuesBlock(TypedDict, total=False):
|
||||
any: dict
|
||||
auto: dict
|
||||
tool: SpecificToolChoiceBlock
|
||||
|
||||
|
||||
class ToolConfigBlock(TypedDict, total=False):
|
||||
tools: Required[List[ToolBlock]]
|
||||
toolChoice: Union[str, ToolChoiceValuesBlock]
|
||||
|
||||
|
||||
class InferenceConfig(TypedDict, total=False):
|
||||
maxTokens: int
|
||||
stopSequences: List[str]
|
||||
temperature: float
|
||||
topP: float
|
||||
|
||||
|
||||
class ToolBlockDeltaEvent(TypedDict):
|
||||
input: str
|
||||
|
||||
|
||||
class ContentBlockDeltaEvent(TypedDict, total=False):
|
||||
"""
|
||||
Either 'text' or 'toolUse' will be specified for Converse API streaming response.
|
||||
"""
|
||||
|
||||
text: str
|
||||
toolUse: ToolBlockDeltaEvent
|
||||
|
||||
|
||||
class RequestObject(TypedDict, total=False):
|
||||
additionalModelRequestFields: dict
|
||||
additionalModelResponseFieldPaths: List[str]
|
||||
inferenceConfig: InferenceConfig
|
||||
messages: Required[List[MessageBlock]]
|
||||
system: List[SystemContentBlock]
|
||||
toolConfig: ToolConfigBlock
|
||||
|
||||
|
||||
class GenericStreamingChunk(TypedDict):
|
||||
text: Required[str]
|
||||
tool_str: Required[str]
|
||||
is_finished: Required[bool]
|
||||
finish_reason: Required[str]
|
||||
usage: Optional[ConverseTokenUsageBlock]
|
||||
|
||||
|
||||
class Document(TypedDict):
|
||||
|
|
|
@ -293,3 +293,20 @@ class ListBatchRequest(TypedDict, total=False):
|
|||
extra_headers: Optional[Dict[str, str]]
|
||||
extra_body: Optional[Dict[str, str]]
|
||||
timeout: Optional[float]
|
||||
|
||||
|
||||
class ChatCompletionToolCallFunctionChunk(TypedDict):
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
|
||||
class ChatCompletionToolCallChunk(TypedDict):
|
||||
id: str
|
||||
type: Literal["function"]
|
||||
function: ChatCompletionToolCallFunctionChunk
|
||||
|
||||
|
||||
class ChatCompletionResponseMessage(TypedDict, total=False):
|
||||
content: Optional[str]
|
||||
tool_calls: List[ChatCompletionToolCallChunk]
|
||||
role: Literal["assistant"]
|
||||
|
|
|
@ -239,6 +239,8 @@ def map_finish_reason(
|
|||
return "length"
|
||||
elif finish_reason == "tool_use": # anthropic
|
||||
return "tool_calls"
|
||||
elif finish_reason == "content_filtered":
|
||||
return "content_filter"
|
||||
return finish_reason
|
||||
|
||||
|
||||
|
@ -4064,7 +4066,9 @@ def openai_token_counter(
|
|||
for c in value:
|
||||
if c["type"] == "text":
|
||||
text += c["text"]
|
||||
num_tokens += len(encoding.encode(c["text"], disallowed_special=()))
|
||||
num_tokens += len(
|
||||
encoding.encode(c["text"], disallowed_special=())
|
||||
)
|
||||
elif c["type"] == "image_url":
|
||||
if isinstance(c["image_url"], dict):
|
||||
image_url_dict = c["image_url"]
|
||||
|
@ -5637,19 +5641,29 @@ def get_optional_params(
|
|||
optional_params["stream"] = stream
|
||||
elif "anthropic" in model:
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
# anthropic params on bedrock
|
||||
# \"max_tokens_to_sample\":300,\"temperature\":0.5,\"top_p\":1,\"stop_sequences\":[\"\\\\n\\\\nHuman:\"]}"
|
||||
if model.startswith("anthropic.claude-3"):
|
||||
optional_params = (
|
||||
litellm.AmazonAnthropicClaude3Config().map_openai_params(
|
||||
if "aws_bedrock_client" in passed_params: # deprecated boto3.invoke route.
|
||||
if model.startswith("anthropic.claude-3"):
|
||||
optional_params = (
|
||||
litellm.AmazonAnthropicClaude3Config().map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
)
|
||||
else:
|
||||
optional_params = litellm.AmazonAnthropicConfig().map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
)
|
||||
else:
|
||||
optional_params = litellm.AmazonAnthropicConfig().map_openai_params(
|
||||
else: # bedrock httpx route
|
||||
optional_params = litellm.AmazonConverseConfig().map_openai_params(
|
||||
model=model,
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
drop_params=(
|
||||
drop_params
|
||||
if drop_params is not None and isinstance(drop_params, bool)
|
||||
else False
|
||||
),
|
||||
)
|
||||
elif "amazon" in model: # amazon titan llms
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
|
@ -6427,20 +6441,7 @@ def get_supported_openai_params(
|
|||
- None if unmapped
|
||||
"""
|
||||
if custom_llm_provider == "bedrock":
|
||||
if model.startswith("anthropic.claude-3"):
|
||||
return litellm.AmazonAnthropicClaude3Config().get_supported_openai_params()
|
||||
elif model.startswith("anthropic"):
|
||||
return litellm.AmazonAnthropicConfig().get_supported_openai_params()
|
||||
elif model.startswith("ai21"):
|
||||
return ["max_tokens", "temperature", "top_p", "stream"]
|
||||
elif model.startswith("amazon"):
|
||||
return ["max_tokens", "temperature", "stop", "top_p", "stream"]
|
||||
elif model.startswith("meta"):
|
||||
return ["max_tokens", "temperature", "top_p", "stream"]
|
||||
elif model.startswith("cohere"):
|
||||
return ["stream", "temperature", "max_tokens"]
|
||||
elif model.startswith("mistral"):
|
||||
return ["max_tokens", "temperature", "stop", "top_p", "stream"]
|
||||
return litellm.AmazonConverseConfig().get_supported_openai_params(model=model)
|
||||
elif custom_llm_provider == "ollama":
|
||||
return litellm.OllamaConfig().get_supported_openai_params()
|
||||
elif custom_llm_provider == "ollama_chat":
|
||||
|
@ -11405,12 +11406,27 @@ class CustomStreamWrapper:
|
|||
if response_obj["is_finished"]:
|
||||
self.received_finish_reason = response_obj["finish_reason"]
|
||||
elif self.custom_llm_provider == "bedrock":
|
||||
from litellm.types.llms.bedrock import GenericStreamingChunk
|
||||
|
||||
if self.received_finish_reason is not None:
|
||||
raise StopIteration
|
||||
response_obj = self.handle_bedrock_stream(chunk)
|
||||
response_obj: GenericStreamingChunk = chunk
|
||||
completion_obj["content"] = response_obj["text"]
|
||||
|
||||
if response_obj["is_finished"]:
|
||||
self.received_finish_reason = response_obj["finish_reason"]
|
||||
|
||||
if (
|
||||
self.stream_options
|
||||
and self.stream_options.get("include_usage", False) is True
|
||||
and response_obj["usage"] is not None
|
||||
):
|
||||
self.sent_stream_usage = True
|
||||
model_response.usage = litellm.Usage(
|
||||
prompt_tokens=response_obj["usage"]["inputTokens"],
|
||||
completion_tokens=response_obj["usage"]["outputTokens"],
|
||||
total_tokens=response_obj["usage"]["totalTokens"],
|
||||
)
|
||||
elif self.custom_llm_provider == "sagemaker":
|
||||
print_verbose(f"ENTERS SAGEMAKER STREAMING for chunk {chunk}")
|
||||
response_obj = self.handle_sagemaker_stream(chunk)
|
||||
|
@ -11677,7 +11693,7 @@ class CustomStreamWrapper:
|
|||
and hasattr(model_response, "usage")
|
||||
and hasattr(model_response.usage, "prompt_tokens")
|
||||
):
|
||||
if self.sent_first_chunk == False:
|
||||
if self.sent_first_chunk is False:
|
||||
completion_obj["role"] = "assistant"
|
||||
self.sent_first_chunk = True
|
||||
model_response.choices[0].delta = Delta(**completion_obj)
|
||||
|
@ -11845,6 +11861,8 @@ class CustomStreamWrapper:
|
|||
|
||||
def __next__(self):
|
||||
try:
|
||||
if self.completion_stream is None:
|
||||
self.fetch_sync_stream()
|
||||
while True:
|
||||
if (
|
||||
isinstance(self.completion_stream, str)
|
||||
|
@ -11919,6 +11937,14 @@ class CustomStreamWrapper:
|
|||
custom_llm_provider=self.custom_llm_provider,
|
||||
)
|
||||
|
||||
def fetch_sync_stream(self):
|
||||
if self.completion_stream is None and self.make_call is not None:
|
||||
# Call make_call to get the completion stream
|
||||
self.completion_stream = self.make_call(client=litellm.module_level_client)
|
||||
self._stream_iter = self.completion_stream.__iter__()
|
||||
|
||||
return self.completion_stream
|
||||
|
||||
async def fetch_stream(self):
|
||||
if self.completion_stream is None and self.make_call is not None:
|
||||
# Call make_call to get the completion stream
|
||||
|
|
|
@ -1 +1,3 @@
|
|||
ignore = ["F403", "F401"]
|
||||
ignore = ["F405"]
|
||||
extend-select = ["E501"]
|
||||
line-length = 120
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue