(Refactor) Code Quality improvement - rename text_completion_codestral.py -> codestral/completion/ (#7172)

* rename files

* fix codestral fim organization

* fix CodestralTextCompletionConfig

* fix import CodestralTextCompletion

* fix BaseLLM

* fix imports

* fix CodestralTextCompletionConfig

* fix imports CodestralTextCompletion
This commit is contained in:
Ishaan Jaff 2024-12-11 00:55:47 -08:00 committed by GitHub
parent 400eb28a91
commit 78d132c1fb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 164 additions and 162 deletions

View file

@ -1178,7 +1178,7 @@ from .llms.friendliai.chat.transformation import FriendliaiChatConfig
from .llms.jina_ai.embedding.transformation import JinaAIEmbeddingConfig from .llms.jina_ai.embedding.transformation import JinaAIEmbeddingConfig
from .llms.xai.chat.transformation import XAIChatConfig from .llms.xai.chat.transformation import XAIChatConfig
from .llms.volcengine import VolcEngineConfig from .llms.volcengine import VolcEngineConfig
from .llms.text_completion_codestral import MistralTextCompletionConfig from .llms.codestral.completion.transformation import CodestralTextCompletionConfig
from .llms.azure.azure import ( from .llms.azure.azure import (
AzureOpenAIError, AzureOpenAIError,
AzureOpenAIAssistantsAPIConfig, AzureOpenAIAssistantsAPIConfig,

View file

@ -107,7 +107,7 @@ def get_supported_openai_params( # noqa: PLR0915
elif request_type == "embeddings": elif request_type == "embeddings":
return litellm.MistralEmbeddingConfig().get_supported_openai_params() return litellm.MistralEmbeddingConfig().get_supported_openai_params()
elif custom_llm_provider == "text-completion-codestral": elif custom_llm_provider == "text-completion-codestral":
return litellm.MistralTextCompletionConfig().get_supported_openai_params( return litellm.CodestralTextCompletionConfig().get_supported_openai_params(
model=model model=model
) )
elif custom_llm_provider == "sambanova": elif custom_llm_provider == "sambanova":
@ -138,7 +138,7 @@ def get_supported_openai_params( # noqa: PLR0915
return litellm.MistralConfig().get_supported_openai_params(model=model) return litellm.MistralConfig().get_supported_openai_params(model=model)
if model.startswith("codestral"): if model.startswith("codestral"):
return ( return (
litellm.MistralTextCompletionConfig().get_supported_openai_params( litellm.CodestralTextCompletionConfig().get_supported_openai_params(
model=model model=model
) )
) )

View file

@ -1147,7 +1147,7 @@ class CustomStreamWrapper:
total_tokens=response_obj["usage"].total_tokens, total_tokens=response_obj["usage"].total_tokens,
) )
elif self.custom_llm_provider == "text-completion-codestral": elif self.custom_llm_provider == "text-completion-codestral":
response_obj = litellm.MistralTextCompletionConfig()._chunk_parser( response_obj = litellm.CodestralTextCompletionConfig()._chunk_parser(
chunk chunk
) )
completion_obj["content"] = response_obj["text"] completion_obj["content"] = response_obj["text"]

View file

@ -1,5 +1,5 @@
# What is this? # What is this?
## Controller file for TextCompletionCodestral Integration - https://codestral.com/ ## handler file for TextCompletionCodestral Integration - https://codestral.com/
import copy import copy
import json import json
@ -18,6 +18,11 @@ import litellm
from litellm import verbose_logger from litellm import verbose_logger
from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
from litellm.litellm_core_utils.prompt_templates.factory import (
custom_prompt,
prompt_factory,
)
from litellm.llms.base import BaseLLM
from litellm.llms.custom_httpx.http_handler import ( from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler, AsyncHTTPHandler,
get_async_httpx_client, get_async_httpx_client,
@ -32,9 +37,6 @@ from litellm.utils import (
Usage, Usage,
) )
from .base import BaseLLM
from litellm.litellm_core_utils.prompt_templates.factory import custom_prompt, prompt_factory
class TextCompletionCodestralError(Exception): class TextCompletionCodestralError(Exception):
def __init__( def __init__(
@ -92,111 +94,6 @@ async def make_call(
return completion_stream return completion_stream
class MistralTextCompletionConfig(OpenAITextCompletionConfig):
"""
Reference: https://docs.mistral.ai/api/#operation/createFIMCompletion
"""
suffix: Optional[str] = None
temperature: Optional[int] = None
max_tokens: Optional[int] = None
min_tokens: Optional[int] = None
stream: Optional[bool] = None
random_seed: Optional[int] = None
def __init__(
self,
suffix: Optional[str] = None,
temperature: Optional[int] = None,
top_p: Optional[float] = None,
max_tokens: Optional[int] = None,
min_tokens: Optional[int] = None,
stream: Optional[bool] = None,
random_seed: Optional[int] = None,
stop: Optional[str] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return super().get_config()
def get_supported_openai_params(self, model: str):
return [
"suffix",
"temperature",
"top_p",
"max_tokens",
"max_completion_tokens",
"stream",
"seed",
"stop",
]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
for param, value in non_default_params.items():
if param == "suffix":
optional_params["suffix"] = value
if param == "temperature":
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
if param == "max_tokens" or param == "max_completion_tokens":
optional_params["max_tokens"] = value
if param == "stream" and value is True:
optional_params["stream"] = value
if param == "stop":
optional_params["stop"] = value
if param == "seed":
optional_params["random_seed"] = value
if param == "min_tokens":
optional_params["min_tokens"] = value
return optional_params
def _chunk_parser(self, chunk_data: str) -> GenericStreamingChunk:
text = ""
is_finished = False
finish_reason = None
logprobs = None
chunk_data = chunk_data.replace("data:", "")
chunk_data = chunk_data.strip()
if len(chunk_data) == 0 or chunk_data == "[DONE]":
return {
"text": "",
"is_finished": is_finished,
"finish_reason": finish_reason,
}
chunk_data_dict = json.loads(chunk_data)
original_chunk = litellm.ModelResponse(**chunk_data_dict, stream=True)
_choices = chunk_data_dict.get("choices", []) or []
_choice = _choices[0]
text = _choice.get("delta", {}).get("content", "")
if _choice.get("finish_reason") is not None:
is_finished = True
finish_reason = _choice.get("finish_reason")
logprobs = _choice.get("logprobs")
return GenericStreamingChunk(
text=text,
original_chunk=original_chunk,
is_finished=is_finished,
finish_reason=finish_reason,
logprobs=logprobs,
)
class CodestralTextCompletion(BaseLLM): class CodestralTextCompletion(BaseLLM):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
@ -351,7 +248,7 @@ class CodestralTextCompletion(BaseLLM):
prompt = prompt_factory(model=model, messages=messages) prompt = prompt_factory(model=model, messages=messages)
## Load Config ## Load Config
config = litellm.MistralTextCompletionConfig.get_config() config = litellm.CodestralTextCompletionConfig.get_config()
for k, v in config.items(): for k, v in config.items():
if ( if (
k not in optional_params k not in optional_params

View file

@ -0,0 +1,110 @@
import json
from typing import Optional
import litellm
from litellm.llms.openai.completion.transformation import OpenAITextCompletionConfig
from litellm.types.llms.databricks import GenericStreamingChunk
class CodestralTextCompletionConfig(OpenAITextCompletionConfig):
"""
Reference: https://docs.mistral.ai/api/#operation/createFIMCompletion
"""
suffix: Optional[str] = None
temperature: Optional[int] = None
max_tokens: Optional[int] = None
min_tokens: Optional[int] = None
stream: Optional[bool] = None
random_seed: Optional[int] = None
def __init__(
self,
suffix: Optional[str] = None,
temperature: Optional[int] = None,
top_p: Optional[float] = None,
max_tokens: Optional[int] = None,
min_tokens: Optional[int] = None,
stream: Optional[bool] = None,
random_seed: Optional[int] = None,
stop: Optional[str] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return super().get_config()
def get_supported_openai_params(self, model: str):
return [
"suffix",
"temperature",
"top_p",
"max_tokens",
"max_completion_tokens",
"stream",
"seed",
"stop",
]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
for param, value in non_default_params.items():
if param == "suffix":
optional_params["suffix"] = value
if param == "temperature":
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
if param == "max_tokens" or param == "max_completion_tokens":
optional_params["max_tokens"] = value
if param == "stream" and value is True:
optional_params["stream"] = value
if param == "stop":
optional_params["stop"] = value
if param == "seed":
optional_params["random_seed"] = value
if param == "min_tokens":
optional_params["min_tokens"] = value
return optional_params
def _chunk_parser(self, chunk_data: str) -> GenericStreamingChunk:
text = ""
is_finished = False
finish_reason = None
logprobs = None
chunk_data = chunk_data.replace("data:", "")
chunk_data = chunk_data.strip()
if len(chunk_data) == 0 or chunk_data == "[DONE]":
return {
"text": "",
"is_finished": is_finished,
"finish_reason": finish_reason,
}
chunk_data_dict = json.loads(chunk_data)
original_chunk = litellm.ModelResponse(**chunk_data_dict, stream=True)
_choices = chunk_data_dict.get("choices", []) or []
_choice = _choices[0]
text = _choice.get("delta", {}).get("content", "")
if _choice.get("finish_reason") is not None:
is_finished = True
finish_reason = _choice.get("finish_reason")
logprobs = _choice.get("logprobs")
return GenericStreamingChunk(
text=text,
original_chunk=original_chunk,
is_finished=is_finished,
finish_reason=finish_reason,
logprobs=logprobs,
)

View file

@ -91,9 +91,11 @@ class VertexAIPartnerModels(VertexBase):
from google.cloud import aiplatform from google.cloud import aiplatform
from litellm.llms.anthropic.chat import AnthropicChatCompletion from litellm.llms.anthropic.chat import AnthropicChatCompletion
from litellm.llms.codestral.completion.handler import (
CodestralTextCompletion,
)
from litellm.llms.openai.openai import OpenAIChatCompletion from litellm.llms.openai.openai import OpenAIChatCompletion
from litellm.llms.openai_like.chat.handler import OpenAILikeChatHandler from litellm.llms.openai_like.chat.handler import OpenAILikeChatHandler
from litellm.llms.text_completion_codestral import CodestralTextCompletion
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import ( from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
VertexLLM, VertexLLM,
) )

View file

@ -75,11 +75,8 @@ class VertexAIModelGardenModels(VertexBase):
import vertexai import vertexai
from google.cloud import aiplatform from google.cloud import aiplatform
from litellm.llms.anthropic.chat import AnthropicChatCompletion
from litellm.llms.openai.openai import OpenAIChatCompletion
from litellm.llms.openai_like.chat.handler import OpenAILikeChatHandler from litellm.llms.openai_like.chat.handler import OpenAILikeChatHandler
from litellm.llms.text_completion_codestral import CodestralTextCompletion from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
VertexLLM, VertexLLM,
) )
except Exception: except Exception:

View file

@ -56,8 +56,10 @@ from litellm.litellm_core_utils.mock_functions import (
mock_embedding, mock_embedding,
mock_image_generation, mock_image_generation,
) )
from litellm.litellm_core_utils.prompt_templates.common_utils import (
get_content_from_model_response,
)
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.litellm_core_utils.prompt_templates.common_utils import get_content_from_model_response
from litellm.secret_managers.main import get_secret_str from litellm.secret_managers.main import get_secret_str
from litellm.utils import ( from litellm.utils import (
CustomStreamWrapper, CustomStreamWrapper,
@ -82,6 +84,15 @@ from litellm.utils import (
from ._logging import verbose_logger from ._logging import verbose_logger
from .caching.caching import disable_cache, enable_cache, update_cache from .caching.caching import disable_cache, enable_cache, update_cache
from .litellm_core_utils.prompt_templates.common_utils import get_completion_messages
from .litellm_core_utils.prompt_templates.factory import (
custom_prompt,
function_call_prompt,
map_system_message_pt,
ollama_pt,
prompt_factory,
stringify_json_tool_call_content,
)
from .litellm_core_utils.streaming_chunk_builder_utils import ChunkProcessor from .litellm_core_utils.streaming_chunk_builder_utils import ChunkProcessor
from .llms import baseten, maritalk, ollama_chat, petals from .llms import baseten, maritalk, ollama_chat, petals
from .llms.anthropic.chat import AnthropicChatCompletion from .llms.anthropic.chat import AnthropicChatCompletion
@ -93,6 +104,7 @@ from .llms.azure_ai.embed import AzureAIEmbedding
from .llms.bedrock.chat import BedrockConverseLLM, BedrockLLM from .llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
from .llms.bedrock.embed.embedding import BedrockEmbedding from .llms.bedrock.embed.embedding import BedrockEmbedding
from .llms.bedrock.image.image_handler import BedrockImageGeneration from .llms.bedrock.image.image_handler import BedrockImageGeneration
from .llms.codestral.completion.handler import CodestralTextCompletion
from .llms.cohere.embed import handler as cohere_embed from .llms.cohere.embed import handler as cohere_embed
from .llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler from .llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
from .llms.custom_llm import CustomLLM, custom_chat_llm_router from .llms.custom_llm import CustomLLM, custom_chat_llm_router
@ -102,33 +114,21 @@ from .llms.deprecated_providers import palm, aleph_alpha
from .llms.groq.chat.handler import GroqChatCompletion from .llms.groq.chat.handler import GroqChatCompletion
from .llms.huggingface.chat.handler import Huggingface from .llms.huggingface.chat.handler import Huggingface
from .llms.nlp_cloud.chat.handler import completion as nlp_cloud_chat_completion from .llms.nlp_cloud.chat.handler import completion as nlp_cloud_chat_completion
from .llms.oobabooga.chat import oobabooga
from .llms.ollama.completion import handler as ollama from .llms.ollama.completion import handler as ollama
from .llms.openai.transcriptions.handler import OpenAIAudioTranscription from .llms.oobabooga.chat import oobabooga
from .llms.openai.completion.handler import OpenAITextCompletion from .llms.openai.completion.handler import OpenAITextCompletion
from .llms.openai.openai import OpenAIChatCompletion from .llms.openai.openai import OpenAIChatCompletion
from .llms.openai.transcriptions.handler import OpenAIAudioTranscription
from .llms.openai_like.chat.handler import OpenAILikeChatHandler from .llms.openai_like.chat.handler import OpenAILikeChatHandler
from .llms.openai_like.embedding.handler import OpenAILikeEmbeddingHandler from .llms.openai_like.embedding.handler import OpenAILikeEmbeddingHandler
from .llms.predibase import PredibaseChatCompletion from .llms.predibase import PredibaseChatCompletion
from .litellm_core_utils.prompt_templates.common_utils import get_completion_messages from .llms.replicate.chat.handler import completion as replicate_chat_completion
from .litellm_core_utils.prompt_templates.factory import (
custom_prompt,
function_call_prompt,
map_system_message_pt,
ollama_pt,
prompt_factory,
stringify_json_tool_call_content,
)
from .llms.sagemaker.chat.handler import SagemakerChatHandler from .llms.sagemaker.chat.handler import SagemakerChatHandler
from .llms.sagemaker.completion.handler import SagemakerLLM from .llms.sagemaker.completion.handler import SagemakerLLM
from .llms.replicate.chat.handler import completion as replicate_chat_completion
from .llms.text_completion_codestral import CodestralTextCompletion
from .llms.together_ai.completion.handler import TogetherAITextCompletion from .llms.together_ai.completion.handler import TogetherAITextCompletion
from .llms.triton import TritonChatCompletion from .llms.triton import TritonChatCompletion
from .llms.vertex_ai import vertex_ai_non_gemini from .llms.vertex_ai import vertex_ai_non_gemini
from .llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import ( from .llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
VertexLLM,
)
from .llms.vertex_ai.gemini_embeddings.batch_embed_content_handler import ( from .llms.vertex_ai.gemini_embeddings.batch_embed_content_handler import (
GoogleBatchEmbeddings, GoogleBatchEmbeddings,
) )
@ -138,18 +138,10 @@ from .llms.vertex_ai.image_generation.image_generation_handler import (
from .llms.vertex_ai.multimodal_embeddings.embedding_handler import ( from .llms.vertex_ai.multimodal_embeddings.embedding_handler import (
VertexMultimodalEmbedding, VertexMultimodalEmbedding,
) )
from .llms.vertex_ai.text_to_speech.text_to_speech_handler import ( from .llms.vertex_ai.text_to_speech.text_to_speech_handler import VertexTextToSpeechAPI
VertexTextToSpeechAPI, from .llms.vertex_ai.vertex_ai_partner_models.main import VertexAIPartnerModels
) from .llms.vertex_ai.vertex_embeddings.embedding_handler import VertexEmbedding
from .llms.vertex_ai.vertex_ai_partner_models.main import ( from .llms.vertex_ai.vertex_model_garden.main import VertexAIModelGardenModels
VertexAIPartnerModels,
)
from .llms.vertex_ai.vertex_embeddings.embedding_handler import (
VertexEmbedding,
)
from .llms.vertex_ai.vertex_model_garden.main import (
VertexAIModelGardenModels,
)
from .llms.vllm.completion import handler as vllm_handler from .llms.vllm.completion import handler as vllm_handler
from .llms.watsonx.chat.handler import WatsonXChatHandler from .llms.watsonx.chat.handler import WatsonXChatHandler
from .llms.watsonx.completion.handler import IBMWatsonXAI from .llms.watsonx.completion.handler import IBMWatsonXAI
@ -2434,9 +2426,7 @@ def completion( # type: ignore # noqa: PLR0915
## RESPONSE OBJECT ## RESPONSE OBJECT
response = model_response response = model_response
elif ( elif custom_llm_provider == "sagemaker":
custom_llm_provider == "sagemaker"
):
# boto3 reads keys from .env # boto3 reads keys from .env
model_response = sagemaker_llm.completion( model_response = sagemaker_llm.completion(
model=model, model=model,

View file

@ -135,7 +135,9 @@ from litellm.types.utils import (
Usage, Usage,
) )
with resources.open_text("litellm.litellm_core_utils.tokenizers", "anthropic_tokenizer.json") as f: with resources.open_text(
"litellm.litellm_core_utils.tokenizers", "anthropic_tokenizer.json"
) as f:
json_data = json.load(f) json_data = json.load(f)
# Convert to str (if necessary) # Convert to str (if necessary)
claude_json_str = json.dumps(json_data) claude_json_str = json.dumps(json_data)
@ -3073,7 +3075,7 @@ def get_optional_params( # noqa: PLR0915
) )
_check_valid_arg(supported_params=supported_params) _check_valid_arg(supported_params=supported_params)
if "codestral" in model: if "codestral" in model:
optional_params = litellm.MistralTextCompletionConfig().map_openai_params( optional_params = litellm.CodestralTextCompletionConfig().map_openai_params(
model=model, model=model,
non_default_params=non_default_params, non_default_params=non_default_params,
optional_params=optional_params, optional_params=optional_params,
@ -3415,7 +3417,7 @@ def get_optional_params( # noqa: PLR0915
model=model, custom_llm_provider=custom_llm_provider model=model, custom_llm_provider=custom_llm_provider
) )
_check_valid_arg(supported_params=supported_params) _check_valid_arg(supported_params=supported_params)
optional_params = litellm.MistralTextCompletionConfig().map_openai_params( optional_params = litellm.CodestralTextCompletionConfig().map_openai_params(
non_default_params=non_default_params, non_default_params=non_default_params,
optional_params=optional_params, optional_params=optional_params,
model=model, model=model,
@ -6237,7 +6239,7 @@ from litellm.llms.base_llm.transformation import BaseConfig
class ProviderConfigManager: class ProviderConfigManager:
@staticmethod @staticmethod
def get_provider_chat_config( # noqa: PLR0915 def get_provider_chat_config( # noqa: PLR0915
model: str, provider: litellm.LlmProviders model: str, provider: litellm.LlmProviders
) -> BaseConfig: ) -> BaseConfig:
""" """
@ -6344,7 +6346,7 @@ class ProviderConfigManager:
elif litellm.LlmProviders.VOLCENGINE == provider: elif litellm.LlmProviders.VOLCENGINE == provider:
return litellm.VolcEngineConfig() return litellm.VolcEngineConfig()
elif litellm.LlmProviders.TEXT_COMPLETION_CODESTRAL == provider: elif litellm.LlmProviders.TEXT_COMPLETION_CODESTRAL == provider:
return litellm.MistralTextCompletionConfig() return litellm.CodestralTextCompletionConfig()
elif litellm.LlmProviders.SAMBANOVA == provider: elif litellm.LlmProviders.SAMBANOVA == provider:
return litellm.SambanovaConfig() return litellm.SambanovaConfig()
elif litellm.LlmProviders.MARITALK == provider: elif litellm.LlmProviders.MARITALK == provider:

View file

@ -212,13 +212,15 @@ def test_all_model_configs():
{}, {},
) == {"max_new_tokens": 10} ) == {"max_new_tokens": 10}
from litellm.llms.text_completion_codestral import MistralTextCompletionConfig from litellm.llms.codestral.completion.transformation import (
CodestralTextCompletionConfig,
)
assert ( assert (
"max_completion_tokens" "max_completion_tokens"
in MistralTextCompletionConfig().get_supported_openai_params(model="llama3") in CodestralTextCompletionConfig().get_supported_openai_params(model="llama3")
) )
assert MistralTextCompletionConfig().map_openai_params( assert CodestralTextCompletionConfig().map_openai_params(
model="llama3", model="llama3",
non_default_params={"max_completion_tokens": 10}, non_default_params={"max_completion_tokens": 10},
optional_params={}, optional_params={},
@ -277,13 +279,15 @@ def test_all_model_configs():
drop_params=False, drop_params=False,
) == {"maxTokens": 10} ) == {"maxTokens": 10}
from litellm.llms.text_completion_codestral import MistralTextCompletionConfig from litellm.llms.codestral.completion.transformation import (
CodestralTextCompletionConfig,
)
assert ( assert (
"max_completion_tokens" "max_completion_tokens"
in MistralTextCompletionConfig().get_supported_openai_params(model="llama3") in CodestralTextCompletionConfig().get_supported_openai_params(model="llama3")
) )
assert MistralTextCompletionConfig().map_openai_params( assert CodestralTextCompletionConfig().map_openai_params(
model="llama3", model="llama3",
non_default_params={"max_completion_tokens": 10}, non_default_params={"max_completion_tokens": 10},
optional_params={}, optional_params={},