feat(utils.py): unify common auth params across azure/vertex_ai/bedrock/watsonx

This commit is contained in:
Krrish Dholakia 2024-04-27 11:06:18 -07:00
parent c9d7437d16
commit 48f19cf839
8 changed files with 194 additions and 20 deletions

View file

@ -4619,7 +4619,36 @@ def get_optional_params(
k.startswith("vertex_") and custom_llm_provider != "vertex_ai"
): # allow dynamically setting vertex ai init logic
continue
passed_params[k] = v
optional_params = {}
common_auth_dict = litellm.common_cloud_provider_auth_params
if custom_llm_provider in common_auth_dict["providers"]:
"""
Check if params = ["project", "region_name", "token"]
and correctly translate for = ["azure", "vertex_ai", "watsonx", "aws"]
"""
if custom_llm_provider == "azure":
optional_params = litellm.AzureOpenAIConfig().map_special_auth_params(
non_default_params=passed_params, optional_params=optional_params
)
elif custom_llm_provider == "bedrock":
optional_params = (
litellm.AmazonBedrockGlobalConfig().map_special_auth_params(
non_default_params=passed_params, optional_params=optional_params
)
)
elif custom_llm_provider == "vertex_ai":
optional_params = litellm.VertexAIConfig().map_special_auth_params(
non_default_params=passed_params, optional_params=optional_params
)
elif custom_llm_provider == "watsonx":
optional_params = litellm.IBMWatsonXAIConfig().map_special_auth_params(
non_default_params=passed_params, optional_params=optional_params
)
default_params = {
"functions": None,
"function_call": None,
@ -4655,7 +4684,7 @@ def get_optional_params(
and v != default_params[k]
)
}
optional_params = {}
## raise exception if function calling passed in for a provider that doesn't support it
if (
"functions" in non_default_params
@ -5446,17 +5475,21 @@ def get_optional_params(
optional_params["random_seed"] = seed
if stop is not None:
optional_params["stop_sequences"] = stop
# WatsonX-only parameters
extra_body = {}
if "decoding_method" in passed_params:
extra_body["decoding_method"] = passed_params.pop("decoding_method")
if "min_tokens" in passed_params or "min_new_tokens" in passed_params:
extra_body["min_new_tokens"] = passed_params.pop("min_tokens", passed_params.pop("min_new_tokens"))
if "min_tokens" in passed_params or "min_new_tokens" in passed_params:
extra_body["min_new_tokens"] = passed_params.pop(
"min_tokens", passed_params.pop("min_new_tokens")
)
if "top_k" in passed_params:
extra_body["top_k"] = passed_params.pop("top_k")
if "truncate_input_tokens" in passed_params:
extra_body["truncate_input_tokens"] = passed_params.pop("truncate_input_tokens")
extra_body["truncate_input_tokens"] = passed_params.pop(
"truncate_input_tokens"
)
if "length_penalty" in passed_params:
extra_body["length_penalty"] = passed_params.pop("length_penalty")
if "time_limit" in passed_params:
@ -5464,7 +5497,7 @@ def get_optional_params(
if "return_options" in passed_params:
extra_body["return_options"] = passed_params.pop("return_options")
optional_params["extra_body"] = (
extra_body # openai client supports `extra_body` param
extra_body # openai client supports `extra_body` param
)
else: # assume passing in params for openai/azure openai
print_verbose(
@ -9793,7 +9826,7 @@ class CustomStreamWrapper:
"is_finished": chunk["is_finished"],
"finish_reason": finish_reason,
}
def handle_watsonx_stream(self, chunk):
try:
if isinstance(chunk, dict):
@ -9801,19 +9834,21 @@ class CustomStreamWrapper:
elif isinstance(chunk, (str, bytes)):
if isinstance(chunk, bytes):
chunk = chunk.decode("utf-8")
if 'generated_text' in chunk:
response = chunk.replace('data: ', '').strip()
if "generated_text" in chunk:
response = chunk.replace("data: ", "").strip()
parsed_response = json.loads(response)
else:
return {"text": "", "is_finished": False}
else:
print_verbose(f"chunk: {chunk} (Type: {type(chunk)})")
raise ValueError(f"Unable to parse response. Original response: {chunk}")
raise ValueError(
f"Unable to parse response. Original response: {chunk}"
)
results = parsed_response.get("results", [])
if len(results) > 0:
text = results[0].get("generated_text", "")
finish_reason = results[0].get("stop_reason")
is_finished = finish_reason != 'not_finished'
is_finished = finish_reason != "not_finished"
return {
"text": text,
"is_finished": is_finished,
@ -10085,14 +10120,19 @@ class CustomStreamWrapper:
completion_obj["content"] = response_obj["text"]
print_verbose(f"completion obj content: {completion_obj['content']}")
if response_obj.get("prompt_tokens") is not None:
prompt_token_count = getattr(model_response.usage, "prompt_tokens", 0)
model_response.usage.prompt_tokens = (prompt_token_count+response_obj["prompt_tokens"])
prompt_token_count = getattr(
model_response.usage, "prompt_tokens", 0
)
model_response.usage.prompt_tokens = (
prompt_token_count + response_obj["prompt_tokens"]
)
if response_obj.get("completion_tokens") is not None:
model_response.usage.completion_tokens = response_obj["completion_tokens"]
model_response.usage.total_tokens = (
getattr(model_response.usage, "prompt_tokens", 0)
+ getattr(model_response.usage, "completion_tokens", 0)
)
model_response.usage.completion_tokens = response_obj[
"completion_tokens"
]
model_response.usage.total_tokens = getattr(
model_response.usage, "prompt_tokens", 0
) + getattr(model_response.usage, "completion_tokens", 0)
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider == "text-completion-openai":