diff --git a/litellm/llms/bedrock_httpx.py b/litellm/llms/bedrock_httpx.py index 73387212ff..e455597524 100644 --- a/litellm/llms/bedrock_httpx.py +++ b/litellm/llms/bedrock_httpx.py @@ -1365,6 +1365,10 @@ class BedrockConverseLLM(BaseAWSLLM): ) setattr(model_response, "usage", usage) + # Add "trace" from Bedrock guardrails - if user has opted in to returning it + if "trace" in completion_response: + setattr(model_response, "trace", completion_response["trace"]) + return model_response def encode_model_id(self, model_id: str) -> str: @@ -1900,6 +1904,10 @@ class AWSEventStreamDecoder: usage=usage, index=index, ) + + if "trace" in chunk_data: + trace = chunk_data.get("trace") + response["provider_specific_fields"] = {"trace": trace} return response except Exception as e: raise Exception("Received streaming error - {}".format(str(e))) @@ -1920,6 +1928,7 @@ class AWSEventStreamDecoder: "contentBlockIndex" in chunk_data or "stopReason" in chunk_data or "metrics" in chunk_data + or "trace" in chunk_data ): return self.converse_chunk_parser(chunk_data=chunk_data) ######## bedrock.mistral mappings ############### diff --git a/litellm/tests/test_bedrock_completion.py b/litellm/tests/test_bedrock_completion.py index c331021213..4892601b15 100644 --- a/litellm/tests/test_bedrock_completion.py +++ b/litellm/tests/test_bedrock_completion.py @@ -82,33 +82,74 @@ def test_completion_bedrock_claude_completion_auth(): # test_completion_bedrock_claude_completion_auth() -def test_completion_bedrock_guardrails(): +@pytest.mark.parametrize("streaming", [True, False]) +def test_completion_bedrock_guardrails(streaming): import os litellm.set_verbose = True + import logging + from litellm._logging import verbose_logger + + # verbose_logger.setLevel(logging.DEBUG) try: - response = completion( - model="anthropic.claude-v2", - messages=[ - { - "content": "where do i buy coffee from? ", - "role": "user", - } - ], - max_tokens=10, - guardrailConfig={ - "guardrailIdentifier": "ff6ujrregl1q", - "guardrailVersion": "DRAFT", - "trace": "disabled", - }, - ) - # Add any assertions here to check the response - print(response) - assert ( - "Sorry, the model cannot answer this question. coffee guardrail applied" - in response.choices[0].message.content - ) + if streaming is False: + response = completion( + model="anthropic.claude-v2", + messages=[ + { + "content": "where do i buy coffee from? ", + "role": "user", + } + ], + max_tokens=10, + guardrailConfig={ + "guardrailIdentifier": "ff6ujrregl1q", + "guardrailVersion": "DRAFT", + "trace": "enabled", + }, + ) + # Add any assertions here to check the response + print(response) + assert ( + "Sorry, the model cannot answer this question. coffee guardrail applied" + in response.choices[0].message.content + ) + + assert "trace" in response + assert response.trace is not None + + print("TRACE=", response.trace) + else: + + response = completion( + model="anthropic.claude-v2", + messages=[ + { + "content": "where do i buy coffee from? ", + "role": "user", + } + ], + stream=True, + max_tokens=10, + guardrailConfig={ + "guardrailIdentifier": "ff6ujrregl1q", + "guardrailVersion": "DRAFT", + "trace": "enabled", + }, + ) + + saw_trace = False + + for chunk in response: + if "trace" in chunk: + saw_trace = True + print(chunk) + + assert ( + saw_trace is True + ), "Did not see trace in response even when trace=enabled sent in the guardrailConfig" + except RateLimitError: pass except Exception as e: diff --git a/litellm/utils.py b/litellm/utils.py index 40564c1077..f89a4a295a 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -9581,7 +9581,8 @@ class CustomStreamWrapper: ): if self.received_finish_reason is not None: - raise StopIteration + if "provider_specific_fields" not in chunk: + raise StopIteration anthropic_response_obj: GChunk = chunk completion_obj["content"] = anthropic_response_obj["text"] if anthropic_response_obj["is_finished"]: @@ -9604,6 +9605,14 @@ class CustomStreamWrapper: ): completion_obj["tool_calls"] = [anthropic_response_obj["tool_use"]] + if ( + "provider_specific_fields" in anthropic_response_obj + and anthropic_response_obj["provider_specific_fields"] is not None + ): + for key, value in anthropic_response_obj[ + "provider_specific_fields" + ].items(): + setattr(model_response, key, value) response_obj = anthropic_response_obj elif ( self.custom_llm_provider @@ -10219,6 +10228,14 @@ class CustomStreamWrapper: return elif self.received_finish_reason is not None: if self.sent_last_chunk is True: + # Bedrock returns the guardrail trace in the last chunk - we want to return this here + if ( + self.custom_llm_provider == "bedrock" + and "trace" in model_response + ): + return model_response + + # Default - return StopIteration raise StopIteration # flush any remaining holding chunk if len(self.holding_chunk) > 0: