Merge branch 'main' into litellm_arize_dynamic_logging

This commit is contained in:
Ishaan Jaff 2025-03-18 22:13:35 -07:00
commit 57e5c94360
13 changed files with 225 additions and 135 deletions

View file

@ -1,4 +1,3 @@
import json
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
from litellm._logging import verbose_logger from litellm._logging import verbose_logger

View file

@ -129,17 +129,15 @@ def get_llm_provider( # noqa: PLR0915
model, custom_llm_provider model, custom_llm_provider
) )
if custom_llm_provider: if custom_llm_provider and (
if ( model.split("/")[0] != custom_llm_provider
model.split("/")[0] == custom_llm_provider
): # handle scenario where model="azure/*" and custom_llm_provider="azure" ): # handle scenario where model="azure/*" and custom_llm_provider="azure"
model = model.replace("{}/".format(custom_llm_provider), "") model = custom_llm_provider + "/" + model
return model, custom_llm_provider, dynamic_api_key, api_base
if api_key and api_key.startswith("os.environ/"): if api_key and api_key.startswith("os.environ/"):
dynamic_api_key = get_secret_str(api_key) dynamic_api_key = get_secret_str(api_key)
# check if llm provider part of model name # check if llm provider part of model name
if ( if (
model.split("/", 1)[0] in litellm.provider_list model.split("/", 1)[0] in litellm.provider_list
and model.split("/", 1)[0] not in litellm.model_list_set and model.split("/", 1)[0] not in litellm.model_list_set
@ -574,10 +572,10 @@ def _get_openai_compatible_provider_info( # noqa: PLR0915
elif custom_llm_provider == "snowflake": elif custom_llm_provider == "snowflake":
api_base = ( api_base = (
api_base api_base
or get_secret("SNOWFLAKE_API_BASE") or get_secret_str("SNOWFLAKE_API_BASE")
or f"https://{get_secret('SNOWFLAKE_ACCOUNT_ID')}.snowflakecomputing.com/api/v2/cortex/inference:complete" or f"https://{get_secret('SNOWFLAKE_ACCOUNT_ID')}.snowflakecomputing.com/api/v2/cortex/inference:complete"
) # type: ignore ) # type: ignore
dynamic_api_key = api_key or get_secret("SNOWFLAKE_JWT") dynamic_api_key = api_key or get_secret_str("SNOWFLAKE_JWT")
if api_base is not None and not isinstance(api_base, str): if api_base is not None and not isinstance(api_base, str):
raise Exception("api base needs to be a string. api_base={}".format(api_base)) raise Exception("api base needs to be a string. api_base={}".format(api_base))

View file

@ -3257,6 +3257,7 @@ class StandardLoggingPayloadSetup:
additional_headers=None, additional_headers=None,
litellm_overhead_time_ms=None, litellm_overhead_time_ms=None,
batch_models=None, batch_models=None,
litellm_model_name=None,
) )
if hidden_params is not None: if hidden_params is not None:
for key in StandardLoggingHiddenParams.__annotations__.keys(): for key in StandardLoggingHiddenParams.__annotations__.keys():
@ -3371,6 +3372,7 @@ def get_standard_logging_object_payload(
response_cost=None, response_cost=None,
litellm_overhead_time_ms=None, litellm_overhead_time_ms=None,
batch_models=None, batch_models=None,
litellm_model_name=None,
) )
) )
@ -3656,6 +3658,7 @@ def create_dummy_standard_logging_payload() -> StandardLoggingPayload:
additional_headers=None, additional_headers=None,
litellm_overhead_time_ms=None, litellm_overhead_time_ms=None,
batch_models=None, batch_models=None,
litellm_model_name=None,
) )
# Convert numeric values to appropriate types # Convert numeric values to appropriate types

View file

@ -44,6 +44,7 @@ class ResponseMetadata:
"additional_headers": process_response_headers( "additional_headers": process_response_headers(
self._get_value_from_hidden_params("additional_headers") or {} self._get_value_from_hidden_params("additional_headers") or {}
), ),
"litellm_model_name": model,
} }
self._update_hidden_params(new_params) self._update_hidden_params(new_params)

View file

