diff --git a/litellm/llms/ollama.py b/litellm/llms/ollama.py index 845d0e2dd..842d946c6 100644 --- a/litellm/llms/ollama.py +++ b/litellm/llms/ollama.py @@ -185,6 +185,8 @@ class OllamaConfig: "name": "mistral" }' """ + if model.startswith("ollama/") or model.startswith("ollama_chat/"): + model = model.split("/", 1)[1] api_base = get_secret_str("OLLAMA_API_BASE") or "http://localhost:11434" try: diff --git a/tests/local_testing/test_get_model_info.py b/tests/local_testing/test_get_model_info.py index 82ce9c465..11506ed3d 100644 --- a/tests/local_testing/test_get_model_info.py +++ b/tests/local_testing/test_get_model_info.py @@ -89,11 +89,16 @@ def test_get_model_info_ollama_chat(): "template": "tools", } ), - ): + ) as mock_client: info = OllamaConfig().get_model_info("mistral") - print("info", info) assert info["supports_function_calling"] is True info = get_model_info("ollama/mistral") - print("info", info) + assert info["supports_function_calling"] is True + + mock_client.assert_called() + + print(mock_client.call_args.kwargs) + + assert mock_client.call_args.kwargs["json"]["name"] == "mistral"