forked from phoenix/litellm-mirror
pass trace through for bedrock guardrails
This commit is contained in:
parent
026c7194f8
commit
89ba7b3e11
3 changed files with 82 additions and 23 deletions
|
@ -1365,6 +1365,10 @@ class BedrockConverseLLM(BaseAWSLLM):
|
||||||
)
|
)
|
||||||
setattr(model_response, "usage", usage)
|
setattr(model_response, "usage", usage)
|
||||||
|
|
||||||
|
# Add "trace" from Bedrock guardrails - if user has opted in to returning it
|
||||||
|
if "trace" in completion_response:
|
||||||
|
setattr(model_response, "trace", completion_response["trace"])
|
||||||
|
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
def encode_model_id(self, model_id: str) -> str:
|
def encode_model_id(self, model_id: str) -> str:
|
||||||
|
@ -1900,6 +1904,10 @@ class AWSEventStreamDecoder:
|
||||||
usage=usage,
|
usage=usage,
|
||||||
index=index,
|
index=index,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if "trace" in chunk_data:
|
||||||
|
trace = chunk_data.get("trace")
|
||||||
|
response["provider_specific_fields"] = {"trace": trace}
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception("Received streaming error - {}".format(str(e)))
|
raise Exception("Received streaming error - {}".format(str(e)))
|
||||||
|
@ -1920,6 +1928,7 @@ class AWSEventStreamDecoder:
|
||||||
"contentBlockIndex" in chunk_data
|
"contentBlockIndex" in chunk_data
|
||||||
or "stopReason" in chunk_data
|
or "stopReason" in chunk_data
|
||||||
or "metrics" in chunk_data
|
or "metrics" in chunk_data
|
||||||
|
or "trace" in chunk_data
|
||||||
):
|
):
|
||||||
return self.converse_chunk_parser(chunk_data=chunk_data)
|
return self.converse_chunk_parser(chunk_data=chunk_data)
|
||||||
######## bedrock.mistral mappings ###############
|
######## bedrock.mistral mappings ###############
|
||||||
|
|
|
@ -82,12 +82,18 @@ def test_completion_bedrock_claude_completion_auth():
|
||||||
# test_completion_bedrock_claude_completion_auth()
|
# test_completion_bedrock_claude_completion_auth()
|
||||||
|
|
||||||
|
|
||||||
def test_completion_bedrock_guardrails():
|
@pytest.mark.parametrize("streaming", [True, False])
|
||||||
|
def test_completion_bedrock_guardrails(streaming):
|
||||||
import os
|
import os
|
||||||
|
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from litellm._logging import verbose_logger
|
||||||
|
|
||||||
|
# verbose_logger.setLevel(logging.DEBUG)
|
||||||
try:
|
try:
|
||||||
|
if streaming is False:
|
||||||
response = completion(
|
response = completion(
|
||||||
model="anthropic.claude-v2",
|
model="anthropic.claude-v2",
|
||||||
messages=[
|
messages=[
|
||||||
|
@ -100,7 +106,7 @@ def test_completion_bedrock_guardrails():
|
||||||
guardrailConfig={
|
guardrailConfig={
|
||||||
"guardrailIdentifier": "ff6ujrregl1q",
|
"guardrailIdentifier": "ff6ujrregl1q",
|
||||||
"guardrailVersion": "DRAFT",
|
"guardrailVersion": "DRAFT",
|
||||||
"trace": "disabled",
|
"trace": "enabled",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
# Add any assertions here to check the response
|
# Add any assertions here to check the response
|
||||||
|
@ -109,6 +115,41 @@ def test_completion_bedrock_guardrails():
|
||||||
"Sorry, the model cannot answer this question. coffee guardrail applied"
|
"Sorry, the model cannot answer this question. coffee guardrail applied"
|
||||||
in response.choices[0].message.content
|
in response.choices[0].message.content
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert "trace" in response
|
||||||
|
assert response.trace is not None
|
||||||
|
|
||||||
|
print("TRACE=", response.trace)
|
||||||
|
else:
|
||||||
|
|
||||||
|
response = completion(
|
||||||
|
model="anthropic.claude-v2",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"content": "where do i buy coffee from? ",
|
||||||
|
"role": "user",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
stream=True,
|
||||||
|
max_tokens=10,
|
||||||
|
guardrailConfig={
|
||||||
|
"guardrailIdentifier": "ff6ujrregl1q",
|
||||||
|
"guardrailVersion": "DRAFT",
|
||||||
|
"trace": "enabled",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
saw_trace = False
|
||||||
|
|
||||||
|
for chunk in response:
|
||||||
|
if "trace" in chunk:
|
||||||
|
saw_trace = True
|
||||||
|
print(chunk)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
saw_trace is True
|
||||||
|
), "Did not see trace in response even when trace=enabled sent in the guardrailConfig"
|
||||||
|
|
||||||
except RateLimitError:
|
except RateLimitError:
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
@ -9579,6 +9579,7 @@ class CustomStreamWrapper:
|
||||||
):
|
):
|
||||||
|
|
||||||
if self.received_finish_reason is not None:
|
if self.received_finish_reason is not None:
|
||||||
|
if "provider_specific_fields" not in chunk:
|
||||||
raise StopIteration
|
raise StopIteration
|
||||||
anthropic_response_obj: GChunk = chunk
|
anthropic_response_obj: GChunk = chunk
|
||||||
completion_obj["content"] = anthropic_response_obj["text"]
|
completion_obj["content"] = anthropic_response_obj["text"]
|
||||||
|
@ -9602,6 +9603,14 @@ class CustomStreamWrapper:
|
||||||
):
|
):
|
||||||
completion_obj["tool_calls"] = [anthropic_response_obj["tool_use"]]
|
completion_obj["tool_calls"] = [anthropic_response_obj["tool_use"]]
|
||||||
|
|
||||||
|
if (
|
||||||
|
"provider_specific_fields" in anthropic_response_obj
|
||||||
|
and anthropic_response_obj["provider_specific_fields"] is not None
|
||||||
|
):
|
||||||
|
for key, value in anthropic_response_obj[
|
||||||
|
"provider_specific_fields"
|
||||||
|
].items():
|
||||||
|
setattr(model_response, key, value)
|
||||||
response_obj = anthropic_response_obj
|
response_obj = anthropic_response_obj
|
||||||
elif (
|
elif (
|
||||||
self.custom_llm_provider
|
self.custom_llm_provider
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue