Merge branch 'main' into litellm_vertex_completion_httpx

This commit is contained in:
Krish Dholakia 2024-06-12 21:19:22 -07:00 committed by GitHub
commit 05e21441a6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
56 changed files with 568 additions and 145 deletions

View file

@ -4903,6 +4903,18 @@ def get_optional_params_embeddings(
)
final_params = {**optional_params, **kwargs}
return final_params
if custom_llm_provider == "vertex_ai":
supported_params = get_supported_openai_params(
model=model,
custom_llm_provider="vertex_ai",
request_type="embeddings",
)
_check_valid_arg(supported_params=supported_params)
optional_params = litellm.VertexAITextEmbeddingConfig().map_openai_params(
non_default_params=non_default_params, optional_params={}
)
final_params = {**optional_params, **kwargs}
return final_params
if custom_llm_provider == "vertex_ai":
if len(non_default_params.keys()) > 0:
if litellm.drop_params is True: # drop the unsupported non-default values
@ -4936,7 +4948,18 @@ def get_optional_params_embeddings(
message=f"Setting user/encoding format is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.",
)
return {**non_default_params, **kwargs}
if custom_llm_provider == "mistral":
supported_params = get_supported_openai_params(
model=model,
custom_llm_provider="mistral",
request_type="embeddings",
)
_check_valid_arg(supported_params=supported_params)
optional_params = litellm.MistralEmbeddingConfig().map_openai_params(
non_default_params=non_default_params, optional_params={}
)
final_params = {**optional_params, **kwargs}
return final_params
if (
custom_llm_provider != "openai"
and custom_llm_provider != "azure"
@ -6355,7 +6378,10 @@ def get_supported_openai_params(
"max_retries",
]
elif custom_llm_provider == "mistral":
return litellm.MistralConfig().get_supported_openai_params()
if request_type == "chat_completion":
return litellm.MistralConfig().get_supported_openai_params()
elif request_type == "embeddings":
return litellm.MistralEmbeddingConfig().get_supported_openai_params()
elif custom_llm_provider == "replicate":
return [
"stream",
@ -6397,7 +6423,10 @@ def get_supported_openai_params(
elif custom_llm_provider == "palm" or custom_llm_provider == "gemini":
return ["temperature", "top_p", "stream", "n", "stop", "max_tokens"]
elif custom_llm_provider == "vertex_ai":
return litellm.VertexAIConfig().get_supported_openai_params()
if request_type == "chat_completion":
return litellm.VertexAIConfig().get_supported_openai_params()
elif request_type == "embeddings":
return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params()
elif custom_llm_provider == "sagemaker":
return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
elif custom_llm_provider == "aleph_alpha":
@ -7207,6 +7236,9 @@ def get_provider_fields(custom_llm_provider: str) -> List[ProviderField]:
elif custom_llm_provider == "ollama":
return litellm.OllamaConfig().get_required_params()
elif custom_llm_provider == "azure_ai":
return litellm.AzureAIStudioConfig().get_required_params()
else:
return []
@ -10081,6 +10113,14 @@ def get_secret(
return oidc_token
else:
raise ValueError("Github OIDC provider failed")
elif oidc_provider == "azure":
# https://azure.github.io/azure-workload-identity/docs/quick-start.html
azure_federated_token_file = os.getenv("AZURE_FEDERATED_TOKEN_FILE")
if azure_federated_token_file is None:
raise ValueError("AZURE_FEDERATED_TOKEN_FILE not found in environment")
with open(azure_federated_token_file, "r") as f:
oidc_token = f.read()
return oidc_token
else:
raise ValueError("Unsupported OIDC provider")