mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
fix async import error
This commit is contained in:
parent
e354f39cc1
commit
2b9e3434ff
1 changed files with 39 additions and 37 deletions
|
@ -2,14 +2,15 @@ import requests
|
||||||
import json
|
import json
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from async_generator import async_generator, yield_ # optional dependancy if you want to use acompletion + streaming
|
from async_generator import async_generator, yield_ # optional dependency
|
||||||
except:
|
async_generator_imported = True
|
||||||
pass # this should not throw an error, it will impact the 'import litellm' statement
|
except ImportError:
|
||||||
|
async_generator_imported = False # this should not throw an error, it will impact the 'import litellm' statement
|
||||||
|
|
||||||
# ollama implementation
|
# ollama implementation
|
||||||
def get_ollama_response_stream(
|
def get_ollama_response_stream(
|
||||||
api_base="http://localhost:11434",
|
api_base="http://localhost:11434",
|
||||||
model="llama2",
|
model="llama2",
|
||||||
prompt="Why is the sky blue?"
|
prompt="Why is the sky blue?"
|
||||||
):
|
):
|
||||||
url = f"{api_base}/api/generate"
|
url = f"{api_base}/api/generate"
|
||||||
|
@ -39,36 +40,37 @@ def get_ollama_response_stream(
|
||||||
print(f"Error decoding JSON: {e}")
|
print(f"Error decoding JSON: {e}")
|
||||||
session.close()
|
session.close()
|
||||||
|
|
||||||
# ollama implementation
|
if async_generator_imported:
|
||||||
@async_generator
|
# ollama implementation
|
||||||
async def async_get_ollama_response_stream(
|
@async_generator
|
||||||
api_base="http://localhost:11434",
|
async def async_get_ollama_response_stream(
|
||||||
model="llama2",
|
api_base="http://localhost:11434",
|
||||||
prompt="Why is the sky blue?"
|
model="llama2",
|
||||||
):
|
prompt="Why is the sky blue?"
|
||||||
url = f"{api_base}/api/generate"
|
):
|
||||||
data = {
|
url = f"{api_base}/api/generate"
|
||||||
"model": model,
|
data = {
|
||||||
"prompt": prompt,
|
"model": model,
|
||||||
}
|
"prompt": prompt,
|
||||||
session = requests.Session()
|
}
|
||||||
|
session = requests.Session()
|
||||||
|
|
||||||
with session.post(url, json=data, stream=True) as resp:
|
with session.post(url, json=data, stream=True) as resp:
|
||||||
for line in resp.iter_lines():
|
for line in resp.iter_lines():
|
||||||
if line:
|
if line:
|
||||||
try:
|
try:
|
||||||
json_chunk = line.decode("utf-8")
|
json_chunk = line.decode("utf-8")
|
||||||
chunks = json_chunk.split("\n")
|
chunks = json_chunk.split("\n")
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
if chunk.strip() != "":
|
if chunk.strip() != "":
|
||||||
j = json.loads(chunk)
|
j = json.loads(chunk)
|
||||||
if "response" in j:
|
if "response" in j:
|
||||||
completion_obj = {
|
completion_obj = {
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": "",
|
"content": "",
|
||||||
}
|
}
|
||||||
completion_obj["content"] = j["response"]
|
completion_obj["content"] = j["response"]
|
||||||
await yield_({"choices": [{"delta": completion_obj}]})
|
await yield_({"choices": [{"delta": completion_obj}]})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error decoding JSON: {e}")
|
print(f"Error decoding JSON: {e}")
|
||||||
session.close()
|
session.close()
|
Loading…
Add table
Add a link
Reference in a new issue