diff --git a/docs/my-website/docs/providers/ollama.md b/docs/my-website/docs/providers/ollama.md index 51d91ccb6..78c91bb63 100644 --- a/docs/my-website/docs/providers/ollama.md +++ b/docs/my-website/docs/providers/ollama.md @@ -5,6 +5,12 @@ LiteLLM supports all models from [Ollama](https://github.com/jmorganca/ollama) Open In Colab +:::info + +We recommend using [ollama_chat](#using-ollama-apichat) for better responses. + +::: + ## Pre-requisites Ensure you have your ollama server running diff --git a/litellm/__init__.py b/litellm/__init__.py index 017bd46ac..506147166 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -588,6 +588,7 @@ from .llms.petals import PetalsConfig from .llms.vertex_ai import VertexAIConfig from .llms.sagemaker import SagemakerConfig from .llms.ollama import OllamaConfig +from .llms.ollama_chat import OllamaChatConfig from .llms.maritalk import MaritTalkConfig from .llms.bedrock import ( AmazonTitanConfig, diff --git a/litellm/llms/ollama_chat.py b/litellm/llms/ollama_chat.py index dec74fa92..8378a95ff 100644 --- a/litellm/llms/ollama_chat.py +++ b/litellm/llms/ollama_chat.py @@ -18,7 +18,7 @@ class OllamaError(Exception): ) # Call the base class constructor with the parameters it needs -class OllamaConfig: +class OllamaChatConfig: """ Reference: https://github.com/jmorganca/ollama/blob/main/docs/api.md#parameters @@ -108,6 +108,7 @@ class OllamaConfig: k: v for k, v in cls.__dict__.items() if not k.startswith("__") + and k != "function_name" # special param for function calling and not isinstance( v, ( @@ -120,6 +121,61 @@ class OllamaConfig: and v is not None } + def get_supported_openai_params( + self, + ): + return [ + "max_tokens", + "stream", + "top_p", + "temperature", + "frequency_penalty", + "stop", + "tools", + "tool_choice", + "functions", + ] + + def map_openai_params(self, non_default_params: dict, optional_params: dict): + for param, value in non_default_params.items(): + if param == "max_tokens": + optional_params["num_predict"] = value + if param == "stream": + optional_params["stream"] = value + if param == "temperature": + optional_params["temperature"] = value + if param == "top_p": + optional_params["top_p"] = value + if param == "frequency_penalty": + optional_params["repeat_penalty"] = param + if param == "stop": + optional_params["stop"] = value + ### FUNCTION CALLING LOGIC ### + if param == "tools": + # ollama actually supports json output + optional_params["format"] = "json" + litellm.add_function_to_prompt = ( + True # so that main.py adds the function call to the prompt + ) + optional_params["functions_unsupported_model"] = value + + if len(optional_params["functions_unsupported_model"]) == 1: + optional_params["function_name"] = optional_params[ + "functions_unsupported_model" + ][0]["function"]["name"] + + if param == "functions": + # ollama actually supports json output + optional_params["format"] = "json" + litellm.add_function_to_prompt = ( + True # so that main.py adds the function call to the prompt + ) + optional_params["functions_unsupported_model"] = non_default_params.pop( + "functions" + ) + non_default_params.pop("tool_choice", None) # causes ollama requests to hang + return optional_params + # ollama implementation def get_ollama_response( @@ -138,7 +194,7 @@ def get_ollama_response( url = f"{api_base}/api/chat" ## Load Config - config = litellm.OllamaConfig.get_config() + config = litellm.OllamaChatConfig.get_config() for k, v in config.items(): if ( k not in optional_params @@ -147,6 +203,7 @@ def get_ollama_response( stream = optional_params.pop("stream", False) format = optional_params.pop("format", None) + function_name = optional_params.pop("function_name", None) for m in messages: if "role" in m and m["role"] == "tool": @@ -187,6 +244,7 @@ def get_ollama_response( model_response=model_response, encoding=encoding, logging_obj=logging_obj, + function_name=function_name, ) return response elif stream == True: @@ -290,7 +348,9 @@ async def ollama_async_streaming(url, data, model_response, encoding, logging_ob traceback.print_exc() -async def ollama_acompletion(url, data, model_response, encoding, logging_obj): +async def ollama_acompletion( + url, data, model_response, encoding, logging_obj, function_name +): data["stream"] = False try: timeout = aiohttp.ClientTimeout(total=litellm.request_timeout) # 10 minutes @@ -324,7 +384,7 @@ async def ollama_acompletion(url, data, model_response, encoding, logging_obj): "id": f"call_{str(uuid.uuid4())}", "function": { "arguments": response_json["message"]["content"], - "name": "", + "name": function_name or "", }, "type": "function", } diff --git a/litellm/utils.py b/litellm/utils.py index 4c48c5516..38836a4bc 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -4147,8 +4147,9 @@ def get_optional_params( and custom_llm_provider != "mistral" and custom_llm_provider != "anthropic" and custom_llm_provider != "bedrock" + and custom_llm_provider != "ollama_chat" ): - if custom_llm_provider == "ollama" or custom_llm_provider == "ollama_chat": + if custom_llm_provider == "ollama": # ollama actually supports json output optional_params["format"] = "json" litellm.add_function_to_prompt = ( @@ -4174,7 +4175,7 @@ def get_optional_params( else: raise UnsupportedParamsError( status_code=500, - message=f"Function calling is not supported by {custom_llm_provider}. To add it to the prompt, set `litellm.add_function_to_prompt = True`.", + message=f"Function calling is not supported by {custom_llm_provider}.", ) def _check_valid_arg(supported_params): @@ -4687,28 +4688,13 @@ def get_optional_params( if stop is not None: optional_params["stop"] = stop elif custom_llm_provider == "ollama_chat": - supported_params = [ - "max_tokens", - "stream", - "top_p", - "temperature", - "frequency_penalty", - "stop", - ] + supported_params = litellm.OllamaChatConfig().get_supported_openai_params() + _check_valid_arg(supported_params=supported_params) - if max_tokens is not None: - optional_params["num_predict"] = max_tokens - if stream: - optional_params["stream"] = stream - if temperature is not None: - optional_params["temperature"] = temperature - if top_p is not None: - optional_params["top_p"] = top_p - if frequency_penalty is not None: - optional_params["repeat_penalty"] = frequency_penalty - if stop is not None: - optional_params["stop"] = stop + optional_params = litellm.OllamaChatConfig().map_openai_params( + non_default_params=non_default_params, optional_params=optional_params + ) elif custom_llm_provider == "nlp_cloud": supported_params = [ "max_tokens",