feat(ollama_chat.py): support ollama tool calling

Closes https://github.com/BerriAI/litellm/issues/4812
This commit is contained in:
Krrish Dholakia 2024-07-26 21:51:54 -07:00
parent a264d1ca8c
commit b25d4a8cb3
5 changed files with 57 additions and 25 deletions

View file

@ -149,7 +149,9 @@ class OllamaChatConfig:
"response_format", "response_format",
] ]
def map_openai_params(self, non_default_params: dict, optional_params: dict): def map_openai_params(
self, model: str, non_default_params: dict, optional_params: dict
):
for param, value in non_default_params.items(): for param, value in non_default_params.items():
if param == "max_tokens": if param == "max_tokens":
optional_params["num_predict"] = value optional_params["num_predict"] = value
@ -170,6 +172,16 @@ class OllamaChatConfig:
### FUNCTION CALLING LOGIC ### ### FUNCTION CALLING LOGIC ###
if param == "tools": if param == "tools":
# ollama actually supports json output # ollama actually supports json output
## CHECK IF MODEL SUPPORTS TOOL CALLING ##
try:
model_info = litellm.get_model_info(
model=model, custom_llm_provider="ollama_chat"
)
if model_info.get("supports_function_calling") is True:
optional_params["tools"] = value
else:
raise Exception
except Exception:
optional_params["format"] = "json" optional_params["format"] = "json"
litellm.add_function_to_prompt = ( litellm.add_function_to_prompt = (
True # so that main.py adds the function call to the prompt True # so that main.py adds the function call to the prompt
@ -198,11 +210,11 @@ class OllamaChatConfig:
# ollama implementation # ollama implementation
def get_ollama_response( def get_ollama_response(
model_response: litellm.ModelResponse, model_response: litellm.ModelResponse,
messages: list,
optional_params: dict,
api_base="http://localhost:11434", api_base="http://localhost:11434",
api_key: Optional[str] = None, api_key: Optional[str] = None,
model="llama2", model="llama2",
messages=None,
optional_params=None,
logging_obj=None, logging_obj=None,
acompletion: bool = False, acompletion: bool = False,
encoding=None, encoding=None,
@ -223,6 +235,7 @@ def get_ollama_response(
stream = optional_params.pop("stream", False) stream = optional_params.pop("stream", False)
format = optional_params.pop("format", None) format = optional_params.pop("format", None)
function_name = optional_params.pop("function_name", None) function_name = optional_params.pop("function_name", None)
tools = optional_params.pop("tools", None)
for m in messages: for m in messages:
if "role" in m and m["role"] == "tool": if "role" in m and m["role"] == "tool":
@ -236,6 +249,8 @@ def get_ollama_response(
} }
if format is not None: if format is not None:
data["format"] = format data["format"] = format
if tools is not None:
data["tools"] = tools
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=None, input=None,
@ -499,7 +514,8 @@ async def ollama_acompletion(
## RESPONSE OBJECT ## RESPONSE OBJECT
model_response.choices[0].finish_reason = "stop" model_response.choices[0].finish_reason = "stop"
if data.get("format", "") == "json":
if data.get("format", "") == "json" and function_name is not None:
function_call = json.loads(response_json["message"]["content"]) function_call = json.loads(response_json["message"]["content"])
message = litellm.Message( message = litellm.Message(
content=None, content=None,
@ -519,11 +535,8 @@ async def ollama_acompletion(
model_response.choices[0].message = message # type: ignore model_response.choices[0].message = message # type: ignore
model_response.choices[0].finish_reason = "tool_calls" model_response.choices[0].finish_reason = "tool_calls"
else: else:
model_response.choices[0].message.content = response_json[ # type: ignore _message = litellm.Message(**response_json["message"])
"message" model_response.choices[0].message = _message # type: ignore
][
"content"
]
model_response.created = int(time.time()) model_response.created = int(time.time())
model_response.model = "ollama_chat/" + data["model"] model_response.model = "ollama_chat/" + data["model"]

View file

@ -1,8 +1,6 @@
model_list: model_list:
- model_name: "*" - model_name: "mistral"
litellm_params: litellm_params:
model: "*" model: "ollama_chat/llama3.1"
model_info:
litellm_settings: supports_function_calling: true
success_callback: ["logfire"]
cache: true

View file

@ -3469,6 +3469,18 @@ class Router:
model_info=_model_info, model_info=_model_info,
) )
## REGISTER MODEL INFO IN LITELLM MODEL COST MAP
_model_name = deployment.litellm_params.model
if deployment.litellm_params.custom_llm_provider is not None:
_model_name = (
deployment.litellm_params.custom_llm_provider + "/" + _model_name
)
litellm.register_model(
model_cost={
_model_name: _model_info,
}
)
deployment = self._add_deployment(deployment=deployment) deployment = self._add_deployment(deployment=deployment)
model = deployment.to_json(exclude_none=True) model = deployment.to_json(exclude_none=True)

View file

@ -74,6 +74,7 @@ class ModelInfo(TypedDict, total=False):
supports_system_messages: Optional[bool] supports_system_messages: Optional[bool]
supports_response_schema: Optional[bool] supports_response_schema: Optional[bool]
supports_vision: Optional[bool] supports_vision: Optional[bool]
supports_function_calling: Optional[bool]
class GenericStreamingChunk(TypedDict): class GenericStreamingChunk(TypedDict):

View file

@ -2089,6 +2089,7 @@ def supports_function_calling(model: str) -> bool:
Raises: Raises:
Exception: If the given model is not found in model_prices_and_context_window.json. Exception: If the given model is not found in model_prices_and_context_window.json.
""" """
if model in litellm.model_cost: if model in litellm.model_cost:
model_info = litellm.model_cost[model] model_info = litellm.model_cost[model]
if model_info.get("supports_function_calling", False) is True: if model_info.get("supports_function_calling", False) is True:
@ -3293,7 +3294,9 @@ def get_optional_params(
_check_valid_arg(supported_params=supported_params) _check_valid_arg(supported_params=supported_params)
optional_params = litellm.OllamaChatConfig().map_openai_params( optional_params = litellm.OllamaChatConfig().map_openai_params(
non_default_params=non_default_params, optional_params=optional_params model=model,
non_default_params=non_default_params,
optional_params=optional_params,
) )
elif custom_llm_provider == "nlp_cloud": elif custom_llm_provider == "nlp_cloud":
supported_params = get_supported_openai_params( supported_params = get_supported_openai_params(
@ -4877,6 +4880,7 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
supports_system_messages: Optional[bool] supports_system_messages: Optional[bool]
supports_response_schema: Optional[bool] supports_response_schema: Optional[bool]
supports_vision: Optional[bool] supports_vision: Optional[bool]
supports_function_calling: Optional[bool]
Raises: Raises:
Exception: If the model is not mapped yet. Exception: If the model is not mapped yet.
@ -4951,6 +4955,7 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
supported_openai_params=supported_openai_params, supported_openai_params=supported_openai_params,
supports_system_messages=None, supports_system_messages=None,
supports_response_schema=None, supports_response_schema=None,
supports_function_calling=None,
) )
else: else:
""" """
@ -5041,6 +5046,9 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
"supports_response_schema", None "supports_response_schema", None
), ),
supports_vision=_model_info.get("supports_vision", False), supports_vision=_model_info.get("supports_vision", False),
supports_function_calling=_model_info.get(
"supports_function_calling", False
),
) )
except Exception: except Exception:
raise Exception( raise Exception(