forked from phoenix/litellm-mirror
feat - add fim codestral api
This commit is contained in:
parent
ad47fee181
commit
364492297d
7 changed files with 549 additions and 132 deletions
|
@ -795,11 +795,11 @@ from .llms.openai import (
|
|||
OpenAIConfig,
|
||||
OpenAITextCompletionConfig,
|
||||
MistralConfig,
|
||||
MistralTextCompletionConfig,
|
||||
MistralEmbeddingConfig,
|
||||
DeepInfraConfig,
|
||||
AzureAIStudioConfig,
|
||||
)
|
||||
from .llms.text_completion_codestral import MistralTextCompletionConfig
|
||||
from .llms.azure import (
|
||||
AzureOpenAIConfig,
|
||||
AzureOpenAIError,
|
||||
|
|
|
@ -27,6 +27,25 @@ class BaseLLM:
|
|||
"""
|
||||
return model_response
|
||||
|
||||
def process_text_completion_response(
|
||||
self,
|
||||
model: str,
|
||||
response: Union[requests.Response, httpx.Response],
|
||||
model_response: litellm.utils.TextCompletionResponse,
|
||||
stream: bool,
|
||||
logging_obj: Logging,
|
||||
optional_params: dict,
|
||||
api_key: str,
|
||||
data: Union[dict, str],
|
||||
messages: list,
|
||||
print_verbose,
|
||||
encoding,
|
||||
) -> Union[litellm.utils.TextCompletionResponse, litellm.utils.CustomStreamWrapper]:
|
||||
"""
|
||||
Helper function to process the response across sync + async completion calls
|
||||
"""
|
||||
return model_response
|
||||
|
||||
def create_client_session(self):
|
||||
if litellm.client_session:
|
||||
_client_session = litellm.client_session
|
||||
|
|
|
@ -208,85 +208,6 @@ class MistralEmbeddingConfig:
|
|||
return optional_params
|
||||
|
||||
|
||||
class MistralTextCompletionConfig:
|
||||
"""
|
||||
Reference: https://docs.mistral.ai/api/#operation/createFIMCompletion
|
||||
"""
|
||||
|
||||
suffix: Optional[str] = None
|
||||
temperature: Optional[int] = None
|
||||
top_p: Optional[float] = None
|
||||
max_tokens: Optional[int] = None
|
||||
min_tokens: Optional[int] = None
|
||||
stream: Optional[bool] = None
|
||||
random_seed: Optional[int] = None
|
||||
stop: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
suffix: Optional[str] = None,
|
||||
temperature: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
min_tokens: Optional[int] = None,
|
||||
stream: Optional[bool] = None,
|
||||
random_seed: Optional[int] = None,
|
||||
stop: Optional[str] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self):
|
||||
return [
|
||||
"suffix",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"max_tokens",
|
||||
"stream",
|
||||
"seed",
|
||||
"stop",
|
||||
]
|
||||
|
||||
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
||||
for param, value in non_default_params.items():
|
||||
if param == "suffix":
|
||||
optional_params["suffix"] = value
|
||||
if param == "temperature":
|
||||
optional_params["temperature"] = value
|
||||
if param == "top_p":
|
||||
optional_params["top_p"] = value
|
||||
if param == "max_tokens":
|
||||
optional_params["max_tokens"] = value
|
||||
if param == "stream" and value == True:
|
||||
optional_params["stream"] = value
|
||||
if param == "stop":
|
||||
optional_params["stop"] = value
|
||||
if param == "seed":
|
||||
optional_params["extra_body"] = {"random_seed": value}
|
||||
|
||||
return optional_params
|
||||
|
||||
|
||||
class AzureAIStudioConfig:
|
||||
def get_required_params(self) -> List[ProviderField]:
|
||||
"""For a given provider, return it's required fields with a description"""
|
||||
|
|
461
litellm/llms/text_completion_codestral.py
Normal file
461
litellm/llms/text_completion_codestral.py
Normal file
|
@ -0,0 +1,461 @@
|
|||
# What is this?
|
||||
## Controller file for TextCompletionCodestral Integration - https://codestral.com/
|
||||
|
||||
from functools import partial
|
||||
import os, types
|
||||
import traceback
|
||||
import json
|
||||
from enum import Enum
|
||||
import requests, copy # type: ignore
|
||||
import time
|
||||
from typing import Callable, Optional, List, Literal, Union
|
||||
from litellm.utils import (
|
||||
TextCompletionResponse,
|
||||
Usage,
|
||||
CustomStreamWrapper,
|
||||
Message,
|
||||
Choices,
|
||||
)
|
||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||
import litellm
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
from .base import BaseLLM
|
||||
import httpx # type: ignore
|
||||
|
||||
|
||||
class TextCompletionCodestralError(Exception):
|
||||
def __init__(
|
||||
self,
|
||||
status_code,
|
||||
message,
|
||||
request: Optional[httpx.Request] = None,
|
||||
response: Optional[httpx.Response] = None,
|
||||
):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
if request is not None:
|
||||
self.request = request
|
||||
else:
|
||||
self.request = httpx.Request(
|
||||
method="POST",
|
||||
url="https://docs.codestral.com/user-guide/inference/rest_api",
|
||||
)
|
||||
if response is not None:
|
||||
self.response = response
|
||||
else:
|
||||
self.response = httpx.Response(
|
||||
status_code=status_code, request=self.request
|
||||
)
|
||||
super().__init__(
|
||||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
async def make_call(
|
||||
client: AsyncHTTPHandler,
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
data: str,
|
||||
model: str,
|
||||
messages: list,
|
||||
logging_obj,
|
||||
):
|
||||
response = await client.post(api_base, headers=headers, data=data, stream=True)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise TextCompletionCodestralError(
|
||||
status_code=response.status_code, message=response.text
|
||||
)
|
||||
|
||||
completion_stream = response.aiter_lines()
|
||||
# LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
original_response=completion_stream, # Pass the completion stream for logging
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
|
||||
return completion_stream
|
||||
|
||||
|
||||
class MistralTextCompletionConfig:
|
||||
"""
|
||||
Reference: https://docs.mistral.ai/api/#operation/createFIMCompletion
|
||||
"""
|
||||
|
||||
suffix: Optional[str] = None
|
||||
temperature: Optional[int] = None
|
||||
top_p: Optional[float] = None
|
||||
max_tokens: Optional[int] = None
|
||||
min_tokens: Optional[int] = None
|
||||
stream: Optional[bool] = None
|
||||
random_seed: Optional[int] = None
|
||||
stop: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
suffix: Optional[str] = None,
|
||||
temperature: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
min_tokens: Optional[int] = None,
|
||||
stream: Optional[bool] = None,
|
||||
random_seed: Optional[int] = None,
|
||||
stop: Optional[str] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self):
|
||||
return [
|
||||
"suffix",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"max_tokens",
|
||||
"stream",
|
||||
"seed",
|
||||
"stop",
|
||||
]
|
||||
|
||||
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
||||
for param, value in non_default_params.items():
|
||||
if param == "suffix":
|
||||
optional_params["suffix"] = value
|
||||
if param == "temperature":
|
||||
optional_params["temperature"] = value
|
||||
if param == "top_p":
|
||||
optional_params["top_p"] = value
|
||||
if param == "max_tokens":
|
||||
optional_params["max_tokens"] = value
|
||||
if param == "stream" and value == True:
|
||||
optional_params["stream"] = value
|
||||
if param == "stop":
|
||||
optional_params["stop"] = value
|
||||
if param == "seed":
|
||||
optional_params["random_seed"] = value
|
||||
if param == "min_tokens":
|
||||
optional_params["min_tokens"] = value
|
||||
|
||||
return optional_params
|
||||
|
||||
|
||||
class CodestralTextCompletion(BaseLLM):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def _validate_environment(
|
||||
self,
|
||||
api_key: Optional[str],
|
||||
user_headers: dict,
|
||||
) -> dict:
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"Missing CODESTRAL_API_Key - Please add CODESTRAL_API_Key to your environment variables"
|
||||
)
|
||||
headers = {
|
||||
"content-type": "application/json",
|
||||
"Authorization": "Bearer {}".format(api_key),
|
||||
}
|
||||
if user_headers is not None and isinstance(user_headers, dict):
|
||||
headers = {**headers, **user_headers}
|
||||
return headers
|
||||
|
||||
def output_parser(self, generated_text: str):
|
||||
"""
|
||||
Parse the output text to remove any special characters. In our current approach we just check for ChatML tokens.
|
||||
|
||||
Initial issue that prompted this - https://github.com/BerriAI/litellm/issues/763
|
||||
"""
|
||||
chat_template_tokens = [
|
||||
"<|assistant|>",
|
||||
"<|system|>",
|
||||
"<|user|>",
|
||||
"<s>",
|
||||
"</s>",
|
||||
]
|
||||
for token in chat_template_tokens:
|
||||
if generated_text.strip().startswith(token):
|
||||
generated_text = generated_text.replace(token, "", 1)
|
||||
if generated_text.endswith(token):
|
||||
generated_text = generated_text[::-1].replace(token[::-1], "", 1)[::-1]
|
||||
return generated_text
|
||||
|
||||
def process_text_completion_response(
|
||||
self,
|
||||
model: str,
|
||||
response: Union[requests.Response, httpx.Response],
|
||||
model_response: TextCompletionResponse,
|
||||
stream: bool,
|
||||
logging_obj: litellm.litellm_core_utils.litellm_logging.Logging,
|
||||
optional_params: dict,
|
||||
api_key: str,
|
||||
data: Union[dict, str],
|
||||
messages: list,
|
||||
print_verbose,
|
||||
encoding,
|
||||
) -> TextCompletionResponse:
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=response.text,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
print_verbose(f"raw model_response: {response.text}")
|
||||
## RESPONSE OBJECT
|
||||
if response.status_code != 200:
|
||||
raise TextCompletionCodestralError(
|
||||
message=str(response.text),
|
||||
status_code=response.status_code,
|
||||
)
|
||||
try:
|
||||
completion_response = response.json()
|
||||
except:
|
||||
raise TextCompletionCodestralError(message=response.text, status_code=422)
|
||||
|
||||
_response = litellm.TextCompletionResponse(**completion_response)
|
||||
return _response
|
||||
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: TextCompletionResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
api_key: str,
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
acompletion=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers: dict = {},
|
||||
) -> Union[TextCompletionResponse, CustomStreamWrapper]:
|
||||
headers = self._validate_environment(api_key, headers)
|
||||
|
||||
completion_url = api_base or "https://codestral.mistral.ai/v1/fim/completions"
|
||||
|
||||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details["roles"],
|
||||
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||
messages=messages,
|
||||
)
|
||||
else:
|
||||
prompt = prompt_factory(model=model, messages=messages)
|
||||
|
||||
## Load Config
|
||||
config = litellm.MistralTextCompletionConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in optional_params
|
||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
optional_params[k] = v
|
||||
|
||||
stream = optional_params.pop("stream", False)
|
||||
|
||||
data = {
|
||||
"prompt": prompt,
|
||||
**optional_params,
|
||||
}
|
||||
input_text = prompt
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=input_text,
|
||||
api_key=api_key,
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"headers": headers,
|
||||
"api_base": completion_url,
|
||||
"acompletion": acompletion,
|
||||
},
|
||||
)
|
||||
## COMPLETION CALL
|
||||
if acompletion is True:
|
||||
### ASYNC STREAMING
|
||||
if stream is True:
|
||||
return self.async_streaming(
|
||||
model=model,
|
||||
messages=messages,
|
||||
data=data,
|
||||
api_base=completion_url,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
) # type: ignore
|
||||
else:
|
||||
### ASYNC COMPLETION
|
||||
return self.async_completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
data=data,
|
||||
api_base=completion_url,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
stream=False,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
) # type: ignore
|
||||
|
||||
### SYNC STREAMING
|
||||
if stream is True:
|
||||
response = requests.post(
|
||||
completion_url,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
stream=stream,
|
||||
)
|
||||
_response = CustomStreamWrapper(
|
||||
response.iter_lines(),
|
||||
model,
|
||||
custom_llm_provider="codestral",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return _response
|
||||
### SYNC COMPLETION
|
||||
else:
|
||||
response = requests.post(
|
||||
url=completion_url,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
)
|
||||
return self.process_text_completion_response(
|
||||
model=model,
|
||||
response=response,
|
||||
model_response=model_response,
|
||||
stream=optional_params.get("stream", False),
|
||||
logging_obj=logging_obj, # type: ignore
|
||||
optional_params=optional_params,
|
||||
api_key=api_key,
|
||||
data=data,
|
||||
messages=messages,
|
||||
print_verbose=print_verbose,
|
||||
encoding=encoding,
|
||||
)
|
||||
|
||||
async def async_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
model_response: TextCompletionResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
stream,
|
||||
data: dict,
|
||||
optional_params: dict,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
) -> TextCompletionResponse:
|
||||
|
||||
async_handler = AsyncHTTPHandler(timeout=httpx.Timeout(timeout=timeout))
|
||||
try:
|
||||
response = await async_handler.post(
|
||||
api_base, headers=headers, data=json.dumps(data)
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise TextCompletionCodestralError(
|
||||
status_code=e.response.status_code,
|
||||
message="HTTPStatusError - {}".format(e.response.text),
|
||||
)
|
||||
except Exception as e:
|
||||
raise TextCompletionCodestralError(
|
||||
status_code=500, message="{}\n{}".format(str(e), traceback.format_exc())
|
||||
)
|
||||
return self.process_text_completion_response(
|
||||
model=model,
|
||||
response=response,
|
||||
model_response=model_response,
|
||||
stream=stream,
|
||||
logging_obj=logging_obj,
|
||||
api_key=api_key,
|
||||
data=data,
|
||||
messages=messages,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
encoding=encoding,
|
||||
)
|
||||
|
||||
async def async_streaming(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
model_response: TextCompletionResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
data: dict,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
) -> CustomStreamWrapper:
|
||||
data["stream"] = True
|
||||
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=None,
|
||||
make_call=partial(
|
||||
make_call,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
model=model,
|
||||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
),
|
||||
model=model,
|
||||
custom_llm_provider="codestral",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return streamwrapper
|
||||
|
||||
def embedding(self, *args, **kwargs):
|
||||
pass
|
|
@ -82,6 +82,7 @@ from .llms.predibase import PredibaseChatCompletion
|
|||
from .llms.bedrock_httpx import BedrockLLM, BedrockConverseLLM
|
||||
from .llms.vertex_httpx import VertexLLM
|
||||
from .llms.triton import TritonChatCompletion
|
||||
from .llms.text_completion_codestral import CodestralTextCompletion
|
||||
from .llms.prompt_templates.factory import (
|
||||
prompt_factory,
|
||||
custom_prompt,
|
||||
|
@ -120,6 +121,7 @@ azure_chat_completions = AzureChatCompletion()
|
|||
azure_text_completions = AzureTextCompletion()
|
||||
huggingface = Huggingface()
|
||||
predibase_chat_completions = PredibaseChatCompletion()
|
||||
codestral_text_completions = CodestralTextCompletion()
|
||||
triton_chat_completions = TritonChatCompletion()
|
||||
bedrock_chat_completion = BedrockLLM()
|
||||
bedrock_converse_chat_completion = BedrockConverseLLM()
|
||||
|
@ -2027,6 +2029,46 @@ def completion(
|
|||
timeout=timeout,
|
||||
)
|
||||
|
||||
if (
|
||||
"stream" in optional_params
|
||||
and optional_params["stream"] is True
|
||||
and acompletion is False
|
||||
):
|
||||
return _model_response
|
||||
response = _model_response
|
||||
elif custom_llm_provider == "text-completion-codestral":
|
||||
|
||||
api_base = (
|
||||
api_base
|
||||
or optional_params.pop("api_base", None)
|
||||
or optional_params.pop("base_url", None)
|
||||
or litellm.api_base
|
||||
or "https://codestral.mistral.ai/v1/fim/completions"
|
||||
)
|
||||
|
||||
api_key = api_key or litellm.api_key or get_secret("CODESTRAL_API_KEY")
|
||||
|
||||
text_completion_model_response = litellm.TextCompletionResponse(
|
||||
stream=stream
|
||||
)
|
||||
|
||||
_model_response = codestral_text_completions.completion( # type: ignore
|
||||
model=model,
|
||||
messages=messages,
|
||||
model_response=text_completion_model_response,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
logging_obj=logging,
|
||||
acompletion=acompletion,
|
||||
api_base=api_base,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
api_key=api_key,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
if (
|
||||
"stream" in optional_params
|
||||
and optional_params["stream"] is True
|
||||
|
|
|
@ -4078,29 +4078,30 @@ async def test_async_text_completion_chat_model_stream():
|
|||
# asyncio.run(test_async_text_completion_chat_model_stream())
|
||||
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_completion_codestral_fim_api():
|
||||
# try:
|
||||
# litellm.set_verbose = True
|
||||
# from litellm._logging import verbose_logger
|
||||
# import logging
|
||||
# verbose_logger.setLevel(level=logging.DEBUG)
|
||||
# response = await litellm.atext_completion(
|
||||
# model="text-completion-codestral/codestral-2405",
|
||||
# prompt="def is_odd(n): \n return n % 2 == 1 \ndef test_is_odd():",
|
||||
# suffix="return True",
|
||||
# temperature=0,
|
||||
# top_p=0.4,
|
||||
# max_tokens=10,
|
||||
# # min_tokens=10,
|
||||
# seed=10,
|
||||
# stop=["return"],
|
||||
# )
|
||||
# # Add any assertions here to check the response
|
||||
# print(response)
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_codestral_fim_api():
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
from litellm._logging import verbose_logger
|
||||
import logging
|
||||
|
||||
# # cost = litellm.completion_cost(completion_response=response)
|
||||
# # print("cost to make mistral completion=", cost)
|
||||
# # assert cost > 0.0
|
||||
# except Exception as e:
|
||||
# pytest.fail(f"Error occurred: {e}")
|
||||
verbose_logger.setLevel(level=logging.DEBUG)
|
||||
response = await litellm.atext_completion(
|
||||
model="text-completion-codestral/codestral-2405",
|
||||
prompt="def is_odd(n): \n return n % 2 == 1 \ndef test_is_odd():",
|
||||
suffix="return True",
|
||||
temperature=0,
|
||||
top_p=1,
|
||||
max_tokens=10,
|
||||
min_tokens=10,
|
||||
seed=10,
|
||||
stop=["return"],
|
||||
)
|
||||
# Add any assertions here to check the response
|
||||
print(response)
|
||||
|
||||
# cost = litellm.completion_cost(completion_response=response)
|
||||
# print("cost to make mistral completion=", cost)
|
||||
# assert cost > 0.0
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
|
|
@ -3017,29 +3017,6 @@ def get_optional_params(
|
|||
optional_params["response_format"] = response_format
|
||||
if seed is not None:
|
||||
optional_params["seed"] = seed
|
||||
elif custom_llm_provider == "codestral":
|
||||
# supported_params = get_supported_openai_params(
|
||||
# model=model, custom_llm_provider=custom_llm_provider
|
||||
# )
|
||||
# _check_valid_arg(supported_params=supported_params)
|
||||
# optional_params = litellm.DeepInfraConfig().map_openai_params(
|
||||
# non_default_params=non_default_params,
|
||||
# optional_params=optional_params,
|
||||
# model=model,
|
||||
# )
|
||||
pass
|
||||
elif custom_llm_provider == "text-completion-codestral":
|
||||
# supported_params = get_supported_openai_params(
|
||||
# model=model, custom_llm_provider=custom_llm_provider
|
||||
# )
|
||||
# _check_valid_arg(supported_params=supported_params)
|
||||
# optional_params = litellm.DeepInfraConfig().map_openai_params(
|
||||
# non_default_params=non_default_params,
|
||||
# optional_params=optional_params,
|
||||
# model=model,
|
||||
# )
|
||||
pass
|
||||
|
||||
elif custom_llm_provider == "deepseek":
|
||||
supported_params = get_supported_openai_params(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
|
@ -3906,10 +3883,6 @@ def get_llm_provider(
|
|||
# codestral is openai compatible, we just need to set this to custom_openai and have the api_base be https://codestral.mistral.ai/v1
|
||||
api_base = "https://codestral.mistral.ai/v1"
|
||||
dynamic_api_key = get_secret("CODESTRAL_API_KEY")
|
||||
elif custom_llm_provider == "text-completion-codestral":
|
||||
# codestral is openai compatible, we just need to set this to custom_openai and have the api_base be https://codestral.mistral.ai/v1
|
||||
api_base = "https://codestral.mistral.ai/v1"
|
||||
dynamic_api_key = get_secret("CODESTRAL_API_KEY")
|
||||
elif custom_llm_provider == "deepseek":
|
||||
# deepseek is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.deepseek.com/v1
|
||||
api_base = "https://api.deepseek.com/v1"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue