diff --git a/litellm/llms/bedrock_httpx.py b/litellm/llms/bedrock_httpx.py index 1461cfd90..c3a563ce4 100644 --- a/litellm/llms/bedrock_httpx.py +++ b/litellm/llms/bedrock_httpx.py @@ -74,6 +74,7 @@ BEDROCK_CONVERSE_MODELS = [ "anthropic.claude-v2", "anthropic.claude-v2:1", "anthropic.claude-v1", + "anthropic.claude-instant-v1", "ai21.jamba-instruct-v1:0", ] @@ -195,13 +196,39 @@ async def make_call( if client is None: client = _get_async_httpx_client() # Create a new client if none provided - response = await client.post(api_base, headers=headers, data=data, stream=True) + response = await client.post( + api_base, + headers=headers, + data=data, + stream=True if "ai21" not in api_base else False, + ) if response.status_code != 200: raise BedrockError(status_code=response.status_code, message=response.text) - decoder = AWSEventStreamDecoder(model=model) - completion_stream = decoder.aiter_bytes(response.aiter_bytes(chunk_size=1024)) + if "ai21" in api_base: + aws_bedrock_process_response = BedrockConverseLLM() + model_response: ( + ModelResponse + ) = aws_bedrock_process_response.process_response( + model=model, + response=response, + model_response=litellm.ModelResponse(), + stream=True, + logging_obj=logging_obj, + optional_params={}, + api_key="", + data=data, + messages=messages, + print_verbose=litellm.print_verbose, + encoding=litellm.encoding, + ) # type: ignore + completion_stream: Any = MockResponseIterator(model_response=model_response) + else: + decoder = AWSEventStreamDecoder(model=model) + completion_stream = decoder.aiter_bytes( + response.aiter_bytes(chunk_size=1024) + ) # LOGGING logging_obj.post_call( @@ -233,13 +260,35 @@ def make_sync_call( if client is None: client = _get_httpx_client() # Create a new client if none provided - response = client.post(api_base, headers=headers, data=data, stream=True) + response = client.post( + api_base, + headers=headers, + data=data, + stream=True if "ai21" not in api_base else False, + ) if response.status_code != 200: raise BedrockError(status_code=response.status_code, message=response.read()) - decoder = AWSEventStreamDecoder(model=model) - completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024)) + if "ai21" in api_base: + aws_bedrock_process_response = BedrockConverseLLM() + model_response: ModelResponse = aws_bedrock_process_response.process_response( + model=model, + response=response, + model_response=litellm.ModelResponse(), + stream=True, + logging_obj=logging_obj, + optional_params={}, + api_key="", + data=data, + messages=messages, + print_verbose=litellm.print_verbose, + encoding=litellm.encoding, + ) # type: ignore + completion_stream: Any = MockResponseIterator(model_response=model_response) + else: + decoder = AWSEventStreamDecoder(model=model) + completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024)) # LOGGING logging_obj.post_call( @@ -1348,7 +1397,7 @@ class BedrockConverseLLM(BaseLLM): response: Union[requests.Response, httpx.Response], model_response: ModelResponse, stream: bool, - logging_obj: Logging, + logging_obj: Optional[Logging], optional_params: dict, api_key: str, data: Union[dict, str], @@ -1358,12 +1407,13 @@ class BedrockConverseLLM(BaseLLM): ) -> Union[ModelResponse, CustomStreamWrapper]: ## LOGGING - logging_obj.post_call( - input=messages, - api_key=api_key, - original_response=response.text, - additional_args={"complete_input_dict": data}, - ) + if logging_obj is not None: + logging_obj.post_call( + input=messages, + api_key=api_key, + original_response=response.text, + additional_args={"complete_input_dict": data}, + ) print_verbose(f"raw model_response: {response.text}") ## RESPONSE OBJECT @@ -1900,7 +1950,7 @@ class BedrockConverseLLM(BaseLLM): if acompletion: if isinstance(client, HTTPHandler): client = None - if stream is True and provider != "ai21": + if stream is True: return self.async_streaming( model=model, messages=messages, @@ -1937,7 +1987,7 @@ class BedrockConverseLLM(BaseLLM): client=client, ) # type: ignore - if (stream is not None and stream is True) and provider != "ai21": + if stream is not None and stream is True: streaming_response = CustomStreamWrapper( completion_stream=None, @@ -1981,7 +2031,7 @@ class BedrockConverseLLM(BaseLLM): model=model, response=response, model_response=model_response, - stream=stream, + stream=stream if isinstance(stream, bool) else False, logging_obj=logging_obj, optional_params=optional_params, api_key="", @@ -2168,3 +2218,49 @@ class AWSEventStreamDecoder: return None return chunk.decode() # type: ignore[no-any-return] + + +class MockResponseIterator: # for returning ai21 streaming responses + def __init__(self, model_response): + self.model_response = model_response + self.is_done = False + + # Sync iterator + def __iter__(self): + return self + + def _chunk_parser(self, chunk_data: ModelResponse) -> GenericStreamingChunk: + + try: + chunk_usage: litellm.Usage = getattr(chunk_data, "usage") + processed_chunk = GenericStreamingChunk( + text=chunk_data.choices[0].message.content or "", # type: ignore + tool_use=None, + is_finished=True, + finish_reason=chunk_data.choices[0].finish_reason, # type: ignore + usage=ConverseTokenUsageBlock( + inputTokens=chunk_usage.prompt_tokens, + outputTokens=chunk_usage.completion_tokens, + totalTokens=chunk_usage.total_tokens, + ), + index=0, + ) + return processed_chunk + except Exception: + raise ValueError(f"Failed to decode chunk: {chunk_data}") + + def __next__(self): + if self.is_done: + raise StopIteration + self.is_done = True + return self._chunk_parser(self.model_response) + + # Async iterator + def __aiter__(self): + return self + + async def __anext__(self): + if self.is_done: + raise StopAsyncIteration + self.is_done = True + return self._chunk_parser(self.model_response) diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index 8c7943893..d07aa681d 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -1312,22 +1312,22 @@ async def test_completion_replicate_llama3_streaming(sync_mode): # pytest.fail(f"Error occurred: {e}") -@pytest.mark.parametrize("sync_mode", [True]) # False +@pytest.mark.parametrize("sync_mode", [True, False]) # @pytest.mark.parametrize( - "model", + "model, region", [ - "bedrock/cohere.command-r-plus-v1:0", - "anthropic.claude-3-sonnet-20240229-v1:0", - "anthropic.claude-instant-v1", - "bedrock/ai21.j2-mid", - "mistral.mistral-7b-instruct-v0:2", - "bedrock/amazon.titan-tg1-large", - "meta.llama3-8b-instruct-v1:0", - "cohere.command-text-v14", + ["bedrock/ai21.jamba-instruct-v1:0", "us-east-1"], + ["bedrock/cohere.command-r-plus-v1:0", None], + ["anthropic.claude-3-sonnet-20240229-v1:0", None], + ["anthropic.claude-instant-v1", None], + ["mistral.mistral-7b-instruct-v0:2", None], + ["bedrock/amazon.titan-tg1-large", None], + ["meta.llama3-8b-instruct-v1:0", None], + ["cohere.command-text-v14", None], ], ) @pytest.mark.asyncio -async def test_bedrock_httpx_streaming(sync_mode, model): +async def test_bedrock_httpx_streaming(sync_mode, model, region): try: litellm.set_verbose = True if sync_mode: @@ -1337,6 +1337,7 @@ async def test_bedrock_httpx_streaming(sync_mode, model): messages=messages, max_tokens=10, # type: ignore stream=True, + aws_region_name=region, ) complete_response = "" # Add any assertions here to check the response @@ -1358,6 +1359,7 @@ async def test_bedrock_httpx_streaming(sync_mode, model): messages=messages, max_tokens=100, # type: ignore stream=True, + aws_region_name=region, ) complete_response = "" # Add any assertions here to check the response