Merge pull request #2425 from BerriAI/litellm_claude_3_bedrock_streaming

fix(bedrock.py): enable claude-3 streaming
This commit is contained in:
Krish Dholakia 2024-03-09 15:44:22 -08:00 committed by GitHub
commit a238603e66
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 36 additions and 2 deletions

View file

@ -126,6 +126,8 @@ class AmazonAnthropicClaude3Config:
optional_params["max_tokens"] = value
if param == "tools":
optional_params["tools"] = value
if param == "stream":
optional_params["stream"] = value
return optional_params

View file

@ -727,6 +727,31 @@ def test_completion_claude_stream_bad_key():
# pytest.fail(f"Error occurred: {e}")
def test_bedrock_claude_3_streaming():
try:
litellm.set_verbose = True
response: ModelResponse = completion(
model="bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
messages=messages,
max_tokens=10,
stream=True,
)
complete_response = ""
# Add any assertions here to check the response
for idx, chunk in enumerate(response):
chunk, finished = streaming_format_tests(idx, chunk)
if finished:
break
complete_response += chunk
if complete_response.strip() == "":
raise Exception("Empty response received")
print(f"completion_response: {complete_response}")
except RateLimitError:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@pytest.mark.skip(reason="Replicate changed exceptions")
def test_completion_replicate_stream_bad_key():
try:

View file

@ -8778,13 +8778,20 @@ class CustomStreamWrapper:
text = chunk_data.get("completions")[0].get("data").get("text")
is_finished = True
finish_reason = "stop"
# anthropic mapping
elif "completion" in chunk_data:
######## bedrock.anthropic mappings ###############
elif "completion" in chunk_data: # not claude-3
text = chunk_data["completion"] # bedrock.anthropic
stop_reason = chunk_data.get("stop_reason", None)
if stop_reason != None:
is_finished = True
finish_reason = stop_reason
elif "delta" in chunk_data:
if chunk_data["delta"].get("text", None) is not None:
text = chunk_data["delta"]["text"]
stop_reason = chunk_data["delta"].get("stop_reason", None)
if stop_reason != None:
is_finished = True
finish_reason = stop_reason
######## bedrock.cohere mappings ###############
# meta mapping
elif "generation" in chunk_data: