fix: move to using pydantic obj for setting values

This commit is contained in:
Krrish Dholakia 2024-07-11 13:18:36 -07:00
parent dd1048cb35
commit 6e9f048618
30 changed files with 1018 additions and 886 deletions

View file

@ -1,7 +1,7 @@
####################################
######### DEPRECATED FILE ##########
####################################
# logic moved to `vertex_httpx.py` #
# ####################################
# ######### DEPRECATED FILE ##########
# ####################################
# # logic moved to `vertex_httpx.py` #
import copy
import time
@ -92,332 +92,332 @@ class GeminiConfig:
}
class TextStreamer:
"""
A class designed to return an async stream from AsyncGenerateContentResponse object.
"""
# class TextStreamer:
# """
# A class designed to return an async stream from AsyncGenerateContentResponse object.
# """
def __init__(self, response):
self.response = response
self._aiter = self.response.__aiter__()
# def __init__(self, response):
# self.response = response
# self._aiter = self.response.__aiter__()
async def __aiter__(self):
while True:
try:
# This will manually advance the async iterator.
# In the case the next object doesn't exists, __anext__() will simply raise a StopAsyncIteration exception
next_object = await self._aiter.__anext__()
yield next_object
except StopAsyncIteration:
# After getting all items from the async iterator, stop iterating
break
# async def __aiter__(self):
# while True:
# try:
# # This will manually advance the async iterator.
# # In the case the next object doesn't exists, __anext__() will simply raise a StopAsyncIteration exception
# next_object = await self._aiter.__anext__()
# yield next_object
# except StopAsyncIteration:
# # After getting all items from the async iterator, stop iterating
# break
def supports_system_instruction():
import google.generativeai as genai
# def supports_system_instruction():
# import google.generativeai as genai
gemini_pkg_version = Version(genai.__version__)
return gemini_pkg_version >= Version("0.5.0")
# gemini_pkg_version = Version(genai.__version__)
# return gemini_pkg_version >= Version("0.5.0")
def completion(
model: str,
messages: list,
model_response: ModelResponse,
print_verbose: Callable,
api_key,
encoding,
logging_obj,
custom_prompt_dict: dict,
acompletion: bool = False,
optional_params=None,
litellm_params=None,
logger_fn=None,
):
try:
import google.generativeai as genai # type: ignore
except:
raise Exception(
"Importing google.generativeai failed, please run 'pip install -q google-generativeai"
)
genai.configure(api_key=api_key)
system_prompt = ""
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages,
)
else:
system_prompt, messages = get_system_prompt(messages=messages)
prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="gemini"
)
# def completion(
# model: str,
# messages: list,
# model_response: ModelResponse,
# print_verbose: Callable,
# api_key,
# encoding,
# logging_obj,
# custom_prompt_dict: dict,
# acompletion: bool = False,
# optional_params=None,
# litellm_params=None,
# logger_fn=None,
# ):
# try:
# import google.generativeai as genai # type: ignore
# except:
# raise Exception(
# "Importing google.generativeai failed, please run 'pip install -q google-generativeai"
# )
# genai.configure(api_key=api_key)
# system_prompt = ""
# if model in custom_prompt_dict:
# # check if the model has a registered custom prompt
# model_prompt_details = custom_prompt_dict[model]
# prompt = custom_prompt(
# role_dict=model_prompt_details["roles"],
# initial_prompt_value=model_prompt_details["initial_prompt_value"],
# final_prompt_value=model_prompt_details["final_prompt_value"],
# messages=messages,
# )
# else:
# system_prompt, messages = get_system_prompt(messages=messages)
# prompt = prompt_factory(
# model=model, messages=messages, custom_llm_provider="gemini"
# )
## Load Config
inference_params = copy.deepcopy(optional_params)
stream = inference_params.pop("stream", None)
# ## Load Config
# inference_params = copy.deepcopy(optional_params)
# stream = inference_params.pop("stream", None)
# Handle safety settings
safety_settings_param = inference_params.pop("safety_settings", None)
safety_settings = None
if safety_settings_param:
safety_settings = [
genai.types.SafetySettingDict(x) for x in safety_settings_param
]
# # Handle safety settings
# safety_settings_param = inference_params.pop("safety_settings", None)
# safety_settings = None
# if safety_settings_param:
# safety_settings = [
# genai.types.SafetySettingDict(x) for x in safety_settings_param
# ]
config = litellm.GeminiConfig.get_config()
for k, v in config.items():
if (
k not in inference_params
): # completion(top_k=3) > gemini_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
# config = litellm.GeminiConfig.get_config()
# for k, v in config.items():
# if (
# k not in inference_params
# ): # completion(top_k=3) > gemini_config(top_k=3) <- allows for dynamic variables to be passed in
# inference_params[k] = v
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key="",
additional_args={
"complete_input_dict": {
"inference_params": inference_params,
"system_prompt": system_prompt,
}
},
)
## COMPLETION CALL
try:
_params = {"model_name": "models/{}".format(model)}
_system_instruction = supports_system_instruction()
if _system_instruction and len(system_prompt) > 0:
_params["system_instruction"] = system_prompt
_model = genai.GenerativeModel(**_params)
if stream is True:
if acompletion is True:
# ## LOGGING
# logging_obj.pre_call(
# input=prompt,
# api_key="",
# additional_args={
# "complete_input_dict": {
# "inference_params": inference_params,
# "system_prompt": system_prompt,
# }
# },
# )
# ## COMPLETION CALL
# try:
# _params = {"model_name": "models/{}".format(model)}
# _system_instruction = supports_system_instruction()
# if _system_instruction and len(system_prompt) > 0:
# _params["system_instruction"] = system_prompt
# _model = genai.GenerativeModel(**_params)
# if stream is True:
# if acompletion is True:
async def async_streaming():
try:
response = await _model.generate_content_async(
contents=prompt,
generation_config=genai.types.GenerationConfig(
**inference_params
),
safety_settings=safety_settings,
stream=True,
)
# async def async_streaming():
# try:
# response = await _model.generate_content_async(
# contents=prompt,
# generation_config=genai.types.GenerationConfig(
# **inference_params
# ),
# safety_settings=safety_settings,
# stream=True,
# )
response = litellm.CustomStreamWrapper(
TextStreamer(response),
model,
custom_llm_provider="gemini",
logging_obj=logging_obj,
)
return response
except Exception as e:
raise GeminiError(status_code=500, message=str(e))
# response = litellm.CustomStreamWrapper(
# TextStreamer(response),
# model,
# custom_llm_provider="gemini",
# logging_obj=logging_obj,
# )
# return response
# except Exception as e:
# raise GeminiError(status_code=500, message=str(e))
return async_streaming()
response = _model.generate_content(
contents=prompt,
generation_config=genai.types.GenerationConfig(**inference_params),
safety_settings=safety_settings,
stream=True,
)
return response
elif acompletion == True:
return async_completion(
_model=_model,
model=model,
prompt=prompt,
inference_params=inference_params,
safety_settings=safety_settings,
logging_obj=logging_obj,
print_verbose=print_verbose,
model_response=model_response,
messages=messages,
encoding=encoding,
)
else:
params = {
"contents": prompt,
"generation_config": genai.types.GenerationConfig(**inference_params),
"safety_settings": safety_settings,
}
response = _model.generate_content(**params)
except Exception as e:
raise GeminiError(
message=str(e),
status_code=500,
)
# return async_streaming()
# response = _model.generate_content(
# contents=prompt,
# generation_config=genai.types.GenerationConfig(**inference_params),
# safety_settings=safety_settings,
# stream=True,
# )
# return response
# elif acompletion == True:
# return async_completion(
# _model=_model,
# model=model,
# prompt=prompt,
# inference_params=inference_params,
# safety_settings=safety_settings,
# logging_obj=logging_obj,
# print_verbose=print_verbose,
# model_response=model_response,
# messages=messages,
# encoding=encoding,
# )
# else:
# params = {
# "contents": prompt,
# "generation_config": genai.types.GenerationConfig(**inference_params),
# "safety_settings": safety_settings,
# }
# response = _model.generate_content(**params)
# except Exception as e:
# raise GeminiError(
# message=str(e),
# status_code=500,
# )
## LOGGING
logging_obj.post_call(
input=prompt,
api_key="",
original_response=response,
additional_args={"complete_input_dict": {}},
)
print_verbose(f"raw model_response: {response}")
## RESPONSE OBJECT
completion_response = response
try:
choices_list = []
for idx, item in enumerate(completion_response.candidates):
if len(item.content.parts) > 0:
message_obj = Message(content=item.content.parts[0].text)
else:
message_obj = Message(content=None)
choice_obj = Choices(index=idx, message=message_obj)
choices_list.append(choice_obj)
model_response["choices"] = choices_list
except Exception as e:
verbose_logger.error("LiteLLM.gemini.py: Exception occured - {}".format(str(e)))
verbose_logger.debug(traceback.format_exc())
raise GeminiError(
message=traceback.format_exc(), status_code=response.status_code
)
# ## LOGGING
# logging_obj.post_call(
# input=prompt,
# api_key="",
# original_response=response,
# additional_args={"complete_input_dict": {}},
# )
# print_verbose(f"raw model_response: {response}")
# ## RESPONSE OBJECT
# completion_response = response
# try:
# choices_list = []
# for idx, item in enumerate(completion_response.candidates):
# if len(item.content.parts) > 0:
# message_obj = Message(content=item.content.parts[0].text)
# else:
# message_obj = Message(content=None)
# choice_obj = Choices(index=idx, message=message_obj)
# choices_list.append(choice_obj)
# model_response.choices = choices_list
# except Exception as e:
# verbose_logger.error("LiteLLM.gemini.py: Exception occured - {}".format(str(e)))
# verbose_logger.debug(traceback.format_exc())
# raise GeminiError(
# message=traceback.format_exc(), status_code=response.status_code
# )
try:
completion_response = model_response["choices"][0]["message"].get("content")
if completion_response is None:
raise Exception
except:
original_response = f"response: {response}"
if hasattr(response, "candidates"):
original_response = f"response: {response.candidates}"
if "SAFETY" in original_response:
original_response += (
"\nThe candidate content was flagged for safety reasons."
)
elif "RECITATION" in original_response:
original_response += (
"\nThe candidate content was flagged for recitation reasons."
)
raise GeminiError(
status_code=400,
message=f"No response received. Original response - {original_response}",
)
# try:
# completion_response = model_response["choices"][0]["message"].get("content")
# if completion_response is None:
# raise Exception
# except:
# original_response = f"response: {response}"
# if hasattr(response, "candidates"):
# original_response = f"response: {response.candidates}"
# if "SAFETY" in original_response:
# original_response += (
# "\nThe candidate content was flagged for safety reasons."
# )
# elif "RECITATION" in original_response:
# original_response += (
# "\nThe candidate content was flagged for recitation reasons."
# )
# raise GeminiError(
# status_code=400,
# message=f"No response received. Original response - {original_response}",
# )
## CALCULATING USAGE
prompt_str = ""
for m in messages:
if isinstance(m["content"], str):
prompt_str += m["content"]
elif isinstance(m["content"], list):
for content in m["content"]:
if content["type"] == "text":
prompt_str += content["text"]
prompt_tokens = len(encoding.encode(prompt_str))
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
)
# ## CALCULATING USAGE
# prompt_str = ""
# for m in messages:
# if isinstance(m["content"], str):
# prompt_str += m["content"]
# elif isinstance(m["content"], list):
# for content in m["content"]:
# if content["type"] == "text":
# prompt_str += content["text"]
# prompt_tokens = len(encoding.encode(prompt_str))
# completion_tokens = len(
# encoding.encode(model_response["choices"][0]["message"].get("content", ""))
# )
model_response["created"] = int(time.time())
model_response["model"] = "gemini/" + model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
setattr(model_response, "usage", usage)
return model_response
# model_response.created = int(time.time())
# model_response.model = "gemini/" + model
# usage = Usage(
# prompt_tokens=prompt_tokens,
# completion_tokens=completion_tokens,
# total_tokens=prompt_tokens + completion_tokens,
# )
# setattr(model_response, "usage", usage)
# return model_response
async def async_completion(
_model,
model,
prompt,
inference_params,
safety_settings,
logging_obj,
print_verbose,
model_response,
messages,
encoding,
):
import google.generativeai as genai # type: ignore
# async def async_completion(
# _model,
# model,
# prompt,
# inference_params,
# safety_settings,
# logging_obj,
# print_verbose,
# model_response,
# messages,
# encoding,
# ):
# import google.generativeai as genai # type: ignore
response = await _model.generate_content_async(
contents=prompt,
generation_config=genai.types.GenerationConfig(**inference_params),
safety_settings=safety_settings,
)
# response = await _model.generate_content_async(
# contents=prompt,
# generation_config=genai.types.GenerationConfig(**inference_params),
# safety_settings=safety_settings,
# )
## LOGGING
logging_obj.post_call(
input=prompt,
api_key="",
original_response=response,
additional_args={"complete_input_dict": {}},
)
print_verbose(f"raw model_response: {response}")
## RESPONSE OBJECT
completion_response = response
try:
choices_list = []
for idx, item in enumerate(completion_response.candidates):
if len(item.content.parts) > 0:
message_obj = Message(content=item.content.parts[0].text)
else:
message_obj = Message(content=None)
choice_obj = Choices(index=idx, message=message_obj)
choices_list.append(choice_obj)
model_response["choices"] = choices_list
except Exception as e:
verbose_logger.error("LiteLLM.gemini.py: Exception occured - {}".format(str(e)))
verbose_logger.debug(traceback.format_exc())
raise GeminiError(
message=traceback.format_exc(), status_code=response.status_code
)
# ## LOGGING
# logging_obj.post_call(
# input=prompt,
# api_key="",
# original_response=response,
# additional_args={"complete_input_dict": {}},
# )
# print_verbose(f"raw model_response: {response}")
# ## RESPONSE OBJECT
# completion_response = response
# try:
# choices_list = []
# for idx, item in enumerate(completion_response.candidates):
# if len(item.content.parts) > 0:
# message_obj = Message(content=item.content.parts[0].text)
# else:
# message_obj = Message(content=None)
# choice_obj = Choices(index=idx, message=message_obj)
# choices_list.append(choice_obj)
# model_response["choices"] = choices_list
# except Exception as e:
# verbose_logger.error("LiteLLM.gemini.py: Exception occured - {}".format(str(e)))
# verbose_logger.debug(traceback.format_exc())
# raise GeminiError(
# message=traceback.format_exc(), status_code=response.status_code
# )
try:
completion_response = model_response["choices"][0]["message"].get("content")
if completion_response is None:
raise Exception
except:
original_response = f"response: {response}"
if hasattr(response, "candidates"):
original_response = f"response: {response.candidates}"
if "SAFETY" in original_response:
original_response += (
"\nThe candidate content was flagged for safety reasons."
)
elif "RECITATION" in original_response:
original_response += (
"\nThe candidate content was flagged for recitation reasons."
)
raise GeminiError(
status_code=400,
message=f"No response received. Original response - {original_response}",
)
# try:
# completion_response = model_response["choices"][0]["message"].get("content")
# if completion_response is None:
# raise Exception
# except:
# original_response = f"response: {response}"
# if hasattr(response, "candidates"):
# original_response = f"response: {response.candidates}"
# if "SAFETY" in original_response:
# original_response += (
# "\nThe candidate content was flagged for safety reasons."
# )
# elif "RECITATION" in original_response:
# original_response += (
# "\nThe candidate content was flagged for recitation reasons."
# )
# raise GeminiError(
# status_code=400,
# message=f"No response received. Original response - {original_response}",
# )
## CALCULATING USAGE
prompt_str = ""
for m in messages:
if isinstance(m["content"], str):
prompt_str += m["content"]
elif isinstance(m["content"], list):
for content in m["content"]:
if content["type"] == "text":
prompt_str += content["text"]
prompt_tokens = len(encoding.encode(prompt_str))
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
)
# ## CALCULATING USAGE
# prompt_str = ""
# for m in messages:
# if isinstance(m["content"], str):
# prompt_str += m["content"]
# elif isinstance(m["content"], list):
# for content in m["content"]:
# if content["type"] == "text":
# prompt_str += content["text"]
# prompt_tokens = len(encoding.encode(prompt_str))
# completion_tokens = len(
# encoding.encode(model_response["choices"][0]["message"].get("content", ""))
# )
model_response["created"] = int(time.time())
model_response["model"] = "gemini/" + model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
return model_response
# model_response["created"] = int(time.time())
# model_response["model"] = "gemini/" + model
# usage = Usage(
# prompt_tokens=prompt_tokens,
# completion_tokens=completion_tokens,
# total_tokens=prompt_tokens + completion_tokens,
# )
# model_response.usage = usage
# return model_response
def embedding():
# logic for parsing in - calling - parsing out model embedding calls
pass
# def embedding():
# # logic for parsing in - calling - parsing out model embedding calls
# pass