forked from phoenix/litellm-mirror
LiteLLM Minor Fixes & Improvements (10/24/2024) (#6421)
* fix(utils.py): support passing dynamic api base to validate_environment Returns True if just api base is required and api base is passed * fix(litellm_pre_call_utils.py): feature flag sending client headers to llm api Fixes https://github.com/BerriAI/litellm/issues/6410 * fix(anthropic/chat/transformation.py): return correct error message * fix(http_handler.py): add error response text in places where we expect it * fix(factory.py): handle base case of no non-system messages to bedrock Fixes https://github.com/BerriAI/litellm/issues/6411 * feat(cohere/embed): Support cohere image embeddings Closes https://github.com/BerriAI/litellm/issues/6413 * fix(__init__.py): fix linting error * docs(supported_embedding.md): add image embedding example to docs * feat(cohere/embed): use cohere embedding returned usage for cost calc * build(model_prices_and_context_window.json): add embed-english-v3.0 details (image cost + 'supports_image_input' flag) * fix(cohere_transformation.py): fix linting error * test(test_proxy_server.py): cleanup test * test: cleanup test * fix: fix linting errors
This commit is contained in:
parent
38708a355a
commit
c03e5da41f
23 changed files with 417 additions and 150 deletions
|
@ -84,6 +84,60 @@ print(query_result[:5])
|
||||||
</TabItem>
|
</TabItem>
|
||||||
</Tabs>
|
</Tabs>
|
||||||
|
|
||||||
|
|
||||||
|
## Image Embeddings
|
||||||
|
|
||||||
|
For models that support image embeddings, you can pass in a base64 encoded image string to the `input` param.
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
|
```python
|
||||||
|
from litellm import embedding
|
||||||
|
import os
|
||||||
|
|
||||||
|
# set your api key
|
||||||
|
os.environ["COHERE_API_KEY"] = ""
|
||||||
|
|
||||||
|
response = embedding(model="cohere/embed-english-v3.0", input=["<base64 encoded image>"])
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="proxy" label="PROXY">
|
||||||
|
|
||||||
|
1. Setup config.yaml
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: cohere-embed
|
||||||
|
litellm_params:
|
||||||
|
model: cohere/embed-english-v3.0
|
||||||
|
api_key: os.environ/COHERE_API_KEY
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
2. Start proxy
|
||||||
|
|
||||||
|
```bash
|
||||||
|
litellm --config /path/to/config.yaml
|
||||||
|
|
||||||
|
# RUNNING on http://0.0.0.0:4000
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Test it!
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST 'http://0.0.0.0:4000/v1/embeddings' \
|
||||||
|
-H 'Authorization: Bearer sk-54d77cd67b9febbb' \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-d '{
|
||||||
|
"model": "cohere/embed-english-v3.0",
|
||||||
|
"input": ["<base64 encoded image>"]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
## Input Params for `litellm.embedding()`
|
## Input Params for `litellm.embedding()`
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -814,6 +814,7 @@ general_settings:
|
||||||
| pass_through_endpoints | List[Dict[str, Any]] | Define the pass through endpoints. [Docs](./pass_through) |
|
| pass_through_endpoints | List[Dict[str, Any]] | Define the pass through endpoints. [Docs](./pass_through) |
|
||||||
| enable_oauth2_proxy_auth | boolean | (Enterprise Feature) If true, enables oauth2.0 authentication |
|
| enable_oauth2_proxy_auth | boolean | (Enterprise Feature) If true, enables oauth2.0 authentication |
|
||||||
| forward_openai_org_id | boolean | If true, forwards the OpenAI Organization ID to the backend LLM call (if it's OpenAI). |
|
| forward_openai_org_id | boolean | If true, forwards the OpenAI Organization ID to the backend LLM call (if it's OpenAI). |
|
||||||
|
| forward_client_headers_to_llm_api | boolean | If true, forwards the client headers (any `x-` headers) to the backend LLM call |
|
||||||
|
|
||||||
### router_settings - Reference
|
### router_settings - Reference
|
||||||
|
|
||||||
|
|
|
@ -8,6 +8,7 @@ import os
|
||||||
from typing import Callable, List, Optional, Dict, Union, Any, Literal, get_args
|
from typing import Callable, List, Optional, Dict, Union, Any, Literal, get_args
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
from litellm.caching.caching import Cache, DualCache, RedisCache, InMemoryCache
|
from litellm.caching.caching import Cache, DualCache, RedisCache, InMemoryCache
|
||||||
|
from litellm.types.llms.bedrock import COHERE_EMBEDDING_INPUT_TYPES
|
||||||
from litellm._logging import (
|
from litellm._logging import (
|
||||||
set_verbose,
|
set_verbose,
|
||||||
_turn_on_debug,
|
_turn_on_debug,
|
||||||
|
@ -136,7 +137,7 @@ enable_azure_ad_token_refresh: Optional[bool] = False
|
||||||
### DEFAULT AZURE API VERSION ###
|
### DEFAULT AZURE API VERSION ###
|
||||||
AZURE_DEFAULT_API_VERSION = "2024-08-01-preview" # this is updated to the latest
|
AZURE_DEFAULT_API_VERSION = "2024-08-01-preview" # this is updated to the latest
|
||||||
### COHERE EMBEDDINGS DEFAULT TYPE ###
|
### COHERE EMBEDDINGS DEFAULT TYPE ###
|
||||||
COHERE_DEFAULT_EMBEDDING_INPUT_TYPE = "search_document"
|
COHERE_DEFAULT_EMBEDDING_INPUT_TYPE: COHERE_EMBEDDING_INPUT_TYPES = "search_document"
|
||||||
### GUARDRAILS ###
|
### GUARDRAILS ###
|
||||||
llamaguard_model_name: Optional[str] = None
|
llamaguard_model_name: Optional[str] = None
|
||||||
openai_moderations_model_name: Optional[str] = None
|
openai_moderations_model_name: Optional[str] = None
|
||||||
|
|
|
@ -333,6 +333,14 @@ def _get_openai_compatible_provider_info( # noqa: PLR0915
|
||||||
api_key: Optional[str],
|
api_key: Optional[str],
|
||||||
dynamic_api_key: Optional[str],
|
dynamic_api_key: Optional[str],
|
||||||
) -> Tuple[str, str, Optional[str], Optional[str]]:
|
) -> Tuple[str, str, Optional[str], Optional[str]]:
|
||||||
|
"""
|
||||||
|
Returns:
|
||||||
|
Tuple[str, str, Optional[str], Optional[str]]:
|
||||||
|
model: str
|
||||||
|
custom_llm_provider: str
|
||||||
|
dynamic_api_key: Optional[str]
|
||||||
|
api_base: Optional[str]
|
||||||
|
"""
|
||||||
custom_llm_provider = model.split("/", 1)[0]
|
custom_llm_provider = model.split("/", 1)[0]
|
||||||
model = model.split("/", 1)[1]
|
model = model.split("/", 1)[1]
|
||||||
|
|
||||||
|
|
|
@ -398,6 +398,8 @@ class AnthropicChatCompletion(BaseLLM):
|
||||||
error_response = getattr(e, "response", None)
|
error_response = getattr(e, "response", None)
|
||||||
if error_headers is None and error_response:
|
if error_headers is None and error_response:
|
||||||
error_headers = getattr(error_response, "headers", None)
|
error_headers = getattr(error_response, "headers", None)
|
||||||
|
if error_response and hasattr(error_response, "text"):
|
||||||
|
error_text = getattr(error_response, "text", error_text)
|
||||||
raise AnthropicError(
|
raise AnthropicError(
|
||||||
message=error_text,
|
message=error_text,
|
||||||
status_code=status_code,
|
status_code=status_code,
|
||||||
|
|
|
@ -9,7 +9,7 @@ import httpx
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.llms.cohere.embed import embedding as cohere_embedding
|
from litellm.llms.cohere.embed.handler import embedding as cohere_embedding
|
||||||
from litellm.llms.custom_httpx.http_handler import (
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
AsyncHTTPHandler,
|
AsyncHTTPHandler,
|
||||||
HTTPHandler,
|
HTTPHandler,
|
||||||
|
|
|
@ -7,6 +7,7 @@ Why separate file? Make it easy to see how transformation works
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
from litellm.llms.cohere.embed.transformation import CohereEmbeddingConfig
|
||||||
from litellm.types.llms.bedrock import CohereEmbeddingRequest, CohereEmbeddingResponse
|
from litellm.types.llms.bedrock import CohereEmbeddingRequest, CohereEmbeddingResponse
|
||||||
from litellm.types.utils import Embedding, EmbeddingResponse
|
from litellm.types.utils import Embedding, EmbeddingResponse
|
||||||
|
|
||||||
|
@ -26,15 +27,21 @@ class BedrockCohereEmbeddingConfig:
|
||||||
optional_params["embedding_types"] = v
|
optional_params["embedding_types"] = v
|
||||||
return optional_params
|
return optional_params
|
||||||
|
|
||||||
|
def _is_v3_model(self, model: str) -> bool:
|
||||||
|
return "3" in model
|
||||||
|
|
||||||
def _transform_request(
|
def _transform_request(
|
||||||
self, input: List[str], inference_params: dict
|
self, model: str, input: List[str], inference_params: dict
|
||||||
) -> CohereEmbeddingRequest:
|
) -> CohereEmbeddingRequest:
|
||||||
transformed_request = CohereEmbeddingRequest(
|
transformed_request = CohereEmbeddingConfig()._transform_request(
|
||||||
texts=input,
|
model, input, inference_params
|
||||||
input_type=litellm.COHERE_DEFAULT_EMBEDDING_INPUT_TYPE, # type: ignore
|
|
||||||
)
|
)
|
||||||
|
|
||||||
for k, v in inference_params.items():
|
new_transformed_request = CohereEmbeddingRequest(
|
||||||
transformed_request[k] = v # type: ignore
|
input_type=transformed_request["input_type"],
|
||||||
|
)
|
||||||
|
for k in CohereEmbeddingRequest.__annotations__.keys():
|
||||||
|
if k in transformed_request:
|
||||||
|
new_transformed_request[k] = transformed_request[k] # type: ignore
|
||||||
|
|
||||||
return transformed_request
|
return new_transformed_request
|
||||||
|
|
|
@ -11,7 +11,7 @@ from typing import Any, Callable, List, Literal, Optional, Tuple, Union
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.llms.cohere.embed import embedding as cohere_embedding
|
from litellm.llms.cohere.embed.handler import embedding as cohere_embedding
|
||||||
from litellm.llms.custom_httpx.http_handler import (
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
AsyncHTTPHandler,
|
AsyncHTTPHandler,
|
||||||
HTTPHandler,
|
HTTPHandler,
|
||||||
|
@ -369,7 +369,7 @@ class BedrockEmbedding(BaseAWSLLM):
|
||||||
batch_data: Optional[List] = None
|
batch_data: Optional[List] = None
|
||||||
if provider == "cohere":
|
if provider == "cohere":
|
||||||
data = BedrockCohereEmbeddingConfig()._transform_request(
|
data = BedrockCohereEmbeddingConfig()._transform_request(
|
||||||
input=input, inference_params=inference_params
|
model=model, input=input, inference_params=inference_params
|
||||||
)
|
)
|
||||||
elif provider == "amazon" and model in [
|
elif provider == "amazon" and model in [
|
||||||
"amazon.titan-embed-image-v1",
|
"amazon.titan-embed-image-v1",
|
||||||
|
|
|
@ -12,8 +12,11 @@ import requests # type: ignore
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
|
from litellm.types.llms.bedrock import CohereEmbeddingRequest
|
||||||
from litellm.utils import Choices, Message, ModelResponse, Usage
|
from litellm.utils import Choices, Message, ModelResponse, Usage
|
||||||
|
|
||||||
|
from .transformation import CohereEmbeddingConfig
|
||||||
|
|
||||||
|
|
||||||
def validate_environment(api_key, headers: dict):
|
def validate_environment(api_key, headers: dict):
|
||||||
headers.update(
|
headers.update(
|
||||||
|
@ -41,39 +44,9 @@ class CohereError(Exception):
|
||||||
) # Call the base class constructor with the parameters it needs
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
|
|
||||||
def _process_embedding_response(
|
|
||||||
embeddings: list,
|
|
||||||
model_response: litellm.EmbeddingResponse,
|
|
||||||
model: str,
|
|
||||||
encoding: Any,
|
|
||||||
input: list,
|
|
||||||
) -> litellm.EmbeddingResponse:
|
|
||||||
output_data = []
|
|
||||||
for idx, embedding in enumerate(embeddings):
|
|
||||||
output_data.append(
|
|
||||||
{"object": "embedding", "index": idx, "embedding": embedding}
|
|
||||||
)
|
|
||||||
model_response.object = "list"
|
|
||||||
model_response.data = output_data
|
|
||||||
model_response.model = model
|
|
||||||
input_tokens = 0
|
|
||||||
for text in input:
|
|
||||||
input_tokens += len(encoding.encode(text))
|
|
||||||
|
|
||||||
setattr(
|
|
||||||
model_response,
|
|
||||||
"usage",
|
|
||||||
Usage(
|
|
||||||
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
return model_response
|
|
||||||
|
|
||||||
|
|
||||||
async def async_embedding(
|
async def async_embedding(
|
||||||
model: str,
|
model: str,
|
||||||
data: dict,
|
data: Union[dict, CohereEmbeddingRequest],
|
||||||
input: list,
|
input: list,
|
||||||
model_response: litellm.utils.EmbeddingResponse,
|
model_response: litellm.utils.EmbeddingResponse,
|
||||||
timeout: Optional[Union[float, httpx.Timeout]],
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
|
@ -121,19 +94,12 @@ async def async_embedding(
|
||||||
)
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
## LOGGING
|
|
||||||
logging_obj.post_call(
|
|
||||||
input=input,
|
|
||||||
api_key=api_key,
|
|
||||||
additional_args={"complete_input_dict": data},
|
|
||||||
original_response=response.text,
|
|
||||||
)
|
|
||||||
|
|
||||||
embeddings = response.json()["embeddings"]
|
|
||||||
|
|
||||||
## PROCESS RESPONSE ##
|
## PROCESS RESPONSE ##
|
||||||
return _process_embedding_response(
|
return CohereEmbeddingConfig()._transform_response(
|
||||||
embeddings=embeddings,
|
response=response,
|
||||||
|
api_key=api_key,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
data=data,
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
model=model,
|
model=model,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
|
@ -149,7 +115,7 @@ def embedding(
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
headers: dict,
|
headers: dict,
|
||||||
encoding: Any,
|
encoding: Any,
|
||||||
data: Optional[dict] = None,
|
data: Optional[Union[dict, CohereEmbeddingRequest]] = None,
|
||||||
complete_api_base: Optional[str] = None,
|
complete_api_base: Optional[str] = None,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
aembedding: Optional[bool] = None,
|
aembedding: Optional[bool] = None,
|
||||||
|
@ -159,11 +125,10 @@ def embedding(
|
||||||
headers = validate_environment(api_key, headers=headers)
|
headers = validate_environment(api_key, headers=headers)
|
||||||
embed_url = complete_api_base or "https://api.cohere.ai/v1/embed"
|
embed_url = complete_api_base or "https://api.cohere.ai/v1/embed"
|
||||||
model = model
|
model = model
|
||||||
data = data or {"model": model, "texts": input, **optional_params}
|
|
||||||
|
|
||||||
if "3" in model and "input_type" not in data:
|
data = data or CohereEmbeddingConfig()._transform_request(
|
||||||
# cohere v3 embedding models require input_type, if no input_type is provided, default to "search_document"
|
model=model, input=input, inference_params=optional_params
|
||||||
data["input_type"] = "search_document"
|
)
|
||||||
|
|
||||||
## ROUTING
|
## ROUTING
|
||||||
if aembedding is True:
|
if aembedding is True:
|
||||||
|
@ -193,30 +158,12 @@ def embedding(
|
||||||
client = HTTPHandler(concurrent_limit=1)
|
client = HTTPHandler(concurrent_limit=1)
|
||||||
|
|
||||||
response = client.post(embed_url, headers=headers, data=json.dumps(data))
|
response = client.post(embed_url, headers=headers, data=json.dumps(data))
|
||||||
## LOGGING
|
|
||||||
logging_obj.post_call(
|
return CohereEmbeddingConfig()._transform_response(
|
||||||
input=input,
|
response=response,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
additional_args={"complete_input_dict": data},
|
logging_obj=logging_obj,
|
||||||
original_response=response,
|
data=data,
|
||||||
)
|
|
||||||
"""
|
|
||||||
response
|
|
||||||
{
|
|
||||||
'object': "list",
|
|
||||||
'data': [
|
|
||||||
|
|
||||||
]
|
|
||||||
'model',
|
|
||||||
'usage'
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
if response.status_code != 200:
|
|
||||||
raise CohereError(message=response.text, status_code=response.status_code)
|
|
||||||
embeddings = response.json()["embeddings"]
|
|
||||||
|
|
||||||
return _process_embedding_response(
|
|
||||||
embeddings=embeddings,
|
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
model=model,
|
model=model,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
160
litellm/llms/cohere/embed/transformation.py
Normal file
160
litellm/llms/cohere/embed/transformation.py
Normal file
|
@ -0,0 +1,160 @@
|
||||||
|
"""
|
||||||
|
Transformation logic from OpenAI /v1/embeddings format to Cohere's /v1/embed format.
|
||||||
|
|
||||||
|
Why separate file? Make it easy to see how transformation works
|
||||||
|
|
||||||
|
Convers
|
||||||
|
- v3 embedding models
|
||||||
|
- v2 embedding models
|
||||||
|
|
||||||
|
Docs - https://docs.cohere.com/v2/reference/embed
|
||||||
|
"""
|
||||||
|
|
||||||
|
import types
|
||||||
|
from typing import Any, List, Optional, Union
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from litellm import COHERE_DEFAULT_EMBEDDING_INPUT_TYPE
|
||||||
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||||
|
from litellm.types.llms.bedrock import (
|
||||||
|
COHERE_EMBEDDING_INPUT_TYPES,
|
||||||
|
CohereEmbeddingRequest,
|
||||||
|
CohereEmbeddingRequestWithModel,
|
||||||
|
)
|
||||||
|
from litellm.types.utils import (
|
||||||
|
Embedding,
|
||||||
|
EmbeddingResponse,
|
||||||
|
PromptTokensDetailsWrapper,
|
||||||
|
Usage,
|
||||||
|
)
|
||||||
|
from litellm.utils import is_base64_encoded
|
||||||
|
|
||||||
|
|
||||||
|
class CohereEmbeddingConfig:
|
||||||
|
"""
|
||||||
|
Reference: https://docs.cohere.com/v2/reference/embed
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_supported_openai_params(self) -> List[str]:
|
||||||
|
return ["encoding_format"]
|
||||||
|
|
||||||
|
def map_openai_params(
|
||||||
|
self, non_default_params: dict, optional_params: dict
|
||||||
|
) -> dict:
|
||||||
|
for k, v in non_default_params.items():
|
||||||
|
if k == "encoding_format":
|
||||||
|
optional_params["embedding_types"] = v
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
def _is_v3_model(self, model: str) -> bool:
|
||||||
|
return "3" in model
|
||||||
|
|
||||||
|
def _transform_request(
|
||||||
|
self, model: str, input: List[str], inference_params: dict
|
||||||
|
) -> CohereEmbeddingRequestWithModel:
|
||||||
|
is_encoded = False
|
||||||
|
for input_str in input:
|
||||||
|
is_encoded = is_base64_encoded(input_str)
|
||||||
|
|
||||||
|
if is_encoded: # check if string is b64 encoded image or not
|
||||||
|
transformed_request = CohereEmbeddingRequestWithModel(
|
||||||
|
model=model,
|
||||||
|
images=input,
|
||||||
|
input_type="image",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
transformed_request = CohereEmbeddingRequestWithModel(
|
||||||
|
model=model,
|
||||||
|
texts=input,
|
||||||
|
input_type=COHERE_DEFAULT_EMBEDDING_INPUT_TYPE,
|
||||||
|
)
|
||||||
|
|
||||||
|
for k, v in inference_params.items():
|
||||||
|
transformed_request[k] = v # type: ignore
|
||||||
|
|
||||||
|
return transformed_request
|
||||||
|
|
||||||
|
def _calculate_usage(self, input: List[str], encoding: Any, meta: dict) -> Usage:
|
||||||
|
|
||||||
|
input_tokens = 0
|
||||||
|
|
||||||
|
text_tokens: Optional[int] = meta.get("billed_units", {}).get("input_tokens")
|
||||||
|
|
||||||
|
image_tokens: Optional[int] = meta.get("billed_units", {}).get("images")
|
||||||
|
|
||||||
|
prompt_tokens_details: Optional[PromptTokensDetailsWrapper] = None
|
||||||
|
if image_tokens is None and text_tokens is None:
|
||||||
|
for text in input:
|
||||||
|
input_tokens += len(encoding.encode(text))
|
||||||
|
else:
|
||||||
|
prompt_tokens_details = PromptTokensDetailsWrapper(
|
||||||
|
image_tokens=image_tokens,
|
||||||
|
text_tokens=text_tokens,
|
||||||
|
)
|
||||||
|
if image_tokens:
|
||||||
|
input_tokens += image_tokens
|
||||||
|
if text_tokens:
|
||||||
|
input_tokens += text_tokens
|
||||||
|
|
||||||
|
return Usage(
|
||||||
|
prompt_tokens=input_tokens,
|
||||||
|
completion_tokens=0,
|
||||||
|
total_tokens=input_tokens,
|
||||||
|
prompt_tokens_details=prompt_tokens_details,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _transform_response(
|
||||||
|
self,
|
||||||
|
response: httpx.Response,
|
||||||
|
api_key: Optional[str],
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
|
data: Union[dict, CohereEmbeddingRequest],
|
||||||
|
model_response: EmbeddingResponse,
|
||||||
|
model: str,
|
||||||
|
encoding: Any,
|
||||||
|
input: list,
|
||||||
|
) -> EmbeddingResponse:
|
||||||
|
|
||||||
|
response_json = response.json()
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=input,
|
||||||
|
api_key=api_key,
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
original_response=response_json,
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
response
|
||||||
|
{
|
||||||
|
'object': "list",
|
||||||
|
'data': [
|
||||||
|
|
||||||
|
]
|
||||||
|
'model',
|
||||||
|
'usage'
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
embeddings = response_json["embeddings"]
|
||||||
|
output_data = []
|
||||||
|
for idx, embedding in enumerate(embeddings):
|
||||||
|
output_data.append(
|
||||||
|
{"object": "embedding", "index": idx, "embedding": embedding}
|
||||||
|
)
|
||||||
|
model_response.object = "list"
|
||||||
|
model_response.data = output_data
|
||||||
|
model_response.model = model
|
||||||
|
input_tokens = 0
|
||||||
|
for text in input:
|
||||||
|
input_tokens += len(encoding.encode(text))
|
||||||
|
|
||||||
|
setattr(
|
||||||
|
model_response,
|
||||||
|
"usage",
|
||||||
|
self._calculate_usage(input, encoding, response_json.get("meta", {})),
|
||||||
|
)
|
||||||
|
|
||||||
|
return model_response
|
|
@ -152,8 +152,10 @@ class AsyncHTTPHandler:
|
||||||
setattr(e, "status_code", e.response.status_code)
|
setattr(e, "status_code", e.response.status_code)
|
||||||
if stream is True:
|
if stream is True:
|
||||||
setattr(e, "message", await e.response.aread())
|
setattr(e, "message", await e.response.aread())
|
||||||
|
setattr(e, "text", await e.response.aread())
|
||||||
else:
|
else:
|
||||||
setattr(e, "message", e.response.text)
|
setattr(e, "message", e.response.text)
|
||||||
|
setattr(e, "text", e.response.text)
|
||||||
raise e
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
|
@ -2429,6 +2429,15 @@ def _bedrock_converse_messages_pt( # noqa: PLR0915
|
||||||
contents: List[BedrockMessageBlock] = []
|
contents: List[BedrockMessageBlock] = []
|
||||||
msg_i = 0
|
msg_i = 0
|
||||||
|
|
||||||
|
## BASE CASE ##
|
||||||
|
if len(messages) == 0:
|
||||||
|
raise litellm.BadRequestError(
|
||||||
|
message=BAD_MESSAGE_ERROR_STR
|
||||||
|
+ "bedrock requires at least one non-system message",
|
||||||
|
model=model,
|
||||||
|
llm_provider=llm_provider,
|
||||||
|
)
|
||||||
|
|
||||||
# if initial message is assistant message
|
# if initial message is assistant message
|
||||||
if messages[0].get("role") is not None and messages[0]["role"] == "assistant":
|
if messages[0].get("role") is not None and messages[0]["role"] == "assistant":
|
||||||
if user_continue_message is not None:
|
if user_continue_message is not None:
|
||||||
|
|
|
@ -113,7 +113,7 @@ from .llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
|
||||||
from .llms.bedrock.embed.embedding import BedrockEmbedding
|
from .llms.bedrock.embed.embedding import BedrockEmbedding
|
||||||
from .llms.cohere import chat as cohere_chat
|
from .llms.cohere import chat as cohere_chat
|
||||||
from .llms.cohere import completion as cohere_completion # type: ignore
|
from .llms.cohere import completion as cohere_completion # type: ignore
|
||||||
from .llms.cohere import embed as cohere_embed
|
from .llms.cohere.embed import handler as cohere_embed
|
||||||
from .llms.custom_llm import CustomLLM, custom_chat_llm_router
|
from .llms.custom_llm import CustomLLM, custom_chat_llm_router
|
||||||
from .llms.databricks.chat import DatabricksChatCompletion
|
from .llms.databricks.chat import DatabricksChatCompletion
|
||||||
from .llms.groq.chat.handler import GroqChatCompletion
|
from .llms.groq.chat.handler import GroqChatCompletion
|
||||||
|
|
|
@ -3364,54 +3364,56 @@
|
||||||
"litellm_provider": "cohere",
|
"litellm_provider": "cohere",
|
||||||
"mode": "rerank"
|
"mode": "rerank"
|
||||||
},
|
},
|
||||||
"embed-english-v3.0": {
|
|
||||||
"max_tokens": 512,
|
|
||||||
"max_input_tokens": 512,
|
|
||||||
"input_cost_per_token": 0.00000010,
|
|
||||||
"output_cost_per_token": 0.00000,
|
|
||||||
"litellm_provider": "cohere",
|
|
||||||
"mode": "embedding"
|
|
||||||
},
|
|
||||||
"embed-english-light-v3.0": {
|
"embed-english-light-v3.0": {
|
||||||
"max_tokens": 512,
|
"max_tokens": 1024,
|
||||||
"max_input_tokens": 512,
|
"max_input_tokens": 1024,
|
||||||
"input_cost_per_token": 0.00000010,
|
"input_cost_per_token": 0.00000010,
|
||||||
"output_cost_per_token": 0.00000,
|
"output_cost_per_token": 0.00000,
|
||||||
"litellm_provider": "cohere",
|
"litellm_provider": "cohere",
|
||||||
"mode": "embedding"
|
"mode": "embedding"
|
||||||
},
|
},
|
||||||
"embed-multilingual-v3.0": {
|
"embed-multilingual-v3.0": {
|
||||||
"max_tokens": 512,
|
"max_tokens": 1024,
|
||||||
"max_input_tokens": 512,
|
"max_input_tokens": 1024,
|
||||||
"input_cost_per_token": 0.00000010,
|
"input_cost_per_token": 0.00000010,
|
||||||
"output_cost_per_token": 0.00000,
|
"output_cost_per_token": 0.00000,
|
||||||
"litellm_provider": "cohere",
|
"litellm_provider": "cohere",
|
||||||
"mode": "embedding"
|
"mode": "embedding"
|
||||||
},
|
},
|
||||||
"embed-english-v2.0": {
|
"embed-english-v2.0": {
|
||||||
"max_tokens": 512,
|
"max_tokens": 4096,
|
||||||
"max_input_tokens": 512,
|
"max_input_tokens": 4096,
|
||||||
"input_cost_per_token": 0.00000010,
|
"input_cost_per_token": 0.00000010,
|
||||||
"output_cost_per_token": 0.00000,
|
"output_cost_per_token": 0.00000,
|
||||||
"litellm_provider": "cohere",
|
"litellm_provider": "cohere",
|
||||||
"mode": "embedding"
|
"mode": "embedding"
|
||||||
},
|
},
|
||||||
"embed-english-light-v2.0": {
|
"embed-english-light-v2.0": {
|
||||||
"max_tokens": 512,
|
"max_tokens": 1024,
|
||||||
"max_input_tokens": 512,
|
"max_input_tokens": 1024,
|
||||||
"input_cost_per_token": 0.00000010,
|
"input_cost_per_token": 0.00000010,
|
||||||
"output_cost_per_token": 0.00000,
|
"output_cost_per_token": 0.00000,
|
||||||
"litellm_provider": "cohere",
|
"litellm_provider": "cohere",
|
||||||
"mode": "embedding"
|
"mode": "embedding"
|
||||||
},
|
},
|
||||||
"embed-multilingual-v2.0": {
|
"embed-multilingual-v2.0": {
|
||||||
"max_tokens": 256,
|
"max_tokens": 768,
|
||||||
"max_input_tokens": 256,
|
"max_input_tokens": 768,
|
||||||
"input_cost_per_token": 0.00000010,
|
"input_cost_per_token": 0.00000010,
|
||||||
"output_cost_per_token": 0.00000,
|
"output_cost_per_token": 0.00000,
|
||||||
"litellm_provider": "cohere",
|
"litellm_provider": "cohere",
|
||||||
"mode": "embedding"
|
"mode": "embedding"
|
||||||
},
|
},
|
||||||
|
"embed-english-v3.0": {
|
||||||
|
"max_tokens": 1024,
|
||||||
|
"max_input_tokens": 1024,
|
||||||
|
"input_cost_per_token": 0.00000010,
|
||||||
|
"input_cost_per_image": 0.0001,
|
||||||
|
"output_cost_per_token": 0.00000,
|
||||||
|
"litellm_provider": "cohere",
|
||||||
|
"mode": "embedding",
|
||||||
|
"supports_image_input": true
|
||||||
|
},
|
||||||
"replicate/meta/llama-2-13b": {
|
"replicate/meta/llama-2-13b": {
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
"max_input_tokens": 4096,
|
"max_input_tokens": 4096,
|
||||||
|
|
|
@ -238,6 +238,10 @@ class LiteLLMProxyRequestSetup:
|
||||||
- Adds org id
|
- Adds org id
|
||||||
"""
|
"""
|
||||||
data = LitellmDataForBackendLLMCall()
|
data = LitellmDataForBackendLLMCall()
|
||||||
|
if (
|
||||||
|
general_settings
|
||||||
|
and general_settings.get("forward_client_headers_to_llm_api") is True
|
||||||
|
):
|
||||||
_headers = LiteLLMProxyRequestSetup.add_headers_to_llm_call(
|
_headers = LiteLLMProxyRequestSetup.add_headers_to_llm_call(
|
||||||
headers, user_api_key_dict
|
headers, user_api_key_dict
|
||||||
)
|
)
|
||||||
|
|
|
@ -210,15 +210,23 @@ class ServerSentEvent:
|
||||||
return f"ServerSentEvent(event={self.event}, data={self.data}, id={self.id}, retry={self.retry})"
|
return f"ServerSentEvent(event={self.event}, data={self.data}, id={self.id}, retry={self.retry})"
|
||||||
|
|
||||||
|
|
||||||
class CohereEmbeddingRequest(TypedDict, total=False):
|
COHERE_EMBEDDING_INPUT_TYPES = Literal[
|
||||||
texts: Required[List[str]]
|
"search_document", "search_query", "classification", "clustering", "image"
|
||||||
input_type: Required[
|
|
||||||
Literal["search_document", "search_query", "classification", "clustering"]
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class CohereEmbeddingRequest(TypedDict, total=False):
|
||||||
|
texts: List[str]
|
||||||
|
images: List[str]
|
||||||
|
input_type: Required[COHERE_EMBEDDING_INPUT_TYPES]
|
||||||
truncate: Literal["NONE", "START", "END"]
|
truncate: Literal["NONE", "START", "END"]
|
||||||
embedding_types: Literal["float", "int8", "uint8", "binary", "ubinary"]
|
embedding_types: Literal["float", "int8", "uint8", "binary", "ubinary"]
|
||||||
|
|
||||||
|
|
||||||
|
class CohereEmbeddingRequestWithModel(CohereEmbeddingRequest):
|
||||||
|
model: Required[str]
|
||||||
|
|
||||||
|
|
||||||
class CohereEmbeddingResponse(TypedDict):
|
class CohereEmbeddingResponse(TypedDict):
|
||||||
embeddings: List[List[float]]
|
embeddings: List[List[float]]
|
||||||
id: str
|
id: str
|
||||||
|
|
|
@ -5197,7 +5197,9 @@ def create_proxy_transport_and_mounts():
|
||||||
|
|
||||||
|
|
||||||
def validate_environment( # noqa: PLR0915
|
def validate_environment( # noqa: PLR0915
|
||||||
model: Optional[str] = None, api_key: Optional[str] = None
|
model: Optional[str] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
Checks if the environment variables are valid for the given model.
|
Checks if the environment variables are valid for the given model.
|
||||||
|
@ -5224,11 +5226,6 @@ def validate_environment( # noqa: PLR0915
|
||||||
_, custom_llm_provider, _, _ = get_llm_provider(model=model)
|
_, custom_llm_provider, _, _ = get_llm_provider(model=model)
|
||||||
except Exception:
|
except Exception:
|
||||||
custom_llm_provider = None
|
custom_llm_provider = None
|
||||||
# # check if llm provider part of model name
|
|
||||||
# if model.split("/",1)[0] in litellm.provider_list:
|
|
||||||
# custom_llm_provider = model.split("/", 1)[0]
|
|
||||||
# model = model.split("/", 1)[1]
|
|
||||||
# custom_llm_provider_passed_in = True
|
|
||||||
|
|
||||||
if custom_llm_provider:
|
if custom_llm_provider:
|
||||||
if custom_llm_provider == "openai":
|
if custom_llm_provider == "openai":
|
||||||
|
@ -5497,6 +5494,17 @@ def validate_environment( # noqa: PLR0915
|
||||||
if "api_key" not in key.lower():
|
if "api_key" not in key.lower():
|
||||||
new_missing_keys.append(key)
|
new_missing_keys.append(key)
|
||||||
missing_keys = new_missing_keys
|
missing_keys = new_missing_keys
|
||||||
|
|
||||||
|
if api_base is not None:
|
||||||
|
new_missing_keys = []
|
||||||
|
for key in missing_keys:
|
||||||
|
if "api_base" not in key.lower():
|
||||||
|
new_missing_keys.append(key)
|
||||||
|
missing_keys = new_missing_keys
|
||||||
|
|
||||||
|
if len(missing_keys) == 0: # no missing keys
|
||||||
|
keys_in_environment = True
|
||||||
|
|
||||||
return {"keys_in_environment": keys_in_environment, "missing_keys": missing_keys}
|
return {"keys_in_environment": keys_in_environment, "missing_keys": missing_keys}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -3364,54 +3364,56 @@
|
||||||
"litellm_provider": "cohere",
|
"litellm_provider": "cohere",
|
||||||
"mode": "rerank"
|
"mode": "rerank"
|
||||||
},
|
},
|
||||||
"embed-english-v3.0": {
|
|
||||||
"max_tokens": 512,
|
|
||||||
"max_input_tokens": 512,
|
|
||||||
"input_cost_per_token": 0.00000010,
|
|
||||||
"output_cost_per_token": 0.00000,
|
|
||||||
"litellm_provider": "cohere",
|
|
||||||
"mode": "embedding"
|
|
||||||
},
|
|
||||||
"embed-english-light-v3.0": {
|
"embed-english-light-v3.0": {
|
||||||
"max_tokens": 512,
|
"max_tokens": 1024,
|
||||||
"max_input_tokens": 512,
|
"max_input_tokens": 1024,
|
||||||
"input_cost_per_token": 0.00000010,
|
"input_cost_per_token": 0.00000010,
|
||||||
"output_cost_per_token": 0.00000,
|
"output_cost_per_token": 0.00000,
|
||||||
"litellm_provider": "cohere",
|
"litellm_provider": "cohere",
|
||||||
"mode": "embedding"
|
"mode": "embedding"
|
||||||
},
|
},
|
||||||
"embed-multilingual-v3.0": {
|
"embed-multilingual-v3.0": {
|
||||||
"max_tokens": 512,
|
"max_tokens": 1024,
|
||||||
"max_input_tokens": 512,
|
"max_input_tokens": 1024,
|
||||||
"input_cost_per_token": 0.00000010,
|
"input_cost_per_token": 0.00000010,
|
||||||
"output_cost_per_token": 0.00000,
|
"output_cost_per_token": 0.00000,
|
||||||
"litellm_provider": "cohere",
|
"litellm_provider": "cohere",
|
||||||
"mode": "embedding"
|
"mode": "embedding"
|
||||||
},
|
},
|
||||||
"embed-english-v2.0": {
|
"embed-english-v2.0": {
|
||||||
"max_tokens": 512,
|
"max_tokens": 4096,
|
||||||
"max_input_tokens": 512,
|
"max_input_tokens": 4096,
|
||||||
"input_cost_per_token": 0.00000010,
|
"input_cost_per_token": 0.00000010,
|
||||||
"output_cost_per_token": 0.00000,
|
"output_cost_per_token": 0.00000,
|
||||||
"litellm_provider": "cohere",
|
"litellm_provider": "cohere",
|
||||||
"mode": "embedding"
|
"mode": "embedding"
|
||||||
},
|
},
|
||||||
"embed-english-light-v2.0": {
|
"embed-english-light-v2.0": {
|
||||||
"max_tokens": 512,
|
"max_tokens": 1024,
|
||||||
"max_input_tokens": 512,
|
"max_input_tokens": 1024,
|
||||||
"input_cost_per_token": 0.00000010,
|
"input_cost_per_token": 0.00000010,
|
||||||
"output_cost_per_token": 0.00000,
|
"output_cost_per_token": 0.00000,
|
||||||
"litellm_provider": "cohere",
|
"litellm_provider": "cohere",
|
||||||
"mode": "embedding"
|
"mode": "embedding"
|
||||||
},
|
},
|
||||||
"embed-multilingual-v2.0": {
|
"embed-multilingual-v2.0": {
|
||||||
"max_tokens": 256,
|
"max_tokens": 768,
|
||||||
"max_input_tokens": 256,
|
"max_input_tokens": 768,
|
||||||
"input_cost_per_token": 0.00000010,
|
"input_cost_per_token": 0.00000010,
|
||||||
"output_cost_per_token": 0.00000,
|
"output_cost_per_token": 0.00000,
|
||||||
"litellm_provider": "cohere",
|
"litellm_provider": "cohere",
|
||||||
"mode": "embedding"
|
"mode": "embedding"
|
||||||
},
|
},
|
||||||
|
"embed-english-v3.0": {
|
||||||
|
"max_tokens": 1024,
|
||||||
|
"max_input_tokens": 1024,
|
||||||
|
"input_cost_per_token": 0.00000010,
|
||||||
|
"input_cost_per_image": 0.0001,
|
||||||
|
"output_cost_per_token": 0.00000,
|
||||||
|
"litellm_provider": "cohere",
|
||||||
|
"mode": "embedding",
|
||||||
|
"supports_image_input": true
|
||||||
|
},
|
||||||
"replicate/meta/llama-2-13b": {
|
"replicate/meta/llama-2-13b": {
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
"max_input_tokens": 4096,
|
"max_input_tokens": 4096,
|
||||||
|
|
File diff suppressed because one or more lines are too long
|
@ -160,3 +160,12 @@ def test_get_llm_provider_jina_ai():
|
||||||
assert custom_llm_provider == "openai_like"
|
assert custom_llm_provider == "openai_like"
|
||||||
assert api_base == "https://api.jina.ai/v1"
|
assert api_base == "https://api.jina.ai/v1"
|
||||||
assert model == "jina-embeddings-v3"
|
assert model == "jina-embeddings-v3"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_llm_provider_hosted_vllm():
|
||||||
|
model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider(
|
||||||
|
model="hosted_vllm/llama-3.1-70b-instruct",
|
||||||
|
)
|
||||||
|
assert custom_llm_provider == "hosted_vllm"
|
||||||
|
assert model == "llama-3.1-70b-instruct"
|
||||||
|
assert dynamic_api_key == ""
|
||||||
|
|
|
@ -675,3 +675,15 @@ def test_alternating_roles_e2e():
|
||||||
"stream": False,
|
"stream": False,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_just_system_message():
|
||||||
|
from litellm.llms.prompt_templates.factory import _bedrock_converse_messages_pt
|
||||||
|
|
||||||
|
with pytest.raises(litellm.BadRequestError) as e:
|
||||||
|
_bedrock_converse_messages_pt(
|
||||||
|
messages=[],
|
||||||
|
model="anthropic.claude-3-sonnet-20240229-v1:0",
|
||||||
|
llm_provider="bedrock",
|
||||||
|
)
|
||||||
|
assert "bedrock requires at least one non-system message" in str(e.value)
|
||||||
|
|
|
@ -225,12 +225,20 @@ def test_add_headers_to_request(litellm_key_header_name):
|
||||||
"litellm_key_header_name",
|
"litellm_key_header_name",
|
||||||
["x-litellm-key", None],
|
["x-litellm-key", None],
|
||||||
)
|
)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"forward_headers",
|
||||||
|
[True, False],
|
||||||
|
)
|
||||||
@mock_patch_acompletion()
|
@mock_patch_acompletion()
|
||||||
def test_chat_completion_forward_headers(
|
def test_chat_completion_forward_headers(
|
||||||
mock_acompletion, client_no_auth, litellm_key_header_name
|
mock_acompletion, client_no_auth, litellm_key_header_name, forward_headers
|
||||||
):
|
):
|
||||||
global headers
|
global headers
|
||||||
try:
|
try:
|
||||||
|
if forward_headers:
|
||||||
|
gs = getattr(litellm.proxy.proxy_server, "general_settings")
|
||||||
|
gs["forward_client_headers_to_llm_api"] = True
|
||||||
|
setattr(litellm.proxy.proxy_server, "general_settings", gs)
|
||||||
if litellm_key_header_name is not None:
|
if litellm_key_header_name is not None:
|
||||||
gs = getattr(litellm.proxy.proxy_server, "general_settings")
|
gs = getattr(litellm.proxy.proxy_server, "general_settings")
|
||||||
gs["litellm_key_header_name"] = litellm_key_header_name
|
gs["litellm_key_header_name"] = litellm_key_header_name
|
||||||
|
@ -260,23 +268,14 @@ def test_chat_completion_forward_headers(
|
||||||
response = client_no_auth.post(
|
response = client_no_auth.post(
|
||||||
"/v1/chat/completions", json=test_data, headers=received_headers
|
"/v1/chat/completions", json=test_data, headers=received_headers
|
||||||
)
|
)
|
||||||
mock_acompletion.assert_called_once_with(
|
if not forward_headers:
|
||||||
model="gpt-3.5-turbo",
|
assert "headers" not in mock_acompletion.call_args.kwargs
|
||||||
messages=[
|
else:
|
||||||
{"role": "user", "content": "hi"},
|
assert mock_acompletion.call_args.kwargs["headers"] == {
|
||||||
],
|
|
||||||
max_tokens=10,
|
|
||||||
litellm_call_id=mock.ANY,
|
|
||||||
litellm_logging_obj=mock.ANY,
|
|
||||||
request_timeout=mock.ANY,
|
|
||||||
specific_deployment=True,
|
|
||||||
metadata=mock.ANY,
|
|
||||||
proxy_server_request=mock.ANY,
|
|
||||||
headers={
|
|
||||||
"x-custom-header": "Custom-Value",
|
"x-custom-header": "Custom-Value",
|
||||||
"x-another-header": "Another-Value",
|
"x-another-header": "Another-Value",
|
||||||
},
|
}
|
||||||
)
|
|
||||||
print(f"response - {response.text}")
|
print(f"response - {response.text}")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
result = response.json()
|
result = response.json()
|
||||||
|
|
|
@ -331,6 +331,13 @@ def test_validate_environment_api_key():
|
||||||
), f"Missing keys={response_obj['missing_keys']}"
|
), f"Missing keys={response_obj['missing_keys']}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_environment_api_base_dynamic():
|
||||||
|
for provider in ["ollama", "ollama_chat"]:
|
||||||
|
kv = validate_environment(provider + "/mistral", api_base="https://example.com")
|
||||||
|
assert kv["keys_in_environment"]
|
||||||
|
assert kv["missing_keys"] == []
|
||||||
|
|
||||||
|
|
||||||
@mock.patch.dict(os.environ, {"OLLAMA_API_BASE": "foo"}, clear=True)
|
@mock.patch.dict(os.environ, {"OLLAMA_API_BASE": "foo"}, clear=True)
|
||||||
def test_validate_environment_ollama():
|
def test_validate_environment_ollama():
|
||||||
for provider in ["ollama", "ollama_chat"]:
|
for provider in ["ollama", "ollama_chat"]:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue