diff --git a/litellm/llms/ollama.py b/litellm/llms/ollama.py index b7ec698505..61680fc96e 100644 --- a/litellm/llms/ollama.py +++ b/litellm/llms/ollama.py @@ -2,14 +2,15 @@ import requests import json try: - from async_generator import async_generator, yield_ # optional dependancy if you want to use acompletion + streaming -except: - pass # this should not throw an error, it will impact the 'import litellm' statement + from async_generator import async_generator, yield_ # optional dependency + async_generator_imported = True +except ImportError: + async_generator_imported = False # this should not throw an error, it will impact the 'import litellm' statement # ollama implementation def get_ollama_response_stream( - api_base="http://localhost:11434", - model="llama2", + api_base="http://localhost:11434", + model="llama2", prompt="Why is the sky blue?" ): url = f"{api_base}/api/generate" @@ -39,36 +40,37 @@ def get_ollama_response_stream( print(f"Error decoding JSON: {e}") session.close() -# ollama implementation -@async_generator -async def async_get_ollama_response_stream( - api_base="http://localhost:11434", - model="llama2", - prompt="Why is the sky blue?" - ): - url = f"{api_base}/api/generate" - data = { - "model": model, - "prompt": prompt, - } - session = requests.Session() +if async_generator_imported: + # ollama implementation + @async_generator + async def async_get_ollama_response_stream( + api_base="http://localhost:11434", + model="llama2", + prompt="Why is the sky blue?" + ): + url = f"{api_base}/api/generate" + data = { + "model": model, + "prompt": prompt, + } + session = requests.Session() - with session.post(url, json=data, stream=True) as resp: - for line in resp.iter_lines(): - if line: - try: - json_chunk = line.decode("utf-8") - chunks = json_chunk.split("\n") - for chunk in chunks: - if chunk.strip() != "": - j = json.loads(chunk) - if "response" in j: - completion_obj = { - "role": "assistant", - "content": "", - } - completion_obj["content"] = j["response"] - await yield_({"choices": [{"delta": completion_obj}]}) - except Exception as e: - print(f"Error decoding JSON: {e}") - session.close() \ No newline at end of file + with session.post(url, json=data, stream=True) as resp: + for line in resp.iter_lines(): + if line: + try: + json_chunk = line.decode("utf-8") + chunks = json_chunk.split("\n") + for chunk in chunks: + if chunk.strip() != "": + j = json.loads(chunk) + if "response" in j: + completion_obj = { + "role": "assistant", + "content": "", + } + completion_obj["content"] = j["response"] + await yield_({"choices": [{"delta": completion_obj}]}) + except Exception as e: + print(f"Error decoding JSON: {e}") + session.close() \ No newline at end of file