fix(ollama_chat.py): fix passing assistant message with tool call param

Fixes https://github.com/BerriAI/litellm/issues/5319
This commit is contained in:
Krrish Dholakia 2024-08-22 09:59:52 -07:00
parent e45ec0ef46
commit 2dd616bad0
4 changed files with 53 additions and 8 deletions

View file

@ -4,14 +4,17 @@ import traceback
import types import types
import uuid import uuid
from itertools import chain from itertools import chain
from typing import Optional from typing import List, Optional
import aiohttp import aiohttp
import httpx import httpx
import requests import requests
from pydantic import BaseModel
import litellm import litellm
from litellm import verbose_logger from litellm import verbose_logger
from litellm.types.llms.ollama import OllamaToolCall, OllamaToolCallFunction
from litellm.types.llms.openai import ChatCompletionAssistantToolCall
class OllamaError(Exception): class OllamaError(Exception):
@ -175,7 +178,7 @@ class OllamaChatConfig:
## CHECK IF MODEL SUPPORTS TOOL CALLING ## ## CHECK IF MODEL SUPPORTS TOOL CALLING ##
try: try:
model_info = litellm.get_model_info( model_info = litellm.get_model_info(
model=model, custom_llm_provider="ollama_chat" model=model, custom_llm_provider="ollama"
) )
if model_info.get("supports_function_calling") is True: if model_info.get("supports_function_calling") is True:
optional_params["tools"] = value optional_params["tools"] = value
@ -237,13 +240,30 @@ def get_ollama_response(
function_name = optional_params.pop("function_name", None) function_name = optional_params.pop("function_name", None)
tools = optional_params.pop("tools", None) tools = optional_params.pop("tools", None)
new_messages = []
for m in messages: for m in messages:
if "role" in m and m["role"] == "tool": if isinstance(
m["role"] = "assistant" m, BaseModel
): # avoid message serialization issues - https://github.com/BerriAI/litellm/issues/5319
m = m.model_dump(exclude_none=True)
if m.get("tool_calls") is not None and isinstance(m["tool_calls"], list):
new_tools: List[OllamaToolCall] = []
for tool in m["tool_calls"]:
typed_tool = ChatCompletionAssistantToolCall(**tool) # type: ignore
if typed_tool["type"] == "function":
ollama_tool_call = OllamaToolCall(
function=OllamaToolCallFunction(
name=typed_tool["function"]["name"],
arguments=json.loads(typed_tool["function"]["arguments"]),
)
)
new_tools.append(ollama_tool_call)
m["tool_calls"] = new_tools
new_messages.append(m)
data = { data = {
"model": model, "model": model,
"messages": messages, "messages": new_messages,
"options": optional_params, "options": optional_params,
"stream": stream, "stream": stream,
} }
@ -263,7 +283,7 @@ def get_ollama_response(
}, },
) )
if acompletion is True: if acompletion is True:
if stream == True: if stream is True:
response = ollama_async_streaming( response = ollama_async_streaming(
url=url, url=url,
api_key=api_key, api_key=api_key,
@ -283,7 +303,7 @@ def get_ollama_response(
function_name=function_name, function_name=function_name,
) )
return response return response
elif stream == True: elif stream is True:
return ollama_completion_stream( return ollama_completion_stream(
url=url, api_key=api_key, data=data, logging_obj=logging_obj url=url, api_key=api_key, data=data, logging_obj=logging_obj
) )

View file

@ -2464,7 +2464,7 @@ def completion(
model_response=model_response, model_response=model_response,
encoding=encoding, encoding=encoding,
) )
if acompletion is True or optional_params.get("stream", False) == True: if acompletion is True or optional_params.get("stream", False) is True:
return generator return generator
response = generator response = generator

View file

@ -54,6 +54,7 @@ def get_current_weather(location, unit="fahrenheit"):
) )
def test_parallel_function_call(model): def test_parallel_function_call(model):
try: try:
litellm.set_verbose = True
# Step 1: send the conversation and available functions to the model # Step 1: send the conversation and available functions to the model
messages = [ messages = [
{ {

View file

@ -0,0 +1,24 @@
import json
from typing import Any, Optional, TypedDict, Union
from pydantic import BaseModel
from typing_extensions import (
Protocol,
Required,
Self,
TypeGuard,
get_origin,
override,
runtime_checkable,
)
class OllamaToolCallFunction(
TypedDict
): # follows - https://github.com/ollama/ollama/blob/6bd8a4b0a1ac15d5718f52bbe1cd56f827beb694/api/types.go#L148
name: str
arguments: dict
class OllamaToolCall(TypedDict):
function: OllamaToolCallFunction