mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Add support for vertex_ai/ model name prefix format
This commit adds support for using the vertex_ai/ prefix with models in the format: vertex_ai/claude-3-7-sonnet@20250219 The implementation: 1. Adds explicit handling for the vertex_ai/ prefix in get_llm_provider() 2. Correctly sets the custom_llm_provider to 'vertex_ai' 3. Returns the base model name without the prefix 4. Includes tests to verify the functionality This allows for proper provider detection and cost calculation with Vertex AI hosted models.
This commit is contained in:
parent
ce828408da
commit
4bae7df72d
2 changed files with 61 additions and 0 deletions
|
@ -121,6 +121,12 @@ def get_llm_provider( # noqa: PLR0915
|
|||
custom_llm_provider = "openai"
|
||||
return model, custom_llm_provider, dynamic_api_key, api_base
|
||||
|
||||
# Handle vertex_ai prefix - allows model format vertex_ai/claude-3-7-sonnet@20250219
|
||||
if model.split("/", 1)[0].lower() == "vertex_ai":
|
||||
base_model = model.split("/", 1)[1]
|
||||
custom_llm_provider = "vertex_ai"
|
||||
return base_model, custom_llm_provider, dynamic_api_key, api_base
|
||||
|
||||
### Handle cases when custom_llm_provider is set to cohere/command-r-plus but it should use cohere_chat route
|
||||
model, custom_llm_provider = handle_cohere_chat_model_custom_llm_provider(
|
||||
model, custom_llm_provider
|
||||
|
|
55
litellm/tests/local_testing/test_vertex_ai_provider.py
Normal file
55
litellm/tests/local_testing/test_vertex_ai_provider.py
Normal file
|
@ -0,0 +1,55 @@
|
|||
import os
|
||||
import sys
|
||||
import unittest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider
|
||||
from litellm.cost_calculator import completion_cost
|
||||
|
||||
|
||||
class TestVertexAIProvider(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# Ensure vertex_ai is in the provider list for testing
|
||||
if "vertex_ai" not in [p.value for p in litellm.provider_list]:
|
||||
# This is a mock addition just for testing
|
||||
litellm.provider_list.append(litellm.types.utils.LlmProviders("vertex_ai"))
|
||||
|
||||
def test_get_llm_provider_vertex_ai(self):
|
||||
"""Test get_llm_provider correctly identifies 'vertex_ai/' prefix."""
|
||||
# Case with vertex_ai prefix
|
||||
model, provider, _, _ = get_llm_provider(model="vertex_ai/claude-3-7-sonnet@20250219")
|
||||
self.assertEqual(model, "claude-3-7-sonnet@20250219")
|
||||
self.assertEqual(provider, "vertex_ai")
|
||||
|
||||
# No need to test model name formatting anymore - we're now using the original model name
|
||||
|
||||
def test_end_to_end_cost_calculation(self):
|
||||
"""Test the end-to-end cost calculation pipeline for vertex_ai models with date formats."""
|
||||
# Create a mock response object with usage information
|
||||
mock_response = MagicMock()
|
||||
mock_response.usage = {"prompt_tokens": 100, "completion_tokens": 50}
|
||||
|
||||
# This is a simple test that verifies the cost calculation process doesn't fail
|
||||
# and that the model name with @ symbol works correctly with the vertex_ai provider
|
||||
try:
|
||||
# We're not testing the actual cost values here, just that the pipeline executes without error
|
||||
completion_cost(
|
||||
completion_response=mock_response,
|
||||
model="vertex_ai/claude-3-sonnet@20240229", # Use a model known to exist in the database
|
||||
custom_llm_provider="vertex_ai"
|
||||
)
|
||||
success = True
|
||||
except Exception as e:
|
||||
success = False
|
||||
# Avoid using print for linting reasons
|
||||
error_msg = f"Cost calculation failed: {str(e)}"
|
||||
self.fail(error_msg)
|
||||
|
||||
self.assertTrue(success, "Cost calculation should complete without errors")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Add table
Add a link
Reference in a new issue