forked from phoenix-oss/llama-stack-mirror
# What does this PR do? Before this change, `distro_codegen.py` would only work if the user manually installed multiple provider-specific dependencies (see #1122). Now, users can run `distro_codegen.py` without any provider-specific dependencies because we avoid importing the entire provider implementations just to get the config needed to build the provider template. Concretely, this mostly means moving the MODEL_ALIASES (and related variants) definitions to a new models.py class within the provider implementation for those providers that require additional dependencies. It also meant moving a couple of imports from top-level imports to inside `get_adapter_impl` for some providers, which follows the pattern used by multiple existing providers. To ensure we don't regress and accidentally add new imports that cause distro_codegen.py to fail, the stubbed-in pre-commit hook for distro_codegen.py was uncommented and slightly tweaked to run via `uv run python ...` to ensure it runs with only the project's default dependencies and to run automatically instead of manually. Lastly, this updates distro_codegen.py itself to keep track of paths it might have changed and to only `git diff` those specific paths when checking for changed files instead of doing a diff on the entire working tree. The latter was overly broad and would require a user have no other unstaged changes in their working tree, even if those unstaged changes were unrelated to generated code. Now it only flags uncommitted changes for paths distro_codegen.py actually writes to. Our generated code was also out-of-date, presumably because of these issues, so this commit also has some updates to the generated code purely because it was out of sync, and the pre-commit hook now enforces things to be updated. (Closes #1122) ## Test Plan I manually tested distro_codegen.py and the pre-commit hook to verify those work as expected, flagging any uncommited changes and catching any imports that attempt to pull in provider-specific dependencies. However, I do not have valid api keys to the impacted provider implementations, and am unable to easily run the inference tests against each changed provider. There are no functional changes to the provider implementations here, but I'd appreciate a second set of eyes on the changed import statements and moving of MODEL_ALIASES type code to a separate models.py to ensure I didn't make any obvious errors. --------- Signed-off-by: Ben Browning <bbrownin@redhat.com> Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
185 lines
6.7 KiB
Python
185 lines
6.7 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import json
|
|
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
|
|
|
from botocore.client import BaseClient
|
|
from llama_models.llama3.api.chat_format import ChatFormat
|
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
|
|
|
from llama_stack.apis.common.content_types import InterleavedContent
|
|
from llama_stack.apis.inference import (
|
|
ChatCompletionRequest,
|
|
ChatCompletionResponse,
|
|
ChatCompletionResponseStreamChunk,
|
|
EmbeddingsResponse,
|
|
Inference,
|
|
LogProbConfig,
|
|
Message,
|
|
ResponseFormat,
|
|
SamplingParams,
|
|
ToolChoice,
|
|
ToolConfig,
|
|
ToolDefinition,
|
|
ToolPromptFormat,
|
|
)
|
|
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
|
|
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
|
from llama_stack.providers.utils.inference.model_registry import (
|
|
ModelRegistryHelper,
|
|
)
|
|
from llama_stack.providers.utils.inference.openai_compat import (
|
|
OpenAICompatCompletionChoice,
|
|
OpenAICompatCompletionResponse,
|
|
get_sampling_strategy_options,
|
|
process_chat_completion_response,
|
|
process_chat_completion_stream_response,
|
|
)
|
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
|
chat_completion_request_to_prompt,
|
|
content_has_media,
|
|
interleaved_content_as_str,
|
|
)
|
|
|
|
from .models import MODEL_ALIASES
|
|
|
|
|
|
class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
|
def __init__(self, config: BedrockConfig) -> None:
|
|
ModelRegistryHelper.__init__(self, MODEL_ALIASES)
|
|
self._config = config
|
|
|
|
self._client = create_bedrock_client(config)
|
|
self.formatter = ChatFormat(Tokenizer.get_instance())
|
|
|
|
@property
|
|
def client(self) -> BaseClient:
|
|
return self._client
|
|
|
|
async def initialize(self) -> None:
|
|
pass
|
|
|
|
async def shutdown(self) -> None:
|
|
self.client.close()
|
|
|
|
async def completion(
|
|
self,
|
|
model_id: str,
|
|
content: InterleavedContent,
|
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
|
response_format: Optional[ResponseFormat] = None,
|
|
stream: Optional[bool] = False,
|
|
logprobs: Optional[LogProbConfig] = None,
|
|
) -> AsyncGenerator:
|
|
raise NotImplementedError()
|
|
|
|
async def chat_completion(
|
|
self,
|
|
model_id: str,
|
|
messages: List[Message],
|
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
|
response_format: Optional[ResponseFormat] = None,
|
|
tools: Optional[List[ToolDefinition]] = None,
|
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
|
stream: Optional[bool] = False,
|
|
logprobs: Optional[LogProbConfig] = None,
|
|
tool_config: Optional[ToolConfig] = None,
|
|
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
|
model = await self.model_store.get_model(model_id)
|
|
request = ChatCompletionRequest(
|
|
model=model.provider_resource_id,
|
|
messages=messages,
|
|
sampling_params=sampling_params,
|
|
tools=tools or [],
|
|
response_format=response_format,
|
|
stream=stream,
|
|
logprobs=logprobs,
|
|
tool_config=tool_config,
|
|
)
|
|
|
|
if stream:
|
|
return self._stream_chat_completion(request)
|
|
else:
|
|
return await self._nonstream_chat_completion(request)
|
|
|
|
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
|
params = await self._get_params_for_chat_completion(request)
|
|
res = self.client.invoke_model(**params)
|
|
chunk = next(res["body"])
|
|
result = json.loads(chunk.decode("utf-8"))
|
|
|
|
choice = OpenAICompatCompletionChoice(
|
|
finish_reason=result["stop_reason"],
|
|
text=result["generation"],
|
|
)
|
|
|
|
response = OpenAICompatCompletionResponse(choices=[choice])
|
|
return process_chat_completion_response(response, self.formatter, request)
|
|
|
|
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
|
params = await self._get_params_for_chat_completion(request)
|
|
res = self.client.invoke_model_with_response_stream(**params)
|
|
event_stream = res["body"]
|
|
|
|
async def _generate_and_convert_to_openai_compat():
|
|
for chunk in event_stream:
|
|
chunk = chunk["chunk"]["bytes"]
|
|
result = json.loads(chunk.decode("utf-8"))
|
|
choice = OpenAICompatCompletionChoice(
|
|
finish_reason=result["stop_reason"],
|
|
text=result["generation"],
|
|
)
|
|
yield OpenAICompatCompletionResponse(choices=[choice])
|
|
|
|
stream = _generate_and_convert_to_openai_compat()
|
|
async for chunk in process_chat_completion_stream_response(stream, self.formatter, request):
|
|
yield chunk
|
|
|
|
async def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> Dict:
|
|
bedrock_model = request.model
|
|
|
|
sampling_params = request.sampling_params
|
|
options = get_sampling_strategy_options(sampling_params)
|
|
|
|
if sampling_params.max_tokens:
|
|
options["max_gen_len"] = sampling_params.max_tokens
|
|
if sampling_params.repetition_penalty > 0:
|
|
options["repetition_penalty"] = sampling_params.repetition_penalty
|
|
|
|
prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model), self.formatter)
|
|
return {
|
|
"modelId": bedrock_model,
|
|
"body": json.dumps(
|
|
{
|
|
"prompt": prompt,
|
|
**options,
|
|
}
|
|
),
|
|
}
|
|
|
|
async def embeddings(
|
|
self,
|
|
model_id: str,
|
|
contents: List[InterleavedContent],
|
|
) -> EmbeddingsResponse:
|
|
model = await self.model_store.get_model(model_id)
|
|
embeddings = []
|
|
for content in contents:
|
|
assert not content_has_media(content), "Bedrock does not support media for embeddings"
|
|
input_text = interleaved_content_as_str(content)
|
|
input_body = {"inputText": input_text}
|
|
body = json.dumps(input_body)
|
|
response = self.client.invoke_model(
|
|
body=body,
|
|
modelId=model.provider_resource_id,
|
|
accept="application/json",
|
|
contentType="application/json",
|
|
)
|
|
response_body = json.loads(response.get("body").read())
|
|
embeddings.append(response_body.get("embedding"))
|
|
return EmbeddingsResponse(embeddings=embeddings)
|