@ -65,10 +65,13 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
# ------------ # ------------
# Setup values # Setup values
# ------------ # ------------
dt = get_utc_datetime() dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M") current_minute = dt.strftime("%H-%M")
model_id = deployment.get("model_info", {}).get("id") model_id = deployment.get("model_info", {}).get("id")
rpm_key = f"{model_id}:rpm:{current_minute}" deployment_name = deployment.get("litellm_params", {}).get("model")
rpm_key = f"{model_id}:{deployment_name}:rpm:{current_minute}"
local_result = self.router_cache.get_cache( local_result = self.router_cache.get_cache(
key=rpm_key, local_only=True key=rpm_key, local_only=True
) # check local result first ) # check local result first
@ -151,7 +154,9 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
dt = get_utc_datetime() dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M") current_minute = dt.strftime("%H-%M")
model_id = deployment.get("model_info", {}).get("id") model_id = deployment.get("model_info", {}).get("id")
rpm_key = f"{model_id}:rpm:{current_minute}" deployment_name = deployment.get("litellm_params", {}).get("model")
rpm_key = f"{model_id}:{deployment_name}:rpm:{current_minute}"
local_result = await self.router_cache.async_get_cache( local_result = await self.router_cache.async_get_cache(
key=rpm_key, local_only=True key=rpm_key, local_only=True
) # check local result first ) # check local result first
@ -228,8 +233,9 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
if standard_logging_object is None: if standard_logging_object is None:
raise ValueError("standard_logging_object not passed in.") raise ValueError("standard_logging_object not passed in.")
model_group = standard_logging_object.get("model_group") model_group = standard_logging_object.get("model_group")
model = standard_logging_object["hidden_params"].get("litellm_model_name")
id = standard_logging_object.get("model_id") id = standard_logging_object.get("model_id")
if model_group is None or id is None: if model_group is None or id is None or model is None:
return return
elif isinstance(id, int): elif isinstance(id, int):
id = str(id) id = str(id)
@ -244,7 +250,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
"%H-%M" "%H-%M"
) # use the same timezone regardless of system clock ) # use the same timezone regardless of system clock
tpm_key = f"{id}:tpm:{current_minute}" tpm_key = f"{id}:{model}:tpm:{current_minute}"
# ------------ # ------------
# Update usage # Update usage
# ------------ # ------------
@ -276,6 +282,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
if standard_logging_object is None: if standard_logging_object is None:
raise ValueError("standard_logging_object not passed in.") raise ValueError("standard_logging_object not passed in.")
model_group = standard_logging_object.get("model_group") model_group = standard_logging_object.get("model_group")
model = standard_logging_object["hidden_params"]["litellm_model_name"]
id = standard_logging_object.get("model_id") id = standard_logging_object.get("model_id")
if model_group is None or id is None: if model_group is None or id is None:
return return
@ -290,7 +297,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
"%H-%M" "%H-%M"
) # use the same timezone regardless of system clock ) # use the same timezone regardless of system clock
tpm_key = f"{id}:tpm:{current_minute}" tpm_key = f"{id}:{model}:tpm:{current_minute}"
# ------------ # ------------
# Update usage # Update usage
# ------------ # ------------
@ -458,8 +465,9 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
id = m.get("model_info", {}).get( id = m.get("model_info", {}).get(
"id" "id"
) # a deployment should always have an 'id'. this is set in router.py ) # a deployment should always have an 'id'. this is set in router.py
tpm_key = "{}:tpm:{}".format(id, current_minute) deployment_name = m.get("litellm_params", {}).get("model")
rpm_key = "{}:rpm:{}".format(id, current_minute) tpm_key = "{}:{}:tpm:{}".format(id, deployment_name, current_minute)
rpm_key = "{}:{}:rpm:{}".format(id, deployment_name, current_minute)
tpm_keys.append(tpm_key) tpm_keys.append(tpm_key)
rpm_keys.append(rpm_key) rpm_keys.append(rpm_key)
@ -576,8 +584,9 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
id = m.get("model_info", {}).get( id = m.get("model_info", {}).get(
"id" "id"
) # a deployment should always have an 'id'. this is set in router.py ) # a deployment should always have an 'id'. this is set in router.py
tpm_key = "{}:tpm:{}".format(id, current_minute) deployment_name = m.get("litellm_params", {}).get("model")
rpm_key = "{}:rpm:{}".format(id, current_minute) tpm_key = "{}:{}:tpm:{}".format(id, deployment_name, current_minute)
rpm_key = "{}:{}:rpm:{}".format(id, deployment_name, current_minute)
tpm_keys.append(tpm_key) tpm_keys.append(tpm_key)
rpm_keys.append(rpm_key) rpm_keys.append(rpm_key)

View file

@ -1625,13 +1625,16 @@ class StandardLoggingAdditionalHeaders(TypedDict, total=False):
class StandardLoggingHiddenParams(TypedDict): class StandardLoggingHiddenParams(TypedDict):
model_id: Optional[str] model_id: Optional[
str
] # id of the model in the router, separates multiple models with the same name but different credentials
cache_key: Optional[str] cache_key: Optional[str]
api_base: Optional[str] api_base: Optional[str]
response_cost: Optional[str] response_cost: Optional[str]
litellm_overhead_time_ms: Optional[float] litellm_overhead_time_ms: Optional[float]
additional_headers: Optional[StandardLoggingAdditionalHeaders] additional_headers: Optional[StandardLoggingAdditionalHeaders]
batch_models: Optional[List[str]] batch_models: Optional[List[str]]
litellm_model_name: Optional[str] # the model name sent to the provider by litellm
class StandardLoggingModelInformation(TypedDict): class StandardLoggingModelInformation(TypedDict):

View file

@ -20,7 +20,8 @@ from litellm.caching.redis_cache import RedisCache
@pytest.mark.parametrize("namespace", [None, "test"]) @pytest.mark.parametrize("namespace", [None, "test"])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_redis_cache_async_increment(namespace): async def test_redis_cache_async_increment(namespace, monkeypatch):
monkeypatch.setenv("REDIS_HOST", "https://my-test-host")
redis_cache = RedisCache(namespace=namespace) redis_cache = RedisCache(namespace=namespace)
# Create an AsyncMock for the Redis client # Create an AsyncMock for the Redis client
mock_redis_instance = AsyncMock() mock_redis_instance = AsyncMock()
@ -46,7 +47,8 @@ async def test_redis_cache_async_increment(namespace):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_redis_client_init_with_socket_timeout(): async def test_redis_client_init_with_socket_timeout(monkeypatch):
monkeypatch.setenv("REDIS_HOST", "my-fake-host")
redis_cache = RedisCache(socket_timeout=1.0) redis_cache = RedisCache(socket_timeout=1.0)
assert redis_cache.redis_kwargs["socket_timeout"] == 1.0 assert redis_cache.redis_kwargs["socket_timeout"] == 1.0
client = redis_cache.init_async_client() client = redis_cache.init_async_client()

View file

@ -8,84 +8,24 @@ import pytest
sys.path.insert( sys.path.insert(
0, os.path.abspath("../../../../..") 0, os.path.abspath("../../../../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
import litellm from litellm.llms.azure_ai.chat.transformation import AzureAIStudioConfig
@pytest.mark.parametrize("is_async", [True, False])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_azure_ai_request_format(is_async): async def test_get_openai_compatible_provider_info():
""" """
Test that Azure AI requests are formatted correctly with the proper endpoint and parameters Test that Azure AI requests are formatted correctly with the proper endpoint and parameters
for both synchronous and asynchronous calls for both synchronous and asynchronous calls
""" """
litellm._turn_on_debug() config = AzureAIStudioConfig()
# Set up the test parameters api_base, dynamic_api_key, custom_llm_provider = (
api_key = "00xxx" config._get_openai_compatible_provider_info(
api_base = "https://my-endpoint-europe-berri-992.openai.azure.com/openai/deployments/gpt-4o-mini/chat/completions?api-version=2024-08-01-preview" model="azure_ai/gpt-4o-mini",
model = "azure_ai/gpt-4o-mini" api_base="https://my-base",
messages = [ api_key="my-key",
{"role": "user", "content": "hi"},
{"role": "assistant", "content": "Hello! How can I assist you today?"},
{"role": "user", "content": "hi"},
]
if is_async:
# Mock AsyncHTTPHandler.post method for async test
with patch(
"litellm.llms.custom_httpx.llm_http_handler.AsyncHTTPHandler.post"
) as mock_post:
# Set up mock response
mock_post.return_value = AsyncMock()
# Call the acompletion function
try:
await litellm.acompletion(
custom_llm_provider="azure_ai", custom_llm_provider="azure_ai",
api_key=api_key,
api_base=api_base,
model=model,
messages=messages,
) )
except Exception as e:
# We expect an exception since we're mocking the response
pass
else:
# Mock HTTPHandler.post method for sync test
with patch(
"litellm.llms.custom_httpx.llm_http_handler.HTTPHandler.post"
) as mock_post:
# Set up mock response
mock_post.return_value = MagicMock()
# Call the completion function
try:
litellm.completion(
custom_llm_provider="azure_ai",
api_key=api_key,
api_base=api_base,
model=model,
messages=messages,
) )
except Exception as e:
# We expect an exception since we're mocking the response
pass
# Verify the request was made with the correct parameters assert custom_llm_provider == "azure"
mock_post.assert_called_once()
call_args = mock_post.call_args
print("sync request call=", json.dumps(call_args.kwargs, indent=4, default=str))
# Check URL
assert call_args.kwargs["url"] == api_base
# Check headers
assert call_args.kwargs["headers"]["api-key"] == api_key
# Check request body
request_body = json.loads(call_args.kwargs["data"])
assert (
request_body["model"] == "gpt-4o-mini"
) # Model name should be stripped of provider prefix
assert request_body["messages"] == messages

View file

@ -265,3 +265,32 @@ class TestAzureAIRerank(BaseLLMRerankTest):
"api_base": os.getenv("AZURE_AI_COHERE_API_BASE"), "api_base": os.getenv("AZURE_AI_COHERE_API_BASE"),
"api_key": os.getenv("AZURE_AI_COHERE_API_KEY"), "api_key": os.getenv("AZURE_AI_COHERE_API_KEY"),
} }
@pytest.mark.asyncio
async def test_azure_ai_request_format():
"""
Test that Azure AI requests are formatted correctly with the proper endpoint and parameters
for both synchronous and asynchronous calls
"""
from openai import AsyncAzureOpenAI, AzureOpenAI
litellm._turn_on_debug()
# Set up the test parameters
api_key = os.getenv("AZURE_API_KEY")
api_base = f"{os.getenv('AZURE_API_BASE')}/openai/deployments/gpt-4o/chat/completions?api-version=2024-08-01-preview"
model = "azure_ai/gpt-4o"
messages = [
{"role": "user", "content": "hi"},
{"role": "assistant", "content": "Hello! How can I assist you today?"},
{"role": "user", "content": "hi"},
]
await litellm.acompletion(
custom_llm_provider="azure_ai",
api_key=api_key,
api_base=api_base,
model=model,
messages=messages,
)

View file

@ -2,7 +2,9 @@ import asyncio
import json import json
import logging import logging
import os import os
import time
from unittest.mock import patch, Mock
import opentelemetry.exporter.otlp.proto.grpc.trace_exporter
from litellm import Choices from litellm import Choices
import pytest import pytest
from dotenv import load_dotenv from dotenv import load_dotenv
@ -81,3 +83,52 @@ def test_get_arize_config_with_endpoints(mock_env_vars, monkeypatch):
config = ArizeLogger.get_arize_config() config = ArizeLogger.get_arize_config()
assert config.endpoint == "grpc://test.endpoint" assert config.endpoint == "grpc://test.endpoint"
assert config.protocol == "otlp_grpc" assert config.protocol == "otlp_grpc"
@pytest.mark.skip(
reason="Works locally but not in CI/CD. We'll need a better way to test Arize on CI/CD"
)
def test_arize_callback():
litellm.callbacks = ["arize"]
os.environ["ARIZE_SPACE_KEY"] = "test_space_key"
os.environ["ARIZE_API_KEY"] = "test_api_key"
os.environ["ARIZE_ENDPOINT"] = "https://otlp.arize.com/v1"
# Set the batch span processor to quickly flush after a span has been added
# This is to ensure that the span is exported before the test ends
os.environ["OTEL_BSP_MAX_QUEUE_SIZE"] = "1"
os.environ["OTEL_BSP_MAX_EXPORT_BATCH_SIZE"] = "1"
os.environ["OTEL_BSP_SCHEDULE_DELAY_MILLIS"] = "1"
os.environ["OTEL_BSP_EXPORT_TIMEOUT_MILLIS"] = "5"
try:
with patch.object(
opentelemetry.exporter.otlp.proto.grpc.trace_exporter.OTLPSpanExporter,
"export",
new=Mock(),
) as patched_export:
litellm.completion(
model="openai/test-model",
messages=[{"role": "user", "content": "arize test content"}],
stream=False,
mock_response="hello there!",
)
time.sleep(1) # Wait for the batch span processor to flush
assert patched_export.called
finally:
# Clean up environment variables
for key in [
"ARIZE_SPACE_KEY",
"ARIZE_API_KEY",
"ARIZE_ENDPOINT",
"OTEL_BSP_MAX_QUEUE_SIZE",
"OTEL_BSP_MAX_EXPORT_BATCH_SIZE",
"OTEL_BSP_SCHEDULE_DELAY_MILLIS",
"OTEL_BSP_EXPORT_TIMEOUT_MILLIS",
]:
if key in os.environ:
del os.environ[key]
# Reset callbacks
litellm.callbacks = []

View file

@ -216,3 +216,21 @@ def test_bedrock_invoke_anthropic():
) )
assert custom_llm_provider == "bedrock" assert custom_llm_provider == "bedrock"
assert model == "invoke/anthropic.claude-3-5-sonnet-20240620-v1:0" assert model == "invoke/anthropic.claude-3-5-sonnet-20240620-v1:0"
@pytest.mark.parametrize("model", ["xai/grok-2-vision-latest", "grok-2-vision-latest"])
def test_xai_api_base(model):
args = {
"model": model,
"custom_llm_provider": "xai",
"api_base": None,
"api_key": "xai-my-specialkey",
"litellm_params": None,
}
model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider(
**args
)
assert custom_llm_provider == "xai"
assert model == "grok-2-vision-latest"
assert api_base == "https://api.x.ai/v1"
assert dynamic_api_key == "xai-my-specialkey"

