fix(ollama_chat.py): fix sync tool calling

Fixes https://github.com/BerriAI/litellm/issues/5245
This commit is contained in:
Krrish Dholakia 2024-08-19 08:31:46 -07:00
parent b8e4ef0abf
commit cc42f96d6a
3 changed files with 87 additions and 18 deletions

View file

@ -313,7 +313,7 @@ def get_ollama_response(
## RESPONSE OBJECT
model_response.choices[0].finish_reason = "stop"
if data.get("format", "") == "json":
if data.get("format", "") == "json" and function_name is not None:
function_call = json.loads(response_json["message"]["content"])
message = litellm.Message(
content=None,
@ -321,8 +321,10 @@ def get_ollama_response(
{
"id": f"call_{str(uuid.uuid4())}",
"function": {
"name": function_call["name"],
"arguments": json.dumps(function_call["arguments"]),
"name": function_call.get("name", function_name),
"arguments": json.dumps(
function_call.get("arguments", function_call)
),
},
"type": "function",
}
@ -331,9 +333,10 @@ def get_ollama_response(
model_response.choices[0].message = message # type: ignore
model_response.choices[0].finish_reason = "tool_calls"
else:
model_response.choices[0].message.content = response_json["message"]["content"] # type: ignore
_message = litellm.Message(**response_json["message"])
model_response.choices[0].message = _message # type: ignore
model_response.created = int(time.time())
model_response.model = "ollama/" + model
model_response.model = "ollama_chat/" + model
prompt_tokens = response_json.get("prompt_eval_count", litellm.token_counter(messages=messages)) # type: ignore
completion_tokens = response_json.get(
"eval_count", litellm.token_counter(text=response_json["message"]["content"])