diff --git a/litellm/llms/bedrock_httpx.py b/litellm/llms/bedrock_httpx.py index 2c0e41b1d..2d24af877 100644 --- a/litellm/llms/bedrock_httpx.py +++ b/litellm/llms/bedrock_httpx.py @@ -7,7 +7,18 @@ import json from enum import Enum import requests, copy # type: ignore import time -from typing import Callable, Optional, List, Literal, Union, Any, TypedDict, Tuple +from typing import ( + Callable, + Optional, + List, + Literal, + Union, + Any, + TypedDict, + Tuple, + Iterator, + AsyncIterator, +) from litellm.utils import ( ModelResponse, Usage, @@ -330,10 +341,10 @@ class BedrockLLM(BaseLLM): encoding, logging_obj, optional_params: dict, + acompletion: bool, timeout: Optional[Union[float, httpx.Timeout]], litellm_params=None, logger_fn=None, - acompletion: bool = False, extra_headers: Optional[dict] = None, client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None, ) -> Union[ModelResponse, CustomStreamWrapper]: @@ -346,6 +357,9 @@ class BedrockLLM(BaseLLM): except ImportError as e: raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") + ## SETUP ## + stream = optional_params.pop("stream", None) + ## CREDENTIALS ## # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) @@ -400,7 +414,10 @@ class BedrockLLM(BaseLLM): else: endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com" - endpoint_url = f"{endpoint_url}/model/{model}/invoke" + if stream is not None and stream == True: + endpoint_url = f"{endpoint_url}/model/{model}/invoke-with-response-stream" + else: + endpoint_url = f"{endpoint_url}/model/{model}/invoke" sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name) @@ -409,7 +426,6 @@ class BedrockLLM(BaseLLM): model, messages, provider, custom_prompt_dict ) inference_params = copy.deepcopy(optional_params) - stream = inference_params.pop("stream", False) if provider == "cohere": if model.startswith("cohere.command-r"): @@ -420,11 +436,6 @@ class BedrockLLM(BaseLLM): k not in inference_params ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in inference_params[k] = v - if optional_params.get("stream", False) == True: - inference_params["stream"] = ( - True # cohere requires stream = True in inference params - ) - _data = {"message": prompt, **inference_params} if chat_history is not None: _data["chat_history"] = chat_history @@ -437,7 +448,7 @@ class BedrockLLM(BaseLLM): k not in inference_params ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in inference_params[k] = v - if optional_params.get("stream", False) == True: + if stream == True: inference_params["stream"] = ( True # cohere requires stream = True in inference params ) @@ -446,6 +457,7 @@ class BedrockLLM(BaseLLM): raise Exception("UNSUPPORTED PROVIDER") ## COMPLETION CALL + headers = {"Content-Type": "application/json"} if extra_headers is not None: headers = {"Content-Type": "application/json", **extra_headers} @@ -455,11 +467,39 @@ class BedrockLLM(BaseLLM): sigv4.add_auth(request) prepped = request.prepare() + ## LOGGING + logging_obj.pre_call( + input=messages, + api_key="", + additional_args={ + "complete_input_dict": data, + "api_base": prepped.url, + "headers": prepped.headers, + }, + ) + ### ROUTING (ASYNC, STREAMING, SYNC) if acompletion: if isinstance(client, HTTPHandler): client = None - + if stream: + return self.async_streaming( + model=model, + messages=messages, + data=data, + api_base=prepped.url, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + logging_obj=logging_obj, + optional_params=optional_params, + stream=True, + litellm_params=litellm_params, + logger_fn=logger_fn, + headers=prepped.headers, + timeout=timeout, + client=client, + ) # type: ignore ### ASYNC COMPLETION return self.async_completion( model=model, @@ -488,17 +528,29 @@ class BedrockLLM(BaseLLM): self.client = HTTPHandler(**_params) # type: ignore else: self.client = client + if stream is not None and stream == True: + response = self.client.post( + url=prepped.url, + headers=prepped.headers, # type: ignore + data=data, + stream=stream, + ) - ## LOGGING - logging_obj.pre_call( - input=messages, - api_key="", - additional_args={ - "complete_input_dict": data, - "api_base": prepped.url, - "headers": prepped.headers, - }, - ) + if response.status_code != 200: + raise BedrockError( + status_code=response.status_code, message=response.text + ) + + decoder = AWSEventStreamDecoder() + + completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024)) + streaming_response = CustomStreamWrapper( + completion_stream=completion_stream, + model=model, + custom_llm_provider="bedrock", + logging_obj=logging_obj, + ) + return streaming_response response = self.client.post(url=prepped.url, headers=prepped.headers, data=data) # type: ignore @@ -565,5 +617,117 @@ class BedrockLLM(BaseLLM): encoding=encoding, ) + async def async_streaming( + self, + model: str, + messages: list, + api_base: str, + model_response: ModelResponse, + print_verbose: Callable, + data: str, + timeout: Optional[Union[float, httpx.Timeout]], + encoding, + logging_obj, + stream, + optional_params: dict, + litellm_params=None, + logger_fn=None, + headers={}, + client: Optional[AsyncHTTPHandler] = None, + ) -> CustomStreamWrapper: + if client is None: + _params = {} + if timeout is not None: + if isinstance(timeout, float) or isinstance(timeout, int): + timeout = httpx.Timeout(timeout) + _params["timeout"] = timeout + self.client = AsyncHTTPHandler(**_params) # type: ignore + else: + self.client = client # type: ignore + + response = await self.client.post(api_base, headers=headers, data=data, stream=True) # type: ignore + + if response.status_code != 200: + raise BedrockError(status_code=response.status_code, message=response.text) + + decoder = AWSEventStreamDecoder() + + completion_stream = decoder.aiter_bytes(response.aiter_bytes(chunk_size=1024)) + streaming_response = CustomStreamWrapper( + completion_stream=completion_stream, + model=model, + custom_llm_provider="bedrock", + logging_obj=logging_obj, + ) + return streaming_response + def embedding(self, *args, **kwargs): return super().embedding(*args, **kwargs) + + +def get_response_stream_shape(): + from botocore.model import ServiceModel + from botocore.loaders import Loader + + loader = Loader() + bedrock_service_dict = loader.load_service_model("bedrock-runtime", "service-2") + bedrock_service_model = ServiceModel(bedrock_service_dict) + return bedrock_service_model.shape_for("ResponseStream") + + +class AWSEventStreamDecoder: + def __init__(self) -> None: + from botocore.parsers import EventStreamJSONParser + + self.parser = EventStreamJSONParser() + + def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GenericStreamingChunk]: + """Given an iterator that yields lines, iterate over it & yield every event encountered""" + from botocore.eventstream import EventStreamBuffer + + event_stream_buffer = EventStreamBuffer() + for chunk in iterator: + event_stream_buffer.add_data(chunk) + for event in event_stream_buffer: + message = self._parse_message_from_event(event) + if message: + # sse_event = ServerSentEvent(data=message, event="completion") + _data = json.loads(message) + streaming_chunk: GenericStreamingChunk = GenericStreamingChunk( + text=_data.get("text", ""), + is_finished=_data.get("is_finished", False), + finish_reason=_data.get("finish_reason", ""), + ) + yield streaming_chunk + + async def aiter_bytes( + self, iterator: AsyncIterator[bytes] + ) -> AsyncIterator[GenericStreamingChunk]: + """Given an async iterator that yields lines, iterate over it & yield every event encountered""" + from botocore.eventstream import EventStreamBuffer + + event_stream_buffer = EventStreamBuffer() + async for chunk in iterator: + event_stream_buffer.add_data(chunk) + for event in event_stream_buffer: + message = self._parse_message_from_event(event) + if message: + _data = json.loads(message) + streaming_chunk: GenericStreamingChunk = GenericStreamingChunk( + text=_data.get("text", ""), + is_finished=_data.get("is_finished", False), + finish_reason=_data.get("finish_reason", ""), + ) + yield streaming_chunk + + def _parse_message_from_event(self, event) -> str | None: + response_dict = event.to_response_dict() + parsed_response = self.parser.parse(response_dict, get_response_stream_shape()) + if response_dict["status_code"] != 200: + raise ValueError(f"Bad response code, expected 200: {response_dict}") + + chunk = parsed_response.get("chunk") + if not chunk: + return None + + return chunk.get("bytes").decode() # type: ignore[no-any-return] diff --git a/litellm/llms/custom_httpx/http_handler.py b/litellm/llms/custom_httpx/http_handler.py index 529ba3b39..0adbd95bf 100644 --- a/litellm/llms/custom_httpx/http_handler.py +++ b/litellm/llms/custom_httpx/http_handler.py @@ -91,11 +91,15 @@ class HTTPHandler: def post( self, url: str, - data: Optional[dict] = None, + data: Optional[Union[dict, str]] = None, params: Optional[dict] = None, headers: Optional[dict] = None, + stream: bool = False, ): - response = self.client.post(url, data=data, params=params, headers=headers) + req = self.client.build_request( + "POST", url, data=data, params=params, headers=headers # type: ignore + ) + response = self.client.send(req, stream=stream) return response def __del__(self) -> None: diff --git a/litellm/main.py b/litellm/main.py index d2f3939fd..8e150f3e6 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -257,7 +257,7 @@ async def acompletion( - If `stream` is True, the function returns an async generator that yields completion lines. """ loop = asyncio.get_event_loop() - custom_llm_provider = None + custom_llm_provider = kwargs.get("custom_llm_provider", None) # Adjusted to use explicit arguments instead of *args and **kwargs completion_kwargs = { "model": model, @@ -289,9 +289,10 @@ async def acompletion( "model_list": model_list, "acompletion": True, # assuming this is a required parameter } - _, custom_llm_provider, _, _ = get_llm_provider( - model=model, api_base=completion_kwargs.get("base_url", None) - ) + if custom_llm_provider is None: + _, custom_llm_provider, _, _ = get_llm_provider( + model=model, api_base=completion_kwargs.get("base_url", None) + ) try: # Use a partial function to pass your keyword arguments func = partial(completion, **completion_kwargs, **kwargs) @@ -300,9 +301,6 @@ async def acompletion( ctx = contextvars.copy_context() func_with_context = partial(ctx.run, func) - _, custom_llm_provider, _, _ = get_llm_provider( - model=model, api_base=kwargs.get("api_base", None) - ) if ( custom_llm_provider == "openai" or custom_llm_provider == "azure" @@ -324,6 +322,7 @@ async def acompletion( or custom_llm_provider == "sagemaker" or custom_llm_provider == "anthropic" or custom_llm_provider == "predibase" + or (custom_llm_provider == "bedrock" and "cohere" in model) or custom_llm_provider in litellm.openai_compatible_providers ): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all. init_response = await loop.run_in_executor(None, func_with_context) @@ -1937,6 +1936,7 @@ def completion( logging_obj=logging, extra_headers=extra_headers, timeout=timeout, + acompletion=acompletion, ) else: response = bedrock.completion( @@ -1954,26 +1954,26 @@ def completion( timeout=timeout, ) - if ( - "stream" in optional_params - and optional_params["stream"] == True - and not isinstance(response, CustomStreamWrapper) - ): - # don't try to access stream object, - if "ai21" in model: - response = CustomStreamWrapper( - response, - model, - custom_llm_provider="bedrock", - logging_obj=logging, - ) - else: - response = CustomStreamWrapper( - iter(response), - model, - custom_llm_provider="bedrock", - logging_obj=logging, - ) + if ( + "stream" in optional_params + and optional_params["stream"] == True + and not isinstance(response, CustomStreamWrapper) + ): + # don't try to access stream object, + if "ai21" in model: + response = CustomStreamWrapper( + response, + model, + custom_llm_provider="bedrock", + logging_obj=logging, + ) + else: + response = CustomStreamWrapper( + iter(response), + model, + custom_llm_provider="bedrock", + logging_obj=logging, + ) if optional_params.get("stream", False): ## LOGGING diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index 5c0e17a3e..13f6c651b 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -984,6 +984,65 @@ def test_vertex_ai_stream(): # pytest.fail(f"Error occurred: {e}") +@pytest.mark.parametrize("sync_mode", [True]) +@pytest.mark.asyncio +async def test_bedrock_cohere_command_r_streaming(sync_mode): + try: + litellm.set_verbose = True + if sync_mode: + final_chunk: Optional[litellm.ModelResponse] = None + response: litellm.CustomStreamWrapper = completion( # type: ignore + model="bedrock/cohere.command-r-plus-v1:0", + messages=messages, + max_tokens=10, # type: ignore + stream=True, + ) + complete_response = "" + # Add any assertions here to check the response + has_finish_reason = False + for idx, chunk in enumerate(response): + final_chunk = chunk + chunk, finished = streaming_format_tests(idx, chunk) + if finished: + has_finish_reason = True + break + complete_response += chunk + if has_finish_reason == False: + raise Exception("finish reason not set") + if complete_response.strip() == "": + raise Exception("Empty response received") + else: + response: litellm.CustomStreamWrapper = await litellm.acompletion( # type: ignore + model="bedrock/cohere.command-r-plus-v1:0", + messages=messages, + max_tokens=100, # type: ignore + stream=True, + ) + complete_response = "" + # Add any assertions here to check the response + has_finish_reason = False + idx = 0 + final_chunk: Optional[litellm.ModelResponse] = None + async for chunk in response: + final_chunk = chunk + chunk, finished = streaming_format_tests(idx, chunk) + if finished: + has_finish_reason = True + break + complete_response += chunk + idx += 1 + if has_finish_reason == False: + raise Exception("finish reason not set") + if complete_response.strip() == "": + raise Exception("Empty response received") + print(f"completion_response: {complete_response}\n\nFinalChunk: {final_chunk}") + raise Exception("it worked!") + except RateLimitError: + pass + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + def test_bedrock_claude_3_streaming(): try: litellm.set_verbose = True diff --git a/litellm/types/llms/bedrock.py b/litellm/types/llms/bedrock.py index 87ef6fd3c..529ab71f2 100644 --- a/litellm/types/llms/bedrock.py +++ b/litellm/types/llms/bedrock.py @@ -1,6 +1,63 @@ -from typing import TypedDict +from typing import TypedDict, Any +import json +from typing_extensions import ( + Self, + Protocol, + TypeGuard, + override, + get_origin, + runtime_checkable, + Required, +) + + +class GenericStreamingChunk(TypedDict): + text: Required[str] + is_finished: Required[bool] + finish_reason: Required[str] class Document(TypedDict): title: str snippet: str + + +class ServerSentEvent: + def __init__( + self, + *, + event: str | None = None, + data: str | None = None, + id: str | None = None, + retry: int | None = None, + ) -> None: + if data is None: + data = "" + + self._id = id + self._data = data + self._event = event or None + self._retry = retry + + @property + def event(self) -> str | None: + return self._event + + @property + def id(self) -> str | None: + return self._id + + @property + def retry(self) -> int | None: + return self._retry + + @property + def data(self) -> str: + return self._data + + def json(self) -> Any: + return json.loads(self.data) + + @override + def __repr__(self) -> str: + return f"ServerSentEvent(event={self.event}, data={self.data}, id={self.id}, retry={self.retry})" diff --git a/litellm/utils.py b/litellm/utils.py index 0fd7963ae..6ceb5fecc 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -10262,6 +10262,12 @@ class CustomStreamWrapper: raise e def handle_bedrock_stream(self, chunk): + if "cohere" in self.model: + return { + "text": chunk["text"], + "is_finished": chunk["is_finished"], + "finish_reason": chunk["finish_reason"], + } if hasattr(chunk, "get"): chunk = chunk.get("chunk") chunk_data = json.loads(chunk.get("bytes").decode()) @@ -11068,6 +11074,7 @@ class CustomStreamWrapper: or self.custom_llm_provider == "gemini" or self.custom_llm_provider == "cached_response" or self.custom_llm_provider == "predibase" + or (self.custom_llm_provider == "bedrock" and "cohere" in self.model) or self.custom_llm_provider in litellm.openai_compatible_endpoints ): async for chunk in self.completion_stream: