diff --git a/litellm/llms/ollama.py b/litellm/llms/ollama.py index 5972d9e8c..68ca60688 100644 --- a/litellm/llms/ollama.py +++ b/litellm/llms/ollama.py @@ -255,8 +255,34 @@ def ollama_completion_stream(url, data, logging_obj): custom_llm_provider="ollama", logging_obj=logging_obj, ) - for transformed_chunk in streamwrapper: - yield transformed_chunk + # If format is JSON, this was a function call + # Gather all chunks and return the function call as one delta to simplify parsing + if data.get("format", "") == "json": + first_chunk = next(streamwrapper) + response_content = "".join( + chunk.choices[0].delta.content + for chunk in chain([first_chunk], streamwrapper) + if chunk.choices[0].delta.content + ) + + function_call = json.loads(response_content) + delta = litellm.utils.Delta( + content=None, + tool_calls=[ + { + "id": f"call_{str(uuid.uuid4())}", + "function": {"name": function_call["name"], "arguments": json.dumps(function_call["arguments"])}, + "type": "function", + } + ], + ) + model_response = first_chunk + model_response["choices"][0]["delta"] = delta + model_response["choices"][0]["finish_reason"] = "tool_calls" + yield model_response + else: + for transformed_chunk in streamwrapper: + yield transformed_chunk except Exception as e: raise e @@ -278,8 +304,36 @@ async def ollama_async_streaming(url, data, model_response, encoding, logging_ob custom_llm_provider="ollama", logging_obj=logging_obj, ) - async for transformed_chunk in streamwrapper: - yield transformed_chunk + + # If format is JSON, this was a function call + # Gather all chunks and return the function call as one delta to simplify parsing + if data.get("format", "") == "json": + first_chunk = await anext(streamwrapper) + first_chunk_content = first_chunk.choices[0].delta.content or "" + response_content = first_chunk_content + "".join( + [ + chunk.choices[0].delta.content + async for chunk in streamwrapper + if chunk.choices[0].delta.content] + ) + function_call = json.loads(response_content) + delta = litellm.utils.Delta( + content=None, + tool_calls=[ + { + "id": f"call_{str(uuid.uuid4())}", + "function": {"name": function_call["name"], "arguments": json.dumps(function_call["arguments"])}, + "type": "function", + } + ], + ) + model_response = first_chunk + model_response["choices"][0]["delta"] = delta + model_response["choices"][0]["finish_reason"] = "tool_calls" + yield model_response + else: + async for transformed_chunk in streamwrapper: + yield transformed_chunk except Exception as e: traceback.print_exc() raise e