""" Manages calling Bedrock's `/converse` API + `/invoke` API """ import copy import json import time import types import urllib.parse import uuid from functools import partial from typing import Any, AsyncIterator, Callable, Iterator, List, Optional, Tuple, Union import httpx # type: ignore import litellm from litellm import verbose_logger from litellm.caching.caching import InMemoryCache from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.litellm_core_utils.litellm_logging import Logging from litellm.litellm_core_utils.prompt_templates.factory import ( cohere_message_pt, construct_tool_use_system_prompt, contains_tag, custom_prompt, extract_between_tags, parse_xml_params, prompt_factory, ) from litellm.llms.custom_httpx.http_handler import ( AsyncHTTPHandler, HTTPHandler, _get_httpx_client, get_async_httpx_client, ) from litellm.types.llms.bedrock import * from litellm.types.llms.openai import ( ChatCompletionToolCallChunk, ChatCompletionUsageBlock, ) from litellm.types.utils import GenericStreamingChunk as GChunk from litellm.types.utils import ModelResponse, Usage from litellm.utils import CustomStreamWrapper, get_secret from ..base_aws_llm import BaseAWSLLM from ..common_utils import BedrockError, ModelResponseIterator, get_bedrock_tool_name _response_stream_shape_cache = None bedrock_tool_name_mappings: InMemoryCache = InMemoryCache( max_size_in_memory=50, default_ttl=600 ) class AmazonCohereChatConfig: """ Reference - https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command-r-plus.html """ documents: Optional[List[Document]] = None search_queries_only: Optional[bool] = None preamble: Optional[str] = None max_tokens: Optional[int] = None temperature: Optional[float] = None p: Optional[float] = None k: Optional[float] = None prompt_truncation: Optional[str] = None frequency_penalty: Optional[float] = None presence_penalty: Optional[float] = None seed: Optional[int] = None return_prompt: Optional[bool] = None stop_sequences: Optional[List[str]] = None raw_prompting: Optional[bool] = None def __init__( self, documents: Optional[List[Document]] = None, search_queries_only: Optional[bool] = None, preamble: Optional[str] = None, max_tokens: Optional[int] = None, temperature: Optional[float] = None, p: Optional[float] = None, k: Optional[float] = None, prompt_truncation: Optional[str] = None, frequency_penalty: Optional[float] = None, presence_penalty: Optional[float] = None, seed: Optional[int] = None, return_prompt: Optional[bool] = None, stop_sequences: Optional[str] = None, raw_prompting: Optional[bool] = None, ) -> None: locals_ = locals() for key, value in locals_.items(): if key != "self" and value is not None: setattr(self.__class__, key, value) @classmethod def get_config(cls): return { k: v for k, v in cls.__dict__.items() if not k.startswith("__") and not isinstance( v, ( types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod, ), ) and v is not None } def get_supported_openai_params(self) -> List[str]: return [ "max_tokens", "max_completion_tokens", "stream", "stop", "temperature", "top_p", "frequency_penalty", "presence_penalty", "seed", "stop", "tools", "tool_choice", ] def map_openai_params( self, non_default_params: dict, optional_params: dict ) -> dict: for param, value in non_default_params.items(): if param == "max_tokens" or param == "max_completion_tokens": optional_params["max_tokens"] = value if param == "stream": optional_params["stream"] = value if param == "stop": if isinstance(value, str): value = [value] optional_params["stop_sequences"] = value if param == "temperature": optional_params["temperature"] = value if param == "top_p": optional_params["p"] = value if param == "frequency_penalty": optional_params["frequency_penalty"] = value if param == "presence_penalty": optional_params["presence_penalty"] = value if "seed": optional_params["seed"] = value return optional_params async def make_call( client: Optional[AsyncHTTPHandler], api_base: str, headers: dict, data: str, model: str, messages: list, logging_obj, fake_stream: bool = False, json_mode: Optional[bool] = False, ): try: if client is None: client = get_async_httpx_client( llm_provider=litellm.LlmProviders.BEDROCK ) # Create a new client if none provided response = await client.post( api_base, headers=headers, data=data, stream=not fake_stream, ) if response.status_code != 200: raise BedrockError(status_code=response.status_code, message=response.text) if fake_stream: model_response: ( ModelResponse ) = litellm.AmazonConverseConfig()._transform_response( model=model, response=response, model_response=litellm.ModelResponse(), stream=True, logging_obj=logging_obj, optional_params={}, api_key="", data=data, messages=messages, print_verbose=litellm.print_verbose, encoding=litellm.encoding, ) # type: ignore completion_stream: Any = MockResponseIterator( model_response=model_response, json_mode=json_mode ) else: decoder = AWSEventStreamDecoder(model=model) completion_stream = decoder.aiter_bytes( response.aiter_bytes(chunk_size=1024) ) # LOGGING logging_obj.post_call( input=messages, api_key="", original_response="first stream response received", additional_args={"complete_input_dict": data}, ) return completion_stream except httpx.HTTPStatusError as err: error_code = err.response.status_code raise BedrockError(status_code=error_code, message=err.response.text) except httpx.TimeoutException: raise BedrockError(status_code=408, message="Timeout error occurred.") except Exception as e: raise BedrockError(status_code=500, message=str(e)) class BedrockLLM(BaseAWSLLM): """ Example call ``` curl --location --request POST 'https://bedrock-runtime.{aws_region_name}.amazonaws.com/model/{bedrock_model_name}/invoke' \ --header 'Content-Type: application/json' \ --header 'Accept: application/json' \ --user "$AWS_ACCESS_KEY_ID":"$AWS_SECRET_ACCESS_KEY" \ --aws-sigv4 "aws:amz:us-east-1:bedrock" \ --data-raw '{ "prompt": "Hi", "temperature": 0, "p": 0.9, "max_tokens": 4096 }' ``` """ def __init__(self) -> None: super().__init__() def convert_messages_to_prompt( self, model, messages, provider, custom_prompt_dict ) -> Tuple[str, Optional[list]]: # handle anthropic prompts and amazon titan prompts prompt = "" chat_history: Optional[list] = None ## CUSTOM PROMPT if model in custom_prompt_dict: # check if the model has a registered custom prompt model_prompt_details = custom_prompt_dict[model] prompt = custom_prompt( role_dict=model_prompt_details["roles"], initial_prompt_value=model_prompt_details.get( "initial_prompt_value", "" ), final_prompt_value=model_prompt_details.get("final_prompt_value", ""), messages=messages, ) return prompt, None ## ELSE if provider == "anthropic" or provider == "amazon": prompt = prompt_factory( model=model, messages=messages, custom_llm_provider="bedrock" ) elif provider == "mistral": prompt = prompt_factory( model=model, messages=messages, custom_llm_provider="bedrock" ) elif provider == "meta": prompt = prompt_factory( model=model, messages=messages, custom_llm_provider="bedrock" ) elif provider == "cohere": prompt, chat_history = cohere_message_pt(messages=messages) else: prompt = "" for message in messages: if "role" in message: if message["role"] == "user": prompt += f"{message['content']}" else: prompt += f"{message['content']}" else: prompt += f"{message['content']}" return prompt, chat_history # type: ignore def process_response( # noqa: PLR0915 self, model: str, response: httpx.Response, model_response: ModelResponse, stream: bool, logging_obj: Logging, optional_params: dict, api_key: str, data: Union[dict, str], messages: List, print_verbose, encoding, ) -> Union[ModelResponse, CustomStreamWrapper]: provider = model.split(".")[0] ## LOGGING logging_obj.post_call( input=messages, api_key=api_key, original_response=response.text, additional_args={"complete_input_dict": data}, ) print_verbose(f"raw model_response: {response.text}") ## RESPONSE OBJECT try: completion_response = response.json() except Exception: raise BedrockError(message=response.text, status_code=422) outputText: Optional[str] = None try: if provider == "cohere": if "text" in completion_response: outputText = completion_response["text"] # type: ignore elif "generations" in completion_response: outputText = completion_response["generations"][0]["text"] model_response.choices[0].finish_reason = map_finish_reason( completion_response["generations"][0]["finish_reason"] ) elif provider == "anthropic": if model.startswith("anthropic.claude-3"): json_schemas: dict = {} _is_function_call = False ## Handle Tool Calling if "tools" in optional_params: _is_function_call = True for tool in optional_params["tools"]: json_schemas[tool["function"]["name"]] = tool[ "function" ].get("parameters", None) outputText = completion_response.get("content")[0].get("text", None) if outputText is not None and contains_tag( "invoke", outputText ): # OUTPUT PARSE FUNCTION CALL function_name = extract_between_tags("tool_name", outputText)[0] function_arguments_str = extract_between_tags( "invoke", outputText )[0].strip() function_arguments_str = ( f"{function_arguments_str}" ) function_arguments = parse_xml_params( function_arguments_str, json_schema=json_schemas.get( function_name, None ), # check if we have a json schema for this function name) ) _message = litellm.Message( tool_calls=[ { "id": f"call_{uuid.uuid4()}", "type": "function", "function": { "name": function_name, "arguments": json.dumps(function_arguments), }, } ], content=None, ) model_response.choices[0].message = _message # type: ignore model_response._hidden_params["original_response"] = ( outputText # allow user to access raw anthropic tool calling response ) if ( _is_function_call is True and stream is not None and stream is True ): print_verbose( "INSIDE BEDROCK STREAMING TOOL CALLING CONDITION BLOCK" ) # return an iterator streaming_model_response = ModelResponse(stream=True) streaming_model_response.choices[0].finish_reason = getattr( model_response.choices[0], "finish_reason", "stop" ) # streaming_model_response.choices = [litellm.utils.StreamingChoices()] streaming_choice = litellm.utils.StreamingChoices() streaming_choice.index = model_response.choices[0].index _tool_calls = [] print_verbose( f"type of model_response.choices[0]: {type(model_response.choices[0])}" ) print_verbose( f"type of streaming_choice: {type(streaming_choice)}" ) if isinstance(model_response.choices[0], litellm.Choices): if getattr( model_response.choices[0].message, "tool_calls", None ) is not None and isinstance( model_response.choices[0].message.tool_calls, list ): for tool_call in model_response.choices[ 0 ].message.tool_calls: _tool_call = {**tool_call.dict(), "index": 0} _tool_calls.append(_tool_call) delta_obj = litellm.utils.Delta( content=getattr( model_response.choices[0].message, "content", None ), role=model_response.choices[0].message.role, tool_calls=_tool_calls, ) streaming_choice.delta = delta_obj streaming_model_response.choices = [streaming_choice] completion_stream = ModelResponseIterator( model_response=streaming_model_response ) print_verbose( "Returns anthropic CustomStreamWrapper with 'cached_response' streaming object" ) return litellm.CustomStreamWrapper( completion_stream=completion_stream, model=model, custom_llm_provider="cached_response", logging_obj=logging_obj, ) model_response.choices[0].finish_reason = map_finish_reason( completion_response.get("stop_reason", "") ) _usage = litellm.Usage( prompt_tokens=completion_response["usage"]["input_tokens"], completion_tokens=completion_response["usage"]["output_tokens"], total_tokens=completion_response["usage"]["input_tokens"] + completion_response["usage"]["output_tokens"], ) setattr(model_response, "usage", _usage) else: outputText = completion_response["completion"] model_response.choices[0].finish_reason = completion_response[ "stop_reason" ] elif provider == "ai21": outputText = ( completion_response.get("completions")[0].get("data").get("text") ) elif provider == "meta": outputText = completion_response["generation"] elif provider == "mistral": outputText = completion_response["outputs"][0]["text"] model_response.choices[0].finish_reason = completion_response[ "outputs" ][0]["stop_reason"] else: # amazon titan outputText = completion_response.get("results")[0].get("outputText") except Exception as e: raise BedrockError( message="Error processing={}, Received error={}".format( response.text, str(e) ), status_code=422, ) try: if ( outputText is not None and len(outputText) > 0 and hasattr(model_response.choices[0], "message") and getattr(model_response.choices[0].message, "tool_calls", None) # type: ignore is None ): model_response.choices[0].message.content = outputText # type: ignore elif ( hasattr(model_response.choices[0], "message") and getattr(model_response.choices[0].message, "tool_calls", None) # type: ignore is not None ): pass else: raise Exception() except Exception as e: raise BedrockError( message="Error parsing received text={}.\nError-{}".format( outputText, str(e) ), status_code=response.status_code, ) if stream and provider == "ai21": streaming_model_response = ModelResponse(stream=True) streaming_model_response.choices[0].finish_reason = model_response.choices[ # type: ignore 0 ].finish_reason # streaming_model_response.choices = [litellm.utils.StreamingChoices()] streaming_choice = litellm.utils.StreamingChoices() streaming_choice.index = model_response.choices[0].index delta_obj = litellm.utils.Delta( content=getattr(model_response.choices[0].message, "content", None), # type: ignore role=model_response.choices[0].message.role, # type: ignore ) streaming_choice.delta = delta_obj streaming_model_response.choices = [streaming_choice] mri = ModelResponseIterator(model_response=streaming_model_response) return CustomStreamWrapper( completion_stream=mri, model=model, custom_llm_provider="cached_response", logging_obj=logging_obj, ) ## CALCULATING USAGE - bedrock returns usage in the headers bedrock_input_tokens = response.headers.get( "x-amzn-bedrock-input-token-count", None ) bedrock_output_tokens = response.headers.get( "x-amzn-bedrock-output-token-count", None ) prompt_tokens = int( bedrock_input_tokens or litellm.token_counter(messages=messages) ) completion_tokens = int( bedrock_output_tokens or litellm.token_counter( text=model_response.choices[0].message.content, # type: ignore count_response_tokens=True, ) ) model_response.created = int(time.time()) model_response.model = model usage = Usage( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens, ) setattr(model_response, "usage", usage) return model_response def encode_model_id(self, model_id: str) -> str: """ Double encode the model ID to ensure it matches the expected double-encoded format. Args: model_id (str): The model ID to encode. Returns: str: The double-encoded model ID. """ return urllib.parse.quote(model_id, safe="") def completion( # noqa: PLR0915 self, model: str, messages: list, api_base: Optional[str], custom_prompt_dict: dict, model_response: ModelResponse, print_verbose: Callable, encoding, logging_obj, optional_params: dict, acompletion: bool, timeout: Optional[Union[float, httpx.Timeout]], litellm_params=None, logger_fn=None, extra_headers: Optional[dict] = None, client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None, ) -> Union[ModelResponse, CustomStreamWrapper]: try: from botocore.auth import SigV4Auth from botocore.awsrequest import AWSRequest from botocore.credentials import Credentials except ImportError: raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") ## SETUP ## stream = optional_params.pop("stream", None) modelId = optional_params.pop("model_id", None) if modelId is not None: modelId = self.encode_model_id(model_id=modelId) else: modelId = model provider = model.split(".")[0] ## CREDENTIALS ## # pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) aws_access_key_id = optional_params.pop("aws_access_key_id", None) aws_session_token = optional_params.pop("aws_session_token", None) aws_region_name = optional_params.pop("aws_region_name", None) aws_role_name = optional_params.pop("aws_role_name", None) aws_session_name = optional_params.pop("aws_session_name", None) aws_profile_name = optional_params.pop("aws_profile_name", None) aws_bedrock_runtime_endpoint = optional_params.pop( "aws_bedrock_runtime_endpoint", None ) # https://bedrock-runtime.{region_name}.amazonaws.com aws_web_identity_token = optional_params.pop("aws_web_identity_token", None) aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None) ### SET REGION NAME ### if aws_region_name is None: # check env # litellm_aws_region_name = get_secret("AWS_REGION_NAME", None) if litellm_aws_region_name is not None and isinstance( litellm_aws_region_name, str ): aws_region_name = litellm_aws_region_name standard_aws_region_name = get_secret("AWS_REGION", None) if standard_aws_region_name is not None and isinstance( standard_aws_region_name, str ): aws_region_name = standard_aws_region_name if aws_region_name is None: aws_region_name = "us-west-2" credentials: Credentials = self.get_credentials( aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, aws_session_token=aws_session_token, aws_region_name=aws_region_name, aws_session_name=aws_session_name, aws_profile_name=aws_profile_name, aws_role_name=aws_role_name, aws_web_identity_token=aws_web_identity_token, aws_sts_endpoint=aws_sts_endpoint, ) ### SET RUNTIME ENDPOINT ### endpoint_url, proxy_endpoint_url = self.get_runtime_endpoint( api_base=api_base, aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint, aws_region_name=aws_region_name, ) if (stream is not None and stream is True) and provider != "ai21": endpoint_url = f"{endpoint_url}/model/{modelId}/invoke-with-response-stream" proxy_endpoint_url = ( f"{proxy_endpoint_url}/model/{modelId}/invoke-with-response-stream" ) else: endpoint_url = f"{endpoint_url}/model/{modelId}/invoke" proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/invoke" sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name) prompt, chat_history = self.convert_messages_to_prompt( model, messages, provider, custom_prompt_dict ) inference_params = copy.deepcopy(optional_params) json_schemas: dict = {} if provider == "cohere": if model.startswith("cohere.command-r"): ## LOAD CONFIG config = litellm.AmazonCohereChatConfig().get_config() for k, v in config.items(): if ( k not in inference_params ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in inference_params[k] = v _data = {"message": prompt, **inference_params} if chat_history is not None: _data["chat_history"] = chat_history data = json.dumps(_data) else: ## LOAD CONFIG config = litellm.AmazonCohereConfig.get_config() for k, v in config.items(): if ( k not in inference_params ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in inference_params[k] = v if stream is True: inference_params["stream"] = ( True # cohere requires stream = True in inference params ) data = json.dumps({"prompt": prompt, **inference_params}) elif provider == "anthropic": if model.startswith("anthropic.claude-3"): # Separate system prompt from rest of message system_prompt_idx: list[int] = [] system_messages: list[str] = [] for idx, message in enumerate(messages): if message["role"] == "system": system_messages.append(message["content"]) system_prompt_idx.append(idx) if len(system_prompt_idx) > 0: inference_params["system"] = "\n".join(system_messages) messages = [ i for j, i in enumerate(messages) if j not in system_prompt_idx ] # Format rest of message according to anthropic guidelines messages = prompt_factory( model=model, messages=messages, custom_llm_provider="anthropic_xml" ) # type: ignore ## LOAD CONFIG config = litellm.AmazonAnthropicClaude3Config.get_config() for k, v in config.items(): if ( k not in inference_params ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in inference_params[k] = v ## Handle Tool Calling if "tools" in inference_params: _is_function_call = True for tool in inference_params["tools"]: json_schemas[tool["function"]["name"]] = tool["function"].get( "parameters", None ) tool_calling_system_prompt = construct_tool_use_system_prompt( tools=inference_params["tools"] ) inference_params["system"] = ( inference_params.get("system", "\n") + tool_calling_system_prompt ) # add the anthropic tool calling prompt to the system prompt inference_params.pop("tools") data = json.dumps({"messages": messages, **inference_params}) else: ## LOAD CONFIG config = litellm.AmazonAnthropicConfig.get_config() for k, v in config.items(): if ( k not in inference_params ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in inference_params[k] = v data = json.dumps({"prompt": prompt, **inference_params}) elif provider == "ai21": ## LOAD CONFIG config = litellm.AmazonAI21Config.get_config() for k, v in config.items(): if ( k not in inference_params ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in inference_params[k] = v data = json.dumps({"prompt": prompt, **inference_params}) elif provider == "mistral": ## LOAD CONFIG config = litellm.AmazonMistralConfig.get_config() for k, v in config.items(): if ( k not in inference_params ): # completion(top_k=3) > amazon_config(top_k=3) <- allows for dynamic variables to be passed in inference_params[k] = v data = json.dumps({"prompt": prompt, **inference_params}) elif provider == "amazon": # amazon titan ## LOAD CONFIG config = litellm.AmazonTitanConfig.get_config() for k, v in config.items(): if ( k not in inference_params ): # completion(top_k=3) > amazon_config(top_k=3) <- allows for dynamic variables to be passed in inference_params[k] = v data = json.dumps( { "inputText": prompt, "textGenerationConfig": inference_params, } ) elif provider == "meta": ## LOAD CONFIG config = litellm.AmazonLlamaConfig.get_config() for k, v in config.items(): if ( k not in inference_params ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in inference_params[k] = v data = json.dumps({"prompt": prompt, **inference_params}) else: ## LOGGING logging_obj.pre_call( input=messages, api_key="", additional_args={ "complete_input_dict": inference_params, }, ) raise BedrockError( status_code=404, message="Bedrock HTTPX: Unknown provider={}, model={}".format( provider, model ), ) ## COMPLETION CALL headers = {"Content-Type": "application/json"} if extra_headers is not None: headers = {"Content-Type": "application/json", **extra_headers} request = AWSRequest( method="POST", url=endpoint_url, data=data, headers=headers ) sigv4.add_auth(request) if ( extra_headers is not None and "Authorization" in extra_headers ): # prevent sigv4 from overwriting the auth header request.headers["Authorization"] = extra_headers["Authorization"] prepped = request.prepare() ## LOGGING logging_obj.pre_call( input=messages, api_key="", additional_args={ "complete_input_dict": data, "api_base": proxy_endpoint_url, "headers": prepped.headers, }, ) ### ROUTING (ASYNC, STREAMING, SYNC) if acompletion: if isinstance(client, HTTPHandler): client = None if stream is True and provider != "ai21": return self.async_streaming( model=model, messages=messages, data=data, api_base=proxy_endpoint_url, model_response=model_response, print_verbose=print_verbose, encoding=encoding, logging_obj=logging_obj, optional_params=optional_params, stream=True, litellm_params=litellm_params, logger_fn=logger_fn, headers=prepped.headers, timeout=timeout, client=client, ) # type: ignore ### ASYNC COMPLETION return self.async_completion( model=model, messages=messages, data=data, api_base=proxy_endpoint_url, model_response=model_response, print_verbose=print_verbose, encoding=encoding, logging_obj=logging_obj, optional_params=optional_params, stream=stream, # type: ignore litellm_params=litellm_params, logger_fn=logger_fn, headers=prepped.headers, timeout=timeout, client=client, ) # type: ignore if client is None or isinstance(client, AsyncHTTPHandler): _params = {} if timeout is not None: if isinstance(timeout, float) or isinstance(timeout, int): timeout = httpx.Timeout(timeout) _params["timeout"] = timeout self.client = _get_httpx_client(_params) # type: ignore else: self.client = client if (stream is not None and stream is True) and provider != "ai21": response = self.client.post( url=proxy_endpoint_url, headers=prepped.headers, # type: ignore data=data, stream=stream, ) if response.status_code != 200: raise BedrockError( status_code=response.status_code, message=response.read() ) decoder = AWSEventStreamDecoder(model=model) completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024)) streaming_response = CustomStreamWrapper( completion_stream=completion_stream, model=model, custom_llm_provider="bedrock", logging_obj=logging_obj, ) ## LOGGING logging_obj.post_call( input=messages, api_key="", original_response=streaming_response, additional_args={"complete_input_dict": data}, ) return streaming_response try: response = self.client.post(url=proxy_endpoint_url, headers=prepped.headers, data=data) # type: ignore response.raise_for_status() except httpx.HTTPStatusError as err: error_code = err.response.status_code raise BedrockError(status_code=error_code, message=err.response.text) except httpx.TimeoutException: raise BedrockError(status_code=408, message="Timeout error occurred.") return self.process_response( model=model, response=response, model_response=model_response, stream=stream, logging_obj=logging_obj, optional_params=optional_params, api_key="", data=data, messages=messages, print_verbose=print_verbose, encoding=encoding, ) async def async_completion( self, model: str, messages: list, api_base: str, model_response: ModelResponse, print_verbose: Callable, data: str, timeout: Optional[Union[float, httpx.Timeout]], encoding, logging_obj, stream, optional_params: dict, litellm_params=None, logger_fn=None, headers={}, client: Optional[AsyncHTTPHandler] = None, ) -> Union[ModelResponse, CustomStreamWrapper]: if client is None: _params = {} if timeout is not None: if isinstance(timeout, float) or isinstance(timeout, int): timeout = httpx.Timeout(timeout) _params["timeout"] = timeout client = get_async_httpx_client(params=_params, llm_provider=litellm.LlmProviders.BEDROCK) # type: ignore else: client = client # type: ignore try: response = await client.post(api_base, headers=headers, data=data) # type: ignore response.raise_for_status() except httpx.HTTPStatusError as err: error_code = err.response.status_code raise BedrockError(status_code=error_code, message=err.response.text) except httpx.TimeoutException: raise BedrockError(status_code=408, message="Timeout error occurred.") return self.process_response( model=model, response=response, model_response=model_response, stream=stream if isinstance(stream, bool) else False, logging_obj=logging_obj, api_key="", data=data, messages=messages, print_verbose=print_verbose, optional_params=optional_params, encoding=encoding, ) async def async_streaming( self, model: str, messages: list, api_base: str, model_response: ModelResponse, print_verbose: Callable, data: str, timeout: Optional[Union[float, httpx.Timeout]], encoding, logging_obj, stream, optional_params: dict, litellm_params=None, logger_fn=None, headers={}, client: Optional[AsyncHTTPHandler] = None, ) -> CustomStreamWrapper: # The call is not made here; instead, we prepare the necessary objects for the stream. streaming_response = CustomStreamWrapper( completion_stream=None, make_call=partial( make_call, client=client, api_base=api_base, headers=headers, data=data, model=model, messages=messages, logging_obj=logging_obj, fake_stream=True if "ai21" in api_base else False, ), model=model, custom_llm_provider="bedrock", logging_obj=logging_obj, ) return streaming_response def get_response_stream_shape(): global _response_stream_shape_cache if _response_stream_shape_cache is None: from botocore.loaders import Loader from botocore.model import ServiceModel loader = Loader() bedrock_service_dict = loader.load_service_model("bedrock-runtime", "service-2") bedrock_service_model = ServiceModel(bedrock_service_dict) _response_stream_shape_cache = bedrock_service_model.shape_for("ResponseStream") return _response_stream_shape_cache class AWSEventStreamDecoder: def __init__(self, model: str) -> None: from botocore.parsers import EventStreamJSONParser self.model = model self.parser = EventStreamJSONParser() self.content_blocks: List[ContentBlockDeltaEvent] = [] def check_empty_tool_call_args(self) -> bool: """ Check if the tool call block so far has been an empty string """ args = "" # if text content block -> skip if len(self.content_blocks) == 0: return False if "text" in self.content_blocks[0]: return False for block in self.content_blocks: if "toolUse" in block: args += block["toolUse"]["input"] if len(args) == 0: return True return False def converse_chunk_parser(self, chunk_data: dict) -> GChunk: try: verbose_logger.debug("\n\nRaw Chunk: {}\n\n".format(chunk_data)) text = "" tool_use: Optional[ChatCompletionToolCallChunk] = None is_finished = False finish_reason = "" usage: Optional[ChatCompletionUsageBlock] = None index = int(chunk_data.get("contentBlockIndex", 0)) if "start" in chunk_data: start_obj = ContentBlockStartEvent(**chunk_data["start"]) self.content_blocks = [] # reset if ( start_obj is not None and "toolUse" in start_obj and start_obj["toolUse"] is not None ): ## check tool name was formatted by litellm _response_tool_name = start_obj["toolUse"]["name"] response_tool_name = get_bedrock_tool_name( response_tool_name=_response_tool_name ) tool_use = { "id": start_obj["toolUse"]["toolUseId"], "type": "function", "function": { "name": response_tool_name, "arguments": "", }, "index": index, } elif "delta" in chunk_data: delta_obj = ContentBlockDeltaEvent(**chunk_data["delta"]) self.content_blocks.append(delta_obj) if "text" in delta_obj: text = delta_obj["text"] elif "toolUse" in delta_obj: tool_use = { "id": None, "type": "function", "function": { "name": None, "arguments": delta_obj["toolUse"]["input"], }, "index": index, } elif ( "contentBlockIndex" in chunk_data ): # stop block, no 'start' or 'delta' object is_empty = self.check_empty_tool_call_args() if is_empty: tool_use = { "id": None, "type": "function", "function": { "name": None, "arguments": "{}", }, "index": chunk_data["contentBlockIndex"], } elif "stopReason" in chunk_data: finish_reason = map_finish_reason(chunk_data.get("stopReason", "stop")) is_finished = True elif "usage" in chunk_data: usage = ChatCompletionUsageBlock( prompt_tokens=chunk_data.get("inputTokens", 0), completion_tokens=chunk_data.get("outputTokens", 0), total_tokens=chunk_data.get("totalTokens", 0), ) response = GChunk( text=text, tool_use=tool_use, is_finished=is_finished, finish_reason=finish_reason, usage=usage, index=index, ) if "trace" in chunk_data: trace = chunk_data.get("trace") response["provider_specific_fields"] = {"trace": trace} return response except Exception as e: raise Exception("Received streaming error - {}".format(str(e))) def _chunk_parser(self, chunk_data: dict) -> GChunk: text = "" is_finished = False finish_reason = "" if "outputText" in chunk_data: text = chunk_data["outputText"] # ai21 mapping elif "ai21" in self.model: # fake ai21 streaming text = chunk_data.get("completions")[0].get("data").get("text") # type: ignore is_finished = True finish_reason = "stop" ######## bedrock.anthropic mappings ############### elif ( "contentBlockIndex" in chunk_data or "stopReason" in chunk_data or "metrics" in chunk_data or "trace" in chunk_data ): return self.converse_chunk_parser(chunk_data=chunk_data) ######## bedrock.mistral mappings ############### elif "outputs" in chunk_data: if ( len(chunk_data["outputs"]) == 1 and chunk_data["outputs"][0].get("text", None) is not None ): text = chunk_data["outputs"][0]["text"] stop_reason = chunk_data.get("stop_reason", None) if stop_reason is not None: is_finished = True finish_reason = stop_reason ######## bedrock.cohere mappings ############### # meta mapping elif "generation" in chunk_data: text = chunk_data["generation"] # bedrock.meta # cohere mapping elif "text" in chunk_data: text = chunk_data["text"] # bedrock.cohere # cohere mapping for finish reason elif "finish_reason" in chunk_data: finish_reason = chunk_data["finish_reason"] is_finished = True elif chunk_data.get("completionReason", None): is_finished = True finish_reason = chunk_data["completionReason"] return GChunk( text=text, is_finished=is_finished, finish_reason=finish_reason, usage=None, index=0, tool_use=None, ) def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GChunk]: """Given an iterator that yields lines, iterate over it & yield every event encountered""" from botocore.eventstream import EventStreamBuffer event_stream_buffer = EventStreamBuffer() for chunk in iterator: event_stream_buffer.add_data(chunk) for event in event_stream_buffer: message = self._parse_message_from_event(event) if message: # sse_event = ServerSentEvent(data=message, event="completion") _data = json.loads(message) yield self._chunk_parser(chunk_data=_data) async def aiter_bytes( self, iterator: AsyncIterator[bytes] ) -> AsyncIterator[GChunk]: """Given an async iterator that yields lines, iterate over it & yield every event encountered""" from botocore.eventstream import EventStreamBuffer event_stream_buffer = EventStreamBuffer() async for chunk in iterator: event_stream_buffer.add_data(chunk) for event in event_stream_buffer: message = self._parse_message_from_event(event) if message: _data = json.loads(message) yield self._chunk_parser(chunk_data=_data) def _parse_message_from_event(self, event) -> Optional[str]: response_dict = event.to_response_dict() parsed_response = self.parser.parse(response_dict, get_response_stream_shape()) if response_dict["status_code"] != 200: raise ValueError(f"Bad response code, expected 200: {response_dict}") if "chunk" in parsed_response: chunk = parsed_response.get("chunk") if not chunk: return None return chunk.get("bytes").decode() # type: ignore[no-any-return] else: chunk = response_dict.get("body") if not chunk: return None return chunk.decode() # type: ignore[no-any-return] class MockResponseIterator: # for returning ai21 streaming responses def __init__(self, model_response, json_mode: Optional[bool] = False): self.model_response = model_response self.json_mode = json_mode self.is_done = False # Sync iterator def __iter__(self): return self def _handle_json_mode_chunk( self, text: str, tool_calls: Optional[List[ChatCompletionToolCallChunk]] ) -> Tuple[str, Optional[ChatCompletionToolCallChunk]]: """ If JSON mode is enabled, convert the tool call to a message. Bedrock returns the JSON schema as part of the tool call OpenAI returns the JSON schema as part of the content, this handles placing it in the content Args: text: str tool_use: Optional[ChatCompletionToolCallChunk] Returns: Tuple[str, Optional[ChatCompletionToolCallChunk]] text: The text to use in the content tool_use: The ChatCompletionToolCallChunk to use in the chunk response """ tool_use: Optional[ChatCompletionToolCallChunk] = None if self.json_mode is True and tool_calls is not None: message = litellm.AnthropicConfig()._convert_tool_response_to_message( tool_calls=tool_calls ) if message is not None: text = message.content or "" tool_use = None elif tool_calls is not None and len(tool_calls) > 0: tool_use = tool_calls[0] return text, tool_use def _chunk_parser(self, chunk_data: ModelResponse) -> GChunk: try: chunk_usage: Usage = getattr(chunk_data, "usage") text = chunk_data.choices[0].message.content or "" # type: ignore tool_use = None if self.json_mode is True: text, tool_use = self._handle_json_mode_chunk( text=text, tool_calls=chunk_data.choices[0].message.tool_calls, # type: ignore ) processed_chunk = GChunk( text=text, tool_use=tool_use, is_finished=True, finish_reason=map_finish_reason( finish_reason=chunk_data.choices[0].finish_reason or "" ), usage=ChatCompletionUsageBlock( prompt_tokens=chunk_usage.prompt_tokens, completion_tokens=chunk_usage.completion_tokens, total_tokens=chunk_usage.total_tokens, ), index=0, ) return processed_chunk except Exception as e: raise ValueError(f"Failed to decode chunk: {chunk_data}. Error: {e}") def __next__(self): if self.is_done: raise StopIteration self.is_done = True return self._chunk_parser(self.model_response) # Async iterator def __aiter__(self): return self async def __anext__(self): if self.is_done: raise StopAsyncIteration self.is_done = True return self._chunk_parser(self.model_response)