fix(bedrock_httpx.py): add async support for bedrock amazon, meta, mistral models

This commit is contained in:
Krrish Dholakia 2024-05-16 22:39:25 -07:00
parent 0293f7766a
commit 92c2e2af6a
6 changed files with 1441 additions and 4383 deletions

View file

@ -10637,75 +10637,11 @@ class CustomStreamWrapper:
raise e
def handle_bedrock_stream(self, chunk):
if "cohere" in self.model or "anthropic" in self.model:
return {
"text": chunk["text"],
"is_finished": chunk["is_finished"],
"finish_reason": chunk["finish_reason"],
}
if hasattr(chunk, "get"):
chunk = chunk.get("chunk")
chunk_data = json.loads(chunk.get("bytes").decode())
else:
chunk_data = json.loads(chunk.decode())
if chunk_data:
text = ""
is_finished = False
finish_reason = ""
if "outputText" in chunk_data:
text = chunk_data["outputText"]
# ai21 mapping
if "ai21" in self.model: # fake ai21 streaming
text = chunk_data.get("completions")[0].get("data").get("text")
is_finished = True
finish_reason = "stop"
######## bedrock.anthropic mappings ###############
elif "completion" in chunk_data: # not claude-3
text = chunk_data["completion"] # bedrock.anthropic
stop_reason = chunk_data.get("stop_reason", None)
if stop_reason != None:
is_finished = True
finish_reason = stop_reason
elif "delta" in chunk_data:
if chunk_data["delta"].get("text", None) is not None:
text = chunk_data["delta"]["text"]
stop_reason = chunk_data["delta"].get("stop_reason", None)
if stop_reason != None:
is_finished = True
finish_reason = stop_reason
######## bedrock.mistral mappings ###############
elif "outputs" in chunk_data:
if (
len(chunk_data["outputs"]) == 1
and chunk_data["outputs"][0].get("text", None) is not None
):
text = chunk_data["outputs"][0]["text"]
stop_reason = chunk_data.get("stop_reason", None)
if stop_reason != None:
is_finished = True
finish_reason = stop_reason
######## bedrock.cohere mappings ###############
# meta mapping
elif "generation" in chunk_data:
text = chunk_data["generation"] # bedrock.meta
# cohere mapping
elif "text" in chunk_data:
text = chunk_data["text"] # bedrock.cohere
# cohere mapping for finish reason
elif "finish_reason" in chunk_data:
finish_reason = chunk_data["finish_reason"]
is_finished = True
elif chunk_data.get("completionReason", None):
is_finished = True
finish_reason = chunk_data["completionReason"]
elif chunk.get("error", None):
raise Exception(chunk["error"])
return {
"text": text,
"is_finished": is_finished,
"finish_reason": finish_reason,
}
return ""
return {
"text": chunk["text"],
"is_finished": chunk["is_finished"],
"finish_reason": chunk["finish_reason"],
}
def handle_sagemaker_stream(self, chunk):
if "data: [DONE]" in chunk:
@ -11508,14 +11444,7 @@ class CustomStreamWrapper:
or self.custom_llm_provider == "replicate"
or self.custom_llm_provider == "cached_response"
or self.custom_llm_provider == "predibase"
or (
self.custom_llm_provider == "bedrock"
and (
"cohere" in self.model
or "anthropic" in self.model
or "ai21" in self.model
)
)
or self.custom_llm_provider == "bedrock"
or self.custom_llm_provider in litellm.openai_compatible_endpoints
):
async for chunk in self.completion_stream: