Skip to content

Commit

Permalink
spine updates
Browse files Browse the repository at this point in the history
  • Loading branch information
louisblankemeier committed Dec 3, 2023
1 parent 3532bb4 commit bc21bd8
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 68 deletions.
121 changes: 71 additions & 50 deletions comp2comp/spine/spine.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,51 +57,66 @@ def __call__(self, inference_pipeline):
# )
os.environ["TOTALSEG_WEIGHTS_PATH"] = self.model_dir

seg = totalsegmentator(
input=os.path.join(self.output_dir_segmentations, "converted_dcm.nii.gz"),
output=os.path.join(self.output_dir_segmentations, "segmentation.nii"),
task_ids=[292],
ml=True,
nr_thr_resamp=1,
nr_thr_saving=6,
fast=False,
nora_tag="None",
preview=False,
task="total",
# roi_subset=[
# "vertebrae_T12",
# "vertebrae_L1",
# "vertebrae_L2",
# "vertebrae_L3",
# "vertebrae_L4",
# "vertebrae_L5",
# ],
roi_subset=None,
statistics=False,
radiomics=False,
crop_path=None,
body_seg=False,
force_split=False,
output_type="nifti",
quiet=False,
verbose=False,
test=0,
skip_saving=True,
device="gpu",
license_number=None,
statistics_exclude_masks_at_border=True,
no_derived_masks=False,
v1_order=False,
)
mv = nib.load(
os.path.join(self.output_dir_segmentations, "converted_dcm.nii.gz")
)
cache_dir = "/dataNAS/people/lblankem/install_testing_v1/Comp2Comp/outputs/2023-11-30_14-21-29/medstar/abdct/studies"

# save the seg
nib.save(
seg,
os.path.join(self.output_dir_segmentations, "spine_seg.nii.gz"),
)
split_output_dir = str(self.output_dir).split("/")
study_name = split_output_dir[-2]
series_name = split_output_dir[-1]

cache_image_path = os.path.join(cache_dir, study_name, series_name, "segmentations/converted_dcm.nii.gz")
cache_seg_path = os.path.join(cache_dir, study_name, series_name, "segmentations/spine_seg.nii.gz")

if os.path.exists(cache_image_path) and os.path.exists(cache_seg_path):
print("Using cached spine segmentation.")
seg = nib.load(cache_seg_path)
mv = nib.load(cache_image_path)

else:
seg = totalsegmentator(
input=os.path.join(self.output_dir_segmentations, "converted_dcm.nii.gz"),
output=os.path.join(self.output_dir_segmentations, "segmentation.nii"),
task_ids=[292],
ml=True,
nr_thr_resamp=1,
nr_thr_saving=6,
fast=False,
nora_tag="None",
preview=False,
task="total",
# roi_subset=[
# "vertebrae_T12",
# "vertebrae_L1",
# "vertebrae_L2",
# "vertebrae_L3",
# "vertebrae_L4",
# "vertebrae_L5",
# ],
roi_subset=None,
statistics=False,
radiomics=False,
crop_path=None,
body_seg=False,
force_split=False,
output_type="nifti",
quiet=False,
verbose=False,
test=0,
skip_saving=True,
device="gpu",
license_number=None,
statistics_exclude_masks_at_border=True,
no_derived_masks=False,
v1_order=False,
)
mv = nib.load(
os.path.join(self.output_dir_segmentations, "converted_dcm.nii.gz")
)

# save the seg
nib.save(
seg,
os.path.join(self.output_dir_segmentations, "spine_seg.nii.gz"),
)

