# What is this? ## Unit testing for the 'get_model_info()' function import os import sys import traceback import json from jsonschema import validate from typing import List, Dict, Any sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system-path import pytest import litellm from litellm import get_model_info from unittest.mock import AsyncMock, MagicMock, patch def test_get_model_info_simple_model_name(): """ tests if model name given, and model exists in model info - the object is returned """ model = "claude-3-opus-20240229" litellm.get_model_info(model) def test_get_model_info_custom_llm_with_model_name(): """ Tests if {custom_llm_provider}/{model_name} name given, and model exists in model info, the object is returned """ model = "anthropic/claude-3-opus-20240229" litellm.get_model_info(model) def test_get_model_info_custom_llm_with_same_name_vllm(monkeypatch): """ Tests if {custom_llm_provider}/{model_name} name given, and model exists in model info, the object is returned """ model = "command-r-plus" provider = "openai" # vllm is openai-compatible litellm.register_model( { "openai/command-r-plus": { "input_cost_per_token": 0.0, "output_cost_per_token": 0.0, }, } ) model_info = litellm.get_model_info(model, custom_llm_provider=provider) print("model_info", model_info) assert model_info["input_cost_per_token"] == 0.0 def test_get_model_info_shows_correct_supports_vision(): info = litellm.get_model_info("gemini/gemini-1.5-flash") print("info", info) assert info["supports_vision"] is True def test_get_model_info_shows_assistant_prefill(): os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" litellm.model_cost = litellm.get_model_cost_map(url="") info = litellm.get_model_info("deepseek/deepseek-chat") print("info", info) assert info.get("supports_assistant_prefill") is True def test_get_model_info_shows_supports_prompt_caching(): os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" litellm.model_cost = litellm.get_model_cost_map(url="") info = litellm.get_model_info("deepseek/deepseek-chat") print("info", info) assert info.get("supports_prompt_caching") is True def test_get_model_info_finetuned_models(): info = litellm.get_model_info("ft:gpt-3.5-turbo:my-org:custom_suffix:id") print("info", info) assert info["input_cost_per_token"] == 0.000003 def test_get_model_info_gemini_pro(): info = litellm.get_model_info("gemini-1.5-pro-002") print("info", info) assert info["key"] == "gemini-1.5-pro-002" def test_get_model_info_ollama_chat(): from litellm.llms.ollama.completion.transformation import OllamaConfig with patch.object( litellm.module_level_client, "post", return_value=MagicMock( json=lambda: { "model_info": {"llama.context_length": 32768}, "template": "tools", } ), ) as mock_client: info = OllamaConfig().get_model_info("mistral") assert info["supports_function_calling"] is True info = get_model_info("ollama/mistral") print("info", info) assert info["supports_function_calling"] is True mock_client.assert_called() print(mock_client.call_args.kwargs) assert mock_client.call_args.kwargs["json"]["name"] == "mistral" def test_get_model_info_gemini(): """ Tests if ALL gemini models have 'tpm' and 'rpm' in the model info """ os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" litellm.model_cost = litellm.get_model_cost_map(url="") model_map = litellm.model_cost for model, info in model_map.items(): if ( model.startswith("gemini/") and not "gemma" in model and not "learnlm" in model ): assert info.get("tpm") is not None, f"{model} does not have tpm" assert info.get("rpm") is not None, f"{model} does not have rpm" def test_get_model_info_bedrock_region(): os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" litellm.model_cost = litellm.get_model_cost_map(url="") args = { "model": "us.anthropic.claude-3-5-sonnet-20241022-v2:0", "custom_llm_provider": "bedrock", } litellm.model_cost.pop("us.anthropic.claude-3-5-sonnet-20241022-v2:0", None) info = litellm.get_model_info(**args) print("info", info) assert info["key"] == "anthropic.claude-3-5-sonnet-20241022-v2:0" assert info["litellm_provider"] == "bedrock" @pytest.mark.parametrize( "model", [ "ft:gpt-3.5-turbo:my-org:custom_suffix:id", "ft:gpt-4-0613:my-org:custom_suffix:id", "ft:davinci-002:my-org:custom_suffix:id", "ft:gpt-4-0613:my-org:custom_suffix:id", "ft:babbage-002:my-org:custom_suffix:id", "gpt-35-turbo", "ada", ], ) def test_get_model_info_completion_cost_unit_tests(model): info = litellm.get_model_info(model) print("info", info) def test_get_model_info_ft_model_with_provider_prefix(): args = { "model": "openai/ft:gpt-3.5-turbo:my-org:custom_suffix:id", "custom_llm_provider": "openai", } info = litellm.get_model_info(**args) print("info", info) assert info["key"] == "ft:gpt-3.5-turbo" def test_get_whitelisted_models(): """ Snapshot of all bedrock models as of 12/24/2024. Enforce any new bedrock chat model to be added as `bedrock_converse` unless explicitly whitelisted. Create whitelist to prevent naming regressions for older litellm versions. """ whitelisted_models = [] for model, info in litellm.model_cost.items(): if info["litellm_provider"] == "bedrock" and info["mode"] == "chat": whitelisted_models.append(model) # Write to a local file with open("whitelisted_bedrock_models.txt", "w") as file: for model in whitelisted_models: file.write(f"{model}\n") print("whitelisted_models written to whitelisted_bedrock_models.txt") def _enforce_bedrock_converse_models( model_cost: List[Dict[str, Any]], whitelist_models: List[str] ): """ Assert all new bedrock chat models are added as `bedrock_converse` unless explicitly whitelisted. """ # Check for unwhitelisted models for model, info in litellm.model_cost.items(): if ( info["litellm_provider"] == "bedrock" and info["mode"] == "chat" and model not in whitelist_models ): raise AssertionError( f"New bedrock chat model detected: {model}. Please set `litellm_provider='bedrock_converse'` for this model." ) def test_model_info_bedrock_converse(monkeypatch): """ Assert all new bedrock chat models are added as `bedrock_converse` unless explicitly whitelisted. This ensures they are automatically routed to the converse endpoint. """ monkeypatch.setenv("LITELLM_LOCAL_MODEL_COST_MAP", "True") litellm.model_cost = litellm.get_model_cost_map(url="") try: # Load whitelist models from file with open("whitelisted_bedrock_models.txt", "r") as file: whitelist_models = [line.strip() for line in file.readlines()] except FileNotFoundError: pytest.skip("whitelisted_bedrock_models.txt not found") _enforce_bedrock_converse_models( model_cost=litellm.model_cost, whitelist_models=whitelist_models ) @pytest.mark.flaky(retries=6, delay=2) def test_model_info_bedrock_converse_enforcement(monkeypatch): """ Test the enforcement of the whitelist by adding a fake model and ensuring the test fails. """ monkeypatch.setenv("LITELLM_LOCAL_MODEL_COST_MAP", "True") litellm.model_cost = litellm.get_model_cost_map(url="") # Add a fake unwhitelisted model litellm.model_cost["fake.bedrock-chat-model"] = { "litellm_provider": "bedrock", "mode": "chat", } try: # Load whitelist models from file with open("whitelisted_bedrock_models.txt", "r") as file: whitelist_models = [line.strip() for line in file.readlines()] # Check for unwhitelisted models with pytest.raises(AssertionError): _enforce_bedrock_converse_models( model_cost=litellm.model_cost, whitelist_models=whitelist_models ) except FileNotFoundError as e: pytest.skip("whitelisted_bedrock_models.txt not found") def test_get_model_info_custom_provider(): # Custom provider example copied from https://docs.litellm.ai/docs/providers/custom_llm_server: import litellm from litellm import CustomLLM, completion, get_llm_provider class MyCustomLLM(CustomLLM): def completion(self, *args, **kwargs) -> litellm.ModelResponse: return litellm.completion( model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hello world"}], mock_response="Hi!", ) # type: ignore my_custom_llm = MyCustomLLM() litellm.custom_provider_map = [ # 👈 KEY STEP - REGISTER HANDLER {"provider": "my-custom-llm", "custom_handler": my_custom_llm} ] resp = completion( model="my-custom-llm/my-fake-model", messages=[{"role": "user", "content": "Hello world!"}], ) assert resp.choices[0].message.content == "Hi!" # Register model info model_info = {"my-custom-llm/my-fake-model": {"max_tokens": 2048}} litellm.register_model(model_info) # Get registered model info from litellm import get_model_info get_model_info( model="my-custom-llm/my-fake-model" ) # 💥 "Exception: This model isn't mapped yet." in v1.56.10 def test_get_model_info_custom_model_router(): from litellm import Router from litellm import get_model_info litellm._turn_on_debug() router = Router( model_list=[ { "model_name": "ma-summary", "litellm_params": { "api_base": "http://ma-mix-llm-serving.cicero.svc.cluster.local/v1", "input_cost_per_token": 1, "output_cost_per_token": 1, "model": "openai/meta-llama/Meta-Llama-3-8B-Instruct", }, "model_info": { "id": "c20d603e-1166-4e0f-aa65-ed9c476ad4ca", } } ] ) info = get_model_info("c20d603e-1166-4e0f-aa65-ed9c476ad4ca") print("info", info) assert info is not None def test_get_model_info_bedrock_models(): """ Check for drift in base model info for bedrock models and regional model info for bedrock models. """ from litellm.llms.bedrock.common_utils import BedrockModelInfo os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" litellm.model_cost = litellm.get_model_cost_map(url="") for k, v in litellm.model_cost.items(): if v["litellm_provider"] == "bedrock": k = k.replace("*/", "") potential_commitments = [ "1-month-commitment", "3-month-commitment", "6-month-commitment", ] if any(commitment in k for commitment in potential_commitments): for commitment in potential_commitments: k = k.replace(f"{commitment}/", "") base_model = BedrockModelInfo.get_base_model(k) base_model_info = litellm.model_cost[base_model] for base_model_key, base_model_value in base_model_info.items(): if "invoke/" in k: continue if base_model_key.startswith("supports_"): assert ( base_model_key in v ), f"{base_model_key} is not in model cost map for {k}" assert ( v[base_model_key] == base_model_value ), f"{base_model_key} is not equal to {base_model_value} for model {k}" def test_get_model_info_huggingface_models(monkeypatch): from litellm import Router from litellm.types.router import ModelGroupInfo monkeypatch.setenv("HUGGINGFACE_API_KEY", "hf_abc123") router = Router( model_list=[ { "model_name": "meta-llama/Meta-Llama-3-8B-Instruct", "litellm_params": { "model": "huggingface/meta-llama/Meta-Llama-3-8B-Instruct", "api_base": "https://router.huggingface.co/hf-inference/models/meta-llama/Meta-Llama-3-8B-Instruct", "api_key": os.environ["HUGGINGFACE_API_KEY"], }, } ] ) info = litellm.get_model_info("huggingface/meta-llama/Meta-Llama-3-8B-Instruct") print("info", info) assert info is not None ModelGroupInfo( model_group="meta-llama/Meta-Llama-3-8B-Instruct", providers=["huggingface"], **info, ) @pytest.mark.parametrize( "model, provider", [ ("bedrock/us-east-2/us.anthropic.claude-3-haiku-20240307-v1:0", None), ( "bedrock/us-east-2/us.anthropic.claude-3-haiku-20240307-v1:0", "bedrock", ), ], ) def test_get_model_info_cost_calculator_bedrock_region_cris_stripped(model, provider): """ ensure cross region inferencing model is used correctly Relevant Issue: https://github.com/BerriAI/litellm/issues/8115 """ info = get_model_info(model=model, custom_llm_provider=provider) print("info", info) assert info["key"] == "us.anthropic.claude-3-haiku-20240307-v1:0" assert info["litellm_provider"] == "bedrock" def test_aaamodel_prices_and_context_window_json_is_valid(): """ Validates the `model_prices_and_context_window.json` file. If this test fails after you update the json, you need to update the schema or correct the change you made. """ INTENDED_SCHEMA = { "type": "object", "additionalProperties": { "type": "object", "properties": { "cache_creation_input_audio_token_cost": {"type": "number"}, "cache_creation_input_token_cost": {"type": "number"}, "cache_read_input_token_cost": {"type": "number"}, "cache_read_input_audio_token_cost": {"type": "number"}, "deprecation_date": {"type": "string"}, "input_cost_per_audio_per_second": {"type": "number"}, "input_cost_per_audio_per_second_above_128k_tokens": {"type": "number"}, "input_cost_per_audio_token": {"type": "number"}, "input_cost_per_character": {"type": "number"}, "input_cost_per_character_above_128k_tokens": {"type": "number"}, "input_cost_per_image": {"type": "number"}, "input_cost_per_image_above_128k_tokens": {"type": "number"}, "input_cost_per_token_above_200k_tokens": {"type": "number"}, "input_cost_per_pixel": {"type": "number"}, "input_cost_per_query": {"type": "number"}, "input_cost_per_request": {"type": "number"}, "input_cost_per_second": {"type": "number"}, "input_cost_per_token": {"type": "number"}, "input_cost_per_token_above_128k_tokens": {"type": "number"}, "input_cost_per_token_batch_requests": {"type": "number"}, "input_cost_per_token_batches": {"type": "number"}, "input_cost_per_token_cache_hit": {"type": "number"}, "input_cost_per_video_per_second": {"type": "number"}, "input_cost_per_video_per_second_above_8s_interval": {"type": "number"}, "input_cost_per_video_per_second_above_15s_interval": { "type": "number" }, "input_cost_per_video_per_second_above_128k_tokens": {"type": "number"}, "input_dbu_cost_per_token": {"type": "number"}, "litellm_provider": {"type": "string"}, "max_audio_length_hours": {"type": "number"}, "max_audio_per_prompt": {"type": "number"}, "max_document_chunks_per_query": {"type": "number"}, "max_images_per_prompt": {"type": "number"}, "max_input_tokens": {"type": "number"}, "max_output_tokens": {"type": "number"}, "max_pdf_size_mb": {"type": "number"}, "max_query_tokens": {"type": "number"}, "max_tokens": {"type": "number"}, "max_tokens_per_document_chunk": {"type": "number"}, "max_video_length": {"type": "number"}, "max_videos_per_prompt": {"type": "number"}, "metadata": {"type": "object"}, "mode": { "type": "string", "enum": [ "audio_speech", "audio_transcription", "chat", "completion", "embedding", "image_generation", "moderation", "rerank", "responses", ], }, "output_cost_per_audio_token": {"type": "number"}, "output_cost_per_character": {"type": "number"}, "output_cost_per_character_above_128k_tokens": {"type": "number"}, "output_cost_per_image": {"type": "number"}, "output_cost_per_pixel": {"type": "number"}, "output_cost_per_second": {"type": "number"}, "output_cost_per_token": {"type": "number"}, "output_cost_per_token_above_128k_tokens": {"type": "number"}, "output_cost_per_token_above_200k_tokens": {"type": "number"}, "output_cost_per_token_batches": {"type": "number"}, "output_cost_per_reasoning_token": {"type": "number"}, "output_db_cost_per_token": {"type": "number"}, "output_dbu_cost_per_token": {"type": "number"}, "output_vector_size": {"type": "number"}, "rpd": {"type": "number"}, "rpm": {"type": "number"}, "source": {"type": "string"}, "supports_assistant_prefill": {"type": "boolean"}, "supports_audio_input": {"type": "boolean"}, "supports_audio_output": {"type": "boolean"}, "supports_embedding_image_input": {"type": "boolean"}, "supports_function_calling": {"type": "boolean"}, "supports_image_input": {"type": "boolean"}, "supports_parallel_function_calling": {"type": "boolean"}, "supports_pdf_input": {"type": "boolean"}, "supports_prompt_caching": {"type": "boolean"}, "supports_response_schema": {"type": "boolean"}, "supports_system_messages": {"type": "boolean"}, "supports_tool_choice": {"type": "boolean"}, "supports_video_input": {"type": "boolean"}, "supports_vision": {"type": "boolean"}, "supports_web_search": {"type": "boolean"}, "supports_reasoning": {"type": "boolean"}, "tool_use_system_prompt_tokens": {"type": "number"}, "tpm": {"type": "number"}, "supported_endpoints": { "type": "array", "items": { "type": "string", "enum": [ "/v1/responses", "/v1/embeddings", "/v1/chat/completions", "/v1/completions", "/v1/images/generations", "/v1/images/variations", "/v1/images/edits", "/v1/batch", "/v1/audio/transcriptions", "/v1/audio/speech", ], }, }, "search_context_cost_per_query": { "type": "object", "properties": { "search_context_size_low": {"type": "number"}, "search_context_size_medium": {"type": "number"}, "search_context_size_high": {"type": "number"}, }, "additionalProperties": False, }, "supported_modalities": { "type": "array", "items": { "type": "string", "enum": ["text", "audio", "image", "video"], }, }, "supported_output_modalities": { "type": "array", "items": { "type": "string", "enum": ["text", "image", "audio"], }, }, "supports_native_streaming": {"type": "boolean"}, }, "additionalProperties": False, }, } prod_json = "./model_prices_and_context_window.json" with open(prod_json, "r") as model_prices_file: actual_json = json.load(model_prices_file) assert isinstance(actual_json, dict) actual_json.pop( "sample_spec", None ) # remove the sample, whose schema is inconsistent with the real data validate(actual_json, INTENDED_SCHEMA)