diff --git a/litellm/llms/ollama_chat.py b/litellm/llms/ollama_chat.py index 21056cee3c..7c4cf7b370 100644 --- a/litellm/llms/ollama_chat.py +++ b/litellm/llms/ollama_chat.py @@ -4,14 +4,17 @@ import traceback import types import uuid from itertools import chain -from typing import Optional +from typing import List, Optional import aiohttp import httpx import requests +from pydantic import BaseModel import litellm from litellm import verbose_logger +from litellm.types.llms.ollama import OllamaToolCall, OllamaToolCallFunction +from litellm.types.llms.openai import ChatCompletionAssistantToolCall class OllamaError(Exception): @@ -175,7 +178,7 @@ class OllamaChatConfig: ## CHECK IF MODEL SUPPORTS TOOL CALLING ## try: model_info = litellm.get_model_info( - model=model, custom_llm_provider="ollama_chat" + model=model, custom_llm_provider="ollama" ) if model_info.get("supports_function_calling") is True: optional_params["tools"] = value @@ -237,13 +240,30 @@ def get_ollama_response( function_name = optional_params.pop("function_name", None) tools = optional_params.pop("tools", None) + new_messages = [] for m in messages: - if "role" in m and m["role"] == "tool": - m["role"] = "assistant" + if isinstance( + m, BaseModel + ): # avoid message serialization issues - https://github.com/BerriAI/litellm/issues/5319 + m = m.model_dump(exclude_none=True) + if m.get("tool_calls") is not None and isinstance(m["tool_calls"], list): + new_tools: List[OllamaToolCall] = [] + for tool in m["tool_calls"]: + typed_tool = ChatCompletionAssistantToolCall(**tool) # type: ignore + if typed_tool["type"] == "function": + ollama_tool_call = OllamaToolCall( + function=OllamaToolCallFunction( + name=typed_tool["function"]["name"], + arguments=json.loads(typed_tool["function"]["arguments"]), + ) + ) + new_tools.append(ollama_tool_call) + m["tool_calls"] = new_tools + new_messages.append(m) data = { "model": model, - "messages": messages, + "messages": new_messages, "options": optional_params, "stream": stream, } @@ -263,7 +283,7 @@ def get_ollama_response( }, ) if acompletion is True: - if stream == True: + if stream is True: response = ollama_async_streaming( url=url, api_key=api_key, @@ -283,7 +303,7 @@ def get_ollama_response( function_name=function_name, ) return response - elif stream == True: + elif stream is True: return ollama_completion_stream( url=url, api_key=api_key, data=data, logging_obj=logging_obj ) diff --git a/litellm/main.py b/litellm/main.py index ee327c2f7c..80a9a94a34 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -2464,7 +2464,7 @@ def completion( model_response=model_response, encoding=encoding, ) - if acompletion is True or optional_params.get("stream", False) == True: + if acompletion is True or optional_params.get("stream", False) is True: return generator response = generator diff --git a/litellm/tests/test_function_calling.py b/litellm/tests/test_function_calling.py index aa88161df5..6bd0c42cf7 100644 --- a/litellm/tests/test_function_calling.py +++ b/litellm/tests/test_function_calling.py @@ -54,6 +54,7 @@ def get_current_weather(location, unit="fahrenheit"): ) def test_parallel_function_call(model): try: + litellm.set_verbose = True # Step 1: send the conversation and available functions to the model messages = [ { diff --git a/litellm/types/llms/ollama.py b/litellm/types/llms/ollama.py new file mode 100644 index 0000000000..0ffa4e0f60 --- /dev/null +++ b/litellm/types/llms/ollama.py @@ -0,0 +1,24 @@ +import json +from typing import Any, Optional, TypedDict, Union + +from pydantic import BaseModel +from typing_extensions import ( + Protocol, + Required, + Self, + TypeGuard, + get_origin, + override, + runtime_checkable, +) + + +class OllamaToolCallFunction( + TypedDict +): # follows - https://github.com/ollama/ollama/blob/6bd8a4b0a1ac15d5718f52bbe1cd56f827beb694/api/types.go#L148 + name: str + arguments: dict + + +class OllamaToolCall(TypedDict): + function: OllamaToolCallFunction