Skip to content

Commit

Permalink
Update on "Move and rename GranularityType -> Granularity"
Browse files Browse the repository at this point in the history
Summary: Move GranularityType to quant_primitives.py to be
consistent with other similar fields like MappingType and
ZeroPointDomain.

Test Plan: CI

[ghstack-poisoned]
  • Loading branch information
andrewor14 committed Oct 9, 2024
2 parents b28be42 + dab1ec2 commit f7f7864
Show file tree
Hide file tree
Showing 20 changed files with 566 additions and 90 deletions.
5 changes: 5 additions & 0 deletions examples/sam2_amg_server/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
To run this example you need to download the vit_h checkpoint and put it into a local folder named checkpoints

You can find the checkpoint for vit_h here: https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

To read the image you also need to install opencv-python: https://pypi.org/project/opencv-python/
133 changes: 133 additions & 0 deletions examples/sam2_amg_server/amg_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import torch.utils.benchmark as benchmark

from torch._inductor import config as inductorconfig
inductorconfig.triton.unique_kernel_names = True
inductorconfig.coordinate_descent_tuning = True
inductorconfig.coordinate_descent_check_all_directions = True

def profiler_runner(path, fn, *args, **kwargs):
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA],
record_shapes=True) as prof:
result = fn(*args, **kwargs)
print(f"Saving trace under {path}")
prof.export_chrome_trace(path)
return result

def show_anns(anns):
if len(anns) == 0:
return
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
ax = plt.gca()
ax.set_autoscale_on(False)

img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
img[:,:,3] = 0
ms = []
for ann in sorted_anns:
m = ann['segmentation']
ms.append(torch.as_tensor(m))
color_mask = np.concatenate([np.random.random(3), [0.35]])
img[m] = color_mask
ax.imshow(img)
return torch.stack(ms)

image = cv2.imread('dog.jpg')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)


# from segment_anything_fast import sam_model_registry, sam_model_fast_registry, SamAutomaticMaskGenerator
#
# sam_checkpoint = "checkpoints/sam_vit_h_4b8939.pth"
# model_type = "vit_h"
device = "cuda"
#
# sam = sam_model_fast_registry[model_type](checkpoint=sam_checkpoint)
# sam.to(device=device)

from sam2.build_sam import build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator

sam2_checkpoint = "checkpoints/sam2_hiera_large.pt"
model_cfg = "sam2_hiera_l.yaml"

sam2 = build_sam2(model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False)
sam2.to(device=device)

# mask_generator = SAM2AutomaticMaskGenerator(sam2, points_per_batch=256)
mask_generator = SAM2AutomaticMaskGenerator(sam2, points_per_batch=None)

## NOTE: Causes numerical differences
## TODO: Implement mIoU to allow approximations.
# torch.set_float32_matmul_precision('high')
# torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
##

## TODO: Using CUDA graphs can cause numerical differences?
mask_generator.predictor.model.image_encoder = torch.compile(
mask_generator.predictor.model.image_encoder,
# mode="max-autotune-no-cudagraphs",
mode="max-autotune",
fullgraph=True,
dynamic=False,
)

# mask_generator.predictor._predict = torch.compile(
# mask_generator.predictor._predict,
# mode="max-autotune-no-cudagraphs",
# fullgraph=True,
# dynamic=False,
# )

torch._dynamo.config.capture_dynamic_output_shape_ops = True
mask_generator._process_batch = torch.compile(
mask_generator._process_batch,
mode="max-autotune-no-cudagraphs",
fullgraph=True,
dynamic=True,
)

# with torch.backends.cuda.sdp_kernel(enable_cudnn=False): #, enable_math=False, enable_mem_efficient=False):
with torch.backends.cuda.sdp_kernel(enable_cudnn=True): #, enable_math=False, enable_mem_efficient=False):
# Run thrice for warmup
masks = mask_generator.generate(image)
masks = mask_generator.generate(image)
masks = mask_generator.generate(image)

# Save an example
plt.figure(figsize=(image.shape[1]/100., image.shape[0]/100.), dpi=100)
plt.imshow(image)
ms = show_anns(masks)
ms_ref = torch.load("dog_mask_fast.pt")
torch.testing.assert_allclose(ms, ms_ref)
print("Masks match reference")
# # torch.save(ms, "dog_mask_fast.pt")
plt.axis('off')
plt.tight_layout()
plt.savefig('dog_mask_fast.png', format='png')

# Benchmark
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for _ in range(10):
masks = mask_generator.generate(image)
end_event.record()
torch.cuda.synchronize()
print(start_event.elapsed_time(end_event) / 10.)

# Save a GPU trace
profiler_runner(f"amg_example_trace.json.gz", mask_generator.generate, image)

# Write out memory usage
max_memory_allocated_bytes = torch.cuda.max_memory_allocated()
_, total_memory = torch.cuda.mem_get_info()
max_memory_allocated_percentage = int(100 * (max_memory_allocated_bytes / total_memory))
max_memory_allocated_bytes = max_memory_allocated_bytes >> 20
print(f"memory(MiB): {max_memory_allocated_bytes} memory(%): {max_memory_allocated_percentage}")
57 changes: 57 additions & 0 deletions examples/sam2_amg_server/example.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Upload and Display Image from FastAPI Response</title>
<style>
#preview {
margin-top: 20px;
max-width: 100%;
max-height: 400px;
display: none;
}
</style>
</head>
<body>
<h1>Upload an Image and Display the Response</h1>
<form id="uploadForm">
<label for="image">Choose an image to upload:</label><br>
<input type="file" id="image" name="image" accept="image/*" required><br><br>
<input type="submit" value="Upload Image">
</form>

<h2>Received Image Preview:</h2>
<img id="preview" alt="Received Image">

<script>
document.getElementById('uploadForm').addEventListener('submit', function (e) {
e.preventDefault();

const formData = new FormData();
const fileInput = document.getElementById('image');
const file = fileInput.files[0];

if (file) {
formData.append('image', file);

// Perform the image upload via Fetch API
fetch('http://127.0.0.1:5000/upload', {
method: 'POST',
body: formData
})
.then(response => response.blob()) // Get the image as a Blob from the response
.then(imageBlob => {
const imageObjectURL = URL.createObjectURL(imageBlob);
const preview = document.getElementById('preview');
preview.src = imageObjectURL;
preview.style.display = 'block';
})
.catch(error => {
console.error('Error:', error);
});
}
});
</script>
</body>
</html>
117 changes: 117 additions & 0 deletions examples/sam2_amg_server/sam2_hiera_l.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# @package _global_

# Model
model:
_target_: sam2.modeling.sam2_base.SAM2Base
image_encoder:
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
scalp: 1
trunk:
_target_: sam2.modeling.backbones.hieradet.Hiera
embed_dim: 144
num_heads: 2
stages: [2, 6, 36, 4]
global_att_blocks: [23, 33, 43]
window_pos_embed_bkg_spatial_size: [7, 7]
window_spec: [8, 4, 16, 8]
neck:
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 256
normalize: true
scale: null
temperature: 10000
d_model: 256
backbone_channel_list: [1152, 576, 288, 144]
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
fpn_interp_model: nearest

memory_attention:
_target_: sam2.modeling.memory_attention.MemoryAttention
d_model: 256
pos_enc_at_input: true
layer:
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
activation: relu
dim_feedforward: 2048
dropout: 0.1
pos_enc_at_attn: false
self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [32, 32]
embedding_dim: 256
num_heads: 1
downsample_rate: 1
dropout: 0.1
d_model: 256
pos_enc_at_cross_attn_keys: true
pos_enc_at_cross_attn_queries: false
cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [32, 32]
rope_k_repeat: True
embedding_dim: 256
num_heads: 1
downsample_rate: 1
dropout: 0.1
kv_in_dim: 64
num_layers: 4

memory_encoder:
_target_: sam2.modeling.memory_encoder.MemoryEncoder
out_dim: 64
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 64
normalize: true
scale: null
temperature: 10000
mask_downsampler:
_target_: sam2.modeling.memory_encoder.MaskDownSampler
kernel_size: 3
stride: 2
padding: 1
fuser:
_target_: sam2.modeling.memory_encoder.Fuser
layer:
_target_: sam2.modeling.memory_encoder.CXBlock
dim: 256
kernel_size: 7
padding: 3
layer_scale_init_value: 1e-6
use_dwconv: True # depth-wise convs
num_layers: 2

num_maskmem: 7
image_size: 1024
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
sigmoid_scale_for_mem_enc: 20.0
sigmoid_bias_for_mem_enc: -10.0
use_mask_input_as_output_without_sam: true
# Memory
directly_add_no_mem_embed: true
# use high-resolution feature map in the SAM mask decoder
use_high_res_features_in_sam: true
# output 3 masks on the first click on initial conditioning frames
multimask_output_in_sam: true
# SAM heads
iou_prediction_use_sigmoid: True
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
use_obj_ptrs_in_encoder: true
add_tpos_enc_to_obj_ptrs: false
only_obj_ptrs_in_the_past_for_eval: true
# object occlusion prediction
pred_obj_scores: true
pred_obj_scores_mlp: true
fixed_no_obj_ptr: true
# multimask tracking settings
multimask_output_for_tracking: true
use_multimask_token_for_obj_ptr: true
multimask_min_pt_num: 0
multimask_max_pt_num: 1
use_mlp_for_obj_ptr_proj: true
# Compilation flag
compile_image_encoder: False
Loading

0 comments on commit f7f7864

Please sign in to comment.