From ee9359adad03273c27a7c8bd9d106823194d3102 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 16 Aug 2024 08:44:29 -0700 Subject: [PATCH 1/3] add provider_specific_fields to GenericStreamingChunk --- litellm/types/utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/litellm/types/utils.py b/litellm/types/utils.py index 354f0f03f8..c78cc2edcf 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -88,6 +88,9 @@ class GenericStreamingChunk(TypedDict, total=False): usage: Optional[ChatCompletionUsageBlock] index: int + # use this dict if you want to return any provider specific fields in the response + provider_specific_fields: Optional[Dict[str, Any]] + from enum import Enum From 98c9191f843f1785d7955c90abf54da1b873a941 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 16 Aug 2024 09:10:56 -0700 Subject: [PATCH 2/3] pass trace through for bedrock guardrails --- litellm/llms/bedrock_httpx.py | 9 +++ litellm/tests/test_bedrock_completion.py | 85 ++++++++++++++++++------ litellm/utils.py | 11 ++- 3 files changed, 82 insertions(+), 23 deletions(-) diff --git a/litellm/llms/bedrock_httpx.py b/litellm/llms/bedrock_httpx.py index 73387212ff..e455597524 100644 --- a/litellm/llms/bedrock_httpx.py +++ b/litellm/llms/bedrock_httpx.py @@ -1365,6 +1365,10 @@ class BedrockConverseLLM(BaseAWSLLM): ) setattr(model_response, "usage", usage) + # Add "trace" from Bedrock guardrails - if user has opted in to returning it + if "trace" in completion_response: + setattr(model_response, "trace", completion_response["trace"]) + return model_response def encode_model_id(self, model_id: str) -> str: @@ -1900,6 +1904,10 @@ class AWSEventStreamDecoder: usage=usage, index=index, ) + + if "trace" in chunk_data: + trace = chunk_data.get("trace") + response["provider_specific_fields"] = {"trace": trace} return response except Exception as e: raise Exception("Received streaming error - {}".format(str(e))) @@ -1920,6 +1928,7 @@ class AWSEventStreamDecoder: "contentBlockIndex" in chunk_data or "stopReason" in chunk_data or "metrics" in chunk_data + or "trace" in chunk_data ): return self.converse_chunk_parser(chunk_data=chunk_data) ######## bedrock.mistral mappings ############### diff --git a/litellm/tests/test_bedrock_completion.py b/litellm/tests/test_bedrock_completion.py index c331021213..4892601b15 100644 --- a/litellm/tests/test_bedrock_completion.py +++ b/litellm/tests/test_bedrock_completion.py @@ -82,33 +82,74 @@ def test_completion_bedrock_claude_completion_auth(): # test_completion_bedrock_claude_completion_auth() -def test_completion_bedrock_guardrails(): +@pytest.mark.parametrize("streaming", [True, False]) +def test_completion_bedrock_guardrails(streaming): import os litellm.set_verbose = True + import logging + from litellm._logging import verbose_logger + + # verbose_logger.setLevel(logging.DEBUG) try: - response = completion( - model="anthropic.claude-v2", - messages=[ - { - "content": "where do i buy coffee from? ", - "role": "user", - } - ], - max_tokens=10, - guardrailConfig={ - "guardrailIdentifier": "ff6ujrregl1q", - "guardrailVersion": "DRAFT", - "trace": "disabled", - }, - ) - # Add any assertions here to check the response - print(response) - assert ( - "Sorry, the model cannot answer this question. coffee guardrail applied" - in response.choices[0].message.content - ) + if streaming is False: + response = completion( + model="anthropic.claude-v2", + messages=[ + { + "content": "where do i buy coffee from? ", + "role": "user", + } + ], + max_tokens=10, + guardrailConfig={ + "guardrailIdentifier": "ff6ujrregl1q", + "guardrailVersion": "DRAFT", + "trace": "enabled", + }, + ) + # Add any assertions here to check the response + print(response) + assert ( + "Sorry, the model cannot answer this question. coffee guardrail applied" + in response.choices[0].message.content + ) + + assert "trace" in response + assert response.trace is not None + + print("TRACE=", response.trace) + else: + + response = completion( + model="anthropic.claude-v2", + messages=[ + { + "content": "where do i buy coffee from? ", + "role": "user", + } + ], + stream=True, + max_tokens=10, + guardrailConfig={ + "guardrailIdentifier": "ff6ujrregl1q", + "guardrailVersion": "DRAFT", + "trace": "enabled", + }, + ) + + saw_trace = False + + for chunk in response: + if "trace" in chunk: + saw_trace = True + print(chunk) + + assert ( + saw_trace is True + ), "Did not see trace in response even when trace=enabled sent in the guardrailConfig" + except RateLimitError: pass except Exception as e: diff --git a/litellm/utils.py b/litellm/utils.py index 0875b0e0e5..414550be48 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -9579,7 +9579,8 @@ class CustomStreamWrapper: ): if self.received_finish_reason is not None: - raise StopIteration + if "provider_specific_fields" not in chunk: + raise StopIteration anthropic_response_obj: GChunk = chunk completion_obj["content"] = anthropic_response_obj["text"] if anthropic_response_obj["is_finished"]: @@ -9602,6 +9603,14 @@ class CustomStreamWrapper: ): completion_obj["tool_calls"] = [anthropic_response_obj["tool_use"]] + if ( + "provider_specific_fields" in anthropic_response_obj + and anthropic_response_obj["provider_specific_fields"] is not None + ): + for key, value in anthropic_response_obj[ + "provider_specific_fields" + ].items(): + setattr(model_response, key, value) response_obj = anthropic_response_obj elif ( self.custom_llm_provider From 262bf1491778e00ba81f1be9c1ab3771f290829c Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 16 Aug 2024 11:35:43 -0700 Subject: [PATCH 3/3] return traces in bedrock guardrails when enabled --- litellm/utils.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/litellm/utils.py b/litellm/utils.py index 414550be48..d915690fc6 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -10226,6 +10226,14 @@ class CustomStreamWrapper: return elif self.received_finish_reason is not None: if self.sent_last_chunk is True: + # Bedrock returns the guardrail trace in the last chunk - we want to return this here + if ( + self.custom_llm_provider == "bedrock" + and "trace" in model_response + ): + return model_response + + # Default - return StopIteration raise StopIteration # flush any remaining holding chunk if len(self.holding_chunk) > 0: