Merge branch 'main' into litellm_arize_dynamic_logging

This commit is contained in:
Ishaan Jaff 2025-03-18 22:13:35 -07:00
commit 57e5c94360
13 changed files with 225 additions and 135 deletions

View file

@ -1,4 +1,3 @@
import json
from typing import TYPE_CHECKING, Any, Optional
from litellm._logging import verbose_logger

View file

@ -129,17 +129,15 @@ def get_llm_provider( # noqa: PLR0915
model, custom_llm_provider
)
if custom_llm_provider:
if (
model.split("/")[0] == custom_llm_provider
): # handle scenario where model="azure/*" and custom_llm_provider="azure"
model = model.replace("{}/".format(custom_llm_provider), "")
return model, custom_llm_provider, dynamic_api_key, api_base
if custom_llm_provider and (
model.split("/")[0] != custom_llm_provider
): # handle scenario where model="azure/*" and custom_llm_provider="azure"
model = custom_llm_provider + "/" + model
if api_key and api_key.startswith("os.environ/"):
dynamic_api_key = get_secret_str(api_key)
# check if llm provider part of model name
if (
model.split("/", 1)[0] in litellm.provider_list
and model.split("/", 1)[0] not in litellm.model_list_set
@ -573,11 +571,11 @@ def _get_openai_compatible_provider_info( # noqa: PLR0915
dynamic_api_key = api_key or get_secret_str("GALADRIEL_API_KEY")
elif custom_llm_provider == "snowflake":
api_base = (
api_base
or get_secret("SNOWFLAKE_API_BASE")
or f"https://{get_secret('SNOWFLAKE_ACCOUNT_ID')}.snowflakecomputing.com/api/v2/cortex/inference:complete"
) # type: ignore
dynamic_api_key = api_key or get_secret("SNOWFLAKE_JWT")
api_base
or get_secret_str("SNOWFLAKE_API_BASE")
or f"https://{get_secret('SNOWFLAKE_ACCOUNT_ID')}.snowflakecomputing.com/api/v2/cortex/inference:complete"
) # type: ignore
dynamic_api_key = api_key or get_secret_str("SNOWFLAKE_JWT")
if api_base is not None and not isinstance(api_base, str):
raise Exception("api base needs to be a string. api_base={}".format(api_base))

View file

@ -3257,6 +3257,7 @@ class StandardLoggingPayloadSetup:
additional_headers=None,
litellm_overhead_time_ms=None,
batch_models=None,
litellm_model_name=None,
)
if hidden_params is not None:
for key in StandardLoggingHiddenParams.__annotations__.keys():
@ -3371,6 +3372,7 @@ def get_standard_logging_object_payload(
response_cost=None,
litellm_overhead_time_ms=None,
batch_models=None,
litellm_model_name=None,
)
)
@ -3656,6 +3658,7 @@ def create_dummy_standard_logging_payload() -> StandardLoggingPayload:
additional_headers=None,
litellm_overhead_time_ms=None,
batch_models=None,
litellm_model_name=None,
)
# Convert numeric values to appropriate types

View file

@ -44,6 +44,7 @@ class ResponseMetadata:
"additional_headers": process_response_headers(
self._get_value_from_hidden_params("additional_headers") or {}
),
"litellm_model_name": model,
}
self._update_hidden_params(new_params)

View file

@ -65,10 +65,13 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
# ------------
# Setup values
# ------------
dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
model_id = deployment.get("model_info", {}).get("id")
rpm_key = f"{model_id}:rpm:{current_minute}"
deployment_name = deployment.get("litellm_params", {}).get("model")
rpm_key = f"{model_id}:{deployment_name}:rpm:{current_minute}"
local_result = self.router_cache.get_cache(
key=rpm_key, local_only=True
) # check local result first
@ -151,7 +154,9 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
model_id = deployment.get("model_info", {}).get("id")
rpm_key = f"{model_id}:rpm:{current_minute}"
deployment_name = deployment.get("litellm_params", {}).get("model")
rpm_key = f"{model_id}:{deployment_name}:rpm:{current_minute}"
local_result = await self.router_cache.async_get_cache(
key=rpm_key, local_only=True
) # check local result first
@ -228,8 +233,9 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
if standard_logging_object is None:
raise ValueError("standard_logging_object not passed in.")
model_group = standard_logging_object.get("model_group")
model = standard_logging_object["hidden_params"].get("litellm_model_name")
id = standard_logging_object.get("model_id")
if model_group is None or id is None:
if model_group is None or id is None or model is None:
return
elif isinstance(id, int):
id = str(id)
@ -244,7 +250,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
"%H-%M"
) # use the same timezone regardless of system clock
tpm_key = f"{id}:tpm:{current_minute}"
tpm_key = f"{id}:{model}:tpm:{current_minute}"
# ------------
# Update usage
# ------------
@ -276,6 +282,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
if standard_logging_object is None:
raise ValueError("standard_logging_object not passed in.")
model_group = standard_logging_object.get("model_group")
model = standard_logging_object["hidden_params"]["litellm_model_name"]
id = standard_logging_object.get("model_id")
if model_group is None or id is None:
return
@ -290,7 +297,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
"%H-%M"
) # use the same timezone regardless of system clock
tpm_key = f"{id}:tpm:{current_minute}"
tpm_key = f"{id}:{model}:tpm:{current_minute}"
# ------------
# Update usage
# ------------
@ -458,8 +465,9 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
id = m.get("model_info", {}).get(
"id"
) # a deployment should always have an 'id'. this is set in router.py
tpm_key = "{}:tpm:{}".format(id, current_minute)
rpm_key = "{}:rpm:{}".format(id, current_minute)
deployment_name = m.get("litellm_params", {}).get("model")
tpm_key = "{}:{}:tpm:{}".format(id, deployment_name, current_minute)
rpm_key = "{}:{}:rpm:{}".format(id, deployment_name, current_minute)
tpm_keys.append(tpm_key)
rpm_keys.append(rpm_key)
@ -576,8 +584,9 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
id = m.get("model_info", {}).get(
"id"
) # a deployment should always have an 'id'. this is set in router.py
tpm_key = "{}:tpm:{}".format(id, current_minute)
rpm_key = "{}:rpm:{}".format(id, current_minute)
deployment_name = m.get("litellm_params", {}).get("model")
tpm_key = "{}:{}:tpm:{}".format(id, deployment_name, current_minute)
rpm_key = "{}:{}:rpm:{}".format(id, deployment_name, current_minute)
tpm_keys.append(tpm_key)
rpm_keys.append(rpm_key)

View file

@ -1625,13 +1625,16 @@ class StandardLoggingAdditionalHeaders(TypedDict, total=False):
class StandardLoggingHiddenParams(TypedDict):
model_id: Optional[str]
model_id: Optional[
str
] # id of the model in the router, separates multiple models with the same name but different credentials
cache_key: Optional[str]
api_base: Optional[str]
response_cost: Optional[str]
litellm_overhead_time_ms: Optional[float]
additional_headers: Optional[StandardLoggingAdditionalHeaders]
batch_models: Optional[List[str]]
litellm_model_name: Optional[str] # the model name sent to the provider by litellm
class StandardLoggingModelInformation(TypedDict):

View file

@ -20,7 +20,8 @@ from litellm.caching.redis_cache import RedisCache
@pytest.mark.parametrize("namespace", [None, "test"])
@pytest.mark.asyncio
async def test_redis_cache_async_increment(namespace):
async def test_redis_cache_async_increment(namespace, monkeypatch):
monkeypatch.setenv("REDIS_HOST", "https://my-test-host")
redis_cache = RedisCache(namespace=namespace)
# Create an AsyncMock for the Redis client
mock_redis_instance = AsyncMock()
@ -46,7 +47,8 @@ async def test_redis_cache_async_increment(namespace):
@pytest.mark.asyncio
async def test_redis_client_init_with_socket_timeout():
async def test_redis_client_init_with_socket_timeout(monkeypatch):
monkeypatch.setenv("REDIS_HOST", "my-fake-host")
redis_cache = RedisCache(socket_timeout=1.0)
assert redis_cache.redis_kwargs["socket_timeout"] == 1.0
client = redis_cache.init_async_client()

View file

@ -8,84 +8,24 @@ import pytest
sys.path.insert(
0, os.path.abspath("../../../../..")
) # Adds the parent directory to the system path
import litellm
from litellm.llms.azure_ai.chat.transformation import AzureAIStudioConfig
@pytest.mark.parametrize("is_async", [True, False])
@pytest.mark.asyncio
async def test_azure_ai_request_format(is_async):
async def test_get_openai_compatible_provider_info():
"""
Test that Azure AI requests are formatted correctly with the proper endpoint and parameters
for both synchronous and asynchronous calls
"""
litellm._turn_on_debug()
config = AzureAIStudioConfig()
# Set up the test parameters
api_key = "00xxx"
api_base = "https://my-endpoint-europe-berri-992.openai.azure.com/openai/deployments/gpt-4o-mini/chat/completions?api-version=2024-08-01-preview"
model = "azure_ai/gpt-4o-mini"
messages = [
{"role": "user", "content": "hi"},
{"role": "assistant", "content": "Hello! How can I assist you today?"},
{"role": "user", "content": "hi"},
]
api_base, dynamic_api_key, custom_llm_provider = (
config._get_openai_compatible_provider_info(
model="azure_ai/gpt-4o-mini",
api_base="https://my-base",
api_key="my-key",
custom_llm_provider="azure_ai",
)
)
if is_async:
# Mock AsyncHTTPHandler.post method for async test
with patch(
"litellm.llms.custom_httpx.llm_http_handler.AsyncHTTPHandler.post"
) as mock_post:
# Set up mock response
mock_post.return_value = AsyncMock()
# Call the acompletion function
try:
await litellm.acompletion(
custom_llm_provider="azure_ai",
api_key=api_key,
api_base=api_base,
model=model,
messages=messages,
)
except Exception as e:
# We expect an exception since we're mocking the response
pass
else:
# Mock HTTPHandler.post method for sync test
with patch(
"litellm.llms.custom_httpx.llm_http_handler.HTTPHandler.post"
) as mock_post:
# Set up mock response
mock_post.return_value = MagicMock()
# Call the completion function
try:
litellm.completion(
custom_llm_provider="azure_ai",
api_key=api_key,
api_base=api_base,
model=model,
messages=messages,
)
except Exception as e:
# We expect an exception since we're mocking the response
pass
# Verify the request was made with the correct parameters
mock_post.assert_called_once()
call_args = mock_post.call_args
print("sync request call=", json.dumps(call_args.kwargs, indent=4, default=str))
# Check URL
assert call_args.kwargs["url"] == api_base
# Check headers
assert call_args.kwargs["headers"]["api-key"] == api_key
# Check request body
request_body = json.loads(call_args.kwargs["data"])
assert (
request_body["model"] == "gpt-4o-mini"
) # Model name should be stripped of provider prefix
assert request_body["messages"] == messages
assert custom_llm_provider == "azure"

View file

