Merge pull request #4918 from BerriAI/litellm_ollama_tool_calling

feat(ollama_chat.py): support ollama tool calling
This commit is contained in:
Krish Dholakia 2024-07-26 22:16:58 -07:00 committed by GitHub
commit f9c2fec1a6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 194 additions and 25 deletions

View file

@ -1,3 +1,6 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Ollama # Ollama
LiteLLM supports all models from [Ollama](https://github.com/ollama/ollama) LiteLLM supports all models from [Ollama](https://github.com/ollama/ollama)
@ -84,6 +87,120 @@ response = completion(
) )
``` ```
## Example Usage - Tool Calling
To use ollama tool calling, pass `tools=[{..}]` to `litellm.completion()`
<Tabs>
<TabItem value="sdk" label="SDK">
```python
from litellm import completion
import litellm
## [OPTIONAL] REGISTER MODEL - not all ollama models support function calling, litellm defaults to json mode tool calls if native tool calling not supported.
# litellm.register_model(model_cost={
# "ollama_chat/llama3.1": {
# "supports_function_calling": true
# },
# })
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
}
}
]
messages = [{"role": "user", "content": "What's the weather like in Boston today?"}]
response = completion(
model="ollama_chat/llama3.1",
messages=messages,
tools=tools
)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Setup config.yaml
```yaml
model_list:
- model_name: "llama3.1"
litellm_params:
model: "ollama_chat/llama3.1"
model_info:
supports_function_calling: true
```
2. Start proxy
```bash
litellm --config /path/to/config.yaml
```
3. Test it!
```bash
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer sk-1234' \
-d '{
"model": "llama3.1",
"messages": [
{
"role": "user",
"content": "What'\''s the weather like in Boston today?"
}
],
"tools": [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"]
}
},
"required": ["location"]
}
}
}
],
"tool_choice": "auto",
"stream": true
}'
```
</TabItem>
</Tabs>
## Using ollama `api/chat` ## Using ollama `api/chat`
In order to send ollama requests to `POST /api/chat` on your ollama server, set the model prefix to `ollama_chat` In order to send ollama requests to `POST /api/chat` on your ollama server, set the model prefix to `ollama_chat`

View file

@ -149,7 +149,9 @@ class OllamaChatConfig:
"response_format", "response_format",
] ]
def map_openai_params(self, non_default_params: dict, optional_params: dict): def map_openai_params(
self, model: str, non_default_params: dict, optional_params: dict
):
for param, value in non_default_params.items(): for param, value in non_default_params.items():
if param == "max_tokens": if param == "max_tokens":
optional_params["num_predict"] = value optional_params["num_predict"] = value
@ -170,6 +172,16 @@ class OllamaChatConfig:
### FUNCTION CALLING LOGIC ### ### FUNCTION CALLING LOGIC ###
if param == "tools": if param == "tools":
# ollama actually supports json output # ollama actually supports json output
## CHECK IF MODEL SUPPORTS TOOL CALLING ##
try:
model_info = litellm.get_model_info(
model=model, custom_llm_provider="ollama_chat"
)
if model_info.get("supports_function_calling") is True:
optional_params["tools"] = value
else:
raise Exception
except Exception:
optional_params["format"] = "json" optional_params["format"] = "json"
litellm.add_function_to_prompt = ( litellm.add_function_to_prompt = (
True # so that main.py adds the function call to the prompt True # so that main.py adds the function call to the prompt
@ -198,11 +210,11 @@ class OllamaChatConfig:
# ollama implementation # ollama implementation
def get_ollama_response( def get_ollama_response(
model_response: litellm.ModelResponse, model_response: litellm.ModelResponse,
messages: list,
optional_params: dict,
api_base="http://localhost:11434", api_base="http://localhost:11434",
api_key: Optional[str] = None, api_key: Optional[str] = None,
model="llama2", model="llama2",
messages=None,
optional_params=None,
logging_obj=None, logging_obj=None,
acompletion: bool = False, acompletion: bool = False,
encoding=None, encoding=None,
@ -223,6 +235,7 @@ def get_ollama_response(
stream = optional_params.pop("stream", False) stream = optional_params.pop("stream", False)
format = optional_params.pop("format", None) format = optional_params.pop("format", None)
function_name = optional_params.pop("function_name", None) function_name = optional_params.pop("function_name", None)
tools = optional_params.pop("tools", None)
for m in messages: for m in messages:
if "role" in m and m["role"] == "tool": if "role" in m and m["role"] == "tool":
@ -236,6 +249,8 @@ def get_ollama_response(
} }
if format is not None: if format is not None:
data["format"] = format data["format"] = format
if tools is not None:
data["tools"] = tools
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=None, input=None,
@ -499,7 +514,8 @@ async def ollama_acompletion(
## RESPONSE OBJECT ## RESPONSE OBJECT
model_response.choices[0].finish_reason = "stop" model_response.choices[0].finish_reason = "stop"
if data.get("format", "") == "json":
if data.get("format", "") == "json" and function_name is not None:
function_call = json.loads(response_json["message"]["content"]) function_call = json.loads(response_json["message"]["content"])
message = litellm.Message( message = litellm.Message(
content=None, content=None,
@ -519,11 +535,8 @@ async def ollama_acompletion(
model_response.choices[0].message = message # type: ignore model_response.choices[0].message = message # type: ignore
model_response.choices[0].finish_reason = "tool_calls" model_response.choices[0].finish_reason = "tool_calls"
else: else:
model_response.choices[0].message.content = response_json[ # type: ignore _message = litellm.Message(**response_json["message"])
"message" model_response.choices[0].message = _message # type: ignore
][
"content"
]
model_response.created = int(time.time()) model_response.created = int(time.time())
model_response.model = "ollama_chat/" + data["model"] model_response.model = "ollama_chat/" + data["model"]

View file

@ -3956,6 +3956,16 @@
"litellm_provider": "ollama", "litellm_provider": "ollama",
"mode": "chat" "mode": "chat"
}, },
"ollama/llama3.1": {
"max_tokens": 8192,
"max_input_tokens": 8192,
"max_output_tokens": 8192,
"input_cost_per_token": 0.0,
"output_cost_per_token": 0.0,
"litellm_provider": "ollama",
"mode": "chat",
"supports_function_calling": true
},
"ollama/mistral": { "ollama/mistral": {
"max_tokens": 8192, "max_tokens": 8192,
"max_input_tokens": 8192, "max_input_tokens": 8192,

View file

@ -1,8 +1,6 @@
model_list: model_list:
- model_name: "*" - model_name: "llama3.1"
litellm_params: litellm_params:
model: "*" model: "ollama_chat/llama3.1"
model_info:
litellm_settings: supports_function_calling: true
success_callback: ["logfire"]
cache: true

View file

@ -3469,6 +3469,18 @@ class Router:
model_info=_model_info, model_info=_model_info,
) )
## REGISTER MODEL INFO IN LITELLM MODEL COST MAP
_model_name = deployment.litellm_params.model
if deployment.litellm_params.custom_llm_provider is not None:
_model_name = (
deployment.litellm_params.custom_llm_provider + "/" + _model_name
)
litellm.register_model(
model_cost={
_model_name: _model_info,
}
)
deployment = self._add_deployment(deployment=deployment) deployment = self._add_deployment(deployment=deployment)
model = deployment.to_json(exclude_none=True) model = deployment.to_json(exclude_none=True)

View file

@ -74,6 +74,7 @@ class ModelInfo(TypedDict, total=False):
supports_system_messages: Optional[bool] supports_system_messages: Optional[bool]
supports_response_schema: Optional[bool] supports_response_schema: Optional[bool]
supports_vision: Optional[bool] supports_vision: Optional[bool]
supports_function_calling: Optional[bool]
class GenericStreamingChunk(TypedDict): class GenericStreamingChunk(TypedDict):

View file

@ -2089,6 +2089,7 @@ def supports_function_calling(model: str) -> bool:
Raises: Raises:
Exception: If the given model is not found in model_prices_and_context_window.json. Exception: If the given model is not found in model_prices_and_context_window.json.
""" """
if model in litellm.model_cost: if model in litellm.model_cost:
model_info = litellm.model_cost[model] model_info = litellm.model_cost[model]
if model_info.get("supports_function_calling", False) is True: if model_info.get("supports_function_calling", False) is True:
@ -3293,7 +3294,9 @@ def get_optional_params(
_check_valid_arg(supported_params=supported_params) _check_valid_arg(supported_params=supported_params)
optional_params = litellm.OllamaChatConfig().map_openai_params( optional_params = litellm.OllamaChatConfig().map_openai_params(
non_default_params=non_default_params, optional_params=optional_params model=model,
non_default_params=non_default_params,
optional_params=optional_params,
) )
elif custom_llm_provider == "nlp_cloud": elif custom_llm_provider == "nlp_cloud":
supported_params = get_supported_openai_params( supported_params = get_supported_openai_params(
@ -4877,6 +4880,7 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
supports_system_messages: Optional[bool] supports_system_messages: Optional[bool]
supports_response_schema: Optional[bool] supports_response_schema: Optional[bool]
supports_vision: Optional[bool] supports_vision: Optional[bool]
supports_function_calling: Optional[bool]
Raises: Raises:
Exception: If the model is not mapped yet. Exception: If the model is not mapped yet.
@ -4951,6 +4955,7 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
supported_openai_params=supported_openai_params, supported_openai_params=supported_openai_params,
supports_system_messages=None, supports_system_messages=None,
supports_response_schema=None, supports_response_schema=None,
supports_function_calling=None,
) )
else: else:
""" """
@ -5041,6 +5046,9 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
"supports_response_schema", None "supports_response_schema", None
), ),
supports_vision=_model_info.get("supports_vision", False), supports_vision=_model_info.get("supports_vision", False),
supports_function_calling=_model_info.get(
"supports_function_calling", False
),
) )
except Exception: except Exception:
raise Exception( raise Exception(

View file

@ -3956,6 +3956,16 @@
"litellm_provider": "ollama", "litellm_provider": "ollama",
"mode": "chat" "mode": "chat"
}, },
"ollama/llama3.1": {
"max_tokens": 8192,
"max_input_tokens": 8192,
"max_output_tokens": 8192,
"input_cost_per_token": 0.0,
"output_cost_per_token": 0.0,
"litellm_provider": "ollama",
"mode": "chat",
"supports_function_calling": true
},
"ollama/mistral": { "ollama/mistral": {
"max_tokens": 8192, "max_tokens": 8192,
"max_input_tokens": 8192, "max_input_tokens": 8192,