feat(ollama_chat.py): support ollama tool calling

Closes https://github.com/BerriAI/litellm/issues/4812
This commit is contained in:
Krrish Dholakia 2024-07-26 21:51:54 -07:00
parent a264d1ca8c
commit b25d4a8cb3
5 changed files with 57 additions and 25 deletions

View file

@ -149,7 +149,9 @@ class OllamaChatConfig:
"response_format",
]
def map_openai_params(self, non_default_params: dict, optional_params: dict):
def map_openai_params(
self, model: str, non_default_params: dict, optional_params: dict
):
for param, value in non_default_params.items():
if param == "max_tokens":
optional_params["num_predict"] = value
@ -170,6 +172,16 @@ class OllamaChatConfig:
### FUNCTION CALLING LOGIC ###
if param == "tools":
# ollama actually supports json output
## CHECK IF MODEL SUPPORTS TOOL CALLING ##
try:
model_info = litellm.get_model_info(
model=model, custom_llm_provider="ollama_chat"
)
if model_info.get("supports_function_calling") is True:
optional_params["tools"] = value
else:
raise Exception
except Exception:
optional_params["format"] = "json"
litellm.add_function_to_prompt = (
True # so that main.py adds the function call to the prompt
@ -198,11 +210,11 @@ class OllamaChatConfig:
# ollama implementation
def get_ollama_response(
model_response: litellm.ModelResponse,
messages: list,
optional_params: dict,
api_base="http://localhost:11434",
api_key: Optional[str] = None,
model="llama2",
messages=None,
optional_params=None,
logging_obj=None,
acompletion: bool = False,
encoding=None,
@ -223,6 +235,7 @@ def get_ollama_response(
stream = optional_params.pop("stream", False)
format = optional_params.pop("format", None)
function_name = optional_params.pop("function_name", None)
tools = optional_params.pop("tools", None)
for m in messages:
if "role" in m and m["role"] == "tool":
@ -236,6 +249,8 @@ def get_ollama_response(
}
if format is not None:
data["format"] = format
if tools is not None:
data["tools"] = tools
## LOGGING
logging_obj.pre_call(
input=None,
@ -499,7 +514,8 @@ async def ollama_acompletion(
## RESPONSE OBJECT
model_response.choices[0].finish_reason = "stop"
if data.get("format", "") == "json":
if data.get("format", "") == "json" and function_name is not None:
function_call = json.loads(response_json["message"]["content"])
message = litellm.Message(
content=None,
@ -519,11 +535,8 @@ async def ollama_acompletion(
model_response.choices[0].message = message # type: ignore
model_response.choices[0].finish_reason = "tool_calls"
else:
model_response.choices[0].message.content = response_json[ # type: ignore
"message"
][
"content"
]
_message = litellm.Message(**response_json["message"])
model_response.choices[0].message = _message # type: ignore
model_response.created = int(time.time())
model_response.model = "ollama_chat/" + data["model"]

View file

@ -1,8 +1,6 @@
model_list:
- model_name: "*"
- model_name: "mistral"
litellm_params:
model: "*"
litellm_settings:
success_callback: ["logfire"]
cache: true
model: "ollama_chat/llama3.1"
model_info:
supports_function_calling: true

View file

@ -3469,6 +3469,18 @@ class Router:
model_info=_model_info,
)
## REGISTER MODEL INFO IN LITELLM MODEL COST MAP
_model_name = deployment.litellm_params.model
if deployment.litellm_params.custom_llm_provider is not None:
_model_name = (
deployment.litellm_params.custom_llm_provider + "/" + _model_name
)
litellm.register_model(
model_cost={
_model_name: _model_info,
}
)
deployment = self._add_deployment(deployment=deployment)
model = deployment.to_json(exclude_none=True)

View file

@ -74,6 +74,7 @@ class ModelInfo(TypedDict, total=False):
supports_system_messages: Optional[bool]
supports_response_schema: Optional[bool]
supports_vision: Optional[bool]
supports_function_calling: Optional[bool]
class GenericStreamingChunk(TypedDict):

View file

@ -2089,6 +2089,7 @@ def supports_function_calling(model: str) -> bool:
Raises:
Exception: If the given model is not found in model_prices_and_context_window.json.
"""
if model in litellm.model_cost:
model_info = litellm.model_cost[model]
if model_info.get("supports_function_calling", False) is True:
@ -3293,7 +3294,9 @@ def get_optional_params(
_check_valid_arg(supported_params=supported_params)
optional_params = litellm.OllamaChatConfig().map_openai_params(
non_default_params=non_default_params, optional_params=optional_params
model=model,
non_default_params=non_default_params,
optional_params=optional_params,
)
elif custom_llm_provider == "nlp_cloud":
supported_params = get_supported_openai_params(
@ -4877,6 +4880,7 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
supports_system_messages: Optional[bool]
supports_response_schema: Optional[bool]
supports_vision: Optional[bool]
supports_function_calling: Optional[bool]
Raises:
Exception: If the model is not mapped yet.
@ -4951,6 +4955,7 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
supported_openai_params=supported_openai_params,
supports_system_messages=None,
supports_response_schema=None,
supports_function_calling=None,
)
else:
"""
@ -5041,6 +5046,9 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
"supports_response_schema", None
),
supports_vision=_model_info.get("supports_vision", False),
supports_function_calling=_model_info.get(
"supports_function_calling", False
),
)
except Exception:
raise Exception(