refactor: add black formatting

This commit is contained in:
Krrish Dholakia 2023-12-25 14:10:38 +05:30
parent b87d630b0a
commit 4905929de3
156 changed files with 19723 additions and 10869 deletions

View file

@ -2,10 +2,11 @@ import requests, types, time
import json, uuid
import traceback
from typing import Optional
import litellm
import litellm
import httpx, aiohttp, asyncio
from .prompt_templates.factory import prompt_factory, custom_prompt
class OllamaError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
@ -16,14 +17,15 @@ class OllamaError(Exception):
self.message
) # Call the base class constructor with the parameters it needs
class OllamaConfig():
class OllamaConfig:
"""
Reference: https://github.com/jmorganca/ollama/blob/main/docs/api.md#parameters
The class `OllamaConfig` provides the configuration for the Ollama's API interface. Below are the parameters:
- `mirostat` (int): Enable Mirostat sampling for controlling perplexity. Default is 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0. Example usage: mirostat 0
- `mirostat_eta` (float): Influences how quickly the algorithm responds to feedback from the generated text. A lower learning rate will result in slower adjustments, while a higher learning rate will make the algorithm more responsive. Default: 0.1. Example usage: mirostat_eta 0.1
- `mirostat_tau` (float): Controls the balance between coherence and diversity of the output. A lower value will result in more focused and coherent text. Default: 5.0. Example usage: mirostat_tau 5.0
@ -56,102 +58,134 @@ class OllamaConfig():
- `template` (string): the full prompt or prompt template (overrides what is defined in the Modelfile)
"""
mirostat: Optional[int]=None
mirostat_eta: Optional[float]=None
mirostat_tau: Optional[float]=None
num_ctx: Optional[int]=None
num_gqa: Optional[int]=None
num_thread: Optional[int]=None
repeat_last_n: Optional[int]=None
repeat_penalty: Optional[float]=None
temperature: Optional[float]=None
stop: Optional[list]=None # stop is a list based on this - https://github.com/jmorganca/ollama/pull/442
tfs_z: Optional[float]=None
num_predict: Optional[int]=None
top_k: Optional[int]=None
top_p: Optional[float]=None
system: Optional[str]=None
template: Optional[str]=None
def __init__(self,
mirostat: Optional[int]=None,
mirostat_eta: Optional[float]=None,
mirostat_tau: Optional[float]=None,
num_ctx: Optional[int]=None,
num_gqa: Optional[int]=None,
num_thread: Optional[int]=None,
repeat_last_n: Optional[int]=None,
repeat_penalty: Optional[float]=None,
temperature: Optional[float]=None,
stop: Optional[list]=None,
tfs_z: Optional[float]=None,
num_predict: Optional[int]=None,
top_k: Optional[int]=None,
top_p: Optional[float]=None,
system: Optional[str]=None,
template: Optional[str]=None) -> None:
mirostat: Optional[int] = None
mirostat_eta: Optional[float] = None
mirostat_tau: Optional[float] = None
num_ctx: Optional[int] = None
num_gqa: Optional[int] = None
num_thread: Optional[int] = None
repeat_last_n: Optional[int] = None
repeat_penalty: Optional[float] = None
temperature: Optional[float] = None
stop: Optional[
list
] = None # stop is a list based on this - https://github.com/jmorganca/ollama/pull/442
tfs_z: Optional[float] = None
num_predict: Optional[int] = None
top_k: Optional[int] = None
top_p: Optional[float] = None
system: Optional[str] = None
template: Optional[str] = None
def __init__(
self,
mirostat: Optional[int] = None,
mirostat_eta: Optional[float] = None,
mirostat_tau: Optional[float] = None,
num_ctx: Optional[int] = None,
num_gqa: Optional[int] = None,
num_thread: Optional[int] = None,
repeat_last_n: Optional[int] = None,
repeat_penalty: Optional[float] = None,
temperature: Optional[float] = None,
stop: Optional[list] = None,
tfs_z: Optional[float] = None,
num_predict: Optional[int] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
system: Optional[str] = None,
template: Optional[str] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != 'self' and value is not None:
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {k: v for k, v in cls.__dict__.items()
if not k.startswith('__')
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
and v is not None}
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
# ollama implementation
def get_ollama_response(
api_base="http://localhost:11434",
model="llama2",
prompt="Why is the sky blue?",
optional_params=None,
logging_obj=None,
acompletion: bool = False,
model_response=None,
encoding=None
):
api_base="http://localhost:11434",
model="llama2",
prompt="Why is the sky blue?",
optional_params=None,
logging_obj=None,
acompletion: bool = False,
model_response=None,
encoding=None,
):
if api_base.endswith("/api/generate"):
url = api_base
else:
else:
url = f"{api_base}/api/generate"
## Load Config
config=litellm.OllamaConfig.get_config()
config = litellm.OllamaConfig.get_config()
for k, v in config.items():
if k not in optional_params: # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
if (
k not in optional_params
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
optional_params["stream"] = optional_params.get("stream", False)
data = {
"model": model,
"prompt": prompt,
**optional_params
}
data = {"model": model, "prompt": prompt, **optional_params}
## LOGGING
logging_obj.pre_call(
input=None,
api_key=None,
additional_args={"api_base": url, "complete_input_dict": data, "headers": {}, "acompletion": acompletion,},
additional_args={
"api_base": url,
"complete_input_dict": data,
"headers": {},
"acompletion": acompletion,
},
)
if acompletion is True:
if acompletion is True:
if optional_params.get("stream", False) == True:
response = ollama_async_streaming(url=url, data=data, model_response=model_response, encoding=encoding, logging_obj=logging_obj)
response = ollama_async_streaming(
url=url,
data=data,
model_response=model_response,
encoding=encoding,
logging_obj=logging_obj,
)
else:
response = ollama_acompletion(url=url, data=data, model_response=model_response, encoding=encoding, logging_obj=logging_obj)
response = ollama_acompletion(
url=url,
data=data,
model_response=model_response,
encoding=encoding,
logging_obj=logging_obj,
)
return response
elif optional_params.get("stream", False) == True:
return ollama_completion_stream(url=url, data=data, logging_obj=logging_obj)
response = requests.post(
url=f"{url}",
json=data,
)
url=f"{url}",
json=data,
)
if response.status_code != 200:
raise OllamaError(status_code=response.status_code, message=response.text)
raise OllamaError(status_code=response.status_code, message=response.text)
## LOGGING
logging_obj.post_call(
input=prompt,
@ -168,52 +202,76 @@ def get_ollama_response(
## RESPONSE OBJECT
model_response["choices"][0]["finish_reason"] = "stop"
if optional_params.get("format", "") == "json":
message = litellm.Message(content=None, tool_calls=[{"id": f"call_{str(uuid.uuid4())}", "function": {"arguments": response_json["response"], "name": ""}, "type": "function"}])
message = litellm.Message(
content=None,
tool_calls=[
{
"id": f"call_{str(uuid.uuid4())}",
"function": {"arguments": response_json["response"], "name": ""},
"type": "function",
}
],
)
model_response["choices"][0]["message"] = message
else:
model_response["choices"][0]["message"]["content"] = response_json["response"]
model_response["created"] = int(time.time())
model_response["model"] = "ollama/" + model
prompt_tokens = response_json["prompt_eval_count"] # type: ignore
prompt_tokens = response_json["prompt_eval_count"] # type: ignore
completion_tokens = response_json["eval_count"]
model_response["usage"] = litellm.Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens)
model_response["usage"] = litellm.Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
return model_response
def ollama_completion_stream(url, data, logging_obj):
with httpx.stream(
url=url,
json=data,
method="POST",
timeout=litellm.request_timeout
) as response:
try:
url=url, json=data, method="POST", timeout=litellm.request_timeout
) as response:
try:
if response.status_code != 200:
raise OllamaError(status_code=response.status_code, message=response.text)
streamwrapper = litellm.CustomStreamWrapper(completion_stream=response.iter_lines(), model=data['model'], custom_llm_provider="ollama",logging_obj=logging_obj)
raise OllamaError(
status_code=response.status_code, message=response.text
)
streamwrapper = litellm.CustomStreamWrapper(
completion_stream=response.iter_lines(),
model=data["model"],
custom_llm_provider="ollama",
logging_obj=logging_obj,
)
for transformed_chunk in streamwrapper:
yield transformed_chunk
except Exception as e:
except Exception as e:
raise e
async def ollama_async_streaming(url, data, model_response, encoding, logging_obj):
try:
client = httpx.AsyncClient()
async with client.stream(
url=f"{url}",
json=data,
method="POST",
timeout=litellm.request_timeout
) as response:
if response.status_code != 200:
raise OllamaError(status_code=response.status_code, message=response.text)
streamwrapper = litellm.CustomStreamWrapper(completion_stream=response.aiter_lines(), model=data['model'], custom_llm_provider="ollama",logging_obj=logging_obj)
async for transformed_chunk in streamwrapper:
yield transformed_chunk
url=f"{url}", json=data, method="POST", timeout=litellm.request_timeout
) as response:
if response.status_code != 200:
raise OllamaError(
status_code=response.status_code, message=response.text
)
streamwrapper = litellm.CustomStreamWrapper(
completion_stream=response.aiter_lines(),
model=data["model"],
custom_llm_provider="ollama",
logging_obj=logging_obj,
)
async for transformed_chunk in streamwrapper:
yield transformed_chunk
except Exception as e:
traceback.print_exc()
async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
data["stream"] = False
try:
@ -224,10 +282,10 @@ async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
if resp.status != 200:
text = await resp.text()
raise OllamaError(status_code=resp.status, message=text)
## LOGGING
logging_obj.post_call(
input=data['prompt'],
input=data["prompt"],
api_key="",
original_response=resp.text,
additional_args={
@ -240,37 +298,59 @@ async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
## RESPONSE OBJECT
model_response["choices"][0]["finish_reason"] = "stop"
if data.get("format", "") == "json":
message = litellm.Message(content=None, tool_calls=[{"id": f"call_{str(uuid.uuid4())}", "function": {"arguments": response_json["response"], "name": ""}, "type": "function"}])
message = litellm.Message(
content=None,
tool_calls=[
{
"id": f"call_{str(uuid.uuid4())}",
"function": {
"arguments": response_json["response"],
"name": "",
},
"type": "function",
}
],
)
model_response["choices"][0]["message"] = message
else:
model_response["choices"][0]["message"]["content"] = response_json["response"]
model_response["choices"][0]["message"]["content"] = response_json[
"response"
]
model_response["created"] = int(time.time())
model_response["model"] = "ollama/" + data['model']
prompt_tokens = response_json["prompt_eval_count"] # type: ignore
model_response["model"] = "ollama/" + data["model"]
prompt_tokens = response_json["prompt_eval_count"] # type: ignore
completion_tokens = response_json["eval_count"]
model_response["usage"] = litellm.Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens)
model_response["usage"] = litellm.Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
return model_response
except Exception as e:
traceback.print_exc()
raise e
async def ollama_aembeddings(api_base="http://localhost:11434",
model="llama2",
prompt="Why is the sky blue?",
optional_params=None,
logging_obj=None,
model_response=None,
encoding=None):
async def ollama_aembeddings(
api_base="http://localhost:11434",
model="llama2",
prompt="Why is the sky blue?",
optional_params=None,
logging_obj=None,
model_response=None,
encoding=None,
):
if api_base.endswith("/api/embeddings"):
url = api_base
else:
else:
url = f"{api_base}/api/embeddings"
## Load Config
config=litellm.OllamaConfig.get_config()
config = litellm.OllamaConfig.get_config()
for k, v in config.items():
if k not in optional_params: # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
if (
k not in optional_params
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
data = {
@ -290,7 +370,7 @@ async def ollama_aembeddings(api_base="http://localhost:11434",
if response.status != 200:
text = await response.text()
raise OllamaError(status_code=response.status, message=text)
## LOGGING
logging_obj.post_call(
input=prompt,
@ -308,20 +388,16 @@ async def ollama_aembeddings(api_base="http://localhost:11434",
output_data = []
for idx, embedding in enumerate(embeddings):
output_data.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding
}
{"object": "embedding", "index": idx, "embedding": embedding}
)
model_response["object"] = "list"
model_response["data"] = output_data
model_response["model"] = model
input_tokens = len(encoding.encode(prompt))
input_tokens = len(encoding.encode(prompt))
model_response["usage"] = {
"prompt_tokens": input_tokens,
model_response["usage"] = {
"prompt_tokens": input_tokens,
"total_tokens": input_tokens,
}
return model_response
return model_response