mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
fix(ollama_chat.py): fix passing assistant message with tool call param
Fixes https://github.com/BerriAI/litellm/issues/5319
This commit is contained in:
parent
e45ec0ef46
commit
2dd616bad0
4 changed files with 53 additions and 8 deletions
|
@ -4,14 +4,17 @@ import traceback
|
|||
import types
|
||||
import uuid
|
||||
from itertools import chain
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
import aiohttp
|
||||
import httpx
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
|
||||
import litellm
|
||||
from litellm import verbose_logger
|
||||
from litellm.types.llms.ollama import OllamaToolCall, OllamaToolCallFunction
|
||||
from litellm.types.llms.openai import ChatCompletionAssistantToolCall
|
||||
|
||||
|
||||
class OllamaError(Exception):
|
||||
|
@ -175,7 +178,7 @@ class OllamaChatConfig:
|
|||
## CHECK IF MODEL SUPPORTS TOOL CALLING ##
|
||||
try:
|
||||
model_info = litellm.get_model_info(
|
||||
model=model, custom_llm_provider="ollama_chat"
|
||||
model=model, custom_llm_provider="ollama"
|
||||
)
|
||||
if model_info.get("supports_function_calling") is True:
|
||||
optional_params["tools"] = value
|
||||
|
@ -237,13 +240,30 @@ def get_ollama_response(
|
|||
function_name = optional_params.pop("function_name", None)
|
||||
tools = optional_params.pop("tools", None)
|
||||
|
||||
new_messages = []
|
||||
for m in messages:
|
||||
if "role" in m and m["role"] == "tool":
|
||||
m["role"] = "assistant"
|
||||
if isinstance(
|
||||
m, BaseModel
|
||||
): # avoid message serialization issues - https://github.com/BerriAI/litellm/issues/5319
|
||||
m = m.model_dump(exclude_none=True)
|
||||
if m.get("tool_calls") is not None and isinstance(m["tool_calls"], list):
|
||||
new_tools: List[OllamaToolCall] = []
|
||||
for tool in m["tool_calls"]:
|
||||
typed_tool = ChatCompletionAssistantToolCall(**tool) # type: ignore
|
||||
if typed_tool["type"] == "function":
|
||||
ollama_tool_call = OllamaToolCall(
|
||||
function=OllamaToolCallFunction(
|
||||
name=typed_tool["function"]["name"],
|
||||
arguments=json.loads(typed_tool["function"]["arguments"]),
|
||||
)
|
||||
)
|
||||
new_tools.append(ollama_tool_call)
|
||||
m["tool_calls"] = new_tools
|
||||
new_messages.append(m)
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"messages": new_messages,
|
||||
"options": optional_params,
|
||||
"stream": stream,
|
||||
}
|
||||
|
@ -263,7 +283,7 @@ def get_ollama_response(
|
|||
},
|
||||
)
|
||||
if acompletion is True:
|
||||
if stream == True:
|
||||
if stream is True:
|
||||
response = ollama_async_streaming(
|
||||
url=url,
|
||||
api_key=api_key,
|
||||
|
@ -283,7 +303,7 @@ def get_ollama_response(
|
|||
function_name=function_name,
|
||||
)
|
||||
return response
|
||||
elif stream == True:
|
||||
elif stream is True:
|
||||
return ollama_completion_stream(
|
||||
url=url, api_key=api_key, data=data, logging_obj=logging_obj
|
||||
)
|
||||
|
|
|
@ -2464,7 +2464,7 @@ def completion(
|
|||
model_response=model_response,
|
||||
encoding=encoding,
|
||||
)
|
||||
if acompletion is True or optional_params.get("stream", False) == True:
|
||||
if acompletion is True or optional_params.get("stream", False) is True:
|
||||
return generator
|
||||
|
||||
response = generator
|
||||
|
|
|
@ -54,6 +54,7 @@ def get_current_weather(location, unit="fahrenheit"):
|
|||
)
|
||||
def test_parallel_function_call(model):
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
# Step 1: send the conversation and available functions to the model
|
||||
messages = [
|
||||
{
|
||||
|
|
24
litellm/types/llms/ollama.py
Normal file
24
litellm/types/llms/ollama.py
Normal file
|
@ -0,0 +1,24 @@
|
|||
import json
|
||||
from typing import Any, Optional, TypedDict, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import (
|
||||
Protocol,
|
||||
Required,
|
||||
Self,
|
||||
TypeGuard,
|
||||
get_origin,
|
||||
override,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
|
||||
class OllamaToolCallFunction(
|
||||
TypedDict
|
||||
): # follows - https://github.com/ollama/ollama/blob/6bd8a4b0a1ac15d5718f52bbe1cd56f827beb694/api/types.go#L148
|
||||
name: str
|
||||
arguments: dict
|
||||
|
||||
|
||||
class OllamaToolCall(TypedDict):
|
||||
function: OllamaToolCallFunction
|
Loading…
Add table
Add a link
Reference in a new issue