forked from phoenix/litellm-mirror
Merge pull request #4247 from BerriAI/litellm_add_codestral_fim
[Feat] Add Codestral FIM API
This commit is contained in:
commit
fc23399b6f
9 changed files with 1018 additions and 8 deletions
255
docs/my-website/docs/providers/codestral.md
Normal file
255
docs/my-website/docs/providers/codestral.md
Normal file
|
@ -0,0 +1,255 @@
|
||||||
|
import Tabs from '@theme/Tabs';
|
||||||
|
import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
|
# Codestral API [Mistral AI]
|
||||||
|
|
||||||
|
Codestral is available in select code-completion plugins but can also be queried directly. See the documentation for more details.
|
||||||
|
|
||||||
|
## API Key
|
||||||
|
```python
|
||||||
|
# env variable
|
||||||
|
os.environ['CODESTRAL_API_KEY']
|
||||||
|
```
|
||||||
|
|
||||||
|
## FIM / Completions
|
||||||
|
|
||||||
|
:::info
|
||||||
|
|
||||||
|
Official Mistral API Docs: https://docs.mistral.ai/api/#operation/createFIMCompletion
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="no-streaming" label="No Streaming">
|
||||||
|
|
||||||
|
### Sample Usage
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
os.environ['CODESTRAL_API_KEY']
|
||||||
|
|
||||||
|
response = await litellm.atext_completion(
|
||||||
|
model="text-completion-codestral/codestral-2405",
|
||||||
|
prompt="def is_odd(n): \n return n % 2 == 1 \ndef test_is_odd():",
|
||||||
|
suffix="return True", # optional
|
||||||
|
temperature=0, # optional
|
||||||
|
top_p=1, # optional
|
||||||
|
max_tokens=10, # optional
|
||||||
|
min_tokens=10, # optional
|
||||||
|
seed=10, # optional
|
||||||
|
stop=["return"], # optional
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Expected Response
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"id": "b41e0df599f94bc1a46ea9fcdbc2aabe",
|
||||||
|
"object": "text_completion",
|
||||||
|
"created": 1589478378,
|
||||||
|
"model": "codestral-latest",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"text": "\n assert is_odd(1)\n assert",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"finish_reason": "length"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 5,
|
||||||
|
"completion_tokens": 7,
|
||||||
|
"total_tokens": 12
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="stream" label="Streaming">
|
||||||
|
|
||||||
|
### Sample Usage - Streaming
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
os.environ['CODESTRAL_API_KEY']
|
||||||
|
|
||||||
|
response = await litellm.atext_completion(
|
||||||
|
model="text-completion-codestral/codestral-2405",
|
||||||
|
prompt="def is_odd(n): \n return n % 2 == 1 \ndef test_is_odd():",
|
||||||
|
suffix="return True", # optional
|
||||||
|
temperature=0, # optional
|
||||||
|
top_p=1, # optional
|
||||||
|
stream=True,
|
||||||
|
seed=10, # optional
|
||||||
|
stop=["return"], # optional
|
||||||
|
)
|
||||||
|
|
||||||
|
async for chunk in response:
|
||||||
|
print(chunk)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Expected Response
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"id": "726025d3e2d645d09d475bb0d29e3640",
|
||||||
|
"object": "text_completion",
|
||||||
|
"created": 1718659669,
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"text": "This",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"finish_reason": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"model": "codestral-2405",
|
||||||
|
}
|
||||||
|
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
### Supported Models
|
||||||
|
All models listed here https://docs.mistral.ai/platform/endpoints are supported. We actively maintain the list of models, pricing, token window, etc. [here](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json).
|
||||||
|
|
||||||
|
| Model Name | Function Call |
|
||||||
|
|----------------|--------------------------------------------------------------|
|
||||||
|
| Codestral Latest | `completion(model="text-completion-codestral/codestral-latest", messages)` |
|
||||||
|
| Codestral 2405 | `completion(model="text-completion-codestral/codestral-2405", messages)`|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## Chat Completions
|
||||||
|
|
||||||
|
:::info
|
||||||
|
|
||||||
|
Official Mistral API Docs: https://docs.mistral.ai/api/#operation/createChatCompletion
|
||||||
|
:::
|
||||||
|
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="no-streaming" label="No Streaming">
|
||||||
|
|
||||||
|
### Sample Usage
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
os.environ['CODESTRAL_API_KEY']
|
||||||
|
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model="codestral/codestral-latest",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hey, how's it going?",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
temperature=0.0, # optional
|
||||||
|
top_p=1, # optional
|
||||||
|
max_tokens=10, # optional
|
||||||
|
safe_prompt=False, # optional
|
||||||
|
seed=12, # optional
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Expected Response
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"id": "chatcmpl-123",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": 1677652288,
|
||||||
|
"model": "codestral/codestral-latest",
|
||||||
|
"system_fingerprint": None,
|
||||||
|
"choices": [{
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "\n\nHello there, how may I assist you today?",
|
||||||
|
},
|
||||||
|
"logprobs": null,
|
||||||
|
"finish_reason": "stop"
|
||||||
|
}],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 9,
|
||||||
|
"completion_tokens": 12,
|
||||||
|
"total_tokens": 21
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="stream" label="Streaming">
|
||||||
|
|
||||||
|
### Sample Usage - Streaming
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
os.environ['CODESTRAL_API_KEY']
|
||||||
|
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model="codestral/codestral-latest",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hey, how's it going?",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
stream=True, # optional
|
||||||
|
temperature=0.0, # optional
|
||||||
|
top_p=1, # optional
|
||||||
|
max_tokens=10, # optional
|
||||||
|
safe_prompt=False, # optional
|
||||||
|
seed=12, # optional
|
||||||
|
)
|
||||||
|
async for chunk in response:
|
||||||
|
print(chunk)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Expected Response
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"id":"chatcmpl-123",
|
||||||
|
"object":"chat.completion.chunk",
|
||||||
|
"created":1694268190,
|
||||||
|
"model": "codestral/codestral-latest",
|
||||||
|
"system_fingerprint": None,
|
||||||
|
"choices":[
|
||||||
|
{
|
||||||
|
"index":0,
|
||||||
|
"delta":{"role":"assistant","content":"gm"},
|
||||||
|
"logprobs":null,
|
||||||
|
" finish_reason":null
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
### Supported Models
|
||||||
|
All models listed here https://docs.mistral.ai/platform/endpoints are supported. We actively maintain the list of models, pricing, token window, etc. [here](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json).
|
||||||
|
|
||||||
|
| Model Name | Function Call |
|
||||||
|
|----------------|--------------------------------------------------------------|
|
||||||
|
| Codestral Latest | `completion(model="codestral/codestral-latest", messages)` |
|
||||||
|
| Codestral 2405 | `completion(model="codestral/codestral-2405", messages)`|
|
|
@ -134,10 +134,11 @@ const sidebars = {
|
||||||
"providers/vertex",
|
"providers/vertex",
|
||||||
"providers/palm",
|
"providers/palm",
|
||||||
"providers/gemini",
|
"providers/gemini",
|
||||||
"providers/mistral",
|
|
||||||
"providers/anthropic",
|
"providers/anthropic",
|
||||||
"providers/aws_sagemaker",
|
"providers/aws_sagemaker",
|
||||||
"providers/bedrock",
|
"providers/bedrock",
|
||||||
|
"providers/mistral",
|
||||||
|
"providers/codestral",
|
||||||
"providers/cohere",
|
"providers/cohere",
|
||||||
"providers/anyscale",
|
"providers/anyscale",
|
||||||
"providers/huggingface",
|
"providers/huggingface",
|
||||||
|
|
|
@ -393,6 +393,8 @@ openai_compatible_endpoints: List = [
|
||||||
"api.endpoints.anyscale.com/v1",
|
"api.endpoints.anyscale.com/v1",
|
||||||
"api.deepinfra.com/v1/openai",
|
"api.deepinfra.com/v1/openai",
|
||||||
"api.mistral.ai/v1",
|
"api.mistral.ai/v1",
|
||||||
|
"codestral.mistral.ai/v1/chat/completions",
|
||||||
|
"codestral.mistral.ai/v1/fim/completions",
|
||||||
"api.groq.com/openai/v1",
|
"api.groq.com/openai/v1",
|
||||||
"api.deepseek.com/v1",
|
"api.deepseek.com/v1",
|
||||||
"api.together.xyz/v1",
|
"api.together.xyz/v1",
|
||||||
|
@ -403,6 +405,7 @@ openai_compatible_providers: List = [
|
||||||
"anyscale",
|
"anyscale",
|
||||||
"mistral",
|
"mistral",
|
||||||
"groq",
|
"groq",
|
||||||
|
"codestral",
|
||||||
"deepseek",
|
"deepseek",
|
||||||
"deepinfra",
|
"deepinfra",
|
||||||
"perplexity",
|
"perplexity",
|
||||||
|
@ -630,6 +633,8 @@ provider_list: List = [
|
||||||
"anyscale",
|
"anyscale",
|
||||||
"mistral",
|
"mistral",
|
||||||
"groq",
|
"groq",
|
||||||
|
"codestral",
|
||||||
|
"text-completion-codestral",
|
||||||
"deepseek",
|
"deepseek",
|
||||||
"maritalk",
|
"maritalk",
|
||||||
"voyage",
|
"voyage",
|
||||||
|
@ -798,6 +803,7 @@ from .llms.openai import (
|
||||||
DeepInfraConfig,
|
DeepInfraConfig,
|
||||||
AzureAIStudioConfig,
|
AzureAIStudioConfig,
|
||||||
)
|
)
|
||||||
|
from .llms.text_completion_codestral import MistralTextCompletionConfig
|
||||||
from .llms.azure import (
|
from .llms.azure import (
|
||||||
AzureOpenAIConfig,
|
AzureOpenAIConfig,
|
||||||
AzureOpenAIError,
|
AzureOpenAIError,
|
||||||
|
|
|
@ -27,6 +27,25 @@ class BaseLLM:
|
||||||
"""
|
"""
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
|
def process_text_completion_response(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
response: Union[requests.Response, httpx.Response],
|
||||||
|
model_response: litellm.utils.TextCompletionResponse,
|
||||||
|
stream: bool,
|
||||||
|
logging_obj: Logging,
|
||||||
|
optional_params: dict,
|
||||||
|
api_key: str,
|
||||||
|
data: Union[dict, str],
|
||||||
|
messages: list,
|
||||||
|
print_verbose,
|
||||||
|
encoding,
|
||||||
|
) -> Union[litellm.utils.TextCompletionResponse, litellm.utils.CustomStreamWrapper]:
|
||||||
|
"""
|
||||||
|
Helper function to process the response across sync + async completion calls
|
||||||
|
"""
|
||||||
|
return model_response
|
||||||
|
|
||||||
def create_client_session(self):
|
def create_client_session(self):
|
||||||
if litellm.client_session:
|
if litellm.client_session:
|
||||||
_client_session = litellm.client_session
|
_client_session = litellm.client_session
|
||||||
|
|
532
litellm/llms/text_completion_codestral.py
Normal file
532
litellm/llms/text_completion_codestral.py
Normal file
|
@ -0,0 +1,532 @@
|
||||||
|
# What is this?
|
||||||
|
## Controller file for TextCompletionCodestral Integration - https://codestral.com/
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
import os, types
|
||||||
|
import traceback
|
||||||
|
import json
|
||||||
|
from enum import Enum
|
||||||
|
import requests, copy # type: ignore
|
||||||
|
import time
|
||||||
|
from typing import Callable, Optional, List, Literal, Union
|
||||||
|
from litellm.utils import (
|
||||||
|
TextCompletionResponse,
|
||||||
|
Usage,
|
||||||
|
CustomStreamWrapper,
|
||||||
|
Message,
|
||||||
|
Choices,
|
||||||
|
)
|
||||||
|
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||||
|
from litellm.types.llms.databricks import GenericStreamingChunk
|
||||||
|
import litellm
|
||||||
|
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||||
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||||
|
from .base import BaseLLM
|
||||||
|
import httpx # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
class TextCompletionCodestralError(Exception):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
status_code,
|
||||||
|
message,
|
||||||
|
request: Optional[httpx.Request] = None,
|
||||||
|
response: Optional[httpx.Response] = None,
|
||||||
|
):
|
||||||
|
self.status_code = status_code
|
||||||
|
self.message = message
|
||||||
|
if request is not None:
|
||||||
|
self.request = request
|
||||||
|
else:
|
||||||
|
self.request = httpx.Request(
|
||||||
|
method="POST",
|
||||||
|
url="https://docs.codestral.com/user-guide/inference/rest_api",
|
||||||
|
)
|
||||||
|
if response is not None:
|
||||||
|
self.response = response
|
||||||
|
else:
|
||||||
|
self.response = httpx.Response(
|
||||||
|
status_code=status_code, request=self.request
|
||||||
|
)
|
||||||
|
super().__init__(
|
||||||
|
self.message
|
||||||
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
|
|
||||||
|
async def make_call(
|
||||||
|
client: AsyncHTTPHandler,
|
||||||
|
api_base: str,
|
||||||
|
headers: dict,
|
||||||
|
data: str,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
logging_obj,
|
||||||
|
):
|
||||||
|
response = await client.post(api_base, headers=headers, data=data, stream=True)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise TextCompletionCodestralError(
|
||||||
|
status_code=response.status_code, message=response.text
|
||||||
|
)
|
||||||
|
|
||||||
|
completion_stream = response.aiter_lines()
|
||||||
|
# LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=messages,
|
||||||
|
api_key="",
|
||||||
|
original_response=completion_stream, # Pass the completion stream for logging
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
)
|
||||||
|
|
||||||
|
return completion_stream
|
||||||
|
|
||||||
|
|
||||||
|
class MistralTextCompletionConfig:
|
||||||
|
"""
|
||||||
|
Reference: https://docs.mistral.ai/api/#operation/createFIMCompletion
|
||||||
|
"""
|
||||||
|
|
||||||
|
suffix: Optional[str] = None
|
||||||
|
temperature: Optional[int] = None
|
||||||
|
top_p: Optional[float] = None
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
min_tokens: Optional[int] = None
|
||||||
|
stream: Optional[bool] = None
|
||||||
|
random_seed: Optional[int] = None
|
||||||
|
stop: Optional[str] = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
suffix: Optional[str] = None,
|
||||||
|
temperature: Optional[int] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
min_tokens: Optional[int] = None,
|
||||||
|
stream: Optional[bool] = None,
|
||||||
|
random_seed: Optional[int] = None,
|
||||||
|
stop: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
locals_ = locals().copy()
|
||||||
|
for key, value in locals_.items():
|
||||||
|
if key != "self" and value is not None:
|
||||||
|
setattr(self.__class__, key, value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config(cls):
|
||||||
|
return {
|
||||||
|
k: v
|
||||||
|
for k, v in cls.__dict__.items()
|
||||||
|
if not k.startswith("__")
|
||||||
|
and not isinstance(
|
||||||
|
v,
|
||||||
|
(
|
||||||
|
types.FunctionType,
|
||||||
|
types.BuiltinFunctionType,
|
||||||
|
classmethod,
|
||||||
|
staticmethod,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
and v is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_supported_openai_params(self):
|
||||||
|
return [
|
||||||
|
"suffix",
|
||||||
|
"temperature",
|
||||||
|
"top_p",
|
||||||
|
"max_tokens",
|
||||||
|
"stream",
|
||||||
|
"seed",
|
||||||
|
"stop",
|
||||||
|
]
|
||||||
|
|
||||||
|
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
||||||
|
for param, value in non_default_params.items():
|
||||||
|
if param == "suffix":
|
||||||
|
optional_params["suffix"] = value
|
||||||
|
if param == "temperature":
|
||||||
|
optional_params["temperature"] = value
|
||||||
|
if param == "top_p":
|
||||||
|
optional_params["top_p"] = value
|
||||||
|
if param == "max_tokens":
|
||||||
|
optional_params["max_tokens"] = value
|
||||||
|
if param == "stream" and value == True:
|
||||||
|
optional_params["stream"] = value
|
||||||
|
if param == "stop":
|
||||||
|
optional_params["stop"] = value
|
||||||
|
if param == "seed":
|
||||||
|
optional_params["random_seed"] = value
|
||||||
|
if param == "min_tokens":
|
||||||
|
optional_params["min_tokens"] = value
|
||||||
|
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
def _chunk_parser(self, chunk_data: str) -> GenericStreamingChunk:
|
||||||
|
text = ""
|
||||||
|
is_finished = False
|
||||||
|
finish_reason = None
|
||||||
|
logprobs = None
|
||||||
|
|
||||||
|
chunk_data = chunk_data.replace("data:", "")
|
||||||
|
chunk_data = chunk_data.strip()
|
||||||
|
if len(chunk_data) == 0 or chunk_data == "[DONE]":
|
||||||
|
return {
|
||||||
|
"text": "",
|
||||||
|
"is_finished": is_finished,
|
||||||
|
"finish_reason": finish_reason,
|
||||||
|
}
|
||||||
|
chunk_data_dict = json.loads(chunk_data)
|
||||||
|
original_chunk = litellm.ModelResponse(**chunk_data_dict, stream=True)
|
||||||
|
_choices = chunk_data_dict.get("choices", []) or []
|
||||||
|
_choice = _choices[0]
|
||||||
|
text = _choice.get("delta", {}).get("content", "")
|
||||||
|
|
||||||
|
if _choice.get("finish_reason") is not None:
|
||||||
|
is_finished = True
|
||||||
|
finish_reason = _choice.get("finish_reason")
|
||||||
|
logprobs = _choice.get("logprobs")
|
||||||
|
|
||||||
|
return GenericStreamingChunk(
|
||||||
|
text=text,
|
||||||
|
original_chunk=original_chunk,
|
||||||
|
is_finished=is_finished,
|
||||||
|
finish_reason=finish_reason,
|
||||||
|
logprobs=logprobs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CodestralTextCompletion(BaseLLM):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def _validate_environment(
|
||||||
|
self,
|
||||||
|
api_key: Optional[str],
|
||||||
|
user_headers: dict,
|
||||||
|
) -> dict:
|
||||||
|
if api_key is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Missing CODESTRAL_API_Key - Please add CODESTRAL_API_Key to your environment variables"
|
||||||
|
)
|
||||||
|
headers = {
|
||||||
|
"content-type": "application/json",
|
||||||
|
"Authorization": "Bearer {}".format(api_key),
|
||||||
|
}
|
||||||
|
if user_headers is not None and isinstance(user_headers, dict):
|
||||||
|
headers = {**headers, **user_headers}
|
||||||
|
return headers
|
||||||
|
|
||||||
|
def output_parser(self, generated_text: str):
|
||||||
|
"""
|
||||||
|
Parse the output text to remove any special characters. In our current approach we just check for ChatML tokens.
|
||||||
|
|
||||||
|
Initial issue that prompted this - https://github.com/BerriAI/litellm/issues/763
|
||||||
|
"""
|
||||||
|
chat_template_tokens = [
|
||||||
|
"<|assistant|>",
|
||||||
|
"<|system|>",
|
||||||
|
"<|user|>",
|
||||||
|
"<s>",
|
||||||
|
"</s>",
|
||||||
|
]
|
||||||
|
for token in chat_template_tokens:
|
||||||
|
if generated_text.strip().startswith(token):
|
||||||
|
generated_text = generated_text.replace(token, "", 1)
|
||||||
|
if generated_text.endswith(token):
|
||||||
|
generated_text = generated_text[::-1].replace(token[::-1], "", 1)[::-1]
|
||||||
|
return generated_text
|
||||||
|
|
||||||
|
def process_text_completion_response(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
response: Union[requests.Response, httpx.Response],
|
||||||
|
model_response: TextCompletionResponse,
|
||||||
|
stream: bool,
|
||||||
|
logging_obj: litellm.litellm_core_utils.litellm_logging.Logging,
|
||||||
|
optional_params: dict,
|
||||||
|
api_key: str,
|
||||||
|
data: Union[dict, str],
|
||||||
|
messages: list,
|
||||||
|
print_verbose,
|
||||||
|
encoding,
|
||||||
|
) -> TextCompletionResponse:
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=messages,
|
||||||
|
api_key=api_key,
|
||||||
|
original_response=response.text,
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
)
|
||||||
|
print_verbose(f"codestral api: raw model_response: {response.text}")
|
||||||
|
## RESPONSE OBJECT
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise TextCompletionCodestralError(
|
||||||
|
message=str(response.text),
|
||||||
|
status_code=response.status_code,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
completion_response = response.json()
|
||||||
|
except:
|
||||||
|
raise TextCompletionCodestralError(message=response.text, status_code=422)
|
||||||
|
|
||||||
|
_original_choices = completion_response.get("choices", [])
|
||||||
|
_choices: List[litellm.utils.TextChoices] = []
|
||||||
|
for choice in _original_choices:
|
||||||
|
# This is what 1 choice looks like from codestral API
|
||||||
|
# {
|
||||||
|
# "index": 0,
|
||||||
|
# "message": {
|
||||||
|
# "role": "assistant",
|
||||||
|
# "content": "\n assert is_odd(1)\n assert",
|
||||||
|
# "tool_calls": null
|
||||||
|
# },
|
||||||
|
# "finish_reason": "length",
|
||||||
|
# "logprobs": null
|
||||||
|
# }
|
||||||
|
_finish_reason = None
|
||||||
|
_index = 0
|
||||||
|
_text = None
|
||||||
|
_logprobs = None
|
||||||
|
|
||||||
|
_choice_message = choice.get("message", {})
|
||||||
|
_choice = litellm.utils.TextChoices(
|
||||||
|
finish_reason=choice.get("finish_reason"),
|
||||||
|
index=choice.get("index"),
|
||||||
|
text=_choice_message.get("content"),
|
||||||
|
logprobs=choice.get("logprobs"),
|
||||||
|
)
|
||||||
|
|
||||||
|
_choices.append(_choice)
|
||||||
|
|
||||||
|
_response = litellm.TextCompletionResponse(
|
||||||
|
id=completion_response.get("id"),
|
||||||
|
choices=_choices,
|
||||||
|
created=completion_response.get("created"),
|
||||||
|
model=completion_response.get("model"),
|
||||||
|
usage=completion_response.get("usage"),
|
||||||
|
stream=False,
|
||||||
|
object=completion_response.get("object"),
|
||||||
|
)
|
||||||
|
return _response
|
||||||
|
|
||||||
|
def completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
api_base: str,
|
||||||
|
custom_prompt_dict: dict,
|
||||||
|
model_response: TextCompletionResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
encoding,
|
||||||
|
api_key: str,
|
||||||
|
logging_obj,
|
||||||
|
optional_params: dict,
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
acompletion=None,
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
headers: dict = {},
|
||||||
|
) -> Union[TextCompletionResponse, CustomStreamWrapper]:
|
||||||
|
headers = self._validate_environment(api_key, headers)
|
||||||
|
|
||||||
|
completion_url = api_base or "https://codestral.mistral.ai/v1/fim/completions"
|
||||||
|
|
||||||
|
if model in custom_prompt_dict:
|
||||||
|
# check if the model has a registered custom prompt
|
||||||
|
model_prompt_details = custom_prompt_dict[model]
|
||||||
|
prompt = custom_prompt(
|
||||||
|
role_dict=model_prompt_details["roles"],
|
||||||
|
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||||
|
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prompt = prompt_factory(model=model, messages=messages)
|
||||||
|
|
||||||
|
## Load Config
|
||||||
|
config = litellm.MistralTextCompletionConfig.get_config()
|
||||||
|
for k, v in config.items():
|
||||||
|
if (
|
||||||
|
k not in optional_params
|
||||||
|
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||||
|
optional_params[k] = v
|
||||||
|
|
||||||
|
stream = optional_params.pop("stream", False)
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"prompt": prompt,
|
||||||
|
**optional_params,
|
||||||
|
}
|
||||||
|
input_text = prompt
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input=input_text,
|
||||||
|
api_key=api_key,
|
||||||
|
additional_args={
|
||||||
|
"complete_input_dict": data,
|
||||||
|
"headers": headers,
|
||||||
|
"api_base": completion_url,
|
||||||
|
"acompletion": acompletion,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
## COMPLETION CALL
|
||||||
|
if acompletion is True:
|
||||||
|
### ASYNC STREAMING
|
||||||
|
if stream is True:
|
||||||
|
return self.async_streaming(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
data=data,
|
||||||
|
api_base=completion_url,
|
||||||
|
model_response=model_response,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
encoding=encoding,
|
||||||
|
api_key=api_key,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
logger_fn=logger_fn,
|
||||||
|
headers=headers,
|
||||||
|
timeout=timeout,
|
||||||
|
) # type: ignore
|
||||||
|
else:
|
||||||
|
### ASYNC COMPLETION
|
||||||
|
return self.async_completion(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
data=data,
|
||||||
|
api_base=completion_url,
|
||||||
|
model_response=model_response,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
encoding=encoding,
|
||||||
|
api_key=api_key,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
optional_params=optional_params,
|
||||||
|
stream=False,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
logger_fn=logger_fn,
|
||||||
|
headers=headers,
|
||||||
|
timeout=timeout,
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
### SYNC STREAMING
|
||||||
|
if stream is True:
|
||||||
|
response = requests.post(
|
||||||
|
completion_url,
|
||||||
|
headers=headers,
|
||||||
|
data=json.dumps(data),
|
||||||
|
stream=stream,
|
||||||
|
)
|
||||||
|
_response = CustomStreamWrapper(
|
||||||
|
response.iter_lines(),
|
||||||
|
model,
|
||||||
|
custom_llm_provider="codestral",
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
return _response
|
||||||
|
### SYNC COMPLETION
|
||||||
|
else:
|
||||||
|
response = requests.post(
|
||||||
|
url=completion_url,
|
||||||
|
headers=headers,
|
||||||
|
data=json.dumps(data),
|
||||||
|
)
|
||||||
|
return self.process_text_completion_response(
|
||||||
|
model=model,
|
||||||
|
response=response,
|
||||||
|
model_response=model_response,
|
||||||
|
stream=optional_params.get("stream", False),
|
||||||
|
logging_obj=logging_obj, # type: ignore
|
||||||
|
optional_params=optional_params,
|
||||||
|
api_key=api_key,
|
||||||
|
data=data,
|
||||||
|
messages=messages,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
encoding=encoding,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
api_base: str,
|
||||||
|
model_response: TextCompletionResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
encoding,
|
||||||
|
api_key,
|
||||||
|
logging_obj,
|
||||||
|
stream,
|
||||||
|
data: dict,
|
||||||
|
optional_params: dict,
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
headers={},
|
||||||
|
) -> TextCompletionResponse:
|
||||||
|
|
||||||
|
async_handler = AsyncHTTPHandler(timeout=httpx.Timeout(timeout=timeout))
|
||||||
|
try:
|
||||||
|
response = await async_handler.post(
|
||||||
|
api_base, headers=headers, data=json.dumps(data)
|
||||||
|
)
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
raise TextCompletionCodestralError(
|
||||||
|
status_code=e.response.status_code,
|
||||||
|
message="HTTPStatusError - {}".format(e.response.text),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise TextCompletionCodestralError(
|
||||||
|
status_code=500, message="{}\n{}".format(str(e), traceback.format_exc())
|
||||||
|
)
|
||||||
|
return self.process_text_completion_response(
|
||||||
|
model=model,
|
||||||
|
response=response,
|
||||||
|
model_response=model_response,
|
||||||
|
stream=stream,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
api_key=api_key,
|
||||||
|
data=data,
|
||||||
|
messages=messages,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
optional_params=optional_params,
|
||||||
|
encoding=encoding,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_streaming(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list,
|
||||||
|
api_base: str,
|
||||||
|
model_response: TextCompletionResponse,
|
||||||
|
print_verbose: Callable,
|
||||||
|
encoding,
|
||||||
|
api_key,
|
||||||
|
logging_obj,
|
||||||
|
data: dict,
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
optional_params=None,
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
headers={},
|
||||||
|
) -> CustomStreamWrapper:
|
||||||
|
data["stream"] = True
|
||||||
|
|
||||||
|
streamwrapper = CustomStreamWrapper(
|
||||||
|
completion_stream=None,
|
||||||
|
make_call=partial(
|
||||||
|
make_call,
|
||||||
|
api_base=api_base,
|
||||||
|
headers=headers,
|
||||||
|
data=json.dumps(data),
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
),
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider="text-completion-codestral",
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
return streamwrapper
|
||||||
|
|
||||||
|
def embedding(self, *args, **kwargs):
|
||||||
|
pass
|
|
@ -82,6 +82,7 @@ from .llms.predibase import PredibaseChatCompletion
|
||||||
from .llms.bedrock_httpx import BedrockLLM, BedrockConverseLLM
|
from .llms.bedrock_httpx import BedrockLLM, BedrockConverseLLM
|
||||||
from .llms.vertex_httpx import VertexLLM
|
from .llms.vertex_httpx import VertexLLM
|
||||||
from .llms.triton import TritonChatCompletion
|
from .llms.triton import TritonChatCompletion
|
||||||
|
from .llms.text_completion_codestral import CodestralTextCompletion
|
||||||
from .llms.prompt_templates.factory import (
|
from .llms.prompt_templates.factory import (
|
||||||
prompt_factory,
|
prompt_factory,
|
||||||
custom_prompt,
|
custom_prompt,
|
||||||
|
@ -120,6 +121,7 @@ azure_chat_completions = AzureChatCompletion()
|
||||||
azure_text_completions = AzureTextCompletion()
|
azure_text_completions = AzureTextCompletion()
|
||||||
huggingface = Huggingface()
|
huggingface = Huggingface()
|
||||||
predibase_chat_completions = PredibaseChatCompletion()
|
predibase_chat_completions = PredibaseChatCompletion()
|
||||||
|
codestral_text_completions = CodestralTextCompletion()
|
||||||
triton_chat_completions = TritonChatCompletion()
|
triton_chat_completions = TritonChatCompletion()
|
||||||
bedrock_chat_completion = BedrockLLM()
|
bedrock_chat_completion = BedrockLLM()
|
||||||
bedrock_converse_chat_completion = BedrockConverseLLM()
|
bedrock_converse_chat_completion = BedrockConverseLLM()
|
||||||
|
@ -322,6 +324,8 @@ async def acompletion(
|
||||||
or custom_llm_provider == "deepinfra"
|
or custom_llm_provider == "deepinfra"
|
||||||
or custom_llm_provider == "perplexity"
|
or custom_llm_provider == "perplexity"
|
||||||
or custom_llm_provider == "groq"
|
or custom_llm_provider == "groq"
|
||||||
|
or custom_llm_provider == "codestral"
|
||||||
|
or custom_llm_provider == "text-completion-codestral"
|
||||||
or custom_llm_provider == "deepseek"
|
or custom_llm_provider == "deepseek"
|
||||||
or custom_llm_provider == "text-completion-openai"
|
or custom_llm_provider == "text-completion-openai"
|
||||||
or custom_llm_provider == "huggingface"
|
or custom_llm_provider == "huggingface"
|
||||||
|
@ -351,9 +355,10 @@ async def acompletion(
|
||||||
else:
|
else:
|
||||||
response = init_response # type: ignore
|
response = init_response # type: ignore
|
||||||
|
|
||||||
if custom_llm_provider == "text-completion-openai" and isinstance(
|
if (
|
||||||
response, TextCompletionResponse
|
custom_llm_provider == "text-completion-openai"
|
||||||
):
|
or custom_llm_provider == "text-completion-codestral"
|
||||||
|
) and isinstance(response, TextCompletionResponse):
|
||||||
response = litellm.OpenAITextCompletionConfig().convert_to_chat_model_response_object(
|
response = litellm.OpenAITextCompletionConfig().convert_to_chat_model_response_object(
|
||||||
response_object=response,
|
response_object=response,
|
||||||
model_response_object=litellm.ModelResponse(),
|
model_response_object=litellm.ModelResponse(),
|
||||||
|
@ -1046,6 +1051,7 @@ def completion(
|
||||||
or custom_llm_provider == "deepinfra"
|
or custom_llm_provider == "deepinfra"
|
||||||
or custom_llm_provider == "perplexity"
|
or custom_llm_provider == "perplexity"
|
||||||
or custom_llm_provider == "groq"
|
or custom_llm_provider == "groq"
|
||||||
|
or custom_llm_provider == "codestral"
|
||||||
or custom_llm_provider == "deepseek"
|
or custom_llm_provider == "deepseek"
|
||||||
or custom_llm_provider == "anyscale"
|
or custom_llm_provider == "anyscale"
|
||||||
or custom_llm_provider == "mistral"
|
or custom_llm_provider == "mistral"
|
||||||
|
@ -2024,6 +2030,46 @@ def completion(
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
"stream" in optional_params
|
||||||
|
and optional_params["stream"] is True
|
||||||
|
and acompletion is False
|
||||||
|
):
|
||||||
|
return _model_response
|
||||||
|
response = _model_response
|
||||||
|
elif custom_llm_provider == "text-completion-codestral":
|
||||||
|
|
||||||
|
api_base = (
|
||||||
|
api_base
|
||||||
|
or optional_params.pop("api_base", None)
|
||||||
|
or optional_params.pop("base_url", None)
|
||||||
|
or litellm.api_base
|
||||||
|
or "https://codestral.mistral.ai/v1/fim/completions"
|
||||||
|
)
|
||||||
|
|
||||||
|
api_key = api_key or litellm.api_key or get_secret("CODESTRAL_API_KEY")
|
||||||
|
|
||||||
|
text_completion_model_response = litellm.TextCompletionResponse(
|
||||||
|
stream=stream
|
||||||
|
)
|
||||||
|
|
||||||
|
_model_response = codestral_text_completions.completion( # type: ignore
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
model_response=text_completion_model_response,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
logger_fn=logger_fn,
|
||||||
|
encoding=encoding,
|
||||||
|
logging_obj=logging,
|
||||||
|
acompletion=acompletion,
|
||||||
|
api_base=api_base,
|
||||||
|
custom_prompt_dict=custom_prompt_dict,
|
||||||
|
api_key=api_key,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
"stream" in optional_params
|
"stream" in optional_params
|
||||||
and optional_params["stream"] is True
|
and optional_params["stream"] is True
|
||||||
|
@ -3413,7 +3459,9 @@ def embedding(
|
||||||
|
|
||||||
###### Text Completion ################
|
###### Text Completion ################
|
||||||
@client
|
@client
|
||||||
async def atext_completion(*args, **kwargs):
|
async def atext_completion(
|
||||||
|
*args, **kwargs
|
||||||
|
) -> Union[TextCompletionResponse, TextCompletionStreamWrapper]:
|
||||||
"""
|
"""
|
||||||
Implemented to handle async streaming for the text completion endpoint
|
Implemented to handle async streaming for the text completion endpoint
|
||||||
"""
|
"""
|
||||||
|
@ -3445,6 +3493,7 @@ async def atext_completion(*args, **kwargs):
|
||||||
or custom_llm_provider == "deepinfra"
|
or custom_llm_provider == "deepinfra"
|
||||||
or custom_llm_provider == "perplexity"
|
or custom_llm_provider == "perplexity"
|
||||||
or custom_llm_provider == "groq"
|
or custom_llm_provider == "groq"
|
||||||
|
or custom_llm_provider == "text-completion-codestral"
|
||||||
or custom_llm_provider == "deepseek"
|
or custom_llm_provider == "deepseek"
|
||||||
or custom_llm_provider == "fireworks_ai"
|
or custom_llm_provider == "fireworks_ai"
|
||||||
or custom_llm_provider == "text-completion-openai"
|
or custom_llm_provider == "text-completion-openai"
|
||||||
|
@ -3706,6 +3755,7 @@ def text_completion(
|
||||||
custom_llm_provider == "openai"
|
custom_llm_provider == "openai"
|
||||||
or custom_llm_provider == "azure"
|
or custom_llm_provider == "azure"
|
||||||
or custom_llm_provider == "azure_text"
|
or custom_llm_provider == "azure_text"
|
||||||
|
or custom_llm_provider == "text-completion-codestral"
|
||||||
or custom_llm_provider == "text-completion-openai"
|
or custom_llm_provider == "text-completion-openai"
|
||||||
)
|
)
|
||||||
and isinstance(prompt, list)
|
and isinstance(prompt, list)
|
||||||
|
|
|
@ -817,6 +817,34 @@ def test_completion_mistral_api():
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_completion_codestral_chat_api():
|
||||||
|
try:
|
||||||
|
litellm.set_verbose = True
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model="codestral/codestral-latest",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hey, how's it going?",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
temperature=0.0,
|
||||||
|
top_p=1,
|
||||||
|
max_tokens=10,
|
||||||
|
safe_prompt=False,
|
||||||
|
seed=12,
|
||||||
|
)
|
||||||
|
# Add any assertions here to-check the response
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
# cost = litellm.completion_cost(completion_response=response)
|
||||||
|
# print("cost to make mistral completion=", cost)
|
||||||
|
# assert cost > 0.0
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
def test_completion_mistral_api_mistral_large_function_call():
|
def test_completion_mistral_api_mistral_large_function_call():
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
tools = [
|
tools = [
|
||||||
|
|
|
@ -4076,3 +4076,72 @@ async def test_async_text_completion_chat_model_stream():
|
||||||
|
|
||||||
|
|
||||||
# asyncio.run(test_async_text_completion_chat_model_stream())
|
# asyncio.run(test_async_text_completion_chat_model_stream())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_completion_codestral_fim_api():
|
||||||
|
try:
|
||||||
|
litellm.set_verbose = True
|
||||||
|
from litellm._logging import verbose_logger
|
||||||
|
import logging
|
||||||
|
|
||||||
|
verbose_logger.setLevel(level=logging.DEBUG)
|
||||||
|
response = await litellm.atext_completion(
|
||||||
|
model="text-completion-codestral/codestral-2405",
|
||||||
|
prompt="def is_odd(n): \n return n % 2 == 1 \ndef test_is_odd():",
|
||||||
|
suffix="return True",
|
||||||
|
temperature=0,
|
||||||
|
top_p=1,
|
||||||
|
max_tokens=10,
|
||||||
|
min_tokens=10,
|
||||||
|
seed=10,
|
||||||
|
stop=["return"],
|
||||||
|
)
|
||||||
|
# Add any assertions here to check the response
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
assert response.choices[0].text is not None
|
||||||
|
assert len(response.choices[0].text) > 0
|
||||||
|
|
||||||
|
# cost = litellm.completion_cost(completion_response=response)
|
||||||
|
# print("cost to make mistral completion=", cost)
|
||||||
|
# assert cost > 0.0
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_completion_codestral_fim_api_stream():
|
||||||
|
try:
|
||||||
|
from litellm._logging import verbose_logger
|
||||||
|
import logging
|
||||||
|
|
||||||
|
litellm.set_verbose = False
|
||||||
|
|
||||||
|
# verbose_logger.setLevel(level=logging.DEBUG)
|
||||||
|
response = await litellm.atext_completion(
|
||||||
|
model="text-completion-codestral/codestral-2405",
|
||||||
|
prompt="def is_odd(n): \n return n % 2 == 1 \ndef test_is_odd():",
|
||||||
|
suffix="return True",
|
||||||
|
temperature=0,
|
||||||
|
top_p=1,
|
||||||
|
stream=True,
|
||||||
|
seed=10,
|
||||||
|
stop=["return"],
|
||||||
|
)
|
||||||
|
|
||||||
|
full_response = ""
|
||||||
|
# Add any assertions here to check the response
|
||||||
|
async for chunk in response:
|
||||||
|
print(chunk)
|
||||||
|
full_response += chunk.get("choices")[0].get("text") or ""
|
||||||
|
|
||||||
|
print("full_response", full_response)
|
||||||
|
|
||||||
|
assert len(full_response) > 2 # we at least have a few chars in response :)
|
||||||
|
|
||||||
|
# cost = litellm.completion_cost(completion_response=response)
|
||||||
|
# print("cost to make mistral completion=", cost)
|
||||||
|
# assert cost > 0.0
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
|
@ -2366,6 +2366,7 @@ def get_optional_params(
|
||||||
and custom_llm_provider != "together_ai"
|
and custom_llm_provider != "together_ai"
|
||||||
and custom_llm_provider != "groq"
|
and custom_llm_provider != "groq"
|
||||||
and custom_llm_provider != "deepseek"
|
and custom_llm_provider != "deepseek"
|
||||||
|
and custom_llm_provider != "codestral"
|
||||||
and custom_llm_provider != "mistral"
|
and custom_llm_provider != "mistral"
|
||||||
and custom_llm_provider != "anthropic"
|
and custom_llm_provider != "anthropic"
|
||||||
and custom_llm_provider != "cohere_chat"
|
and custom_llm_provider != "cohere_chat"
|
||||||
|
@ -2974,7 +2975,7 @@ def get_optional_params(
|
||||||
optional_params["stream"] = stream
|
optional_params["stream"] = stream
|
||||||
if max_tokens:
|
if max_tokens:
|
||||||
optional_params["max_tokens"] = max_tokens
|
optional_params["max_tokens"] = max_tokens
|
||||||
elif custom_llm_provider == "mistral":
|
elif custom_llm_provider == "mistral" or custom_llm_provider == "codestral":
|
||||||
supported_params = get_supported_openai_params(
|
supported_params = get_supported_openai_params(
|
||||||
model=model, custom_llm_provider=custom_llm_provider
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
)
|
)
|
||||||
|
@ -2982,6 +2983,15 @@ def get_optional_params(
|
||||||
optional_params = litellm.MistralConfig().map_openai_params(
|
optional_params = litellm.MistralConfig().map_openai_params(
|
||||||
non_default_params=non_default_params, optional_params=optional_params
|
non_default_params=non_default_params, optional_params=optional_params
|
||||||
)
|
)
|
||||||
|
elif custom_llm_provider == "text-completion-codestral":
|
||||||
|
supported_params = get_supported_openai_params(
|
||||||
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
|
)
|
||||||
|
_check_valid_arg(supported_params=supported_params)
|
||||||
|
optional_params = litellm.MistralTextCompletionConfig().map_openai_params(
|
||||||
|
non_default_params=non_default_params, optional_params=optional_params
|
||||||
|
)
|
||||||
|
|
||||||
elif custom_llm_provider == "databricks":
|
elif custom_llm_provider == "databricks":
|
||||||
supported_params = get_supported_openai_params(
|
supported_params = get_supported_openai_params(
|
||||||
model=model, custom_llm_provider=custom_llm_provider
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
|
@ -3014,7 +3024,6 @@ def get_optional_params(
|
||||||
optional_params["response_format"] = response_format
|
optional_params["response_format"] = response_format
|
||||||
if seed is not None:
|
if seed is not None:
|
||||||
optional_params["seed"] = seed
|
optional_params["seed"] = seed
|
||||||
|
|
||||||
elif custom_llm_provider == "deepseek":
|
elif custom_llm_provider == "deepseek":
|
||||||
supported_params = get_supported_openai_params(
|
supported_params = get_supported_openai_params(
|
||||||
model=model, custom_llm_provider=custom_llm_provider
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
|
@ -3633,11 +3642,14 @@ def get_supported_openai_params(
|
||||||
"tool_choice",
|
"tool_choice",
|
||||||
"max_retries",
|
"max_retries",
|
||||||
]
|
]
|
||||||
elif custom_llm_provider == "mistral":
|
elif custom_llm_provider == "mistral" or custom_llm_provider == "codestral":
|
||||||
|
# mistal and codestral api have the exact same params
|
||||||
if request_type == "chat_completion":
|
if request_type == "chat_completion":
|
||||||
return litellm.MistralConfig().get_supported_openai_params()
|
return litellm.MistralConfig().get_supported_openai_params()
|
||||||
elif request_type == "embeddings":
|
elif request_type == "embeddings":
|
||||||
return litellm.MistralEmbeddingConfig().get_supported_openai_params()
|
return litellm.MistralEmbeddingConfig().get_supported_openai_params()
|
||||||
|
elif custom_llm_provider == "text-completion-codestral":
|
||||||
|
return litellm.MistralTextCompletionConfig().get_supported_openai_params()
|
||||||
elif custom_llm_provider == "replicate":
|
elif custom_llm_provider == "replicate":
|
||||||
return [
|
return [
|
||||||
"stream",
|
"stream",
|
||||||
|
@ -3874,6 +3886,10 @@ def get_llm_provider(
|
||||||
# groq is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.groq.com/openai/v1
|
# groq is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.groq.com/openai/v1
|
||||||
api_base = "https://api.groq.com/openai/v1"
|
api_base = "https://api.groq.com/openai/v1"
|
||||||
dynamic_api_key = get_secret("GROQ_API_KEY")
|
dynamic_api_key = get_secret("GROQ_API_KEY")
|
||||||
|
elif custom_llm_provider == "codestral":
|
||||||
|
# codestral is openai compatible, we just need to set this to custom_openai and have the api_base be https://codestral.mistral.ai/v1
|
||||||
|
api_base = "https://codestral.mistral.ai/v1"
|
||||||
|
dynamic_api_key = get_secret("CODESTRAL_API_KEY")
|
||||||
elif custom_llm_provider == "deepseek":
|
elif custom_llm_provider == "deepseek":
|
||||||
# deepseek is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.deepseek.com/v1
|
# deepseek is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.deepseek.com/v1
|
||||||
api_base = "https://api.deepseek.com/v1"
|
api_base = "https://api.deepseek.com/v1"
|
||||||
|
@ -3966,6 +3982,12 @@ def get_llm_provider(
|
||||||
elif endpoint == "api.groq.com/openai/v1":
|
elif endpoint == "api.groq.com/openai/v1":
|
||||||
custom_llm_provider = "groq"
|
custom_llm_provider = "groq"
|
||||||
dynamic_api_key = get_secret("GROQ_API_KEY")
|
dynamic_api_key = get_secret("GROQ_API_KEY")
|
||||||
|
elif endpoint == "https://codestral.mistral.ai/v1":
|
||||||
|
custom_llm_provider = "codestral"
|
||||||
|
dynamic_api_key = get_secret("CODESTRAL_API_KEY")
|
||||||
|
elif endpoint == "https://codestral.mistral.ai/v1":
|
||||||
|
custom_llm_provider = "text-completion-codestral"
|
||||||
|
dynamic_api_key = get_secret("CODESTRAL_API_KEY")
|
||||||
elif endpoint == "api.deepseek.com/v1":
|
elif endpoint == "api.deepseek.com/v1":
|
||||||
custom_llm_provider = "deepseek"
|
custom_llm_provider = "deepseek"
|
||||||
dynamic_api_key = get_secret("DEEPSEEK_API_KEY")
|
dynamic_api_key = get_secret("DEEPSEEK_API_KEY")
|
||||||
|
@ -4677,6 +4699,14 @@ def validate_environment(model: Optional[str] = None) -> dict:
|
||||||
keys_in_environment = True
|
keys_in_environment = True
|
||||||
else:
|
else:
|
||||||
missing_keys.append("GROQ_API_KEY")
|
missing_keys.append("GROQ_API_KEY")
|
||||||
|
elif (
|
||||||
|
custom_llm_provider == "codestral"
|
||||||
|
or custom_llm_provider == "text-completion-codestral"
|
||||||
|
):
|
||||||
|
if "CODESTRAL_API_KEY" in os.environ:
|
||||||
|
keys_in_environment = True
|
||||||
|
else:
|
||||||
|
missing_keys.append("GROQ_API_KEY")
|
||||||
elif custom_llm_provider == "deepseek":
|
elif custom_llm_provider == "deepseek":
|
||||||
if "DEEPSEEK_API_KEY" in os.environ:
|
if "DEEPSEEK_API_KEY" in os.environ:
|
||||||
keys_in_environment = True
|
keys_in_environment = True
|
||||||
|
@ -8548,6 +8578,25 @@ class CustomStreamWrapper:
|
||||||
completion_tokens=response_obj["usage"].completion_tokens,
|
completion_tokens=response_obj["usage"].completion_tokens,
|
||||||
total_tokens=response_obj["usage"].total_tokens,
|
total_tokens=response_obj["usage"].total_tokens,
|
||||||
)
|
)
|
||||||
|
elif self.custom_llm_provider == "text-completion-codestral":
|
||||||
|
response_obj = litellm.MistralTextCompletionConfig()._chunk_parser(
|
||||||
|
chunk
|
||||||
|
)
|
||||||
|
completion_obj["content"] = response_obj["text"]
|
||||||
|
print_verbose(f"completion obj content: {completion_obj['content']}")
|
||||||
|
if response_obj["is_finished"]:
|
||||||
|
self.received_finish_reason = response_obj["finish_reason"]
|
||||||
|
if (
|
||||||
|
self.stream_options
|
||||||
|
and self.stream_options.get("include_usage", False) == True
|
||||||
|
and response_obj["usage"] is not None
|
||||||
|
):
|
||||||
|
self.sent_stream_usage = True
|
||||||
|
model_response.usage = litellm.Usage(
|
||||||
|
prompt_tokens=response_obj["usage"].prompt_tokens,
|
||||||
|
completion_tokens=response_obj["usage"].completion_tokens,
|
||||||
|
total_tokens=response_obj["usage"].total_tokens,
|
||||||
|
)
|
||||||
elif self.custom_llm_provider == "databricks":
|
elif self.custom_llm_provider == "databricks":
|
||||||
response_obj = litellm.DatabricksConfig()._chunk_parser(chunk)
|
response_obj = litellm.DatabricksConfig()._chunk_parser(chunk)
|
||||||
completion_obj["content"] = response_obj["text"]
|
completion_obj["content"] = response_obj["text"]
|
||||||
|
@ -9021,6 +9070,7 @@ class CustomStreamWrapper:
|
||||||
or self.custom_llm_provider == "azure"
|
or self.custom_llm_provider == "azure"
|
||||||
or self.custom_llm_provider == "custom_openai"
|
or self.custom_llm_provider == "custom_openai"
|
||||||
or self.custom_llm_provider == "text-completion-openai"
|
or self.custom_llm_provider == "text-completion-openai"
|
||||||
|
or self.custom_llm_provider == "text-completion-codestral"
|
||||||
or self.custom_llm_provider == "azure_text"
|
or self.custom_llm_provider == "azure_text"
|
||||||
or self.custom_llm_provider == "anthropic"
|
or self.custom_llm_provider == "anthropic"
|
||||||
or self.custom_llm_provider == "anthropic_text"
|
or self.custom_llm_provider == "anthropic_text"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue