From 1262d89ab385d16220d1578a4908f53b9bc5a075 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sun, 24 Dec 2023 09:42:45 +0530 Subject: [PATCH] feat(gemini.py): add support for completion calls for gemini-pro (google ai studio) --- litellm/__init__.py | 2 + litellm/llms/gemini.py | 186 +++++++++++++++++++++++ litellm/llms/prompt_templates/factory.py | 43 ++++++ litellm/main.py | 25 +++ litellm/tests/test_completion.py | 20 ++- 5 files changed, 272 insertions(+), 4 deletions(-) create mode 100644 litellm/llms/gemini.py diff --git a/litellm/__init__.py b/litellm/__init__.py index bd7e6b11f..ce606c8fd 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -292,6 +292,7 @@ provider_list: List = [ "openrouter", "vertex_ai", "palm", + "gemini", "ai21", "baseten", "azure", @@ -406,6 +407,7 @@ from .llms.cohere import CohereConfig from .llms.ai21 import AI21Config from .llms.together_ai import TogetherAIConfig from .llms.palm import PalmConfig +from .llms.gemini import GeminiConfig from .llms.nlp_cloud import NLPCloudConfig from .llms.aleph_alpha import AlephAlphaConfig from .llms.petals import PetalsConfig diff --git a/litellm/llms/gemini.py b/litellm/llms/gemini.py new file mode 100644 index 000000000..ebbad901a --- /dev/null +++ b/litellm/llms/gemini.py @@ -0,0 +1,186 @@ +import os, types, traceback, copy +import json +from enum import Enum +import time +from typing import Callable, Optional +from litellm.utils import ModelResponse, get_secret, Choices, Message, Usage +import litellm +import sys, httpx +from .prompt_templates.factory import prompt_factory, custom_prompt + +class GeminiError(Exception): + def __init__(self, status_code, message): + self.status_code = status_code + self.message = message + self.request = httpx.Request(method="POST", url="https://developers.generativeai.google/api/python/google/generativeai/chat") + self.response = httpx.Response(status_code=status_code, request=self.request) + super().__init__( + self.message + ) # Call the base class constructor with the parameters it needs + +class GeminiConfig(): + """ + Reference: https://ai.google.dev/api/python/google/generativeai/GenerationConfig + + The class `GeminiConfig` provides configuration for the Gemini's API interface. Here are the parameters: + + - `candidate_count` (int): Number of generated responses to return. + + - `stop_sequences` (List[str]): The set of character sequences (up to 5) that will stop output generation. If specified, the API will stop at the first appearance of a stop sequence. The stop sequence will not be included as part of the response. + + - `max_output_tokens` (int): The maximum number of tokens to include in a candidate. If unset, this will default to output_token_limit specified in the model's specification. + + - `temperature` (float): Controls the randomness of the output. Note: The default value varies by model, see the Model.temperature attribute of the Model returned the genai.get_model function. Values can range from [0.0,1.0], inclusive. A value closer to 1.0 will produce responses that are more varied and creative, while a value closer to 0.0 will typically result in more straightforward responses from the model. + + - `top_p` (float): Optional. The maximum cumulative probability of tokens to consider when sampling. + + - `top_k` (int): Optional. The maximum number of tokens to consider when sampling. + """ + + candidate_count: Optional[int]=None + stop_sequences: Optional[list]=None + max_output_tokens: Optional[int]=None + temperature: Optional[float]=None + top_p: Optional[float]=None + top_k: Optional[int]=None + + def __init__(self, + candidate_count: Optional[int]=None, + stop_sequences: Optional[list]=None, + max_output_tokens: Optional[int]=None, + temperature: Optional[float]=None, + top_p: Optional[float]=None, + top_k: Optional[int]=None) -> None: + + locals_ = locals() + for key, value in locals_.items(): + if key != 'self' and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return {k: v for k, v in cls.__dict__.items() + if not k.startswith('__') + and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod)) + and v is not None} + + + +def completion( + model: str, + messages: list, + model_response: ModelResponse, + print_verbose: Callable, + api_key, + encoding, + logging_obj, + custom_prompt_dict: dict, + acompletion: bool = False, + optional_params=None, + litellm_params=None, + logger_fn=None, +): + try: + import google.generativeai as genai + except: + raise Exception("Importing google.generativeai failed, please run 'pip install -q google-generativeai") + genai.configure(api_key=api_key) + + + if model in custom_prompt_dict: + # check if the model has a registered custom prompt + model_prompt_details = custom_prompt_dict[model] + prompt = custom_prompt( + role_dict=model_prompt_details["roles"], + initial_prompt_value=model_prompt_details["initial_prompt_value"], + final_prompt_value=model_prompt_details["final_prompt_value"], + messages=messages + ) + else: + prompt = prompt_factory(model=model, messages=messages, custom_llm_provider="gemini") + + + ## Load Config + inference_params = copy.deepcopy(optional_params) + inference_params.pop("stream", None) # palm does not support streaming, so we handle this by fake streaming in main.py + config = litellm.GeminiConfig.get_config() + for k, v in config.items(): + if k not in inference_params: # completion(top_k=3) > gemini_config(top_k=3) <- allows for dynamic variables to be passed in + inference_params[k] = v + + + ## LOGGING + logging_obj.pre_call( + input=prompt, + api_key="", + additional_args={"complete_input_dict": {"inference_params": inference_params}}, + ) + ## COMPLETION CALL + try: + _model = genai.GenerativeModel(f'models/{model}') + response = _model.generate_content(contents=prompt, generation_config=genai.types.GenerationConfig(**inference_params)) + except Exception as e: + raise GeminiError( + message=str(e), + status_code=500, + ) + + ## LOGGING + logging_obj.post_call( + input=prompt, + api_key="", + original_response=response, + additional_args={"complete_input_dict": {}}, + ) + print_verbose(f"raw model_response: {response}") + ## RESPONSE OBJECT + completion_response = response + try: + choices_list = [] + for idx, item in enumerate(completion_response.candidates): + if len(item.content.parts) > 0: + message_obj = Message(content=item.content.parts[0].text) + else: + message_obj = Message(content=None) + choice_obj = Choices(index=idx+1, message=message_obj) + choices_list.append(choice_obj) + model_response["choices"] = choices_list + except Exception as e: + traceback.print_exc() + raise GeminiError(message=traceback.format_exc(), status_code=response.status_code) + + try: + completion_response = model_response["choices"][0]["message"].get("content") + except: + raise GeminiError(status_code=400, message=f"No response received. Original response - {response}") + + ## CALCULATING USAGE + prompt_str = "" + for m in messages: + if isinstance(m["content"], str): + prompt_str += m["content"] + elif isinstance(m["content"], list): + for content in m["content"]: + if content["type"] == "text": + prompt_str += content["text"] + + prompt_tokens = len( + encoding.encode(prompt_str) + ) + completion_tokens = len( + encoding.encode(model_response["choices"][0]["message"].get("content", "")) + ) + + model_response["created"] = int(time.time()) + model_response["model"] = "gemini/" + model + usage = Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens + ) + model_response.usage = usage + return model_response + +def embedding(): + # logic for parsing in - calling - parsing out model embedding calls + pass diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py index d908231cb..265cae941 100644 --- a/litellm/llms/prompt_templates/factory.py +++ b/litellm/llms/prompt_templates/factory.py @@ -346,6 +346,47 @@ def anthropic_pt(messages: list): # format - https://docs.anthropic.com/claude/r prompt += f"{AnthropicConstants.AI_PROMPT.value}" return prompt +def gemini_text_image_pt(messages: list): + """ + { + "contents":[ + { + "parts":[ + {"text": "What is this picture?"}, + { + "inline_data": { + "mime_type":"image/jpeg", + "data": "'$(base64 -w0 image.jpg)'" + } + } + ] + } + ] + } + """ + try: + import google.generativeai as genai + except: + raise Exception("Importing google.generativeai failed, please run 'pip install -q google-generativeai") + + prompt = "" + images = [] + for message in messages: + if isinstance(message["content"], str): + prompt += message["content"] + elif isinstance(message["content"], list): + # see https://docs.litellm.ai/docs/providers/openai#openai-vision-models + for element in message["content"]: + if isinstance(element, dict): + if element["type"] == "text": + prompt += element["text"] + elif element["type"] == "image_url": + image_url = element["image_url"]["url"] + images.append(image_url) + + content = [prompt] + images + return content + # Function call template def function_call_prompt(messages: list, functions: list): function_prompt = "Produce JSON OUTPUT ONLY! The following functions are available to you:" @@ -401,6 +442,8 @@ def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str elif custom_llm_provider == "together_ai": prompt_format, chat_template = get_model_info(token=api_key, model=model) return format_prompt_togetherai(messages=messages, prompt_format=prompt_format, chat_template=chat_template) + elif custom_llm_provider == "gemini": + return gemini_text_image_pt(messages=messages) try: if "meta-llama/llama-2" in model and "chat" in model: return llama_2_chat_pt(messages=messages) diff --git a/litellm/main.py b/litellm/main.py index 8d2cabdfb..0c3ab4562 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -54,6 +54,7 @@ from .llms import ( oobabooga, openrouter, palm, + gemini, vertex_ai, maritalk) from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion @@ -1137,6 +1138,30 @@ def completion( ) return response response = model_response + elif custom_llm_provider == "gemini": + gemini_api_key = ( + api_key + or get_secret("GEMINI_API_KEY") + or get_secret("PALM_API_KEY") # older palm api key should also work + or litellm.api_key + ) + + # palm does not support streaming as yet :( + model_response = gemini.completion( + model=model, + messages=messages, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + api_key=gemini_api_key, + logging_obj=logging, + acompletion=acompletion, + custom_prompt_dict=custom_prompt_dict + ) + response = model_response elif custom_llm_provider == "vertex_ai": vertex_ai_project = (litellm.vertex_project or get_secret("VERTEXAI_PROJECT")) diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 365348433..2f4195717 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -12,7 +12,7 @@ import pytest import litellm from litellm import embedding, completion, completion_cost, Timeout from litellm import RateLimitError -litellm.num_retries = 3 +# litellm.num_retries = 3 litellm.cache = None litellm.success_callback = [] user_message = "Write a short poem about the sky" @@ -668,7 +668,7 @@ def test_completion_azure_key_completion_arg(): except Exception as e: os.environ["AZURE_API_KEY"] = old_key pytest.fail(f"Error occurred: {e}") -test_completion_azure_key_completion_arg() +# test_completion_azure_key_completion_arg() async def test_re_use_azure_async_client(): @@ -745,7 +745,7 @@ def test_completion_azure(): except Exception as e: pytest.fail(f"Error occurred: {e}") -test_completion_azure() +# test_completion_azure() def test_azure_openai_ad_token(): # this tests if the azure ad token is set in the request header @@ -1082,7 +1082,7 @@ def test_completion_together_ai_mixtral(): except Exception as e: pytest.fail(f"Error occurred: {e}") -test_completion_together_ai_mixtral() +# test_completion_together_ai_mixtral() def test_completion_together_ai_yi_chat(): model_name = "together_ai/zero-one-ai/Yi-34B-Chat" @@ -1623,6 +1623,18 @@ def test_completion_deep_infra_mistral(): pytest.fail(f"Error occurred: {e}") # test_completion_deep_infra_mistral() +# Gemini tests +def test_completion_gemini(): + litellm.set_verbose = True + model_name = "gemini/gemini-pro" + messages = [{"role": "user", "content": "Hey, how's it going?"}] + try: + response = completion(model=model_name, messages=messages) + # Add any assertions here to check the response + print(response) + except Exception as e: + pytest.fail(f"Error occurred: {e}") +test_completion_gemini() # Palm tests def test_completion_palm(): litellm.set_verbose = True