fix(ollama.py): enable parallel ollama completion calls

This commit is contained in:
Krrish Dholakia 2023-12-11 23:18:25 -08:00
parent eb8514ddf6
commit 2c1c75fdf0
3 changed files with 72 additions and 6 deletions

4
.vscode/settings.json vendored Normal file
View file

@ -0,0 +1,4 @@
{
"python.analysis.typeCheckingMode": "off",
"python.analysis.autoImportCompletions": true
}

View file

@ -1,10 +1,12 @@
import requests, types from email import header
from re import T
from tkinter import N
import requests, types, time
import json import json
import traceback import traceback
from typing import Optional from typing import Optional
import litellm import litellm
import httpx import httpx, aiohttp, asyncio
try: try:
from async_generator import async_generator, yield_ # optional dependency from async_generator import async_generator, yield_ # optional dependency
async_generator_imported = True async_generator_imported = True
@ -115,6 +117,9 @@ def get_ollama_response_stream(
prompt="Why is the sky blue?", prompt="Why is the sky blue?",
optional_params=None, optional_params=None,
logging_obj=None, logging_obj=None,
acompletion: bool = False,
model_response=None,
encoding=None
): ):
if api_base.endswith("/api/generate"): if api_base.endswith("/api/generate"):
url = api_base url = api_base
@ -136,8 +141,15 @@ def get_ollama_response_stream(
logging_obj.pre_call( logging_obj.pre_call(
input=None, input=None,
api_key=None, api_key=None,
additional_args={"api_base": url, "complete_input_dict": data}, additional_args={"api_base": url, "complete_input_dict": data, "headers": {}, "acompletion": acompletion,},
) )
if acompletion is True:
response = ollama_acompletion(url=url, data=data, model_response=model_response, encoding=encoding, logging_obj=logging_obj)
return response
else:
return ollama_completion_stream(url=url, data=data)
def ollama_completion_stream(url, data):
session = requests.Session() session = requests.Session()
with session.post(url, json=data, stream=True) as resp: with session.post(url, json=data, stream=True) as resp:
@ -169,6 +181,52 @@ def get_ollama_response_stream(
traceback.print_exc() traceback.print_exc()
session.close() session.close()
async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
try:
timeout = aiohttp.ClientTimeout(total=600) # 10 minutes
async with aiohttp.ClientSession(timeout=timeout) as session:
resp = await session.post(url, json=data)
if resp.status != 200:
text = await resp.text()
raise OllamaError(status_code=resp.status, message=text)
async for line in resp.content.iter_any():
if line:
try:
json_chunk = line.decode("utf-8")
chunks = json_chunk.split("\n")
completion_string = ""
for chunk in chunks:
if chunk.strip() != "":
j = json.loads(chunk)
if "error" in j:
completion_obj = {
"role": "assistant",
"content": "",
"error": j
}
if "response" in j:
completion_obj = {
"role": "assistant",
"content": j["response"],
}
completion_string += completion_obj["content"]
except Exception as e:
traceback.print_exc()
## RESPONSE OBJECT
model_response["choices"][0]["finish_reason"] = "stop"
model_response["choices"][0]["message"]["content"] = completion_string
model_response["created"] = int(time.time())
model_response["model"] = "ollama/" + data['model']
prompt_tokens = len(encoding.encode(data['prompt'])) # type: ignore
completion_tokens = len(encoding.encode(completion_string))
model_response["usage"] = litellm.Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens)
return model_response
except Exception as e:
traceback.print_exc()
if async_generator_imported: if async_generator_imported:
# ollama implementation # ollama implementation
@async_generator @async_generator

View file

@ -8,6 +8,7 @@
# Thank you ! We ❤️ you! - Krrish & Ishaan # Thank you ! We ❤️ you! - Krrish & Ishaan
import os, openai, sys, json, inspect, uuid, datetime, threading import os, openai, sys, json, inspect, uuid, datetime, threading
from re import T
from typing import Any from typing import Any
from functools import partial from functools import partial
import dotenv, traceback, random, asyncio, time, contextvars import dotenv, traceback, random, asyncio, time, contextvars
@ -175,7 +176,8 @@ async def acompletion(*args, **kwargs):
or custom_llm_provider == "deepinfra" or custom_llm_provider == "deepinfra"
or custom_llm_provider == "perplexity" or custom_llm_provider == "perplexity"
or custom_llm_provider == "text-completion-openai" or custom_llm_provider == "text-completion-openai"
or custom_llm_provider == "huggingface"): # currently implemented aiohttp calls for just azure and openai, soon all. or custom_llm_provider == "huggingface"
or custom_llm_provider == "ollama"): # currently implemented aiohttp calls for just azure and openai, soon all.
if kwargs.get("stream", False): if kwargs.get("stream", False):
response = completion(*args, **kwargs) response = completion(*args, **kwargs)
else: else:
@ -1318,7 +1320,9 @@ def completion(
async_generator = ollama.async_get_ollama_response_stream(api_base, model, prompt, optional_params, logging_obj=logging) async_generator = ollama.async_get_ollama_response_stream(api_base, model, prompt, optional_params, logging_obj=logging)
return async_generator return async_generator
generator = ollama.get_ollama_response_stream(api_base, model, prompt, optional_params, logging_obj=logging) generator = ollama.get_ollama_response_stream(api_base, model, prompt, optional_params, logging_obj=logging, acompletion=acompletion, model_response=model_response, encoding=encoding)
if acompletion is True:
return generator
if optional_params.get("stream", False) == True: if optional_params.get("stream", False) == True:
# assume all ollama responses are streamed # assume all ollama responses are streamed
response = CustomStreamWrapper( response = CustomStreamWrapper(