llama-stack-mirror/llama_stack/providers/remote/inference/tgi/tgi.py
Xi Yan 3c72c034e6
[remove import *] clean up import *'s (#689)
# What does this PR do?

- as title, cleaning up `import *`'s
- upgrade tests to make them more robust to bad model outputs
- remove import *'s in llama_stack/apis/* (skip __init__ modules)
<img width="465" alt="image"
src="https://github.com/user-attachments/assets/d8339c13-3b40-4ba5-9c53-0d2329726ee2"
/>

- run `sh run_openapi_generator.sh`, no types gets affected

## Test Plan

### Providers Tests

**agents**
```
pytest -v -s llama_stack/providers/tests/agents/test_agents.py -m "together" --safety-shield meta-llama/Llama-Guard-3-8B --inference-model meta-llama/Llama-3.1-405B-Instruct-FP8
```

**inference**
```bash
# meta-reference
torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference" --inference-model="meta-llama/Llama-3.1-8B-Instruct" ./llama_stack/providers/tests/inference/test_text_inference.py
torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference" --inference-model="meta-llama/Llama-3.2-11B-Vision-Instruct" ./llama_stack/providers/tests/inference/test_vision_inference.py

# together
pytest -v -s -k "together" --inference-model="meta-llama/Llama-3.1-8B-Instruct" ./llama_stack/providers/tests/inference/test_text_inference.py
pytest -v -s -k "together" --inference-model="meta-llama/Llama-3.2-11B-Vision-Instruct" ./llama_stack/providers/tests/inference/test_vision_inference.py

pytest ./llama_stack/providers/tests/inference/test_prompt_adapter.py 
```

**safety**
```
pytest -v -s llama_stack/providers/tests/safety/test_safety.py -m together --safety-shield meta-llama/Llama-Guard-3-8B
```

**memory**
```
pytest -v -s llama_stack/providers/tests/memory/test_memory.py -m "sentence_transformers" --env EMBEDDING_DIMENSION=384
```

**scoring**
```
pytest -v -s -m llm_as_judge_scoring_together_inference llama_stack/providers/tests/scoring/test_scoring.py --judge-model meta-llama/Llama-3.2-3B-Instruct
pytest -v -s -m basic_scoring_together_inference llama_stack/providers/tests/scoring/test_scoring.py
pytest -v -s -m braintrust_scoring_together_inference llama_stack/providers/tests/scoring/test_scoring.py
```


**datasetio**
```
pytest -v -s -m localfs llama_stack/providers/tests/datasetio/test_datasetio.py
pytest -v -s -m huggingface llama_stack/providers/tests/datasetio/test_datasetio.py
```


**eval**
```
pytest -v -s -m meta_reference_eval_together_inference llama_stack/providers/tests/eval/test_eval.py
pytest -v -s -m meta_reference_eval_together_inference_huggingface_datasetio llama_stack/providers/tests/eval/test_eval.py
```

### Client-SDK Tests
```
LLAMA_STACK_BASE_URL=http://localhost:5000 pytest -v ./tests/client-sdk
```

### llama-stack-apps
```
PORT=5000
LOCALHOST=localhost

python -m examples.agents.hello $LOCALHOST $PORT
python -m examples.agents.inflation $LOCALHOST $PORT
python -m examples.agents.podcast_transcript $LOCALHOST $PORT
python -m examples.agents.rag_as_attachments $LOCALHOST $PORT
python -m examples.agents.rag_with_memory_bank $LOCALHOST $PORT
python -m examples.safety.llama_guard_demo_mm $LOCALHOST $PORT
python -m examples.agents.e2e_loop_with_custom_tools $LOCALHOST $PORT

# Vision model
python -m examples.interior_design_assistant.app
python -m examples.agent_store.app $LOCALHOST $PORT
```

### CLI
```
which llama
llama model prompt-format -m Llama3.2-11B-Vision-Instruct
llama model list
llama stack list-apis
llama stack list-providers inference

llama stack build --template ollama --image-type conda
```

### Distributions Tests
**ollama**
```
llama stack build --template ollama --image-type conda
ollama run llama3.2:1b-instruct-fp16
llama stack run ./llama_stack/templates/ollama/run.yaml --env INFERENCE_MODEL=meta-llama/Llama-3.2-1B-Instruct
```

**fireworks**
```
llama stack build --template fireworks --image-type conda
llama stack run ./llama_stack/templates/fireworks/run.yaml
```

**together**
```
llama stack build --template together --image-type conda
llama stack run ./llama_stack/templates/together/run.yaml
```

**tgi**
```
llama stack run ./llama_stack/templates/tgi/run.yaml --env TGI_URL=http://0.0.0.0:5009 --env INFERENCE_MODEL=meta-llama/Llama-3.1-8B-Instruct
```

## Sources

Please link relevant resources if necessary.


## Before submitting

- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Ran pre-commit to handle lint / formatting issues.
- [ ] Read the [contributor
guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
      Pull Request section?
- [ ] Updated relevant documentation.
- [ ] Wrote necessary unit or integration tests.
2024-12-27 15:45:44 -08:00

323 lines
11 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from typing import AsyncGenerator, List, Optional
from huggingface_hub import AsyncInferenceClient, HfApi
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import all_registered_models
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest,
EmbeddingsResponse,
Inference,
LogProbConfig,
Message,
ResponseFormat,
ResponseFormatType,
SamplingParams,
ToolChoice,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.models import Model
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
process_chat_completion_response,
process_chat_completion_stream_response,
process_completion_response,
process_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_model_input_info,
completion_request_to_prompt_model_input_info,
)
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
log = logging.getLogger(__name__)
def build_model_aliases():
return [
build_model_alias(
model.huggingface_repo,
model.descriptor(),
)
for model in all_registered_models()
if model.huggingface_repo
]
class _HfAdapter(Inference, ModelsProtocolPrivate):
client: AsyncInferenceClient
max_tokens: int
model_id: str
def __init__(self) -> None:
self.formatter = ChatFormat(Tokenizer.get_instance())
self.register_helper = ModelRegistryHelper(build_model_aliases())
self.huggingface_repo_to_llama_model_id = {
model.huggingface_repo: model.descriptor()
for model in all_registered_models()
if model.huggingface_repo
}
async def shutdown(self) -> None:
pass
async def register_model(self, model: Model) -> None:
model = await self.register_helper.register_model(model)
if model.provider_resource_id != self.model_id:
raise ValueError(
f"Model {model.provider_resource_id} does not match the model {self.model_id} served by TGI."
)
return model
async def unregister_model(self, model_id: str) -> None:
pass
async def completion(
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
model = await self.model_store.get_model(model_id)
request = CompletionRequest(
model=model.provider_resource_id,
content=content,
sampling_params=sampling_params,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
if stream:
return self._stream_completion(request)
else:
return await self._nonstream_completion(request)
def _get_max_new_tokens(self, sampling_params, input_tokens):
return min(
sampling_params.max_tokens or (self.max_tokens - input_tokens),
self.max_tokens - input_tokens - 1,
)
def _build_options(
self,
sampling_params: Optional[SamplingParams] = None,
fmt: ResponseFormat = None,
):
options = get_sampling_options(sampling_params)
# delete key "max_tokens" from options since its not supported by the API
options.pop("max_tokens", None)
if fmt:
if fmt.type == ResponseFormatType.json_schema.value:
options["grammar"] = {
"type": "json",
"value": fmt.json_schema,
}
elif fmt.type == ResponseFormatType.grammar.value:
raise ValueError("Grammar response format not supported yet")
else:
raise ValueError(f"Unexpected response format: {fmt.type}")
return options
async def _get_params_for_completion(self, request: CompletionRequest) -> dict:
prompt, input_tokens = await completion_request_to_prompt_model_input_info(
request, self.formatter
)
return dict(
prompt=prompt,
stream=request.stream,
details=True,
max_new_tokens=self._get_max_new_tokens(
request.sampling_params, input_tokens
),
stop_sequences=["<|eom_id|>", "<|eot_id|>"],
**self._build_options(request.sampling_params, request.response_format),
)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params_for_completion(request)
async def _generate_and_convert_to_openai_compat():
s = await self.client.text_generation(**params)
async for chunk in s:
token_result = chunk.token
finish_reason = None
if chunk.details:
finish_reason = chunk.details.finish_reason
choice = OpenAICompatCompletionChoice(
text=token_result.text, finish_reason=finish_reason
)
yield OpenAICompatCompletionResponse(
choices=[choice],
)
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_completion_stream_response(stream, self.formatter):
yield chunk
async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params_for_completion(request)
r = await self.client.text_generation(**params)
choice = OpenAICompatCompletionChoice(
finish_reason=r.details.finish_reason,
text="".join(t.text for t in r.details.tokens),
)
response = OpenAICompatCompletionResponse(
choices=[choice],
)
return process_completion_response(response, self.formatter)
async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
if stream:
return self._stream_chat_completion(request)
else:
return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
params = await self._get_params(request)
r = await self.client.text_generation(**params)
choice = OpenAICompatCompletionChoice(
finish_reason=r.details.finish_reason,
text="".join(t.text for t in r.details.tokens),
)
response = OpenAICompatCompletionResponse(
choices=[choice],
)
return process_chat_completion_response(response, self.formatter)
async def _stream_chat_completion(
self, request: ChatCompletionRequest
) -> AsyncGenerator:
params = await self._get_params(request)
async def _generate_and_convert_to_openai_compat():
s = await self.client.text_generation(**params)
async for chunk in s:
token_result = chunk.token
choice = OpenAICompatCompletionChoice(text=token_result.text)
yield OpenAICompatCompletionResponse(
choices=[choice],
)
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_chat_completion_stream_response(
stream, self.formatter
):
yield chunk
async def _get_params(self, request: ChatCompletionRequest) -> dict:
prompt, input_tokens = await chat_completion_request_to_model_input_info(
request, self.register_helper.get_llama_model(request.model), self.formatter
)
return dict(
prompt=prompt,
stream=request.stream,
details=True,
max_new_tokens=self._get_max_new_tokens(
request.sampling_params, input_tokens
),
stop_sequences=["<|eom_id|>", "<|eot_id|>"],
**self._build_options(request.sampling_params, request.response_format),
)
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent],
) -> EmbeddingsResponse:
raise NotImplementedError()
class TGIAdapter(_HfAdapter):
async def initialize(self, config: TGIImplConfig) -> None:
log.info(f"Initializing TGI client with url={config.url}")
self.client = AsyncInferenceClient(model=config.url, token=config.api_token)
endpoint_info = await self.client.get_endpoint_info()
self.max_tokens = endpoint_info["max_total_tokens"]
self.model_id = endpoint_info["model_id"]
class InferenceAPIAdapter(_HfAdapter):
async def initialize(self, config: InferenceAPIImplConfig) -> None:
self.client = AsyncInferenceClient(
model=config.huggingface_repo, token=config.api_token
)
endpoint_info = await self.client.get_endpoint_info()
self.max_tokens = endpoint_info["max_total_tokens"]
self.model_id = endpoint_info["model_id"]
class InferenceEndpointAdapter(_HfAdapter):
async def initialize(self, config: InferenceEndpointImplConfig) -> None:
# Get the inference endpoint details
api = HfApi(token=config.api_token)
endpoint = api.get_inference_endpoint(config.endpoint_name)
# Wait for the endpoint to be ready (if not already)
endpoint.wait(timeout=60)
# Initialize the adapter
self.client = endpoint.async_client
self.model_id = endpoint.repository
self.max_tokens = int(
endpoint.raw["model"]["image"]["custom"]["env"]["MAX_TOTAL_TOKENS"]
)