Skip to content

Commit

Permalink
Merge pull request #87 from microsoft/longnet/longvit
Browse files Browse the repository at this point in the history
Release LongNet and LongViT
  • Loading branch information
shumingma authored Dec 20, 2023
2 parents 3ff2f1f + ef15951 commit 11e745e
Show file tree
Hide file tree
Showing 31 changed files with 4,210 additions and 76 deletions.
50 changes: 50 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Fundamental research to develop new architectures for foundation models and A(G)

## News

- December, 2023: [LongNet](torchscale/model/LongNet.py) and [LongViT](examples/longvit/README.md) released
- October, 2023: Update RMSNorm and SwiGLU as the default module in RetNet
- November, 2022: TorchScale 0.1.1 released [[Paper](https://arxiv.org/abs/2211.13184)] [[PyPI](https://pypi.org/project/torchscale/)]

Expand All @@ -37,6 +38,18 @@ cd torchscale
pip install -e .
```

For faster training install [Flash Attention](https://github.com/Dao-AILab/flash-attention) for Turing, Ampere, Ada, or Hopper GPUs:
```
pip install flash-attn
```
or [xFormers](https://github.com/facebookresearch/xformers) for Volta, Turing, Ampere, Ada, or Hopper GPUs:
```
# cuda 11.8 version
pip3 install -U xformers --index-url https://download.pytorch.org/whl/cu118
# cuda 12.1 version
pip3 install -U xformers --index-url https://download.pytorch.org/whl/cu121
```

## Getting Started

It takes only several lines of code to create a model with the above fundamental research features enabled. Here is how to quickly obtain a BERT-like encoder:
Expand Down Expand Up @@ -85,6 +98,21 @@ It takes only several lines of code to create a RetNet model:
>>> print(retnet)
```

For LongNet models ([Flash Attention](https://github.com/Dao-AILab/flash-attention) required):
```python
>>> import torch
>>> from torchscale.architecture.config import EncoderConfig, DecoderConfig
>>> from torchscale.model.longnet import LongNetEncoder, LongNetDecoder

# Creating a LongNet encoder with the dilated pattern of segment_length=[2048,4096] and dilated_ratio=[1,2]
>>> config = EncoderConfig(vocab_size=64000, segment_length='[2048,4096]', dilated_ratio='[1,2]', flash_attention=True)
>>> longnet = LongNetEncoder(config)

# Creating a LongNet decoder with the dilated pattern of segment_length=[2048,4096] and dilated_ratio=[1,2]
>>> config = DecoderConfig(vocab_size=64000, segment_length='[2048,4096]', dilated_ratio='[1,2]', flash_attention=True)
>>> longnet = LongNetDecoder(config)
```

## Key Features

- [DeepNorm to improve the training stability of Post-LayerNorm Transformers](https://arxiv.org/abs/2203.00555)
Expand Down Expand Up @@ -142,6 +170,8 @@ We have examples of how to use TorchScale in the following scenarios/tasks:

- Vision

* [LongViT](examples/longvit/README.md)

* ViT/BEiT [In progress]

- Speech
Expand Down Expand Up @@ -228,6 +258,26 @@ If you find this repository useful, please consider citing our work:
}
```

```
@article{longnet,
author={Jiayu Ding and Shuming Ma and Li Dong and Xingxing Zhang and Shaohan Huang and Wenhui Wang and Nanning Zheng and Furu Wei},
title = {{LongNet}: Scaling Transformers to 1,000,000,000 Tokens},
journal = {ArXiv},
volume = {abs/2307.02486},
year = {2023}
}
```

```
@article{longvit,
title = {When an Image is Worth 1,024 x 1,024 Words: A Case Study in Computational Pathology},
author = {Wenhui Wang and Shuming Ma and Hanwen Xu and Naoto Usuyama and Jiayu Ding and Hoifung Poon and Furu Wei},
journal = {ArXiv},
volume = {abs/2312.03558},
year = {2023}
}
```

## Contributing

This project welcomes contributions and suggestions. Most contributions require you to agree to a
Expand Down
39 changes: 39 additions & 0 deletions examples/fairseq/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,45 @@ python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 train.py \
--use-xmoe
```

### LongNet Model

```bash
cd examples/fairseq/
python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 train.py \
${PATH_TO_DATA} \
--num-workers 2 \
--activation-fn gelu \
--share-decoder-input-output-embed \
--validate-interval-updates 1000 \
--save-interval-updates 1000 \
--no-epoch-checkpoints \
--memory-efficient-fp16 \
--fp16-init-scale 4 \
--arch lm_base \
--task language_modeling \
--sample-break-mode none \
--tokens-per-sample 4096 \
--optimizer adam --adam-betas "(0.9, 0.98)" \
--adam-eps 1e-08 \
--clip-norm 0.0 \
--lr 5e-4 \
--lr-scheduler polynomial_decay \
--warmup-updates 750 \
--dropout 0.1 \
--attention-dropout 0.1 \
--weight-decay 0.01 \
--batch-size 4 \
--update-freq 1 \
--required-batch-size-multiple 1 \
--total-num-update 50000 \
--max-update 50000 \
--seed 1 \
--ddp-backend=c10d \
--flash-attention \
--segment-length [2048,4096] \
--dilated-ratio [1,2]
```

## Example: Machine Translation

### Data Format
Expand Down
41 changes: 40 additions & 1 deletion examples/fairseq/models/language_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from torchscale.architecture.config import DecoderConfig
from torchscale.architecture.decoder import Decoder
from torchscale.model.LongNet import LongNetDecoder

DEFAULT_MAX_TARGET_POSITIONS = 1024
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -196,6 +197,19 @@ class LanguageConfig(FairseqDataclass):
xpos_scale_base: Optional[int] = field(
default=512,
)
flash_attention: Optional[bool] = field(
default=False,
)
seq_parallel: Optional[bool] = field(
default=False,
)
segment_length: Optional[str] = field(
default='',
)
dilated_ratio: Optional[str] = field(
default='',
)



@register_model("lm", dataclass=LanguageConfig)
Expand Down Expand Up @@ -256,7 +270,13 @@ def build_model(cls, args, task):
config = DecoderConfig()
config.override(args)

decoder = LMDecoder(
if args.segment_length != '':
assert args.dilated_ratio != ''
DECODER_CLASS = LongNetLMDecoder
else:
DECODER_CLASS = LMDecoder

decoder = DECODER_CLASS(
config,
embed_tokens,
embed_positions,
Expand Down Expand Up @@ -291,6 +311,25 @@ def reorder_incremental_state_scripting(
incremental_state[module][key] = result


class LongNetLMDecoder(LongNetDecoder, FairseqIncrementalDecoder):
def forward(self, src_tokens, **kwargs):
self_attn_padding_mask = src_tokens.eq(self.dictionary.pad())
return super().forward(src_tokens, self_attn_padding_mask, **kwargs)

def max_positions(self):
return self.embed_positions.max_positions

def reorder_incremental_state_scripting(
self,
incremental_state,
new_order,
):
for module in incremental_state:
for key in incremental_state[module]:
result = incremental_state[module][key].index_select(0, new_order)
incremental_state[module][key] = result


@register_model_architecture("lm", "lm_base")
def base_lm_architecture(args):
# backward compatibility for older model checkpoints
Expand Down
71 changes: 71 additions & 0 deletions examples/longvit/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# [(LongViT) When an Image is Worth 1,024 x 1,024 Words: A Case Study in Computational Pathology](https://arxiv.org/abs/2312.03558)

**LongViT** is a vision Transformer that can process gigapixel images (e.g., 32,768x32,768 images) in an end-to-end manner. We split the image into millions of patches and employ [LongNet](https://arxiv.org/abs/2307.02486) to directly model the extremely long sequence. We apply LongViT in the field of computational pathology and achieve remarkable performance on cancer subtyping and survival prediction tasks.


## Setup
```
pip install -r requirements.txt
pip install git+https://github.com/shumingma/fairseq.git@moe
pip install -v -U git+https://github.com/facebookresearch/[email protected]#egg=xformers
```


## Pretraining

We perform self-supervised pretraining on TCGA diagnostic slides using [DINO](https://arxiv.org/abs/2104.14294) objective. The detailed instructions can be found at [`get_started_for_tcga_pretraining.md`](get_started/get_started_for_tcga_pretraining.md).

The link to the pretrained LongViT model on TCGA diagnostic slides:
- [`LongViT`](https://conversationhub.blob.core.windows.net/beit-share-public/longvit/longvit_small_patch32_1024.pth?sv=2021-10-04&st=2023-06-08T11%3A16%3A02Z&se=2033-06-09T11%3A16%3A00Z&sr=c&sp=r&sig=N4pfCVmSeq4L4tS8QbrFVsX6f6q844eft8xSuXdxU48%3D): #layer=12; hidden=384; FFN factor=4x; #head=16; patch=32x32


## Fine-tuning on Subtyping Classification

We perform finetuning on cancer subtyping on images with sizes up to 32,768x32,768 (1M patches). The detailed instructions can be found at [`get_started_for_tcga_subtyping.md`](get_started/get_started_for_tcga_subtyping.md).


## Fine-tuning on Survival Prediction

We perform finetuning on survival prediction on images with sizes up to 32,768x32,768 (1M patches). The detailed instructions can be found at [`get_started_for_tcga_survival_prediction.md`](get_started/get_started_for_tcga_survival_prediction.md).


## Citation

If you find this repository useful, please consider citing our work:
```
@article{longvit,
title={When an Image is Worth 1,024 x 1,024 Words: A Case Study in Computational Pathology},
author={Wang, Wenhui and Ma, Shuming and Xu, Hanwen and Usuyama, Naoto and Ding, Jiayu and Poon, Hoifung and Wei, Furu},
journal={arXiv preprint arXiv:2312.03558},
year={2023}
}
@article{longnet,
title={LongNet: Scaling transformers to 1,000,000,000 tokens},
author={Ding, Jiayu and Ma, Shuming and Dong, Li and Zhang, Xingxing and Huang, Shaohan and Wang, Wenhui and Zheng, Nanning and Wei, Furu},
journal={arXiv preprint arXiv:2307.02486},
year={2023}
}
@article{torchscale,
title={TorchScale: Transformers at scale},
author={Ma, Shuming and Wang, Hongyu and Huang, Shaohan and Wang, Wenhui and Chi, Zewen and Dong, Li and Benhaim, Alon and Patra, Barun and Chaudhary, Vishrav and Song, Xia and others},
journal={arXiv preprint arXiv:2211.13184},
year={2022}
}
```


## Acknowledgement

This repository is built using the [BEiT-3](https://github.com/microsoft/unilm/tree/master/beit3), the [MCAT](https://github.com/mahmoodlab/MCAT), the [DINO](https://github.com/facebookresearch/dino), the [HIPT](https://github.com/mahmoodlab/HIPT) repository and the [timm](https://github.com/rwightman/pytorch-image-models) library.


## License
This project is licensed under the license found in the LICENSE file in the root directory of this source tree.

[Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct)

### Contact Information

For help or issues using LongViT models, please submit a GitHub issue.
78 changes: 78 additions & 0 deletions examples/longvit/data_preprocessing/cache_transformed_images.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import os
import sys
import torch
import random
import argparse
from PIL import Image, ImageFilter, ImageOps
from multiprocessing import Pool, cpu_count
from timm.data.transforms import RandomResizedCropAndInterpolation
import torchvision.transforms as transforms

Image.MAX_IMAGE_PIXELS = 6400000000


def build_transform(input_size):
train_interpolation = "bicubic"
t = [
RandomResizedCropAndInterpolation(input_size, scale=(0.5, 1.0), interpolation=train_interpolation),
transforms.RandomHorizontalFlip(),
]
t = transforms.Compose(t)

return t


def pil_loader(path):
with open(path, "rb") as f:
img = Image.open(f)
return img.convert("RGB")


def save_image(transformed_img, output_image_path):
if isinstance(transformed_img, torch.Tensor):
transformed_img = transforms.ToPILImage()(transformed_img)
transformed_img.save(output_image_path)


def get_image_files(input_dir):
for root, _, files in os.walk(input_dir):
for file in files:
if file.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp')):
yield os.path.join(root, file)


def transform_and_save_crops(args):
input_path, input_dir, output_dir, transform = args
print(input_path)
file_basename = os.path.basename(input_path)

img = pil_loader(input_path)
transformed_img = transform(img)
output_image_path = os.path.join(output_dir, file_basename)
save_image(transformed_img, output_image_path)


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Save transformed images in a directory.')
parser.add_argument('input_dir', help='Path to the input directory.')
parser.add_argument('output_dir', help='Path to the output directory.')
parser.add_argument('-p', '--processes', type=int, default=cpu_count(), help='Number of processes to use. Default: number of CPU cores')
parser.add_argument('--input_size', type=int, default=16384, help='input image size')
args = parser.parse_args()

input_dir = args.input_dir
output_dir = args.output_dir
num_processes = args.processes
input_size = args.input_size
print("num_processes: {}".format(num_processes))
print("input_size: {}".format(input_size))

transform = build_transform(input_size=input_size)

image_files = list(get_image_files(input_dir))
task_args = [(file, input_dir, output_dir, transform) for file in image_files]

os.makedirs(output_dir, exist_ok=True)

with Pool(processes=num_processes) as pool:
pool.map(transform_and_save_crops, task_args)
45 changes: 45 additions & 0 deletions examples/longvit/data_preprocessing/convert_wsi_to_images.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import os
import glob
import argparse
import openslide

from PIL import Image
from concurrent.futures import ProcessPoolExecutor


def convert_wsi_to_images(slide_path, image_path, target_size, level=0):
slide = openslide.open_slide(slide_path)
level_dims = slide.level_dimensions
region = slide.read_region((0,0), level, level_dims[level])
region = region.convert("RGB")
print("convert: {}({}) -> {}".format(slide_path, region.size, image_path))
resized_img = region.resize((target_size, target_size), Image.BICUBIC)
resized_img.save(image_path)


def process_slides(input_folder, output_folder, target_size, level=0):
if not os.path.exists(output_folder):
os.makedirs(output_folder)

slide_paths = glob.glob(os.path.join(input_folder, "*.svs"))

with ProcessPoolExecutor(max_workers=1) as executor:
for slide_path in slide_paths:
image_path = os.path.join(output_folder, os.path.basename(slide_path).split(".svs")[0] + ".jpg")
executor.submit(convert_wsi_to_images, slide_path, image_path, target_size, level=level)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert slides into images")
parser.add_argument("input_folder", type=str, help="")
parser.add_argument("output_folder", type=str, help="")
parser.add_argument("target_size", type=int, help="")
parser.add_argument("level", type=int, help="")

args = parser.parse_args()
input_folder = args.input_folder
output_folder = args.output_folder
target_size = args.target_size
level = args.level

process_slides(input_folder, output_folder, target_size, level=level)
Loading

0 comments on commit 11e745e

Please sign in to comment.