From cc42f96d6a9ab9cc461fc5f7e6312af1aaeb788a Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 19 Aug 2024 08:31:46 -0700 Subject: [PATCH] fix(ollama_chat.py): fix sync tool calling Fixes https://github.com/BerriAI/litellm/issues/5245 --- litellm/llms/ollama_chat.py | 13 ++-- litellm/proxy/_new_secret_config.yaml | 4 +- litellm/tests/test_ollama.py | 88 +++++++++++++++++++++++---- 3 files changed, 87 insertions(+), 18 deletions(-) diff --git a/litellm/llms/ollama_chat.py b/litellm/llms/ollama_chat.py index 2c55a3c0a1..21056cee3c 100644 --- a/litellm/llms/ollama_chat.py +++ b/litellm/llms/ollama_chat.py @@ -313,7 +313,7 @@ def get_ollama_response( ## RESPONSE OBJECT model_response.choices[0].finish_reason = "stop" - if data.get("format", "") == "json": + if data.get("format", "") == "json" and function_name is not None: function_call = json.loads(response_json["message"]["content"]) message = litellm.Message( content=None, @@ -321,8 +321,10 @@ def get_ollama_response( { "id": f"call_{str(uuid.uuid4())}", "function": { - "name": function_call["name"], - "arguments": json.dumps(function_call["arguments"]), + "name": function_call.get("name", function_name), + "arguments": json.dumps( + function_call.get("arguments", function_call) + ), }, "type": "function", } @@ -331,9 +333,10 @@ def get_ollama_response( model_response.choices[0].message = message # type: ignore model_response.choices[0].finish_reason = "tool_calls" else: - model_response.choices[0].message.content = response_json["message"]["content"] # type: ignore + _message = litellm.Message(**response_json["message"]) + model_response.choices[0].message = _message # type: ignore model_response.created = int(time.time()) - model_response.model = "ollama/" + model + model_response.model = "ollama_chat/" + model prompt_tokens = response_json.get("prompt_eval_count", litellm.token_counter(messages=messages)) # type: ignore completion_tokens = response_json.get( "eval_count", litellm.token_counter(text=response_json["message"]["content"]) diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index eff98ae672..b81f6abde9 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -1,4 +1,4 @@ model_list: - - model_name: "*" + - model_name: "ollama-llama3.1" litellm_params: - model: "*" \ No newline at end of file + model: "ollama_chat/llama3.1" \ No newline at end of file diff --git a/litellm/tests/test_ollama.py b/litellm/tests/test_ollama.py index 77a6c91c3e..0d0b076721 100644 --- a/litellm/tests/test_ollama.py +++ b/litellm/tests/test_ollama.py @@ -1,20 +1,25 @@ import asyncio -import sys, os +import os +import sys import traceback + from dotenv import load_dotenv load_dotenv() -import os, io +import io +import os sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path -import pytest -import litellm from unittest import mock +import pytest + +import litellm + ## for ollama we can't test making the completion call -from litellm.utils import EmbeddingResponse, get_optional_params, get_llm_provider +from litellm.utils import EmbeddingResponse, get_llm_provider, get_optional_params def test_get_ollama_params(): @@ -48,21 +53,31 @@ def test_get_ollama_model(): except Exception as e: pytest.fail(f"Error occurred: {e}") + # test_get_ollama_model() - + + def test_ollama_json_mode(): - # assert that format: json gets passed as is to ollama + # assert that format: json gets passed as is to ollama try: - converted_params = get_optional_params(custom_llm_provider="ollama", model="llama2", format = "json", temperature=0.5) + converted_params = get_optional_params( + custom_llm_provider="ollama", model="llama2", format="json", temperature=0.5 + ) print("Converted params", converted_params) - assert converted_params == {'temperature': 0.5, 'format': 'json'}, f"{converted_params} != {'temperature': 0.5, 'format': 'json'}" + assert converted_params == { + "temperature": 0.5, + "format": "json", + }, f"{converted_params} != {'temperature': 0.5, 'format': 'json'}" except Exception as e: pytest.fail(f"Error occurred: {e}") + + # test_ollama_json_mode() mock_ollama_embedding_response = EmbeddingResponse(model="ollama/nomic-embed-text") + @mock.patch( "litellm.llms.ollama.ollama_embeddings", return_value=mock_ollama_embedding_response, @@ -70,7 +85,9 @@ mock_ollama_embedding_response = EmbeddingResponse(model="ollama/nomic-embed-tex def test_ollama_embeddings(mock_embeddings): # assert that ollama_embeddings is called with the right parameters try: - embeddings = litellm.embedding(model="ollama/nomic-embed-text", input=["hello world"]) + embeddings = litellm.embedding( + model="ollama/nomic-embed-text", input=["hello world"] + ) print(embeddings) mock_embeddings.assert_called_once_with( api_base="http://localhost:11434", @@ -83,8 +100,11 @@ def test_ollama_embeddings(mock_embeddings): ) except Exception as e: pytest.fail(f"Error occurred: {e}") + + # test_ollama_embeddings() + @mock.patch( "litellm.llms.ollama.ollama_aembeddings", return_value=mock_ollama_embedding_response, @@ -92,7 +112,9 @@ def test_ollama_embeddings(mock_embeddings): def test_ollama_aembeddings(mock_aembeddings): # assert that ollama_aembeddings is called with the right parameters try: - embeddings = asyncio.run(litellm.aembedding(model="ollama/nomic-embed-text", input=["hello world"])) + embeddings = asyncio.run( + litellm.aembedding(model="ollama/nomic-embed-text", input=["hello world"]) + ) print(embeddings) mock_aembeddings.assert_called_once_with( api_base="http://localhost:11434", @@ -105,4 +127,48 @@ def test_ollama_aembeddings(mock_aembeddings): ) except Exception as e: pytest.fail(f"Error occurred: {e}") + + # test_ollama_aembeddings() + + +def test_ollama_chat_function_calling(): + import json + + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": {"type": "string"}, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + }, + }, + "required": ["location"], + }, + }, + }, + ] + + messages = [ + {"role": "user", "content": "What's the weather like in San Francisco?"} + ] + + response = litellm.completion( + model="ollama_chat/llama3.1", + messages=messages, + tools=tools, + ) + tool_calls = response.choices[0].message.get("tool_calls", None) + + assert tool_calls is not None + + print(json.loads(tool_calls[0].function.arguments)) + + print(response)