feat(bedrock_httpx.py): working bedrock converse api streaming

This commit is contained in:
Krrish Dholakia 2024-06-06 22:13:21 -07:00
parent a995a0b172
commit 51ba5652a0
6 changed files with 165 additions and 25 deletions

View file

@ -5,7 +5,7 @@ warnings.filterwarnings("ignore", message=".*conflict with protected namespace.*
### INIT VARIABLES ### ### INIT VARIABLES ###
import threading, requests, os import threading, requests, os
from typing import Callable, List, Optional, Dict, Union, Any, Literal from typing import Callable, List, Optional, Dict, Union, Any, Literal
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.caching import Cache from litellm.caching import Cache
from litellm._logging import ( from litellm._logging import (
set_verbose, set_verbose,
@ -232,6 +232,7 @@ max_end_user_budget: Optional[float] = None
#### RELIABILITY #### #### RELIABILITY ####
request_timeout: float = 6000 request_timeout: float = 6000
module_level_aclient = AsyncHTTPHandler(timeout=request_timeout) module_level_aclient = AsyncHTTPHandler(timeout=request_timeout)
module_level_client = HTTPHandler(timeout=request_timeout)
num_retries: Optional[int] = None # per model endpoint num_retries: Optional[int] = None # per model endpoint
default_fallbacks: Optional[List] = None default_fallbacks: Optional[List] = None
fallbacks: Optional[List] = None fallbacks: Optional[List] = None

View file

@ -185,6 +185,37 @@ async def make_call(
return completion_stream return completion_stream
def make_sync_call(
client: Optional[HTTPHandler],
api_base: str,
headers: dict,
data: str,
model: str,
messages: list,
logging_obj,
):
if client is None:
client = HTTPHandler() # Create a new client if none provided
response = client.post(api_base, headers=headers, data=data, stream=True)
if response.status_code != 200:
raise BedrockError(status_code=response.status_code, message=response.read())
decoder = AWSEventStreamDecoder(model=model)
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024))
# LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response=completion_stream, # Pass the completion stream for logging
additional_args={"complete_input_dict": data},
)
return completion_stream
class BedrockLLM(BaseLLM): class BedrockLLM(BaseLLM):
""" """
Example call Example call
@ -1081,6 +1112,7 @@ class BedrockLLM(BaseLLM):
class AmazonConverseConfig: class AmazonConverseConfig:
""" """
Reference - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html Reference - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html
#2 - https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#conversation-inference-supported-models-features
""" """
maxTokens: Optional[int] maxTokens: Optional[int]
@ -1118,30 +1150,32 @@ class AmazonConverseConfig:
and v is not None and v is not None
} }
def get_supported_openai_params(self) -> List[str]: def get_supported_openai_params(self, model: str) -> List[str]:
return [ supported_params = [
"max_tokens", "max_tokens",
"stream", "stream",
"stream_options", "stream_options",
"stop", "stop",
"temperature", "temperature",
"top_p", "top_p",
"tools",
"tool_choice",
] ]
if (
model.startswith("anthropic")
or model.startswith("mistral")
or model.startswith("cohere")
):
supported_params.append("tools")
if model.startswith("anthropic") or model.startswith("mistral"):
# only anthropic and mistral support tool choice config. otherwise (E.g. cohere) will fail the call - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html
supported_params.append("tool_choice")
return supported_params
def map_tool_choice_values( def map_tool_choice_values(
self, model: str, tool_choice: Union[str, dict], drop_params: bool self, model: str, tool_choice: Union[str, dict], drop_params: bool
) -> Optional[ToolChoiceValuesBlock]: ) -> Optional[ToolChoiceValuesBlock]:
if not model.startswith("anthropic") and not model.startswith("mistral"):
# only anthropic and mistral support tool choice config. otherwise (E.g. cohere) will fail the call - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html
if drop_params == True or litellm.drop_params == True:
return None
else:
raise litellm.utils.UnsupportedParamsError(
message="Only Anthropic and Mistral on Bedrock support 'tool_choice'. To drop it from the call, set `litellm.drop_params = True.`",
status_code=400,
)
if tool_choice == "none": if tool_choice == "none":
if litellm.drop_params is True or drop_params is True: if litellm.drop_params is True or drop_params is True:
return None return None
@ -1197,7 +1231,7 @@ class AmazonConverseConfig:
optional_params["tools"] = value optional_params["tools"] = value
if param == "tool_choice": if param == "tool_choice":
_tool_choice_value = self.map_tool_choice_values( _tool_choice_value = self.map_tool_choice_values(
model=model, tool_choice=value, drop_params=drop_params model=model, tool_choice=value, drop_params=drop_params # type: ignore
) )
if _tool_choice_value is not None: if _tool_choice_value is not None:
optional_params["tool_choice"] = _tool_choice_value optional_params["tool_choice"] = _tool_choice_value
@ -1539,7 +1573,7 @@ class BedrockConverseLLM(BaseLLM):
else: else:
endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com" endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
if (stream is not None and stream == True) and provider != "ai21": if (stream is not None and stream is True) and provider != "ai21":
endpoint_url = f"{endpoint_url}/model/{modelId}/converse-stream" endpoint_url = f"{endpoint_url}/model/{modelId}/converse-stream"
else: else:
endpoint_url = f"{endpoint_url}/model/{modelId}/converse" endpoint_url = f"{endpoint_url}/model/{modelId}/converse"
@ -1561,7 +1595,7 @@ class BedrockConverseLLM(BaseLLM):
inference_params = copy.deepcopy(optional_params) inference_params = copy.deepcopy(optional_params)
additional_request_keys = [] additional_request_keys = []
additional_request_params = {} additional_request_params = {}
supported_converse_params = AmazonConverseConfig().get_config().keys() supported_converse_params = AmazonConverseConfig.__annotations__.keys()
supported_tool_call_params = ["tools", "tool_choice"] supported_tool_call_params = ["tools", "tool_choice"]
## TRANSFORMATION ## ## TRANSFORMATION ##
# send all model-specific params in 'additional_request_params' # send all model-specific params in 'additional_request_params'
@ -1596,6 +1630,7 @@ class BedrockConverseLLM(BaseLLM):
"messages": bedrock_messages, "messages": bedrock_messages,
"additionalModelRequestFields": additional_request_params, "additionalModelRequestFields": additional_request_params,
"system": system_content_blocks, "system": system_content_blocks,
"inferenceConfig": InferenceConfig(**inference_params),
} }
if bedrock_tool_config is not None: if bedrock_tool_config is not None:
_data["toolConfig"] = bedrock_tool_config _data["toolConfig"] = bedrock_tool_config
@ -1623,7 +1658,35 @@ class BedrockConverseLLM(BaseLLM):
) )
### ROUTING (ASYNC, STREAMING, SYNC) ### ROUTING (ASYNC, STREAMING, SYNC)
if (stream is not None and stream is True) and provider != "ai21":
streaming_response = CustomStreamWrapper(
completion_stream=None,
make_call=partial(
make_sync_call,
client=None,
api_base=prepped.url,
headers=prepped.headers,
data=data,
model=model,
messages=messages,
logging_obj=logging_obj,
),
model=model,
custom_llm_provider="bedrock",
logging_obj=logging_obj,
)
## LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response=streaming_response,
additional_args={"complete_input_dict": data},
)
return streaming_response
### COMPLETION ### COMPLETION
if client is None or isinstance(client, AsyncHTTPHandler): if client is None or isinstance(client, AsyncHTTPHandler):
_params = {} _params = {}
if timeout is not None: if timeout is not None:
@ -1675,6 +1738,31 @@ class AWSEventStreamDecoder:
self.parser = EventStreamJSONParser() self.parser = EventStreamJSONParser()
def _chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk: def _chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk:
text = ""
tool_str = ""
is_finished = False
finish_reason = ""
usage: Optional[ConverseTokenUsageBlock] = None
if "delta" in chunk_data:
delta_obj = ContentBlockDeltaEvent(**chunk_data["delta"])
if "text" in delta_obj:
text = delta_obj["text"]
elif "toolUse" in delta_obj:
tool_str = delta_obj["toolUse"]["input"]
elif "stopReason" in chunk_data:
finish_reason = map_finish_reason(chunk_data.get("stopReason", "stop"))
elif "usage" in chunk_data:
usage = ConverseTokenUsageBlock(**chunk_data["usage"])
response = GenericStreamingChunk(
text=text,
tool_str=tool_str,
is_finished=is_finished,
finish_reason=finish_reason,
usage=usage,
)
return response
def _old_chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk:
text = "" text = ""
is_finished = False is_finished = False
finish_reason = "" finish_reason = ""
@ -1763,12 +1851,11 @@ class AWSEventStreamDecoder:
def _parse_message_from_event(self, event) -> Optional[str]: def _parse_message_from_event(self, event) -> Optional[str]:
response_dict = event.to_response_dict() response_dict = event.to_response_dict()
parsed_response = self.parser.parse(response_dict, get_response_stream_shape())
if response_dict["status_code"] != 200: if response_dict["status_code"] != 200:
raise ValueError(f"Bad response code, expected 200: {response_dict}") raise ValueError(f"Bad response code, expected 200: {response_dict}")
chunk = parsed_response.get("chunk") chunk = response_dict.get("body")
if not chunk: if not chunk:
return None return None
return chunk.get("bytes").decode() # type: ignore[no-any-return] return chunk.decode() # type: ignore[no-any-return]