# inference_pipeline.segmentation = nib.load(
# os.path.join(self.output_dir_segmentations, "segmentation.nii")
Expand Down Expand Up @@ -265,12 +280,18 @@ def __call__(self, inference_pipeline):
"""
segmentation = inference_pipeline.segmentation
segmentation_data = segmentation.get_fdata()
upper_level_index = np.where(segmentation_data == self.upper_level_index)[
2
].max()
lower_level_index = np.where(segmentation_data == self.lower_level_index)[
2
].min()
try:
upper_level_index = np.where(segmentation_data == self.upper_level_index)[
2
].max()
except:
upper_level_index = segmentation_data.shape[2]
try:
lower_level_index = np.where(segmentation_data == self.lower_level_index)[
2
].min()
except:
lower_level_index = 0
segmentation = segmentation.slicer[:, :, lower_level_index:upper_level_index]
inference_pipeline.segmentation = segmentation

Expand Down
74 changes: 56 additions & 18 deletions comp2comp/spine/spine_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import nibabel as nib
import numpy as np
from scipy.ndimage import zoom
import scipy
import time

from comp2comp.spine import spine_visualization

Expand Down Expand Up @@ -218,36 +220,52 @@ def roi_from_mask(img, centroid: np.ndarray, seg: np.ndarray, slice: np.ndarray)
)
return roi
else:
roi_start_time = time.time()

mask = None
inferior_superior_line = seg[int(centroid[0]), int(centroid[1]), :]
# get the center point
updated_z_center = np.mean(np.where(inferior_superior_line == 1))
lower_z_idx = updated_z_center - ((length_k * 1.5) // 2)
upper_z_idx = updated_z_center + ((length_k * 1.5) // 2)
for idx in range(int(lower_z_idx), int(upper_z_idx) + 1):
posterior_anterior_line = slice[:, idx]
updated_posterior_anterior_center = np.mean(
np.where(posterior_anterior_line == 1)
)
# take the center to be the 1/4 percentile
# take multiple to increase robustness
posterior_anterior_lines = [
slice[:, idx],
slice[:, idx + 1],
slice[:, idx - 1],
]
posterior_anterior_sums = [
np.sum(posterior_anterior_lines[0]),
np.sum(posterior_anterior_lines[1]),
np.sum(posterior_anterior_lines[2]),
]
min_idx = np.argmin(posterior_anterior_sums)

posterior_anterior_line = posterior_anterior_lines[min_idx]
updated_posterior_anterior_center = (
np.min(np.where(posterior_anterior_line == 1))
+ np.sum(posterior_anterior_line) * 0.58
)
posterior_anterior_length = (np.sum(posterior_anterior_line) * 0.5) // 2
left_right_line = seg[:, int(updated_posterior_anterior_center), idx]
posterior_anterior_length = (posterior_anterior_sums[min_idx] * 0.5) // 2

left_right_lines = [
seg[:, int(updated_posterior_anterior_center), idx],
seg[:, int(updated_posterior_anterior_center) + 1, idx],
seg[:, int(updated_posterior_anterior_center) - 1, idx],
]

left_right_sums = [
np.sum(left_right_lines[0]),
np.sum(left_right_lines[1]),
np.sum(left_right_lines[2]),
]

min_idx = np.argmin(left_right_sums)
left_right_line = left_right_lines[min_idx]

updated_left_right_center = np.mean(np.where(left_right_line == 1))
left_right_length = (np.sum(left_right_line) * 0.65) // 2
# roi_2d = np.zeros((img_np.shape[0], img_np.shape[1]))
# roi_2d[
# int(updated_left_right_center - left_right_length) : int(
# updated_left_right_center + left_right_length
# ),
# int(
# updated_posterior_anterior_center
# - (posterior_anterior_length * 0.5)
# ) : int(updated_posterior_anterior_center + posterior_anterior_length),
# ] = 1
left_right_length = (left_right_sums[min_idx] * 0.65) // 2

roi_2d = np.zeros((img_np.shape[0], img_np.shape[1]))
h = updated_left_right_center
Expand Down Expand Up @@ -276,6 +294,26 @@ def roi_from_mask(img, centroid: np.ndarray, seg: np.ndarray, slice: np.ndarray)
if mask is None:
mask = updated_mask

start_time = time.time()

# Make sure there is no overlap with the cortical bone
num_iteration = 2
if pixel_spacing[2] >= 3:
num_iteration = 1
struct = scipy.ndimage.generate_binary_structure(3, 1)
struct = scipy.ndimage.iterate_structure(struct, num_iteration)
seg = scipy.ndimage.binary_erosion(seg, structure=struct).astype(np.int8)

roi = roi * seg

end_time = time.time()
roi_end_time = time.time()
elapsed_time = end_time - start_time
print(f"Elapsed time for erosion operation: {elapsed_time} seconds")

elapsed_time = roi_end_time - roi_start_time
print(f"Elapsed time for full ROI computation: {elapsed_time} seconds")

return roi, mask


Expand Down

0 comments on commit bc21bd8

Please sign in to comment.