This commit is contained in:
ishaan-jaff 2023-09-26 10:00:56 -07:00
parent 6eea9da4ab
commit d6bc20d5be
2 changed files with 130 additions and 0 deletions

103
litellm/llms/palm.py Normal file
View file

@ -0,0 +1,103 @@
import os
import json
from enum import Enum
import requests
import time
from typing import Callable
from litellm.utils import ModelResponse, get_secret
import sys
class PalmError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
def completion(
model: str,
messages: list,
model_response: ModelResponse,
api_key: str,
print_verbose: Callable,
encoding,
logging_obj,
optional_params=None,
litellm_params=None,
logger_fn=None,
):
import google.generativeai as palm
palm.configure(api_key=api_key)
model = model
prompt = ""
for message in messages:
if "role" in message:
if message["role"] == "user":
prompt += (
f"{message['content']}"
)
else:
prompt += (
f"{message['content']}"
)
else:
prompt += f"{message['content']}"
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key="",
additional_args={"complete_input_dict": {}},
)
## COMPLETION CALL
response = palm.chat(messages=prompt)
if "stream" in optional_params and optional_params["stream"] == True:
return response.iter_lines()
else:
## LOGGING
logging_obj.post_call(
input=prompt,
api_key="",
original_response=response,
additional_args={"complete_input_dict": {}},
)
print_verbose(f"raw model_response: {response}")
## RESPONSE OBJECT
completion_response = response.last
if "error" in completion_response:
raise PalmError(
message=completion_response["error"],
status_code=response.status_code,
)
else:
try:
model_response["choices"][0]["message"]["content"] = completion_response
except:
raise PalmError(message=json.dumps(completion_response), status_code=response.status_code)
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
prompt_tokens = len(
encoding.encode(prompt)
)
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"]["content"])
)
model_response["created"] = time.time()
model_response["model"] = "palm/" + model
model_response["usage"] = {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
}
return model_response
def embedding():
# logic for parsing in - calling - parsing out model embedding calls
pass

View file

@ -44,6 +44,7 @@ from .llms import ollama
from .llms import cohere from .llms import cohere
from .llms import petals from .llms import petals
from .llms import oobabooga from .llms import oobabooga
from .llms import palm
import tiktoken import tiktoken
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Callable, List, Optional, Dict from typing import Callable, List, Optional, Dict
@ -792,6 +793,32 @@ def completion(
) )
return response return response
response = model_response response = model_response
elif custom_llm_provider == "palm":
api_key = (
api_key
or get_secret("PALM_API_KEY")
or litellm.api_key
)
model_response = palm.completion(
model=model,
messages=messages,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
api_key=api_key,
logging_obj=logging
)
if "stream_tokens" in optional_params and optional_params["stream_tokens"] == True:
# don't try to access stream object,
response = CustomStreamWrapper(
model_response, model, custom_llm_provider="palm", logging_obj=logging
)
return response
response = model_response
elif model in litellm.vertex_chat_models or model in litellm.vertex_code_chat_models: elif model in litellm.vertex_chat_models or model in litellm.vertex_code_chat_models:
try: try:
import vertexai import vertexai