forked from phoenix/litellm-mirror
fix(bedrock_httpx.py): add async support for bedrock amazon, meta, mistral models
This commit is contained in:
parent
0293f7766a
commit
92c2e2af6a
6 changed files with 1441 additions and 4383 deletions
|
@ -307,7 +307,7 @@ class BedrockLLM(BaseLLM):
|
|||
|
||||
try:
|
||||
if provider == "cohere":
|
||||
model_response.choices[0].message.content = completion_response["text"] # type: ignore
|
||||
outputText = completion_response["text"] # type: ignore
|
||||
elif provider == "anthropic":
|
||||
if model.startswith("anthropic.claude-3"):
|
||||
json_schemas: dict = {}
|
||||
|
@ -427,6 +427,15 @@ class BedrockLLM(BaseLLM):
|
|||
outputText = (
|
||||
completion_response.get("completions")[0].get("data").get("text")
|
||||
)
|
||||
elif provider == "meta":
|
||||
outputText = completion_response["generation"]
|
||||
elif provider == "mistral":
|
||||
outputText = completion_response["outputs"][0]["text"]
|
||||
model_response["finish_reason"] = completion_response["outputs"][0][
|
||||
"stop_reason"
|
||||
]
|
||||
else: # amazon titan
|
||||
outputText = completion_response.get("results")[0].get("outputText")
|
||||
except Exception as e:
|
||||
raise BedrockError(
|
||||
message="Error processing={}, Received error={}".format(
|
||||
|
@ -691,6 +700,40 @@ class BedrockLLM(BaseLLM):
|
|||
inference_params[k] = v
|
||||
|
||||
data = json.dumps({"prompt": prompt, **inference_params})
|
||||
elif provider == "mistral":
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonMistralConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in inference_params
|
||||
): # completion(top_k=3) > amazon_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
|
||||
data = json.dumps({"prompt": prompt, **inference_params})
|
||||
elif provider == "amazon": # amazon titan
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonTitanConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in inference_params
|
||||
): # completion(top_k=3) > amazon_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
|
||||
data = json.dumps(
|
||||
{
|
||||
"inputText": prompt,
|
||||
"textGenerationConfig": inference_params,
|
||||
}
|
||||
)
|
||||
elif provider == "meta":
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonLlamaConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in inference_params
|
||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
data = json.dumps({"prompt": prompt, **inference_params})
|
||||
else:
|
||||
raise Exception("UNSUPPORTED PROVIDER")
|
||||
|
||||
|
|
|
@ -326,10 +326,7 @@ async def acompletion(
|
|||
or custom_llm_provider == "sagemaker"
|
||||
or custom_llm_provider == "anthropic"
|
||||
or custom_llm_provider == "predibase"
|
||||
or (
|
||||
custom_llm_provider == "bedrock"
|
||||
and ("cohere" in model or "anthropic" in model or "ai21" in model)
|
||||
)
|
||||
or custom_llm_provider == "bedrock"
|
||||
or custom_llm_provider in litellm.openai_compatible_providers
|
||||
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
|
@ -1982,59 +1979,21 @@ def completion(
|
|||
# boto3 reads keys from .env
|
||||
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
||||
|
||||
if "cohere" in model or "anthropic" in model or "ai21" in model:
|
||||
response = bedrock_chat_completion.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
custom_prompt_dict=litellm.custom_prompt_dict,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
logging_obj=logging,
|
||||
extra_headers=extra_headers,
|
||||
timeout=timeout,
|
||||
acompletion=acompletion,
|
||||
)
|
||||
else:
|
||||
response = bedrock.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
custom_prompt_dict=litellm.custom_prompt_dict,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
logging_obj=logging,
|
||||
extra_headers=extra_headers,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
if (
|
||||
"stream" in optional_params
|
||||
and optional_params["stream"] == True
|
||||
and not isinstance(response, CustomStreamWrapper)
|
||||
):
|
||||
# don't try to access stream object,
|
||||
if "ai21" in model:
|
||||
response = CustomStreamWrapper(
|
||||
response,
|
||||
model,
|
||||
custom_llm_provider="bedrock",
|
||||
logging_obj=logging,
|
||||
)
|
||||
else:
|
||||
response = CustomStreamWrapper(
|
||||
iter(response),
|
||||
model,
|
||||
custom_llm_provider="bedrock",
|
||||
logging_obj=logging,
|
||||
)
|
||||
|
||||
response = bedrock_chat_completion.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
custom_prompt_dict=litellm.custom_prompt_dict,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
logging_obj=logging,
|
||||
extra_headers=extra_headers,
|
||||
timeout=timeout,
|
||||
acompletion=acompletion,
|
||||
)
|
||||
if optional_params.get("stream", False):
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -2670,6 +2670,9 @@ def response_format_tests(response: litellm.ModelResponse):
|
|||
"anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
"anthropic.claude-instant-v1",
|
||||
"bedrock/ai21.j2-mid",
|
||||
"mistral.mistral-7b-instruct-v0:2",
|
||||
"bedrock/amazon.titan-tg1-large",
|
||||
"meta.llama3-8b-instruct-v1:0",
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
|
@ -2692,7 +2695,7 @@ async def test_completion_bedrock_httpx_models(sync_mode, model):
|
|||
model=model,
|
||||
messages=[{"role": "user", "content": "Hey! how's it going?"}],
|
||||
temperature=0.2,
|
||||
max_tokens=200,
|
||||
max_tokens=100,
|
||||
)
|
||||
|
||||
assert isinstance(response, litellm.ModelResponse)
|
||||
|
@ -2728,24 +2731,6 @@ def test_completion_bedrock_titan_null_response():
|
|||
pytest.fail(f"An error occurred - {str(e)}")
|
||||
|
||||
|
||||
def test_completion_bedrock_titan():
|
||||
try:
|
||||
response = completion(
|
||||
model="bedrock/amazon.titan-tg1-large",
|
||||
messages=messages,
|
||||
temperature=0.2,
|
||||
max_tokens=200,
|
||||
top_p=0.8,
|
||||
logger_fn=logger_fn,
|
||||
)
|
||||
# Add any assertions here to check the response
|
||||
print(response)
|
||||
except RateLimitError:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
# test_completion_bedrock_titan()
|
||||
|
||||
|
||||
|
|
|
@ -1048,6 +1048,9 @@ async def test_completion_replicate_llama3_streaming(sync_mode):
|
|||
"anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
"anthropic.claude-instant-v1",
|
||||
"bedrock/ai21.j2-mid",
|
||||
"mistral.mistral-7b-instruct-v0:2",
|
||||
"bedrock/amazon.titan-tg1-large",
|
||||
"meta.llama3-8b-instruct-v1:0",
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
@ -10637,75 +10637,11 @@ class CustomStreamWrapper:
|
|||
raise e
|
||||
|
||||
def handle_bedrock_stream(self, chunk):
|
||||
if "cohere" in self.model or "anthropic" in self.model:
|
||||
return {
|
||||
"text": chunk["text"],
|
||||
"is_finished": chunk["is_finished"],
|
||||
"finish_reason": chunk["finish_reason"],
|
||||
}
|
||||
if hasattr(chunk, "get"):
|
||||
chunk = chunk.get("chunk")
|
||||
chunk_data = json.loads(chunk.get("bytes").decode())
|
||||
else:
|
||||
chunk_data = json.loads(chunk.decode())
|
||||
if chunk_data:
|
||||
text = ""
|
||||
is_finished = False
|
||||
finish_reason = ""
|
||||
if "outputText" in chunk_data:
|
||||
text = chunk_data["outputText"]
|
||||
# ai21 mapping
|
||||
if "ai21" in self.model: # fake ai21 streaming
|
||||
text = chunk_data.get("completions")[0].get("data").get("text")
|
||||
is_finished = True
|
||||
finish_reason = "stop"
|
||||
######## bedrock.anthropic mappings ###############
|
||||
elif "completion" in chunk_data: # not claude-3
|
||||
text = chunk_data["completion"] # bedrock.anthropic
|
||||
stop_reason = chunk_data.get("stop_reason", None)
|
||||
if stop_reason != None:
|
||||
is_finished = True
|
||||
finish_reason = stop_reason
|
||||
elif "delta" in chunk_data:
|
||||
if chunk_data["delta"].get("text", None) is not None:
|
||||
text = chunk_data["delta"]["text"]
|
||||
stop_reason = chunk_data["delta"].get("stop_reason", None)
|
||||
if stop_reason != None:
|
||||
is_finished = True
|
||||
finish_reason = stop_reason
|
||||
######## bedrock.mistral mappings ###############
|
||||
elif "outputs" in chunk_data:
|
||||
if (
|
||||
len(chunk_data["outputs"]) == 1
|
||||
and chunk_data["outputs"][0].get("text", None) is not None
|
||||
):
|
||||
text = chunk_data["outputs"][0]["text"]
|
||||
stop_reason = chunk_data.get("stop_reason", None)
|
||||
if stop_reason != None:
|
||||
is_finished = True
|
||||
finish_reason = stop_reason
|
||||
######## bedrock.cohere mappings ###############
|
||||
# meta mapping
|
||||
elif "generation" in chunk_data:
|
||||
text = chunk_data["generation"] # bedrock.meta
|
||||
# cohere mapping
|
||||
elif "text" in chunk_data:
|
||||
text = chunk_data["text"] # bedrock.cohere
|
||||
# cohere mapping for finish reason
|
||||
elif "finish_reason" in chunk_data:
|
||||
finish_reason = chunk_data["finish_reason"]
|
||||
is_finished = True
|
||||
elif chunk_data.get("completionReason", None):
|
||||
is_finished = True
|
||||
finish_reason = chunk_data["completionReason"]
|
||||
elif chunk.get("error", None):
|
||||
raise Exception(chunk["error"])
|
||||
return {
|
||||
"text": text,
|
||||
"is_finished": is_finished,
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
return ""
|
||||
return {
|
||||
"text": chunk["text"],
|
||||
"is_finished": chunk["is_finished"],
|
||||
"finish_reason": chunk["finish_reason"],
|
||||
}
|
||||
|
||||
def handle_sagemaker_stream(self, chunk):
|
||||
if "data: [DONE]" in chunk:
|
||||
|
@ -11508,14 +11444,7 @@ class CustomStreamWrapper:
|
|||
or self.custom_llm_provider == "replicate"
|
||||
or self.custom_llm_provider == "cached_response"
|
||||
or self.custom_llm_provider == "predibase"
|
||||
or (
|
||||
self.custom_llm_provider == "bedrock"
|
||||
and (
|
||||
"cohere" in self.model
|
||||
or "anthropic" in self.model
|
||||
or "ai21" in self.model
|
||||
)
|
||||
)
|
||||
or self.custom_llm_provider == "bedrock"
|
||||
or self.custom_llm_provider in litellm.openai_compatible_endpoints
|
||||
):
|
||||
async for chunk in self.completion_stream:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue