diff --git a/comp2comp/spine/spine.py b/comp2comp/spine/spine.py index a4384fb..91c3c0f 100644 --- a/comp2comp/spine/spine.py +++ b/comp2comp/spine/spine.py @@ -57,51 +57,66 @@ def __call__(self, inference_pipeline): # ) os.environ["TOTALSEG_WEIGHTS_PATH"] = self.model_dir - seg = totalsegmentator( - input=os.path.join(self.output_dir_segmentations, "converted_dcm.nii.gz"), - output=os.path.join(self.output_dir_segmentations, "segmentation.nii"), - task_ids=[292], - ml=True, - nr_thr_resamp=1, - nr_thr_saving=6, - fast=False, - nora_tag="None", - preview=False, - task="total", - # roi_subset=[ - # "vertebrae_T12", - # "vertebrae_L1", - # "vertebrae_L2", - # "vertebrae_L3", - # "vertebrae_L4", - # "vertebrae_L5", - # ], - roi_subset=None, - statistics=False, - radiomics=False, - crop_path=None, - body_seg=False, - force_split=False, - output_type="nifti", - quiet=False, - verbose=False, - test=0, - skip_saving=True, - device="gpu", - license_number=None, - statistics_exclude_masks_at_border=True, - no_derived_masks=False, - v1_order=False, - ) - mv = nib.load( - os.path.join(self.output_dir_segmentations, "converted_dcm.nii.gz") - ) + cache_dir = "/dataNAS/people/lblankem/install_testing_v1/Comp2Comp/outputs/2023-11-30_14-21-29/medstar/abdct/studies" - # save the seg - nib.save( - seg, - os.path.join(self.output_dir_segmentations, "spine_seg.nii.gz"), - ) + split_output_dir = str(self.output_dir).split("/") + study_name = split_output_dir[-2] + series_name = split_output_dir[-1] + + cache_image_path = os.path.join(cache_dir, study_name, series_name, "segmentations/converted_dcm.nii.gz") + cache_seg_path = os.path.join(cache_dir, study_name, series_name, "segmentations/spine_seg.nii.gz") + + if os.path.exists(cache_image_path) and os.path.exists(cache_seg_path): + print("Using cached spine segmentation.") + seg = nib.load(cache_seg_path) + mv = nib.load(cache_image_path) + + else: + seg = totalsegmentator( + input=os.path.join(self.output_dir_segmentations, "converted_dcm.nii.gz"), + output=os.path.join(self.output_dir_segmentations, "segmentation.nii"), + task_ids=[292], + ml=True, + nr_thr_resamp=1, + nr_thr_saving=6, + fast=False, + nora_tag="None", + preview=False, + task="total", + # roi_subset=[ + # "vertebrae_T12", + # "vertebrae_L1", + # "vertebrae_L2", + # "vertebrae_L3", + # "vertebrae_L4", + # "vertebrae_L5", + # ], + roi_subset=None, + statistics=False, + radiomics=False, + crop_path=None, + body_seg=False, + force_split=False, + output_type="nifti", + quiet=False, + verbose=False, + test=0, + skip_saving=True, + device="gpu", + license_number=None, + statistics_exclude_masks_at_border=True, + no_derived_masks=False, + v1_order=False, + ) + mv = nib.load( + os.path.join(self.output_dir_segmentations, "converted_dcm.nii.gz") + ) + + # save the seg + nib.save( + seg, + os.path.join(self.output_dir_segmentations, "spine_seg.nii.gz"), + ) # inference_pipeline.segmentation = nib.load( # os.path.join(self.output_dir_segmentations, "segmentation.nii") @@ -265,12 +280,18 @@ def __call__(self, inference_pipeline): """ segmentation = inference_pipeline.segmentation segmentation_data = segmentation.get_fdata() - upper_level_index = np.where(segmentation_data == self.upper_level_index)[ - 2 - ].max() - lower_level_index = np.where(segmentation_data == self.lower_level_index)[ - 2 - ].min() + try: + upper_level_index = np.where(segmentation_data == self.upper_level_index)[ + 2 + ].max() + except: + upper_level_index = segmentation_data.shape[2] + try: + lower_level_index = np.where(segmentation_data == self.lower_level_index)[ + 2 + ].min() + except: + lower_level_index = 0 segmentation = segmentation.slicer[:, :, lower_level_index:upper_level_index] inference_pipeline.segmentation = segmentation diff --git a/comp2comp/spine/spine_utils.py b/comp2comp/spine/spine_utils.py index 01fad4e..680341d 100644 --- a/comp2comp/spine/spine_utils.py +++ b/comp2comp/spine/spine_utils.py @@ -11,6 +11,8 @@ import nibabel as nib import numpy as np from scipy.ndimage import zoom +import scipy +import time from comp2comp.spine import spine_visualization @@ -218,6 +220,8 @@ def roi_from_mask(img, centroid: np.ndarray, seg: np.ndarray, slice: np.ndarray) ) return roi else: + roi_start_time = time.time() + mask = None inferior_superior_line = seg[int(centroid[0]), int(centroid[1]), :] # get the center point @@ -225,29 +229,43 @@ def roi_from_mask(img, centroid: np.ndarray, seg: np.ndarray, slice: np.ndarray) lower_z_idx = updated_z_center - ((length_k * 1.5) // 2) upper_z_idx = updated_z_center + ((length_k * 1.5) // 2) for idx in range(int(lower_z_idx), int(upper_z_idx) + 1): - posterior_anterior_line = slice[:, idx] - updated_posterior_anterior_center = np.mean( - np.where(posterior_anterior_line == 1) - ) - # take the center to be the 1/4 percentile + # take multiple to increase robustness + posterior_anterior_lines = [ + slice[:, idx], + slice[:, idx + 1], + slice[:, idx - 1], + ] + posterior_anterior_sums = [ + np.sum(posterior_anterior_lines[0]), + np.sum(posterior_anterior_lines[1]), + np.sum(posterior_anterior_lines[2]), + ] + min_idx = np.argmin(posterior_anterior_sums) + + posterior_anterior_line = posterior_anterior_lines[min_idx] updated_posterior_anterior_center = ( np.min(np.where(posterior_anterior_line == 1)) + np.sum(posterior_anterior_line) * 0.58 ) - posterior_anterior_length = (np.sum(posterior_anterior_line) * 0.5) // 2 - left_right_line = seg[:, int(updated_posterior_anterior_center), idx] + posterior_anterior_length = (posterior_anterior_sums[min_idx] * 0.5) // 2 + + left_right_lines = [ + seg[:, int(updated_posterior_anterior_center), idx], + seg[:, int(updated_posterior_anterior_center) + 1, idx], + seg[:, int(updated_posterior_anterior_center) - 1, idx], + ] + + left_right_sums = [ + np.sum(left_right_lines[0]), + np.sum(left_right_lines[1]), + np.sum(left_right_lines[2]), + ] + + min_idx = np.argmin(left_right_sums) + left_right_line = left_right_lines[min_idx] + updated_left_right_center = np.mean(np.where(left_right_line == 1)) - left_right_length = (np.sum(left_right_line) * 0.65) // 2 - # roi_2d = np.zeros((img_np.shape[0], img_np.shape[1])) - # roi_2d[ - # int(updated_left_right_center - left_right_length) : int( - # updated_left_right_center + left_right_length - # ), - # int( - # updated_posterior_anterior_center - # - (posterior_anterior_length * 0.5) - # ) : int(updated_posterior_anterior_center + posterior_anterior_length), - # ] = 1 + left_right_length = (left_right_sums[min_idx] * 0.65) // 2 roi_2d = np.zeros((img_np.shape[0], img_np.shape[1])) h = updated_left_right_center @@ -276,6 +294,26 @@ def roi_from_mask(img, centroid: np.ndarray, seg: np.ndarray, slice: np.ndarray) if mask is None: mask = updated_mask + start_time = time.time() + + # Make sure there is no overlap with the cortical bone + num_iteration = 2 + if pixel_spacing[2] >= 3: + num_iteration = 1 + struct = scipy.ndimage.generate_binary_structure(3, 1) + struct = scipy.ndimage.iterate_structure(struct, num_iteration) + seg = scipy.ndimage.binary_erosion(seg, structure=struct).astype(np.int8) + + roi = roi * seg + + end_time = time.time() + roi_end_time = time.time() + elapsed_time = end_time - start_time + print(f"Elapsed time for erosion operation: {elapsed_time} seconds") + + elapsed_time = roi_end_time - roi_start_time + print(f"Elapsed time for full ROI computation: {elapsed_time} seconds") + return roi, mask