forked from phoenix/litellm-mirror
fix(bedrock_httpx.py): add async support for bedrock amazon, meta, mistral models
This commit is contained in:
parent
0293f7766a
commit
92c2e2af6a
6 changed files with 1441 additions and 4383 deletions
|
@ -307,7 +307,7 @@ class BedrockLLM(BaseLLM):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if provider == "cohere":
|
if provider == "cohere":
|
||||||
model_response.choices[0].message.content = completion_response["text"] # type: ignore
|
outputText = completion_response["text"] # type: ignore
|
||||||
elif provider == "anthropic":
|
elif provider == "anthropic":
|
||||||
if model.startswith("anthropic.claude-3"):
|
if model.startswith("anthropic.claude-3"):
|
||||||
json_schemas: dict = {}
|
json_schemas: dict = {}
|
||||||
|
@ -427,6 +427,15 @@ class BedrockLLM(BaseLLM):
|
||||||
outputText = (
|
outputText = (
|
||||||
completion_response.get("completions")[0].get("data").get("text")
|
completion_response.get("completions")[0].get("data").get("text")
|
||||||
)
|
)
|
||||||
|
elif provider == "meta":
|
||||||
|
outputText = completion_response["generation"]
|
||||||
|
elif provider == "mistral":
|
||||||
|
outputText = completion_response["outputs"][0]["text"]
|
||||||
|
model_response["finish_reason"] = completion_response["outputs"][0][
|
||||||
|
"stop_reason"
|
||||||
|
]
|
||||||
|
else: # amazon titan
|
||||||
|
outputText = completion_response.get("results")[0].get("outputText")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise BedrockError(
|
raise BedrockError(
|
||||||
message="Error processing={}, Received error={}".format(
|
message="Error processing={}, Received error={}".format(
|
||||||
|
@ -691,6 +700,40 @@ class BedrockLLM(BaseLLM):
|
||||||
inference_params[k] = v
|
inference_params[k] = v
|
||||||
|
|
||||||
data = json.dumps({"prompt": prompt, **inference_params})
|
data = json.dumps({"prompt": prompt, **inference_params})
|
||||||
|
elif provider == "mistral":
|
||||||
|
## LOAD CONFIG
|
||||||
|
config = litellm.AmazonMistralConfig.get_config()
|
||||||
|
for k, v in config.items():
|
||||||
|
if (
|
||||||
|
k not in inference_params
|
||||||
|
): # completion(top_k=3) > amazon_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||||
|
inference_params[k] = v
|
||||||
|
|
||||||
|
data = json.dumps({"prompt": prompt, **inference_params})
|
||||||
|
elif provider == "amazon": # amazon titan
|
||||||
|
## LOAD CONFIG
|
||||||
|
config = litellm.AmazonTitanConfig.get_config()
|
||||||
|
for k, v in config.items():
|
||||||
|
if (
|
||||||
|
k not in inference_params
|
||||||
|
): # completion(top_k=3) > amazon_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||||
|
inference_params[k] = v
|
||||||
|
|
||||||
|
data = json.dumps(
|
||||||
|
{
|
||||||
|
"inputText": prompt,
|
||||||
|
"textGenerationConfig": inference_params,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif provider == "meta":
|
||||||
|
## LOAD CONFIG
|
||||||
|
config = litellm.AmazonLlamaConfig.get_config()
|
||||||
|
for k, v in config.items():
|
||||||
|
if (
|
||||||
|
k not in inference_params
|
||||||
|
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||||
|
inference_params[k] = v
|
||||||
|
data = json.dumps({"prompt": prompt, **inference_params})
|
||||||
else:
|
else:
|
||||||
raise Exception("UNSUPPORTED PROVIDER")
|
raise Exception("UNSUPPORTED PROVIDER")
|
||||||
|
|
||||||
|
|
|
@ -326,10 +326,7 @@ async def acompletion(
|
||||||
or custom_llm_provider == "sagemaker"
|
or custom_llm_provider == "sagemaker"
|
||||||
or custom_llm_provider == "anthropic"
|
or custom_llm_provider == "anthropic"
|
||||||
or custom_llm_provider == "predibase"
|
or custom_llm_provider == "predibase"
|
||||||
or (
|
or custom_llm_provider == "bedrock"
|
||||||
custom_llm_provider == "bedrock"
|
|
||||||
and ("cohere" in model or "anthropic" in model or "ai21" in model)
|
|
||||||
)
|
|
||||||
or custom_llm_provider in litellm.openai_compatible_providers
|
or custom_llm_provider in litellm.openai_compatible_providers
|
||||||
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
|
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
|
||||||
init_response = await loop.run_in_executor(None, func_with_context)
|
init_response = await loop.run_in_executor(None, func_with_context)
|
||||||
|
@ -1982,7 +1979,6 @@ def completion(
|
||||||
# boto3 reads keys from .env
|
# boto3 reads keys from .env
|
||||||
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
||||||
|
|
||||||
if "cohere" in model or "anthropic" in model or "ai21" in model:
|
|
||||||
response = bedrock_chat_completion.completion(
|
response = bedrock_chat_completion.completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
@ -1998,43 +1994,6 @@ def completion(
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
acompletion=acompletion,
|
acompletion=acompletion,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
response = bedrock.completion(
|
|
||||||
model=model,
|
|
||||||
messages=messages,
|
|
||||||
custom_prompt_dict=litellm.custom_prompt_dict,
|
|
||||||
model_response=model_response,
|
|
||||||
print_verbose=print_verbose,
|
|
||||||
optional_params=optional_params,
|
|
||||||
litellm_params=litellm_params,
|
|
||||||
logger_fn=logger_fn,
|
|
||||||
encoding=encoding,
|
|
||||||
logging_obj=logging,
|
|
||||||
extra_headers=extra_headers,
|
|
||||||
timeout=timeout,
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
|
||||||
"stream" in optional_params
|
|
||||||
and optional_params["stream"] == True
|
|
||||||
and not isinstance(response, CustomStreamWrapper)
|
|
||||||
):
|
|
||||||
# don't try to access stream object,
|
|
||||||
if "ai21" in model:
|
|
||||||
response = CustomStreamWrapper(
|
|
||||||
response,
|
|
||||||
model,
|
|
||||||
custom_llm_provider="bedrock",
|
|
||||||
logging_obj=logging,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
response = CustomStreamWrapper(
|
|
||||||
iter(response),
|
|
||||||
model,
|
|
||||||
custom_llm_provider="bedrock",
|
|
||||||
logging_obj=logging,
|
|
||||||
)
|
|
||||||
|
|
||||||
if optional_params.get("stream", False):
|
if optional_params.get("stream", False):
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging.post_call(
|
logging.post_call(
|
||||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -2670,6 +2670,9 @@ def response_format_tests(response: litellm.ModelResponse):
|
||||||
"anthropic.claude-3-sonnet-20240229-v1:0",
|
"anthropic.claude-3-sonnet-20240229-v1:0",
|
||||||
"anthropic.claude-instant-v1",
|
"anthropic.claude-instant-v1",
|
||||||
"bedrock/ai21.j2-mid",
|
"bedrock/ai21.j2-mid",
|
||||||
|
"mistral.mistral-7b-instruct-v0:2",
|
||||||
|
"bedrock/amazon.titan-tg1-large",
|
||||||
|
"meta.llama3-8b-instruct-v1:0",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@ -2692,7 +2695,7 @@ async def test_completion_bedrock_httpx_models(sync_mode, model):
|
||||||
model=model,
|
model=model,
|
||||||
messages=[{"role": "user", "content": "Hey! how's it going?"}],
|
messages=[{"role": "user", "content": "Hey! how's it going?"}],
|
||||||
temperature=0.2,
|
temperature=0.2,
|
||||||
max_tokens=200,
|
max_tokens=100,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert isinstance(response, litellm.ModelResponse)
|
assert isinstance(response, litellm.ModelResponse)
|
||||||
|
@ -2728,24 +2731,6 @@ def test_completion_bedrock_titan_null_response():
|
||||||
pytest.fail(f"An error occurred - {str(e)}")
|
pytest.fail(f"An error occurred - {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
def test_completion_bedrock_titan():
|
|
||||||
try:
|
|
||||||
response = completion(
|
|
||||||
model="bedrock/amazon.titan-tg1-large",
|
|
||||||
messages=messages,
|
|
||||||
temperature=0.2,
|
|
||||||
max_tokens=200,
|
|
||||||
top_p=0.8,
|
|
||||||
logger_fn=logger_fn,
|
|
||||||
)
|
|
||||||
# Add any assertions here to check the response
|
|
||||||
print(response)
|
|
||||||
except RateLimitError:
|
|
||||||
pass
|
|
||||||
except Exception as e:
|
|
||||||
pytest.fail(f"Error occurred: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
# test_completion_bedrock_titan()
|
# test_completion_bedrock_titan()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1048,6 +1048,9 @@ async def test_completion_replicate_llama3_streaming(sync_mode):
|
||||||
"anthropic.claude-3-sonnet-20240229-v1:0",
|
"anthropic.claude-3-sonnet-20240229-v1:0",
|
||||||
"anthropic.claude-instant-v1",
|
"anthropic.claude-instant-v1",
|
||||||
"bedrock/ai21.j2-mid",
|
"bedrock/ai21.j2-mid",
|
||||||
|
"mistral.mistral-7b-instruct-v0:2",
|
||||||
|
"bedrock/amazon.titan-tg1-large",
|
||||||
|
"meta.llama3-8b-instruct-v1:0",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
|
@ -10637,75 +10637,11 @@ class CustomStreamWrapper:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def handle_bedrock_stream(self, chunk):
|
def handle_bedrock_stream(self, chunk):
|
||||||
if "cohere" in self.model or "anthropic" in self.model:
|
|
||||||
return {
|
return {
|
||||||
"text": chunk["text"],
|
"text": chunk["text"],
|
||||||
"is_finished": chunk["is_finished"],
|
"is_finished": chunk["is_finished"],
|
||||||
"finish_reason": chunk["finish_reason"],
|
"finish_reason": chunk["finish_reason"],
|
||||||
}
|
}
|
||||||
if hasattr(chunk, "get"):
|
|
||||||
chunk = chunk.get("chunk")
|
|
||||||
chunk_data = json.loads(chunk.get("bytes").decode())
|
|
||||||
else:
|
|
||||||
chunk_data = json.loads(chunk.decode())
|
|
||||||
if chunk_data:
|
|
||||||
text = ""
|
|
||||||
is_finished = False
|
|
||||||
finish_reason = ""
|
|
||||||
if "outputText" in chunk_data:
|
|
||||||
text = chunk_data["outputText"]
|
|
||||||
# ai21 mapping
|
|
||||||
if "ai21" in self.model: # fake ai21 streaming
|
|
||||||
text = chunk_data.get("completions")[0].get("data").get("text")
|
|
||||||
is_finished = True
|
|
||||||
finish_reason = "stop"
|
|
||||||
######## bedrock.anthropic mappings ###############
|
|
||||||
elif "completion" in chunk_data: # not claude-3
|
|
||||||
text = chunk_data["completion"] # bedrock.anthropic
|
|
||||||
stop_reason = chunk_data.get("stop_reason", None)
|
|
||||||
if stop_reason != None:
|
|
||||||
is_finished = True
|
|
||||||
finish_reason = stop_reason
|
|
||||||
elif "delta" in chunk_data:
|
|
||||||
if chunk_data["delta"].get("text", None) is not None:
|
|
||||||
text = chunk_data["delta"]["text"]
|
|
||||||
stop_reason = chunk_data["delta"].get("stop_reason", None)
|
|
||||||
if stop_reason != None:
|
|
||||||
is_finished = True
|
|
||||||
finish_reason = stop_reason
|
|
||||||
######## bedrock.mistral mappings ###############
|
|
||||||
elif "outputs" in chunk_data:
|
|
||||||
if (
|
|
||||||
len(chunk_data["outputs"]) == 1
|
|
||||||
and chunk_data["outputs"][0].get("text", None) is not None
|
|
||||||
):
|
|
||||||
text = chunk_data["outputs"][0]["text"]
|
|
||||||
stop_reason = chunk_data.get("stop_reason", None)
|
|
||||||
if stop_reason != None:
|
|
||||||
is_finished = True
|
|
||||||
finish_reason = stop_reason
|
|
||||||
######## bedrock.cohere mappings ###############
|
|
||||||
# meta mapping
|
|
||||||
elif "generation" in chunk_data:
|
|
||||||
text = chunk_data["generation"] # bedrock.meta
|
|
||||||
# cohere mapping
|
|
||||||
elif "text" in chunk_data:
|
|
||||||
text = chunk_data["text"] # bedrock.cohere
|
|
||||||
# cohere mapping for finish reason
|
|
||||||
elif "finish_reason" in chunk_data:
|
|
||||||
finish_reason = chunk_data["finish_reason"]
|
|
||||||
is_finished = True
|
|
||||||
elif chunk_data.get("completionReason", None):
|
|
||||||
is_finished = True
|
|
||||||
finish_reason = chunk_data["completionReason"]
|
|
||||||
elif chunk.get("error", None):
|
|
||||||
raise Exception(chunk["error"])
|
|
||||||
return {
|
|
||||||
"text": text,
|
|
||||||
"is_finished": is_finished,
|
|
||||||
"finish_reason": finish_reason,
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
|
|
||||||
def handle_sagemaker_stream(self, chunk):
|
def handle_sagemaker_stream(self, chunk):
|
||||||
if "data: [DONE]" in chunk:
|
if "data: [DONE]" in chunk:
|
||||||
|
@ -11508,14 +11444,7 @@ class CustomStreamWrapper:
|
||||||
or self.custom_llm_provider == "replicate"
|
or self.custom_llm_provider == "replicate"
|
||||||
or self.custom_llm_provider == "cached_response"
|
or self.custom_llm_provider == "cached_response"
|
||||||
or self.custom_llm_provider == "predibase"
|
or self.custom_llm_provider == "predibase"
|
||||||
or (
|
or self.custom_llm_provider == "bedrock"
|
||||||
self.custom_llm_provider == "bedrock"
|
|
||||||
and (
|
|
||||||
"cohere" in self.model
|
|
||||||
or "anthropic" in self.model
|
|
||||||
or "ai21" in self.model
|
|
||||||
)
|
|
||||||
)
|
|
||||||
or self.custom_llm_provider in litellm.openai_compatible_endpoints
|
or self.custom_llm_provider in litellm.openai_compatible_endpoints
|
||||||
):
|
):
|
||||||
async for chunk in self.completion_stream:
|
async for chunk in self.completion_stream:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue