diff --git a/litellm/__init__.py b/litellm/__init__.py index 2fc47a992..fe0dd2a56 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -5,7 +5,7 @@ warnings.filterwarnings("ignore", message=".*conflict with protected namespace.* ### INIT VARIABLES ### import threading, requests, os from typing import Callable, List, Optional, Dict, Union, Any, Literal -from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.caching import Cache from litellm._logging import ( set_verbose, @@ -232,6 +232,7 @@ max_end_user_budget: Optional[float] = None #### RELIABILITY #### request_timeout: float = 6000 module_level_aclient = AsyncHTTPHandler(timeout=request_timeout) +module_level_client = HTTPHandler(timeout=request_timeout) num_retries: Optional[int] = None # per model endpoint default_fallbacks: Optional[List] = None fallbacks: Optional[List] = None diff --git a/litellm/llms/bedrock_httpx.py b/litellm/llms/bedrock_httpx.py index ce6a93174..6329f165e 100644 --- a/litellm/llms/bedrock_httpx.py +++ b/litellm/llms/bedrock_httpx.py @@ -185,6 +185,37 @@ async def make_call( return completion_stream +def make_sync_call( + client: Optional[HTTPHandler], + api_base: str, + headers: dict, + data: str, + model: str, + messages: list, + logging_obj, +): + if client is None: + client = HTTPHandler() # Create a new client if none provided + + response = client.post(api_base, headers=headers, data=data, stream=True) + + if response.status_code != 200: + raise BedrockError(status_code=response.status_code, message=response.read()) + + decoder = AWSEventStreamDecoder(model=model) + completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024)) + + # LOGGING + logging_obj.post_call( + input=messages, + api_key="", + original_response=completion_stream, # Pass the completion stream for logging + additional_args={"complete_input_dict": data}, + ) + + return completion_stream + + class BedrockLLM(BaseLLM): """ Example call @@ -1081,6 +1112,7 @@ class BedrockLLM(BaseLLM): class AmazonConverseConfig: """ Reference - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html + #2 - https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#conversation-inference-supported-models-features """ maxTokens: Optional[int] @@ -1118,30 +1150,32 @@ class AmazonConverseConfig: and v is not None } - def get_supported_openai_params(self) -> List[str]: - return [ + def get_supported_openai_params(self, model: str) -> List[str]: + supported_params = [ "max_tokens", "stream", "stream_options", "stop", "temperature", "top_p", - "tools", - "tool_choice", ] + if ( + model.startswith("anthropic") + or model.startswith("mistral") + or model.startswith("cohere") + ): + supported_params.append("tools") + + if model.startswith("anthropic") or model.startswith("mistral"): + # only anthropic and mistral support tool choice config. otherwise (E.g. cohere) will fail the call - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html + supported_params.append("tool_choice") + + return supported_params + def map_tool_choice_values( self, model: str, tool_choice: Union[str, dict], drop_params: bool ) -> Optional[ToolChoiceValuesBlock]: - if not model.startswith("anthropic") and not model.startswith("mistral"): - # only anthropic and mistral support tool choice config. otherwise (E.g. cohere) will fail the call - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html - if drop_params == True or litellm.drop_params == True: - return None - else: - raise litellm.utils.UnsupportedParamsError( - message="Only Anthropic and Mistral on Bedrock support 'tool_choice'. To drop it from the call, set `litellm.drop_params = True.`", - status_code=400, - ) if tool_choice == "none": if litellm.drop_params is True or drop_params is True: return None @@ -1197,7 +1231,7 @@ class AmazonConverseConfig: optional_params["tools"] = value if param == "tool_choice": _tool_choice_value = self.map_tool_choice_values( - model=model, tool_choice=value, drop_params=drop_params + model=model, tool_choice=value, drop_params=drop_params # type: ignore ) if _tool_choice_value is not None: optional_params["tool_choice"] = _tool_choice_value @@ -1539,7 +1573,7 @@ class BedrockConverseLLM(BaseLLM): else: endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com" - if (stream is not None and stream == True) and provider != "ai21": + if (stream is not None and stream is True) and provider != "ai21": endpoint_url = f"{endpoint_url}/model/{modelId}/converse-stream" else: endpoint_url = f"{endpoint_url}/model/{modelId}/converse" @@ -1561,7 +1595,7 @@ class BedrockConverseLLM(BaseLLM): inference_params = copy.deepcopy(optional_params) additional_request_keys = [] additional_request_params = {} - supported_converse_params = AmazonConverseConfig().get_config().keys() + supported_converse_params = AmazonConverseConfig.__annotations__.keys() supported_tool_call_params = ["tools", "tool_choice"] ## TRANSFORMATION ## # send all model-specific params in 'additional_request_params' @@ -1596,6 +1630,7 @@ class BedrockConverseLLM(BaseLLM): "messages": bedrock_messages, "additionalModelRequestFields": additional_request_params, "system": system_content_blocks, + "inferenceConfig": InferenceConfig(**inference_params), } if bedrock_tool_config is not None: _data["toolConfig"] = bedrock_tool_config @@ -1623,7 +1658,35 @@ class BedrockConverseLLM(BaseLLM): ) ### ROUTING (ASYNC, STREAMING, SYNC) + if (stream is not None and stream is True) and provider != "ai21": + + streaming_response = CustomStreamWrapper( + completion_stream=None, + make_call=partial( + make_sync_call, + client=None, + api_base=prepped.url, + headers=prepped.headers, + data=data, + model=model, + messages=messages, + logging_obj=logging_obj, + ), + model=model, + custom_llm_provider="bedrock", + logging_obj=logging_obj, + ) + + ## LOGGING + logging_obj.post_call( + input=messages, + api_key="", + original_response=streaming_response, + additional_args={"complete_input_dict": data}, + ) + return streaming_response ### COMPLETION + if client is None or isinstance(client, AsyncHTTPHandler): _params = {} if timeout is not None: @@ -1675,6 +1738,31 @@ class AWSEventStreamDecoder: self.parser = EventStreamJSONParser() def _chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk: + text = "" + tool_str = "" + is_finished = False + finish_reason = "" + usage: Optional[ConverseTokenUsageBlock] = None + if "delta" in chunk_data: + delta_obj = ContentBlockDeltaEvent(**chunk_data["delta"]) + if "text" in delta_obj: + text = delta_obj["text"] + elif "toolUse" in delta_obj: + tool_str = delta_obj["toolUse"]["input"] + elif "stopReason" in chunk_data: + finish_reason = map_finish_reason(chunk_data.get("stopReason", "stop")) + elif "usage" in chunk_data: + usage = ConverseTokenUsageBlock(**chunk_data["usage"]) + response = GenericStreamingChunk( + text=text, + tool_str=tool_str, + is_finished=is_finished, + finish_reason=finish_reason, + usage=usage, + ) + return response + + def _old_chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk: text = "" is_finished = False finish_reason = "" @@ -1763,12 +1851,11 @@ class AWSEventStreamDecoder: def _parse_message_from_event(self, event) -> Optional[str]: response_dict = event.to_response_dict() - parsed_response = self.parser.parse(response_dict, get_response_stream_shape()) if response_dict["status_code"] != 200: raise ValueError(f"Bad response code, expected 200: {response_dict}") - chunk = parsed_response.get("chunk") + chunk = response_dict.get("body") if not chunk: return None - return chunk.get("bytes").decode() # type: ignore[no-any-return] + return chunk.decode() # type: ignore[no-any-return] diff --git a/litellm/llms/custom_httpx/http_handler.py b/litellm/llms/custom_httpx/http_handler.py index 5ec9c79bb..d8dd4f01e 100644 --- a/litellm/llms/custom_httpx/http_handler.py +++ b/litellm/llms/custom_httpx/http_handler.py @@ -168,6 +168,7 @@ class HTTPHandler: return response def __del__(self) -> None: + traceback.print_stack() try: self.close() except Exception: diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index c24de601f..1113adc40 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -1284,7 +1284,7 @@ async def test_completion_replicate_llama3_streaming(sync_mode): # pytest.fail(f"Error occurred: {e}") -@pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.parametrize("sync_mode", [True]) # False @pytest.mark.parametrize( "model", [ @@ -1324,6 +1324,8 @@ async def test_bedrock_httpx_streaming(sync_mode, model): raise Exception("finish reason not set") if complete_response.strip() == "": raise Exception("Empty response received") + + assert False else: response: litellm.CustomStreamWrapper = await litellm.acompletion( # type: ignore model=model, diff --git a/litellm/types/llms/bedrock.py b/litellm/types/llms/bedrock.py index 647dc1d7b..757ece516 100644 --- a/litellm/types/llms/bedrock.py +++ b/litellm/types/llms/bedrock.py @@ -107,10 +107,30 @@ class ToolConfigBlock(TypedDict, total=False): toolChoice: Union[str, ToolChoiceValuesBlock] +class InferenceConfig(TypedDict, total=False): + maxTokens: int + stopSequences: List[str] + temperature: float + topP: float + + +class ToolBlockDeltaEvent(TypedDict): + input: str + + +class ContentBlockDeltaEvent(TypedDict, total=False): + """ + Either 'text' or 'toolUse' will be specified for Converse API streaming response. + """ + + text: str + toolUse: ToolBlockDeltaEvent + + class RequestObject(TypedDict, total=False): additionalModelRequestFields: dict additionalModelResponseFieldPaths: List[str] - inferenceConfig: dict + inferenceConfig: InferenceConfig messages: Required[List[MessageBlock]] system: List[SystemContentBlock] toolConfig: ToolConfigBlock @@ -118,8 +138,10 @@ class RequestObject(TypedDict, total=False): class GenericStreamingChunk(TypedDict): text: Required[str] + tool_str: Required[str] is_finished: Required[bool] finish_reason: Required[str] + usage: Optional[ConverseTokenUsageBlock] class Document(TypedDict): diff --git a/litellm/utils.py b/litellm/utils.py index 6db5f540c..75dd85328 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -239,6 +239,8 @@ def map_finish_reason( return "length" elif finish_reason == "tool_use": # anthropic return "tool_calls" + elif finish_reason == "content_filtered": + return "content_filter" return finish_reason @@ -6330,7 +6332,7 @@ def get_supported_openai_params( - None if unmapped """ if custom_llm_provider == "bedrock": - return litellm.AmazonConverseConfig().get_supported_openai_params() + return litellm.AmazonConverseConfig().get_supported_openai_params(model=model) elif custom_llm_provider == "ollama": return litellm.OllamaConfig().get_supported_openai_params() elif custom_llm_provider == "ollama_chat": @@ -11242,12 +11244,27 @@ class CustomStreamWrapper: if response_obj["is_finished"]: self.received_finish_reason = response_obj["finish_reason"] elif self.custom_llm_provider == "bedrock": + from litellm.types.llms.bedrock import GenericStreamingChunk + if self.received_finish_reason is not None: raise StopIteration - response_obj = self.handle_bedrock_stream(chunk) + response_obj: GenericStreamingChunk = chunk completion_obj["content"] = response_obj["text"] + if response_obj["is_finished"]: self.received_finish_reason = response_obj["finish_reason"] + + if ( + self.stream_options + and self.stream_options.get("include_usage", False) is True + and response_obj["usage"] is not None + ): + self.sent_stream_usage = True + model_response.usage = litellm.Usage( + prompt_tokens=response_obj["usage"]["inputTokens"], + completion_tokens=response_obj["usage"]["outputTokens"], + total_tokens=response_obj["usage"]["totalTokens"], + ) elif self.custom_llm_provider == "sagemaker": print_verbose(f"ENTERS SAGEMAKER STREAMING for chunk {chunk}") response_obj = self.handle_sagemaker_stream(chunk) @@ -11509,7 +11526,7 @@ class CustomStreamWrapper: and hasattr(model_response, "usage") and hasattr(model_response.usage, "prompt_tokens") ): - if self.sent_first_chunk == False: + if self.sent_first_chunk is False: completion_obj["role"] = "assistant" self.sent_first_chunk = True model_response.choices[0].delta = Delta(**completion_obj) @@ -11677,6 +11694,8 @@ class CustomStreamWrapper: def __next__(self): try: + if self.completion_stream is None: + self.fetch_sync_stream() while True: if ( isinstance(self.completion_stream, str) @@ -11751,6 +11770,14 @@ class CustomStreamWrapper: custom_llm_provider=self.custom_llm_provider, ) + def fetch_sync_stream(self): + if self.completion_stream is None and self.make_call is not None: + # Call make_call to get the completion stream + self.completion_stream = self.make_call(client=litellm.module_level_client) + self._stream_iter = self.completion_stream.__iter__() + + return self.completion_stream + async def fetch_stream(self): if self.completion_stream is None and self.make_call is not None: # Call make_call to get the completion stream