fix(ollama): Handle non-tool-call JSON response when format=json - fix lint

This commit is contained in:
Arjun Prabhulal 2025-04-13 20:45:13 -04:00
parent c3b4290adc
commit 653a47aea6

View file

@ -19,6 +19,7 @@ from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMExcepti
from litellm.secret_managers.main import get_secret_str from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllMessageValues, ChatCompletionUsageBlock from litellm.types.llms.openai import AllMessageValues, ChatCompletionUsageBlock
from litellm.types.utils import ( from litellm.types.utils import (
Choices,
GenericStreamingChunk, GenericStreamingChunk,
ModelInfoBase, ModelInfoBase,
ModelResponse, ModelResponse,
@ -257,12 +258,9 @@ class OllamaConfig(BaseConfig):
model_response.choices[0].finish_reason = "stop" model_response.choices[0].finish_reason = "stop"
if request_data.get("format", "") == "json": if request_data.get("format", "") == "json":
try: try:
# Try to parse the response string as JSON
parsed_response_content = json.loads(response_json["response"]) parsed_response_content = json.loads(response_json["response"])
# Check if the parsed content looks like the expected tool call format
if isinstance(parsed_response_content, dict) and "name" in parsed_response_content and "arguments" in parsed_response_content: if isinstance(parsed_response_content, dict) and "name" in parsed_response_content and "arguments" in parsed_response_content:
# Looks like a tool call, proceed as before
function_call = parsed_response_content function_call = parsed_response_content
message = litellm.Message( message = litellm.Message(
content=None, content=None,
@ -270,32 +268,40 @@ class OllamaConfig(BaseConfig):
{ {
"id": f"call_{str(uuid.uuid4())}", "id": f"call_{str(uuid.uuid4())}",
"function": { "function": {
"name": function_call["name"], # type: ignore "name": function_call["name"],
"arguments": json.dumps(function_call["arguments"]), # type: ignore "arguments": json.dumps(function_call["arguments"]),
}, },
"type": "function", "type": "function",
} }
], ],
) )
model_response.choices[0].message = message choice = model_response.choices[0]
model_response.choices[0].finish_reason = "tool_calls" if isinstance(choice, Choices):
choice.message = message # type: ignore[attr-defined]
choice.finish_reason = "tool_calls"
else: else:
# Parsed JSON doesn't have "name"/"arguments" - treat as plain text choice.message.content = response_json["response"] # type: ignore[attr-defined]
# Fallback: Use the original JSON string as the text content choice.finish_reason = "stop"
model_response.choices[0].message.content = response_json["response"]
model_response.choices[0].finish_reason = "stop"
except json.JSONDecodeError:
# If response_json["response"] wasn't valid JSON, treat as plain text
model_response.choices[0].message.content = response_json["response"]
model_response.choices[0].finish_reason = "stop"
else: else:
model_response.choices[0].message.content = response_json["response"] # type: ignore choice = model_response.choices[0]
if isinstance(choice, Choices):
choice.message.content = response_json["response"] # type: ignore[attr-defined]
choice.finish_reason = "stop"
except json.JSONDecodeError:
choice = model_response.choices[0]
if isinstance(choice, Choices):
choice.message.content = response_json["response"] # type: ignore[attr-defined]
choice.finish_reason = "stop"
else:
choice = model_response.choices[0]
if isinstance(choice, Choices):
choice.message.content = response_json["response"] # type: ignore[attr-defined]
model_response.created = int(time.time()) model_response.created = int(time.time())
model_response.model = "ollama/" + model model_response.model = "ollama/" + model
_prompt = request_data.get("prompt", "") _prompt = request_data.get("prompt", "")
prompt_tokens = response_json.get( prompt_tokens = response_json.get(
"prompt_eval_count", len(encoding.encode(_prompt, disallowed_special=())) # type: ignore "prompt_eval_count", len(encoding.encode(_prompt))
) )
completion_tokens = response_json.get( completion_tokens = response_json.get(
"eval_count", len(response_json.get("message", dict()).get("content", "")) "eval_count", len(response_json.get("message", dict()).get("content", ""))