diff --git a/litellm/llms/bedrock_httpx.py b/litellm/llms/bedrock_httpx.py index 73387212f..e45559752 100644 --- a/litellm/llms/bedrock_httpx.py +++ b/litellm/llms/bedrock_httpx.py @@ -1365,6 +1365,10 @@ class BedrockConverseLLM(BaseAWSLLM): ) setattr(model_response, "usage", usage) + # Add "trace" from Bedrock guardrails - if user has opted in to returning it + if "trace" in completion_response: + setattr(model_response, "trace", completion_response["trace"]) + return model_response def encode_model_id(self, model_id: str) -> str: @@ -1900,6 +1904,10 @@ class AWSEventStreamDecoder: usage=usage, index=index, ) + + if "trace" in chunk_data: + trace = chunk_data.get("trace") + response["provider_specific_fields"] = {"trace": trace} return response except Exception as e: raise Exception("Received streaming error - {}".format(str(e))) @@ -1920,6 +1928,7 @@ class AWSEventStreamDecoder: "contentBlockIndex" in chunk_data or "stopReason" in chunk_data or "metrics" in chunk_data + or "trace" in chunk_data ): return self.converse_chunk_parser(chunk_data=chunk_data) ######## bedrock.mistral mappings ############### diff --git a/litellm/tests/test_bedrock_completion.py b/litellm/tests/test_bedrock_completion.py index c33102121..4892601b1 100644 --- a/litellm/tests/test_bedrock_completion.py +++ b/litellm/tests/test_bedrock_completion.py @@ -82,33 +82,74 @@ def test_completion_bedrock_claude_completion_auth(): # test_completion_bedrock_claude_completion_auth() -def test_completion_bedrock_guardrails(): +@pytest.mark.parametrize("streaming", [True, False]) +def test_completion_bedrock_guardrails(streaming): import os litellm.set_verbose = True + import logging + from litellm._logging import verbose_logger + + # verbose_logger.setLevel(logging.DEBUG) try: - response = completion( - model="anthropic.claude-v2", - messages=[ - { - "content": "where do i buy coffee from? ", - "role": "user", - } - ], - max_tokens=10, - guardrailConfig={ - "guardrailIdentifier": "ff6ujrregl1q", - "guardrailVersion": "DRAFT", - "trace": "disabled", - }, - ) - # Add any assertions here to check the response - print(response) - assert ( - "Sorry, the model cannot answer this question. coffee guardrail applied" - in response.choices[0].message.content - ) + if streaming is False: + response = completion( + model="anthropic.claude-v2", + messages=[ + { + "content": "where do i buy coffee from? ", + "role": "user", + } + ], + max_tokens=10, + guardrailConfig={ + "guardrailIdentifier": "ff6ujrregl1q", + "guardrailVersion": "DRAFT", + "trace": "enabled", + }, + ) + # Add any assertions here to check the response + print(response) + assert ( + "Sorry, the model cannot answer this question. coffee guardrail applied" + in response.choices[0].message.content + ) + + assert "trace" in response + assert response.trace is not None + + print("TRACE=", response.trace) + else: + + response = completion( + model="anthropic.claude-v2", + messages=[ + { + "content": "where do i buy coffee from? ", + "role": "user", + } + ], + stream=True, + max_tokens=10, + guardrailConfig={ + "guardrailIdentifier": "ff6ujrregl1q", + "guardrailVersion": "DRAFT", + "trace": "enabled", + }, + ) + + saw_trace = False + + for chunk in response: + if "trace" in chunk: + saw_trace = True + print(chunk) + + assert ( + saw_trace is True + ), "Did not see trace in response even when trace=enabled sent in the guardrailConfig" + except RateLimitError: pass except Exception as e: diff --git a/litellm/utils.py b/litellm/utils.py index 0875b0e0e..414550be4 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -9579,7 +9579,8 @@ class CustomStreamWrapper: ): if self.received_finish_reason is not None: - raise StopIteration + if "provider_specific_fields" not in chunk: + raise StopIteration anthropic_response_obj: GChunk = chunk completion_obj["content"] = anthropic_response_obj["text"] if anthropic_response_obj["is_finished"]: @@ -9602,6 +9603,14 @@ class CustomStreamWrapper: ): completion_obj["tool_calls"] = [anthropic_response_obj["tool_use"]] + if ( + "provider_specific_fields" in anthropic_response_obj + and anthropic_response_obj["provider_specific_fields"] is not None + ): + for key, value in anthropic_response_obj[ + "provider_specific_fields" + ].items(): + setattr(model_response, key, value) response_obj = anthropic_response_obj elif ( self.custom_llm_provider