fix(ollama_chat.py): fix sync tool calling

Fixes https://github.com/BerriAI/litellm/issues/5245
This commit is contained in:
Krrish Dholakia 2024-08-19 08:31:46 -07:00
parent b8e4ef0abf
commit cc42f96d6a
3 changed files with 87 additions and 18 deletions

View file

@ -1,20 +1,25 @@
import asyncio
import sys, os
import os
import sys
import traceback
from dotenv import load_dotenv
load_dotenv()
import os, io
import io
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
import litellm
from unittest import mock
import pytest
import litellm
## for ollama we can't test making the completion call
from litellm.utils import EmbeddingResponse, get_optional_params, get_llm_provider
from litellm.utils import EmbeddingResponse, get_llm_provider, get_optional_params
def test_get_ollama_params():
@ -48,21 +53,31 @@ def test_get_ollama_model():
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_get_ollama_model()
def test_ollama_json_mode():
# assert that format: json gets passed as is to ollama
# assert that format: json gets passed as is to ollama
try:
converted_params = get_optional_params(custom_llm_provider="ollama", model="llama2", format = "json", temperature=0.5)
converted_params = get_optional_params(
custom_llm_provider="ollama", model="llama2", format="json", temperature=0.5
)
print("Converted params", converted_params)
assert converted_params == {'temperature': 0.5, 'format': 'json'}, f"{converted_params} != {'temperature': 0.5, 'format': 'json'}"
assert converted_params == {
"temperature": 0.5,
"format": "json",
}, f"{converted_params} != {'temperature': 0.5, 'format': 'json'}"
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_ollama_json_mode()
mock_ollama_embedding_response = EmbeddingResponse(model="ollama/nomic-embed-text")
@mock.patch(
"litellm.llms.ollama.ollama_embeddings",
return_value=mock_ollama_embedding_response,
@ -70,7 +85,9 @@ mock_ollama_embedding_response = EmbeddingResponse(model="ollama/nomic-embed-tex
def test_ollama_embeddings(mock_embeddings):
# assert that ollama_embeddings is called with the right parameters
try:
embeddings = litellm.embedding(model="ollama/nomic-embed-text", input=["hello world"])
embeddings = litellm.embedding(
model="ollama/nomic-embed-text", input=["hello world"]
)
print(embeddings)
mock_embeddings.assert_called_once_with(
api_base="http://localhost:11434",
@ -83,8 +100,11 @@ def test_ollama_embeddings(mock_embeddings):
)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_ollama_embeddings()
@mock.patch(
"litellm.llms.ollama.ollama_aembeddings",
return_value=mock_ollama_embedding_response,
@ -92,7 +112,9 @@ def test_ollama_embeddings(mock_embeddings):
def test_ollama_aembeddings(mock_aembeddings):
# assert that ollama_aembeddings is called with the right parameters
try:
embeddings = asyncio.run(litellm.aembedding(model="ollama/nomic-embed-text", input=["hello world"]))
embeddings = asyncio.run(
litellm.aembedding(model="ollama/nomic-embed-text", input=["hello world"])
)
print(embeddings)
mock_aembeddings.assert_called_once_with(
api_base="http://localhost:11434",
@ -105,4 +127,48 @@ def test_ollama_aembeddings(mock_aembeddings):
)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_ollama_aembeddings()
def test_ollama_chat_function_calling():
import json
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {"type": "string"},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["location"],
},
},
},
]
messages = [
{"role": "user", "content": "What's the weather like in San Francisco?"}
]
response = litellm.completion(
model="ollama_chat/llama3.1",
messages=messages,
tools=tools,
)
tool_calls = response.choices[0].message.get("tool_calls", None)
assert tool_calls is not None
print(json.loads(tool_calls[0].function.arguments))
print(response)