forked from phoenix/litellm-mirror
feat(utils.py): unify common auth params across azure/vertex_ai/bedrock/watsonx
This commit is contained in:
parent
c9d7437d16
commit
48f19cf839
8 changed files with 194 additions and 20 deletions
|
@ -5,6 +5,9 @@ LiteLLM allows you to specify the following:
|
||||||
* API Base
|
* API Base
|
||||||
* API Version
|
* API Version
|
||||||
* API Type
|
* API Type
|
||||||
|
* Project
|
||||||
|
* Location
|
||||||
|
* Token
|
||||||
|
|
||||||
Useful Helper functions:
|
Useful Helper functions:
|
||||||
* [`check_valid_key()`](#check_valid_key)
|
* [`check_valid_key()`](#check_valid_key)
|
||||||
|
@ -43,6 +46,24 @@ os.environ['AZURE_API_TYPE'] = "azure" # [OPTIONAL]
|
||||||
os.environ['OPENAI_API_BASE'] = "https://openai-gpt-4-test2-v-12.openai.azure.com/"
|
os.environ['OPENAI_API_BASE'] = "https://openai-gpt-4-test2-v-12.openai.azure.com/"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Setting Project, Location, Token
|
||||||
|
|
||||||
|
For cloud providers:
|
||||||
|
- Azure
|
||||||
|
- Bedrock
|
||||||
|
- GCP
|
||||||
|
- Watson AI
|
||||||
|
|
||||||
|
you might need to set additional parameters. LiteLLM provides a common set of params, that we map across all providers.
|
||||||
|
|
||||||
|
| | LiteLLM param | Watson | Vertex AI | Azure | Bedrock |
|
||||||
|
|------|--------------|--------------|--------------|--------------|--------------|
|
||||||
|
| Project | project | watsonx_project | vertex_project | n/a | n/a |
|
||||||
|
| Region | region_name | watsonx_region_name | vertex_location | n/a | aws_region_name |
|
||||||
|
| Token | token | watsonx_token or token | n/a | azure_ad_token | n/a |
|
||||||
|
|
||||||
|
If you want, you can call them by their provider-specific params as well.
|
||||||
|
|
||||||
## litellm variables
|
## litellm variables
|
||||||
|
|
||||||
### litellm.api_key
|
### litellm.api_key
|
||||||
|
|
|
@ -58,6 +58,7 @@ max_tokens = 256 # OpenAI Defaults
|
||||||
drop_params = False
|
drop_params = False
|
||||||
modify_params = False
|
modify_params = False
|
||||||
retry = True
|
retry = True
|
||||||
|
### AUTH ###
|
||||||
api_key: Optional[str] = None
|
api_key: Optional[str] = None
|
||||||
openai_key: Optional[str] = None
|
openai_key: Optional[str] = None
|
||||||
azure_key: Optional[str] = None
|
azure_key: Optional[str] = None
|
||||||
|
@ -76,6 +77,10 @@ cloudflare_api_key: Optional[str] = None
|
||||||
baseten_key: Optional[str] = None
|
baseten_key: Optional[str] = None
|
||||||
aleph_alpha_key: Optional[str] = None
|
aleph_alpha_key: Optional[str] = None
|
||||||
nlp_cloud_key: Optional[str] = None
|
nlp_cloud_key: Optional[str] = None
|
||||||
|
common_cloud_provider_auth_params: dict = {
|
||||||
|
"params": ["project", "region_name", "token"],
|
||||||
|
"providers": ["vertex_ai", "bedrock", "watsonx", "azure"],
|
||||||
|
}
|
||||||
use_client: bool = False
|
use_client: bool = False
|
||||||
ssl_verify: bool = True
|
ssl_verify: bool = True
|
||||||
disable_streaming_logging: bool = False
|
disable_streaming_logging: bool = False
|
||||||
|
@ -654,6 +659,7 @@ from .llms.bedrock import (
|
||||||
AmazonLlamaConfig,
|
AmazonLlamaConfig,
|
||||||
AmazonStabilityConfig,
|
AmazonStabilityConfig,
|
||||||
AmazonMistralConfig,
|
AmazonMistralConfig,
|
||||||
|
AmazonBedrockGlobalConfig,
|
||||||
)
|
)
|
||||||
from .llms.openai import OpenAIConfig, OpenAITextCompletionConfig
|
from .llms.openai import OpenAIConfig, OpenAITextCompletionConfig
|
||||||
from .llms.azure import AzureOpenAIConfig, AzureOpenAIError
|
from .llms.azure import AzureOpenAIConfig, AzureOpenAIError
|
||||||
|
|
|
@ -96,6 +96,15 @@ class AzureOpenAIConfig(OpenAIConfig):
|
||||||
top_p,
|
top_p,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_mapped_special_auth_params(self) -> dict:
|
||||||
|
return {"token": "azure_ad_token"}
|
||||||
|
|
||||||
|
def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
|
||||||
|
for param, value in non_default_params.items():
|
||||||
|
if param == "token":
|
||||||
|
optional_params["azure_ad_token"] = value
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
|
||||||
def select_azure_base_url_or_endpoint(azure_client_params: dict):
|
def select_azure_base_url_or_endpoint(azure_client_params: dict):
|
||||||
# azure_client_params = {
|
# azure_client_params = {
|
||||||
|
|
|
@ -29,6 +29,24 @@ class BedrockError(Exception):
|
||||||
) # Call the base class constructor with the parameters it needs
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
|
|
||||||
|
class AmazonBedrockGlobalConfig:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_mapped_special_auth_params(self) -> dict:
|
||||||
|
"""
|
||||||
|
Mapping of common auth params across bedrock/vertex/azure/watsonx
|
||||||
|
"""
|
||||||
|
return {"region_name": "aws_region_name"}
|
||||||
|
|
||||||
|
def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
|
||||||
|
mapped_params = self.get_mapped_special_auth_params()
|
||||||
|
for param, value in non_default_params.items():
|
||||||
|
if param in mapped_params:
|
||||||
|
optional_params[mapped_params[param]] = value
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
|
||||||
class AmazonTitanConfig:
|
class AmazonTitanConfig:
|
||||||
"""
|
"""
|
||||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-text-express-v1
|
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-text-express-v1
|
||||||
|
|
|
@ -184,6 +184,20 @@ class VertexAIConfig:
|
||||||
pass
|
pass
|
||||||
return optional_params
|
return optional_params
|
||||||
|
|
||||||
|
def get_mapped_special_auth_params(self) -> dict:
|
||||||
|
"""
|
||||||
|
Common auth params across bedrock/vertex_ai/azure/watsonx
|
||||||
|
"""
|
||||||
|
return {"project": "vertex_project", "region_name": "vertex_location"}
|
||||||
|
|
||||||
|
def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
|
||||||
|
mapped_params = self.get_mapped_special_auth_params()
|
||||||
|
|
||||||
|
for param, value in non_default_params.items():
|
||||||
|
if param in mapped_params:
|
||||||
|
optional_params[mapped_params[param]] = value
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
|
@ -529,7 +543,7 @@ def completion(
|
||||||
"instances": instances,
|
"instances": instances,
|
||||||
"vertex_location": vertex_location,
|
"vertex_location": vertex_location,
|
||||||
"vertex_project": vertex_project,
|
"vertex_project": vertex_project,
|
||||||
"safety_settings":safety_settings,
|
"safety_settings": safety_settings,
|
||||||
**optional_params,
|
**optional_params,
|
||||||
}
|
}
|
||||||
if optional_params.get("stream", False) is True:
|
if optional_params.get("stream", False) is True:
|
||||||
|
|
|
@ -131,6 +131,24 @@ class IBMWatsonXAIConfig:
|
||||||
"stream", # equivalent to stream
|
"stream", # equivalent to stream
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def get_mapped_special_auth_params(self) -> dict:
|
||||||
|
"""
|
||||||
|
Common auth params across bedrock/vertex_ai/azure/watsonx
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"project": "watsonx_project",
|
||||||
|
"region_name": "watsonx_region_name",
|
||||||
|
"token": "watsonx_token",
|
||||||
|
}
|
||||||
|
|
||||||
|
def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
|
||||||
|
mapped_params = self.get_mapped_special_auth_params()
|
||||||
|
|
||||||
|
for param, value in non_default_params.items():
|
||||||
|
if param in mapped_params:
|
||||||
|
optional_params[mapped_params[param]] = value
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
|
||||||
def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
|
def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
|
||||||
# handle anthropic prompts and amazon titan prompts
|
# handle anthropic prompts and amazon titan prompts
|
||||||
|
|
|
@ -2654,6 +2654,7 @@ def test_completion_palm_stream():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
def test_completion_watsonx():
|
def test_completion_watsonx():
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
model_name = "watsonx/ibm/granite-13b-chat-v2"
|
model_name = "watsonx/ibm/granite-13b-chat-v2"
|
||||||
|
@ -2671,10 +2672,57 @@ def test_completion_watsonx():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"provider, model, project, region_name, token",
|
||||||
|
[
|
||||||
|
("azure", "chatgpt-v-2", None, None, "test-token"),
|
||||||
|
("vertex_ai", "anthropic-claude-3", "adroit-crow-1", "us-east1", None),
|
||||||
|
("watsonx", "ibm/granite", "96946574", "dallas", "1234"),
|
||||||
|
("bedrock", "anthropic.claude-3", None, "us-east-1", None),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_unified_auth_params(provider, model, project, region_name, token):
|
||||||
|
"""
|
||||||
|
Check if params = ["project", "region_name", "token"]
|
||||||
|
are correctly translated for = ["azure", "vertex_ai", "watsonx", "aws"]
|
||||||
|
|
||||||
|
tests get_optional_params
|
||||||
|
"""
|
||||||
|
data = {
|
||||||
|
"project": project,
|
||||||
|
"region_name": region_name,
|
||||||
|
"token": token,
|
||||||
|
"custom_llm_provider": provider,
|
||||||
|
"model": model,
|
||||||
|
}
|
||||||
|
|
||||||
|
translated_optional_params = litellm.utils.get_optional_params(**data)
|
||||||
|
|
||||||
|
if provider == "azure":
|
||||||
|
special_auth_params = (
|
||||||
|
litellm.AzureOpenAIConfig().get_mapped_special_auth_params()
|
||||||
|
)
|
||||||
|
elif provider == "bedrock":
|
||||||
|
special_auth_params = (
|
||||||
|
litellm.AmazonBedrockGlobalConfig().get_mapped_special_auth_params()
|
||||||
|
)
|
||||||
|
elif provider == "vertex_ai":
|
||||||
|
special_auth_params = litellm.VertexAIConfig().get_mapped_special_auth_params()
|
||||||
|
elif provider == "watsonx":
|
||||||
|
special_auth_params = (
|
||||||
|
litellm.IBMWatsonXAIConfig().get_mapped_special_auth_params()
|
||||||
|
)
|
||||||
|
|
||||||
|
for param, value in special_auth_params.items():
|
||||||
|
assert param in data
|
||||||
|
assert value in translated_optional_params
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_acompletion_watsonx():
|
async def test_acompletion_watsonx():
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
model_name = "watsonx/deployment/"+os.getenv("WATSONX_DEPLOYMENT_ID")
|
model_name = "watsonx/deployment/" + os.getenv("WATSONX_DEPLOYMENT_ID")
|
||||||
print("testing watsonx")
|
print("testing watsonx")
|
||||||
try:
|
try:
|
||||||
response = await litellm.acompletion(
|
response = await litellm.acompletion(
|
||||||
|
|
|
@ -4619,7 +4619,36 @@ def get_optional_params(
|
||||||
k.startswith("vertex_") and custom_llm_provider != "vertex_ai"
|
k.startswith("vertex_") and custom_llm_provider != "vertex_ai"
|
||||||
): # allow dynamically setting vertex ai init logic
|
): # allow dynamically setting vertex ai init logic
|
||||||
continue
|
continue
|
||||||
|
|
||||||
passed_params[k] = v
|
passed_params[k] = v
|
||||||
|
|
||||||
|
optional_params = {}
|
||||||
|
|
||||||
|
common_auth_dict = litellm.common_cloud_provider_auth_params
|
||||||
|
if custom_llm_provider in common_auth_dict["providers"]:
|
||||||
|
"""
|
||||||
|
Check if params = ["project", "region_name", "token"]
|
||||||
|
and correctly translate for = ["azure", "vertex_ai", "watsonx", "aws"]
|
||||||
|
"""
|
||||||
|
if custom_llm_provider == "azure":
|
||||||
|
optional_params = litellm.AzureOpenAIConfig().map_special_auth_params(
|
||||||
|
non_default_params=passed_params, optional_params=optional_params
|
||||||
|
)
|
||||||
|
elif custom_llm_provider == "bedrock":
|
||||||
|
optional_params = (
|
||||||
|
litellm.AmazonBedrockGlobalConfig().map_special_auth_params(
|
||||||
|
non_default_params=passed_params, optional_params=optional_params
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif custom_llm_provider == "vertex_ai":
|
||||||
|
optional_params = litellm.VertexAIConfig().map_special_auth_params(
|
||||||
|
non_default_params=passed_params, optional_params=optional_params
|
||||||
|
)
|
||||||
|
elif custom_llm_provider == "watsonx":
|
||||||
|
optional_params = litellm.IBMWatsonXAIConfig().map_special_auth_params(
|
||||||
|
non_default_params=passed_params, optional_params=optional_params
|
||||||
|
)
|
||||||
|
|
||||||
default_params = {
|
default_params = {
|
||||||
"functions": None,
|
"functions": None,
|
||||||
"function_call": None,
|
"function_call": None,
|
||||||
|
@ -4655,7 +4684,7 @@ def get_optional_params(
|
||||||
and v != default_params[k]
|
and v != default_params[k]
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
optional_params = {}
|
|
||||||
## raise exception if function calling passed in for a provider that doesn't support it
|
## raise exception if function calling passed in for a provider that doesn't support it
|
||||||
if (
|
if (
|
||||||
"functions" in non_default_params
|
"functions" in non_default_params
|
||||||
|
@ -5452,11 +5481,15 @@ def get_optional_params(
|
||||||
if "decoding_method" in passed_params:
|
if "decoding_method" in passed_params:
|
||||||
extra_body["decoding_method"] = passed_params.pop("decoding_method")
|
extra_body["decoding_method"] = passed_params.pop("decoding_method")
|
||||||
if "min_tokens" in passed_params or "min_new_tokens" in passed_params:
|
if "min_tokens" in passed_params or "min_new_tokens" in passed_params:
|
||||||
extra_body["min_new_tokens"] = passed_params.pop("min_tokens", passed_params.pop("min_new_tokens"))
|
extra_body["min_new_tokens"] = passed_params.pop(
|
||||||
|
"min_tokens", passed_params.pop("min_new_tokens")
|
||||||
|
)
|
||||||
if "top_k" in passed_params:
|
if "top_k" in passed_params:
|
||||||
extra_body["top_k"] = passed_params.pop("top_k")
|
extra_body["top_k"] = passed_params.pop("top_k")
|
||||||
if "truncate_input_tokens" in passed_params:
|
if "truncate_input_tokens" in passed_params:
|
||||||
extra_body["truncate_input_tokens"] = passed_params.pop("truncate_input_tokens")
|
extra_body["truncate_input_tokens"] = passed_params.pop(
|
||||||
|
"truncate_input_tokens"
|
||||||
|
)
|
||||||
if "length_penalty" in passed_params:
|
if "length_penalty" in passed_params:
|
||||||
extra_body["length_penalty"] = passed_params.pop("length_penalty")
|
extra_body["length_penalty"] = passed_params.pop("length_penalty")
|
||||||
if "time_limit" in passed_params:
|
if "time_limit" in passed_params:
|
||||||
|
@ -9801,19 +9834,21 @@ class CustomStreamWrapper:
|
||||||
elif isinstance(chunk, (str, bytes)):
|
elif isinstance(chunk, (str, bytes)):
|
||||||
if isinstance(chunk, bytes):
|
if isinstance(chunk, bytes):
|
||||||
chunk = chunk.decode("utf-8")
|
chunk = chunk.decode("utf-8")
|
||||||
if 'generated_text' in chunk:
|
if "generated_text" in chunk:
|
||||||
response = chunk.replace('data: ', '').strip()
|
response = chunk.replace("data: ", "").strip()
|
||||||
parsed_response = json.loads(response)
|
parsed_response = json.loads(response)
|
||||||
else:
|
else:
|
||||||
return {"text": "", "is_finished": False}
|
return {"text": "", "is_finished": False}
|
||||||
else:
|
else:
|
||||||
print_verbose(f"chunk: {chunk} (Type: {type(chunk)})")
|
print_verbose(f"chunk: {chunk} (Type: {type(chunk)})")
|
||||||
raise ValueError(f"Unable to parse response. Original response: {chunk}")
|
raise ValueError(
|
||||||
|
f"Unable to parse response. Original response: {chunk}"
|
||||||
|
)
|
||||||
results = parsed_response.get("results", [])
|
results = parsed_response.get("results", [])
|
||||||
if len(results) > 0:
|
if len(results) > 0:
|
||||||
text = results[0].get("generated_text", "")
|
text = results[0].get("generated_text", "")
|
||||||
finish_reason = results[0].get("stop_reason")
|
finish_reason = results[0].get("stop_reason")
|
||||||
is_finished = finish_reason != 'not_finished'
|
is_finished = finish_reason != "not_finished"
|
||||||
return {
|
return {
|
||||||
"text": text,
|
"text": text,
|
||||||
"is_finished": is_finished,
|
"is_finished": is_finished,
|
||||||
|
@ -10085,14 +10120,19 @@ class CustomStreamWrapper:
|
||||||
completion_obj["content"] = response_obj["text"]
|
completion_obj["content"] = response_obj["text"]
|
||||||
print_verbose(f"completion obj content: {completion_obj['content']}")
|
print_verbose(f"completion obj content: {completion_obj['content']}")
|
||||||
if response_obj.get("prompt_tokens") is not None:
|
if response_obj.get("prompt_tokens") is not None:
|
||||||
prompt_token_count = getattr(model_response.usage, "prompt_tokens", 0)
|
prompt_token_count = getattr(
|
||||||
model_response.usage.prompt_tokens = (prompt_token_count+response_obj["prompt_tokens"])
|
model_response.usage, "prompt_tokens", 0
|
||||||
if response_obj.get("completion_tokens") is not None:
|
|
||||||
model_response.usage.completion_tokens = response_obj["completion_tokens"]
|
|
||||||
model_response.usage.total_tokens = (
|
|
||||||
getattr(model_response.usage, "prompt_tokens", 0)
|
|
||||||
+ getattr(model_response.usage, "completion_tokens", 0)
|
|
||||||
)
|
)
|
||||||
|
model_response.usage.prompt_tokens = (
|
||||||
|
prompt_token_count + response_obj["prompt_tokens"]
|
||||||
|
)
|
||||||
|
if response_obj.get("completion_tokens") is not None:
|
||||||
|
model_response.usage.completion_tokens = response_obj[
|
||||||
|
"completion_tokens"
|
||||||
|
]
|
||||||
|
model_response.usage.total_tokens = getattr(
|
||||||
|
model_response.usage, "prompt_tokens", 0
|
||||||
|
) + getattr(model_response.usage, "completion_tokens", 0)
|
||||||
if response_obj["is_finished"]:
|
if response_obj["is_finished"]:
|
||||||
self.received_finish_reason = response_obj["finish_reason"]
|
self.received_finish_reason = response_obj["finish_reason"]
|
||||||
elif self.custom_llm_provider == "text-completion-openai":
|
elif self.custom_llm_provider == "text-completion-openai":
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue