forked from phoenix-oss/llama-stack-mirror
		
	Convert SamplingParams.strategy to a union (#767)
				
					
				
			# What does this PR do?
Cleans up how we provide sampling params. Earlier, strategy was an enum
and all params (top_p, temperature, top_k) across all strategies were
grouped. We now have a strategy union object with each strategy (greedy,
top_p, top_k) having its corresponding params.
Earlier, 
```
class SamplingParams: 
    strategy: enum ()
    top_p, temperature, top_k and other params
```
However, the `strategy` field was not being used in any providers making
it confusing to know the exact sampling behavior purely based on the
params since you could pass temperature, top_p, top_k and how the
provider would interpret those would not be clear.
Hence we introduced -- a union where the strategy and relevant params
are all clubbed together to avoid this confusion.
Have updated all providers, tests, notebooks, readme and otehr places
where sampling params was being used to use the new format.
   
## Test Plan
`pytest llama_stack/providers/tests/inference/groq/test_groq_utils.py`
// inference on ollama, fireworks and together 
`with-proxy pytest -v -s -k "ollama"
--inference-model="meta-llama/Llama-3.1-8B-Instruct"
llama_stack/providers/tests/inference/test_text_inference.py `
// agents on fireworks 
`pytest -v -s -k 'fireworks and create_agent'
--inference-model="meta-llama/Llama-3.1-8B-Instruct"
llama_stack/providers/tests/agents/test_agents.py
--safety-shield="meta-llama/Llama-Guard-3-8B"`
## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [X] Ran pre-commit to handle lint / formatting issues.
- [X] Read the [contributor
guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
      Pull Request section?
- [X] Updated relevant documentation.
- [X] Wrote necessary unit or integration tests.
---------
Co-authored-by: Hardik Shah <hjshah@fb.com>
			
			
This commit is contained in:
		
							parent
							
								
									300e6e2702
								
							
						
					
					
						commit
						a51c8b4efc
					
				
					 29 changed files with 611 additions and 388 deletions
				
			
		|  | @ -13,7 +13,6 @@ from termcolor import colored | |||
| 
 | ||||
| from llama_stack.cli.subcommand import Subcommand | ||||
| from llama_stack.cli.table import print_table | ||||
| from llama_stack.distribution.utils.serialize import EnumEncoder | ||||
| 
 | ||||
| 
 | ||||
| class ModelDescribe(Subcommand): | ||||
|  | @ -72,7 +71,7 @@ class ModelDescribe(Subcommand): | |||
|             rows.append( | ||||
|                 ( | ||||
|                     "Recommended sampling params", | ||||
|                     json.dumps(sampling_params, cls=EnumEncoder, indent=4), | ||||
|                     json.dumps(sampling_params, indent=4), | ||||
|                 ) | ||||
|             ) | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue