Merge branch 'main' into litellm_gemini_refactoring

This commit is contained in:
Krish Dholakia 2024-06-17 17:28:50 -07:00 committed by GitHub
commit a80520004e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 1889 additions and 1035 deletions

View file

@ -0,0 +1,255 @@
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# Codestral API [Mistral AI]
Codestral is available in select code-completion plugins but can also be queried directly. See the documentation for more details.
## API Key
```python
# env variable
os.environ['CODESTRAL_API_KEY']
```
## FIM / Completions
:::info
Official Mistral API Docs: https://docs.mistral.ai/api/#operation/createFIMCompletion
:::
<Tabs>
<TabItem value="no-streaming" label="No Streaming">
#### Sample Usage
```python
import os
import litellm
os.environ['CODESTRAL_API_KEY']
response = await litellm.atext_completion(
model="text-completion-codestral/codestral-2405",
prompt="def is_odd(n): \n return n % 2 == 1 \ndef test_is_odd():",
suffix="return True", # optional
temperature=0, # optional
top_p=1, # optional
max_tokens=10, # optional
min_tokens=10, # optional
seed=10, # optional
stop=["return"], # optional
)
```
#### Expected Response
```json
{
"id": "b41e0df599f94bc1a46ea9fcdbc2aabe",
"object": "text_completion",
"created": 1589478378,
"model": "codestral-latest",
"choices": [
{
"text": "\n assert is_odd(1)\n assert",
"index": 0,
"logprobs": null,
"finish_reason": "length"
}
],
"usage": {
"prompt_tokens": 5,
"completion_tokens": 7,
"total_tokens": 12
}
}
```
</TabItem>
<TabItem value="stream" label="Streaming">
#### Sample Usage - Streaming
```python
import os
import litellm
os.environ['CODESTRAL_API_KEY']
response = await litellm.atext_completion(
model="text-completion-codestral/codestral-2405",
prompt="def is_odd(n): \n return n % 2 == 1 \ndef test_is_odd():",
suffix="return True", # optional
temperature=0, # optional
top_p=1, # optional
stream=True,
seed=10, # optional
stop=["return"], # optional
)
async for chunk in response:
print(chunk)
```
#### Expected Response
```json
{
"id": "726025d3e2d645d09d475bb0d29e3640",
"object": "text_completion",
"created": 1718659669,
"choices": [
{
"text": "This",
"index": 0,
"logprobs": null,
"finish_reason": null
}
],
"model": "codestral-2405",
}
```
</TabItem>
</Tabs>
### Supported Models
All models listed here https://docs.mistral.ai/platform/endpoints are supported. We actively maintain the list of models, pricing, token window, etc. [here](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json).
| Model Name | Function Call |
|----------------|--------------------------------------------------------------|
| Codestral Latest | `completion(model="text-completion-codestral/codestral-latest", messages)` |
| Codestral 2405 | `completion(model="text-completion-codestral/codestral-2405", messages)`|
## Chat Completions
:::info
Official Mistral API Docs: https://docs.mistral.ai/api/#operation/createChatCompletion
:::
<Tabs>
<TabItem value="no-streaming" label="No Streaming">
#### Sample Usage
```python
import os
import litellm
os.environ['CODESTRAL_API_KEY']
response = await litellm.acompletion(
model="codestral/codestral-latest",
messages=[
{
"role": "user",
"content": "Hey, how's it going?",
}
],
temperature=0.0, # optional
top_p=1, # optional
max_tokens=10, # optional
safe_prompt=False, # optional
seed=12, # optional
)
```
#### Expected Response
```json
{
"id": "chatcmpl-123",
"object": "chat.completion",
"created": 1677652288,
"model": "codestral/codestral-latest",
"system_fingerprint": None,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": "\n\nHello there, how may I assist you today?",
},
"logprobs": null,
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 9,
"completion_tokens": 12,
"total_tokens": 21
}
}
```
</TabItem>
<TabItem value="stream" label="Streaming">
#### Sample Usage - Streaming
```python
import os
import litellm
os.environ['CODESTRAL_API_KEY']
response = await litellm.acompletion(
model="codestral/codestral-latest",
messages=[
{
"role": "user",
"content": "Hey, how's it going?",
}
],
stream=True, # optional
temperature=0.0, # optional
top_p=1, # optional
max_tokens=10, # optional
safe_prompt=False, # optional
seed=12, # optional
)
async for chunk in response:
print(chunk)
```
#### Expected Response
```json
{
"id":"chatcmpl-123",
"object":"chat.completion.chunk",
"created":1694268190,
"model": "codestral/codestral-latest",
"system_fingerprint": None,
"choices":[
{
"index":0,
"delta":{"role":"assistant","content":"gm"},
"logprobs":null,
" finish_reason":null
}
]
}
```
</TabItem>
</Tabs>
### Supported Models
All models listed here https://docs.mistral.ai/platform/endpoints are supported. We actively maintain the list of models, pricing, token window, etc. [here](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json).
| Model Name | Function Call |
|----------------|--------------------------------------------------------------|
| Codestral Latest | `completion(model="codestral/codestral-latest", messages)` |
| Codestral 2405 | `completion(model="codestral/codestral-2405", messages)`|

View file

