test(test_custom_callback_input.py): add bedrock testing

n

n
This commit is contained in:
Krrish Dholakia 2023-12-11 12:59:49 -08:00
parent 6a3ba74183
commit b09ecb986e
4 changed files with 165 additions and 36 deletions

View file

@ -683,9 +683,14 @@ def completion(
logger_fn=logger_fn
)
# if "stream" in optional_params and optional_params["stream"] == True:
# response = CustomStreamWrapper(model_response, model, custom_llm_provider="text-completion-openai", logging_obj=logging)
# return response
if optional_params.get("stream", False) or acompletion == True:
## LOGGING
logging.post_call(
input=messages,
api_key=api_key,
original_response=model_response,
additional_args={"headers": headers},
)
response = model_response
elif (
"replicate" in model or
@ -730,8 +735,16 @@ def completion(
)
if "stream" in optional_params and optional_params["stream"] == True:
# don't try to access stream object,
response = CustomStreamWrapper(model_response, model, logging_obj=logging, custom_llm_provider="replicate")
return response
model_response = CustomStreamWrapper(model_response, model, logging_obj=logging, custom_llm_provider="replicate") # type: ignore
if optional_params.get("stream", False) or acompletion == True:
## LOGGING
logging.post_call(
input=messages,
api_key=replicate_key,
original_response=model_response,
)
response = model_response
elif custom_llm_provider=="anthropic":
@ -751,7 +764,7 @@ def completion(
custom_prompt_dict
or litellm.custom_prompt_dict
)
model_response = anthropic.completion(
response = anthropic.completion(
model=model,
messages=messages,
api_base=api_base,
@ -767,9 +780,16 @@ def completion(
)
if "stream" in optional_params and optional_params["stream"] == True:
# don't try to access stream object,
response = CustomStreamWrapper(model_response, model, custom_llm_provider="anthropic", logging_obj=logging)
return response
response = model_response
response = CustomStreamWrapper(response, model, custom_llm_provider="anthropic", logging_obj=logging)
if optional_params.get("stream", False) or acompletion == True:
## LOGGING
logging.post_call(
input=messages,
api_key=api_key,
original_response=response,
)
response = response
elif custom_llm_provider == "nlp_cloud":
nlp_cloud_key = (
api_key or litellm.nlp_cloud_key or get_secret("NLP_CLOUD_API_KEY") or litellm.api_key
@ -782,7 +802,7 @@ def completion(
or "https://api.nlpcloud.io/v1/gpu/"
)
model_response = nlp_cloud.completion(
response = nlp_cloud.completion(
model=model,
messages=messages,
api_base=api_base,
@ -798,9 +818,17 @@ def completion(
if "stream" in optional_params and optional_params["stream"] == True:
# don't try to access stream object,
response = CustomStreamWrapper(model_response, model, custom_llm_provider="nlp_cloud", logging_obj=logging)
return response
response = model_response
response = CustomStreamWrapper(response, model, custom_llm_provider="nlp_cloud", logging_obj=logging)
if optional_params.get("stream", False) or acompletion == True:
## LOGGING
logging.post_call(
input=messages,
api_key=api_key,
original_response=response,
)
response = response
elif custom_llm_provider == "aleph_alpha":
aleph_alpha_key = (
api_key or litellm.aleph_alpha_key or get_secret("ALEPH_ALPHA_API_KEY") or get_secret("ALEPHALPHA_API_KEY") or litellm.api_key
@ -1202,7 +1230,7 @@ def completion(
custom_prompt_dict
or litellm.custom_prompt_dict
)
model_response = bedrock.completion(
response = bedrock.completion(
model=model,
messages=messages,
custom_prompt_dict=litellm.custom_prompt_dict,
@ -1220,16 +1248,24 @@ def completion(
# don't try to access stream object,
if "ai21" in model:
response = CustomStreamWrapper(
model_response, model, custom_llm_provider="bedrock", logging_obj=logging
response, model, custom_llm_provider="bedrock", logging_obj=logging
)
else:
response = CustomStreamWrapper(
iter(model_response), model, custom_llm_provider="bedrock", logging_obj=logging
iter(response), model, custom_llm_provider="bedrock", logging_obj=logging
)
return response
if optional_params.get("stream", False):
## LOGGING
logging.post_call(
input=messages,
api_key=None,
original_response=response,
)
## RESPONSE OBJECT
response = model_response
response = response
elif custom_llm_provider == "vllm":
model_response = vllm.completion(
model=model,