mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-31 07:43:54 +00:00
make TGI work well
This commit is contained in:
parent
e58c7f6c37
commit
021dd0d35d
9 changed files with 617 additions and 326 deletions
|
|
@ -33,10 +33,9 @@ from llama_stack.apis.inference import (
|
|||
from llama_stack.apis.models.models import Model
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
convert_chat_completion_request_to_openai_params,
|
||||
convert_message_to_openai_dict_new,
|
||||
convert_openai_chat_completion_choice,
|
||||
convert_openai_chat_completion_stream,
|
||||
|
|
@ -55,7 +54,9 @@ class LiteLLMOpenAIMixin(
|
|||
Inference,
|
||||
NeedsRequestProviderData,
|
||||
):
|
||||
def __init__(self, model_entries, api_key_from_config: str, provider_data_api_key_field: str):
|
||||
def __init__(
|
||||
self, model_entries, api_key_from_config: str, provider_data_api_key_field: str
|
||||
):
|
||||
ModelRegistryHelper.__init__(self, model_entries)
|
||||
self.api_key_from_config = api_key_from_config
|
||||
self.provider_data_api_key_field = provider_data_api_key_field
|
||||
|
|
@ -95,7 +96,9 @@ class LiteLLMOpenAIMixin(
|
|||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||
) -> Union[
|
||||
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
|
||||
]:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
|
|
@ -110,7 +113,17 @@ class LiteLLMOpenAIMixin(
|
|||
tool_config=tool_config,
|
||||
)
|
||||
|
||||
params = await self._get_params(request)
|
||||
params = await convert_chat_completion_request_to_openai_params(request)
|
||||
|
||||
# add api_key to params if available
|
||||
provider_data = self.get_request_provider_data()
|
||||
key_field = self.provider_data_api_key_field
|
||||
if provider_data and getattr(provider_data, key_field, None):
|
||||
api_key = getattr(provider_data, key_field)
|
||||
else:
|
||||
api_key = self.api_key_from_config
|
||||
params["api_key"] = api_key
|
||||
|
||||
logger.debug(f"params to litellm (openai compat): {params}")
|
||||
# unfortunately, we need to use synchronous litellm.completion here because litellm
|
||||
# caches various httpx.client objects in a non-eventloop aware manner
|
||||
|
|
@ -132,87 +145,6 @@ class LiteLLMOpenAIMixin(
|
|||
):
|
||||
yield chunk
|
||||
|
||||
def _add_additional_properties_recursive(self, schema):
|
||||
"""
|
||||
Recursively add additionalProperties: False to all object schemas
|
||||
"""
|
||||
if isinstance(schema, dict):
|
||||
if schema.get("type") == "object":
|
||||
schema["additionalProperties"] = False
|
||||
|
||||
# Add required field with all property keys if properties exist
|
||||
if "properties" in schema and schema["properties"]:
|
||||
schema["required"] = list(schema["properties"].keys())
|
||||
|
||||
if "properties" in schema:
|
||||
for prop_schema in schema["properties"].values():
|
||||
self._add_additional_properties_recursive(prop_schema)
|
||||
|
||||
for key in ["anyOf", "allOf", "oneOf"]:
|
||||
if key in schema:
|
||||
for sub_schema in schema[key]:
|
||||
self._add_additional_properties_recursive(sub_schema)
|
||||
|
||||
if "not" in schema:
|
||||
self._add_additional_properties_recursive(schema["not"])
|
||||
|
||||
# Handle $defs/$ref
|
||||
if "$defs" in schema:
|
||||
for def_schema in schema["$defs"].values():
|
||||
self._add_additional_properties_recursive(def_schema)
|
||||
|
||||
return schema
|
||||
|
||||
async def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
input_dict = {}
|
||||
|
||||
input_dict["messages"] = [await convert_message_to_openai_dict_new(m) for m in request.messages]
|
||||
if fmt := request.response_format:
|
||||
if not isinstance(fmt, JsonSchemaResponseFormat):
|
||||
raise ValueError(
|
||||
f"Unsupported response format: {type(fmt)}. Only JsonSchemaResponseFormat is supported."
|
||||
)
|
||||
|
||||
fmt = fmt.json_schema
|
||||
name = fmt["title"]
|
||||
del fmt["title"]
|
||||
fmt["additionalProperties"] = False
|
||||
|
||||
# Apply additionalProperties: False recursively to all objects
|
||||
fmt = self._add_additional_properties_recursive(fmt)
|
||||
|
||||
input_dict["response_format"] = {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": name,
|
||||
"schema": fmt,
|
||||
"strict": True,
|
||||
},
|
||||
}
|
||||
if request.tools:
|
||||
input_dict["tools"] = [convert_tooldef_to_openai_tool(tool) for tool in request.tools]
|
||||
if request.tool_config.tool_choice:
|
||||
input_dict["tool_choice"] = (
|
||||
request.tool_config.tool_choice.value
|
||||
if isinstance(request.tool_config.tool_choice, ToolChoice)
|
||||
else request.tool_config.tool_choice
|
||||
)
|
||||
|
||||
provider_data = self.get_request_provider_data()
|
||||
key_field = self.provider_data_api_key_field
|
||||
if provider_data and getattr(provider_data, key_field, None):
|
||||
api_key = getattr(provider_data, key_field)
|
||||
else:
|
||||
api_key = self.api_key_from_config
|
||||
|
||||
return {
|
||||
"model": request.model,
|
||||
"api_key": api_key,
|
||||
**input_dict,
|
||||
"stream": request.stream,
|
||||
**get_sampling_options(request.sampling_params),
|
||||
}
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue