diff --git a/litellm/llms/palm.py b/litellm/llms/palm.py new file mode 100644 index 000000000..ee1fc19bb --- /dev/null +++ b/litellm/llms/palm.py @@ -0,0 +1,103 @@ +import os +import json +from enum import Enum +import requests +import time +from typing import Callable +from litellm.utils import ModelResponse, get_secret +import sys + +class PalmError(Exception): + def __init__(self, status_code, message): + self.status_code = status_code + self.message = message + super().__init__( + self.message + ) # Call the base class constructor with the parameters it needs + +def completion( + model: str, + messages: list, + model_response: ModelResponse, + api_key: str, + print_verbose: Callable, + encoding, + logging_obj, + optional_params=None, + litellm_params=None, + logger_fn=None, +): + + import google.generativeai as palm + palm.configure(api_key=api_key) + + model = model + prompt = "" + for message in messages: + if "role" in message: + if message["role"] == "user": + prompt += ( + f"{message['content']}" + ) + else: + prompt += ( + f"{message['content']}" + ) + else: + prompt += f"{message['content']}" + + ## LOGGING + logging_obj.pre_call( + input=prompt, + api_key="", + additional_args={"complete_input_dict": {}}, + ) + ## COMPLETION CALL + response = palm.chat(messages=prompt) + + + if "stream" in optional_params and optional_params["stream"] == True: + return response.iter_lines() + else: + ## LOGGING + logging_obj.post_call( + input=prompt, + api_key="", + original_response=response, + additional_args={"complete_input_dict": {}}, + ) + print_verbose(f"raw model_response: {response}") + ## RESPONSE OBJECT + completion_response = response.last + + if "error" in completion_response: + raise PalmError( + message=completion_response["error"], + status_code=response.status_code, + ) + else: + try: + model_response["choices"][0]["message"]["content"] = completion_response + except: + raise PalmError(message=json.dumps(completion_response), status_code=response.status_code) + + ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. + prompt_tokens = len( + encoding.encode(prompt) + ) + completion_tokens = len( + encoding.encode(model_response["choices"][0]["message"]["content"]) + ) + + model_response["created"] = time.time() + model_response["model"] = "palm/" + model + model_response["usage"] = { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + } + return model_response + +def embedding(): + # logic for parsing in - calling - parsing out model embedding calls + pass diff --git a/litellm/main.py b/litellm/main.py index df4047c74..a17a8a707 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -44,6 +44,7 @@ from .llms import ollama from .llms import cohere from .llms import petals from .llms import oobabooga +from .llms import palm import tiktoken from concurrent.futures import ThreadPoolExecutor from typing import Callable, List, Optional, Dict @@ -792,6 +793,32 @@ def completion( ) return response response = model_response + elif custom_llm_provider == "palm": + api_key = ( + api_key + or get_secret("PALM_API_KEY") + or litellm.api_key + ) + + model_response = palm.completion( + model=model, + messages=messages, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + api_key=api_key, + logging_obj=logging + ) + if "stream_tokens" in optional_params and optional_params["stream_tokens"] == True: + # don't try to access stream object, + response = CustomStreamWrapper( + model_response, model, custom_llm_provider="palm", logging_obj=logging + ) + return response + response = model_response elif model in litellm.vertex_chat_models or model in litellm.vertex_code_chat_models: try: import vertexai