View file

@ -168,6 +168,7 @@ class HTTPHandler:
return response return response
def __del__(self) -> None: def __del__(self) -> None:
traceback.print_stack()
try: try:
self.close() self.close()
except Exception: except Exception:

View file

@ -1284,7 +1284,7 @@ async def test_completion_replicate_llama3_streaming(sync_mode):
# pytest.fail(f"Error occurred: {e}") # pytest.fail(f"Error occurred: {e}")
@pytest.mark.parametrize("sync_mode", [True, False]) @pytest.mark.parametrize("sync_mode", [True]) # False
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model", "model",
[ [
@ -1324,6 +1324,8 @@ async def test_bedrock_httpx_streaming(sync_mode, model):
raise Exception("finish reason not set") raise Exception("finish reason not set")
if complete_response.strip() == "": if complete_response.strip() == "":
raise Exception("Empty response received") raise Exception("Empty response received")
assert False
else: else:
response: litellm.CustomStreamWrapper = await litellm.acompletion( # type: ignore response: litellm.CustomStreamWrapper = await litellm.acompletion( # type: ignore
model=model, model=model,

View file

@ -107,10 +107,30 @@ class ToolConfigBlock(TypedDict, total=False):
toolChoice: Union[str, ToolChoiceValuesBlock] toolChoice: Union[str, ToolChoiceValuesBlock]
class InferenceConfig(TypedDict, total=False):
maxTokens: int
stopSequences: List[str]
temperature: float
topP: float
class ToolBlockDeltaEvent(TypedDict):
input: str
class ContentBlockDeltaEvent(TypedDict, total=False):
"""
Either 'text' or 'toolUse' will be specified for Converse API streaming response.
"""
text: str
toolUse: ToolBlockDeltaEvent
class RequestObject(TypedDict, total=False): class RequestObject(TypedDict, total=False):
additionalModelRequestFields: dict additionalModelRequestFields: dict
additionalModelResponseFieldPaths: List[str] additionalModelResponseFieldPaths: List[str]
inferenceConfig: dict inferenceConfig: InferenceConfig
messages: Required[List[MessageBlock]] messages: Required[List[MessageBlock]]
system: List[SystemContentBlock] system: List[SystemContentBlock]
toolConfig: ToolConfigBlock toolConfig: ToolConfigBlock
@ -118,8 +138,10 @@ class RequestObject(TypedDict, total=False):
class GenericStreamingChunk(TypedDict): class GenericStreamingChunk(TypedDict):
text: Required[str] text: Required[str]
tool_str: Required[str]
is_finished: Required[bool] is_finished: Required[bool]
finish_reason: Required[str] finish_reason: Required[str]
usage: Optional[ConverseTokenUsageBlock]
class Document(TypedDict): class Document(TypedDict):

View file

@ -239,6 +239,8 @@ def map_finish_reason(
return "length" return "length"
elif finish_reason == "tool_use": # anthropic elif finish_reason == "tool_use": # anthropic
return "tool_calls" return "tool_calls"
elif finish_reason == "content_filtered":
return "content_filter"
return finish_reason return finish_reason
@ -6330,7 +6332,7 @@ def get_supported_openai_params(
- None if unmapped - None if unmapped
""" """
if custom_llm_provider == "bedrock": if custom_llm_provider == "bedrock":
return litellm.AmazonConverseConfig().get_supported_openai_params() return litellm.AmazonConverseConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "ollama": elif custom_llm_provider == "ollama":
return litellm.OllamaConfig().get_supported_openai_params() return litellm.OllamaConfig().get_supported_openai_params()
elif custom_llm_provider == "ollama_chat": elif custom_llm_provider == "ollama_chat":
@ -11242,12 +11244,27 @@ class CustomStreamWrapper:
if response_obj["is_finished"]: if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"] self.received_finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider == "bedrock": elif self.custom_llm_provider == "bedrock":
from litellm.types.llms.bedrock import GenericStreamingChunk
if self.received_finish_reason is not None: if self.received_finish_reason is not None:
raise StopIteration raise StopIteration
response_obj = self.handle_bedrock_stream(chunk) response_obj: GenericStreamingChunk = chunk
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]: if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"] self.received_finish_reason = response_obj["finish_reason"]
if (
self.stream_options
and self.stream_options.get("include_usage", False) is True
and response_obj["usage"] is not None
):
self.sent_stream_usage = True
model_response.usage = litellm.Usage(
prompt_tokens=response_obj["usage"]["inputTokens"],
completion_tokens=response_obj["usage"]["outputTokens"],
total_tokens=response_obj["usage"]["totalTokens"],
)
elif self.custom_llm_provider == "sagemaker": elif self.custom_llm_provider == "sagemaker":
print_verbose(f"ENTERS SAGEMAKER STREAMING for chunk {chunk}") print_verbose(f"ENTERS SAGEMAKER STREAMING for chunk {chunk}")
response_obj = self.handle_sagemaker_stream(chunk) response_obj = self.handle_sagemaker_stream(chunk)
@ -11509,7 +11526,7 @@ class CustomStreamWrapper:
and hasattr(model_response, "usage") and hasattr(model_response, "usage")
and hasattr(model_response.usage, "prompt_tokens") and hasattr(model_response.usage, "prompt_tokens")
): ):
if self.sent_first_chunk == False: if self.sent_first_chunk is False:
completion_obj["role"] = "assistant" completion_obj["role"] = "assistant"
self.sent_first_chunk = True self.sent_first_chunk = True
model_response.choices[0].delta = Delta(**completion_obj) model_response.choices[0].delta = Delta(**completion_obj)
@ -11677,6 +11694,8 @@ class CustomStreamWrapper:
def __next__(self): def __next__(self):
try: try:
if self.completion_stream is None:
self.fetch_sync_stream()
while True: while True:
if ( if (
isinstance(self.completion_stream, str) isinstance(self.completion_stream, str)
@ -11751,6 +11770,14 @@ class CustomStreamWrapper:
custom_llm_provider=self.custom_llm_provider, custom_llm_provider=self.custom_llm_provider,
) )
def fetch_sync_stream(self):
if self.completion_stream is None and self.make_call is not None:
# Call make_call to get the completion stream
self.completion_stream = self.make_call(client=litellm.module_level_client)
self._stream_iter = self.completion_stream.__iter__()
return self.completion_stream
async def fetch_stream(self): async def fetch_stream(self):
if self.completion_stream is None and self.make_call is not None: if self.completion_stream is None and self.make_call is not None:
# Call make_call to get the completion stream # Call make_call to get the completion stream