LiteLLM Minor Fixes & Improvements (10/24/2024) (#6421)

* fix(utils.py): support passing dynamic api base to validate_environment

Returns True if just api base is required and api base is passed

* fix(litellm_pre_call_utils.py): feature flag sending client headers to llm api

Fixes https://github.com/BerriAI/litellm/issues/6410

* fix(anthropic/chat/transformation.py): return correct error message

* fix(http_handler.py): add error response text in places where we expect it

* fix(factory.py): handle base case of no non-system messages to bedrock

Fixes https://github.com/BerriAI/litellm/issues/6411

* feat(cohere/embed): Support cohere image embeddings

Closes https://github.com/BerriAI/litellm/issues/6413

* fix(__init__.py): fix linting error

* docs(supported_embedding.md): add image embedding example to docs

* feat(cohere/embed): use cohere embedding returned usage for cost calc

* build(model_prices_and_context_window.json): add embed-english-v3.0 details (image cost + 'supports_image_input' flag)

* fix(cohere_transformation.py): fix linting error

* test(test_proxy_server.py): cleanup test

* test: cleanup test

* fix: fix linting errors
This commit is contained in:
Krish Dholakia 2024-10-25 15:55:56 -07:00 committed by GitHub
parent 38708a355a
commit c03e5da41f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
23 changed files with 417 additions and 150 deletions

View file

@ -84,6 +84,60 @@ print(query_result[:5])
</TabItem> </TabItem>
</Tabs> </Tabs>
## Image Embeddings
For models that support image embeddings, you can pass in a base64 encoded image string to the `input` param.
<Tabs>
<TabItem value="sdk" label="SDK">
```python
from litellm import embedding
import os
# set your api key
os.environ["COHERE_API_KEY"] = ""
response = embedding(model="cohere/embed-english-v3.0", input=["<base64 encoded image>"])
```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Setup config.yaml
```yaml
model_list:
- model_name: cohere-embed
litellm_params:
model: cohere/embed-english-v3.0
api_key: os.environ/COHERE_API_KEY
```
2. Start proxy
```bash
litellm --config /path/to/config.yaml
# RUNNING on http://0.0.0.0:4000
```
3. Test it!
```bash
curl -X POST 'http://0.0.0.0:4000/v1/embeddings' \
-H 'Authorization: Bearer sk-54d77cd67b9febbb' \
-H 'Content-Type: application/json' \
-d '{
"model": "cohere/embed-english-v3.0",
"input": ["<base64 encoded image>"]
}'
```
</TabItem>
</Tabs>
## Input Params for `litellm.embedding()` ## Input Params for `litellm.embedding()`

View file

@ -814,6 +814,7 @@ general_settings:
| pass_through_endpoints | List[Dict[str, Any]] | Define the pass through endpoints. [Docs](./pass_through) | | pass_through_endpoints | List[Dict[str, Any]] | Define the pass through endpoints. [Docs](./pass_through) |
| enable_oauth2_proxy_auth | boolean | (Enterprise Feature) If true, enables oauth2.0 authentication | | enable_oauth2_proxy_auth | boolean | (Enterprise Feature) If true, enables oauth2.0 authentication |
| forward_openai_org_id | boolean | If true, forwards the OpenAI Organization ID to the backend LLM call (if it's OpenAI). | | forward_openai_org_id | boolean | If true, forwards the OpenAI Organization ID to the backend LLM call (if it's OpenAI). |
| forward_client_headers_to_llm_api | boolean | If true, forwards the client headers (any `x-` headers) to the backend LLM call |
### router_settings - Reference ### router_settings - Reference

View file

@ -8,6 +8,7 @@ import os
from typing import Callable, List, Optional, Dict, Union, Any, Literal, get_args from typing import Callable, List, Optional, Dict, Union, Any, Literal, get_args
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.caching.caching import Cache, DualCache, RedisCache, InMemoryCache from litellm.caching.caching import Cache, DualCache, RedisCache, InMemoryCache
from litellm.types.llms.bedrock import COHERE_EMBEDDING_INPUT_TYPES
from litellm._logging import ( from litellm._logging import (
set_verbose, set_verbose,
_turn_on_debug, _turn_on_debug,
@ -136,7 +137,7 @@ enable_azure_ad_token_refresh: Optional[bool] = False
### DEFAULT AZURE API VERSION ### ### DEFAULT AZURE API VERSION ###
AZURE_DEFAULT_API_VERSION = "2024-08-01-preview" # this is updated to the latest AZURE_DEFAULT_API_VERSION = "2024-08-01-preview" # this is updated to the latest
### COHERE EMBEDDINGS DEFAULT TYPE ### ### COHERE EMBEDDINGS DEFAULT TYPE ###
COHERE_DEFAULT_EMBEDDING_INPUT_TYPE = "search_document" COHERE_DEFAULT_EMBEDDING_INPUT_TYPE: COHERE_EMBEDDING_INPUT_TYPES = "search_document"
### GUARDRAILS ### ### GUARDRAILS ###
llamaguard_model_name: Optional[str] = None llamaguard_model_name: Optional[str] = None
openai_moderations_model_name: Optional[str] = None openai_moderations_model_name: Optional[str] = None

View file

@ -333,6 +333,14 @@ def _get_openai_compatible_provider_info( # noqa: PLR0915
api_key: Optional[str], api_key: Optional[str],
dynamic_api_key: Optional[str], dynamic_api_key: Optional[str],
) -> Tuple[str, str, Optional[str], Optional[str]]: ) -> Tuple[str, str, Optional[str], Optional[str]]:
"""
Returns:
Tuple[str, str, Optional[str], Optional[str]]:
model: str
custom_llm_provider: str
dynamic_api_key: Optional[str]
api_base: Optional[str]
"""
custom_llm_provider = model.split("/", 1)[0] custom_llm_provider = model.split("/", 1)[0]
model = model.split("/", 1)[1] model = model.split("/", 1)[1]

View file

@ -398,6 +398,8 @@ class AnthropicChatCompletion(BaseLLM):
error_response = getattr(e, "response", None) error_response = getattr(e, "response", None)
if error_headers is None and error_response: if error_headers is None and error_response:
error_headers = getattr(error_response, "headers", None) error_headers = getattr(error_response, "headers", None)
if error_response and hasattr(error_response, "text"):
error_text = getattr(error_response, "text", error_text)
raise AnthropicError( raise AnthropicError(
message=error_text, message=error_text,
status_code=status_code, status_code=status_code,

View file

@ -9,7 +9,7 @@ import httpx
from openai import OpenAI from openai import OpenAI
import litellm import litellm
from litellm.llms.cohere.embed import embedding as cohere_embedding from litellm.llms.cohere.embed.handler import embedding as cohere_embedding
from litellm.llms.custom_httpx.http_handler import ( from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler, AsyncHTTPHandler,
HTTPHandler, HTTPHandler,

View file

@ -7,6 +7,7 @@ Why separate file? Make it easy to see how transformation works
from typing import List from typing import List
import litellm import litellm
from litellm.llms.cohere.embed.transformation import CohereEmbeddingConfig
from litellm.types.llms.bedrock import CohereEmbeddingRequest, CohereEmbeddingResponse from litellm.types.llms.bedrock import CohereEmbeddingRequest, CohereEmbeddingResponse
from litellm.types.utils import Embedding, EmbeddingResponse from litellm.types.utils import Embedding, EmbeddingResponse
@ -26,15 +27,21 @@ class BedrockCohereEmbeddingConfig:
optional_params["embedding_types"] = v optional_params["embedding_types"] = v
return optional_params return optional_params
def _is_v3_model(self, model: str) -> bool:
return "3" in model
def _transform_request( def _transform_request(
self, input: List[str], inference_params: dict self, model: str, input: List[str], inference_params: dict
) -> CohereEmbeddingRequest: ) -> CohereEmbeddingRequest:
transformed_request = CohereEmbeddingRequest( transformed_request = CohereEmbeddingConfig()._transform_request(
texts=input, model, input, inference_params
input_type=litellm.COHERE_DEFAULT_EMBEDDING_INPUT_TYPE, # type: ignore
) )
for k, v in inference_params.items(): new_transformed_request = CohereEmbeddingRequest(
transformed_request[k] = v # type: ignore input_type=transformed_request["input_type"],
)
for k in CohereEmbeddingRequest.__annotations__.keys():
if k in transformed_request:
new_transformed_request[k] = transformed_request[k] # type: ignore
return transformed_request return new_transformed_request

View file

@ -11,7 +11,7 @@ from typing import Any, Callable, List, Literal, Optional, Tuple, Union
import httpx import httpx
import litellm import litellm
from litellm.llms.cohere.embed import embedding as cohere_embedding from litellm.llms.cohere.embed.handler import embedding as cohere_embedding
from litellm.llms.custom_httpx.http_handler import ( from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler, AsyncHTTPHandler,
HTTPHandler, HTTPHandler,
@ -369,7 +369,7 @@ class BedrockEmbedding(BaseAWSLLM):
batch_data: Optional[List] = None batch_data: Optional[List] = None
if provider == "cohere": if provider == "cohere":
data = BedrockCohereEmbeddingConfig()._transform_request( data = BedrockCohereEmbeddingConfig()._transform_request(
input=input, inference_params=inference_params model=model, input=input, inference_params=inference_params
) )
elif provider == "amazon" and model in [ elif provider == "amazon" and model in [
"amazon.titan-embed-image-v1", "amazon.titan-embed-image-v1",

View file

@ -12,8 +12,11 @@ import requests # type: ignore
import litellm import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.llms.bedrock import CohereEmbeddingRequest
from litellm.utils import Choices, Message, ModelResponse, Usage from litellm.utils import Choices, Message, ModelResponse, Usage
from .transformation import CohereEmbeddingConfig
def validate_environment(api_key, headers: dict): def validate_environment(api_key, headers: dict):
headers.update( headers.update(
@ -41,39 +44,9 @@ class CohereError(Exception):
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
def _process_embedding_response(
embeddings: list,
model_response: litellm.EmbeddingResponse,
model: str,
encoding: Any,
input: list,
) -> litellm.EmbeddingResponse:
output_data = []
for idx, embedding in enumerate(embeddings):
output_data.append(
{"object": "embedding", "index": idx, "embedding": embedding}
)
model_response.object = "list"
model_response.data = output_data
model_response.model = model
input_tokens = 0
for text in input:
input_tokens += len(encoding.encode(text))
setattr(
model_response,
"usage",
Usage(
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
),
)
return model_response
async def async_embedding( async def async_embedding(
model: str, model: str,
data: dict, data: Union[dict, CohereEmbeddingRequest],
input: list, input: list,
model_response: litellm.utils.EmbeddingResponse, model_response: litellm.utils.EmbeddingResponse,
timeout: Optional[Union[float, httpx.Timeout]], timeout: Optional[Union[float, httpx.Timeout]],
@ -121,19 +94,12 @@ async def async_embedding(
) )
raise e raise e
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=response.text,
)
embeddings = response.json()["embeddings"]
## PROCESS RESPONSE ## ## PROCESS RESPONSE ##
return _process_embedding_response( return CohereEmbeddingConfig()._transform_response(
embeddings=embeddings, response=response,
api_key=api_key,
logging_obj=logging_obj,
data=data,
model_response=model_response, model_response=model_response,
model=model, model=model,
encoding=encoding, encoding=encoding,
@ -149,7 +115,7 @@ def embedding(
optional_params: dict, optional_params: dict,
headers: dict, headers: dict,
encoding: Any, encoding: Any,
data: Optional[dict] = None, data: Optional[Union[dict, CohereEmbeddingRequest]] = None,
complete_api_base: Optional[str] = None, complete_api_base: Optional[str] = None,
api_key: Optional[str] = None, api_key: Optional[str] = None,
aembedding: Optional[bool] = None, aembedding: Optional[bool] = None,
@ -159,11 +125,10 @@ def embedding(
headers = validate_environment(api_key, headers=headers) headers = validate_environment(api_key, headers=headers)
embed_url = complete_api_base or "https://api.cohere.ai/v1/embed" embed_url = complete_api_base or "https://api.cohere.ai/v1/embed"
model = model model = model
data = data or {"model": model, "texts": input, **optional_params}
if "3" in model and "input_type" not in data: data = data or CohereEmbeddingConfig()._transform_request(
# cohere v3 embedding models require input_type, if no input_type is provided, default to "search_document" model=model, input=input, inference_params=optional_params
data["input_type"] = "search_document" )
## ROUTING ## ROUTING
if aembedding is True: if aembedding is True:
@ -193,30 +158,12 @@ def embedding(
client = HTTPHandler(concurrent_limit=1) client = HTTPHandler(concurrent_limit=1)
response = client.post(embed_url, headers=headers, data=json.dumps(data)) response = client.post(embed_url, headers=headers, data=json.dumps(data))
## LOGGING
logging_obj.post_call( return CohereEmbeddingConfig()._transform_response(
input=input, response=response,
api_key=api_key, api_key=api_key,
additional_args={"complete_input_dict": data}, logging_obj=logging_obj,
original_response=response, data=data,
)
"""
response
{
'object': "list",
'data': [
]
'model',
'usage'
}
"""
if response.status_code != 200:
raise CohereError(message=response.text, status_code=response.status_code)
embeddings = response.json()["embeddings"]
return _process_embedding_response(
embeddings=embeddings,
model_response=model_response, model_response=model_response,
model=model, model=model,
encoding=encoding, encoding=encoding,

View file

@ -0,0 +1,160 @@
"""
Transformation logic from OpenAI /v1/embeddings format to Cohere's /v1/embed format.
Why separate file? Make it easy to see how transformation works
Convers
- v3 embedding models
- v2 embedding models
Docs - https://docs.cohere.com/v2/reference/embed
"""
import types
from typing import Any, List, Optional, Union
import httpx
from litellm import COHERE_DEFAULT_EMBEDDING_INPUT_TYPE
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.types.llms.bedrock import (
COHERE_EMBEDDING_INPUT_TYPES,
CohereEmbeddingRequest,
CohereEmbeddingRequestWithModel,
)
from litellm.types.utils import (
Embedding,
EmbeddingResponse,
PromptTokensDetailsWrapper,
Usage,
)
from litellm.utils import is_base64_encoded
class CohereEmbeddingConfig:
"""
Reference: https://docs.cohere.com/v2/reference/embed
"""
def __init__(self) -> None:
pass
def get_supported_openai_params(self) -> List[str]:
return ["encoding_format"]
def map_openai_params(
self, non_default_params: dict, optional_params: dict
) -> dict:
for k, v in non_default_params.items():
if k == "encoding_format":
optional_params["embedding_types"] = v
return optional_params
def _is_v3_model(self, model: str) -> bool:
return "3" in model
def _transform_request(
self, model: str, input: List[str], inference_params: dict
) -> CohereEmbeddingRequestWithModel:
is_encoded = False
for input_str in input:
is_encoded = is_base64_encoded(input_str)
if is_encoded: # check if string is b64 encoded image or not
transformed_request = CohereEmbeddingRequestWithModel(
model=model,
images=input,
input_type="image",
)
else:
transformed_request = CohereEmbeddingRequestWithModel(
model=model,
texts=input,
input_type=COHERE_DEFAULT_EMBEDDING_INPUT_TYPE,
)
for k, v in inference_params.items():
transformed_request[k] = v # type: ignore
return transformed_request
def _calculate_usage(self, input: List[str], encoding: Any, meta: dict) -> Usage:
input_tokens = 0
text_tokens: Optional[int] = meta.get("billed_units", {}).get("input_tokens")
image_tokens: Optional[int] = meta.get("billed_units", {}).get("images")
prompt_tokens_details: Optional[PromptTokensDetailsWrapper] = None
if image_tokens is None and text_tokens is None:
for text in input:
input_tokens += len(encoding.encode(text))
else:
prompt_tokens_details = PromptTokensDetailsWrapper(
image_tokens=image_tokens,
text_tokens=text_tokens,
)
if image_tokens:
input_tokens += image_tokens
if text_tokens:
input_tokens += text_tokens
return Usage(
prompt_tokens=input_tokens,
completion_tokens=0,
total_tokens=input_tokens,
prompt_tokens_details=prompt_tokens_details,
)
def _transform_response(
self,
response: httpx.Response,
api_key: Optional[str],
logging_obj: LiteLLMLoggingObj,
data: Union[dict, CohereEmbeddingRequest],
model_response: EmbeddingResponse,
model: str,
encoding: Any,
input: list,
) -> EmbeddingResponse:
response_json = response.json()
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=response_json,
)
"""
response
{
'object': "list",
'data': [
]
'model',
'usage'
}
"""
embeddings = response_json["embeddings"]
output_data = []
for idx, embedding in enumerate(embeddings):
output_data.append(
{"object": "embedding", "index": idx, "embedding": embedding}
)
model_response.object = "list"
model_response.data = output_data
model_response.model = model
input_tokens = 0
for text in input:
input_tokens += len(encoding.encode(text))
setattr(
model_response,
"usage",
self._calculate_usage(input, encoding, response_json.get("meta", {})),
)
return model_response

View file

@ -152,8 +152,10 @@ class AsyncHTTPHandler:
setattr(e, "status_code", e.response.status_code) setattr(e, "status_code", e.response.status_code)
if stream is True: if stream is True:
setattr(e, "message", await e.response.aread()) setattr(e, "message", await e.response.aread())
setattr(e, "text", await e.response.aread())
else: else:
setattr(e, "message", e.response.text) setattr(e, "message", e.response.text)
setattr(e, "text", e.response.text)
raise e raise e
except Exception as e: except Exception as e:
raise e raise e

View file

@ -2429,6 +2429,15 @@ def _bedrock_converse_messages_pt( # noqa: PLR0915
contents: List[BedrockMessageBlock] = [] contents: List[BedrockMessageBlock] = []
msg_i = 0 msg_i = 0
## BASE CASE ##
if len(messages) == 0:
raise litellm.BadRequestError(
message=BAD_MESSAGE_ERROR_STR
+ "bedrock requires at least one non-system message",
model=model,
llm_provider=llm_provider,
)
# if initial message is assistant message # if initial message is assistant message
if messages[0].get("role") is not None and messages[0]["role"] == "assistant": if messages[0].get("role") is not None and messages[0]["role"] == "assistant":
if user_continue_message is not None: if user_continue_message is not None:

View file

@ -113,7 +113,7 @@ from .llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
from .llms.bedrock.embed.embedding import BedrockEmbedding from .llms.bedrock.embed.embedding import BedrockEmbedding
from .llms.cohere import chat as cohere_chat from .llms.cohere import chat as cohere_chat
from .llms.cohere import completion as cohere_completion # type: ignore from .llms.cohere import completion as cohere_completion # type: ignore
from .llms.cohere import embed as cohere_embed from .llms.cohere.embed import handler as cohere_embed
from .llms.custom_llm import CustomLLM, custom_chat_llm_router from .llms.custom_llm import CustomLLM, custom_chat_llm_router
from .llms.databricks.chat import DatabricksChatCompletion from .llms.databricks.chat import DatabricksChatCompletion
from .llms.groq.chat.handler import GroqChatCompletion from .llms.groq.chat.handler import GroqChatCompletion

View file

@ -3364,54 +3364,56 @@
"litellm_provider": "cohere", "litellm_provider": "cohere",
"mode": "rerank" "mode": "rerank"
}, },
"embed-english-v3.0": {
"max_tokens": 512,
"max_input_tokens": 512,
"input_cost_per_token": 0.00000010,
"output_cost_per_token": 0.00000,
"litellm_provider": "cohere",
"mode": "embedding"
},
"embed-english-light-v3.0": { "embed-english-light-v3.0": {
"max_tokens": 512, "max_tokens": 1024,
"max_input_tokens": 512, "max_input_tokens": 1024,
"input_cost_per_token": 0.00000010, "input_cost_per_token": 0.00000010,
"output_cost_per_token": 0.00000, "output_cost_per_token": 0.00000,
"litellm_provider": "cohere", "litellm_provider": "cohere",
"mode": "embedding" "mode": "embedding"
}, },
"embed-multilingual-v3.0": { "embed-multilingual-v3.0": {
"max_tokens": 512, "max_tokens": 1024,
"max_input_tokens": 512, "max_input_tokens": 1024,
"input_cost_per_token": 0.00000010, "input_cost_per_token": 0.00000010,
"output_cost_per_token": 0.00000, "output_cost_per_token": 0.00000,
"litellm_provider": "cohere", "litellm_provider": "cohere",
"mode": "embedding" "mode": "embedding"
}, },
"embed-english-v2.0": { "embed-english-v2.0": {
"max_tokens": 512, "max_tokens": 4096,
"max_input_tokens": 512, "max_input_tokens": 4096,
"input_cost_per_token": 0.00000010, "input_cost_per_token": 0.00000010,
"output_cost_per_token": 0.00000, "output_cost_per_token": 0.00000,
"litellm_provider": "cohere", "litellm_provider": "cohere",
"mode": "embedding" "mode": "embedding"
}, },
"embed-english-light-v2.0": { "embed-english-light-v2.0": {
"max_tokens": 512, "max_tokens": 1024,
"max_input_tokens": 512, "max_input_tokens": 1024,
"input_cost_per_token": 0.00000010, "input_cost_per_token": 0.00000010,
"output_cost_per_token": 0.00000, "output_cost_per_token": 0.00000,
"litellm_provider": "cohere", "litellm_provider": "cohere",
"mode": "embedding" "mode": "embedding"
}, },
"embed-multilingual-v2.0": { "embed-multilingual-v2.0": {
"max_tokens": 256, "max_tokens": 768,
"max_input_tokens": 256, "max_input_tokens": 768,
"input_cost_per_token": 0.00000010, "input_cost_per_token": 0.00000010,
"output_cost_per_token": 0.00000, "output_cost_per_token": 0.00000,
"litellm_provider": "cohere", "litellm_provider": "cohere",
"mode": "embedding" "mode": "embedding"
}, },
"embed-english-v3.0": {
"max_tokens": 1024,
"max_input_tokens": 1024,
"input_cost_per_token": 0.00000010,
"input_cost_per_image": 0.0001,
"output_cost_per_token": 0.00000,
"litellm_provider": "cohere",
"mode": "embedding",
"supports_image_input": true
},
"replicate/meta/llama-2-13b": { "replicate/meta/llama-2-13b": {
"max_tokens": 4096, "max_tokens": 4096,
"max_input_tokens": 4096, "max_input_tokens": 4096,

View file

@ -238,6 +238,10 @@ class LiteLLMProxyRequestSetup:
- Adds org id - Adds org id
""" """
data = LitellmDataForBackendLLMCall() data = LitellmDataForBackendLLMCall()
if (
general_settings
and general_settings.get("forward_client_headers_to_llm_api") is True
):
_headers = LiteLLMProxyRequestSetup.add_headers_to_llm_call( _headers = LiteLLMProxyRequestSetup.add_headers_to_llm_call(
headers, user_api_key_dict headers, user_api_key_dict
) )

View file

@ -210,15 +210,23 @@ class ServerSentEvent:
return f"ServerSentEvent(event={self.event}, data={self.data}, id={self.id}, retry={self.retry})" return f"ServerSentEvent(event={self.event}, data={self.data}, id={self.id}, retry={self.retry})"
class CohereEmbeddingRequest(TypedDict, total=False): COHERE_EMBEDDING_INPUT_TYPES = Literal[
texts: Required[List[str]] "search_document", "search_query", "classification", "clustering", "image"
input_type: Required[
Literal["search_document", "search_query", "classification", "clustering"]
] ]
class CohereEmbeddingRequest(TypedDict, total=False):
texts: List[str]
images: List[str]
input_type: Required[COHERE_EMBEDDING_INPUT_TYPES]
truncate: Literal["NONE", "START", "END"] truncate: Literal["NONE", "START", "END"]
embedding_types: Literal["float", "int8", "uint8", "binary", "ubinary"] embedding_types: Literal["float", "int8", "uint8", "binary", "ubinary"]
class CohereEmbeddingRequestWithModel(CohereEmbeddingRequest):
model: Required[str]
class CohereEmbeddingResponse(TypedDict): class CohereEmbeddingResponse(TypedDict):
embeddings: List[List[float]] embeddings: List[List[float]]
id: str id: str

View file

@ -5197,7 +5197,9 @@ def create_proxy_transport_and_mounts():
def validate_environment( # noqa: PLR0915 def validate_environment( # noqa: PLR0915
model: Optional[str] = None, api_key: Optional[str] = None model: Optional[str] = None,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict: ) -> dict:
""" """
Checks if the environment variables are valid for the given model. Checks if the environment variables are valid for the given model.
@ -5224,11 +5226,6 @@ def validate_environment( # noqa: PLR0915
_, custom_llm_provider, _, _ = get_llm_provider(model=model) _, custom_llm_provider, _, _ = get_llm_provider(model=model)
except Exception: except Exception:
custom_llm_provider = None custom_llm_provider = None
# # check if llm provider part of model name
# if model.split("/",1)[0] in litellm.provider_list:
# custom_llm_provider = model.split("/", 1)[0]
# model = model.split("/", 1)[1]
# custom_llm_provider_passed_in = True
if custom_llm_provider: if custom_llm_provider:
if custom_llm_provider == "openai": if custom_llm_provider == "openai":
@ -5497,6 +5494,17 @@ def validate_environment( # noqa: PLR0915
if "api_key" not in key.lower(): if "api_key" not in key.lower():
new_missing_keys.append(key) new_missing_keys.append(key)
missing_keys = new_missing_keys missing_keys = new_missing_keys
if api_base is not None:
new_missing_keys = []
for key in missing_keys:
if "api_base" not in key.lower():
new_missing_keys.append(key)
missing_keys = new_missing_keys
if len(missing_keys) == 0: # no missing keys
keys_in_environment = True
return {"keys_in_environment": keys_in_environment, "missing_keys": missing_keys} return {"keys_in_environment": keys_in_environment, "missing_keys": missing_keys}

View file

@ -3364,54 +3364,56 @@
"litellm_provider": "cohere", "litellm_provider": "cohere",
"mode": "rerank" "mode": "rerank"
}, },
"embed-english-v3.0": {
"max_tokens": 512,
"max_input_tokens": 512,
"input_cost_per_token": 0.00000010,
"output_cost_per_token": 0.00000,
"litellm_provider": "cohere",
"mode": "embedding"
},
"embed-english-light-v3.0": { "embed-english-light-v3.0": {
"max_tokens": 512, "max_tokens": 1024,
"max_input_tokens": 512, "max_input_tokens": 1024,
"input_cost_per_token": 0.00000010, "input_cost_per_token": 0.00000010,
"output_cost_per_token": 0.00000, "output_cost_per_token": 0.00000,
"litellm_provider": "cohere", "litellm_provider": "cohere",
"mode": "embedding" "mode": "embedding"
}, },
"embed-multilingual-v3.0": { "embed-multilingual-v3.0": {
"max_tokens": 512, "max_tokens": 1024,
"max_input_tokens": 512, "max_input_tokens": 1024,
"input_cost_per_token": 0.00000010, "input_cost_per_token": 0.00000010,
"output_cost_per_token": 0.00000, "output_cost_per_token": 0.00000,
"litellm_provider": "cohere", "litellm_provider": "cohere",
"mode": "embedding" "mode": "embedding"
}, },
"embed-english-v2.0": { "embed-english-v2.0": {
"max_tokens": 512, "max_tokens": 4096,
"max_input_tokens": 512, "max_input_tokens": 4096,
"input_cost_per_token": 0.00000010, "input_cost_per_token": 0.00000010,
"output_cost_per_token": 0.00000, "output_cost_per_token": 0.00000,
"litellm_provider": "cohere", "litellm_provider": "cohere",
"mode": "embedding" "mode": "embedding"
}, },
"embed-english-light-v2.0": { "embed-english-light-v2.0": {
"max_tokens": 512, "max_tokens": 1024,
"max_input_tokens": 512, "max_input_tokens": 1024,
"input_cost_per_token": 0.00000010, "input_cost_per_token": 0.00000010,
"output_cost_per_token": 0.00000, "output_cost_per_token": 0.00000,
"litellm_provider": "cohere", "litellm_provider": "cohere",
"mode": "embedding" "mode": "embedding"
}, },
"embed-multilingual-v2.0": { "embed-multilingual-v2.0": {
"max_tokens": 256, "max_tokens": 768,
"max_input_tokens": 256, "max_input_tokens": 768,
"input_cost_per_token": 0.00000010, "input_cost_per_token": 0.00000010,
"output_cost_per_token": 0.00000, "output_cost_per_token": 0.00000,
"litellm_provider": "cohere", "litellm_provider": "cohere",
"mode": "embedding" "mode": "embedding"
}, },
"embed-english-v3.0": {
"max_tokens": 1024,
"max_input_tokens": 1024,
"input_cost_per_token": 0.00000010,
"input_cost_per_image": 0.0001,
"output_cost_per_token": 0.00000,
"litellm_provider": "cohere",
"mode": "embedding",
"supports_image_input": true
},
"replicate/meta/llama-2-13b": { "replicate/meta/llama-2-13b": {
"max_tokens": 4096, "max_tokens": 4096,
"max_input_tokens": 4096, "max_input_tokens": 4096,

File diff suppressed because one or more lines are too long

View file

@ -160,3 +160,12 @@ def test_get_llm_provider_jina_ai():
assert custom_llm_provider == "openai_like" assert custom_llm_provider == "openai_like"
assert api_base == "https://api.jina.ai/v1" assert api_base == "https://api.jina.ai/v1"
assert model == "jina-embeddings-v3" assert model == "jina-embeddings-v3"
def test_get_llm_provider_hosted_vllm():
model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider(
model="hosted_vllm/llama-3.1-70b-instruct",
)
assert custom_llm_provider == "hosted_vllm"
assert model == "llama-3.1-70b-instruct"
assert dynamic_api_key == ""

View file

@ -675,3 +675,15 @@ def test_alternating_roles_e2e():
"stream": False, "stream": False,
} }
) )
def test_just_system_message():
from litellm.llms.prompt_templates.factory import _bedrock_converse_messages_pt
with pytest.raises(litellm.BadRequestError) as e:
_bedrock_converse_messages_pt(
messages=[],
model="anthropic.claude-3-sonnet-20240229-v1:0",
llm_provider="bedrock",
)
assert "bedrock requires at least one non-system message" in str(e.value)

View file

@ -225,12 +225,20 @@ def test_add_headers_to_request(litellm_key_header_name):
"litellm_key_header_name", "litellm_key_header_name",
["x-litellm-key", None], ["x-litellm-key", None],
) )
@pytest.mark.parametrize(
"forward_headers",
[True, False],
)
@mock_patch_acompletion() @mock_patch_acompletion()
def test_chat_completion_forward_headers( def test_chat_completion_forward_headers(
mock_acompletion, client_no_auth, litellm_key_header_name mock_acompletion, client_no_auth, litellm_key_header_name, forward_headers
): ):
global headers global headers
try: try:
if forward_headers:
gs = getattr(litellm.proxy.proxy_server, "general_settings")
gs["forward_client_headers_to_llm_api"] = True
setattr(litellm.proxy.proxy_server, "general_settings", gs)
if litellm_key_header_name is not None: if litellm_key_header_name is not None:
gs = getattr(litellm.proxy.proxy_server, "general_settings") gs = getattr(litellm.proxy.proxy_server, "general_settings")
gs["litellm_key_header_name"] = litellm_key_header_name gs["litellm_key_header_name"] = litellm_key_header_name
@ -260,23 +268,14 @@ def test_chat_completion_forward_headers(
response = client_no_auth.post( response = client_no_auth.post(
"/v1/chat/completions", json=test_data, headers=received_headers "/v1/chat/completions", json=test_data, headers=received_headers
) )
mock_acompletion.assert_called_once_with( if not forward_headers:
model="gpt-3.5-turbo", assert "headers" not in mock_acompletion.call_args.kwargs
messages=[ else:
{"role": "user", "content": "hi"}, assert mock_acompletion.call_args.kwargs["headers"] == {
],
max_tokens=10,
litellm_call_id=mock.ANY,
litellm_logging_obj=mock.ANY,
request_timeout=mock.ANY,
specific_deployment=True,
metadata=mock.ANY,
proxy_server_request=mock.ANY,
headers={
"x-custom-header": "Custom-Value", "x-custom-header": "Custom-Value",
"x-another-header": "Another-Value", "x-another-header": "Another-Value",
}, }
)
print(f"response - {response.text}") print(f"response - {response.text}")
assert response.status_code == 200 assert response.status_code == 200
result = response.json() result = response.json()

View file

@ -331,6 +331,13 @@ def test_validate_environment_api_key():
), f"Missing keys={response_obj['missing_keys']}" ), f"Missing keys={response_obj['missing_keys']}"
def test_validate_environment_api_base_dynamic():
for provider in ["ollama", "ollama_chat"]:
kv = validate_environment(provider + "/mistral", api_base="https://example.com")
assert kv["keys_in_environment"]
assert kv["missing_keys"] == []
@mock.patch.dict(os.environ, {"OLLAMA_API_BASE": "foo"}, clear=True) @mock.patch.dict(os.environ, {"OLLAMA_API_BASE": "foo"}, clear=True)
def test_validate_environment_ollama(): def test_validate_environment_ollama():
for provider in ["ollama", "ollama_chat"]: for provider in ["ollama", "ollama_chat"]: