Merge branch 'main' into litellm_vertex_migration

This commit is contained in:
Krish Dholakia 2024-07-27 20:25:12 -07:00 committed by GitHub
commit 0525fb75f3
319 changed files with 23692 additions and 5152 deletions

View file

@ -38,6 +38,7 @@ import dotenv
import httpx
import openai
import tiktoken
from pydantic import BaseModel
from typing_extensions import overload
import litellm
@ -48,6 +49,7 @@ from litellm import ( # type: ignore
get_litellm_params,
get_optional_params,
)
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.utils import (
CustomStreamWrapper,
@ -61,6 +63,7 @@ from litellm.utils import (
get_llm_provider,
get_optional_params_embeddings,
get_optional_params_image_gen,
get_optional_params_transcription,
get_secret,
mock_completion_streaming_obj,
read_config_args,
@ -104,6 +107,7 @@ from .llms.anthropic_text import AnthropicTextCompletion
from .llms.azure import AzureChatCompletion
from .llms.azure_text import AzureTextCompletion
from .llms.bedrock_httpx import BedrockConverseLLM, BedrockLLM
from .llms.custom_llm import CustomLLM, custom_chat_llm_router
from .llms.databricks import DatabricksChatCompletion
from .llms.huggingface_restapi import Huggingface
from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion
@ -117,7 +121,9 @@ from .llms.prompt_templates.factory import (
)
from .llms.text_completion_codestral import CodestralTextCompletion
from .llms.triton import TritonChatCompletion
from .llms.vertex_ai_llama import VertexAILlama3
from .llms.vertex_httpx import VertexLLM
from .llms.watsonx import IBMWatsonXAI
from .types.llms.openai import HttpxBinaryResponseContent
from .types.utils import ChatCompletionMessageToolCall
@ -152,6 +158,8 @@ triton_chat_completions = TritonChatCompletion()
bedrock_chat_completion = BedrockLLM()
bedrock_converse_chat_completion = BedrockConverseLLM()
vertex_chat_completion = VertexLLM()
vertex_llama_chat_completion = VertexAILlama3()
watsonxai = IBMWatsonXAI()
####### COMPLETION ENDPOINTS ################
@ -242,6 +250,7 @@ async def acompletion(
seed: Optional[int] = None,
tools: Optional[List] = None,
tool_choice: Optional[str] = None,
parallel_tool_calls: Optional[bool] = None,
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
deployment_id=None,
@ -317,6 +326,7 @@ async def acompletion(
"seed": seed,
"tools": tools,
"tool_choice": tool_choice,
"parallel_tool_calls": parallel_tool_calls,
"logprobs": logprobs,
"top_logprobs": top_logprobs,
"deployment_id": deployment_id,
@ -368,8 +378,11 @@ async def acompletion(
or custom_llm_provider == "predibase"
or custom_llm_provider == "bedrock"
or custom_llm_provider == "databricks"
or custom_llm_provider == "triton"
or custom_llm_provider == "clarifai"
or custom_llm_provider == "watsonx"
or custom_llm_provider in litellm.openai_compatible_providers
or custom_llm_provider in litellm._custom_providers
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
init_response = await loop.run_in_executor(None, func_with_context)
if isinstance(init_response, dict) or isinstance(
@ -471,7 +484,7 @@ def mock_completion(
if isinstance(mock_response, Exception):
if isinstance(mock_response, openai.APIError):
raise mock_response
raise litellm.APIError(
raise litellm.MockException(
status_code=getattr(mock_response, "status_code", 500), # type: ignore
message=getattr(mock_response, "text", str(mock_response)),
llm_provider=getattr(mock_response, "llm_provider", custom_llm_provider or "openai"), # type: ignore
@ -514,7 +527,7 @@ def mock_completion(
)
return response
if n is None:
model_response["choices"][0]["message"]["content"] = mock_response
model_response.choices[0].message.content = mock_response # type: ignore
else:
_all_choices = []
for i in range(n):
@ -525,12 +538,12 @@ def mock_completion(
),
)
_all_choices.append(_choice)
model_response["choices"] = _all_choices
model_response["created"] = int(time.time())
model_response["model"] = model
model_response.choices = _all_choices # type: ignore
model_response.created = int(time.time())
model_response.model = model
if mock_tool_calls:
model_response["choices"][0]["message"]["tool_calls"] = [
model_response.choices[0].message.tool_calls = [ # type: ignore
ChatCompletionMessageToolCall(**tool_call)
for tool_call in mock_tool_calls
]
@ -590,6 +603,7 @@ def completion(
tool_choice: Optional[Union[str, dict]] = None,
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None,
deployment_id=None,
extra_headers: Optional[dict] = None,
# soon to be deprecated params by OpenAI
@ -719,12 +733,14 @@ def completion(
"tools",
"tool_choice",
"max_retries",
"parallel_tool_calls",
"logprobs",
"top_logprobs",
"extra_headers",
]
litellm_params = [
"metadata",
"tags",
"acompletion",
"atext_completion",
"text_completion",
@ -893,7 +909,7 @@ def completion(
if (
supports_system_message is not None
and isinstance(supports_system_message, bool)
and supports_system_message == False
and supports_system_message is False
):
messages = map_system_message_pt(messages=messages)
model_api_key = get_api_key(
@ -929,6 +945,7 @@ def completion(
top_logprobs=top_logprobs,
extra_headers=extra_headers,
api_version=api_version,
parallel_tool_calls=parallel_tool_calls,
**non_default_params,
)
@ -1164,6 +1181,7 @@ def completion(
client=client, # pass AsyncOpenAI, OpenAI client
organization=organization,
custom_llm_provider=custom_llm_provider,
drop_params=non_default_params.get("drop_params"),
)
except Exception as e:
## LOGGING - log the original exception returned
@ -1475,8 +1493,13 @@ def completion(
api_base
or litellm.api_base
or get_secret("ANTHROPIC_API_BASE")
or get_secret("ANTHROPIC_BASE_URL")
or "https://api.anthropic.com/v1/complete"
)
if api_base is not None and not api_base.endswith("/v1/complete"):
api_base += "/v1/complete"
response = anthropic_text_completions.completion(
model=model,
messages=messages,
@ -1500,8 +1523,13 @@ def completion(
api_base
or litellm.api_base
or get_secret("ANTHROPIC_API_BASE")
or get_secret("ANTHROPIC_BASE_URL")
or "https://api.anthropic.com/v1/messages"
)
if api_base is not None and not api_base.endswith("/v1/messages"):
api_base += "/v1/messages"
response = anthropic_chat_completions.completion(
model=model,
messages=messages,
@ -1517,6 +1545,8 @@ def completion(
api_key=api_key,
logging_obj=logging,
headers=headers,
timeout=timeout,
client=client,
)
if optional_params.get("stream", False) or acompletion == True:
## LOGGING
@ -1923,51 +1953,7 @@ def completion(
"""
Deprecated. We now do together ai calls via the openai client - https://docs.together.ai/docs/openai-api-compatibility
"""
custom_llm_provider = "together_ai"
together_ai_key = (
api_key
or litellm.togetherai_api_key
or get_secret("TOGETHER_AI_TOKEN")
or get_secret("TOGETHERAI_API_KEY")
or litellm.api_key
)
api_base = (
api_base
or litellm.api_base
or get_secret("TOGETHERAI_API_BASE")
or "https://api.together.xyz/inference"
)
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
model_response = together_ai.completion(
model=model,
messages=messages,
api_base=api_base,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
api_key=together_ai_key,
logging_obj=logging,
custom_prompt_dict=custom_prompt_dict,
)
if (
"stream_tokens" in optional_params
and optional_params["stream_tokens"] == True
):
# don't try to access stream object,
response = CustomStreamWrapper(
model_response,
model,
custom_llm_provider="together_ai",
logging_obj=logging,
)
return response
response = model_response
pass
elif custom_llm_provider == "palm":
palm_api_key = api_key or get_secret("PALM_API_KEY") or litellm.api_key
@ -2079,6 +2065,28 @@ def completion(
acompletion=acompletion,
headers=headers,
custom_prompt_dict=custom_prompt_dict,
timeout=timeout,
client=client,
)
elif model.startswith("meta/"):
model_response = vertex_llama_chat_completion.completion(
model=model,
messages=messages,
model_response=model_response,
print_verbose=print_verbose,
optional_params=new_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
vertex_location=vertex_ai_location,
vertex_project=vertex_ai_project,
vertex_credentials=vertex_credentials,
logging_obj=logging,
acompletion=acompletion,
headers=headers,
custom_prompt_dict=custom_prompt_dict,
timeout=timeout,
client=client,
)
elif "gemini" in model:
model_response = vertex_chat_completion.completion( # type: ignore
@ -2374,7 +2382,7 @@ def completion(
response = response
elif custom_llm_provider == "watsonx":
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
response = watsonx.IBMWatsonXAI().completion(
response = watsonxai.completion(
model=model,
messages=messages,
custom_prompt_dict=custom_prompt_dict,
@ -2386,6 +2394,7 @@ def completion(
encoding=encoding,
logging_obj=logging,
timeout=timeout, # type: ignore
acompletion=acompletion,
)
if (
"stream" in optional_params
@ -2471,10 +2480,10 @@ def completion(
## LOGGING
generator = ollama.get_ollama_response(
api_base,
model,
prompt,
optional_params,
api_base=api_base,
model=model,
prompt=prompt,
optional_params=optional_params,
logging_obj=logging,
acompletion=acompletion,
model_response=model_response,
@ -2500,11 +2509,11 @@ def completion(
)
## LOGGING
generator = ollama_chat.get_ollama_response(
api_base,
api_key,
model,
messages,
optional_params,
api_base=api_base,
api_key=api_key,
model=model,
messages=messages,
optional_params=optional_params,
logging_obj=logging,
acompletion=acompletion,
model_response=model_response,
@ -2514,6 +2523,25 @@ def completion(
return generator
response = generator
elif custom_llm_provider == "triton":
api_base = litellm.api_base or api_base
model_response = triton_chat_completions.completion(
api_base=api_base,
timeout=timeout, # type: ignore
model=model,
messages=messages,
model_response=model_response,
optional_params=optional_params,
logging_obj=logging,
stream=stream,
acompletion=acompletion,
)
## RESPONSE OBJECT
response = model_response
return response
elif custom_llm_provider == "cloudflare":
api_key = (
api_key
@ -2682,10 +2710,58 @@ def completion(
"""
string_response = response_json["data"][0]["output"][0]
## RESPONSE OBJECT
model_response["choices"][0]["message"]["content"] = string_response
model_response["created"] = int(time.time())
model_response["model"] = model
model_response.choices[0].message.content = string_response # type: ignore
model_response.created = int(time.time())
model_response.model = model
response = model_response
elif (
custom_llm_provider in litellm._custom_providers
): # Assume custom LLM provider
# Get the Custom Handler
custom_handler: Optional[CustomLLM] = None
for item in litellm.custom_provider_map:
if item["provider"] == custom_llm_provider:
custom_handler = item["custom_handler"]
if custom_handler is None:
raise ValueError(
f"Unable to map your input to a model. Check your input - {args}"
)
## ROUTE LLM CALL ##
handler_fn = custom_chat_llm_router(
async_fn=acompletion, stream=stream, custom_llm=custom_handler
)
headers = headers or litellm.headers
## CALL FUNCTION
response = handler_fn(
model=model,
messages=messages,
headers=headers,
model_response=model_response,
print_verbose=print_verbose,
api_key=api_key,
api_base=api_base,
acompletion=acompletion,
logging_obj=logging,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
timeout=timeout, # type: ignore
custom_prompt_dict=custom_prompt_dict,
client=client, # pass AsyncOpenAI, OpenAI client
encoding=encoding,
)
if stream is True:
return CustomStreamWrapper(
completion_stream=response,
model=model,
custom_llm_provider=custom_llm_provider,
logging_obj=logging,
)
else:
raise ValueError(
f"Unable to map your input to a model. Check your input - {args}"
@ -3052,6 +3128,7 @@ async def aembedding(*args, **kwargs) -> EmbeddingResponse:
or custom_llm_provider == "ollama"
or custom_llm_provider == "vertex_ai"
or custom_llm_provider == "databricks"
or custom_llm_provider == "watsonx"
): # currently implemented aiohttp calls for just azure and openai, soon all.
# Await normally
init_response = await loop.run_in_executor(None, func_with_context)
@ -3203,6 +3280,7 @@ def embedding(
"allowed_model_region",
"model_config",
"cooldown_time",
"tags",
]
default_params = openai_params + litellm_params
non_default_params = {
@ -3474,7 +3552,7 @@ def embedding(
or api_base
or get_secret("OLLAMA_API_BASE")
or "http://localhost:11434"
)
) # type: ignore
if isinstance(input, str):
input = [input]
if not all(isinstance(item, str) for item in input):
@ -3484,9 +3562,11 @@ def embedding(
llm_provider="ollama", # type: ignore
)
ollama_embeddings_fn = (
ollama.ollama_aembeddings if aembedding else ollama.ollama_embeddings
ollama.ollama_aembeddings
if aembedding is True
else ollama.ollama_embeddings
)
response = ollama_embeddings_fn(
response = ollama_embeddings_fn( # type: ignore
api_base=api_base,
model=model,
prompts=input,
@ -3559,13 +3639,14 @@ def embedding(
aembedding=aembedding,
)
elif custom_llm_provider == "watsonx":
response = watsonx.IBMWatsonXAI().embedding(
response = watsonxai.embedding(
model=model,
input=input,
encoding=encoding,
logging_obj=logging,
optional_params=optional_params,
model_response=EmbeddingResponse(),
aembedding=aembedding,
)
else:
args = locals()
@ -3824,7 +3905,7 @@ def text_completion(
optional_params["custom_llm_provider"] = custom_llm_provider
# get custom_llm_provider
_, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
_model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
if custom_llm_provider == "huggingface":
# if echo == True, for TGI llms we need to set top_n_tokens to 3
@ -3907,10 +3988,12 @@ def text_completion(
kwargs.pop("prompt", None)
if model is not None and model.startswith(
"openai/"
if (
_model is not None and custom_llm_provider == "openai"
): # for openai compatible endpoints - e.g. vllm, call the native /v1/completions endpoint for text completion calls
model = model.replace("openai/", "text-completion-openai/")
if _model not in litellm.open_ai_chat_completion_models:
model = "text-completion-openai/" + _model
optional_params.pop("custom_llm_provider", None)
kwargs["text_completion"] = True
response = completion(
@ -3953,6 +4036,63 @@ def text_completion(
return text_completion_response
###### Adapter Completion ################
async def aadapter_completion(*, adapter_id: str, **kwargs) -> Optional[BaseModel]:
"""
Implemented to handle async calls for adapter_completion()
"""
try:
translation_obj: Optional[CustomLogger] = None
for item in litellm.adapters:
if item["id"] == adapter_id:
translation_obj = item["adapter"]
if translation_obj is None:
raise ValueError(
"No matching adapter given. Received 'adapter_id'={}, litellm.adapters={}".format(
adapter_id, litellm.adapters
)
)
new_kwargs = translation_obj.translate_completion_input_params(kwargs=kwargs)
response: ModelResponse = await acompletion(**new_kwargs) # type: ignore
translated_response = translation_obj.translate_completion_output_params(
response=response
)
return translated_response
except Exception as e:
raise e
def adapter_completion(*, adapter_id: str, **kwargs) -> Optional[BaseModel]:
translation_obj: Optional[CustomLogger] = None
for item in litellm.adapters:
if item["id"] == adapter_id:
translation_obj = item["adapter"]
if translation_obj is None:
raise ValueError(
"No matching adapter given. Received 'adapter_id'={}, litellm.adapters={}".format(
adapter_id, litellm.adapters
)
)
new_kwargs = translation_obj.translate_completion_input_params(kwargs=kwargs)
response: ModelResponse = completion(**new_kwargs) # type: ignore
translated_response = translation_obj.translate_completion_output_params(
response=response
)
return translated_response
##### Moderation #######################
@ -4273,6 +4413,7 @@ def image_generation(
model_response=model_response,
vertex_project=vertex_ai_project,
vertex_location=vertex_ai_location,
vertex_credentials=vertex_credentials,
aimg_generation=aimg_generation,
)
@ -4292,7 +4433,7 @@ def image_generation(
@client
async def atranscription(*args, **kwargs):
async def atranscription(*args, **kwargs) -> TranscriptionResponse:
"""
Calls openai + azure whisper endpoints.
@ -4317,9 +4458,9 @@ async def atranscription(*args, **kwargs):
# Await normally
init_response = await loop.run_in_executor(None, func_with_context)
if isinstance(init_response, dict) or isinstance(
init_response, TranscriptionResponse
): ## CACHING SCENARIO
if isinstance(init_response, dict):
response = TranscriptionResponse(**init_response)
elif isinstance(init_response, TranscriptionResponse): ## CACHING SCENARIO
response = init_response
elif asyncio.iscoroutine(init_response):
response = await init_response
@ -4359,7 +4500,7 @@ def transcription(
litellm_logging_obj: Optional[LiteLLMLoggingObj] = None,
custom_llm_provider=None,
**kwargs,
):
) -> TranscriptionResponse:
"""
Calls openai + azure whisper endpoints.
@ -4371,6 +4512,9 @@ def transcription(
proxy_server_request = kwargs.get("proxy_server_request", None)
model_info = kwargs.get("model_info", None)
metadata = kwargs.get("metadata", {})
tags = kwargs.pop("tags", [])
drop_params = kwargs.get("drop_params", None)
client: Optional[
Union[
openai.AsyncOpenAI,
@ -4392,12 +4536,22 @@ def transcription(
if dynamic_api_key is not None:
api_key = dynamic_api_key
optional_params = {
"language": language,
"prompt": prompt,
"response_format": response_format,
"temperature": None, # openai defaults this to 0
}
optional_params = get_optional_params_transcription(
model=model,
language=language,
prompt=prompt,
response_format=response_format,
temperature=temperature,
custom_llm_provider=custom_llm_provider,
drop_params=drop_params,
)
# optional_params = {
# "language": language,
# "prompt": prompt,
# "response_format": response_format,
# "temperature": None, # openai defaults this to 0
# }
if custom_llm_provider == "azure":
# azure configs
@ -4532,6 +4686,7 @@ def speech(
) -> HttpxBinaryResponseContent:
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
tags = kwargs.pop("tags", [])
optional_params = {}
if response_format is not None:
@ -4664,12 +4819,12 @@ async def ahealth_check(
raise Exception("model not set")
if model in litellm.model_cost and mode is None:
mode = litellm.model_cost[model]["mode"]
mode = litellm.model_cost[model].get("mode")
model, custom_llm_provider, _, _ = get_llm_provider(model=model)
if model in litellm.model_cost and mode is None:
mode = litellm.model_cost[model]["mode"]
mode = litellm.model_cost[model].get("mode")
mode = mode or "chat" # default to chat completion calls
@ -4769,9 +4924,10 @@ async def ahealth_check(
if isinstance(stack_trace, str):
stack_trace = stack_trace[:1000]
if model not in litellm.model_cost and mode is None:
raise Exception(
"Missing `mode`. Set the `mode` for the model - https://docs.litellm.ai/docs/proxy/health#embedding-models"
)
return {
"error": "Missing `mode`. Set the `mode` for the model - https://docs.litellm.ai/docs/proxy/health#embedding-models"
}
error_to_return = str(e) + " stack trace: " + stack_trace
return {"error": error_to_return}