mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 03:34:10 +00:00
refactor: add black formatting
This commit is contained in:
parent
b87d630b0a
commit
4905929de3
156 changed files with 19723 additions and 10869 deletions
|
@ -2,10 +2,11 @@ import requests, types, time
|
|||
import json, uuid
|
||||
import traceback
|
||||
from typing import Optional
|
||||
import litellm
|
||||
import litellm
|
||||
import httpx, aiohttp, asyncio
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
|
||||
|
||||
class OllamaError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
self.status_code = status_code
|
||||
|
@ -16,14 +17,15 @@ class OllamaError(Exception):
|
|||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
class OllamaConfig():
|
||||
|
||||
class OllamaConfig:
|
||||
"""
|
||||
Reference: https://github.com/jmorganca/ollama/blob/main/docs/api.md#parameters
|
||||
|
||||
The class `OllamaConfig` provides the configuration for the Ollama's API interface. Below are the parameters:
|
||||
|
||||
|
||||
- `mirostat` (int): Enable Mirostat sampling for controlling perplexity. Default is 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0. Example usage: mirostat 0
|
||||
|
||||
|
||||
- `mirostat_eta` (float): Influences how quickly the algorithm responds to feedback from the generated text. A lower learning rate will result in slower adjustments, while a higher learning rate will make the algorithm more responsive. Default: 0.1. Example usage: mirostat_eta 0.1
|
||||
|
||||
- `mirostat_tau` (float): Controls the balance between coherence and diversity of the output. A lower value will result in more focused and coherent text. Default: 5.0. Example usage: mirostat_tau 5.0
|
||||
|
@ -56,102 +58,134 @@ class OllamaConfig():
|
|||
|
||||
- `template` (string): the full prompt or prompt template (overrides what is defined in the Modelfile)
|
||||
"""
|
||||
mirostat: Optional[int]=None
|
||||
mirostat_eta: Optional[float]=None
|
||||
mirostat_tau: Optional[float]=None
|
||||
num_ctx: Optional[int]=None
|
||||
num_gqa: Optional[int]=None
|
||||
num_thread: Optional[int]=None
|
||||
repeat_last_n: Optional[int]=None
|
||||
repeat_penalty: Optional[float]=None
|
||||
temperature: Optional[float]=None
|
||||
stop: Optional[list]=None # stop is a list based on this - https://github.com/jmorganca/ollama/pull/442
|
||||
tfs_z: Optional[float]=None
|
||||
num_predict: Optional[int]=None
|
||||
top_k: Optional[int]=None
|
||||
top_p: Optional[float]=None
|
||||
system: Optional[str]=None
|
||||
template: Optional[str]=None
|
||||
|
||||
def __init__(self,
|
||||
mirostat: Optional[int]=None,
|
||||
mirostat_eta: Optional[float]=None,
|
||||
mirostat_tau: Optional[float]=None,
|
||||
num_ctx: Optional[int]=None,
|
||||
num_gqa: Optional[int]=None,
|
||||
num_thread: Optional[int]=None,
|
||||
repeat_last_n: Optional[int]=None,
|
||||
repeat_penalty: Optional[float]=None,
|
||||
temperature: Optional[float]=None,
|
||||
stop: Optional[list]=None,
|
||||
tfs_z: Optional[float]=None,
|
||||
num_predict: Optional[int]=None,
|
||||
top_k: Optional[int]=None,
|
||||
top_p: Optional[float]=None,
|
||||
system: Optional[str]=None,
|
||||
template: Optional[str]=None) -> None:
|
||||
mirostat: Optional[int] = None
|
||||
mirostat_eta: Optional[float] = None
|
||||
mirostat_tau: Optional[float] = None
|
||||
num_ctx: Optional[int] = None
|
||||
num_gqa: Optional[int] = None
|
||||
num_thread: Optional[int] = None
|
||||
repeat_last_n: Optional[int] = None
|
||||
repeat_penalty: Optional[float] = None
|
||||
temperature: Optional[float] = None
|
||||
stop: Optional[
|
||||
list
|
||||
] = None # stop is a list based on this - https://github.com/jmorganca/ollama/pull/442
|
||||
tfs_z: Optional[float] = None
|
||||
num_predict: Optional[int] = None
|
||||
top_k: Optional[int] = None
|
||||
top_p: Optional[float] = None
|
||||
system: Optional[str] = None
|
||||
template: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mirostat: Optional[int] = None,
|
||||
mirostat_eta: Optional[float] = None,
|
||||
mirostat_tau: Optional[float] = None,
|
||||
num_ctx: Optional[int] = None,
|
||||
num_gqa: Optional[int] = None,
|
||||
num_thread: Optional[int] = None,
|
||||
repeat_last_n: Optional[int] = None,
|
||||
repeat_penalty: Optional[float] = None,
|
||||
temperature: Optional[float] = None,
|
||||
stop: Optional[list] = None,
|
||||
tfs_z: Optional[float] = None,
|
||||
num_predict: Optional[int] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
system: Optional[str] = None,
|
||||
template: Optional[str] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != 'self' and value is not None:
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {k: v for k, v in cls.__dict__.items()
|
||||
if not k.startswith('__')
|
||||
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
|
||||
and v is not None}
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
|
||||
# ollama implementation
|
||||
def get_ollama_response(
|
||||
api_base="http://localhost:11434",
|
||||
model="llama2",
|
||||
prompt="Why is the sky blue?",
|
||||
optional_params=None,
|
||||
logging_obj=None,
|
||||
acompletion: bool = False,
|
||||
model_response=None,
|
||||
encoding=None
|
||||
):
|
||||
api_base="http://localhost:11434",
|
||||
model="llama2",
|
||||
prompt="Why is the sky blue?",
|
||||
optional_params=None,
|
||||
logging_obj=None,
|
||||
acompletion: bool = False,
|
||||
model_response=None,
|
||||
encoding=None,
|
||||
):
|
||||
if api_base.endswith("/api/generate"):
|
||||
url = api_base
|
||||
else:
|
||||
else:
|
||||
url = f"{api_base}/api/generate"
|
||||
|
||||
|
||||
## Load Config
|
||||
config=litellm.OllamaConfig.get_config()
|
||||
config = litellm.OllamaConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if k not in optional_params: # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
if (
|
||||
k not in optional_params
|
||||
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
optional_params[k] = v
|
||||
|
||||
optional_params["stream"] = optional_params.get("stream", False)
|
||||
data = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
**optional_params
|
||||
}
|
||||
data = {"model": model, "prompt": prompt, **optional_params}
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=None,
|
||||
api_key=None,
|
||||
additional_args={"api_base": url, "complete_input_dict": data, "headers": {}, "acompletion": acompletion,},
|
||||
additional_args={
|
||||
"api_base": url,
|
||||
"complete_input_dict": data,
|
||||
"headers": {},
|
||||
"acompletion": acompletion,
|
||||
},
|
||||
)
|
||||
if acompletion is True:
|
||||
if acompletion is True:
|
||||
if optional_params.get("stream", False) == True:
|
||||
response = ollama_async_streaming(url=url, data=data, model_response=model_response, encoding=encoding, logging_obj=logging_obj)
|
||||
response = ollama_async_streaming(
|
||||
url=url,
|
||||
data=data,
|
||||
model_response=model_response,
|
||||
encoding=encoding,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
else:
|
||||
response = ollama_acompletion(url=url, data=data, model_response=model_response, encoding=encoding, logging_obj=logging_obj)
|
||||
response = ollama_acompletion(
|
||||
url=url,
|
||||
data=data,
|
||||
model_response=model_response,
|
||||
encoding=encoding,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return response
|
||||
elif optional_params.get("stream", False) == True:
|
||||
return ollama_completion_stream(url=url, data=data, logging_obj=logging_obj)
|
||||
|
||||
|
||||
response = requests.post(
|
||||
url=f"{url}",
|
||||
json=data,
|
||||
)
|
||||
url=f"{url}",
|
||||
json=data,
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise OllamaError(status_code=response.status_code, message=response.text)
|
||||
|
||||
raise OllamaError(status_code=response.status_code, message=response.text)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=prompt,
|
||||
|
@ -168,52 +202,76 @@ def get_ollama_response(
|
|||
## RESPONSE OBJECT
|
||||
model_response["choices"][0]["finish_reason"] = "stop"
|
||||
if optional_params.get("format", "") == "json":
|
||||
message = litellm.Message(content=None, tool_calls=[{"id": f"call_{str(uuid.uuid4())}", "function": {"arguments": response_json["response"], "name": ""}, "type": "function"}])
|
||||
message = litellm.Message(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
{
|
||||
"id": f"call_{str(uuid.uuid4())}",
|
||||
"function": {"arguments": response_json["response"], "name": ""},
|
||||
"type": "function",
|
||||
}
|
||||
],
|
||||
)
|
||||
model_response["choices"][0]["message"] = message
|
||||
else:
|
||||
model_response["choices"][0]["message"]["content"] = response_json["response"]
|
||||
model_response["created"] = int(time.time())
|
||||
model_response["model"] = "ollama/" + model
|
||||
prompt_tokens = response_json["prompt_eval_count"] # type: ignore
|
||||
prompt_tokens = response_json["prompt_eval_count"] # type: ignore
|
||||
completion_tokens = response_json["eval_count"]
|
||||
model_response["usage"] = litellm.Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens)
|
||||
model_response["usage"] = litellm.Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
return model_response
|
||||
|
||||
|
||||
def ollama_completion_stream(url, data, logging_obj):
|
||||
with httpx.stream(
|
||||
url=url,
|
||||
json=data,
|
||||
method="POST",
|
||||
timeout=litellm.request_timeout
|
||||
) as response:
|
||||
try:
|
||||
url=url, json=data, method="POST", timeout=litellm.request_timeout
|
||||
) as response:
|
||||
try:
|
||||
if response.status_code != 200:
|
||||
raise OllamaError(status_code=response.status_code, message=response.text)
|
||||
|
||||
streamwrapper = litellm.CustomStreamWrapper(completion_stream=response.iter_lines(), model=data['model'], custom_llm_provider="ollama",logging_obj=logging_obj)
|
||||
raise OllamaError(
|
||||
status_code=response.status_code, message=response.text
|
||||
)
|
||||
|
||||
streamwrapper = litellm.CustomStreamWrapper(
|
||||
completion_stream=response.iter_lines(),
|
||||
model=data["model"],
|
||||
custom_llm_provider="ollama",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
for transformed_chunk in streamwrapper:
|
||||
yield transformed_chunk
|
||||
except Exception as e:
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
async def ollama_async_streaming(url, data, model_response, encoding, logging_obj):
|
||||
try:
|
||||
client = httpx.AsyncClient()
|
||||
async with client.stream(
|
||||
url=f"{url}",
|
||||
json=data,
|
||||
method="POST",
|
||||
timeout=litellm.request_timeout
|
||||
) as response:
|
||||
if response.status_code != 200:
|
||||
raise OllamaError(status_code=response.status_code, message=response.text)
|
||||
|
||||
streamwrapper = litellm.CustomStreamWrapper(completion_stream=response.aiter_lines(), model=data['model'], custom_llm_provider="ollama",logging_obj=logging_obj)
|
||||
async for transformed_chunk in streamwrapper:
|
||||
yield transformed_chunk
|
||||
url=f"{url}", json=data, method="POST", timeout=litellm.request_timeout
|
||||
) as response:
|
||||
if response.status_code != 200:
|
||||
raise OllamaError(
|
||||
status_code=response.status_code, message=response.text
|
||||
)
|
||||
|
||||
streamwrapper = litellm.CustomStreamWrapper(
|
||||
completion_stream=response.aiter_lines(),
|
||||
model=data["model"],
|
||||
custom_llm_provider="ollama",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
async for transformed_chunk in streamwrapper:
|
||||
yield transformed_chunk
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
|
||||
data["stream"] = False
|
||||
try:
|
||||
|
@ -224,10 +282,10 @@ async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
|
|||
if resp.status != 200:
|
||||
text = await resp.text()
|
||||
raise OllamaError(status_code=resp.status, message=text)
|
||||
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=data['prompt'],
|
||||
input=data["prompt"],
|
||||
api_key="",
|
||||
original_response=resp.text,
|
||||
additional_args={
|
||||
|
@ -240,37 +298,59 @@ async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
|
|||
## RESPONSE OBJECT
|
||||
model_response["choices"][0]["finish_reason"] = "stop"
|
||||
if data.get("format", "") == "json":
|
||||
message = litellm.Message(content=None, tool_calls=[{"id": f"call_{str(uuid.uuid4())}", "function": {"arguments": response_json["response"], "name": ""}, "type": "function"}])
|
||||
message = litellm.Message(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
{
|
||||
"id": f"call_{str(uuid.uuid4())}",
|
||||
"function": {
|
||||
"arguments": response_json["response"],
|
||||
"name": "",
|
||||
},
|
||||
"type": "function",
|
||||
}
|
||||
],
|
||||
)
|
||||
model_response["choices"][0]["message"] = message
|
||||
else:
|
||||
model_response["choices"][0]["message"]["content"] = response_json["response"]
|
||||
model_response["choices"][0]["message"]["content"] = response_json[
|
||||
"response"
|
||||
]
|
||||
model_response["created"] = int(time.time())
|
||||
model_response["model"] = "ollama/" + data['model']
|
||||
prompt_tokens = response_json["prompt_eval_count"] # type: ignore
|
||||
model_response["model"] = "ollama/" + data["model"]
|
||||
prompt_tokens = response_json["prompt_eval_count"] # type: ignore
|
||||
completion_tokens = response_json["eval_count"]
|
||||
model_response["usage"] = litellm.Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens)
|
||||
model_response["usage"] = litellm.Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
return model_response
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
raise e
|
||||
|
||||
|
||||
async def ollama_aembeddings(api_base="http://localhost:11434",
|
||||
model="llama2",
|
||||
prompt="Why is the sky blue?",
|
||||
optional_params=None,
|
||||
logging_obj=None,
|
||||
model_response=None,
|
||||
encoding=None):
|
||||
async def ollama_aembeddings(
|
||||
api_base="http://localhost:11434",
|
||||
model="llama2",
|
||||
prompt="Why is the sky blue?",
|
||||
optional_params=None,
|
||||
logging_obj=None,
|
||||
model_response=None,
|
||||
encoding=None,
|
||||
):
|
||||
if api_base.endswith("/api/embeddings"):
|
||||
url = api_base
|
||||
else:
|
||||
else:
|
||||
url = f"{api_base}/api/embeddings"
|
||||
|
||||
|
||||
## Load Config
|
||||
config=litellm.OllamaConfig.get_config()
|
||||
config = litellm.OllamaConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if k not in optional_params: # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
if (
|
||||
k not in optional_params
|
||||
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
optional_params[k] = v
|
||||
|
||||
data = {
|
||||
|
@ -290,7 +370,7 @@ async def ollama_aembeddings(api_base="http://localhost:11434",
|
|||
if response.status != 200:
|
||||
text = await response.text()
|
||||
raise OllamaError(status_code=response.status, message=text)
|
||||
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=prompt,
|
||||
|
@ -308,20 +388,16 @@ async def ollama_aembeddings(api_base="http://localhost:11434",
|
|||
output_data = []
|
||||
for idx, embedding in enumerate(embeddings):
|
||||
output_data.append(
|
||||
{
|
||||
"object": "embedding",
|
||||
"index": idx,
|
||||
"embedding": embedding
|
||||
}
|
||||
{"object": "embedding", "index": idx, "embedding": embedding}
|
||||
)
|
||||
model_response["object"] = "list"
|
||||
model_response["data"] = output_data
|
||||
model_response["model"] = model
|
||||
|
||||
input_tokens = len(encoding.encode(prompt))
|
||||
input_tokens = len(encoding.encode(prompt))
|
||||
|
||||
model_response["usage"] = {
|
||||
"prompt_tokens": input_tokens,
|
||||
model_response["usage"] = {
|
||||
"prompt_tokens": input_tokens,
|
||||
"total_tokens": input_tokens,
|
||||
}
|
||||
return model_response
|
||||
return model_response
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue