fix(bedrock_httpx.py): add async support for bedrock amazon, meta, mistral models

This commit is contained in:
Krrish Dholakia 2024-05-16 22:39:25 -07:00
parent 0293f7766a
commit 92c2e2af6a
6 changed files with 1441 additions and 4383 deletions

View file

@ -307,7 +307,7 @@ class BedrockLLM(BaseLLM):
try:
if provider == "cohere":
model_response.choices[0].message.content = completion_response["text"] # type: ignore
outputText = completion_response["text"] # type: ignore
elif provider == "anthropic":
if model.startswith("anthropic.claude-3"):
json_schemas: dict = {}
@ -427,6 +427,15 @@ class BedrockLLM(BaseLLM):
outputText = (
completion_response.get("completions")[0].get("data").get("text")
)
elif provider == "meta":
outputText = completion_response["generation"]
elif provider == "mistral":
outputText = completion_response["outputs"][0]["text"]
model_response["finish_reason"] = completion_response["outputs"][0][
"stop_reason"
]
else: # amazon titan
outputText = completion_response.get("results")[0].get("outputText")
except Exception as e:
raise BedrockError(
message="Error processing={}, Received error={}".format(
@ -691,6 +700,40 @@ class BedrockLLM(BaseLLM):
inference_params[k] = v
data = json.dumps({"prompt": prompt, **inference_params})
elif provider == "mistral":
## LOAD CONFIG
config = litellm.AmazonMistralConfig.get_config()
for k, v in config.items():
if (
k not in inference_params
): # completion(top_k=3) > amazon_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
data = json.dumps({"prompt": prompt, **inference_params})
elif provider == "amazon": # amazon titan
## LOAD CONFIG
config = litellm.AmazonTitanConfig.get_config()
for k, v in config.items():
if (
k not in inference_params
): # completion(top_k=3) > amazon_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
data = json.dumps(
{
"inputText": prompt,
"textGenerationConfig": inference_params,
}
)
elif provider == "meta":
## LOAD CONFIG
config = litellm.AmazonLlamaConfig.get_config()
for k, v in config.items():
if (
k not in inference_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
data = json.dumps({"prompt": prompt, **inference_params})
else:
raise Exception("UNSUPPORTED PROVIDER")

View file

@ -326,10 +326,7 @@ async def acompletion(
or custom_llm_provider == "sagemaker"
or custom_llm_provider == "anthropic"
or custom_llm_provider == "predibase"
or (
custom_llm_provider == "bedrock"
and ("cohere" in model or "anthropic" in model or "ai21" in model)
)
or custom_llm_provider == "bedrock"
or custom_llm_provider in litellm.openai_compatible_providers
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
init_response = await loop.run_in_executor(None, func_with_context)
@ -1982,7 +1979,6 @@ def completion(
# boto3 reads keys from .env
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
if "cohere" in model or "anthropic" in model or "ai21" in model:
response = bedrock_chat_completion.completion(
model=model,
messages=messages,
@ -1998,43 +1994,6 @@ def completion(
timeout=timeout,
acompletion=acompletion,
)
else:
response = bedrock.completion(
model=model,
messages=messages,
custom_prompt_dict=litellm.custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
logging_obj=logging,
extra_headers=extra_headers,
timeout=timeout,
)
if (
"stream" in optional_params
and optional_params["stream"] == True
and not isinstance(response, CustomStreamWrapper)
):
# don't try to access stream object,
if "ai21" in model:
response = CustomStreamWrapper(
response,
model,
custom_llm_provider="bedrock",
logging_obj=logging,
)
else:
response = CustomStreamWrapper(
iter(response),
model,
custom_llm_provider="bedrock",
logging_obj=logging,
)
if optional_params.get("stream", False):
## LOGGING
logging.post_call(

File diff suppressed because it is too large Load diff

View file

@ -2670,6 +2670,9 @@ def response_format_tests(response: litellm.ModelResponse):
"anthropic.claude-3-sonnet-20240229-v1:0",
"anthropic.claude-instant-v1",
"bedrock/ai21.j2-mid",
"mistral.mistral-7b-instruct-v0:2",
"bedrock/amazon.titan-tg1-large",
"meta.llama3-8b-instruct-v1:0",
],
)
@pytest.mark.asyncio
@ -2692,7 +2695,7 @@ async def test_completion_bedrock_httpx_models(sync_mode, model):
model=model,
messages=[{"role": "user", "content": "Hey! how's it going?"}],
temperature=0.2,
max_tokens=200,
max_tokens=100,
)
assert isinstance(response, litellm.ModelResponse)
@ -2728,24 +2731,6 @@ def test_completion_bedrock_titan_null_response():
pytest.fail(f"An error occurred - {str(e)}")
def test_completion_bedrock_titan():
try:
response = completion(
model="bedrock/amazon.titan-tg1-large",
messages=messages,
temperature=0.2,
max_tokens=200,
top_p=0.8,
logger_fn=logger_fn,
)
# Add any assertions here to check the response
print(response)
except RateLimitError:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_completion_bedrock_titan()

View file

@ -1048,6 +1048,9 @@ async def test_completion_replicate_llama3_streaming(sync_mode):
"anthropic.claude-3-sonnet-20240229-v1:0",
"anthropic.claude-instant-v1",
"bedrock/ai21.j2-mid",
"mistral.mistral-7b-instruct-v0:2",
"bedrock/amazon.titan-tg1-large",
"meta.llama3-8b-instruct-v1:0",
],
)
@pytest.mark.asyncio

View file

@ -10637,75 +10637,11 @@ class CustomStreamWrapper:
raise e
def handle_bedrock_stream(self, chunk):
if "cohere" in self.model or "anthropic" in self.model:
return {
"text": chunk["text"],
"is_finished": chunk["is_finished"],
"finish_reason": chunk["finish_reason"],
}
if hasattr(chunk, "get"):
chunk = chunk.get("chunk")
chunk_data = json.loads(chunk.get("bytes").decode())
else:
chunk_data = json.loads(chunk.decode())
if chunk_data:
text = ""
is_finished = False
finish_reason = ""
if "outputText" in chunk_data:
text = chunk_data["outputText"]
# ai21 mapping
if "ai21" in self.model: # fake ai21 streaming
text = chunk_data.get("completions")[0].get("data").get("text")
is_finished = True
finish_reason = "stop"
######## bedrock.anthropic mappings ###############
elif "completion" in chunk_data: # not claude-3
text = chunk_data["completion"] # bedrock.anthropic
stop_reason = chunk_data.get("stop_reason", None)
if stop_reason != None:
is_finished = True
finish_reason = stop_reason
elif "delta" in chunk_data:
if chunk_data["delta"].get("text", None) is not None:
text = chunk_data["delta"]["text"]
stop_reason = chunk_data["delta"].get("stop_reason", None)
if stop_reason != None:
is_finished = True
finish_reason = stop_reason
######## bedrock.mistral mappings ###############
elif "outputs" in chunk_data:
if (
len(chunk_data["outputs"]) == 1
and chunk_data["outputs"][0].get("text", None) is not None
):
text = chunk_data["outputs"][0]["text"]
stop_reason = chunk_data.get("stop_reason", None)
if stop_reason != None:
is_finished = True
finish_reason = stop_reason
######## bedrock.cohere mappings ###############
# meta mapping
elif "generation" in chunk_data:
text = chunk_data["generation"] # bedrock.meta
# cohere mapping
elif "text" in chunk_data:
text = chunk_data["text"] # bedrock.cohere
# cohere mapping for finish reason
elif "finish_reason" in chunk_data:
finish_reason = chunk_data["finish_reason"]
is_finished = True
elif chunk_data.get("completionReason", None):
is_finished = True
finish_reason = chunk_data["completionReason"]
elif chunk.get("error", None):
raise Exception(chunk["error"])
return {
"text": text,
"is_finished": is_finished,
"finish_reason": finish_reason,
}
return ""
def handle_sagemaker_stream(self, chunk):
if "data: [DONE]" in chunk:
@ -11508,14 +11444,7 @@ class CustomStreamWrapper:
or self.custom_llm_provider == "replicate"
or self.custom_llm_provider == "cached_response"
or self.custom_llm_provider == "predibase"
or (
self.custom_llm_provider == "bedrock"
and (
"cohere" in self.model
or "anthropic" in self.model
or "ai21" in self.model
)
)
or self.custom_llm_provider == "bedrock"
or self.custom_llm_provider in litellm.openai_compatible_endpoints
):
async for chunk in self.completion_stream: