Skip to content

Commit

Permalink
Merge pull request #134 from OpenAccess-AI-Collective/gas-batch-fix
Browse files Browse the repository at this point in the history
fix batch size calculation
  • Loading branch information
winglian authored May 31, 2023
2 parents cbf705a + baa49cb commit 909ef0c
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
8 changes: 5 additions & 3 deletions scripts/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,15 +163,17 @@ def train(
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
cfg.batch_size // cfg.micro_batch_size
)
cfg.batch_size = (
cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps
)
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
choose_device(cfg)
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
if cfg.ddp:
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
cfg.gradient_accumulation_steps = (
cfg.gradient_accumulation_steps // cfg.world_size
)
cfg.batch_size = cfg.batch_size * cfg.world_size

setup_wandb_env_vars(cfg)
if cfg.device == "mps":
cfg.load_in_8bit = False
Expand Down
1 change: 1 addition & 0 deletions src/axolotl/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def load_tokenized_prepared_datasets(
datasets.append(ds_wrapper)
else:
logging.error(f"unhandled prompt tokenization strategy: {d.type}")
raise ValueError(f"unhandled prompt tokenization strategy: {d.type}")
logging.info("tokenizing, merging, and shuffling master dataset")

samples: List[int] = []
Expand Down

0 comments on commit 909ef0c

Please sign in to comment.