feat(vertex_ai_llama.py): vertex ai llama3.1 api support

Initial working commit for vertex ai llama 3.1 api support
This commit is contained in:
Krrish Dholakia 2024-07-23 17:07:30 -07:00
parent 169da8b8d0
commit 83ef52e180
5 changed files with 355 additions and 19 deletions

View file

@ -120,6 +120,7 @@ from .llms.prompt_templates.factory import (
)
from .llms.text_completion_codestral import CodestralTextCompletion
from .llms.triton import TritonChatCompletion
from .llms.vertex_ai_llama import VertexAILlama3
from .llms.vertex_httpx import VertexLLM
from .llms.watsonx import IBMWatsonXAI
from .types.llms.openai import HttpxBinaryResponseContent
@ -156,6 +157,7 @@ triton_chat_completions = TritonChatCompletion()
bedrock_chat_completion = BedrockLLM()
bedrock_converse_chat_completion = BedrockConverseLLM()
vertex_chat_completion = VertexLLM()
vertex_llama_chat_completion = VertexAILlama3()
watsonxai = IBMWatsonXAI()
####### COMPLETION ENDPOINTS ################
@ -2064,7 +2066,26 @@ def completion(
timeout=timeout,
client=client,
)
elif model.startswith("meta/"):
model_response = vertex_llama_chat_completion.completion(
model=model,
messages=messages,
model_response=model_response,
print_verbose=print_verbose,
optional_params=new_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
vertex_location=vertex_ai_location,
vertex_project=vertex_ai_project,
vertex_credentials=vertex_credentials,
logging_obj=logging,
acompletion=acompletion,
headers=headers,
custom_prompt_dict=custom_prompt_dict,
timeout=timeout,
client=client,
)
else:
model_response = vertex_ai.completion(
model=model,
@ -2478,28 +2499,25 @@ def completion(
return generator
response = generator
elif custom_llm_provider == "triton":
api_base = (
litellm.api_base or api_base
)
api_base = litellm.api_base or api_base
model_response = triton_chat_completions.completion(
api_base=api_base,
timeout=timeout, # type: ignore
model=model,
messages=messages,
model_response=model_response,
optional_params=optional_params,
logging_obj=logging,
stream=stream,
acompletion=acompletion
api_base=api_base,
timeout=timeout, # type: ignore
model=model,
messages=messages,
model_response=model_response,
optional_params=optional_params,
logging_obj=logging,
stream=stream,
acompletion=acompletion,
)
## RESPONSE OBJECT
response = model_response
return response
elif custom_llm_provider == "cloudflare":
api_key = (
api_key