diff --git a/docs/my-website/docs/providers/ollama.md b/docs/my-website/docs/providers/ollama.md index c1c8fc57c..63b79fe3a 100644 --- a/docs/my-website/docs/providers/ollama.md +++ b/docs/my-website/docs/providers/ollama.md @@ -1,3 +1,6 @@ +import Tabs from '@theme/Tabs'; +import TabItem from '@theme/TabItem'; + # Ollama LiteLLM supports all models from [Ollama](https://github.com/ollama/ollama) @@ -84,6 +87,120 @@ response = completion( ) ``` +## Example Usage - Tool Calling + +To use ollama tool calling, pass `tools=[{..}]` to `litellm.completion()` + + + + +```python +from litellm import completion +import litellm + +## [OPTIONAL] REGISTER MODEL - not all ollama models support function calling, litellm defaults to json mode tool calls if native tool calling not supported. + +# litellm.register_model(model_cost={ +# "ollama_chat/llama3.1": { +# "supports_function_calling": true +# }, +# }) + +tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + } + } +] + +messages = [{"role": "user", "content": "What's the weather like in Boston today?"}] + + +response = completion( + model="ollama_chat/llama3.1", + messages=messages, + tools=tools +) +``` + + + + +1. Setup config.yaml + +```yaml +model_list: + - model_name: "llama3.1" + litellm_params: + model: "ollama_chat/llama3.1" + model_info: + supports_function_calling: true +``` + +2. Start proxy + +```bash +litellm --config /path/to/config.yaml +``` + +3. Test it! + +```bash +curl -X POST 'http://0.0.0.0:4000/chat/completions' \ +-H 'Content-Type: application/json' \ +-H 'Authorization: Bearer sk-1234' \ +-d '{ + "model": "llama3.1", + "messages": [ + { + "role": "user", + "content": "What'\''s the weather like in Boston today?" + } + ], + "tools": [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"] + } + }, + "required": ["location"] + } + } + } + ], + "tool_choice": "auto", + "stream": true +}' +``` + + + ## Using ollama `api/chat` In order to send ollama requests to `POST /api/chat` on your ollama server, set the model prefix to `ollama_chat` diff --git a/litellm/llms/ollama_chat.py b/litellm/llms/ollama_chat.py index ebd0f22fb..a6b975026 100644 --- a/litellm/llms/ollama_chat.py +++ b/litellm/llms/ollama_chat.py @@ -149,7 +149,9 @@ class OllamaChatConfig: "response_format", ] - def map_openai_params(self, non_default_params: dict, optional_params: dict): + def map_openai_params( + self, model: str, non_default_params: dict, optional_params: dict + ): for param, value in non_default_params.items(): if param == "max_tokens": optional_params["num_predict"] = value @@ -170,16 +172,26 @@ class OllamaChatConfig: ### FUNCTION CALLING LOGIC ### if param == "tools": # ollama actually supports json output - optional_params["format"] = "json" - litellm.add_function_to_prompt = ( - True # so that main.py adds the function call to the prompt - ) - optional_params["functions_unsupported_model"] = value + ## CHECK IF MODEL SUPPORTS TOOL CALLING ## + try: + model_info = litellm.get_model_info( + model=model, custom_llm_provider="ollama_chat" + ) + if model_info.get("supports_function_calling") is True: + optional_params["tools"] = value + else: + raise Exception + except Exception: + optional_params["format"] = "json" + litellm.add_function_to_prompt = ( + True # so that main.py adds the function call to the prompt + ) + optional_params["functions_unsupported_model"] = value - if len(optional_params["functions_unsupported_model"]) == 1: - optional_params["function_name"] = optional_params[ - "functions_unsupported_model" - ][0]["function"]["name"] + if len(optional_params["functions_unsupported_model"]) == 1: + optional_params["function_name"] = optional_params[ + "functions_unsupported_model" + ][0]["function"]["name"] if param == "functions": # ollama actually supports json output @@ -198,11 +210,11 @@ class OllamaChatConfig: # ollama implementation def get_ollama_response( model_response: litellm.ModelResponse, + messages: list, + optional_params: dict, api_base="http://localhost:11434", api_key: Optional[str] = None, model="llama2", - messages=None, - optional_params=None, logging_obj=None, acompletion: bool = False, encoding=None, @@ -223,6 +235,7 @@ def get_ollama_response( stream = optional_params.pop("stream", False) format = optional_params.pop("format", None) function_name = optional_params.pop("function_name", None) + tools = optional_params.pop("tools", None) for m in messages: if "role" in m and m["role"] == "tool": @@ -236,6 +249,8 @@ def get_ollama_response( } if format is not None: data["format"] = format + if tools is not None: + data["tools"] = tools ## LOGGING logging_obj.pre_call( input=None, @@ -499,7 +514,8 @@ async def ollama_acompletion( ## RESPONSE OBJECT model_response.choices[0].finish_reason = "stop" - if data.get("format", "") == "json": + + if data.get("format", "") == "json" and function_name is not None: function_call = json.loads(response_json["message"]["content"]) message = litellm.Message( content=None, @@ -519,11 +535,8 @@ async def ollama_acompletion( model_response.choices[0].message = message # type: ignore model_response.choices[0].finish_reason = "tool_calls" else: - model_response.choices[0].message.content = response_json[ # type: ignore - "message" - ][ - "content" - ] + _message = litellm.Message(**response_json["message"]) + model_response.choices[0].message = _message # type: ignore model_response.created = int(time.time()) model_response.model = "ollama_chat/" + data["model"] diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json index d4985bffd..2689d0566 100644 --- a/litellm/model_prices_and_context_window_backup.json +++ b/litellm/model_prices_and_context_window_backup.json @@ -3956,6 +3956,16 @@ "litellm_provider": "ollama", "mode": "chat" }, + "ollama/llama3.1": { + "max_tokens": 8192, + "max_input_tokens": 8192, + "max_output_tokens": 8192, + "input_cost_per_token": 0.0, + "output_cost_per_token": 0.0, + "litellm_provider": "ollama", + "mode": "chat", + "supports_function_calling": true + }, "ollama/mistral": { "max_tokens": 8192, "max_input_tokens": 8192, diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index c12847736..34bf7d89a 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -1,8 +1,6 @@ model_list: - - model_name: "*" + - model_name: "llama3.1" litellm_params: - model: "*" - -litellm_settings: - success_callback: ["logfire"] - cache: true \ No newline at end of file + model: "ollama_chat/llama3.1" + model_info: + supports_function_calling: true \ No newline at end of file diff --git a/litellm/router.py b/litellm/router.py index eff5f94db..d72f3ea5e 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -3469,6 +3469,18 @@ class Router: model_info=_model_info, ) + ## REGISTER MODEL INFO IN LITELLM MODEL COST MAP + _model_name = deployment.litellm_params.model + if deployment.litellm_params.custom_llm_provider is not None: + _model_name = ( + deployment.litellm_params.custom_llm_provider + "/" + _model_name + ) + litellm.register_model( + model_cost={ + _model_name: _model_info, + } + ) + deployment = self._add_deployment(deployment=deployment) model = deployment.to_json(exclude_none=True) diff --git a/litellm/types/utils.py b/litellm/types/utils.py index 88bfa19e9..e64099aa6 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -74,6 +74,7 @@ class ModelInfo(TypedDict, total=False): supports_system_messages: Optional[bool] supports_response_schema: Optional[bool] supports_vision: Optional[bool] + supports_function_calling: Optional[bool] class GenericStreamingChunk(TypedDict): diff --git a/litellm/utils.py b/litellm/utils.py index a8ef6119b..358904677 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2089,6 +2089,7 @@ def supports_function_calling(model: str) -> bool: Raises: Exception: If the given model is not found in model_prices_and_context_window.json. """ + if model in litellm.model_cost: model_info = litellm.model_cost[model] if model_info.get("supports_function_calling", False) is True: @@ -3293,7 +3294,9 @@ def get_optional_params( _check_valid_arg(supported_params=supported_params) optional_params = litellm.OllamaChatConfig().map_openai_params( - non_default_params=non_default_params, optional_params=optional_params + model=model, + non_default_params=non_default_params, + optional_params=optional_params, ) elif custom_llm_provider == "nlp_cloud": supported_params = get_supported_openai_params( @@ -4877,6 +4880,7 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod supports_system_messages: Optional[bool] supports_response_schema: Optional[bool] supports_vision: Optional[bool] + supports_function_calling: Optional[bool] Raises: Exception: If the model is not mapped yet. @@ -4951,6 +4955,7 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod supported_openai_params=supported_openai_params, supports_system_messages=None, supports_response_schema=None, + supports_function_calling=None, ) else: """ @@ -5041,6 +5046,9 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod "supports_response_schema", None ), supports_vision=_model_info.get("supports_vision", False), + supports_function_calling=_model_info.get( + "supports_function_calling", False + ), ) except Exception: raise Exception( diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index d4985bffd..2689d0566 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -3956,6 +3956,16 @@ "litellm_provider": "ollama", "mode": "chat" }, + "ollama/llama3.1": { + "max_tokens": 8192, + "max_input_tokens": 8192, + "max_output_tokens": 8192, + "input_cost_per_token": 0.0, + "output_cost_per_token": 0.0, + "litellm_provider": "ollama", + "mode": "chat", + "supports_function_calling": true + }, "ollama/mistral": { "max_tokens": 8192, "max_input_tokens": 8192,