fix async import error

This commit is contained in:
ishaan-jaff 2023-09-21 11:16:50 -07:00
parent e354f39cc1
commit 2b9e3434ff

View file

@ -2,14 +2,15 @@ import requests
import json import json
try: try:
from async_generator import async_generator, yield_ # optional dependancy if you want to use acompletion + streaming from async_generator import async_generator, yield_ # optional dependency
except: async_generator_imported = True
pass # this should not throw an error, it will impact the 'import litellm' statement except ImportError:
async_generator_imported = False # this should not throw an error, it will impact the 'import litellm' statement
# ollama implementation # ollama implementation
def get_ollama_response_stream( def get_ollama_response_stream(
api_base="http://localhost:11434", api_base="http://localhost:11434",
model="llama2", model="llama2",
prompt="Why is the sky blue?" prompt="Why is the sky blue?"
): ):
url = f"{api_base}/api/generate" url = f"{api_base}/api/generate"
@ -39,36 +40,37 @@ def get_ollama_response_stream(
print(f"Error decoding JSON: {e}") print(f"Error decoding JSON: {e}")
session.close() session.close()
# ollama implementation if async_generator_imported:
@async_generator # ollama implementation
async def async_get_ollama_response_stream( @async_generator
api_base="http://localhost:11434", async def async_get_ollama_response_stream(
model="llama2", api_base="http://localhost:11434",
prompt="Why is the sky blue?" model="llama2",
): prompt="Why is the sky blue?"
url = f"{api_base}/api/generate" ):
data = { url = f"{api_base}/api/generate"
"model": model, data = {
"prompt": prompt, "model": model,
} "prompt": prompt,
session = requests.Session() }
session = requests.Session()
with session.post(url, json=data, stream=True) as resp: with session.post(url, json=data, stream=True) as resp:
for line in resp.iter_lines(): for line in resp.iter_lines():
if line: if line:
try: try:
json_chunk = line.decode("utf-8") json_chunk = line.decode("utf-8")
chunks = json_chunk.split("\n") chunks = json_chunk.split("\n")
for chunk in chunks: for chunk in chunks:
if chunk.strip() != "": if chunk.strip() != "":
j = json.loads(chunk) j = json.loads(chunk)
if "response" in j: if "response" in j:
completion_obj = { completion_obj = {
"role": "assistant", "role": "assistant",
"content": "", "content": "",
} }
completion_obj["content"] = j["response"] completion_obj["content"] = j["response"]
await yield_({"choices": [{"delta": completion_obj}]}) await yield_({"choices": [{"delta": completion_obj}]})
except Exception as e: except Exception as e:
print(f"Error decoding JSON: {e}") print(f"Error decoding JSON: {e}")
session.close() session.close()