pass trace through for bedrock guardrails

This commit is contained in:
Ishaan Jaff 2024-08-16 09:10:56 -07:00
parent 026c7194f8
commit 89ba7b3e11
3 changed files with 82 additions and 23 deletions

View file

@ -1365,6 +1365,10 @@ class BedrockConverseLLM(BaseAWSLLM):
) )
setattr(model_response, "usage", usage) setattr(model_response, "usage", usage)
# Add "trace" from Bedrock guardrails - if user has opted in to returning it
if "trace" in completion_response:
setattr(model_response, "trace", completion_response["trace"])
return model_response return model_response
def encode_model_id(self, model_id: str) -> str: def encode_model_id(self, model_id: str) -> str:
@ -1900,6 +1904,10 @@ class AWSEventStreamDecoder:
usage=usage, usage=usage,
index=index, index=index,
) )
if "trace" in chunk_data:
trace = chunk_data.get("trace")
response["provider_specific_fields"] = {"trace": trace}
return response return response
except Exception as e: except Exception as e:
raise Exception("Received streaming error - {}".format(str(e))) raise Exception("Received streaming error - {}".format(str(e)))
@ -1920,6 +1928,7 @@ class AWSEventStreamDecoder:
"contentBlockIndex" in chunk_data "contentBlockIndex" in chunk_data
or "stopReason" in chunk_data or "stopReason" in chunk_data
or "metrics" in chunk_data or "metrics" in chunk_data
or "trace" in chunk_data
): ):
return self.converse_chunk_parser(chunk_data=chunk_data) return self.converse_chunk_parser(chunk_data=chunk_data)
######## bedrock.mistral mappings ############### ######## bedrock.mistral mappings ###############

View file

@ -82,12 +82,18 @@ def test_completion_bedrock_claude_completion_auth():
# test_completion_bedrock_claude_completion_auth() # test_completion_bedrock_claude_completion_auth()
def test_completion_bedrock_guardrails(): @pytest.mark.parametrize("streaming", [True, False])
def test_completion_bedrock_guardrails(streaming):
import os import os
litellm.set_verbose = True litellm.set_verbose = True
import logging
from litellm._logging import verbose_logger
# verbose_logger.setLevel(logging.DEBUG)
try: try:
if streaming is False:
response = completion( response = completion(
model="anthropic.claude-v2", model="anthropic.claude-v2",
messages=[ messages=[
@ -100,7 +106,7 @@ def test_completion_bedrock_guardrails():
guardrailConfig={ guardrailConfig={
"guardrailIdentifier": "ff6ujrregl1q", "guardrailIdentifier": "ff6ujrregl1q",
"guardrailVersion": "DRAFT", "guardrailVersion": "DRAFT",
"trace": "disabled", "trace": "enabled",
}, },
) )
# Add any assertions here to check the response # Add any assertions here to check the response
@ -109,6 +115,41 @@ def test_completion_bedrock_guardrails():
"Sorry, the model cannot answer this question. coffee guardrail applied" "Sorry, the model cannot answer this question. coffee guardrail applied"
in response.choices[0].message.content in response.choices[0].message.content
) )
assert "trace" in response
assert response.trace is not None
print("TRACE=", response.trace)
else:
response = completion(
model="anthropic.claude-v2",
messages=[
{
"content": "where do i buy coffee from? ",
"role": "user",
}
],
stream=True,
max_tokens=10,
guardrailConfig={
"guardrailIdentifier": "ff6ujrregl1q",
"guardrailVersion": "DRAFT",
"trace": "enabled",
},
)
saw_trace = False
for chunk in response:
if "trace" in chunk:
saw_trace = True
print(chunk)
assert (
saw_trace is True
), "Did not see trace in response even when trace=enabled sent in the guardrailConfig"
except RateLimitError: except RateLimitError:
pass pass
except Exception as e: except Exception as e:

View file

@ -9579,6 +9579,7 @@ class CustomStreamWrapper:
): ):
if self.received_finish_reason is not None: if self.received_finish_reason is not None:
if "provider_specific_fields" not in chunk:
raise StopIteration raise StopIteration
anthropic_response_obj: GChunk = chunk anthropic_response_obj: GChunk = chunk
completion_obj["content"] = anthropic_response_obj["text"] completion_obj["content"] = anthropic_response_obj["text"]
@ -9602,6 +9603,14 @@ class CustomStreamWrapper:
): ):
completion_obj["tool_calls"] = [anthropic_response_obj["tool_use"]] completion_obj["tool_calls"] = [anthropic_response_obj["tool_use"]]
if (
"provider_specific_fields" in anthropic_response_obj
and anthropic_response_obj["provider_specific_fields"] is not None
):
for key, value in anthropic_response_obj[
"provider_specific_fields"
].items():
setattr(model_response, key, value)
response_obj = anthropic_response_obj response_obj = anthropic_response_obj
elif ( elif (
self.custom_llm_provider self.custom_llm_provider