diff --git a/litellm/__init__.py b/litellm/__init__.py index 8b7a13e4f..d1a9dc302 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -346,6 +346,7 @@ from .utils import ( acreate, get_model_list, get_max_tokens, + get_model_info, register_prompt_template, validate_environment, check_valid_key, diff --git a/litellm/tests/test_get_model_cost_map.py b/litellm/tests/test_get_model_cost_map.py index 8486ae199..c4e8232c6 100644 --- a/litellm/tests/test_get_model_cost_map.py +++ b/litellm/tests/test_get_model_cost_map.py @@ -9,20 +9,21 @@ from litellm import get_max_tokens, model_cost, open_ai_chat_completion_models def test_get_gpt3_tokens(): max_tokens = get_max_tokens("gpt-3.5-turbo") - results = max_tokens['max_tokens'] - print(results) -# test_get_gpt3_tokens() + print(max_tokens) + assert max_tokens==4097 + # print(results) +test_get_gpt3_tokens() def test_get_palm_tokens(): # # 🦄🦄🦄🦄🦄🦄🦄🦄 max_tokens = get_max_tokens("palm/chat-bison") - results = max_tokens['max_tokens'] - print(results) -# test_get_palm_tokens() + assert max_tokens == 4096 + print(max_tokens) +test_get_palm_tokens() def test_zephyr_hf_tokens(): max_tokens = get_max_tokens("huggingface/HuggingFaceH4/zephyr-7b-beta") - results = max_tokens["max_tokens"] - print(results) + print(max_tokens) + assert max_tokens == 32768 test_zephyr_hf_tokens() \ No newline at end of file diff --git a/litellm/utils.py b/litellm/utils.py index 07d376d30..96ad325fc 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2358,6 +2358,58 @@ def get_api_key(llm_provider: str, dynamic_api_key: Optional[str]): return api_key def get_max_tokens(model: str): + """ + Get the maximum number of tokens allowed for a given model. + + Parameters: + model (str): The name of the model. + + Returns: + int: The maximum number of tokens allowed for the given model. + + Raises: + Exception: If the model is not mapped yet. + + Example: + >>> get_max_tokens("gpt-4") + 8192 + """ + def _get_max_position_embeddings(model_name): + # Construct the URL for the config.json file + config_url = f"https://huggingface.co/{model_name}/raw/main/config.json" + + try: + # Make the HTTP request to get the raw JSON file + response = requests.get(config_url) + response.raise_for_status() # Raise an exception for bad responses (4xx or 5xx) + + # Parse the JSON response + config_json = response.json() + + # Extract and return the max_position_embeddings + max_position_embeddings = config_json.get("max_position_embeddings") + + if max_position_embeddings is not None: + return max_position_embeddings + else: + return None + except requests.exceptions.RequestException as e: + return None + + try: + if model in litellm.model_cost: + return litellm.model_cost[model]["max_tokens"] + model, custom_llm_provider, _, _ = get_llm_provider(model=model) + if custom_llm_provider == "huggingface": + max_tokens = _get_max_position_embeddings(model_name=model) + return max_tokens + else: + raise Exception() + except: + raise Exception("This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json") + + +def get_model_info(model: str): """ Get a dict for the maximum tokens (context window), input_cost_per_token, output_cost_per_token for a given model. @@ -2377,7 +2429,7 @@ def get_max_tokens(model: str): Exception: If the model is not mapped yet. Example: - >>> get_max_tokens("gpt-4") + >>> get_model_info("gpt-4") { "max_tokens": 8192, "input_cost_per_token": 0.00003,