forked from phoenix/litellm-mirror
fix(bedrock_httpx.py): support mapping for bedrock cohere command r text
This commit is contained in:
parent
9aa05c19d1
commit
5d24a72b7e
5 changed files with 106 additions and 4085 deletions
|
@ -307,7 +307,13 @@ class BedrockLLM(BaseLLM):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if provider == "cohere":
|
if provider == "cohere":
|
||||||
outputText = completion_response["text"] # type: ignore
|
if "text" in completion_response:
|
||||||
|
outputText = completion_response["text"] # type: ignore
|
||||||
|
elif "generations" in completion_response:
|
||||||
|
outputText = completion_response["generations"][0]["text"]
|
||||||
|
model_response["finish_reason"] = map_finish_reason(
|
||||||
|
completion_response["generations"][0]["finish_reason"]
|
||||||
|
)
|
||||||
elif provider == "anthropic":
|
elif provider == "anthropic":
|
||||||
if model.startswith("anthropic.claude-3"):
|
if model.startswith("anthropic.claude-3"):
|
||||||
json_schemas: dict = {}
|
json_schemas: dict = {}
|
||||||
|
|
|
@ -1981,21 +1981,60 @@ def completion(
|
||||||
# boto3 reads keys from .env
|
# boto3 reads keys from .env
|
||||||
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
||||||
|
|
||||||
response = bedrock_chat_completion.completion(
|
if (
|
||||||
model=model,
|
"aws_bedrock_client" in optional_params
|
||||||
messages=messages,
|
): # use old bedrock flow for aws_bedrock_client users.
|
||||||
custom_prompt_dict=litellm.custom_prompt_dict,
|
response = bedrock.completion(
|
||||||
model_response=model_response,
|
model=model,
|
||||||
print_verbose=print_verbose,
|
messages=messages,
|
||||||
optional_params=optional_params,
|
custom_prompt_dict=litellm.custom_prompt_dict,
|
||||||
litellm_params=litellm_params,
|
model_response=model_response,
|
||||||
logger_fn=logger_fn,
|
print_verbose=print_verbose,
|
||||||
encoding=encoding,
|
optional_params=optional_params,
|
||||||
logging_obj=logging,
|
litellm_params=litellm_params,
|
||||||
extra_headers=extra_headers,
|
logger_fn=logger_fn,
|
||||||
timeout=timeout,
|
encoding=encoding,
|
||||||
acompletion=acompletion,
|
logging_obj=logging,
|
||||||
)
|
extra_headers=extra_headers,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
"stream" in optional_params
|
||||||
|
and optional_params["stream"] == True
|
||||||
|
and not isinstance(response, CustomStreamWrapper)
|
||||||
|
):
|
||||||
|
# don't try to access stream object,
|
||||||
|
if "ai21" in model:
|
||||||
|
response = CustomStreamWrapper(
|
||||||
|
response,
|
||||||
|
model,
|
||||||
|
custom_llm_provider="bedrock",
|
||||||
|
logging_obj=logging,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = CustomStreamWrapper(
|
||||||
|
iter(response),
|
||||||
|
model,
|
||||||
|
custom_llm_provider="bedrock",
|
||||||
|
logging_obj=logging,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = bedrock_chat_completion.completion(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
custom_prompt_dict=custom_prompt_dict,
|
||||||
|
model_response=model_response,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
logger_fn=logger_fn,
|
||||||
|
encoding=encoding,
|
||||||
|
logging_obj=logging,
|
||||||
|
extra_headers=extra_headers,
|
||||||
|
timeout=timeout,
|
||||||
|
acompletion=acompletion,
|
||||||
|
)
|
||||||
if optional_params.get("stream", False):
|
if optional_params.get("stream", False):
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging.post_call(
|
logging.post_call(
|
||||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -2673,6 +2673,7 @@ def response_format_tests(response: litellm.ModelResponse):
|
||||||
"mistral.mistral-7b-instruct-v0:2",
|
"mistral.mistral-7b-instruct-v0:2",
|
||||||
"bedrock/amazon.titan-tg1-large",
|
"bedrock/amazon.titan-tg1-large",
|
||||||
"meta.llama3-8b-instruct-v1:0",
|
"meta.llama3-8b-instruct-v1:0",
|
||||||
|
"cohere.command-text-v14",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
|
@ -1044,13 +1044,14 @@ async def test_completion_replicate_llama3_streaming(sync_mode):
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model",
|
"model",
|
||||||
[
|
[
|
||||||
"bedrock/cohere.command-r-plus-v1:0",
|
# "bedrock/cohere.command-r-plus-v1:0",
|
||||||
"anthropic.claude-3-sonnet-20240229-v1:0",
|
# "anthropic.claude-3-sonnet-20240229-v1:0",
|
||||||
"anthropic.claude-instant-v1",
|
# "anthropic.claude-instant-v1",
|
||||||
"bedrock/ai21.j2-mid",
|
# "bedrock/ai21.j2-mid",
|
||||||
"mistral.mistral-7b-instruct-v0:2",
|
# "mistral.mistral-7b-instruct-v0:2",
|
||||||
"bedrock/amazon.titan-tg1-large",
|
# "bedrock/amazon.titan-tg1-large",
|
||||||
"meta.llama3-8b-instruct-v1:0",
|
# "meta.llama3-8b-instruct-v1:0",
|
||||||
|
"cohere.command-text-v14"
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue