-
Notifications
You must be signed in to change notification settings - Fork 460
Expand sharding dump #3034
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Expand sharding dump #3034
Conversation
964aca4 to
bd07254
Compare
Codecov Report✅ All modified and coverable lines are covered by tests. 📢 Thoughts on this report? Let us know! |
cdbeb48 to
5e9737b
Compare
richjames0
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
| @@ -23,7 +23,7 @@ | |||
| from MaxText.train_compile import get_shaped_inputs, get_topology_mesh, validate_config | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it looks like this test only verifies get_shaped_inputs and get_topology_mesh from train_compile. Please explicitly specify that this test is for train_compile only if this is the intention. Otherwise, I think it is better to test the get_abstract_state from maxtext_utils.
bf0879b to
7ba2a20
Compare
7ba2a20 to
277368f
Compare
277368f to
1321643
Compare
Description
Expand
sharding_dump.pyto output logical axes and update unit testsharding_compare_test.pyCurrently, two json files (
logical_shardings.json and named_shardings.json) will be generated per model/device_type/slice_number.To prevent negatively impacting GitHub Actions CI run times, we will only check a limited set of golden files (
Deepeek2 -16b/gpt-oss-20b/qwen-0.6b with tpu7x-16/v5p-16/v6e-16)into the MaxText repositoryCommand:
Get baseline sharding info:
There are two primary ways to use the script
run_sharding_dump.py:Run the script without any command-line arguments to iterate through all test
cases defined in
tests.utils.sharding_dump.TEST_CASES. It will skip anycombination for which the output files already exist.
Command:
python3 -m tests.utils.run_sharding_dumpProvide the
model_name,topology, andnum_sliceas command-line argumentsto generate sharding information for a single configuration. You must provide
all three arguments.
Command:
python3 -m tests.utils.run_sharding_dump --model_name <model> --topology <topology> --num_slice <slices>Example:
python3 -m tests.utils.run_sharding_dump --model_name gemma-7b --topology v5p-256 --num_slice 1Compare sharding info:
python3 -m pytest tests/unit/sharding_compare_test.py -s -v -k "llama3.1-70b" 2>&1 | tee test_output.logExample
Content in logical_shardings.json
Content in named_shardings.json
Tests
UT for sharding dump comparison Failed (physical weight) : https://paste.googleplex.com/6032726044049408
UT for sharding dump comparison Failed (logical): https://paste.googleplex.com/5855428334452736
UT for sharding dump comparison Successed: https://paste.googleplex.com/6737857618247680
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.