mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
fix(ollama_chat.py): fix sync tool calling
Fixes https://github.com/BerriAI/litellm/issues/5245
This commit is contained in:
parent
b8e4ef0abf
commit
cc42f96d6a
3 changed files with 87 additions and 18 deletions
|
@ -313,7 +313,7 @@ def get_ollama_response(
|
||||||
|
|
||||||
## RESPONSE OBJECT
|
## RESPONSE OBJECT
|
||||||
model_response.choices[0].finish_reason = "stop"
|
model_response.choices[0].finish_reason = "stop"
|
||||||
if data.get("format", "") == "json":
|
if data.get("format", "") == "json" and function_name is not None:
|
||||||
function_call = json.loads(response_json["message"]["content"])
|
function_call = json.loads(response_json["message"]["content"])
|
||||||
message = litellm.Message(
|
message = litellm.Message(
|
||||||
content=None,
|
content=None,
|
||||||
|
@ -321,8 +321,10 @@ def get_ollama_response(
|
||||||
{
|
{
|
||||||
"id": f"call_{str(uuid.uuid4())}",
|
"id": f"call_{str(uuid.uuid4())}",
|
||||||
"function": {
|
"function": {
|
||||||
"name": function_call["name"],
|
"name": function_call.get("name", function_name),
|
||||||
"arguments": json.dumps(function_call["arguments"]),
|
"arguments": json.dumps(
|
||||||
|
function_call.get("arguments", function_call)
|
||||||
|
),
|
||||||
},
|
},
|
||||||
"type": "function",
|
"type": "function",
|
||||||
}
|
}
|
||||||
|
@ -331,9 +333,10 @@ def get_ollama_response(
|
||||||
model_response.choices[0].message = message # type: ignore
|
model_response.choices[0].message = message # type: ignore
|
||||||
model_response.choices[0].finish_reason = "tool_calls"
|
model_response.choices[0].finish_reason = "tool_calls"
|
||||||
else:
|
else:
|
||||||
model_response.choices[0].message.content = response_json["message"]["content"] # type: ignore
|
_message = litellm.Message(**response_json["message"])
|
||||||
|
model_response.choices[0].message = _message # type: ignore
|
||||||
model_response.created = int(time.time())
|
model_response.created = int(time.time())
|
||||||
model_response.model = "ollama/" + model
|
model_response.model = "ollama_chat/" + model
|
||||||
prompt_tokens = response_json.get("prompt_eval_count", litellm.token_counter(messages=messages)) # type: ignore
|
prompt_tokens = response_json.get("prompt_eval_count", litellm.token_counter(messages=messages)) # type: ignore
|
||||||
completion_tokens = response_json.get(
|
completion_tokens = response_json.get(
|
||||||
"eval_count", litellm.token_counter(text=response_json["message"]["content"])
|
"eval_count", litellm.token_counter(text=response_json["message"]["content"])
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
model_list:
|
model_list:
|
||||||
- model_name: "*"
|
- model_name: "ollama-llama3.1"
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: "*"
|
model: "ollama_chat/llama3.1"
|
|
@ -1,20 +1,25 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import sys, os
|
import os
|
||||||
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
import os, io
|
import io
|
||||||
|
import os
|
||||||
|
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
import pytest
|
|
||||||
import litellm
|
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
|
||||||
## for ollama we can't test making the completion call
|
## for ollama we can't test making the completion call
|
||||||
from litellm.utils import EmbeddingResponse, get_optional_params, get_llm_provider
|
from litellm.utils import EmbeddingResponse, get_llm_provider, get_optional_params
|
||||||
|
|
||||||
|
|
||||||
def test_get_ollama_params():
|
def test_get_ollama_params():
|
||||||
|
@ -48,21 +53,31 @@ def test_get_ollama_model():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
# test_get_ollama_model()
|
# test_get_ollama_model()
|
||||||
|
|
||||||
|
|
||||||
def test_ollama_json_mode():
|
def test_ollama_json_mode():
|
||||||
# assert that format: json gets passed as is to ollama
|
# assert that format: json gets passed as is to ollama
|
||||||
try:
|
try:
|
||||||
converted_params = get_optional_params(custom_llm_provider="ollama", model="llama2", format = "json", temperature=0.5)
|
converted_params = get_optional_params(
|
||||||
|
custom_llm_provider="ollama", model="llama2", format="json", temperature=0.5
|
||||||
|
)
|
||||||
print("Converted params", converted_params)
|
print("Converted params", converted_params)
|
||||||
assert converted_params == {'temperature': 0.5, 'format': 'json'}, f"{converted_params} != {'temperature': 0.5, 'format': 'json'}"
|
assert converted_params == {
|
||||||
|
"temperature": 0.5,
|
||||||
|
"format": "json",
|
||||||
|
}, f"{converted_params} != {'temperature': 0.5, 'format': 'json'}"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
# test_ollama_json_mode()
|
# test_ollama_json_mode()
|
||||||
|
|
||||||
|
|
||||||
mock_ollama_embedding_response = EmbeddingResponse(model="ollama/nomic-embed-text")
|
mock_ollama_embedding_response = EmbeddingResponse(model="ollama/nomic-embed-text")
|
||||||
|
|
||||||
|
|
||||||
@mock.patch(
|
@mock.patch(
|
||||||
"litellm.llms.ollama.ollama_embeddings",
|
"litellm.llms.ollama.ollama_embeddings",
|
||||||
return_value=mock_ollama_embedding_response,
|
return_value=mock_ollama_embedding_response,
|
||||||
|
@ -70,7 +85,9 @@ mock_ollama_embedding_response = EmbeddingResponse(model="ollama/nomic-embed-tex
|
||||||
def test_ollama_embeddings(mock_embeddings):
|
def test_ollama_embeddings(mock_embeddings):
|
||||||
# assert that ollama_embeddings is called with the right parameters
|
# assert that ollama_embeddings is called with the right parameters
|
||||||
try:
|
try:
|
||||||
embeddings = litellm.embedding(model="ollama/nomic-embed-text", input=["hello world"])
|
embeddings = litellm.embedding(
|
||||||
|
model="ollama/nomic-embed-text", input=["hello world"]
|
||||||
|
)
|
||||||
print(embeddings)
|
print(embeddings)
|
||||||
mock_embeddings.assert_called_once_with(
|
mock_embeddings.assert_called_once_with(
|
||||||
api_base="http://localhost:11434",
|
api_base="http://localhost:11434",
|
||||||
|
@ -83,8 +100,11 @@ def test_ollama_embeddings(mock_embeddings):
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
# test_ollama_embeddings()
|
# test_ollama_embeddings()
|
||||||
|
|
||||||
|
|
||||||
@mock.patch(
|
@mock.patch(
|
||||||
"litellm.llms.ollama.ollama_aembeddings",
|
"litellm.llms.ollama.ollama_aembeddings",
|
||||||
return_value=mock_ollama_embedding_response,
|
return_value=mock_ollama_embedding_response,
|
||||||
|
@ -92,7 +112,9 @@ def test_ollama_embeddings(mock_embeddings):
|
||||||
def test_ollama_aembeddings(mock_aembeddings):
|
def test_ollama_aembeddings(mock_aembeddings):
|
||||||
# assert that ollama_aembeddings is called with the right parameters
|
# assert that ollama_aembeddings is called with the right parameters
|
||||||
try:
|
try:
|
||||||
embeddings = asyncio.run(litellm.aembedding(model="ollama/nomic-embed-text", input=["hello world"]))
|
embeddings = asyncio.run(
|
||||||
|
litellm.aembedding(model="ollama/nomic-embed-text", input=["hello world"])
|
||||||
|
)
|
||||||
print(embeddings)
|
print(embeddings)
|
||||||
mock_aembeddings.assert_called_once_with(
|
mock_aembeddings.assert_called_once_with(
|
||||||
api_base="http://localhost:11434",
|
api_base="http://localhost:11434",
|
||||||
|
@ -105,4 +127,48 @@ def test_ollama_aembeddings(mock_aembeddings):
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
# test_ollama_aembeddings()
|
# test_ollama_aembeddings()
|
||||||
|
|
||||||
|
|
||||||
|
def test_ollama_chat_function_calling():
|
||||||
|
import json
|
||||||
|
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"description": "Get the current weather in a given location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {"type": "string"},
|
||||||
|
"unit": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["location"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "What's the weather like in San Francisco?"}
|
||||||
|
]
|
||||||
|
|
||||||
|
response = litellm.completion(
|
||||||
|
model="ollama_chat/llama3.1",
|
||||||
|
messages=messages,
|
||||||
|
tools=tools,
|
||||||
|
)
|
||||||
|
tool_calls = response.choices[0].message.get("tool_calls", None)
|
||||||
|
|
||||||
|
assert tool_calls is not None
|
||||||
|
|
||||||
|
print(json.loads(tool_calls[0].function.arguments))
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue