diff --git a/scripts/finetune.py b/scripts/finetune.py index e807456d87..1b1e994dda 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -163,15 +163,17 @@ def train( cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or ( cfg.batch_size // cfg.micro_batch_size ) + cfg.batch_size = ( + cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps + ) cfg.world_size = int(os.environ.get("WORLD_SIZE", 1)) cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0)) choose_device(cfg) cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1 if cfg.ddp: cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))} - cfg.gradient_accumulation_steps = ( - cfg.gradient_accumulation_steps // cfg.world_size - ) + cfg.batch_size = cfg.batch_size * cfg.world_size + setup_wandb_env_vars(cfg) if cfg.device == "mps": cfg.load_in_8bit = False diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 9534323de8..037fa45bf9 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -233,6 +233,7 @@ def load_tokenized_prepared_datasets( datasets.append(ds_wrapper) else: logging.error(f"unhandled prompt tokenization strategy: {d.type}") + raise ValueError(f"unhandled prompt tokenization strategy: {d.type}") logging.info("tokenizing, merging, and shuffling master dataset") samples: List[int] = []