mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
test(test_custom_callback_input.py): add bedrock testing
n n
This commit is contained in:
parent
6a3ba74183
commit
b09ecb986e
4 changed files with 165 additions and 36 deletions
|
@ -683,9 +683,14 @@ def completion(
|
|||
logger_fn=logger_fn
|
||||
)
|
||||
|
||||
# if "stream" in optional_params and optional_params["stream"] == True:
|
||||
# response = CustomStreamWrapper(model_response, model, custom_llm_provider="text-completion-openai", logging_obj=logging)
|
||||
# return response
|
||||
if optional_params.get("stream", False) or acompletion == True:
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=model_response,
|
||||
additional_args={"headers": headers},
|
||||
)
|
||||
response = model_response
|
||||
elif (
|
||||
"replicate" in model or
|
||||
|
@ -730,8 +735,16 @@ def completion(
|
|||
)
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
# don't try to access stream object,
|
||||
response = CustomStreamWrapper(model_response, model, logging_obj=logging, custom_llm_provider="replicate")
|
||||
return response
|
||||
model_response = CustomStreamWrapper(model_response, model, logging_obj=logging, custom_llm_provider="replicate") # type: ignore
|
||||
|
||||
if optional_params.get("stream", False) or acompletion == True:
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
input=messages,
|
||||
api_key=replicate_key,
|
||||
original_response=model_response,
|
||||
)
|
||||
|
||||
response = model_response
|
||||
|
||||
elif custom_llm_provider=="anthropic":
|
||||
|
@ -751,7 +764,7 @@ def completion(
|
|||
custom_prompt_dict
|
||||
or litellm.custom_prompt_dict
|
||||
)
|
||||
model_response = anthropic.completion(
|
||||
response = anthropic.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_base=api_base,
|
||||
|
@ -767,9 +780,16 @@ def completion(
|
|||
)
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
# don't try to access stream object,
|
||||
response = CustomStreamWrapper(model_response, model, custom_llm_provider="anthropic", logging_obj=logging)
|
||||
return response
|
||||
response = model_response
|
||||
response = CustomStreamWrapper(response, model, custom_llm_provider="anthropic", logging_obj=logging)
|
||||
|
||||
if optional_params.get("stream", False) or acompletion == True:
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=response,
|
||||
)
|
||||
response = response
|
||||
elif custom_llm_provider == "nlp_cloud":
|
||||
nlp_cloud_key = (
|
||||
api_key or litellm.nlp_cloud_key or get_secret("NLP_CLOUD_API_KEY") or litellm.api_key
|
||||
|
@ -782,7 +802,7 @@ def completion(
|
|||
or "https://api.nlpcloud.io/v1/gpu/"
|
||||
)
|
||||
|
||||
model_response = nlp_cloud.completion(
|
||||
response = nlp_cloud.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_base=api_base,
|
||||
|
@ -798,9 +818,17 @@ def completion(
|
|||
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
# don't try to access stream object,
|
||||
response = CustomStreamWrapper(model_response, model, custom_llm_provider="nlp_cloud", logging_obj=logging)
|
||||
return response
|
||||
response = model_response
|
||||
response = CustomStreamWrapper(response, model, custom_llm_provider="nlp_cloud", logging_obj=logging)
|
||||
|
||||
if optional_params.get("stream", False) or acompletion == True:
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=response,
|
||||
)
|
||||
|
||||
response = response
|
||||
elif custom_llm_provider == "aleph_alpha":
|
||||
aleph_alpha_key = (
|
||||
api_key or litellm.aleph_alpha_key or get_secret("ALEPH_ALPHA_API_KEY") or get_secret("ALEPHALPHA_API_KEY") or litellm.api_key
|
||||
|
@ -1202,7 +1230,7 @@ def completion(
|
|||
custom_prompt_dict
|
||||
or litellm.custom_prompt_dict
|
||||
)
|
||||
model_response = bedrock.completion(
|
||||
response = bedrock.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
custom_prompt_dict=litellm.custom_prompt_dict,
|
||||
|
@ -1220,16 +1248,24 @@ def completion(
|
|||
# don't try to access stream object,
|
||||
if "ai21" in model:
|
||||
response = CustomStreamWrapper(
|
||||
model_response, model, custom_llm_provider="bedrock", logging_obj=logging
|
||||
response, model, custom_llm_provider="bedrock", logging_obj=logging
|
||||
)
|
||||
else:
|
||||
response = CustomStreamWrapper(
|
||||
iter(model_response), model, custom_llm_provider="bedrock", logging_obj=logging
|
||||
iter(response), model, custom_llm_provider="bedrock", logging_obj=logging
|
||||
)
|
||||
return response
|
||||
|
||||
if optional_params.get("stream", False):
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
input=messages,
|
||||
api_key=None,
|
||||
original_response=response,
|
||||
)
|
||||
|
||||
|
||||
## RESPONSE OBJECT
|
||||
response = model_response
|
||||
response = response
|
||||
elif custom_llm_provider == "vllm":
|
||||
model_response = vllm.completion(
|
||||
model=model,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue