diff --git a/litellm/llms/bedrock_httpx.py b/litellm/llms/bedrock_httpx.py index 9e39b81e5d..4f5d4f2636 100644 --- a/litellm/llms/bedrock_httpx.py +++ b/litellm/llms/bedrock_httpx.py @@ -419,9 +419,14 @@ class BedrockLLM(BaseLLM): + completion_response["usage"]["output_tokens"], ) setattr(model_response, "usage", _usage) - else: - outputText = completion_response["completion"] - model_response["finish_reason"] = completion_response["stop_reason"] + else: + outputText = completion_response["completion"] + + model_response["finish_reason"] = completion_response["stop_reason"] + elif provider == "ai21": + outputText = ( + completion_response.get("completions")[0].get("data").get("text") + ) except Exception as e: raise BedrockError( message="Error processing={}, Received error={}".format( @@ -430,6 +435,49 @@ class BedrockLLM(BaseLLM): status_code=422, ) + try: + if ( + len(outputText) > 0 + and hasattr(model_response.choices[0], "message") + and getattr(model_response.choices[0].message, "tool_calls", None) + is None + ): + model_response["choices"][0]["message"]["content"] = outputText + elif ( + hasattr(model_response.choices[0], "message") + and getattr(model_response.choices[0].message, "tool_calls", None) + is not None + ): + pass + else: + raise Exception() + except: + raise BedrockError( + message=json.dumps(outputText), status_code=response.status_code + ) + + if stream and provider == "ai21": + streaming_model_response = ModelResponse(stream=True) + streaming_model_response.choices[0].finish_reason = model_response.choices[ # type: ignore + 0 + ].finish_reason + # streaming_model_response.choices = [litellm.utils.StreamingChoices()] + streaming_choice = litellm.utils.StreamingChoices() + streaming_choice.index = model_response.choices[0].index + delta_obj = litellm.utils.Delta( + content=getattr(model_response.choices[0].message, "content", None), + role=model_response.choices[0].message.role, + ) + streaming_choice.delta = delta_obj + streaming_model_response.choices = [streaming_choice] + mri = ModelResponseIterator(model_response=streaming_model_response) + return CustomStreamWrapper( + completion_stream=mri, + model=model, + custom_llm_provider="cached_response", + logging_obj=logging_obj, + ) + ## CALCULATING USAGE - bedrock returns usage in the headers bedrock_input_tokens = response.headers.get( "x-amzn-bedrock-input-token-count", None @@ -489,6 +537,7 @@ class BedrockLLM(BaseLLM): ## SETUP ## stream = optional_params.pop("stream", None) + provider = model.split(".")[0] ## CREDENTIALS ## # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them @@ -544,14 +593,13 @@ class BedrockLLM(BaseLLM): else: endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com" - if stream is not None and stream == True: + if (stream is not None and stream == True) and provider != "ai21": endpoint_url = f"{endpoint_url}/model/{model}/invoke-with-response-stream" else: endpoint_url = f"{endpoint_url}/model/{model}/invoke" sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name) - provider = model.split(".")[0] prompt, chat_history = self.convert_messages_to_prompt( model, messages, provider, custom_prompt_dict ) @@ -633,6 +681,16 @@ class BedrockLLM(BaseLLM): ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in inference_params[k] = v data = json.dumps({"prompt": prompt, **inference_params}) + elif provider == "ai21": + ## LOAD CONFIG + config = litellm.AmazonAI21Config.get_config() + for k, v in config.items(): + if ( + k not in inference_params + ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in + inference_params[k] = v + + data = json.dumps({"prompt": prompt, **inference_params}) else: raise Exception("UNSUPPORTED PROVIDER") @@ -662,7 +720,7 @@ class BedrockLLM(BaseLLM): if acompletion: if isinstance(client, HTTPHandler): client = None - if stream: + if stream == True and provider != "ai21": return self.async_streaming( model=model, messages=messages, @@ -691,7 +749,7 @@ class BedrockLLM(BaseLLM): encoding=encoding, logging_obj=logging_obj, optional_params=optional_params, - stream=False, + stream=stream, # type: ignore litellm_params=litellm_params, logger_fn=logger_fn, headers=prepped.headers, @@ -708,7 +766,7 @@ class BedrockLLM(BaseLLM): self.client = HTTPHandler(**_params) # type: ignore else: self.client = client - if stream is not None and stream == True: + if (stream is not None and stream == True) and provider != "ai21": response = self.client.post( url=prepped.url, headers=prepped.headers, # type: ignore @@ -787,7 +845,7 @@ class BedrockLLM(BaseLLM): model=model, response=response, model_response=model_response, - stream=stream, + stream=stream if isinstance(stream, bool) else False, logging_obj=logging_obj, api_key="", data=data, diff --git a/litellm/main.py b/litellm/main.py index 73acf00153..769b5964a0 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -328,7 +328,7 @@ async def acompletion( or custom_llm_provider == "predibase" or ( custom_llm_provider == "bedrock" - and ("cohere" in model or "anthropic" in model) + and ("cohere" in model or "anthropic" in model or "ai21" in model) ) or custom_llm_provider in litellm.openai_compatible_providers ): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all. @@ -1982,7 +1982,7 @@ def completion( # boto3 reads keys from .env custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict - if "cohere" in model or "anthropic" in model: + if "cohere" in model or "anthropic" in model or "ai21" in model: response = bedrock_chat_completion.completion( model=model, messages=messages, diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 3b8845fb5b..f3ec308fbb 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -2665,7 +2665,12 @@ def response_format_tests(response: litellm.ModelResponse): @pytest.mark.parametrize("sync_mode", [True, False]) @pytest.mark.parametrize( "model", - ["bedrock/cohere.command-r-plus-v1:0", "anthropic.claude-3-sonnet-20240229-v1:0"], + [ + "bedrock/cohere.command-r-plus-v1:0", + "anthropic.claude-3-sonnet-20240229-v1:0", + "anthropic.claude-instant-v1", + "bedrock/ai21.j2-mid", + ], ) @pytest.mark.asyncio async def test_completion_bedrock_httpx_models(sync_mode, model): @@ -2675,6 +2680,8 @@ async def test_completion_bedrock_httpx_models(sync_mode, model): response = completion( model=model, messages=[{"role": "user", "content": "Hey! how's it going?"}], + temperature=0.2, + max_tokens=200, ) assert isinstance(response, litellm.ModelResponse) @@ -2684,6 +2691,8 @@ async def test_completion_bedrock_httpx_models(sync_mode, model): response = await litellm.acompletion( model=model, messages=[{"role": "user", "content": "Hey! how's it going?"}], + temperature=0.2, + max_tokens=200, ) assert isinstance(response, litellm.ModelResponse) @@ -2740,48 +2749,9 @@ def test_completion_bedrock_titan(): # test_completion_bedrock_titan() -def test_completion_bedrock_claude(): - print("calling claude") - try: - response = completion( - model="anthropic.claude-instant-v1", - messages=messages, - max_tokens=10, - temperature=0.1, - logger_fn=logger_fn, - ) - # Add any assertions here to check the response - print(response) - except RateLimitError: - pass - except Exception as e: - pytest.fail(f"Error occurred: {e}") - - # test_completion_bedrock_claude() -def test_completion_bedrock_cohere(): - print("calling bedrock cohere") - litellm.set_verbose = True - try: - response = completion( - model="bedrock/cohere.command-text-v14", - messages=[{"role": "user", "content": "hi"}], - temperature=0.1, - max_tokens=10, - stream=True, - ) - # Add any assertions here to check the response - print(response) - for chunk in response: - print(chunk) - except RateLimitError: - pass - except Exception as e: - pytest.fail(f"Error occurred: {e}") - - # test_completion_bedrock_cohere() @@ -2804,23 +2774,6 @@ def test_completion_bedrock_cohere(): # pytest.fail(f"Error occurred: {e}") # test_completion_bedrock_claude_stream() -# def test_completion_bedrock_ai21(): -# try: -# litellm.set_verbose = False -# response = completion( -# model="bedrock/ai21.j2-mid", -# messages=messages, -# temperature=0.2, -# top_p=0.2, -# max_tokens=20 -# ) -# # Add any assertions here to check the response -# print(response) -# except RateLimitError: -# pass -# except Exception as e: -# pytest.fail(f"Error occurred: {e}") - ######## Test VLLM ######## # def test_completion_vllm(): diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index 8c125198f9..e4aa8b1356 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -1044,8 +1044,10 @@ async def test_completion_replicate_llama3_streaming(sync_mode): @pytest.mark.parametrize( "model", [ - # "bedrock/cohere.command-r-plus-v1:0", - "anthropic.claude-3-sonnet-20240229-v1:0" + "bedrock/cohere.command-r-plus-v1:0", + "anthropic.claude-3-sonnet-20240229-v1:0", + "anthropic.claude-instant-v1", + "bedrock/ai21.j2-mid", ], ) @pytest.mark.asyncio diff --git a/litellm/utils.py b/litellm/utils.py index 82a33f7ad2..51f31a1ff3 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -11510,7 +11510,11 @@ class CustomStreamWrapper: or self.custom_llm_provider == "predibase" or ( self.custom_llm_provider == "bedrock" - and ("cohere" in self.model or "anthropic" in self.model) + and ( + "cohere" in self.model + or "anthropic" in self.model + or "ai21" in self.model + ) ) or self.custom_llm_provider in litellm.openai_compatible_endpoints ):