fix(utils.py): handle 'os.environ/' being passed in as kwargs

This commit is contained in:
Krrish Dholakia 2023-12-22 11:08:44 +05:30
parent d87e59db25
commit 278f61f3ed
5 changed files with 68 additions and 4 deletions

View file

@ -0,0 +1,17 @@
model_list:
- model_name: text-embedding-ada-002
litellm_params:
model: azure/azure-embedding-model
api_base: "os.environ/AZURE_API_BASE"
api_key: "os.environ/AZURE_API_KEY"
api_version: "2023-07-01-preview"
model_info:
mode: embedding
base_model: text-embedding-ada-002
litellm_settings:
set_verbose: True
general_settings:
background_health_checks: True # enable background health checks
health_check_interval: 300 # frequency of background health checks

View file

@ -44,6 +44,16 @@ async def _perform_health_check(model_list: list):
"""
Perform a health check for each model in the list.
"""
async def _check_img_gen_model(model_params: dict):
model_params.pop("messages", None)
model_params["prompt"] = "test from litellm"
try:
await litellm.aimage_generation(**model_params)
except Exception as e:
print_verbose(f"Health check failed for model {model_params['model']}. Error: {e}")
return False
return True
async def _check_embedding_model(model_params: dict):
model_params.pop("messages", None)
model_params["input"] = ["test from litellm"]
@ -64,17 +74,17 @@ async def _perform_health_check(model_list: list):
return True
prepped_params = []
tasks = []
for model in model_list:
litellm_params = model["litellm_params"]
model_info = model.get("model_info", {})
litellm_params["messages"] = _get_random_llm_message()
prepped_params.append(litellm_params)
if model_info.get("mode", None) == "embedding":
# this is an embedding model
tasks.append(_check_embedding_model(litellm_params))
elif model_info.get("mode", None) == "image_generation":
tasks.append(_check_img_gen_model(litellm_params))
else:
tasks.append(_check_model(litellm_params))

View file

@ -49,6 +49,8 @@ model_list:
api_version: 2023-07-01-preview
model: azure/azure-embedding-model
model_name: azure-embedding-model
model_info:
mode: "embedding"
- litellm_params:
model: gpt-3.5-turbo
model_info:
@ -76,21 +78,40 @@ model_list:
- model_name: amazon-embeddings
litellm_params:
model: "bedrock/amazon.titan-embed-text-v1"
model_info:
mode: embedding
- model_name: "GPT-J 6B - Sagemaker Text Embedding (Internal)"
litellm_params:
model: "sagemaker/berri-benchmarking-gpt-j-6b-fp16"
model_info:
mode: embedding
- model_name: dall-e-3
litellm_params:
model: dall-e-3
model: dall-e-3
model_info:
mode: image_generation
- model_name: dall-e-3
litellm_params:
model: "azure/dall-e-3-test"
api_version: "2023-12-01-preview"
api_base: "os.environ/AZURE_SWEDEN_API_BASE"
api_key: "os.environ/AZURE_SWEDEN_API_KEY"
model_info:
mode: image_generation
- model_name: dall-e-2
litellm_params:
model: "azure/"
api_version: "2023-06-01-preview"
api_base: "os.environ/AZURE_API_BASE"
api_key: "os.environ/AZURE_API_KEY"
api_key: "os.environ/AZURE_API_KEY"
model_info:
mode: image_generation
- model_name: text-embedding-ada-002
litellm_params:
model: azure/azure-embedding-model
api_base: "os.environ/AZURE_API_BASE"
api_key: "os.environ/AZURE_API_KEY"
api_version: "2023-07-01-preview"
model_info:
mode: embedding
base_model: text-embedding-ada-002

View file

@ -211,6 +211,18 @@ def test_add_new_model(client_no_auth):
except Exception as e:
pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}")
def test_health(client_no_auth):
global headers
import time
try:
response = client_no_auth.get("/health")
assert response.status_code == 200
result = response.json()
assert result["unhealthy_count"] == 0
raise Exception(f"It worked!")
except Exception as e:
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
# test_add_new_model()
from litellm.integrations.custom_logger import CustomLogger

View file

@ -1627,6 +1627,10 @@ def client(original_function):
logging_obj = function_setup(start_time, *args, **kwargs)
kwargs["litellm_logging_obj"] = logging_obj
# CHECK FOR 'os.environ/' in kwargs
for k,v in kwargs.items():
if v is not None and isinstance(v, str) and v.startswith("os.environ/"):
kwargs[k] = litellm.get_secret(v)
# [OPTIONAL] CHECK BUDGET
if litellm.max_budget:
if litellm._current_cost > litellm.max_budget: