feat(ollama.py): add support for ollama function calling

This commit is contained in:
Krrish Dholakia 2023-12-20 14:59:43 +05:30
parent bab8f3350d
commit f0df28362a
6 changed files with 211 additions and 74 deletions

View file

@ -2390,10 +2390,15 @@ def get_optional_params( # use the openai defaults
non_default_params = {k: v for k, v in passed_params.items() if (k != "model" and k != "custom_llm_provider" and k in default_params and v != default_params[k])}
optional_params = {}
## raise exception if function calling passed in for a provider that doesn't support it
if "functions" in non_default_params or "function_call" in non_default_params:
if "functions" in non_default_params or "function_call" in non_default_params or "tools" in non_default_params:
if custom_llm_provider != "openai" and custom_llm_provider != "text-completion-openai" and custom_llm_provider != "azure":
if litellm.add_function_to_prompt: # if user opts to add it to prompt instead
optional_params["functions_unsupported_model"] = non_default_params.pop("functions")
if custom_llm_provider == "ollama":
# ollama actually supports json output
optional_params["format"] = "json"
litellm.add_function_to_prompt = True # so that main.py adds the function call to the prompt
optional_params["functions_unsupported_model"] = non_default_params.pop("tools", non_default_params.pop("functions"))
elif litellm.add_function_to_prompt: # if user opts to add it to prompt instead
optional_params["functions_unsupported_model"] = non_default_params.pop("tools", non_default_params.pop("functions"))
else:
raise UnsupportedParamsError(status_code=500, message=f"Function calling is not supported by {custom_llm_provider}. To add it to the prompt, set `litellm.add_function_to_prompt = True`.")
@ -5192,9 +5197,6 @@ def exception_type(
raise original_exception
raise original_exception
elif custom_llm_provider == "ollama":
if "no attribute 'async_get_ollama_response_stream" in error_str:
exception_mapping_worked = True
raise ImportError("Import error - trying to use async for ollama. import async_generator failed. Try 'pip install async_generator'")
if isinstance(original_exception, dict):
error_str = original_exception.get("error", "")
else: