litellm-mirror/litellm/llms/ollama/completion/handler.py
Krish Dholakia 405080396d
Litellm ollama refactor (#7162)
* refactor(ollama/): refactor ollama `/api/generate` to use base llm config

Addresses https://github.com/andrewyng/aisuite/issues/113#issuecomment-2512369132

* test: skip unresponsive test

* test(test_secret_manager.py): mark flaky test

* test: fix google sm test

* fix: fix init.py
2024-12-10 21:45:35 -08:00

496 lines
17 KiB
Python

import asyncio
import json
import time
import traceback
import types
import uuid
from copy import deepcopy
from itertools import chain
from typing import Any, Dict, List, Optional
import aiohttp
import httpx # type: ignore
import requests # type: ignore
import litellm
from litellm import verbose_logger
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
from litellm.secret_managers.main import get_secret_str
from litellm.types.utils import ModelInfo, ProviderField, StreamingChoices
from ...prompt_templates.factory import custom_prompt, prompt_factory
from ..common_utils import OllamaError
from .transformation import OllamaConfig
# ollama wants plain base64 jpeg/png files as images. strip any leading dataURI
# and convert to jpeg if necessary.
def _convert_image(image):
import base64
import io
try:
from PIL import Image
except Exception:
raise Exception(
"ollama image conversion failed please run `pip install Pillow`"
)
orig = image
if image.startswith("data:"):
image = image.split(",")[-1]
try:
image_data = Image.open(io.BytesIO(base64.b64decode(image)))
if image_data.format in ["JPEG", "PNG"]:
return image
except Exception:
return orig
jpeg_image = io.BytesIO()
image_data.convert("RGB").save(jpeg_image, "JPEG")
jpeg_image.seek(0)
return base64.b64encode(jpeg_image.getvalue()).decode("utf-8")
# ollama implementation
def get_ollama_response(
model_response: litellm.ModelResponse,
model: str,
prompt: str,
optional_params: dict,
logging_obj: Any,
encoding: Any,
acompletion: bool = False,
api_base="http://localhost:11434",
):
if api_base.endswith("/api/generate"):
url = api_base
else:
url = f"{api_base}/api/generate"
## Load Config
config = litellm.OllamaConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
stream = optional_params.pop("stream", False)
format = optional_params.pop("format", None)
images = optional_params.pop("images", None)
data = {
"model": model,
"prompt": prompt,
"options": optional_params,
"stream": stream,
}
if format is not None:
data["format"] = format
if images is not None:
data["images"] = [_convert_image(image) for image in images]
## LOGGING
logging_obj.pre_call(
input=None,
api_key=None,
additional_args={
"api_base": url,
"complete_input_dict": data,
"headers": {},
"acompletion": acompletion,
},
)
if acompletion is True:
if stream is True:
response = ollama_async_streaming(
url=url,
data=data,
model_response=model_response,
encoding=encoding,
logging_obj=logging_obj,
)
else:
response = ollama_acompletion(
url=url,
data=data,
model_response=model_response,
encoding=encoding,
logging_obj=logging_obj,
)
return response
elif stream is True:
return ollama_completion_stream(url=url, data=data, logging_obj=logging_obj)
response = requests.post(
url=f"{url}", json={**data, "stream": stream}, timeout=litellm.request_timeout
)
if response.status_code != 200:
raise OllamaError(
status_code=response.status_code,
message=response.text,
headers=dict(response.headers),
)
## LOGGING
logging_obj.post_call(
input=prompt,
api_key="",
original_response=response.text,
additional_args={
"headers": None,
"api_base": api_base,
},
)
response_json = response.json()
## RESPONSE OBJECT
model_response.choices[0].finish_reason = "stop"
if data.get("format", "") == "json":
function_call = json.loads(response_json["response"])
message = litellm.Message(
content=None,
tool_calls=[
{
"id": f"call_{str(uuid.uuid4())}",
"function": {
"name": function_call["name"],
"arguments": json.dumps(function_call["arguments"]),
},
"type": "function",
}
],
)
model_response.choices[0].message = message # type: ignore
model_response.choices[0].finish_reason = "tool_calls"
else:
model_response.choices[0].message.content = response_json["response"] # type: ignore
model_response.created = int(time.time())
model_response.model = "ollama/" + model
prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(prompt, disallowed_special=()))) # type: ignore
completion_tokens = response_json.get(
"eval_count", len(response_json.get("message", dict()).get("content", ""))
)
setattr(
model_response,
"usage",
litellm.Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
)
return model_response
def ollama_completion_stream(url, data, logging_obj):
with httpx.stream(
url=url, json=data, method="POST", timeout=litellm.request_timeout
) as response:
try:
if response.status_code != 200:
raise OllamaError(
status_code=response.status_code,
message=str(response.read()),
headers=response.headers,
)
streamwrapper = litellm.CustomStreamWrapper(
completion_stream=response.iter_lines(),
model=data["model"],
custom_llm_provider="ollama",
logging_obj=logging_obj,
)
# If format is JSON, this was a function call
# Gather all chunks and return the function call as one delta to simplify parsing
if data.get("format", "") == "json":
first_chunk = next(streamwrapper)
content_chunks = []
for chunk in chain([first_chunk], streamwrapper):
content_chunk = chunk.choices[0]
if (
isinstance(content_chunk, StreamingChoices)
and hasattr(content_chunk, "delta")
and hasattr(content_chunk.delta, "content")
and content_chunk.delta.content is not None
):
content_chunks.append(content_chunk.delta.content)
response_content = "".join(content_chunks)
function_call = json.loads(response_content)
delta = litellm.utils.Delta(
content=None,
tool_calls=[
{
"id": f"call_{str(uuid.uuid4())}",
"function": {
"name": function_call["name"],
"arguments": json.dumps(function_call["arguments"]),
},
"type": "function",
}
],
)
model_response = first_chunk
model_response.choices[0].delta = delta # type: ignore
model_response.choices[0].finish_reason = "tool_calls"
yield model_response
else:
for transformed_chunk in streamwrapper:
yield transformed_chunk
except Exception as e:
raise e
async def ollama_async_streaming(url, data, model_response, encoding, logging_obj):
try:
_async_http_client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.OLLAMA
)
client = _async_http_client.client
async with client.stream(
url=f"{url}", json=data, method="POST", timeout=litellm.request_timeout
) as response:
if response.status_code != 200:
raise OllamaError(
status_code=response.status_code,
message=str(await response.aread()),
headers=dict(response.headers),
)
streamwrapper = litellm.CustomStreamWrapper(
completion_stream=response.aiter_lines(),
model=data["model"],
custom_llm_provider="ollama",
logging_obj=logging_obj,
)
# If format is JSON, this was a function call
# Gather all chunks and return the function call as one delta to simplify parsing
if data.get("format", "") == "json":
first_chunk = await anext(streamwrapper) # noqa F821
chunk_choice = first_chunk.choices[0]
if (
isinstance(chunk_choice, StreamingChoices)
and hasattr(chunk_choice, "delta")
and hasattr(chunk_choice.delta, "content")
):
first_chunk_content = chunk_choice.delta.content or ""
else:
first_chunk_content = ""
content_chunks = []
async for chunk in streamwrapper:
chunk_choice = chunk.choices[0]
if (
isinstance(chunk_choice, StreamingChoices)
and hasattr(chunk_choice, "delta")
and hasattr(chunk_choice.delta, "content")
):
content_chunks.append(chunk_choice.delta.content)
response_content = first_chunk_content + "".join(content_chunks)
function_call = json.loads(response_content)
delta = litellm.utils.Delta(
content=None,
tool_calls=[
{
"id": f"call_{str(uuid.uuid4())}",
"function": {
"name": function_call["name"],
"arguments": json.dumps(function_call["arguments"]),
},
"type": "function",
}
],
)
model_response = first_chunk
model_response.choices[0].delta = delta # type: ignore
model_response.choices[0].finish_reason = "tool_calls"
yield model_response
else:
async for transformed_chunk in streamwrapper:
yield transformed_chunk
except Exception as e:
raise e # don't use verbose_logger.exception, if exception is raised
async def ollama_acompletion(
url, data, model_response: litellm.ModelResponse, encoding, logging_obj
):
data["stream"] = False
try:
timeout = aiohttp.ClientTimeout(total=litellm.request_timeout) # 10 minutes
async with aiohttp.ClientSession(timeout=timeout) as session:
resp = await session.post(url, json=data)
if resp.status != 200:
text = await resp.text()
raise OllamaError(
status_code=resp.status,
message=text,
headers=dict(resp.headers),
)
## LOGGING
logging_obj.post_call(
input=data["prompt"],
api_key="",
original_response=resp.text,
additional_args={
"headers": None,
"api_base": url,
},
)
response_json = await resp.json()
## RESPONSE OBJECT
model_response.choices[0].finish_reason = "stop"
if data.get("format", "") == "json":
function_call = json.loads(response_json["response"])
message = litellm.Message(
content=None,
tool_calls=[
{
"id": f"call_{str(uuid.uuid4())}",
"function": {
"name": function_call.get(
"name", function_call.get("function", None)
),
"arguments": json.dumps(function_call["arguments"]),
},
"type": "function",
}
],
)
model_response.choices[0].message = message # type: ignore
model_response.choices[0].finish_reason = "tool_calls"
else:
model_response.choices[0].message.content = response_json["response"] # type: ignore
model_response.created = int(time.time())
model_response.model = "ollama/" + data["model"]
prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(data["prompt"], disallowed_special=()))) # type: ignore
completion_tokens = response_json.get(
"eval_count",
len(response_json.get("message", dict()).get("content", "")),
)
setattr(
model_response,
"usage",
litellm.Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
)
return model_response
except Exception as e:
raise e # don't use verbose_logger.exception, if exception is raised
async def ollama_aembeddings(
api_base: str,
model: str,
prompts: List[str],
model_response: litellm.EmbeddingResponse,
optional_params: dict,
logging_obj: Any,
encoding: Any,
):
if api_base.endswith("/api/embed"):
url = api_base
else:
url = f"{api_base}/api/embed"
## Load Config
config = litellm.OllamaConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
data: Dict[str, Any] = {"model": model, "input": prompts}
special_optional_params = ["truncate", "options", "keep_alive"]
for k, v in optional_params.items():
if k in special_optional_params:
data[k] = v
else:
# Ensure "options" is a dictionary before updating it
data.setdefault("options", {})
if isinstance(data["options"], dict):
data["options"].update({k: v})
total_input_tokens = 0
output_data = []
timeout = aiohttp.ClientTimeout(total=litellm.request_timeout) # 10 minutes
async with aiohttp.ClientSession(timeout=timeout) as session:
## LOGGING
logging_obj.pre_call(
input=None,
api_key=None,
additional_args={
"api_base": url,
"complete_input_dict": data,
"headers": {},
},
)
response = await session.post(url, json=data)
if response.status != 200:
text = await response.text()
raise OllamaError(
status_code=response.status,
message=text,
headers=dict(response.headers),
)
response_json = await response.json()
embeddings: List[List[float]] = response_json["embeddings"]
for idx, emb in enumerate(embeddings):
output_data.append({"object": "embedding", "index": idx, "embedding": emb})
input_tokens = response_json.get("prompt_eval_count") or len(
encoding.encode("".join(prompt for prompt in prompts))
)
total_input_tokens += input_tokens
model_response.object = "list"
model_response.data = output_data
model_response.model = "ollama/" + model
setattr(
model_response,
"usage",
litellm.Usage(
prompt_tokens=total_input_tokens,
completion_tokens=total_input_tokens,
total_tokens=total_input_tokens,
prompt_tokens_details=None,
completion_tokens_details=None,
),
)
return model_response
def ollama_embeddings(
api_base: str,
model: str,
prompts: list,
optional_params: dict,
model_response: litellm.EmbeddingResponse,
logging_obj: Any,
encoding=None,
):
return asyncio.run(
ollama_aembeddings(
api_base=api_base,
model=model,
prompts=prompts,
model_response=model_response,
optional_params=optional_params,
logging_obj=logging_obj,
encoding=encoding,
)
)