View file

@ -20,7 +20,7 @@ sys.path.insert(
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
from litellm.types.utils import StandardLoggingPayload from litellm.types.utils import StandardLoggingPayload
import pytest import pytest
from litellm.types.router import DeploymentTypedDict
import litellm import litellm
from litellm import Router from litellm import Router
from litellm.caching.caching import DualCache from litellm.caching.caching import DualCache
@ -47,12 +47,14 @@ def test_tpm_rpm_updated():
deployment_id = "1234" deployment_id = "1234"
deployment = "azure/chatgpt-v-2" deployment = "azure/chatgpt-v-2"
total_tokens = 50 total_tokens = 50
standard_logging_payload = create_standard_logging_payload() standard_logging_payload: StandardLoggingPayload = create_standard_logging_payload()
standard_logging_payload["model_group"] = model_group standard_logging_payload["model_group"] = model_group
standard_logging_payload["model_id"] = deployment_id standard_logging_payload["model_id"] = deployment_id
standard_logging_payload["total_tokens"] = total_tokens standard_logging_payload["total_tokens"] = total_tokens
standard_logging_payload["hidden_params"]["litellm_model_name"] = deployment
kwargs = { kwargs = {
"litellm_params": { "litellm_params": {
"model": deployment,
"metadata": { "metadata": {
"model_group": model_group, "model_group": model_group,
"deployment": deployment, "deployment": deployment,
@ -62,10 +64,16 @@ def test_tpm_rpm_updated():
"standard_logging_object": standard_logging_payload, "standard_logging_object": standard_logging_payload,
} }
litellm_deployment_dict: DeploymentTypedDict = {
"model_name": model_group,
"litellm_params": {"model": deployment},
"model_info": {"id": deployment_id},
}
start_time = time.time() start_time = time.time()
response_obj = {"usage": {"total_tokens": total_tokens}} response_obj = {"usage": {"total_tokens": total_tokens}}
end_time = time.time() end_time = time.time()
lowest_tpm_logger.pre_call_check(deployment=kwargs["litellm_params"]) lowest_tpm_logger.pre_call_check(deployment=litellm_deployment_dict)
lowest_tpm_logger.log_success_event( lowest_tpm_logger.log_success_event(
response_obj=response_obj, response_obj=response_obj,
kwargs=kwargs, kwargs=kwargs,
@ -74,8 +82,8 @@ def test_tpm_rpm_updated():
) )
dt = get_utc_datetime() dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M") current_minute = dt.strftime("%H-%M")
tpm_count_api_key = f"{deployment_id}:tpm:{current_minute}" tpm_count_api_key = f"{deployment_id}:{deployment}:tpm:{current_minute}"
rpm_count_api_key = f"{deployment_id}:rpm:{current_minute}" rpm_count_api_key = f"{deployment_id}:{deployment}:rpm:{current_minute}"
print(f"tpm_count_api_key={tpm_count_api_key}") print(f"tpm_count_api_key={tpm_count_api_key}")
assert response_obj["usage"]["total_tokens"] == test_cache.get_cache( assert response_obj["usage"]["total_tokens"] == test_cache.get_cache(
@ -113,6 +121,7 @@ def test_get_available_deployments():
standard_logging_payload["model_group"] = model_group standard_logging_payload["model_group"] = model_group
standard_logging_payload["model_id"] = deployment_id standard_logging_payload["model_id"] = deployment_id
standard_logging_payload["total_tokens"] = total_tokens standard_logging_payload["total_tokens"] = total_tokens
standard_logging_payload["hidden_params"]["litellm_model_name"] = deployment
kwargs = { kwargs = {
"litellm_params": { "litellm_params": {
"metadata": { "metadata": {
@ -135,10 +144,11 @@ def test_get_available_deployments():
## DEPLOYMENT 2 ## ## DEPLOYMENT 2 ##
total_tokens = 20 total_tokens = 20
deployment_id = "5678" deployment_id = "5678"
standard_logging_payload = create_standard_logging_payload() standard_logging_payload: StandardLoggingPayload = create_standard_logging_payload()
standard_logging_payload["model_group"] = model_group standard_logging_payload["model_group"] = model_group
standard_logging_payload["model_id"] = deployment_id standard_logging_payload["model_id"] = deployment_id
standard_logging_payload["total_tokens"] = total_tokens standard_logging_payload["total_tokens"] = total_tokens
standard_logging_payload["hidden_params"]["litellm_model_name"] = deployment
kwargs = { kwargs = {
"litellm_params": { "litellm_params": {
"metadata": { "metadata": {
@ -209,11 +219,12 @@ def test_router_get_available_deployments():
print(f"router id's: {router.get_model_ids()}") print(f"router id's: {router.get_model_ids()}")
## DEPLOYMENT 1 ## ## DEPLOYMENT 1 ##
deployment_id = 1 deployment_id = 1
standard_logging_payload = create_standard_logging_payload() standard_logging_payload: StandardLoggingPayload = create_standard_logging_payload()
standard_logging_payload["model_group"] = "azure-model" standard_logging_payload["model_group"] = "azure-model"
standard_logging_payload["model_id"] = str(deployment_id) standard_logging_payload["model_id"] = str(deployment_id)
total_tokens = 50 total_tokens = 50
standard_logging_payload["total_tokens"] = total_tokens standard_logging_payload["total_tokens"] = total_tokens
standard_logging_payload["hidden_params"]["litellm_model_name"] = "azure/gpt-turbo"
kwargs = { kwargs = {
"litellm_params": { "litellm_params": {
"metadata": { "metadata": {
@ -237,6 +248,9 @@ def test_router_get_available_deployments():
standard_logging_payload = create_standard_logging_payload() standard_logging_payload = create_standard_logging_payload()
standard_logging_payload["model_group"] = "azure-model" standard_logging_payload["model_group"] = "azure-model"
standard_logging_payload["model_id"] = str(deployment_id) standard_logging_payload["model_id"] = str(deployment_id)
standard_logging_payload["hidden_params"][
"litellm_model_name"
] = "azure/gpt-35-turbo"
kwargs = { kwargs = {
"litellm_params": { "litellm_params": {
"metadata": { "metadata": {
@ -293,10 +307,11 @@ def test_router_skip_rate_limited_deployments():
## DEPLOYMENT 1 ## ## DEPLOYMENT 1 ##
deployment_id = 1 deployment_id = 1
total_tokens = 1439 total_tokens = 1439
standard_logging_payload = create_standard_logging_payload() standard_logging_payload: StandardLoggingPayload = create_standard_logging_payload()
standard_logging_payload["model_group"] = "azure-model" standard_logging_payload["model_group"] = "azure-model"
standard_logging_payload["model_id"] = str(deployment_id) standard_logging_payload["model_id"] = str(deployment_id)
standard_logging_payload["total_tokens"] = total_tokens standard_logging_payload["total_tokens"] = total_tokens
standard_logging_payload["hidden_params"]["litellm_model_name"] = "azure/gpt-turbo"
kwargs = { kwargs = {
"litellm_params": { "litellm_params": {
"metadata": { "metadata": {
@ -699,3 +714,54 @@ def test_return_potential_deployments():
) )
assert len(potential_deployments) == 1 assert len(potential_deployments) == 1
@pytest.mark.asyncio
async def test_tpm_rpm_routing_model_name_checks():
deployment = {
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_base": os.getenv("AZURE_API_BASE"),
"mock_response": "Hey, how's it going?",
},
}
router = Router(model_list=[deployment], routing_strategy="usage-based-routing-v2")
async def side_effect_pre_call_check(*args, **kwargs):
return args[0]
with patch.object(
router.lowesttpm_logger_v2,
"async_pre_call_check",
side_effect=side_effect_pre_call_check,
) as mock_object, patch.object(
router.lowesttpm_logger_v2, "async_log_success_event"
) as mock_logging_event:
response = await router.acompletion(
model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hey!"}]
)
mock_object.assert_called()
print(f"mock_object.call_args: {mock_object.call_args[0][0]}")
assert (
mock_object.call_args[0][0]["litellm_params"]["model"]
== deployment["litellm_params"]["model"]
)
await asyncio.sleep(1)
mock_logging_event.assert_called()
print(f"mock_logging_event: {mock_logging_event.call_args.kwargs}")
standard_logging_payload: StandardLoggingPayload = (
mock_logging_event.call_args.kwargs.get("kwargs", {}).get(
"standard_logging_object"
)
)
assert (
standard_logging_payload["hidden_params"]["litellm_model_name"]
== "azure/chatgpt-v-2"
)

View file

@ -15,35 +15,6 @@ import litellm
from litellm.types.utils import Choices from litellm.types.utils import Choices
def test_arize_callback():
litellm.callbacks = ["arize"]
os.environ["ARIZE_SPACE_KEY"] = "test_space_key"
os.environ["ARIZE_API_KEY"] = "test_api_key"
os.environ["ARIZE_ENDPOINT"] = "https://otlp.arize.com/v1"
# Set the batch span processor to quickly flush after a span has been added
# This is to ensure that the span is exported before the test ends
os.environ["OTEL_BSP_MAX_QUEUE_SIZE"] = "1"
os.environ["OTEL_BSP_MAX_EXPORT_BATCH_SIZE"] = "1"
os.environ["OTEL_BSP_SCHEDULE_DELAY_MILLIS"] = "1"
os.environ["OTEL_BSP_EXPORT_TIMEOUT_MILLIS"] = "5"
with patch.object(
opentelemetry.exporter.otlp.proto.grpc.trace_exporter.OTLPSpanExporter,
"export",
new=Mock(),
) as patched_export:
completion(
model="openai/test-model",
messages=[{"role": "user", "content": "arize test content"}],
stream=False,
mock_response="hello there!",
)
time.sleep(1) # Wait for the batch span processor to flush
assert patched_export.called
def test_arize_set_attributes(): def test_arize_set_attributes():
""" """
Test setting attributes for Arize Test setting attributes for Arize