From 2c1c75fdf0fcc1de3d7ba6a90ae97b219fba13b3 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 11 Dec 2023 23:18:25 -0800 Subject: [PATCH] fix(ollama.py): enable parallel ollama completion calls --- .vscode/settings.json | 4 +++ litellm/llms/ollama.py | 66 +++++++++++++++++++++++++++++++++++++++--- litellm/main.py | 8 +++-- 3 files changed, 72 insertions(+), 6 deletions(-) create mode 100644 .vscode/settings.json diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 000000000..242c7c86f --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,4 @@ +{ + "python.analysis.typeCheckingMode": "off", + "python.analysis.autoImportCompletions": true +} \ No newline at end of file diff --git a/litellm/llms/ollama.py b/litellm/llms/ollama.py index a24e47c07..40ff94390 100644 --- a/litellm/llms/ollama.py +++ b/litellm/llms/ollama.py @@ -1,10 +1,12 @@ -import requests, types +from email import header +from re import T +from tkinter import N +import requests, types, time import json import traceback from typing import Optional import litellm -import httpx - +import httpx, aiohttp, asyncio try: from async_generator import async_generator, yield_ # optional dependency async_generator_imported = True @@ -115,6 +117,9 @@ def get_ollama_response_stream( prompt="Why is the sky blue?", optional_params=None, logging_obj=None, + acompletion: bool = False, + model_response=None, + encoding=None ): if api_base.endswith("/api/generate"): url = api_base @@ -136,8 +141,15 @@ def get_ollama_response_stream( logging_obj.pre_call( input=None, api_key=None, - additional_args={"api_base": url, "complete_input_dict": data}, + additional_args={"api_base": url, "complete_input_dict": data, "headers": {}, "acompletion": acompletion,}, ) + if acompletion is True: + response = ollama_acompletion(url=url, data=data, model_response=model_response, encoding=encoding, logging_obj=logging_obj) + return response + else: + return ollama_completion_stream(url=url, data=data) + +def ollama_completion_stream(url, data): session = requests.Session() with session.post(url, json=data, stream=True) as resp: @@ -169,6 +181,52 @@ def get_ollama_response_stream( traceback.print_exc() session.close() +async def ollama_acompletion(url, data, model_response, encoding, logging_obj): + + try: + timeout = aiohttp.ClientTimeout(total=600) # 10 minutes + async with aiohttp.ClientSession(timeout=timeout) as session: + resp = await session.post(url, json=data) + + if resp.status != 200: + text = await resp.text() + raise OllamaError(status_code=resp.status, message=text) + + async for line in resp.content.iter_any(): + if line: + try: + json_chunk = line.decode("utf-8") + chunks = json_chunk.split("\n") + completion_string = "" + for chunk in chunks: + if chunk.strip() != "": + j = json.loads(chunk) + if "error" in j: + completion_obj = { + "role": "assistant", + "content": "", + "error": j + } + if "response" in j: + completion_obj = { + "role": "assistant", + "content": j["response"], + } + completion_string += completion_obj["content"] + except Exception as e: + traceback.print_exc() + ## RESPONSE OBJECT + model_response["choices"][0]["finish_reason"] = "stop" + model_response["choices"][0]["message"]["content"] = completion_string + model_response["created"] = int(time.time()) + model_response["model"] = "ollama/" + data['model'] + prompt_tokens = len(encoding.encode(data['prompt'])) # type: ignore + completion_tokens = len(encoding.encode(completion_string)) + model_response["usage"] = litellm.Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens) + return model_response + except Exception as e: + traceback.print_exc() + if async_generator_imported: # ollama implementation @async_generator diff --git a/litellm/main.py b/litellm/main.py index ddeff3414..828497820 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -8,6 +8,7 @@ # Thank you ! We ❤️ you! - Krrish & Ishaan import os, openai, sys, json, inspect, uuid, datetime, threading +from re import T from typing import Any from functools import partial import dotenv, traceback, random, asyncio, time, contextvars @@ -175,7 +176,8 @@ async def acompletion(*args, **kwargs): or custom_llm_provider == "deepinfra" or custom_llm_provider == "perplexity" or custom_llm_provider == "text-completion-openai" - or custom_llm_provider == "huggingface"): # currently implemented aiohttp calls for just azure and openai, soon all. + or custom_llm_provider == "huggingface" + or custom_llm_provider == "ollama"): # currently implemented aiohttp calls for just azure and openai, soon all. if kwargs.get("stream", False): response = completion(*args, **kwargs) else: @@ -1318,7 +1320,9 @@ def completion( async_generator = ollama.async_get_ollama_response_stream(api_base, model, prompt, optional_params, logging_obj=logging) return async_generator - generator = ollama.get_ollama_response_stream(api_base, model, prompt, optional_params, logging_obj=logging) + generator = ollama.get_ollama_response_stream(api_base, model, prompt, optional_params, logging_obj=logging, acompletion=acompletion, model_response=model_response, encoding=encoding) + if acompletion is True: + return generator if optional_params.get("stream", False) == True: # assume all ollama responses are streamed response = CustomStreamWrapper(