@ -265,3 +265,32 @@ class TestAzureAIRerank(BaseLLMRerankTest):
"api_base": os.getenv("AZURE_AI_COHERE_API_BASE"),
"api_key": os.getenv("AZURE_AI_COHERE_API_KEY"),
}
@pytest.mark.asyncio
async def test_azure_ai_request_format():
"""
Test that Azure AI requests are formatted correctly with the proper endpoint and parameters
for both synchronous and asynchronous calls
"""
from openai import AsyncAzureOpenAI, AzureOpenAI
litellm._turn_on_debug()
# Set up the test parameters
api_key = os.getenv("AZURE_API_KEY")
api_base = f"{os.getenv('AZURE_API_BASE')}/openai/deployments/gpt-4o/chat/completions?api-version=2024-08-01-preview"
model = "azure_ai/gpt-4o"
messages = [
{"role": "user", "content": "hi"},
{"role": "assistant", "content": "Hello! How can I assist you today?"},
{"role": "user", "content": "hi"},
]
await litellm.acompletion(
custom_llm_provider="azure_ai",
api_key=api_key,
api_base=api_base,
model=model,
messages=messages,
)

View file

@ -2,7 +2,9 @@ import asyncio
import json
import logging
import os
import time
from unittest.mock import patch, Mock
import opentelemetry.exporter.otlp.proto.grpc.trace_exporter
from litellm import Choices
import pytest
from dotenv import load_dotenv
@ -81,3 +83,52 @@ def test_get_arize_config_with_endpoints(mock_env_vars, monkeypatch):
config = ArizeLogger.get_arize_config()
assert config.endpoint == "grpc://test.endpoint"
assert config.protocol == "otlp_grpc"
@pytest.mark.skip(
reason="Works locally but not in CI/CD. We'll need a better way to test Arize on CI/CD"
)
def test_arize_callback():
litellm.callbacks = ["arize"]
os.environ["ARIZE_SPACE_KEY"] = "test_space_key"
os.environ["ARIZE_API_KEY"] = "test_api_key"
os.environ["ARIZE_ENDPOINT"] = "https://otlp.arize.com/v1"
# Set the batch span processor to quickly flush after a span has been added
# This is to ensure that the span is exported before the test ends
os.environ["OTEL_BSP_MAX_QUEUE_SIZE"] = "1"
os.environ["OTEL_BSP_MAX_EXPORT_BATCH_SIZE"] = "1"
os.environ["OTEL_BSP_SCHEDULE_DELAY_MILLIS"] = "1"
os.environ["OTEL_BSP_EXPORT_TIMEOUT_MILLIS"] = "5"
try:
with patch.object(
opentelemetry.exporter.otlp.proto.grpc.trace_exporter.OTLPSpanExporter,
"export",
new=Mock(),
) as patched_export:
litellm.completion(
model="openai/test-model",
messages=[{"role": "user", "content": "arize test content"}],
stream=False,
mock_response="hello there!",
)
time.sleep(1) # Wait for the batch span processor to flush
assert patched_export.called
finally:
# Clean up environment variables
for key in [
"ARIZE_SPACE_KEY",
"ARIZE_API_KEY",
"ARIZE_ENDPOINT",
"OTEL_BSP_MAX_QUEUE_SIZE",
"OTEL_BSP_MAX_EXPORT_BATCH_SIZE",
"OTEL_BSP_SCHEDULE_DELAY_MILLIS",
"OTEL_BSP_EXPORT_TIMEOUT_MILLIS",
]:
if key in os.environ:
del os.environ[key]
# Reset callbacks
litellm.callbacks = []

View file

@ -216,3 +216,21 @@ def test_bedrock_invoke_anthropic():
)
assert custom_llm_provider == "bedrock"
assert model == "invoke/anthropic.claude-3-5-sonnet-20240620-v1:0"
@pytest.mark.parametrize("model", ["xai/grok-2-vision-latest", "grok-2-vision-latest"])
def test_xai_api_base(model):
args = {
"model": model,
"custom_llm_provider": "xai",
"api_base": None,
"api_key": "xai-my-specialkey",
"litellm_params": None,
}
model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider(
**args
)
assert custom_llm_provider == "xai"
assert model == "grok-2-vision-latest"
assert api_base == "https://api.x.ai/v1"
assert dynamic_api_key == "xai-my-specialkey"

View file

@ -20,7 +20,7 @@ sys.path.insert(
from unittest.mock import AsyncMock, MagicMock, patch
from litellm.types.utils import StandardLoggingPayload
import pytest
from litellm.types.router import DeploymentTypedDict
import litellm
from litellm import Router
from litellm.caching.caching import DualCache
@ -47,12 +47,14 @@ def test_tpm_rpm_updated():
deployment_id = "1234"
deployment = "azure/chatgpt-v-2"
total_tokens = 50
standard_logging_payload = create_standard_logging_payload()
standard_logging_payload: StandardLoggingPayload = create_standard_logging_payload()
standard_logging_payload["model_group"] = model_group
standard_logging_payload["model_id"] = deployment_id
standard_logging_payload["total_tokens"] = total_tokens
standard_logging_payload["hidden_params"]["litellm_model_name"] = deployment
kwargs = {
"litellm_params": {
"model": deployment,
"metadata": {
"model_group": model_group,
"deployment": deployment,
@ -62,10 +64,16 @@ def test_tpm_rpm_updated():
"standard_logging_object": standard_logging_payload,
}
litellm_deployment_dict: DeploymentTypedDict = {
"model_name": model_group,
"litellm_params": {"model": deployment},
"model_info": {"id": deployment_id},
}
start_time = time.time()
response_obj = {"usage": {"total_tokens": total_tokens}}
end_time = time.time()
lowest_tpm_logger.pre_call_check(deployment=kwargs["litellm_params"])
lowest_tpm_logger.pre_call_check(deployment=litellm_deployment_dict)
lowest_tpm_logger.log_success_event(
response_obj=response_obj,
kwargs=kwargs,
@ -74,8 +82,8 @@ def test_tpm_rpm_updated():
)
dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
tpm_count_api_key = f"{deployment_id}:tpm:{current_minute}"
rpm_count_api_key = f"{deployment_id}:rpm:{current_minute}"
tpm_count_api_key = f"{deployment_id}:{deployment}:tpm:{current_minute}"
rpm_count_api_key = f"{deployment_id}:{deployment}:rpm:{current_minute}"
print(f"tpm_count_api_key={tpm_count_api_key}")
assert response_obj["usage"]["total_tokens"] == test_cache.get_cache(
@ -113,6 +121,7 @@ def test_get_available_deployments():
standard_logging_payload["model_group"] = model_group
standard_logging_payload["model_id"] = deployment_id
standard_logging_payload["total_tokens"] = total_tokens
standard_logging_payload["hidden_params"]["litellm_model_name"] = deployment
kwargs = {
"litellm_params": {
"metadata": {
@ -135,10 +144,11 @@ def test_get_available_deployments():
## DEPLOYMENT 2 ##
total_tokens = 20
deployment_id = "5678"
standard_logging_payload = create_standard_logging_payload()
standard_logging_payload: StandardLoggingPayload = create_standard_logging_payload()
standard_logging_payload["model_group"] = model_group
standard_logging_payload["model_id"] = deployment_id
standard_logging_payload["total_tokens"] = total_tokens
standard_logging_payload["hidden_params"]["litellm_model_name"] = deployment
kwargs = {
"litellm_params": {
"metadata": {
@ -209,11 +219,12 @@ def test_router_get_available_deployments():
print(f"router id's: {router.get_model_ids()}")
## DEPLOYMENT 1 ##
deployment_id = 1
standard_logging_payload = create_standard_logging_payload()
standard_logging_payload: StandardLoggingPayload = create_standard_logging_payload()
standard_logging_payload["model_group"] = "azure-model"
standard_logging_payload["model_id"] = str(deployment_id)
total_tokens = 50
standard_logging_payload["total_tokens"] = total_tokens
standard_logging_payload["hidden_params"]["litellm_model_name"] = "azure/gpt-turbo"
kwargs = {
"litellm_params": {
"metadata": {
@ -237,6 +248,9 @@ def test_router_get_available_deployments():
standard_logging_payload = create_standard_logging_payload()
standard_logging_payload["model_group"] = "azure-model"
standard_logging_payload["model_id"] = str(deployment_id)
standard_logging_payload["hidden_params"][
"litellm_model_name"
] = "azure/gpt-35-turbo"
kwargs = {
"litellm_params": {
"metadata": {
@ -293,10 +307,11 @@ def test_router_skip_rate_limited_deployments():
## DEPLOYMENT 1 ##
deployment_id = 1
total_tokens = 1439
standard_logging_payload = create_standard_logging_payload()
standard_logging_payload: StandardLoggingPayload = create_standard_logging_payload()
standard_logging_payload["model_group"] = "azure-model"
standard_logging_payload["model_id"] = str(deployment_id)
standard_logging_payload["total_tokens"] = total_tokens
standard_logging_payload["hidden_params"]["litellm_model_name"] = "azure/gpt-turbo"
kwargs = {
"litellm_params": {
"metadata": {
@ -699,3 +714,54 @@ def test_return_potential_deployments():
)
assert len(potential_deployments) == 1
@pytest.mark.asyncio
async def test_tpm_rpm_routing_model_name_checks():
deployment = {
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_base": os.getenv("AZURE_API_BASE"),
"mock_response": "Hey, how's it going?",
},
}
router = Router(model_list=[deployment], routing_strategy="usage-based-routing-v2")
async def side_effect_pre_call_check(*args, **kwargs):
return args[0]
with patch.object(
router.lowesttpm_logger_v2,
"async_pre_call_check",
side_effect=side_effect_pre_call_check,
) as mock_object, patch.object(
router.lowesttpm_logger_v2, "async_log_success_event"
) as mock_logging_event:
response = await router.acompletion(
model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hey!"}]
)
mock_object.assert_called()
print(f"mock_object.call_args: {mock_object.call_args[0][0]}")
assert (
mock_object.call_args[0][0]["litellm_params"]["model"]
== deployment["litellm_params"]["model"]
)
await asyncio.sleep(1)
mock_logging_event.assert_called()
print(f"mock_logging_event: {mock_logging_event.call_args.kwargs}")
standard_logging_payload: StandardLoggingPayload = (
mock_logging_event.call_args.kwargs.get("kwargs", {}).get(
"standard_logging_object"
)
)
assert (
standard_logging_payload["hidden_params"]["litellm_model_name"]
== "azure/chatgpt-v-2"
)

View file

@ -15,35 +15,6 @@ import litellm
from litellm.types.utils import Choices
def test_arize_callback():
litellm.callbacks = ["arize"]
os.environ["ARIZE_SPACE_KEY"] = "test_space_key"
os.environ["ARIZE_API_KEY"] = "test_api_key"
os.environ["ARIZE_ENDPOINT"] = "https://otlp.arize.com/v1"
# Set the batch span processor to quickly flush after a span has been added
# This is to ensure that the span is exported before the test ends
os.environ["OTEL_BSP_MAX_QUEUE_SIZE"] = "1"
os.environ["OTEL_BSP_MAX_EXPORT_BATCH_SIZE"] = "1"
os.environ["OTEL_BSP_SCHEDULE_DELAY_MILLIS"] = "1"
os.environ["OTEL_BSP_EXPORT_TIMEOUT_MILLIS"] = "5"
with patch.object(
opentelemetry.exporter.otlp.proto.grpc.trace_exporter.OTLPSpanExporter,
"export",
new=Mock(),
) as patched_export:
completion(
model="openai/test-model",
messages=[{"role": "user", "content": "arize test content"}],
stream=False,
mock_response="hello there!",
)
time.sleep(1) # Wait for the batch span processor to flush
assert patched_export.called
def test_arize_set_attributes():
"""
Test setting attributes for Arize