mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
Merge pull request #5243 from BerriAI/litellm_add_bedrock_traces_in_response
[Feat] Add bedrock Guardrail `traces ` in response when trace=enabled
This commit is contained in:
commit
6de7785442
3 changed files with 90 additions and 23 deletions
|
@ -1365,6 +1365,10 @@ class BedrockConverseLLM(BaseAWSLLM):
|
|||
)
|
||||
setattr(model_response, "usage", usage)
|
||||
|
||||
# Add "trace" from Bedrock guardrails - if user has opted in to returning it
|
||||
if "trace" in completion_response:
|
||||
setattr(model_response, "trace", completion_response["trace"])
|
||||
|
||||
return model_response
|
||||
|
||||
def encode_model_id(self, model_id: str) -> str:
|
||||
|
@ -1900,6 +1904,10 @@ class AWSEventStreamDecoder:
|
|||
usage=usage,
|
||||
index=index,
|
||||
)
|
||||
|
||||
if "trace" in chunk_data:
|
||||
trace = chunk_data.get("trace")
|
||||
response["provider_specific_fields"] = {"trace": trace}
|
||||
return response
|
||||
except Exception as e:
|
||||
raise Exception("Received streaming error - {}".format(str(e)))
|
||||
|
@ -1920,6 +1928,7 @@ class AWSEventStreamDecoder:
|
|||
"contentBlockIndex" in chunk_data
|
||||
or "stopReason" in chunk_data
|
||||
or "metrics" in chunk_data
|
||||
or "trace" in chunk_data
|
||||
):
|
||||
return self.converse_chunk_parser(chunk_data=chunk_data)
|
||||
######## bedrock.mistral mappings ###############
|
||||
|
|
|
@ -82,33 +82,74 @@ def test_completion_bedrock_claude_completion_auth():
|
|||
# test_completion_bedrock_claude_completion_auth()
|
||||
|
||||
|
||||
def test_completion_bedrock_guardrails():
|
||||
@pytest.mark.parametrize("streaming", [True, False])
|
||||
def test_completion_bedrock_guardrails(streaming):
|
||||
import os
|
||||
|
||||
litellm.set_verbose = True
|
||||
import logging
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
|
||||
# verbose_logger.setLevel(logging.DEBUG)
|
||||
try:
|
||||
response = completion(
|
||||
model="anthropic.claude-v2",
|
||||
messages=[
|
||||
{
|
||||
"content": "where do i buy coffee from? ",
|
||||
"role": "user",
|
||||
}
|
||||
],
|
||||
max_tokens=10,
|
||||
guardrailConfig={
|
||||
"guardrailIdentifier": "ff6ujrregl1q",
|
||||
"guardrailVersion": "DRAFT",
|
||||
"trace": "disabled",
|
||||
},
|
||||
)
|
||||
# Add any assertions here to check the response
|
||||
print(response)
|
||||
assert (
|
||||
"Sorry, the model cannot answer this question. coffee guardrail applied"
|
||||
in response.choices[0].message.content
|
||||
)
|
||||
if streaming is False:
|
||||
response = completion(
|
||||
model="anthropic.claude-v2",
|
||||
messages=[
|
||||
{
|
||||
"content": "where do i buy coffee from? ",
|
||||
"role": "user",
|
||||
}
|
||||
],
|
||||
max_tokens=10,
|
||||
guardrailConfig={
|
||||
"guardrailIdentifier": "ff6ujrregl1q",
|
||||
"guardrailVersion": "DRAFT",
|
||||
"trace": "enabled",
|
||||
},
|
||||
)
|
||||
# Add any assertions here to check the response
|
||||
print(response)
|
||||
assert (
|
||||
"Sorry, the model cannot answer this question. coffee guardrail applied"
|
||||
in response.choices[0].message.content
|
||||
)
|
||||
|
||||
assert "trace" in response
|
||||
assert response.trace is not None
|
||||
|
||||
print("TRACE=", response.trace)
|
||||
else:
|
||||
|
||||
response = completion(
|
||||
model="anthropic.claude-v2",
|
||||
messages=[
|
||||
{
|
||||
"content": "where do i buy coffee from? ",
|
||||
"role": "user",
|
||||
}
|
||||
],
|
||||
stream=True,
|
||||
max_tokens=10,
|
||||
guardrailConfig={
|
||||
"guardrailIdentifier": "ff6ujrregl1q",
|
||||
"guardrailVersion": "DRAFT",
|
||||
"trace": "enabled",
|
||||
},
|
||||
)
|
||||
|
||||
saw_trace = False
|
||||
|
||||
for chunk in response:
|
||||
if "trace" in chunk:
|
||||
saw_trace = True
|
||||
print(chunk)
|
||||
|
||||
assert (
|
||||
saw_trace is True
|
||||
), "Did not see trace in response even when trace=enabled sent in the guardrailConfig"
|
||||
|
||||
except RateLimitError:
|
||||
pass
|
||||
except Exception as e:
|
||||
|
|
|
@ -9581,7 +9581,8 @@ class CustomStreamWrapper:
|
|||
):
|
||||
|
||||
if self.received_finish_reason is not None:
|
||||
raise StopIteration
|
||||
if "provider_specific_fields" not in chunk:
|
||||
raise StopIteration
|
||||
anthropic_response_obj: GChunk = chunk
|
||||
completion_obj["content"] = anthropic_response_obj["text"]
|
||||
if anthropic_response_obj["is_finished"]:
|
||||
|
@ -9604,6 +9605,14 @@ class CustomStreamWrapper:
|
|||
):
|
||||
completion_obj["tool_calls"] = [anthropic_response_obj["tool_use"]]
|
||||
|
||||
if (
|
||||
"provider_specific_fields" in anthropic_response_obj
|
||||
and anthropic_response_obj["provider_specific_fields"] is not None
|
||||
):
|
||||
for key, value in anthropic_response_obj[
|
||||
"provider_specific_fields"
|
||||
].items():
|
||||
setattr(model_response, key, value)
|
||||
response_obj = anthropic_response_obj
|
||||
elif (
|
||||
self.custom_llm_provider
|
||||
|
@ -10219,6 +10228,14 @@ class CustomStreamWrapper:
|
|||
return
|
||||
elif self.received_finish_reason is not None:
|
||||
if self.sent_last_chunk is True:
|
||||
# Bedrock returns the guardrail trace in the last chunk - we want to return this here
|
||||
if (
|
||||
self.custom_llm_provider == "bedrock"
|
||||
and "trace" in model_response
|
||||
):
|
||||
return model_response
|
||||
|
||||
# Default - return StopIteration
|
||||
raise StopIteration
|
||||
# flush any remaining holding chunk
|
||||
if len(self.holding_chunk) > 0:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue