diff --git a/litellm/tests/test_get_model_cost_map.py b/litellm/tests/test_get_model_cost_map.py index 8a86cc51f..8486ae199 100644 --- a/litellm/tests/test_get_model_cost_map.py +++ b/litellm/tests/test_get_model_cost_map.py @@ -7,8 +7,6 @@ sys.path.insert( import time from litellm import get_max_tokens, model_cost, open_ai_chat_completion_models -print(get_max_tokens("gpt-3.5-turbo")) - def test_get_gpt3_tokens(): max_tokens = get_max_tokens("gpt-3.5-turbo") results = max_tokens['max_tokens'] @@ -21,3 +19,10 @@ def test_get_palm_tokens(): results = max_tokens['max_tokens'] print(results) # test_get_palm_tokens() + +def test_zephyr_hf_tokens(): + max_tokens = get_max_tokens("huggingface/HuggingFaceH4/zephyr-7b-beta") + results = max_tokens["max_tokens"] + print(results) + +test_zephyr_hf_tokens() \ No newline at end of file diff --git a/litellm/utils.py b/litellm/utils.py index a4d1016a1..31581253b 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2354,8 +2354,42 @@ def get_max_tokens(model: str): "mode": "chat" } """ + def _get_max_position_embeddings(model_name): + # Construct the URL for the config.json file + config_url = f"https://huggingface.co/{model_name}/raw/main/config.json" + + try: + # Make the HTTP request to get the raw JSON file + response = requests.get(config_url) + response.raise_for_status() # Raise an exception for bad responses (4xx or 5xx) + + # Parse the JSON response + config_json = response.json() + + # Extract and return the max_position_embeddings + max_position_embeddings = config_json.get("max_position_embeddings") + + if max_position_embeddings is not None: + return max_position_embeddings + else: + return None + except requests.exceptions.RequestException as e: + return None try: - return litellm.model_cost[model] + if model in litellm.model_cost: + return litellm.model_cost[model] + model, custom_llm_provider, _, _ = get_llm_provider(model=model) + if custom_llm_provider == "huggingface": + max_tokens = _get_max_position_embeddings(model_name=model) + return { + "max_tokens": max_tokens, + "input_cost_per_token": 0, + "output_cost_per_token": 0, + "litellm_provider": "huggingface", + "mode": "chat" + } + else: + raise Exception() except: raise Exception("This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json")