Merge pull request #4033 from BerriAI/litellm_bedrock_converse_api

feat(bedrock_httpx.py): add support for bedrock converse api
This commit is contained in:
Krish Dholakia 2024-06-07 12:08:10 -07:00 committed by GitHub
commit 6b703f0ebf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 1496 additions and 4374 deletions

1
.gitignore vendored
View file

@ -59,3 +59,4 @@ myenv/*
litellm/proxy/_experimental/out/404/index.html
litellm/proxy/_experimental/out/model_hub/index.html
litellm/proxy/_experimental/out/onboarding/index.html
litellm/tests/log.txt

View file

@ -5,7 +5,7 @@ warnings.filterwarnings("ignore", message=".*conflict with protected namespace.*
### INIT VARIABLES ###
import threading, requests, os
from typing import Callable, List, Optional, Dict, Union, Any, Literal
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.caching import Cache
from litellm._logging import (
set_verbose,
@ -233,6 +233,7 @@ max_end_user_budget: Optional[float] = None
#### RELIABILITY ####
request_timeout: float = 6000
module_level_aclient = AsyncHTTPHandler(timeout=request_timeout)
module_level_client = HTTPHandler(timeout=request_timeout)
num_retries: Optional[int] = None # per model endpoint
default_fallbacks: Optional[List] = None
fallbacks: Optional[List] = None
@ -766,7 +767,7 @@ from .llms.sagemaker import SagemakerConfig
from .llms.ollama import OllamaConfig
from .llms.ollama_chat import OllamaChatConfig
from .llms.maritalk import MaritTalkConfig
from .llms.bedrock_httpx import AmazonCohereChatConfig
from .llms.bedrock_httpx import AmazonCohereChatConfig, AmazonConverseConfig
from .llms.bedrock import (
AmazonTitanConfig,
AmazonAI21Config,

View file

@ -38,6 +38,8 @@ from .prompt_templates.factory import (
extract_between_tags,
parse_xml_params,
contains_tag,
_bedrock_converse_messages_pt,
_bedrock_tools_pt,
)
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from .base import BaseLLM
@ -45,6 +47,11 @@ import httpx # type: ignore
from .bedrock import BedrockError, convert_messages_to_prompt, ModelResponseIterator
from litellm.types.llms.bedrock import *
import urllib.parse
from litellm.types.llms.openai import (
ChatCompletionResponseMessage,
ChatCompletionToolCallChunk,
ChatCompletionToolCallFunctionChunk,
)
class AmazonCohereChatConfig:
@ -118,6 +125,8 @@ class AmazonCohereChatConfig:
"presence_penalty",
"seed",
"stop",
"tools",
"tool_choice",
]
def map_openai_params(
@ -176,6 +185,37 @@ async def make_call(
return completion_stream
def make_sync_call(
client: Optional[HTTPHandler],
api_base: str,
headers: dict,
data: str,
model: str,
messages: list,
logging_obj,
):
if client is None:
client = HTTPHandler() # Create a new client if none provided
response = client.post(api_base, headers=headers, data=data, stream=True)
if response.status_code != 200:
raise BedrockError(status_code=response.status_code, message=response.read())
decoder = AWSEventStreamDecoder(model=model)
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024))
# LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response=completion_stream, # Pass the completion stream for logging
additional_args={"complete_input_dict": data},
)
return completion_stream
class BedrockLLM(BaseLLM):
"""
Example call
@ -1000,12 +1040,12 @@ class BedrockLLM(BaseLLM):
if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout
self.client = AsyncHTTPHandler(**_params) # type: ignore
client = AsyncHTTPHandler(**_params) # type: ignore
else:
self.client = client # type: ignore
client = client # type: ignore
try:
response = await self.client.post(api_base, headers=headers, data=data) # type: ignore
response = await client.post(api_base, headers=headers, data=data) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
@ -1069,6 +1109,745 @@ class BedrockLLM(BaseLLM):
return super().embedding(*args, **kwargs)
class AmazonConverseConfig:
"""
Reference - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html
#2 - https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#conversation-inference-supported-models-features
"""
maxTokens: Optional[int]
stopSequences: Optional[List[str]]
temperature: Optional[int]
topP: Optional[int]
def __init__(
self,
maxTokens: Optional[int] = None,
stopSequences: Optional[List[str]] = None,
temperature: Optional[int] = None,
topP: Optional[int] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self, model: str) -> List[str]:
supported_params = [
"max_tokens",
"stream",
"stream_options",
"stop",
"temperature",
"top_p",
"extra_headers",
]
if (
model.startswith("anthropic")
or model.startswith("mistral")
or model.startswith("cohere")
):
supported_params.append("tools")
if model.startswith("anthropic") or model.startswith("mistral"):
# only anthropic and mistral support tool choice config. otherwise (E.g. cohere) will fail the call - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html
supported_params.append("tool_choice")
return supported_params
def map_tool_choice_values(
self, model: str, tool_choice: Union[str, dict], drop_params: bool
) -> Optional[ToolChoiceValuesBlock]:
if tool_choice == "none":
if litellm.drop_params is True or drop_params is True:
return None
else:
raise litellm.utils.UnsupportedParamsError(
message="Bedrock doesn't support tool_choice={}. To drop it from the call, set `litellm.drop_params = True.".format(
tool_choice
),
status_code=400,
)
elif tool_choice == "required":
return ToolChoiceValuesBlock(any={})
elif tool_choice == "auto":
return ToolChoiceValuesBlock(auto={})
elif isinstance(tool_choice, dict):
# only supported for anthropic + mistral models - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html
specific_tool = SpecificToolChoiceBlock(
name=tool_choice.get("function", {}).get("name", "")
)
return ToolChoiceValuesBlock(tool=specific_tool)
else:
raise litellm.utils.UnsupportedParamsError(
message="Bedrock doesn't support tool_choice={}. Supported tool_choice values=['auto', 'required', json object]. To drop it from the call, set `litellm.drop_params = True.".format(
tool_choice
),
status_code=400,
)
def get_supported_image_types(self) -> List[str]:
return ["png", "jpeg", "gif", "webp"]
def map_openai_params(
self,
model: str,
non_default_params: dict,
optional_params: dict,
drop_params: bool,
) -> dict:
for param, value in non_default_params.items():
if param == "max_tokens":
optional_params["maxTokens"] = value
if param == "stream":
optional_params["stream"] = value
if param == "stop":
if isinstance(value, str):
value = [value]
optional_params["stop_sequences"] = value
if param == "temperature":
optional_params["temperature"] = value
if param == "top_p":
optional_params["topP"] = value
if param == "tools":
optional_params["tools"] = value
if param == "tool_choice":
_tool_choice_value = self.map_tool_choice_values(
model=model, tool_choice=value, drop_params=drop_params # type: ignore
)
if _tool_choice_value is not None:
optional_params["tool_choice"] = _tool_choice_value
return optional_params
class BedrockConverseLLM(BaseLLM):
def __init__(self) -> None:
super().__init__()
def process_response(
self,
model: str,
response: Union[requests.Response, httpx.Response],
model_response: ModelResponse,
stream: bool,
logging_obj: Logging,
optional_params: dict,
api_key: str,
data: Union[dict, str],
messages: List,
print_verbose,
encoding,
) -> Union[ModelResponse, CustomStreamWrapper]:
## LOGGING
logging_obj.post_call(
input=messages,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT
try:
completion_response = ConverseResponseBlock(**response.json()) # type: ignore
except Exception as e:
raise BedrockError(
message="Received={}, Error converting to valid response block={}. File an issue if litellm error - https://github.com/BerriAI/litellm/issues".format(
response.text, str(e)
),
status_code=422,
)
"""
Bedrock Response Object has optional message block
completion_response["output"].get("message", None)
A message block looks like this (Example 1):
"output": {
"message": {
"role": "assistant",
"content": [
{
"text": "Is there anything else you'd like to talk about? Perhaps I can help with some economic questions or provide some information about economic concepts?"
}
]
}
},
(Example 2):
"output": {
"message": {
"role": "assistant",
"content": [
{
"toolUse": {
"toolUseId": "tooluse_hbTgdi0CSLq_hM4P8csZJA",
"name": "top_song",
"input": {
"sign": "WZPZ"
}
}
}
]
}
}
"""
message: Optional[MessageBlock] = completion_response["output"]["message"]
chat_completion_message: ChatCompletionResponseMessage = {"role": "assistant"}
content_str = ""
tools: List[ChatCompletionToolCallChunk] = []
if message is not None:
for content in message["content"]:
"""
- Content is either a tool response or text
"""
if "text" in content:
content_str += content["text"]
if "toolUse" in content:
_function_chunk = ChatCompletionToolCallFunctionChunk(
name=content["toolUse"]["name"],
arguments=json.dumps(content["toolUse"]["input"]),
)
_tool_response_chunk = ChatCompletionToolCallChunk(
id=content["toolUse"]["toolUseId"],
type="function",
function=_function_chunk,
)
tools.append(_tool_response_chunk)
chat_completion_message["content"] = content_str
chat_completion_message["tool_calls"] = tools
## CALCULATING USAGE - bedrock returns usage in the headers
input_tokens = completion_response["usage"]["inputTokens"]
output_tokens = completion_response["usage"]["outputTokens"]
total_tokens = completion_response["usage"]["totalTokens"]
model_response.choices = [
litellm.Choices(
finish_reason=map_finish_reason(completion_response["stopReason"]),
index=0,
message=litellm.Message(**chat_completion_message),
)
]
model_response["created"] = int(time.time())
model_response["model"] = model
usage = Usage(
prompt_tokens=input_tokens,
completion_tokens=output_tokens,
total_tokens=total_tokens,
)
setattr(model_response, "usage", usage)
return model_response
def encode_model_id(self, model_id: str) -> str:
"""
Double encode the model ID to ensure it matches the expected double-encoded format.
Args:
model_id (str): The model ID to encode.
Returns:
str: The double-encoded model ID.
"""
return urllib.parse.quote(model_id, safe="")
def get_credentials(
self,
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_region_name: Optional[str] = None,
aws_session_name: Optional[str] = None,
aws_profile_name: Optional[str] = None,
aws_role_name: Optional[str] = None,
aws_web_identity_token: Optional[str] = None,
):
"""
Return a boto3.Credentials object
"""
import boto3
## CHECK IS 'os.environ/' passed in
params_to_check: List[Optional[str]] = [
aws_access_key_id,
aws_secret_access_key,
aws_region_name,
aws_session_name,
aws_profile_name,
aws_role_name,
aws_web_identity_token,
]
# Iterate over parameters and update if needed
for i, param in enumerate(params_to_check):
if param and param.startswith("os.environ/"):
_v = get_secret(param)
if _v is not None and isinstance(_v, str):
params_to_check[i] = _v
# Assign updated values back to parameters
(
aws_access_key_id,
aws_secret_access_key,
aws_region_name,
aws_session_name,
aws_profile_name,
aws_role_name,
aws_web_identity_token,
) = params_to_check
### CHECK STS ###
if (
aws_web_identity_token is not None
and aws_role_name is not None
and aws_session_name is not None
):
oidc_token = get_secret(aws_web_identity_token)
if oidc_token is None:
raise BedrockError(
message="OIDC token could not be retrieved from secret manager.",
status_code=401,
)
sts_client = boto3.client("sts")
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
sts_response = sts_client.assume_role_with_web_identity(
RoleArn=aws_role_name,
RoleSessionName=aws_session_name,
WebIdentityToken=oidc_token,
DurationSeconds=3600,
)
session = boto3.Session(
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
aws_session_token=sts_response["Credentials"]["SessionToken"],
region_name=aws_region_name,
)
return session.get_credentials()
elif aws_role_name is not None and aws_session_name is not None:
sts_client = boto3.client(
"sts",
aws_access_key_id=aws_access_key_id, # [OPTIONAL]
aws_secret_access_key=aws_secret_access_key, # [OPTIONAL]
)
sts_response = sts_client.assume_role(
RoleArn=aws_role_name, RoleSessionName=aws_session_name
)
# Extract the credentials from the response and convert to Session Credentials
sts_credentials = sts_response["Credentials"]
from botocore.credentials import Credentials
credentials = Credentials(
access_key=sts_credentials["AccessKeyId"],
secret_key=sts_credentials["SecretAccessKey"],
token=sts_credentials["SessionToken"],
)
return credentials
elif aws_profile_name is not None: ### CHECK SESSION ###
# uses auth values from AWS profile usually stored in ~/.aws/credentials
client = boto3.Session(profile_name=aws_profile_name)
return client.get_credentials()
else:
session = boto3.Session(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
region_name=aws_region_name,
)
return session.get_credentials()
async def async_streaming(
self,
model: str,
messages: list,
api_base: str,
model_response: ModelResponse,
print_verbose: Callable,
data: str,
timeout: Optional[Union[float, httpx.Timeout]],
encoding,
logging_obj,
stream,
optional_params: dict,
litellm_params=None,
logger_fn=None,
headers={},
client: Optional[AsyncHTTPHandler] = None,
) -> CustomStreamWrapper:
streaming_response = CustomStreamWrapper(
completion_stream=None,
make_call=partial(
make_call,
client=client,
api_base=api_base,
headers=headers,
data=data,
model=model,
messages=messages,
logging_obj=logging_obj,
),
model=model,
custom_llm_provider="bedrock",
logging_obj=logging_obj,
)
return streaming_response
async def async_completion(
self,
model: str,
messages: list,
api_base: str,
model_response: ModelResponse,
print_verbose: Callable,
data: str,
timeout: Optional[Union[float, httpx.Timeout]],
encoding,
logging_obj,
stream,
optional_params: dict,
litellm_params=None,
logger_fn=None,
headers={},
client: Optional[AsyncHTTPHandler] = None,
) -> Union[ModelResponse, CustomStreamWrapper]:
if client is None:
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout
client = AsyncHTTPHandler(**_params) # type: ignore
else:
client = client # type: ignore
try:
response = await client.post(api_base, headers=headers, data=data) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise BedrockError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException as e:
raise BedrockError(status_code=408, message="Timeout error occurred.")
return self.process_response(
model=model,
response=response,
model_response=model_response,
stream=stream if isinstance(stream, bool) else False,
logging_obj=logging_obj,
api_key="",
data=data,
messages=messages,
print_verbose=print_verbose,
optional_params=optional_params,
encoding=encoding,
)
def completion(
self,
model: str,
messages: list,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
logging_obj,
optional_params: dict,
acompletion: bool,
timeout: Optional[Union[float, httpx.Timeout]],
litellm_params=None,
logger_fn=None,
extra_headers: Optional[dict] = None,
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
):
try:
import boto3
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
from botocore.credentials import Credentials
except ImportError:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
## SETUP ##
stream = optional_params.pop("stream", None)
modelId = optional_params.pop("model_id", None)
if modelId is not None:
modelId = self.encode_model_id(model_id=modelId)
else:
modelId = model
provider = model.split(".")[0]
## CREDENTIALS ##
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_region_name = optional_params.pop("aws_region_name", None)
aws_role_name = optional_params.pop("aws_role_name", None)
aws_session_name = optional_params.pop("aws_session_name", None)
aws_profile_name = optional_params.pop("aws_profile_name", None)
aws_bedrock_runtime_endpoint = optional_params.pop(
"aws_bedrock_runtime_endpoint", None
) # https://bedrock-runtime.{region_name}.amazonaws.com
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
### SET REGION NAME ###
if aws_region_name is None:
# check env #
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
if litellm_aws_region_name is not None and isinstance(
litellm_aws_region_name, str
):
aws_region_name = litellm_aws_region_name
standard_aws_region_name = get_secret("AWS_REGION", None)
if standard_aws_region_name is not None and isinstance(
standard_aws_region_name, str
):
aws_region_name = standard_aws_region_name
if aws_region_name is None:
aws_region_name = "us-west-2"
credentials: Credentials = self.get_credentials(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_region_name=aws_region_name,
aws_session_name=aws_session_name,
aws_profile_name=aws_profile_name,
aws_role_name=aws_role_name,
aws_web_identity_token=aws_web_identity_token,
)
### SET RUNTIME ENDPOINT ###
endpoint_url = ""
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
if aws_bedrock_runtime_endpoint is not None and isinstance(
aws_bedrock_runtime_endpoint, str
):
endpoint_url = aws_bedrock_runtime_endpoint
elif env_aws_bedrock_runtime_endpoint and isinstance(
env_aws_bedrock_runtime_endpoint, str
):
endpoint_url = env_aws_bedrock_runtime_endpoint
else:
endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
if (stream is not None and stream is True) and provider != "ai21":
endpoint_url = f"{endpoint_url}/model/{modelId}/converse-stream"
else:
endpoint_url = f"{endpoint_url}/model/{modelId}/converse"
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
# Separate system prompt from rest of message
system_prompt_indices = []
system_content_blocks: List[SystemContentBlock] = []
for idx, message in enumerate(messages):
if message["role"] == "system":
_system_content_block = SystemContentBlock(text=message["content"])
system_content_blocks.append(_system_content_block)
system_prompt_indices.append(idx)
if len(system_prompt_indices) > 0:
for idx in reversed(system_prompt_indices):
messages.pop(idx)
inference_params = copy.deepcopy(optional_params)
additional_request_keys = []
additional_request_params = {}
supported_converse_params = AmazonConverseConfig.__annotations__.keys()
supported_tool_call_params = ["tools", "tool_choice"]
## TRANSFORMATION ##
# send all model-specific params in 'additional_request_params'
for k, v in inference_params.items():
if (
k not in supported_converse_params
and k not in supported_tool_call_params
):
additional_request_params[k] = v
additional_request_keys.append(k)
for key in additional_request_keys:
inference_params.pop(key, None)
bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt(
messages=messages
)
bedrock_tools: List[ToolBlock] = _bedrock_tools_pt(
inference_params.pop("tools", [])
)
bedrock_tool_config: Optional[ToolConfigBlock] = None
if len(bedrock_tools) > 0:
tool_choice_values: ToolChoiceValuesBlock = inference_params.pop(
"tool_choice", None
)
bedrock_tool_config = ToolConfigBlock(
tools=bedrock_tools,
)
if tool_choice_values is not None:
bedrock_tool_config["toolChoice"] = tool_choice_values
_data: RequestObject = {
"messages": bedrock_messages,
"additionalModelRequestFields": additional_request_params,
"system": system_content_blocks,
"inferenceConfig": InferenceConfig(**inference_params),
}
if bedrock_tool_config is not None:
_data["toolConfig"] = bedrock_tool_config
data = json.dumps(_data)
## COMPLETION CALL
headers = {"Content-Type": "application/json"}
if extra_headers is not None:
headers = {"Content-Type": "application/json", **extra_headers}
request = AWSRequest(
method="POST", url=endpoint_url, data=data, headers=headers
)
sigv4.add_auth(request)
prepped = request.prepare()
## LOGGING
logging_obj.pre_call(
input=messages,
api_key="",
additional_args={
"complete_input_dict": data,
"api_base": prepped.url,
"headers": prepped.headers,
},
)
### ROUTING (ASYNC, STREAMING, SYNC)
if acompletion:
if isinstance(client, HTTPHandler):
client = None
if stream is True and provider != "ai21":
return self.async_streaming(
model=model,
messages=messages,
data=data,
api_base=prepped.url,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
logging_obj=logging_obj,
optional_params=optional_params,
stream=True,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=prepped.headers,
timeout=timeout,
client=client,
) # type: ignore
### ASYNC COMPLETION
return self.async_completion(
model=model,
messages=messages,
data=data,
api_base=prepped.url,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
logging_obj=logging_obj,
optional_params=optional_params,
stream=stream, # type: ignore
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=prepped.headers,
timeout=timeout,
client=client,
) # type: ignore
if (stream is not None and stream is True) and provider != "ai21":
streaming_response = CustomStreamWrapper(
completion_stream=None,
make_call=partial(
make_sync_call,
client=None,
api_base=prepped.url,
headers=prepped.headers, # type: ignore
data=data,
model=model,
messages=messages,
logging_obj=logging_obj,
),
model=model,
custom_llm_provider="bedrock",
logging_obj=logging_obj,
)
## LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response=streaming_response,
additional_args={"complete_input_dict": data},
)
return streaming_response
### COMPLETION
if client is None or isinstance(client, AsyncHTTPHandler):
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout
client = HTTPHandler(**_params) # type: ignore
else:
client = client
try:
response = client.post(url=prepped.url, headers=prepped.headers, data=data) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise BedrockError(status_code=error_code, message=response.text)
except httpx.TimeoutException:
raise BedrockError(status_code=408, message="Timeout error occurred.")
return self.process_response(
model=model,
response=response,
model_response=model_response,
stream=stream,
logging_obj=logging_obj,
optional_params=optional_params,
api_key="",
data=data,
messages=messages,
print_verbose=print_verbose,
encoding=encoding,
)
def get_response_stream_shape():
from botocore.model import ServiceModel
from botocore.loaders import Loader
@ -1086,6 +1865,31 @@ class AWSEventStreamDecoder:
self.model = model
self.parser = EventStreamJSONParser()
def converse_chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk:
text = ""
tool_str = ""
is_finished = False
finish_reason = ""
usage: Optional[ConverseTokenUsageBlock] = None
if "delta" in chunk_data:
delta_obj = ContentBlockDeltaEvent(**chunk_data["delta"])
if "text" in delta_obj:
text = delta_obj["text"]
elif "toolUse" in delta_obj:
tool_str = delta_obj["toolUse"]["input"]
elif "stopReason" in chunk_data:
finish_reason = map_finish_reason(chunk_data.get("stopReason", "stop"))
elif "usage" in chunk_data:
usage = ConverseTokenUsageBlock(**chunk_data["usage"]) # type: ignore
response = GenericStreamingChunk(
text=text,
tool_str=tool_str,
is_finished=is_finished,
finish_reason=finish_reason,
usage=usage,
)
return response
def _chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk:
text = ""
is_finished = False
@ -1098,19 +1902,8 @@ class AWSEventStreamDecoder:
is_finished = True
finish_reason = "stop"
######## bedrock.anthropic mappings ###############
elif "completion" in chunk_data: # not claude-3
text = chunk_data["completion"] # bedrock.anthropic
stop_reason = chunk_data.get("stop_reason", None)
if stop_reason != None:
is_finished = True
finish_reason = stop_reason
elif "delta" in chunk_data:
if chunk_data["delta"].get("text", None) is not None:
text = chunk_data["delta"]["text"]
stop_reason = chunk_data["delta"].get("stop_reason", None)
if stop_reason != None:
is_finished = True
finish_reason = stop_reason
return self.converse_chunk_parser(chunk_data=chunk_data)
######## bedrock.mistral mappings ###############
elif "outputs" in chunk_data:
if (
@ -1137,11 +1930,11 @@ class AWSEventStreamDecoder:
is_finished = True
finish_reason = chunk_data["completionReason"]
return GenericStreamingChunk(
**{
"text": text,
"is_finished": is_finished,
"finish_reason": finish_reason,
}
text=text,
is_finished=is_finished,
finish_reason=finish_reason,
tool_str="",
usage=None,
)
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GenericStreamingChunk]:
@ -1178,9 +1971,14 @@ class AWSEventStreamDecoder:
parsed_response = self.parser.parse(response_dict, get_response_stream_shape())
if response_dict["status_code"] != 200:
raise ValueError(f"Bad response code, expected 200: {response_dict}")
if "chunk" in parsed_response:
chunk = parsed_response.get("chunk")
if not chunk:
return None
return chunk.get("bytes").decode() # type: ignore[no-any-return]
else:
chunk = response_dict.get("body")
if not chunk:
return None
chunk = parsed_response.get("chunk")
if not chunk:
return None
return chunk.get("bytes").decode() # type: ignore[no-any-return]
return chunk.decode() # type: ignore[no-any-return]

View file

@ -156,12 +156,13 @@ class HTTPHandler:
self,
url: str,
data: Optional[Union[dict, str]] = None,
json: Optional[Union[dict, str]] = None,
params: Optional[dict] = None,
headers: Optional[dict] = None,
stream: bool = False,
):
req = self.client.build_request(
"POST", url, data=data, params=params, headers=headers # type: ignore
"POST", url, data=data, json=json, params=params, headers=headers # type: ignore
)
response = self.client.send(req, stream=stream)
return response

View file

@ -3,14 +3,7 @@ import requests, traceback
import json, re, xml.etree.ElementTree as ET
from jinja2 import Template, exceptions, meta, BaseLoader
from jinja2.sandbox import ImmutableSandboxedEnvironment
from typing import (
Any,
List,
Mapping,
MutableMapping,
Optional,
Sequence,
)
from typing import Any, List, Mapping, MutableMapping, Optional, Sequence, Tuple
import litellm
import litellm.types
from litellm.types.completion import (
@ -24,7 +17,7 @@ from litellm.types.completion import (
import litellm.types.llms
from litellm.types.llms.anthropic import *
import uuid
from litellm.types.llms.bedrock import MessageBlock as BedrockMessageBlock
import litellm.types.llms.vertex_ai
@ -1460,9 +1453,7 @@ def _load_image_from_url(image_url):
try:
from PIL import Image
except:
raise Exception(
"gemini image conversion failed please run `pip install Pillow`"
)
raise Exception("image conversion failed please run `pip install Pillow`")
from io import BytesIO
try:
@ -1613,6 +1604,380 @@ def azure_text_pt(messages: list):
return prompt
###### AMAZON BEDROCK #######
from litellm.types.llms.bedrock import (
ToolResultContentBlock as BedrockToolResultContentBlock,
ToolResultBlock as BedrockToolResultBlock,
ToolConfigBlock as BedrockToolConfigBlock,
ToolUseBlock as BedrockToolUseBlock,
ImageSourceBlock as BedrockImageSourceBlock,
ImageBlock as BedrockImageBlock,
ContentBlock as BedrockContentBlock,
ToolInputSchemaBlock as BedrockToolInputSchemaBlock,
ToolSpecBlock as BedrockToolSpecBlock,
ToolBlock as BedrockToolBlock,
ToolChoiceValuesBlock as BedrockToolChoiceValuesBlock,
)
def get_image_details(image_url) -> Tuple[str, str]:
try:
import base64
# Send a GET request to the image URL
response = requests.get(image_url)
response.raise_for_status() # Raise an exception for HTTP errors
# Check the response's content type to ensure it is an image
content_type = response.headers.get("content-type")
if not content_type or "image" not in content_type:
raise ValueError(
f"URL does not point to a valid image (content-type: {content_type})"
)
# Convert the image content to base64 bytes
base64_bytes = base64.b64encode(response.content).decode("utf-8")
# Get mime-type
mime_type = content_type.split("/")[
1
] # Extract mime-type from content-type header
return base64_bytes, mime_type
except requests.RequestException as e:
raise Exception(f"Request failed: {e}")
except Exception as e:
raise e
def _process_bedrock_converse_image_block(image_url: str) -> BedrockImageBlock:
if "base64" in image_url:
# Case 1: Images with base64 encoding
import base64, re
# base 64 is passed as data:image/jpeg;base64,<base-64-encoded-image>
image_metadata, img_without_base_64 = image_url.split(",")
# read mime_type from img_without_base_64=data:image/jpeg;base64
# Extract MIME type using regular expression
mime_type_match = re.match(r"data:(.*?);base64", image_metadata)
if mime_type_match:
mime_type = mime_type_match.group(1)
image_format = mime_type.split("/")[1]
else:
mime_type = "image/jpeg"
image_format = "jpeg"
_blob = BedrockImageSourceBlock(bytes=img_without_base_64)
supported_image_formats = (
litellm.AmazonConverseConfig().get_supported_image_types()
)
if image_format in supported_image_formats:
return BedrockImageBlock(source=_blob, format=image_format) # type: ignore
else:
# Handle the case when the image format is not supported
raise ValueError(
"Unsupported image format: {}. Supported formats: {}".format(
image_format, supported_image_formats
)
)
elif "https:/" in image_url:
# Case 2: Images with direct links
image_bytes, image_format = get_image_details(image_url)
_blob = BedrockImageSourceBlock(bytes=image_bytes)
supported_image_formats = (
litellm.AmazonConverseConfig().get_supported_image_types()
)
if image_format in supported_image_formats:
return BedrockImageBlock(source=_blob, format=image_format) # type: ignore
else:
# Handle the case when the image format is not supported
raise ValueError(
"Unsupported image format: {}. Supported formats: {}".format(
image_format, supported_image_formats
)
)
else:
raise ValueError(
"Unsupported image type. Expected either image url or base64 encoded string - \
e.g. 'data:image/jpeg;base64,<base64-encoded-string>'"
)
def _convert_to_bedrock_tool_call_invoke(
tool_calls: list,
) -> List[BedrockContentBlock]:
"""
OpenAI tool invokes:
{
"role": "assistant",
"content": null,
"tool_calls": [
{
"id": "call_abc123",
"type": "function",
"function": {
"name": "get_current_weather",
"arguments": "{\n\"location\": \"Boston, MA\"\n}"
}
}
]
},
"""
"""
Bedrock tool invokes:
[
{
"role": "assistant",
"toolUse": {
"input": {"location": "Boston, MA", ..},
"name": "get_current_weather",
"toolUseId": "call_abc123"
}
}
]
"""
"""
- json.loads argument
- extract name
- extract id
"""
try:
_parts_list: List[BedrockContentBlock] = []
for tool in tool_calls:
if "function" in tool:
id = tool["id"]
name = tool["function"].get("name", "")
arguments = tool["function"].get("arguments", "")
arguments_dict = json.loads(arguments)
bedrock_tool = BedrockToolUseBlock(
input=arguments_dict, name=name, toolUseId=id
)
bedrock_content_block = BedrockContentBlock(toolUse=bedrock_tool)
_parts_list.append(bedrock_content_block)
return _parts_list
except Exception as e:
raise Exception(
"Unable to convert openai tool calls={} to bedrock tool calls. Received error={}".format(
tool_calls, str(e)
)
)
def _convert_to_bedrock_tool_call_result(
message: dict,
) -> BedrockMessageBlock:
"""
OpenAI message with a tool result looks like:
{
"tool_call_id": "tool_1",
"role": "tool",
"name": "get_current_weather",
"content": "function result goes here",
},
OpenAI message with a function call result looks like:
{
"role": "function",
"name": "get_current_weather",
"content": "function result goes here",
}
"""
"""
Bedrock result looks like this:
{
"role": "user",
"content": [
{
"toolResult": {
"toolUseId": "tooluse_kZJMlvQmRJ6eAyJE5GIl7Q",
"content": [
{
"json": {
"song": "Elemental Hotel",
"artist": "8 Storey Hike"
}
}
]
}
}
]
}
"""
"""
-
"""
content = message.get("content", "")
name = message.get("name", "")
id = message.get("tool_call_id", str(uuid.uuid4()))
tool_result_content_block = BedrockToolResultContentBlock(text=content)
tool_result = BedrockToolResultBlock(
content=[tool_result_content_block],
toolUseId=id,
)
content_block = BedrockContentBlock(toolResult=tool_result)
return BedrockMessageBlock(role="user", content=[content_block])
def _bedrock_converse_messages_pt(messages: List) -> List[BedrockMessageBlock]:
"""
Converts given messages from OpenAI format to Bedrock format
- Roles must alternate b/w 'user' and 'model' (same as anthropic -> merge consecutive roles)
- Please ensure that function response turn comes immediately after a function call turn
"""
contents: List[BedrockMessageBlock] = []
msg_i = 0
while msg_i < len(messages):
user_content: List[BedrockContentBlock] = []
init_msg_i = msg_i
## MERGE CONSECUTIVE USER CONTENT ##
while msg_i < len(messages) and messages[msg_i]["role"] == "user":
if isinstance(messages[msg_i]["content"], list):
_parts: List[BedrockContentBlock] = []
for element in messages[msg_i]["content"]:
if isinstance(element, dict):
if element["type"] == "text":
_part = BedrockContentBlock(text=element["text"])
_parts.append(_part)
elif element["type"] == "image_url":
image_url = element["image_url"]["url"]
_part = _process_bedrock_converse_image_block( # type: ignore
image_url=image_url
)
_parts.append(BedrockContentBlock(image=_part)) # type: ignore
user_content.extend(_parts)
else:
_part = BedrockContentBlock(text=messages[msg_i]["content"])
user_content.append(_part)
msg_i += 1
if user_content:
contents.append(BedrockMessageBlock(role="user", content=user_content))
assistant_content: List[BedrockContentBlock] = []
## MERGE CONSECUTIVE ASSISTANT CONTENT ##
while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
if isinstance(messages[msg_i]["content"], list):
assistants_parts: List[BedrockContentBlock] = []
for element in messages[msg_i]["content"]:
if isinstance(element, dict):
if element["type"] == "text":
assistants_part = BedrockContentBlock(text=element["text"])
assistants_parts.append(assistants_part)
elif element["type"] == "image_url":
image_url = element["image_url"]["url"]
assistants_part = _process_bedrock_converse_image_block( # type: ignore
image_url=image_url
)
assistants_parts.append(
BedrockContentBlock(image=assistants_part) # type: ignore
)
assistant_content.extend(assistants_parts)
elif messages[msg_i].get(
"tool_calls", []
): # support assistant tool invoke convertion
assistant_content.extend(
_convert_to_bedrock_tool_call_invoke(messages[msg_i]["tool_calls"])
)
else:
assistant_text = (
messages[msg_i].get("content") or ""
) # either string or none
if assistant_text:
assistant_content.append(BedrockContentBlock(text=assistant_text))
msg_i += 1
if assistant_content:
contents.append(
BedrockMessageBlock(role="assistant", content=assistant_content)
)
## APPEND TOOL CALL MESSAGES ##
if msg_i < len(messages) and messages[msg_i]["role"] == "tool":
tool_call_result = _convert_to_bedrock_tool_call_result(messages[msg_i])
contents.append(tool_call_result)
msg_i += 1
if msg_i == init_msg_i: # prevent infinite loops
raise Exception(
"Invalid Message passed in - {}. File an issue https://github.com/BerriAI/litellm/issues".format(
messages[msg_i]
)
)
return contents
def _bedrock_tools_pt(tools: List) -> List[BedrockToolBlock]:
"""
OpenAI tools looks like:
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
}
}
]
"""
"""
Bedrock toolConfig looks like:
"tools": [
{
"toolSpec": {
"name": "top_song",
"description": "Get the most popular song played on a radio station.",
"inputSchema": {
"json": {
"type": "object",
"properties": {
"sign": {
"type": "string",
"description": "The call sign for the radio station for which you want the most popular song. Example calls signs are WZPZ, and WKRP."
}
},
"required": [
"sign"
]
}
}
}
}
]
"""
tool_block_list: List[BedrockToolBlock] = []
for tool in tools:
parameters = tool.get("function", {}).get("parameters", None)
name = tool.get("function", {}).get("name", "")
description = tool.get("function", {}).get("description", "")
tool_input_schema = BedrockToolInputSchemaBlock(json=parameters)
tool_spec = BedrockToolSpecBlock(
inputSchema=tool_input_schema, name=name, description=description
)
tool_block = BedrockToolBlock(toolSpec=tool_spec)
tool_block_list.append(tool_block)
return tool_block_list
# Function call template
def function_call_prompt(messages: list, functions: list):
function_prompt = """Produce JSON OUTPUT ONLY! Adhere to this format {"name": "function_name", "arguments":{"argument_name": "argument_value"}} The following functions are available to you:"""

View file

@ -79,7 +79,7 @@ from .llms.anthropic import AnthropicChatCompletion
from .llms.anthropic_text import AnthropicTextCompletion
from .llms.huggingface_restapi import Huggingface
from .llms.predibase import PredibaseChatCompletion
from .llms.bedrock_httpx import BedrockLLM
from .llms.bedrock_httpx import BedrockLLM, BedrockConverseLLM
from .llms.vertex_httpx import VertexLLM
from .llms.triton import TritonChatCompletion
from .llms.prompt_templates.factory import (
@ -122,6 +122,7 @@ huggingface = Huggingface()
predibase_chat_completions = PredibaseChatCompletion()
triton_chat_completions = TritonChatCompletion()
bedrock_chat_completion = BedrockLLM()
bedrock_converse_chat_completion = BedrockConverseLLM()
vertex_chat_completion = VertexLLM()
####### COMPLETION ENDPOINTS ################
@ -2103,22 +2104,40 @@ def completion(
logging_obj=logging,
)
else:
response = bedrock_chat_completion.completion(
model=model,
messages=messages,
custom_prompt_dict=custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
logging_obj=logging,
extra_headers=extra_headers,
timeout=timeout,
acompletion=acompletion,
client=client,
)
if model.startswith("anthropic"):
response = bedrock_converse_chat_completion.completion(
model=model,
messages=messages,
custom_prompt_dict=custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
logging_obj=logging,
extra_headers=extra_headers,
timeout=timeout,
acompletion=acompletion,
client=client,
)
else:
response = bedrock_chat_completion.completion(
model=model,
messages=messages,
custom_prompt_dict=custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
logging_obj=logging,
extra_headers=extra_headers,
timeout=timeout,
acompletion=acompletion,
client=client,
)
if optional_params.get("stream", False):
## LOGGING
logging.post_call(

File diff suppressed because it is too large Load diff

View file

@ -243,6 +243,7 @@ def test_completion_bedrock_claude_sts_oidc_auth():
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@pytest.mark.skipif(
os.environ.get("CIRCLE_OIDC_TOKEN_V2") is None,
reason="Cannot run without being in CircleCI Runner",
@ -277,7 +278,15 @@ def test_completion_bedrock_httpx_command_r_sts_oidc_auth():
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_bedrock_claude_3():
@pytest.mark.parametrize(
"image_url",
[
"",
"https://avatars.githubusercontent.com/u/29436595?v=",
],
)
def test_bedrock_claude_3(image_url):
try:
litellm.set_verbose = True
data = {
@ -294,7 +303,7 @@ def test_bedrock_claude_3():
{
"image_url": {
"detail": "high",
"url": "",
"url": image_url,
},
"type": "image_url",
},
@ -313,7 +322,6 @@ def test_bedrock_claude_3():
# Add any assertions here to check the response
assert len(response.choices) > 0
assert len(response.choices[0].message.content) > 0
except RateLimitError:
pass
except Exception as e:
@ -552,7 +560,7 @@ def test_bedrock_ptu():
assert "url" in mock_client_post.call_args.kwargs
assert (
mock_client_post.call_args.kwargs["url"]
== "https://bedrock-runtime.us-west-2.amazonaws.com/model/arn%3Aaws%3Abedrock%3Aus-west-2%3A888602223428%3Aprovisioned-model%2F8fxff74qyhs3/invoke"
== "https://bedrock-runtime.us-west-2.amazonaws.com/model/arn%3Aaws%3Abedrock%3Aus-west-2%3A888602223428%3Aprovisioned-model%2F8fxff74qyhs3/converse"
)
mock_client_post.assert_called_once()

View file

@ -300,7 +300,11 @@ def test_completion_claude_3():
pytest.fail(f"Error occurred: {e}")
def test_completion_claude_3_function_call():
@pytest.mark.parametrize(
"model",
["anthropic/claude-3-opus-20240229", "anthropic.claude-3-sonnet-20240229-v1:0"],
)
def test_completion_claude_3_function_call(model):
litellm.set_verbose = True
tools = [
{
@ -331,13 +335,14 @@ def test_completion_claude_3_function_call():
try:
# test without max tokens
response = completion(
model="anthropic/claude-3-opus-20240229",
model=model,
messages=messages,
tools=tools,
tool_choice={
"type": "function",
"function": {"name": "get_current_weather"},
},
drop_params=True,
)
# Add any assertions, here to check response args
@ -364,10 +369,11 @@ def test_completion_claude_3_function_call():
)
# In the second response, Claude should deduce answer from tool results
second_response = completion(
model="anthropic/claude-3-opus-20240229",
model=model,
messages=messages,
tools=tools,
tool_choice="auto",
drop_params=True,
)
print(second_response)
except Exception as e:

View file

@ -15,6 +15,7 @@ from litellm.llms.prompt_templates.factory import (
claude_2_1_pt,
llama_2_chat_pt,
prompt_factory,
_bedrock_tools_pt,
)
@ -128,3 +129,27 @@ def test_anthropic_messages_pt():
# codellama_prompt_format()
def test_bedrock_tool_calling_pt():
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
]
converted_tools = _bedrock_tools_pt(tools=tools)
print(converted_tools)

View file

@ -1284,18 +1284,18 @@ async def test_completion_replicate_llama3_streaming(sync_mode):
# pytest.fail(f"Error occurred: {e}")
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.parametrize("sync_mode", [True]) # False
@pytest.mark.parametrize(
"model",
[
# "bedrock/cohere.command-r-plus-v1:0",
# "anthropic.claude-3-sonnet-20240229-v1:0",
# "anthropic.claude-instant-v1",
# "bedrock/ai21.j2-mid",
# "mistral.mistral-7b-instruct-v0:2",
# "bedrock/amazon.titan-tg1-large",
# "meta.llama3-8b-instruct-v1:0",
"cohere.command-text-v14"
"bedrock/cohere.command-r-plus-v1:0",
"anthropic.claude-3-sonnet-20240229-v1:0",
"anthropic.claude-instant-v1",
"bedrock/ai21.j2-mid",
"mistral.mistral-7b-instruct-v0:2",
"bedrock/amazon.titan-tg1-large",
"meta.llama3-8b-instruct-v1:0",
"cohere.command-text-v14",
],
)
@pytest.mark.asyncio

View file

@ -1,4 +1,4 @@
from typing import TypedDict, Any, Union, Optional
from typing import TypedDict, Any, Union, Optional, Literal, List
import json
from typing_extensions import (
Self,
@ -11,10 +11,137 @@ from typing_extensions import (
)
class SystemContentBlock(TypedDict):
text: str
class ImageSourceBlock(TypedDict):
bytes: Optional[str] # base 64 encoded string
class ImageBlock(TypedDict):
format: Literal["png", "jpeg", "gif", "webp"]
source: ImageSourceBlock
class ToolResultContentBlock(TypedDict, total=False):
image: ImageBlock
json: dict
text: str
class ToolResultBlock(TypedDict, total=False):
content: Required[List[ToolResultContentBlock]]
toolUseId: Required[str]
status: Literal["success", "error"]
class ToolUseBlock(TypedDict):
input: dict
name: str
toolUseId: str
class ContentBlock(TypedDict, total=False):
text: str
image: ImageBlock
toolResult: ToolResultBlock
toolUse: ToolUseBlock
class MessageBlock(TypedDict):
content: List[ContentBlock]
role: Literal["user", "assistant"]
class ConverseMetricsBlock(TypedDict):
latencyMs: float # time in ms
class ConverseResponseOutputBlock(TypedDict):
message: Optional[MessageBlock]
class ConverseTokenUsageBlock(TypedDict):
inputTokens: int
outputTokens: int
totalTokens: int
class ConverseResponseBlock(TypedDict):
additionalModelResponseFields: dict
metrics: ConverseMetricsBlock
output: ConverseResponseOutputBlock
stopReason: (
str # end_turn | tool_use | max_tokens | stop_sequence | content_filtered
)
usage: ConverseTokenUsageBlock
class ToolInputSchemaBlock(TypedDict):
json: Optional[dict]
class ToolSpecBlock(TypedDict, total=False):
inputSchema: Required[ToolInputSchemaBlock]
name: Required[str]
description: str
class ToolBlock(TypedDict):
toolSpec: Optional[ToolSpecBlock]
class SpecificToolChoiceBlock(TypedDict):
name: str
class ToolChoiceValuesBlock(TypedDict, total=False):
any: dict
auto: dict
tool: SpecificToolChoiceBlock
class ToolConfigBlock(TypedDict, total=False):
tools: Required[List[ToolBlock]]
toolChoice: Union[str, ToolChoiceValuesBlock]
class InferenceConfig(TypedDict, total=False):
maxTokens: int
stopSequences: List[str]
temperature: float
topP: float
class ToolBlockDeltaEvent(TypedDict):
input: str
class ContentBlockDeltaEvent(TypedDict, total=False):
"""
Either 'text' or 'toolUse' will be specified for Converse API streaming response.
"""
text: str
toolUse: ToolBlockDeltaEvent
class RequestObject(TypedDict, total=False):
additionalModelRequestFields: dict
additionalModelResponseFieldPaths: List[str]
inferenceConfig: InferenceConfig
messages: Required[List[MessageBlock]]
system: List[SystemContentBlock]
toolConfig: ToolConfigBlock
class GenericStreamingChunk(TypedDict):
text: Required[str]
tool_str: Required[str]
is_finished: Required[bool]
finish_reason: Required[str]
usage: Optional[ConverseTokenUsageBlock]
class Document(TypedDict):

View file

@ -293,3 +293,20 @@ class ListBatchRequest(TypedDict, total=False):
extra_headers: Optional[Dict[str, str]]
extra_body: Optional[Dict[str, str]]
timeout: Optional[float]
class ChatCompletionToolCallFunctionChunk(TypedDict):
name: str
arguments: str
class ChatCompletionToolCallChunk(TypedDict):
id: str
type: Literal["function"]
function: ChatCompletionToolCallFunctionChunk
class ChatCompletionResponseMessage(TypedDict, total=False):
content: Optional[str]
tool_calls: List[ChatCompletionToolCallChunk]
role: Literal["assistant"]

View file

@ -239,6 +239,8 @@ def map_finish_reason(
return "length"
elif finish_reason == "tool_use": # anthropic
return "tool_calls"
elif finish_reason == "content_filtered":
return "content_filter"
return finish_reason
@ -4064,7 +4066,9 @@ def openai_token_counter(
for c in value:
if c["type"] == "text":
text += c["text"]
num_tokens += len(encoding.encode(c["text"], disallowed_special=()))
num_tokens += len(
encoding.encode(c["text"], disallowed_special=())
)
elif c["type"] == "image_url":
if isinstance(c["image_url"], dict):
image_url_dict = c["image_url"]
@ -5637,19 +5641,29 @@ def get_optional_params(
optional_params["stream"] = stream
elif "anthropic" in model:
_check_valid_arg(supported_params=supported_params)
# anthropic params on bedrock
# \"max_tokens_to_sample\":300,\"temperature\":0.5,\"top_p\":1,\"stop_sequences\":[\"\\\\n\\\\nHuman:\"]}"
if model.startswith("anthropic.claude-3"):
optional_params = (
litellm.AmazonAnthropicClaude3Config().map_openai_params(
if "aws_bedrock_client" in passed_params: # deprecated boto3.invoke route.
if model.startswith("anthropic.claude-3"):
optional_params = (
litellm.AmazonAnthropicClaude3Config().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
)
)
else:
optional_params = litellm.AmazonAnthropicConfig().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
)
)
else:
optional_params = litellm.AmazonAnthropicConfig().map_openai_params(
else: # bedrock httpx route
optional_params = litellm.AmazonConverseConfig().map_openai_params(
model=model,
non_default_params=non_default_params,
optional_params=optional_params,
drop_params=(
drop_params
if drop_params is not None and isinstance(drop_params, bool)
else False
),
)
elif "amazon" in model: # amazon titan llms
_check_valid_arg(supported_params=supported_params)
@ -6427,20 +6441,7 @@ def get_supported_openai_params(
- None if unmapped
"""
if custom_llm_provider == "bedrock":
if model.startswith("anthropic.claude-3"):
return litellm.AmazonAnthropicClaude3Config().get_supported_openai_params()
elif model.startswith("anthropic"):
return litellm.AmazonAnthropicConfig().get_supported_openai_params()
elif model.startswith("ai21"):
return ["max_tokens", "temperature", "top_p", "stream"]
elif model.startswith("amazon"):
return ["max_tokens", "temperature", "stop", "top_p", "stream"]
elif model.startswith("meta"):
return ["max_tokens", "temperature", "top_p", "stream"]
elif model.startswith("cohere"):
return ["stream", "temperature", "max_tokens"]
elif model.startswith("mistral"):
return ["max_tokens", "temperature", "stop", "top_p", "stream"]
return litellm.AmazonConverseConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "ollama":
return litellm.OllamaConfig().get_supported_openai_params()
elif custom_llm_provider == "ollama_chat":
@ -11405,12 +11406,27 @@ class CustomStreamWrapper:
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider == "bedrock":
from litellm.types.llms.bedrock import GenericStreamingChunk
if self.received_finish_reason is not None:
raise StopIteration
response_obj = self.handle_bedrock_stream(chunk)
response_obj: GenericStreamingChunk = chunk
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
if (
self.stream_options
and self.stream_options.get("include_usage", False) is True
and response_obj["usage"] is not None
):
self.sent_stream_usage = True
model_response.usage = litellm.Usage(
prompt_tokens=response_obj["usage"]["inputTokens"],
completion_tokens=response_obj["usage"]["outputTokens"],
total_tokens=response_obj["usage"]["totalTokens"],
)
elif self.custom_llm_provider == "sagemaker":
print_verbose(f"ENTERS SAGEMAKER STREAMING for chunk {chunk}")
response_obj = self.handle_sagemaker_stream(chunk)
@ -11677,7 +11693,7 @@ class CustomStreamWrapper:
and hasattr(model_response, "usage")
and hasattr(model_response.usage, "prompt_tokens")
):
if self.sent_first_chunk == False:
if self.sent_first_chunk is False:
completion_obj["role"] = "assistant"
self.sent_first_chunk = True
model_response.choices[0].delta = Delta(**completion_obj)
@ -11845,6 +11861,8 @@ class CustomStreamWrapper:
def __next__(self):
try:
if self.completion_stream is None:
self.fetch_sync_stream()
while True:
if (
isinstance(self.completion_stream, str)
@ -11919,6 +11937,14 @@ class CustomStreamWrapper:
custom_llm_provider=self.custom_llm_provider,
)
def fetch_sync_stream(self):
if self.completion_stream is None and self.make_call is not None:
# Call make_call to get the completion stream
self.completion_stream = self.make_call(client=litellm.module_level_client)
self._stream_iter = self.completion_stream.__iter__()
return self.completion_stream
async def fetch_stream(self):
if self.completion_stream is None and self.make_call is not None:
# Call make_call to get the completion stream

View file

@ -1 +1,3 @@
ignore = ["F403", "F401"]
ignore = ["F405"]
extend-select = ["E501"]
line-length = 120