mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
* LiteLLM Minor Fixes & Improvements (09/23/2024) (#5842) * feat(auth_utils.py): enable admin to allow client-side credentials to be passed Makes it easier for devs to experiment with finetuned fireworks ai models * feat(router.py): allow setting configurable_clientside_auth_params for a model Closes https://github.com/BerriAI/litellm/issues/5843 * build(model_prices_and_context_window.json): fix anthropic claude-3-5-sonnet max output token limit Fixes https://github.com/BerriAI/litellm/issues/5850 * fix(azure_ai/): support content list for azure ai Fixes https://github.com/BerriAI/litellm/issues/4237 * fix(litellm_logging.py): always set saved_cache_cost Set to 0 by default * fix(fireworks_ai/cost_calculator.py): add fireworks ai default pricing handles calling 405b+ size models * fix(slack_alerting.py): fix error alerting for failed spend tracking Fixes regression with slack alerting error monitoring * fix(vertex_and_google_ai_studio_gemini.py): handle gemini no candidates in streaming chunk error * docs(bedrock.md): add llama3-1 models * test: fix tests * fix(azure_ai/chat): fix transformation for azure ai calls * feat(azure_ai/embed): Add azure ai embeddings support Closes https://github.com/BerriAI/litellm/issues/5861 * fix(azure_ai/embed): enable async embedding * feat(azure_ai/embed): support azure ai multimodal embeddings * fix(azure_ai/embed): support async multi modal embeddings * feat(together_ai/embed): support together ai embedding calls * feat(rerank/main.py): log source documents for rerank endpoints to langfuse improves rerank endpoint logging * fix(langfuse.py): support logging `/audio/speech` input to langfuse * test(test_embedding.py): fix test * test(test_completion_cost.py): fix helper util
261 lines
9.3 KiB
Python
261 lines
9.3 KiB
Python
import asyncio
|
|
import contextvars
|
|
from functools import partial
|
|
from typing import Any, Coroutine, Dict, List, Literal, Optional, Union
|
|
|
|
import litellm
|
|
from litellm._logging import verbose_logger
|
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
|
from litellm.llms.azure_ai.rerank import AzureAIRerank
|
|
from litellm.llms.cohere.rerank import CohereRerank
|
|
from litellm.llms.together_ai.rerank import TogetherAIRerank
|
|
from litellm.secret_managers.main import get_secret
|
|
from litellm.types.router import *
|
|
from litellm.utils import client, exception_type, supports_httpx_timeout
|
|
|
|
from .types import RerankRequest, RerankResponse
|
|
|
|
####### ENVIRONMENT VARIABLES ###################
|
|
# Initialize any necessary instances or variables here
|
|
cohere_rerank = CohereRerank()
|
|
together_rerank = TogetherAIRerank()
|
|
azure_ai_rerank = AzureAIRerank()
|
|
#################################################
|
|
|
|
|
|
@client
|
|
async def arerank(
|
|
model: str,
|
|
query: str,
|
|
documents: List[Union[str, Dict[str, Any]]],
|
|
custom_llm_provider: Optional[Literal["cohere", "together_ai"]] = None,
|
|
top_n: Optional[int] = None,
|
|
rank_fields: Optional[List[str]] = None,
|
|
return_documents: Optional[bool] = None,
|
|
max_chunks_per_doc: Optional[int] = None,
|
|
**kwargs,
|
|
) -> Union[RerankResponse, Coroutine[Any, Any, RerankResponse]]:
|
|
"""
|
|
Async: Reranks a list of documents based on their relevance to the query
|
|
"""
|
|
try:
|
|
loop = asyncio.get_event_loop()
|
|
kwargs["arerank"] = True
|
|
|
|
func = partial(
|
|
rerank,
|
|
model,
|
|
query,
|
|
documents,
|
|
custom_llm_provider,
|
|
top_n,
|
|
rank_fields,
|
|
return_documents,
|
|
max_chunks_per_doc,
|
|
**kwargs,
|
|
)
|
|
|
|
ctx = contextvars.copy_context()
|
|
func_with_context = partial(ctx.run, func)
|
|
init_response = await loop.run_in_executor(None, func_with_context)
|
|
|
|
if asyncio.iscoroutine(init_response):
|
|
response = await init_response
|
|
else:
|
|
response = init_response
|
|
return response
|
|
except Exception as e:
|
|
raise e
|
|
|
|
|
|
@client
|
|
def rerank(
|
|
model: str,
|
|
query: str,
|
|
documents: List[Union[str, Dict[str, Any]]],
|
|
custom_llm_provider: Optional[Literal["cohere", "together_ai", "azure_ai"]] = None,
|
|
top_n: Optional[int] = None,
|
|
rank_fields: Optional[List[str]] = None,
|
|
return_documents: Optional[bool] = True,
|
|
max_chunks_per_doc: Optional[int] = None,
|
|
**kwargs,
|
|
) -> Union[RerankResponse, Coroutine[Any, Any, RerankResponse]]:
|
|
"""
|
|
Reranks a list of documents based on their relevance to the query
|
|
"""
|
|
headers: Optional[dict] = kwargs.get("headers") # type: ignore
|
|
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore
|
|
litellm_call_id: Optional[str] = kwargs.get("litellm_call_id", None)
|
|
proxy_server_request = kwargs.get("proxy_server_request", None)
|
|
model_info = kwargs.get("model_info", None)
|
|
metadata = kwargs.get("metadata", {})
|
|
user = kwargs.get("user", None)
|
|
try:
|
|
_is_async = kwargs.pop("arerank", False) is True
|
|
optional_params = GenericLiteLLMParams(**kwargs)
|
|
|
|
model, _custom_llm_provider, dynamic_api_key, dynamic_api_base = (
|
|
litellm.get_llm_provider(
|
|
model=model,
|
|
custom_llm_provider=custom_llm_provider,
|
|
api_base=optional_params.api_base,
|
|
api_key=optional_params.api_key,
|
|
)
|
|
)
|
|
|
|
model_params_dict = {
|
|
"top_n": top_n,
|
|
"rank_fields": rank_fields,
|
|
"return_documents": return_documents,
|
|
"max_chunks_per_doc": max_chunks_per_doc,
|
|
"documents": documents,
|
|
}
|
|
|
|
litellm_logging_obj.update_environment_variables(
|
|
model=model,
|
|
user=user,
|
|
optional_params=model_params_dict,
|
|
litellm_params={
|
|
"litellm_call_id": litellm_call_id,
|
|
"proxy_server_request": proxy_server_request,
|
|
"model_info": model_info,
|
|
"metadata": metadata,
|
|
"preset_cache_key": None,
|
|
"stream_response": {},
|
|
**optional_params.model_dump(exclude_unset=True),
|
|
},
|
|
custom_llm_provider=_custom_llm_provider,
|
|
)
|
|
|
|
# Implement rerank logic here based on the custom_llm_provider
|
|
if _custom_llm_provider == "cohere":
|
|
# Implement Cohere rerank logic
|
|
api_key: Optional[str] = (
|
|
dynamic_api_key
|
|
or optional_params.api_key
|
|
or litellm.cohere_key
|
|
or get_secret("COHERE_API_KEY") # type: ignore
|
|
or get_secret("CO_API_KEY") # type: ignore
|
|
or litellm.api_key
|
|
)
|
|
|
|
if api_key is None:
|
|
raise ValueError(
|
|
"Cohere API key is required, please set 'COHERE_API_KEY' in your environment"
|
|
)
|
|
|
|
api_base: Optional[str] = (
|
|
dynamic_api_base
|
|
or optional_params.api_base
|
|
or litellm.api_base
|
|
or get_secret("COHERE_API_BASE") # type: ignore
|
|
or "https://api.cohere.com/v1/rerank"
|
|
)
|
|
|
|
if api_base is None:
|
|
raise Exception(
|
|
"Invalid api base. api_base=None. Set in call or via `COHERE_API_BASE` env var."
|
|
)
|
|
|
|
headers = headers or litellm.headers or {}
|
|
|
|
response = cohere_rerank.rerank(
|
|
model=model,
|
|
query=query,
|
|
documents=documents,
|
|
top_n=top_n,
|
|
rank_fields=rank_fields,
|
|
return_documents=return_documents,
|
|
max_chunks_per_doc=max_chunks_per_doc,
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
_is_async=_is_async,
|
|
headers=headers,
|
|
litellm_logging_obj=litellm_logging_obj,
|
|
)
|
|
elif _custom_llm_provider == "azure_ai":
|
|
api_base = (
|
|
dynamic_api_base # for deepinfra/perplexity/anyscale/groq/friendliai we check in get_llm_provider and pass in the api base from there
|
|
or optional_params.api_base
|
|
or litellm.api_base
|
|
or get_secret("AZURE_AI_API_BASE") # type: ignore
|
|
)
|
|
# set API KEY
|
|
api_key = (
|
|
dynamic_api_key
|
|
or litellm.api_key # for deepinfra/perplexity/anyscale/friendliai we check in get_llm_provider and pass in the api key from there
|
|
or litellm.openai_key
|
|
or get_secret("AZURE_AI_API_KEY") # type: ignore
|
|
)
|
|
|
|
headers = headers or litellm.headers or {}
|
|
|
|
if api_key is None:
|
|
raise ValueError(
|
|
"Azure AI API key is required, please set 'AZURE_AI_API_KEY' in your environment"
|
|
)
|
|
|
|
if api_base is None:
|
|
raise Exception(
|
|
"Azure AI API Base is required. api_base=None. Set in call or via `AZURE_AI_API_BASE` env var."
|
|
)
|
|
|
|
## LOAD CONFIG - if set
|
|
config = litellm.OpenAIConfig.get_config()
|
|
for k, v in config.items():
|
|
if (
|
|
k not in optional_params
|
|
): # completion(top_k=3) > openai_config(top_k=3) <- allows for dynamic variables to be passed in
|
|
optional_params[k] = v
|
|
|
|
response = azure_ai_rerank.rerank(
|
|
model=model,
|
|
query=query,
|
|
documents=documents,
|
|
top_n=top_n,
|
|
rank_fields=rank_fields,
|
|
return_documents=return_documents,
|
|
max_chunks_per_doc=max_chunks_per_doc,
|
|
api_key=api_key,
|
|
api_base=api_base,
|
|
_is_async=_is_async,
|
|
headers=headers,
|
|
litellm_logging_obj=litellm_logging_obj,
|
|
)
|
|
elif _custom_llm_provider == "together_ai":
|
|
# Implement Together AI rerank logic
|
|
api_key = (
|
|
dynamic_api_key
|
|
or optional_params.api_key
|
|
or litellm.togetherai_api_key
|
|
or get_secret("TOGETHERAI_API_KEY") # type: ignore
|
|
or litellm.api_key
|
|
)
|
|
|
|
if api_key is None:
|
|
raise ValueError(
|
|
"TogetherAI API key is required, please set 'TOGETHERAI_API_KEY' in your environment"
|
|
)
|
|
|
|
response = together_rerank.rerank(
|
|
model=model,
|
|
query=query,
|
|
documents=documents,
|
|
top_n=top_n,
|
|
rank_fields=rank_fields,
|
|
return_documents=return_documents,
|
|
max_chunks_per_doc=max_chunks_per_doc,
|
|
api_key=api_key,
|
|
_is_async=_is_async,
|
|
)
|
|
|
|
else:
|
|
raise ValueError(f"Unsupported provider: {_custom_llm_provider}")
|
|
|
|
# Placeholder return
|
|
return response
|
|
except Exception as e:
|
|
verbose_logger.error(f"Error in rerank: {str(e)}")
|
|
raise exception_type(
|
|
model=model, custom_llm_provider=custom_llm_provider, original_exception=e
|
|
)
|