litellm-mirror/litellm/llms/ollama/completion/handler.py
Krish Dholakia b82add11ba
LITELLM: Remove requests library usage (#7235)
* fix(generic_api_callback.py): remove requests lib usage

* fix(budget_manager.py): remove requests lib usgae

* fix(main.py): cleanup requests lib usage

* fix(utils.py): remove requests lib usage

* fix(argilla.py): fix argilla test

* fix(athina.py): replace 'requests' lib usage with litellm module

* fix(greenscale.py): replace 'requests' lib usage with httpx

* fix: remove unused 'requests' lib import + replace usage in some places

* fix(prompt_layer.py): remove 'requests' lib usage from prompt layer

* fix(ollama_chat.py): remove 'requests' lib usage

* fix(baseten.py): replace 'requests' lib usage

* fix(codestral/): replace 'requests' lib usage

* fix(predibase/): replace 'requests' lib usage

* refactor: cleanup unused 'requests' lib imports

* fix(oobabooga.py): cleanup 'requests' lib usage

* fix(invoke_handler.py): remove unused 'requests' lib usage

* refactor: cleanup unused 'requests' lib import

* fix: fix linting errors

* refactor(ollama/): move ollama to using base llm http handler

removes 'requests' lib dep for ollama integration

* fix(ollama_chat.py): fix linting errors

* fix(ollama/completion/transformation.py): convert non-jpeg/png image to jpeg/png before passing to ollama
2024-12-17 12:50:04 -08:00

124 lines
3.4 KiB
Python

"""
Ollama /chat/completion calls handled in llm_http_handler.py
[TODO]: migrate embeddings to a base handler as well.
"""
import asyncio
import json
import time
import traceback
import types
import uuid
from copy import deepcopy
from itertools import chain
from typing import Any, Dict, List, Optional
import litellm
from litellm import verbose_logger
from litellm.litellm_core_utils.prompt_templates.factory import (
custom_prompt,
prompt_factory,
)
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
from litellm.secret_managers.main import get_secret_str
from litellm.types.utils import (
EmbeddingResponse,
ModelInfo,
ModelResponse,
ProviderField,
StreamingChoices,
)
from ..common_utils import OllamaError
from .transformation import OllamaConfig
# ollama wants plain base64 jpeg/png files as images. strip any leading dataURI
# and convert to jpeg if necessary.
async def ollama_aembeddings(
api_base: str,
model: str,
prompts: List[str],
model_response: EmbeddingResponse,
optional_params: dict,
logging_obj: Any,
encoding: Any,
):
if api_base.endswith("/api/embed"):
url = api_base
else:
url = f"{api_base}/api/embed"
## Load Config
config = litellm.OllamaConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
data: Dict[str, Any] = {"model": model, "input": prompts}
special_optional_params = ["truncate", "options", "keep_alive"]
for k, v in optional_params.items():
if k in special_optional_params:
data[k] = v
else:
# Ensure "options" is a dictionary before updating it
data.setdefault("options", {})
if isinstance(data["options"], dict):
data["options"].update({k: v})
total_input_tokens = 0
output_data = []
response = await litellm.module_level_aclient.post(url=url, json=data)
response_json = await response.json()
embeddings: List[List[float]] = response_json["embeddings"]
for idx, emb in enumerate(embeddings):
output_data.append({"object": "embedding", "index": idx, "embedding": emb})
input_tokens = response_json.get("prompt_eval_count") or len(
encoding.encode("".join(prompt for prompt in prompts))
)
total_input_tokens += input_tokens
model_response.object = "list"
model_response.data = output_data
model_response.model = "ollama/" + model
setattr(
model_response,
"usage",
litellm.Usage(
prompt_tokens=total_input_tokens,
completion_tokens=total_input_tokens,
total_tokens=total_input_tokens,
prompt_tokens_details=None,
completion_tokens_details=None,
),
)
return model_response
def ollama_embeddings(
api_base: str,
model: str,
prompts: list,
optional_params: dict,
model_response: EmbeddingResponse,
logging_obj: Any,
encoding=None,
):
return asyncio.run(
ollama_aembeddings(
api_base=api_base,
model=model,
prompts=prompts,
model_response=model_response,
optional_params=optional_params,
logging_obj=logging_obj,
encoding=encoding,
)
)