mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
with text-bison
This commit is contained in:
parent
15944eb0f3
commit
1b4aadbb25
4 changed files with 85 additions and 29 deletions
|
@ -47,7 +47,10 @@ def completion(
|
|||
temperature=1, top_p=1, n=1, stream=False, stop=None, max_tokens=float('inf'),
|
||||
presence_penalty=0, frequency_penalty=0, logit_bias={}, user="", deployment_id=None,
|
||||
# Optional liteLLM function params
|
||||
*, return_async=False, api_key=None, force_timeout=600, logger_fn=None, verbose=False, azure=False, custom_llm_provider=None, custom_api_base=None
|
||||
*, return_async=False, api_key=None, force_timeout=600, logger_fn=None, verbose=False, azure=False, custom_llm_provider=None, custom_api_base=None,
|
||||
# model specific optional params
|
||||
# used by text-bison only
|
||||
top_k=40,
|
||||
):
|
||||
try:
|
||||
global new_response
|
||||
|
@ -61,7 +64,7 @@ def completion(
|
|||
temperature=temperature, top_p=top_p, n=n, stream=stream, stop=stop, max_tokens=max_tokens,
|
||||
presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, logit_bias=logit_bias, user=user, deployment_id=deployment_id,
|
||||
# params to identify the model
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
model=model, custom_llm_provider=custom_llm_provider, top_k=top_k,
|
||||
)
|
||||
# For logging - save the values of the litellm-specific params passed in
|
||||
litellm_params = get_litellm_params(
|
||||
|
@ -366,7 +369,7 @@ def completion(
|
|||
"total_tokens": prompt_tokens + completion_tokens
|
||||
}
|
||||
response = model_response
|
||||
elif model in litellm.vertex_models:
|
||||
elif model in litellm.vertex_chat_models:
|
||||
# import vertexai/if it fails then pip install vertexai# import cohere/if it fails then pip install cohere
|
||||
install_and_import("vertexai")
|
||||
import vertexai
|
||||
|
@ -387,6 +390,28 @@ def completion(
|
|||
## LOGGING
|
||||
logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, additional_args={"max_tokens": max_tokens, "original_response": completion_response}, logger_fn=logger_fn)
|
||||
|
||||
## RESPONSE OBJECT
|
||||
model_response["choices"][0]["message"]["content"] = completion_response
|
||||
model_response["created"] = time.time()
|
||||
model_response["model"] = model
|
||||
elif model in litellm.vertex_text_models:
|
||||
# import vertexai/if it fails then pip install vertexai# import cohere/if it fails then pip install cohere
|
||||
install_and_import("vertexai")
|
||||
import vertexai
|
||||
from vertexai.language_models import TextGenerationModel
|
||||
|
||||
vertexai.init(project=litellm.vertex_project, location=litellm.vertex_location)
|
||||
# vertexai does not use an API key, it looks for credentials.json in the environment
|
||||
|
||||
prompt = " ".join([message["content"] for message in messages])
|
||||
## LOGGING
|
||||
logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, logger_fn=logger_fn)
|
||||
vertex_model = TextGenerationModel.from_pretrained(model)
|
||||
completion_response= vertex_model.predict(prompt, **optional_params)
|
||||
|
||||
## LOGGING
|
||||
logging(model=model, input=prompt, custom_llm_provider=custom_llm_provider, additional_args={"max_tokens": max_tokens, "original_response": completion_response}, logger_fn=logger_fn)
|
||||
|
||||
## RESPONSE OBJECT
|
||||
model_response["choices"][0]["message"]["content"] = completion_response
|
||||
model_response["created"] = time.time()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue