mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
fix(ollama_chat.py): fix passing assistant message with tool call param
Fixes https://github.com/BerriAI/litellm/issues/5319
This commit is contained in:
parent
e45ec0ef46
commit
2dd616bad0
4 changed files with 53 additions and 8 deletions
|
@ -4,14 +4,17 @@ import traceback
|
||||||
import types
|
import types
|
||||||
import uuid
|
import uuid
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import httpx
|
import httpx
|
||||||
import requests
|
import requests
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import verbose_logger
|
from litellm import verbose_logger
|
||||||
|
from litellm.types.llms.ollama import OllamaToolCall, OllamaToolCallFunction
|
||||||
|
from litellm.types.llms.openai import ChatCompletionAssistantToolCall
|
||||||
|
|
||||||
|
|
||||||
class OllamaError(Exception):
|
class OllamaError(Exception):
|
||||||
|
@ -175,7 +178,7 @@ class OllamaChatConfig:
|
||||||
## CHECK IF MODEL SUPPORTS TOOL CALLING ##
|
## CHECK IF MODEL SUPPORTS TOOL CALLING ##
|
||||||
try:
|
try:
|
||||||
model_info = litellm.get_model_info(
|
model_info = litellm.get_model_info(
|
||||||
model=model, custom_llm_provider="ollama_chat"
|
model=model, custom_llm_provider="ollama"
|
||||||
)
|
)
|
||||||
if model_info.get("supports_function_calling") is True:
|
if model_info.get("supports_function_calling") is True:
|
||||||
optional_params["tools"] = value
|
optional_params["tools"] = value
|
||||||
|
@ -237,13 +240,30 @@ def get_ollama_response(
|
||||||
function_name = optional_params.pop("function_name", None)
|
function_name = optional_params.pop("function_name", None)
|
||||||
tools = optional_params.pop("tools", None)
|
tools = optional_params.pop("tools", None)
|
||||||
|
|
||||||
|
new_messages = []
|
||||||
for m in messages:
|
for m in messages:
|
||||||
if "role" in m and m["role"] == "tool":
|
if isinstance(
|
||||||
m["role"] = "assistant"
|
m, BaseModel
|
||||||
|
): # avoid message serialization issues - https://github.com/BerriAI/litellm/issues/5319
|
||||||
|
m = m.model_dump(exclude_none=True)
|
||||||
|
if m.get("tool_calls") is not None and isinstance(m["tool_calls"], list):
|
||||||
|
new_tools: List[OllamaToolCall] = []
|
||||||
|
for tool in m["tool_calls"]:
|
||||||
|
typed_tool = ChatCompletionAssistantToolCall(**tool) # type: ignore
|
||||||
|
if typed_tool["type"] == "function":
|
||||||
|
ollama_tool_call = OllamaToolCall(
|
||||||
|
function=OllamaToolCallFunction(
|
||||||
|
name=typed_tool["function"]["name"],
|
||||||
|
arguments=json.loads(typed_tool["function"]["arguments"]),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
new_tools.append(ollama_tool_call)
|
||||||
|
m["tool_calls"] = new_tools
|
||||||
|
new_messages.append(m)
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"model": model,
|
"model": model,
|
||||||
"messages": messages,
|
"messages": new_messages,
|
||||||
"options": optional_params,
|
"options": optional_params,
|
||||||
"stream": stream,
|
"stream": stream,
|
||||||
}
|
}
|
||||||
|
@ -263,7 +283,7 @@ def get_ollama_response(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
if acompletion is True:
|
if acompletion is True:
|
||||||
if stream == True:
|
if stream is True:
|
||||||
response = ollama_async_streaming(
|
response = ollama_async_streaming(
|
||||||
url=url,
|
url=url,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
@ -283,7 +303,7 @@ def get_ollama_response(
|
||||||
function_name=function_name,
|
function_name=function_name,
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
elif stream == True:
|
elif stream is True:
|
||||||
return ollama_completion_stream(
|
return ollama_completion_stream(
|
||||||
url=url, api_key=api_key, data=data, logging_obj=logging_obj
|
url=url, api_key=api_key, data=data, logging_obj=logging_obj
|
||||||
)
|
)
|
||||||
|
|
|
@ -2464,7 +2464,7 @@ def completion(
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
)
|
)
|
||||||
if acompletion is True or optional_params.get("stream", False) == True:
|
if acompletion is True or optional_params.get("stream", False) is True:
|
||||||
return generator
|
return generator
|
||||||
|
|
||||||
response = generator
|
response = generator
|
||||||
|
|
|
@ -54,6 +54,7 @@ def get_current_weather(location, unit="fahrenheit"):
|
||||||
)
|
)
|
||||||
def test_parallel_function_call(model):
|
def test_parallel_function_call(model):
|
||||||
try:
|
try:
|
||||||
|
litellm.set_verbose = True
|
||||||
# Step 1: send the conversation and available functions to the model
|
# Step 1: send the conversation and available functions to the model
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{
|
||||||
|
|
24
litellm/types/llms/ollama.py
Normal file
24
litellm/types/llms/ollama.py
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
import json
|
||||||
|
from typing import Any, Optional, TypedDict, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing_extensions import (
|
||||||
|
Protocol,
|
||||||
|
Required,
|
||||||
|
Self,
|
||||||
|
TypeGuard,
|
||||||
|
get_origin,
|
||||||
|
override,
|
||||||
|
runtime_checkable,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaToolCallFunction(
|
||||||
|
TypedDict
|
||||||
|
): # follows - https://github.com/ollama/ollama/blob/6bd8a4b0a1ac15d5718f52bbe1cd56f827beb694/api/types.go#L148
|
||||||
|
name: str
|
||||||
|
arguments: dict
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaToolCall(TypedDict):
|
||||||
|
function: OllamaToolCallFunction
|
Loading…
Add table
Add a link
Reference in a new issue