@ -22219,9 +22219,9 @@
} }
}, },
"node_modules/webpack-dev-server/node_modules/ws": { "node_modules/webpack-dev-server/node_modules/ws": {
"version": "8.13.0", "version": "8.17.1",
"resolved": "https://registry.npmjs.org/ws/-/ws-8.13.0.tgz", "resolved": "https://registry.npmjs.org/ws/-/ws-8.17.1.tgz",
"integrity": "sha512-x9vcZYTrFPC7aSIbj7sRCYo7L/Xb8Iy+pW0ng0wt2vCJv7M9HOMy0UoN3rr+IFC7hb7vXoqS+P9ktyLLLhO+LA==", "integrity": "sha512-6XQFvXTkbfUOZOKKILFG1PDK2NDQs4azKQl26T0YS5CxqWLgXajbPZ+h4gZekJyRqFU8pvnbAbbs/3TgRPy+GQ==",
"engines": { "engines": {
"node": ">=10.0.0" "node": ">=10.0.0"
}, },
@ -22518,9 +22518,9 @@
} }
}, },
"node_modules/ws": { "node_modules/ws": {
"version": "7.5.9", "version": "7.5.10",
"resolved": "https://registry.npmjs.org/ws/-/ws-7.5.9.tgz", "resolved": "https://registry.npmjs.org/ws/-/ws-7.5.10.tgz",
"integrity": "sha512-F+P9Jil7UiSKSkppIiD94dN07AwvFixvLIj1Og1Rl9GGMuNipJnV9JzjD6XuqmAeiswGvUmNLjr5cFuXwNS77Q==", "integrity": "sha512-+dbF1tHwZpXcbOJdVOkzLDxZP1ailvSxM6ZweXTegylPny803bFhA+vqBYw4s31NSAk4S2Qz+AKXK9a4wkdjcQ==",
"engines": { "engines": {
"node": ">=8.3.0" "node": ">=8.3.0"
}, },

View file

@ -134,10 +134,11 @@ const sidebars = {
"providers/vertex", "providers/vertex",
"providers/palm", "providers/palm",
"providers/gemini", "providers/gemini",
"providers/mistral",
"providers/anthropic", "providers/anthropic",
"providers/aws_sagemaker", "providers/aws_sagemaker",
"providers/bedrock", "providers/bedrock",
"providers/mistral",
"providers/codestral",
"providers/cohere", "providers/cohere",
"providers/anyscale", "providers/anyscale",
"providers/huggingface", "providers/huggingface",

File diff suppressed because it is too large Load diff

View file

@ -396,6 +396,8 @@ openai_compatible_endpoints: List = [
"api.endpoints.anyscale.com/v1", "api.endpoints.anyscale.com/v1",
"api.deepinfra.com/v1/openai", "api.deepinfra.com/v1/openai",
"api.mistral.ai/v1", "api.mistral.ai/v1",
"codestral.mistral.ai/v1/chat/completions",
"codestral.mistral.ai/v1/fim/completions",
"api.groq.com/openai/v1", "api.groq.com/openai/v1",
"api.deepseek.com/v1", "api.deepseek.com/v1",
"api.together.xyz/v1", "api.together.xyz/v1",
@ -406,6 +408,7 @@ openai_compatible_providers: List = [
"anyscale", "anyscale",
"mistral", "mistral",
"groq", "groq",
"codestral",
"deepseek", "deepseek",
"deepinfra", "deepinfra",
"perplexity", "perplexity",
@ -633,6 +636,8 @@ provider_list: List = [
"anyscale", "anyscale",
"mistral", "mistral",
"groq", "groq",
"codestral",
"text-completion-codestral",
"deepseek", "deepseek",
"maritalk", "maritalk",
"voyage", "voyage",
@ -801,6 +806,7 @@ from .llms.openai import (
DeepInfraConfig, DeepInfraConfig,
AzureAIStudioConfig, AzureAIStudioConfig,
) )
from .llms.text_completion_codestral import MistralTextCompletionConfig
from .llms.azure import ( from .llms.azure import (
AzureOpenAIConfig, AzureOpenAIConfig,
AzureOpenAIError, AzureOpenAIError,

View file

@ -1,20 +1,24 @@
# What is this? # What is this?
## File for 'response_cost' calculation in Logging ## File for 'response_cost' calculation in Logging
from typing import Optional, Union, Literal, List, Tuple from typing import List, Literal, Optional, Tuple, Union
import litellm
import litellm._logging import litellm._logging
from litellm import verbose_logger
from litellm.litellm_core_utils.llm_cost_calc.google import (
cost_per_token as google_cost_per_token,
)
from litellm.utils import ( from litellm.utils import (
ModelResponse, CallTypes,
CostPerToken,
EmbeddingResponse, EmbeddingResponse,
ImageResponse, ImageResponse,
TranscriptionResponse, ModelResponse,
TextCompletionResponse, TextCompletionResponse,
CallTypes, TranscriptionResponse,
print_verbose, print_verbose,
CostPerToken,
token_counter, token_counter,
) )
import litellm
from litellm import verbose_logger
def _cost_per_token_custom_pricing_helper( def _cost_per_token_custom_pricing_helper(
@ -42,10 +46,10 @@ def _cost_per_token_custom_pricing_helper(
def cost_per_token( def cost_per_token(
model: str = "", model: str = "",
prompt_tokens=0, prompt_tokens: float = 0,
completion_tokens=0, completion_tokens: float = 0,
response_time_ms=None, response_time_ms=None,
custom_llm_provider=None, custom_llm_provider: Optional[str] = None,
region_name=None, region_name=None,
### CUSTOM PRICING ### ### CUSTOM PRICING ###
custom_cost_per_token: Optional[CostPerToken] = None, custom_cost_per_token: Optional[CostPerToken] = None,
@ -66,6 +70,7 @@ def cost_per_token(
Returns: Returns:
tuple: A tuple containing the cost in USD dollars for prompt tokens and completion tokens, respectively. tuple: A tuple containing the cost in USD dollars for prompt tokens and completion tokens, respectively.
""" """
args = locals()
if model is None: if model is None:
raise Exception("Invalid arg. Model cannot be none.") raise Exception("Invalid arg. Model cannot be none.")
## CUSTOM PRICING ## ## CUSTOM PRICING ##
@ -94,7 +99,8 @@ def cost_per_token(
model_with_provider_and_region in model_cost_ref model_with_provider_and_region in model_cost_ref
): # use region based pricing, if it's available ): # use region based pricing, if it's available
model_with_provider = model_with_provider_and_region model_with_provider = model_with_provider_and_region
else:
_, custom_llm_provider, _, _ = litellm.get_llm_provider(model=model)
model_without_prefix = model model_without_prefix = model
model_parts = model.split("/") model_parts = model.split("/")
if len(model_parts) > 1: if len(model_parts) > 1:
@ -120,7 +126,14 @@ def cost_per_token(
# see this https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models # see this https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models
print_verbose(f"Looking up model={model} in model_cost_map") print_verbose(f"Looking up model={model} in model_cost_map")
if model in model_cost_ref: if custom_llm_provider == "vertex_ai" or custom_llm_provider == "gemini":
return google_cost_per_token(
model=model_without_prefix,
custom_llm_provider=custom_llm_provider,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
elif model in model_cost_ref:
print_verbose(f"Success: model={model} in model_cost_map") print_verbose(f"Success: model={model} in model_cost_map")
print_verbose( print_verbose(
f"prompt_tokens={prompt_tokens}; completion_tokens={completion_tokens}" f"prompt_tokens={prompt_tokens}; completion_tokens={completion_tokens}"

View file

@ -105,7 +105,6 @@ class LunaryLogger:
end_time=datetime.now(timezone.utc), end_time=datetime.now(timezone.utc),
error=None, error=None,
): ):
# Method definition
try: try:
print_verbose(f"Lunary Logging - Logging request for model {model}") print_verbose(f"Lunary Logging - Logging request for model {model}")
@ -114,10 +113,9 @@ class LunaryLogger:
metadata = litellm_params.get("metadata", {}) or {} metadata = litellm_params.get("metadata", {}) or {}
if optional_params: if optional_params:
# merge into extra
extra = {**extra, **optional_params} extra = {**extra, **optional_params}
tags = litellm_params.pop("tags", None) or [] tags = metadata.get("tags", None)
if extra: if extra:
extra.pop("extra_body", None) extra.pop("extra_body", None)

View file

@ -0,0 +1,82 @@
# What is this?
## Cost calculation for Google AI Studio / Vertex AI models
from typing import Literal, Tuple
import litellm
"""
Gemini pricing covers:
- token
- image
- audio
- video
"""
models_without_dynamic_pricing = ["gemini-1.0-pro", "gemini-pro"]
def _is_above_128k(tokens: float) -> bool:
if tokens > 128000:
return True
return False
def cost_per_token(
model: str,
custom_llm_provider: str,
prompt_tokens: float,
completion_tokens: float,
) -> Tuple[float, float]:
"""
Calculates the cost per token for a given model, prompt tokens, and completion tokens.
Input:
- model: str, the model name without provider prefix
- custom_llm_provider: str, either "vertex_ai-*" or "gemini"
- prompt_tokens: float, the number of input tokens
- completion_tokens: float, the number of output tokens
Returns:
Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
Raises:
Exception if model requires >128k pricing, but model cost not mapped
"""
## GET MODEL INFO
model_info = litellm.get_model_info(
model=model, custom_llm_provider=custom_llm_provider
)
## CALCULATE INPUT COST
if (
_is_above_128k(tokens=prompt_tokens)
and model not in models_without_dynamic_pricing
):
assert (
model_info["input_cost_per_token_above_128k_tokens"] is not None
), "model info for model={} does not have pricing for > 128k tokens\nmodel_info={}".format(
model, model_info
)
prompt_cost = (
prompt_tokens * model_info["input_cost_per_token_above_128k_tokens"]
)
else:
prompt_cost = prompt_tokens * model_info["input_cost_per_token"]
## CALCULATE OUTPUT COST
if (
_is_above_128k(tokens=completion_tokens)
and model not in models_without_dynamic_pricing
):
assert (
model_info["output_cost_per_token_above_128k_tokens"] is not None
), "model info for model={} does not have pricing for > 128k tokens\nmodel_info={}".format(
model, model_info
)
completion_cost = (
completion_tokens * model_info["output_cost_per_token_above_128k_tokens"]
)
else:
completion_cost = completion_tokens * model_info["output_cost_per_token"]
return prompt_cost, completion_cost

View file

@ -27,6 +27,25 @@ class BaseLLM:
""" """
return model_response return model_response
def process_text_completion_response(
self,
model: str,
response: Union[requests.Response, httpx.Response],
model_response: litellm.utils.TextCompletionResponse,
stream: bool,
logging_obj: Logging,
optional_params: dict,
api_key: str,
data: Union[dict, str],
messages: list,
print_verbose,
encoding,
) -> Union[litellm.utils.TextCompletionResponse, litellm.utils.CustomStreamWrapper]:
"""
Helper function to process the response across sync + async completion calls
"""
return model_response
def create_client_session(self): def create_client_session(self):
if litellm.client_session: if litellm.client_session:
_client_session = litellm.client_session _client_session = litellm.client_session

View file

@ -0,0 +1,532 @@
# What is this?
## Controller file for TextCompletionCodestral Integration - https://codestral.com/
from functools import partial
import os, types
import traceback
import json
from enum import Enum
import requests, copy # type: ignore
import time
from typing import Callable, Optional, List, Literal, Union
from litellm.utils import (
TextCompletionResponse,
Usage,
CustomStreamWrapper,
Message,
Choices,
)
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.types.llms.databricks import GenericStreamingChunk
import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from .base import BaseLLM
import httpx # type: ignore
class TextCompletionCodestralError(Exception):
def __init__(
self,
status_code,
message,
request: Optional[httpx.Request] = None,
response: Optional[httpx.Response] = None,
):
self.status_code = status_code
self.message = message
if request is not None:
self.request = request
else:
self.request = httpx.Request(
method="POST",
url="https://docs.codestral.com/user-guide/inference/rest_api",
)
if response is not None:
self.response = response
else:
self.response = httpx.Response(
status_code=status_code, request=self.request
)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
async def make_call(
client: AsyncHTTPHandler,
api_base: str,
headers: dict,
data: str,
model: str,
messages: list,
logging_obj,
):
response = await client.post(api_base, headers=headers, data=data, stream=True)
if response.status_code != 200:
raise TextCompletionCodestralError(
status_code=response.status_code, message=response.text
)
completion_stream = response.aiter_lines()
# LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response=completion_stream, # Pass the completion stream for logging
additional_args={"complete_input_dict": data},
)
return completion_stream
class MistralTextCompletionConfig:
"""
Reference: https://docs.mistral.ai/api/#operation/createFIMCompletion
"""
suffix: Optional[str] = None
temperature: Optional[int] = None
top_p: Optional[float] = None
max_tokens: Optional[int] = None
min_tokens: Optional[int] = None
stream: Optional[bool] = None
random_seed: Optional[int] = None
stop: Optional[str] = None
def __init__(
self,
suffix: Optional[str] = None,
temperature: Optional[int] = None,
top_p: Optional[float] = None,
max_tokens: Optional[int] = None,
min_tokens: Optional[int] = None,
stream: Optional[bool] = None,
random_seed: Optional[int] = None,
stop: Optional[str] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
return [
"suffix",
"temperature",
"top_p",
"max_tokens",
"stream",
"seed",
"stop",
]
def map_openai_params(self, non_default_params: dict, optional_params: dict):
for param, value in non_default_params.items():
if param == "suffix":
optional_params["suffix"] = value
if param == "temperature":
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
if param == "max_tokens":
optional_params["max_tokens"] = value
if param == "stream" and value == True:
optional_params["stream"] = value
if param == "stop":
optional_params["stop"] = value
if param == "seed":
optional_params["random_seed"] = value
if param == "min_tokens":
optional_params["min_tokens"] = value
return optional_params
def _chunk_parser(self, chunk_data: str) -> GenericStreamingChunk:
text = ""
is_finished = False
finish_reason = None
logprobs = None
chunk_data = chunk_data.replace("data:", "")
chunk_data = chunk_data.strip()
if len(chunk_data) == 0 or chunk_data == "[DONE]":
return {
"text": "",
"is_finished": is_finished,
"finish_reason": finish_reason,
}
chunk_data_dict = json.loads(chunk_data)
original_chunk = litellm.ModelResponse(**chunk_data_dict, stream=True)
_choices = chunk_data_dict.get("choices", []) or []
_choice = _choices[0]
text = _choice.get("delta", {}).get("content", "")
if _choice.get("finish_reason") is not None:
is_finished = True
finish_reason = _choice.get("finish_reason")
logprobs = _choice.get("logprobs")
return GenericStreamingChunk(
text=text,
original_chunk=original_chunk,
is_finished=is_finished,
finish_reason=finish_reason,
logprobs=logprobs,
)
class CodestralTextCompletion(BaseLLM):
def __init__(self) -> None:
super().__init__()
def _validate_environment(
self,
api_key: Optional[str],
user_headers: dict,
) -> dict:
if api_key is None:
raise ValueError(
"Missing CODESTRAL_API_Key - Please add CODESTRAL_API_Key to your environment variables"
)
headers = {
"content-type": "application/json",
"Authorization": "Bearer {}".format(api_key),
}
if user_headers is not None and isinstance(user_headers, dict):
headers = {**headers, **user_headers}
return headers
def output_parser(self, generated_text: str):
"""
Parse the output text to remove any special characters. In our current approach we just check for ChatML tokens.
Initial issue that prompted this - https://github.com/BerriAI/litellm/issues/763
"""
chat_template_tokens = [
"<|assistant|>",
"<|system|>",
"<|user|>",
"<s>",
"</s>",
]
for token in chat_template_tokens:
if generated_text.strip().startswith(token):
generated_text = generated_text.replace(token, "", 1)
if generated_text.endswith(token):
generated_text = generated_text[::-1].replace(token[::-1], "", 1)[::-1]
return generated_text
def process_text_completion_response(
self,
model: str,
response: Union[requests.Response, httpx.Response],
model_response: TextCompletionResponse,
stream: bool,
logging_obj: litellm.litellm_core_utils.litellm_logging.Logging,
optional_params: dict,
api_key: str,
data: Union[dict, str],
messages: list,
print_verbose,
encoding,
) -> TextCompletionResponse:
## LOGGING
logging_obj.post_call(
input=messages,
api_key=api_key,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
print_verbose(f"codestral api: raw model_response: {response.text}")
## RESPONSE OBJECT
if response.status_code != 200:
raise TextCompletionCodestralError(
message=str(response.text),
status_code=response.status_code,
)
try:
completion_response = response.json()
except:
raise TextCompletionCodestralError(message=response.text, status_code=422)
_original_choices = completion_response.get("choices", [])
_choices: List[litellm.utils.TextChoices] = []
for choice in _original_choices:
# This is what 1 choice looks like from codestral API
# {
# "index": 0,
# "message": {
# "role": "assistant",
# "content": "\n assert is_odd(1)\n assert",
# "tool_calls": null
# },
# "finish_reason": "length",
# "logprobs": null
# }
_finish_reason = None
_index = 0
_text = None
_logprobs = None
_choice_message = choice.get("message", {})
_choice = litellm.utils.TextChoices(
finish_reason=choice.get("finish_reason"),
index=choice.get("index"),
text=_choice_message.get("content"),
logprobs=choice.get("logprobs"),
)
_choices.append(_choice)
_response = litellm.TextCompletionResponse(
id=completion_response.get("id"),
choices=_choices,
created=completion_response.get("created"),
model=completion_response.get("model"),
usage=completion_response.get("usage"),
stream=False,
object=completion_response.get("object"),
)
return _response
def completion(
self,
model: str,
messages: list,
api_base: str,
custom_prompt_dict: dict,
model_response: TextCompletionResponse,
print_verbose: Callable,
encoding,
api_key: str,
logging_obj,
optional_params: dict,
timeout: Union[float, httpx.Timeout],
acompletion=None,
litellm_params=None,
logger_fn=None,
headers: dict = {},
) -> Union[TextCompletionResponse, CustomStreamWrapper]:
headers = self._validate_environment(api_key, headers)
completion_url = api_base or "https://codestral.mistral.ai/v1/fim/completions"
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages,
)
else:
prompt = prompt_factory(model=model, messages=messages)
## Load Config
config = litellm.MistralTextCompletionConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
stream = optional_params.pop("stream", False)
data = {
"prompt": prompt,
**optional_params,
}
input_text = prompt
## LOGGING
logging_obj.pre_call(
input=input_text,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"headers": headers,
"api_base": completion_url,
"acompletion": acompletion,
},
)
## COMPLETION CALL
if acompletion is True:
### ASYNC STREAMING
if stream is True:
return self.async_streaming(
model=model,
messages=messages,
data=data,
api_base=completion_url,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
api_key=api_key,
logging_obj=logging_obj,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
timeout=timeout,
) # type: ignore
else:
### ASYNC COMPLETION
return self.async_completion(
model=model,
messages=messages,
data=data,
api_base=completion_url,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
api_key=api_key,
logging_obj=logging_obj,
optional_params=optional_params,
stream=False,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
timeout=timeout,
) # type: ignore
### SYNC STREAMING
if stream is True:
response = requests.post(
completion_url,
headers=headers,
data=json.dumps(data),
stream=stream,
)
_response = CustomStreamWrapper(
response.iter_lines(),
model,
custom_llm_provider="codestral",
logging_obj=logging_obj,
)
return _response
### SYNC COMPLETION
else:
response = requests.post(
url=completion_url,
headers=headers,
data=json.dumps(data),
)
return self.process_text_completion_response(
model=model,
response=response,
model_response=model_response,
stream=optional_params.get("stream", False),
logging_obj=logging_obj, # type: ignore
optional_params=optional_params,
api_key=api_key,
data=data,
messages=messages,
print_verbose=print_verbose,
encoding=encoding,
)
async def async_completion(
self,
model: str,
messages: list,
api_base: str,
model_response: TextCompletionResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
stream,
data: dict,
optional_params: dict,
timeout: Union[float, httpx.Timeout],
litellm_params=None,
logger_fn=None,
headers={},
) -> TextCompletionResponse:
async_handler = AsyncHTTPHandler(timeout=httpx.Timeout(timeout=timeout))
try:
response = await async_handler.post(
api_base, headers=headers, data=json.dumps(data)
)
except httpx.HTTPStatusError as e:
raise TextCompletionCodestralError(
status_code=e.response.status_code,
message="HTTPStatusError - {}".format(e.response.text),
)
except Exception as e:
raise TextCompletionCodestralError(
status_code=500, message="{}\n{}".format(str(e), traceback.format_exc())
)
return self.process_text_completion_response(
model=model,
response=response,
model_response=model_response,
stream=stream,
logging_obj=logging_obj,
api_key=api_key,
data=data,
messages=messages,
print_verbose=print_verbose,
optional_params=optional_params,
encoding=encoding,
)
async def async_streaming(
self,
model: str,
messages: list,
api_base: str,
model_response: TextCompletionResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
data: dict,
timeout: Union[float, httpx.Timeout],
optional_params=None,
litellm_params=None,
logger_fn=None,
headers={},
) -> CustomStreamWrapper:
data["stream"] = True
streamwrapper = CustomStreamWrapper(
completion_stream=None,
make_call=partial(
make_call,
api_base=api_base,
headers=headers,
data=json.dumps(data),
model=model,
messages=messages,
logging_obj=logging_obj,
),
model=model,
custom_llm_provider="text-completion-codestral",
logging_obj=logging_obj,
)
return streamwrapper
def embedding(self, *args, **kwargs):
pass

View file

@ -107,6 +107,10 @@ from .llms.databricks import DatabricksChatCompletion
from .llms.huggingface_restapi import Huggingface from .llms.huggingface_restapi import Huggingface
from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion
from .llms.predibase import PredibaseChatCompletion from .llms.predibase import PredibaseChatCompletion
from .llms.bedrock_httpx import BedrockLLM, BedrockConverseLLM
from .llms.vertex_httpx import VertexLLM
from .llms.triton import TritonChatCompletion
from .llms.text_completion_codestral import CodestralTextCompletion
from .llms.prompt_templates.factory import ( from .llms.prompt_templates.factory import (
custom_prompt, custom_prompt,
function_call_prompt, function_call_prompt,
@ -143,6 +147,7 @@ azure_chat_completions = AzureChatCompletion()
azure_text_completions = AzureTextCompletion() azure_text_completions = AzureTextCompletion()
huggingface = Huggingface() huggingface = Huggingface()
predibase_chat_completions = PredibaseChatCompletion() predibase_chat_completions = PredibaseChatCompletion()
codestral_text_completions = CodestralTextCompletion()
triton_chat_completions = TritonChatCompletion() triton_chat_completions = TritonChatCompletion()
bedrock_chat_completion = BedrockLLM() bedrock_chat_completion = BedrockLLM()
bedrock_converse_chat_completion = BedrockConverseLLM() bedrock_converse_chat_completion = BedrockConverseLLM()
@ -345,6 +350,8 @@ async def acompletion(
or custom_llm_provider == "deepinfra" or custom_llm_provider == "deepinfra"
or custom_llm_provider == "perplexity" or custom_llm_provider == "perplexity"
or custom_llm_provider == "groq" or custom_llm_provider == "groq"
or custom_llm_provider == "codestral"
or custom_llm_provider == "text-completion-codestral"
or custom_llm_provider == "deepseek" or custom_llm_provider == "deepseek"
or custom_llm_provider == "text-completion-openai" or custom_llm_provider == "text-completion-openai"
or custom_llm_provider == "huggingface" or custom_llm_provider == "huggingface"
@ -374,9 +381,10 @@ async def acompletion(
else: else:
response = init_response # type: ignore response = init_response # type: ignore
if custom_llm_provider == "text-completion-openai" and isinstance( if (
response, TextCompletionResponse custom_llm_provider == "text-completion-openai"
): or custom_llm_provider == "text-completion-codestral"
) and isinstance(response, TextCompletionResponse):
response = litellm.OpenAITextCompletionConfig().convert_to_chat_model_response_object( response = litellm.OpenAITextCompletionConfig().convert_to_chat_model_response_object(
response_object=response, response_object=response,
model_response_object=litellm.ModelResponse(), model_response_object=litellm.ModelResponse(),
@ -1069,6 +1077,7 @@ def completion(
or custom_llm_provider == "deepinfra" or custom_llm_provider == "deepinfra"
or custom_llm_provider == "perplexity" or custom_llm_provider == "perplexity"
or custom_llm_provider == "groq" or custom_llm_provider == "groq"
or custom_llm_provider == "codestral"
or custom_llm_provider == "deepseek" or custom_llm_provider == "deepseek"
or custom_llm_provider == "anyscale" or custom_llm_provider == "anyscale"
or custom_llm_provider == "mistral" or custom_llm_provider == "mistral"
@ -2021,6 +2030,46 @@ def completion(
timeout=timeout, timeout=timeout,
) )
if (
"stream" in optional_params
and optional_params["stream"] is True
and acompletion is False
):
return _model_response
response = _model_response
elif custom_llm_provider == "text-completion-codestral":
api_base = (
api_base
or optional_params.pop("api_base", None)
or optional_params.pop("base_url", None)
or litellm.api_base
or "https://codestral.mistral.ai/v1/fim/completions"
)
api_key = api_key or litellm.api_key or get_secret("CODESTRAL_API_KEY")
text_completion_model_response = litellm.TextCompletionResponse(
stream=stream
)
_model_response = codestral_text_completions.completion( # type: ignore
model=model,
messages=messages,
model_response=text_completion_model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
logging_obj=logging,
acompletion=acompletion,
api_base=api_base,
custom_prompt_dict=custom_prompt_dict,
api_key=api_key,
timeout=timeout,
)
if ( if (
"stream" in optional_params "stream" in optional_params
and optional_params["stream"] is True and optional_params["stream"] is True
@ -3410,7 +3459,9 @@ def embedding(
###### Text Completion ################ ###### Text Completion ################
@client @client
async def atext_completion(*args, **kwargs): async def atext_completion(
*args, **kwargs
) -> Union[TextCompletionResponse, TextCompletionStreamWrapper]:
""" """
Implemented to handle async streaming for the text completion endpoint Implemented to handle async streaming for the text completion endpoint
""" """
@ -3442,6 +3493,7 @@ async def atext_completion(*args, **kwargs):
or custom_llm_provider == "deepinfra" or custom_llm_provider == "deepinfra"
or custom_llm_provider == "perplexity" or custom_llm_provider == "perplexity"
or custom_llm_provider == "groq" or custom_llm_provider == "groq"
or custom_llm_provider == "text-completion-codestral"
or custom_llm_provider == "deepseek" or custom_llm_provider == "deepseek"
or custom_llm_provider == "fireworks_ai" or custom_llm_provider == "fireworks_ai"
or custom_llm_provider == "text-completion-openai" or custom_llm_provider == "text-completion-openai"
@ -3703,6 +3755,7 @@ def text_completion(
custom_llm_provider == "openai" custom_llm_provider == "openai"
or custom_llm_provider == "azure" or custom_llm_provider == "azure"
or custom_llm_provider == "azure_text" or custom_llm_provider == "azure_text"
or custom_llm_provider == "text-completion-codestral"
or custom_llm_provider == "text-completion-openai" or custom_llm_provider == "text-completion-openai"
) )
and isinstance(prompt, list) and isinstance(prompt, list)

View file

@ -1564,6 +1564,27 @@
"mode": "completion", "mode": "completion",
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
}, },
"gemini/gemini-1.5-flash": {
"max_tokens": 8192,
"max_input_tokens": 1000000,
"max_output_tokens": 8192,
"max_images_per_prompt": 3000,
"max_videos_per_prompt": 10,
"max_video_length": 1,
"max_audio_length_hours": 8.4,
"max_audio_per_prompt": 1,
"max_pdf_size_mb": 30,
"input_cost_per_token": 0.00000035,
"input_cost_per_token_above_128k_tokens": 0.0000007,
"output_cost_per_token": 0.00000105,
"output_cost_per_token_above_128k_tokens": 0.0000021,
"litellm_provider": "gemini",
"mode": "chat",
"supports_system_messages": true,
"supports_function_calling": true,
"supports_vision": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
},
"gemini/gemini-1.5-flash-latest": { "gemini/gemini-1.5-flash-latest": {
"max_tokens": 8192, "max_tokens": 8192,
"max_input_tokens": 1000000, "max_input_tokens": 1000000,
@ -1580,6 +1601,7 @@
"output_cost_per_token_above_128k_tokens": 0.0000021, "output_cost_per_token_above_128k_tokens": 0.0000021,
"litellm_provider": "gemini", "litellm_provider": "gemini",
"mode": "chat", "mode": "chat",
"supports_system_messages": true,
"supports_function_calling": true, "supports_function_calling": true,
"supports_vision": true, "supports_vision": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
@ -1607,6 +1629,7 @@
"output_cost_per_token_above_128k_tokens": 0.0000021, "output_cost_per_token_above_128k_tokens": 0.0000021,
"litellm_provider": "gemini", "litellm_provider": "gemini",
"mode": "chat", "mode": "chat",
"supports_system_messages": true,
"supports_function_calling": true, "supports_function_calling": true,
"supports_vision": true, "supports_vision": true,
"supports_tool_choice": true, "supports_tool_choice": true,
@ -1622,6 +1645,7 @@
"output_cost_per_token_above_128k_tokens": 0.0000021, "output_cost_per_token_above_128k_tokens": 0.0000021,
"litellm_provider": "gemini", "litellm_provider": "gemini",
"mode": "chat", "mode": "chat",
"supports_system_messages": true,
"supports_function_calling": true, "supports_function_calling": true,
"supports_vision": true, "supports_vision": true,
"supports_tool_choice": true, "supports_tool_choice": true,

View file

@ -6057,8 +6057,11 @@ async def model_info_v2(
model_info[k] = v model_info[k] = v
_model["model_info"] = model_info _model["model_info"] = model_info
# don't return the api key / vertex credentials # don't return the api key / vertex credentials
# don't return the llm credentials
_model["litellm_params"].pop("api_key", None) _model["litellm_params"].pop("api_key", None)
_model["litellm_params"].pop("vertex_credentials", None) _model["litellm_params"].pop("vertex_credentials", None)
_model["litellm_params"].pop("aws_access_key_id", None)
_model["litellm_params"].pop("aws_secret_access_key", None)
verbose_proxy_logger.debug("all_models: %s", all_models) verbose_proxy_logger.debug("all_models: %s", all_models)
return {"data": all_models} return {"data": all_models}
@ -6570,8 +6573,11 @@ async def model_info_v1(
if k not in model_info: if k not in model_info:
model_info[k] = v model_info[k] = v
model["model_info"] = model_info model["model_info"] = model_info
# don't return the api key # don't return the llm credentials
model["litellm_params"].pop("api_key", None) model["litellm_params"].pop("api_key", None)
model["litellm_params"].pop("vertex_credentials", None)
model["litellm_params"].pop("aws_access_key_id", None)
model["litellm_params"].pop("aws_secret_access_key", None)
verbose_proxy_logger.debug("all_models: %s", all_models) verbose_proxy_logger.debug("all_models: %s", all_models)
return {"data": all_models} return {"data": all_models}

View file

@ -823,6 +823,34 @@ def test_completion_mistral_api():
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
@pytest.mark.asyncio
async def test_completion_codestral_chat_api():
try:
litellm.set_verbose = True
response = await litellm.acompletion(
model="codestral/codestral-latest",
messages=[
{
"role": "user",
"content": "Hey, how's it going?",
}
],
temperature=0.0,
top_p=1,
max_tokens=10,
safe_prompt=False,
seed=12,
)
# Add any assertions here to-check the response
print(response)
# cost = litellm.completion_cost(completion_response=response)
# print("cost to make mistral completion=", cost)
# assert cost > 0.0
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_completion_mistral_api_mistral_large_function_call(): def test_completion_mistral_api_mistral_large_function_call():
litellm.set_verbose = True litellm.set_verbose = True
tools = [ tools = [

View file

@ -1,20 +1,28 @@
import sys, os import os
import sys
import traceback import traceback
import litellm.cost_calculator
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
import asyncio
import time import time
from typing import Optional from typing import Optional
import pytest
import litellm import litellm
from litellm import ( from litellm import (
TranscriptionResponse,
completion_cost,
cost_per_token,
get_max_tokens, get_max_tokens,
model_cost, model_cost,
open_ai_chat_completion_models, open_ai_chat_completion_models,
TranscriptionResponse,
) )
from litellm.litellm_core_utils.litellm_logging import CustomLogger from litellm.litellm_core_utils.litellm_logging import CustomLogger
import pytest, asyncio
class CustomLoggingHandler(CustomLogger): class CustomLoggingHandler(CustomLogger):
@ -66,7 +74,7 @@ async def test_custom_pricing(sync_mode):
def test_custom_pricing_as_completion_cost_param(): def test_custom_pricing_as_completion_cost_param():
from litellm import ModelResponse, Choices, Message from litellm import Choices, Message, ModelResponse
from litellm.utils import Usage from litellm.utils import Usage
resp = ModelResponse( resp = ModelResponse(
@ -134,7 +142,7 @@ def test_cost_ft_gpt_35():
try: try:
# this tests if litellm.completion_cost can calculate cost for ft:gpt-3.5-turbo:my-org:custom_suffix:id # this tests if litellm.completion_cost can calculate cost for ft:gpt-3.5-turbo:my-org:custom_suffix:id
# it needs to lookup ft:gpt-3.5-turbo in the litellm model_cost map to get the correct cost # it needs to lookup ft:gpt-3.5-turbo in the litellm model_cost map to get the correct cost
from litellm import ModelResponse, Choices, Message from litellm import Choices, Message, ModelResponse
from litellm.utils import Usage from litellm.utils import Usage
resp = ModelResponse( resp = ModelResponse(
@ -179,7 +187,7 @@ def test_cost_azure_gpt_35():
try: try:
# this tests if litellm.completion_cost can calculate cost for azure/chatgpt-deployment-2 which maps to azure/gpt-3.5-turbo # this tests if litellm.completion_cost can calculate cost for azure/chatgpt-deployment-2 which maps to azure/gpt-3.5-turbo
# for this test we check if passing `model` to completion_cost overrides the completion cost # for this test we check if passing `model` to completion_cost overrides the completion cost
from litellm import ModelResponse, Choices, Message from litellm import Choices, Message, ModelResponse
from litellm.utils import Usage from litellm.utils import Usage
resp = ModelResponse( resp = ModelResponse(
@ -266,7 +274,7 @@ def test_cost_bedrock_pricing():
""" """
- get pricing specific to region for a model - get pricing specific to region for a model
""" """
from litellm import ModelResponse, Choices, Message from litellm import Choices, Message, ModelResponse
from litellm.utils import Usage from litellm.utils import Usage
litellm.set_verbose = True litellm.set_verbose = True
@ -475,13 +483,13 @@ def test_replicate_llama3_cost_tracking():
@pytest.mark.parametrize("is_streaming", [True, False]) # @pytest.mark.parametrize("is_streaming", [True, False]) #
def test_groq_response_cost_tracking(is_streaming): def test_groq_response_cost_tracking(is_streaming):
from litellm.utils import ( from litellm.utils import (
ModelResponse,
Choices,
Message,
Usage,
CallTypes, CallTypes,
StreamingChoices, Choices,
Delta, Delta,
Message,
ModelResponse,
StreamingChoices,
Usage,
) )
response = ModelResponse( response = ModelResponse(
@ -565,3 +573,58 @@ def test_together_ai_qwen_completion_cost():
) )
assert response == "together-ai-41.1b-80b" assert response == "together-ai-41.1b-80b"
@pytest.mark.parametrize("above_128k", [False, True])
@pytest.mark.parametrize("provider", ["vertex_ai", "gemini"])
def test_gemini_completion_cost(above_128k, provider):
"""
Check if cost correctly calculated for gemini models based on context window
"""
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
litellm.model_cost = litellm.get_model_cost_map(url="")
if provider == "gemini":
model_name = "gemini-1.5-flash-latest"
else:
model_name = "gemini-1.5-flash-preview-0514"
if above_128k:
prompt_tokens = 128001.0
output_tokens = 228001.0
else:
prompt_tokens = 128.0
output_tokens = 228.0
## GET MODEL FROM LITELLM.MODEL_INFO
model_info = litellm.get_model_info(model=model_name, custom_llm_provider=provider)
## EXPECTED COST
if above_128k:
assert (
model_info["input_cost_per_token_above_128k_tokens"] is not None
), "model info for model={} does not have pricing for > 128k tokens\nmodel_info={}".format(
model_name, model_info
)
assert (
model_info["output_cost_per_token_above_128k_tokens"] is not None
), "model info for model={} does not have pricing for > 128k tokens\nmodel_info={}".format(
model_name, model_info
)
input_cost = (
prompt_tokens * model_info["input_cost_per_token_above_128k_tokens"]
)
output_cost = (
output_tokens * model_info["output_cost_per_token_above_128k_tokens"]
)
else:
input_cost = prompt_tokens * model_info["input_cost_per_token"]
output_cost = output_tokens * model_info["output_cost_per_token"]
## CALCULATED COST
calculated_input_cost, calculated_output_cost = cost_per_token(
model=model_name,
prompt_tokens=prompt_tokens,
completion_tokens=output_tokens,
custom_llm_provider=provider,
)
assert calculated_input_cost == input_cost
assert calculated_output_cost == output_cost

View file

@ -24,8 +24,7 @@ def test_lunary_logging():
except Exception as e: except Exception as e:
print(e) print(e)
test_lunary_logging()
# test_lunary_logging()
def test_lunary_template(): def test_lunary_template():
@ -38,8 +37,7 @@ def test_lunary_template():
except Exception as e: except Exception as e:
print(e) print(e)
test_lunary_template()
# test_lunary_template()
def test_lunary_logging_with_metadata(): def test_lunary_logging_with_metadata():
@ -52,16 +50,16 @@ def test_lunary_logging_with_metadata():
metadata={ metadata={
"run_name": "litellmRUN", "run_name": "litellmRUN",
"project_name": "litellm-completion", "project_name": "litellm-completion",
"tags": ["tag1", "tag2"]
}, },
) )
print(response) print(response)
except Exception as e: except Exception as e:
print(e) print(e)
#test_lunary_logging_with_metadata() test_lunary_logging_with_metadata()
def test_lunary_with_tools(): def test_lunary_with_tools():
import litellm import litellm
messages = [{"role": "user", "content": "What's the weather like in San Francisco, Tokyo, and Paris?"}] messages = [{"role": "user", "content": "What's the weather like in San Francisco, Tokyo, and Paris?"}]
@ -97,7 +95,7 @@ def test_lunary_with_tools():
print("\nLLM Response:\n", response.choices[0].message) print("\nLLM Response:\n", response.choices[0].message)
#test_lunary_with_tools() test_lunary_with_tools()
def test_lunary_logging_with_streaming_and_metadata(): def test_lunary_logging_with_streaming_and_metadata():
try: try:
@ -117,5 +115,4 @@ def test_lunary_logging_with_streaming_and_metadata():
except Exception as e: except Exception as e:
print(e) print(e)
test_lunary_logging_with_streaming_and_metadata()
# test_lunary_logging_with_streaming_and_metadata()

View file

@ -4076,3 +4076,72 @@ async def test_async_text_completion_chat_model_stream():
# asyncio.run(test_async_text_completion_chat_model_stream()) # asyncio.run(test_async_text_completion_chat_model_stream())
@pytest.mark.asyncio
async def test_completion_codestral_fim_api():
try:
litellm.set_verbose = True
from litellm._logging import verbose_logger
import logging
verbose_logger.setLevel(level=logging.DEBUG)
response = await litellm.atext_completion(
model="text-completion-codestral/codestral-2405",
prompt="def is_odd(n): \n return n % 2 == 1 \ndef test_is_odd():",
suffix="return True",
temperature=0,
top_p=1,
max_tokens=10,
min_tokens=10,
seed=10,
stop=["return"],
)
# Add any assertions here to check the response
print(response)
assert response.choices[0].text is not None
assert len(response.choices[0].text) > 0
# cost = litellm.completion_cost(completion_response=response)
# print("cost to make mistral completion=", cost)
# assert cost > 0.0
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@pytest.mark.asyncio
async def test_completion_codestral_fim_api_stream():
try:
from litellm._logging import verbose_logger
import logging
litellm.set_verbose = False
# verbose_logger.setLevel(level=logging.DEBUG)
response = await litellm.atext_completion(
model="text-completion-codestral/codestral-2405",
prompt="def is_odd(n): \n return n % 2 == 1 \ndef test_is_odd():",
suffix="return True",
temperature=0,
top_p=1,
stream=True,
seed=10,
stop=["return"],
)
full_response = ""
# Add any assertions here to check the response
async for chunk in response:
print(chunk)
full_response += chunk.get("choices")[0].get("text") or ""
print("full_response", full_response)
assert len(full_response) > 2 # we at least have a few chars in response :)
# cost = litellm.completion_cost(completion_response=response)
# print("cost to make mistral completion=", cost)
# assert cost > 0.0
except Exception as e:
pytest.fail(f"Error occurred: {e}")

View file

@ -1,14 +1,15 @@
from typing import List, Optional, Union, Dict, Tuple, Literal
from typing_extensions import TypedDict
from enum import Enum
from typing_extensions import override, Required, Dict
from .llms.openai import ChatCompletionUsageBlock, ChatCompletionToolCallChunk
from ..litellm_core_utils.core_helpers import map_finish_reason
from openai._models import BaseModel as OpenAIObject
from pydantic import ConfigDict
import uuid
import json import json
import time import time
import uuid
from enum import Enum
from typing import Dict, List, Literal, Optional, Tuple, Union
from openai._models import BaseModel as OpenAIObject
from pydantic import ConfigDict
from typing_extensions import Dict, Required, TypedDict, override
from ..litellm_core_utils.core_helpers import map_finish_reason
from .llms.openai import ChatCompletionToolCallChunk, ChatCompletionUsageBlock
def _generate_id(): # private helper function def _generate_id(): # private helper function
@ -34,21 +35,31 @@ class ProviderField(TypedDict):
field_value: str field_value: str
class ModelInfo(TypedDict): class ModelInfo(TypedDict, total=False):
""" """
Model info for a given model, this is information found in litellm.model_prices_and_context_window.json Model info for a given model, this is information found in litellm.model_prices_and_context_window.json
""" """
max_tokens: Optional[int] max_tokens: Required[Optional[int]]
max_input_tokens: Optional[int] max_input_tokens: Required[Optional[int]]
max_output_tokens: Optional[int] max_output_tokens: Required[Optional[int]]
input_cost_per_token: float input_cost_per_token: Required[float]
output_cost_per_token: float input_cost_per_token_above_128k_tokens: Optional[float]
litellm_provider: str input_cost_per_image: Optional[float]
mode: Literal[ input_cost_per_audio_per_second: Optional[float]
input_cost_per_video_per_second: Optional[float]
output_cost_per_token: Required[float]
output_cost_per_token_above_128k_tokens: Optional[float]
output_cost_per_image: Optional[float]
output_cost_per_video_per_second: Optional[float]
output_cost_per_audio_per_second: Optional[float]
litellm_provider: Required[str]
mode: Required[
Literal[
"completion", "embedding", "image_generation", "chat", "audio_transcription" "completion", "embedding", "image_generation", "chat", "audio_transcription"
] ]
supported_openai_params: Optional[List[str]] ]
supported_openai_params: Required[Optional[List[str]]]
class GenericStreamingChunk(TypedDict): class GenericStreamingChunk(TypedDict):

View file

@ -2366,6 +2366,7 @@ def get_optional_params(
and custom_llm_provider != "together_ai" and custom_llm_provider != "together_ai"
and custom_llm_provider != "groq" and custom_llm_provider != "groq"
and custom_llm_provider != "deepseek" and custom_llm_provider != "deepseek"
and custom_llm_provider != "codestral"
and custom_llm_provider != "mistral" and custom_llm_provider != "mistral"
and custom_llm_provider != "anthropic" and custom_llm_provider != "anthropic"
and custom_llm_provider != "cohere_chat" and custom_llm_provider != "cohere_chat"
@ -2974,7 +2975,7 @@ def get_optional_params(
optional_params["stream"] = stream optional_params["stream"] = stream
if max_tokens: if max_tokens:
optional_params["max_tokens"] = max_tokens optional_params["max_tokens"] = max_tokens
elif custom_llm_provider == "mistral": elif custom_llm_provider == "mistral" or custom_llm_provider == "codestral":
supported_params = get_supported_openai_params( supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider model=model, custom_llm_provider=custom_llm_provider
) )
@ -2982,6 +2983,15 @@ def get_optional_params(
optional_params = litellm.MistralConfig().map_openai_params( optional_params = litellm.MistralConfig().map_openai_params(
non_default_params=non_default_params, optional_params=optional_params non_default_params=non_default_params, optional_params=optional_params
) )
elif custom_llm_provider == "text-completion-codestral":
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
)
_check_valid_arg(supported_params=supported_params)
optional_params = litellm.MistralTextCompletionConfig().map_openai_params(
non_default_params=non_default_params, optional_params=optional_params
)
elif custom_llm_provider == "databricks": elif custom_llm_provider == "databricks":
supported_params = get_supported_openai_params( supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider model=model, custom_llm_provider=custom_llm_provider
@ -3014,7 +3024,6 @@ def get_optional_params(
optional_params["response_format"] = response_format optional_params["response_format"] = response_format
if seed is not None: if seed is not None:
optional_params["seed"] = seed optional_params["seed"] = seed
elif custom_llm_provider == "deepseek": elif custom_llm_provider == "deepseek":
supported_params = get_supported_openai_params( supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider model=model, custom_llm_provider=custom_llm_provider
@ -3633,11 +3642,14 @@ def get_supported_openai_params(
"tool_choice", "tool_choice",
"max_retries", "max_retries",
] ]
elif custom_llm_provider == "mistral": elif custom_llm_provider == "mistral" or custom_llm_provider == "codestral":
# mistal and codestral api have the exact same params
if request_type == "chat_completion": if request_type == "chat_completion":
return litellm.MistralConfig().get_supported_openai_params() return litellm.MistralConfig().get_supported_openai_params()
elif request_type == "embeddings": elif request_type == "embeddings":
return litellm.MistralEmbeddingConfig().get_supported_openai_params() return litellm.MistralEmbeddingConfig().get_supported_openai_params()
elif custom_llm_provider == "text-completion-codestral":
return litellm.MistralTextCompletionConfig().get_supported_openai_params()
elif custom_llm_provider == "replicate": elif custom_llm_provider == "replicate":
return [ return [
"stream", "stream",
@ -3874,6 +3886,10 @@ def get_llm_provider(
# groq is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.groq.com/openai/v1 # groq is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.groq.com/openai/v1
api_base = "https://api.groq.com/openai/v1" api_base = "https://api.groq.com/openai/v1"
dynamic_api_key = get_secret("GROQ_API_KEY") dynamic_api_key = get_secret("GROQ_API_KEY")
elif custom_llm_provider == "codestral":
# codestral is openai compatible, we just need to set this to custom_openai and have the api_base be https://codestral.mistral.ai/v1
api_base = "https://codestral.mistral.ai/v1"
dynamic_api_key = get_secret("CODESTRAL_API_KEY")
elif custom_llm_provider == "deepseek": elif custom_llm_provider == "deepseek":
# deepseek is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.deepseek.com/v1 # deepseek is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.deepseek.com/v1
api_base = "https://api.deepseek.com/v1" api_base = "https://api.deepseek.com/v1"
@ -3966,6 +3982,12 @@ def get_llm_provider(
elif endpoint == "api.groq.com/openai/v1": elif endpoint == "api.groq.com/openai/v1":
custom_llm_provider = "groq" custom_llm_provider = "groq"
dynamic_api_key = get_secret("GROQ_API_KEY") dynamic_api_key = get_secret("GROQ_API_KEY")
elif endpoint == "https://codestral.mistral.ai/v1":
custom_llm_provider = "codestral"
dynamic_api_key = get_secret("CODESTRAL_API_KEY")
elif endpoint == "https://codestral.mistral.ai/v1":
custom_llm_provider = "text-completion-codestral"
dynamic_api_key = get_secret("CODESTRAL_API_KEY")
elif endpoint == "api.deepseek.com/v1": elif endpoint == "api.deepseek.com/v1":
custom_llm_provider = "deepseek" custom_llm_provider = "deepseek"
dynamic_api_key = get_secret("DEEPSEEK_API_KEY") dynamic_api_key = get_secret("DEEPSEEK_API_KEY")
@ -4286,8 +4308,10 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
split_model, custom_llm_provider, _, _ = get_llm_provider(model=model) split_model, custom_llm_provider, _, _ = get_llm_provider(model=model)
except: except:
pass pass
combined_model_name = model
else: else:
split_model = model split_model = model
combined_model_name = "{}/{}".format(custom_llm_provider, model)
######################### #########################
supported_openai_params = litellm.get_supported_openai_params( supported_openai_params = litellm.get_supported_openai_params(
@ -4305,33 +4329,58 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
} }
else: else:
""" """
Check if: Check if: (in order of specificity)
1. 'model' in litellm.model_cost. Checks "groq/llama3-8b-8192" in litellm.model_cost 1. 'custom_llm_provider/model' in litellm.model_cost. Checks "groq/llama3-8b-8192" if model="llama3-8b-8192" and custom_llm_provider="groq"
2. 'split_model' in litellm.model_cost. Checks "llama3-8b-8192" in litellm.model_cost 2. 'model' in litellm.model_cost. Checks "groq/llama3-8b-8192" in litellm.model_cost if model="groq/llama3-8b-8192" and custom_llm_provider=None
3. 'split_model' in litellm.model_cost. Checks "llama3-8b-8192" in litellm.model_cost if model="groq/llama3-8b-8192"
""" """
if model in litellm.model_cost: if combined_model_name in litellm.model_cost:
_model_info = litellm.model_cost[combined_model_name]
_model_info["supported_openai_params"] = supported_openai_params
if (
"litellm_provider" in _model_info
and _model_info["litellm_provider"] != custom_llm_provider
):
if custom_llm_provider == "vertex_ai" and _model_info[
"litellm_provider"
].startswith("vertex_ai"):
pass
else:
raise Exception
return _model_info
elif model in litellm.model_cost:
_model_info = litellm.model_cost[model] _model_info = litellm.model_cost[model]
_model_info["supported_openai_params"] = supported_openai_params _model_info["supported_openai_params"] = supported_openai_params
if ( if (
"litellm_provider" in _model_info "litellm_provider" in _model_info
and _model_info["litellm_provider"] != custom_llm_provider and _model_info["litellm_provider"] != custom_llm_provider
): ):
if custom_llm_provider == "vertex_ai" and _model_info[
"litellm_provider"
].startswith("vertex_ai"):
pass
else:
raise Exception raise Exception
return _model_info return _model_info
if split_model in litellm.model_cost: elif split_model in litellm.model_cost:
_model_info = litellm.model_cost[split_model] _model_info = litellm.model_cost[split_model]
_model_info["supported_openai_params"] = supported_openai_params _model_info["supported_openai_params"] = supported_openai_params
if ( if (
"litellm_provider" in _model_info "litellm_provider" in _model_info
and _model_info["litellm_provider"] != custom_llm_provider and _model_info["litellm_provider"] != custom_llm_provider
): ):
if custom_llm_provider == "vertex_ai" and _model_info[
"litellm_provider"
].startswith("vertex_ai"):
pass
else:
raise Exception raise Exception
return _model_info return _model_info
else: else:
raise ValueError( raise ValueError(
"This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json" "This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json"
) )
except: except Exception:
raise Exception( raise Exception(
"This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json" "This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json"
) )
@ -4650,6 +4699,14 @@ def validate_environment(model: Optional[str] = None) -> dict:
keys_in_environment = True keys_in_environment = True
else: else:
missing_keys.append("GROQ_API_KEY") missing_keys.append("GROQ_API_KEY")
elif (
custom_llm_provider == "codestral"
or custom_llm_provider == "text-completion-codestral"
):
if "CODESTRAL_API_KEY" in os.environ:
keys_in_environment = True
else:
missing_keys.append("GROQ_API_KEY")
elif custom_llm_provider == "deepseek": elif custom_llm_provider == "deepseek":
if "DEEPSEEK_API_KEY" in os.environ: if "DEEPSEEK_API_KEY" in os.environ:
keys_in_environment = True keys_in_environment = True
@ -8523,6 +8580,25 @@ class CustomStreamWrapper:
completion_tokens=response_obj["usage"].completion_tokens, completion_tokens=response_obj["usage"].completion_tokens,
total_tokens=response_obj["usage"].total_tokens, total_tokens=response_obj["usage"].total_tokens,
) )
elif self.custom_llm_provider == "text-completion-codestral":
response_obj = litellm.MistralTextCompletionConfig()._chunk_parser(
chunk
)
completion_obj["content"] = response_obj["text"]
print_verbose(f"completion obj content: {completion_obj['content']}")
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
if (
self.stream_options
and self.stream_options.get("include_usage", False) == True
and response_obj["usage"] is not None
):
self.sent_stream_usage = True
model_response.usage = litellm.Usage(
prompt_tokens=response_obj["usage"].prompt_tokens,
completion_tokens=response_obj["usage"].completion_tokens,
total_tokens=response_obj["usage"].total_tokens,
)
elif self.custom_llm_provider == "databricks": elif self.custom_llm_provider == "databricks":
response_obj = litellm.DatabricksConfig()._chunk_parser(chunk) response_obj = litellm.DatabricksConfig()._chunk_parser(chunk)
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]
@ -8996,6 +9072,7 @@ class CustomStreamWrapper:
or self.custom_llm_provider == "azure" or self.custom_llm_provider == "azure"
or self.custom_llm_provider == "custom_openai" or self.custom_llm_provider == "custom_openai"
or self.custom_llm_provider == "text-completion-openai" or self.custom_llm_provider == "text-completion-openai"
or self.custom_llm_provider == "text-completion-codestral"
or self.custom_llm_provider == "azure_text" or self.custom_llm_provider == "azure_text"
or self.custom_llm_provider == "anthropic" or self.custom_llm_provider == "anthropic"
or self.custom_llm_provider == "anthropic_text" or self.custom_llm_provider == "anthropic_text"

View file

@ -1564,6 +1564,27 @@
"mode": "completion", "mode": "completion",
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
}, },
"gemini/gemini-1.5-flash": {
"max_tokens": 8192,
"max_input_tokens": 1000000,
"max_output_tokens": 8192,
"max_images_per_prompt": 3000,
"max_videos_per_prompt": 10,
"max_video_length": 1,
"max_audio_length_hours": 8.4,
"max_audio_per_prompt": 1,
"max_pdf_size_mb": 30,
"input_cost_per_token": 0.00000035,
"input_cost_per_token_above_128k_tokens": 0.0000007,
"output_cost_per_token": 0.00000105,
"output_cost_per_token_above_128k_tokens": 0.0000021,
"litellm_provider": "gemini",
"mode": "chat",
"supports_system_messages": true,
"supports_function_calling": true,
"supports_vision": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
},
"gemini/gemini-1.5-flash-latest": { "gemini/gemini-1.5-flash-latest": {
"max_tokens": 8192, "max_tokens": 8192,
"max_input_tokens": 1000000, "max_input_tokens": 1000000,
@ -1580,6 +1601,7 @@
"output_cost_per_token_above_128k_tokens": 0.0000021, "output_cost_per_token_above_128k_tokens": 0.0000021,
"litellm_provider": "gemini", "litellm_provider": "gemini",
"mode": "chat", "mode": "chat",
"supports_system_messages": true,
"supports_function_calling": true, "supports_function_calling": true,
"supports_vision": true, "supports_vision": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models" "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
@ -1607,6 +1629,7 @@
"output_cost_per_token_above_128k_tokens": 0.0000021, "output_cost_per_token_above_128k_tokens": 0.0000021,
"litellm_provider": "gemini", "litellm_provider": "gemini",
"mode": "chat", "mode": "chat",
"supports_system_messages": true,
"supports_function_calling": true, "supports_function_calling": true,
"supports_vision": true, "supports_vision": true,
"supports_tool_choice": true, "supports_tool_choice": true,
@ -1622,6 +1645,7 @@
"output_cost_per_token_above_128k_tokens": 0.0000021, "output_cost_per_token_above_128k_tokens": 0.0000021,
"litellm_provider": "gemini", "litellm_provider": "gemini",
"mode": "chat", "mode": "chat",
"supports_system_messages": true,
"supports_function_calling": true, "supports_function_calling": true,
"supports_vision": true, "supports_vision": true,
"supports_tool_choice": true, "supports_tool_choice": true,

9
poetry.lock generated
View file

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. # This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
[[package]] [[package]]
name = "aiohttp" name = "aiohttp"
@ -2174,6 +2174,7 @@ files = [
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
@ -2820,13 +2821,13 @@ files = [
[[package]] [[package]]
name = "urllib3" name = "urllib3"
version = "2.2.1" version = "2.2.2"
description = "HTTP library with thread-safe connection pooling, file post, and more." description = "HTTP library with thread-safe connection pooling, file post, and more."
optional = false optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
{file = "urllib3-2.2.1-py3-none-any.whl", hash = "sha256:450b20ec296a467077128bff42b73080516e71b56ff59a60a02bef2232c4fa9d"}, {file = "urllib3-2.2.2-py3-none-any.whl", hash = "sha256:a448b2f64d686155468037e1ace9f2d2199776e17f0a46610480d311f73e3472"},
{file = "urllib3-2.2.1.tar.gz", hash = "sha256:d0570876c61ab9e520d776c38acbbb5b05a776d3f9ff98a5c8fd5162a444cf19"}, {file = "urllib3-2.2.2.tar.gz", hash = "sha256:dd505485549a7a552833da5e6063639d0d177c04f23bc3864e41e5dc5f612168"},
] ]
[package.extras] [package.extras]