Skip to content

Commit

Permalink
fixed nifty support bug
Browse files Browse the repository at this point in the history
  • Loading branch information
malteekj committed Jan 10, 2025
1 parent 437f5a6 commit a414f3f
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 39 deletions.
7 changes: 4 additions & 3 deletions bin/C2C-slurm
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,18 @@ def python_submit(command, node=None):
print(f'Submitted the command --- "{command}" --- to slurm.')
except subprocess.CalledProcessError:
if node == None:
command = f"sbatch -c 8 --gres=gpu:1 --output ./slurm/slurm-%j.out --mem=50gb --time=100-00:00:00 slurm.sh "
command = f"sbatch -c 8 --gres=gpu:1 --output ./slurm/slurm-%j.out --mem=60gb --time=100-00:00:00 slurm.sh "
submit_command(command)
print(f'Submitted the command --- "{command}" --- to slurm.')
else:
command = f"sbatch -c 8 --gres=gpu:1 --output ./slurm/slurm-%j.out --nodelist={node} --mem=50gb --time=100-00:00:00 slurm.sh"
# command = f"sbatch -c 8 --gres=gpu:titanrtx:1 --output ./slurm/slurm-%j.out --nodelist={node} --mem=60gb --time=100-00:00:00 slurm.sh"
command = f"sbatch -c 8 --gres=gpu:1 --output ./slurm/slurm-%j.out --nodelist={node} --mem=60gb --time=100-00:00:00 slurm.sh"
submit_command(command)
print(f'Submitted the command --- "{command}" --- to slurm.')
os.remove("./slurm.sh")


python_submit(command, node='siena')
python_submit(command, node='amalfi')



59 changes: 39 additions & 20 deletions comp2comp/aortic_calcium/aortic_calcium.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def __init__(self):
# self.input_path = input_path

def __call__(self, inference_pipeline):
# check if kernels are allowed if agatson is used
if inference_pipeline.args.threshold == 'agatson':
# check if kernels are allowed if agatston is used
if inference_pipeline.args.threshold == 'agatston':
self.reconKernelChecker(inference_pipeline.dcm)

# inference_pipeline.dicom_series_path = self.input_path
Expand Down Expand Up @@ -191,6 +191,7 @@ def __call__(self, inference_pipeline):
exclude_mask=spine_mask,
remove_size=3,
return_dilated_mask=True,
return_eroded_aorta=True,
threshold=inference_pipeline.args.threshold,
dilation_iteration=target_aorta_dil,
dilation_iteration_exclude=target_exclude_dil,
Expand All @@ -208,19 +209,37 @@ def __call__(self, inference_pipeline):
"calcium_segmentations.nii.gz",
),
)

inference_pipeline.saveArrToNifti(
calcification_results["dilated_mask"],
os.path.join(
inference_pipeline.output_dir_segmentation_masks,
"dilated_aorta_mask.nii.gz",
),
)

inference_pipeline.saveArrToNifti(
calcification_results["aorta_eroded"],
os.path.join(
inference_pipeline.output_dir_segmentation_masks,
"eroded_aorta_mask.nii.gz",
),
)

inference_pipeline.saveArrToNifti(
spine_mask,
os.path.join(
inference_pipeline.output_dir_segmentation_masks, "spine_mask.nii.gz"
),
)

inference_pipeline.saveArrToNifti(
aorta_mask,
os.path.join(
inference_pipeline.output_dir_segmentation_masks, "aorta_mask.nii.gz"
),
)

inference_pipeline.saveArrToNifti(
ct,
os.path.join(inference_pipeline.output_dir_segmentation_masks, "ct.nii.gz"),
Expand All @@ -247,7 +266,7 @@ def detectCalcifications(
return_eroded_aorta=False,
aorta_erode_iteration=6,
threshold="adaptive",
agatson_failsafe=100,
agatston_failsafe=100,
generate_plots=True,
):
"""
Expand Down Expand Up @@ -290,11 +309,11 @@ def detectCalcifications(
aorta_erode_iteration (int, optional):
Number of iterations for the strcturing element. Defaults to 6.
threshold: (str, int):
Can either be 'adaptive', 'agatson', or int. Choosing 'agatson'
Can either be 'adaptive', 'agatston', or int. Choosing 'agatston'
Will mean a threshold of 130 HU.
agatson_failsafe: (int):
agatston_failsafe: (int):
A fail-safe raising an error if the mean HU of the aorta is too high
to reliably be using the agatson threshold of 130. Defaults to 100 HU.
to reliably be using the agatston threshold of 130. Defaults to 100 HU.
Returns:
results: array of only the mask is returned, or dict if other volumes are also returned.
Expand Down Expand Up @@ -359,11 +378,11 @@ def detectCalcifications(
os.path.join(self.output_dir, "images/histogram_eroded_aorta.png")
)

# Perform the fail-safe check if the method is agatson
if threshold == "agatson" and eroded_ct_points_mean > agatson_failsafe:
# Perform the fail-safe check if the method is agatston
if threshold == "agatston" and eroded_ct_points_mean > agatston_failsafe:
raise ValueError(
"The mean HU in the center aorta is {:.0f}, and the Agatson method will provide unreliable results (fail-safe threshold is {})".format(
eroded_ct_points_mean, agatson_failsafe
"The mean HU in the center aorta is {:.0f}, and the agatston method will provide unreliable results (fail-safe threshold is {})".format(
eroded_ct_points_mean, agatston_failsafe
)
)

Expand All @@ -388,7 +407,7 @@ def detectCalcifications(
)
calc_thres = np.median(aorta_ct_points) + quantile_median_dist * num_std

elif threshold == "agatson":
elif threshold == "agatston":
calc_thres = 130

counter = self.slicedSizeCount(aorta_eroded, ct, remove_size, calc_thres)
Expand All @@ -414,7 +433,7 @@ def detectCalcifications(
except:
raise ValueError(
"Error in threshold value for aortic calcium segmentaiton. \
Should be 'adaptive', 'agatson' or int, but got: "
Should be 'adaptive', 'agatston' or int, but got: "
+ str(threshold)
)

Expand Down Expand Up @@ -714,11 +733,11 @@ def __call__(self, inference_pipeline):

metrics["num_calc"] = num_lesions

if inference_pipeline.args.threshold == "agatson":
if inference_pipeline.args.threshold == "agatston":
if num_lesions == 0:
metrics["agatson_score"] = 0
metrics["agatston_score"] = 0
else:
metrics["agatson_score"] = self.CalculateAgatsonScore(
metrics["agatston_score"] = self.CalculateAgatstonScore(
calc_mask_region, ct, inference_pipeline.pix_dims
)

Expand All @@ -728,9 +747,9 @@ def __call__(self, inference_pipeline):

return {}

def CalculateAgatsonScore(self, calc_mask_region, ct, pix_dims):
def CalculateAgatstonScore(self, calc_mask_region, ct, pix_dims):
"""
Original Agatson papers says need to be >= 1mm^2, other papers
Original Agatston papers says need to be >= 1mm^2, other papers
use at least 3 face-linked pixels.
"""

Expand All @@ -751,7 +770,7 @@ def get_hu_factor(max_hu):

# dims are in mm here
area_per_pixel = pix_dims[0] * pix_dims[1]
agatson = 0
agatston = 0

for i in range(calc_mask_region.shape[2]):
tmp_slice = calc_mask_region[:, :, i]
Expand All @@ -767,8 +786,8 @@ def get_hu_factor(max_hu):
if tmp_area <= 1:
continue
else:
agatson += tmp_area * get_hu_factor(
agatston += tmp_area * get_hu_factor(
int(tmp_ct_slice[tmp_mask].max())
)

return agatson
return agatston
8 changes: 4 additions & 4 deletions comp2comp/aortic_calcium/aortic_calcium_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ def __call__(self, inference_pipeline):
"{},{:.3f}\n".format("Min volume (cm³):", np.min(metrics["volume"]))
)

if inference_pipeline.args.threshold == "agatson":
f.write("Agatson score,{:.1f}\n".format(metrics["agatson_score"]))
if inference_pipeline.args.threshold == "agatston":
f.write("Agatston score,{:.1f}\n".format(metrics["agatston_score"]))

distance = 25
print("\n")
Expand Down Expand Up @@ -187,10 +187,10 @@ def __call__(self, inference_pipeline):
inference_pipeline.calcium_threshold,
)
)
if inference_pipeline.args.threshold == "agatson":
if inference_pipeline.args.threshold == "agatston":
print(
"{:<{}}{:.1f}".format(
"Agatson score:", distance, metrics["agatson_score"]
"Agatston score:", distance, metrics["agatston_score"]
)
)

Expand Down
32 changes: 22 additions & 10 deletions comp2comp/contrast_phase/contrast_inf.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,14 +425,23 @@ def predict_phase(TS_path, scan_path, outputPath=None, save_sample=False):
y_pred_proba = model.predict_proba([featureArray])[0]
y_pred = np.argmax(y_pred_proba)

if y_pred == 0:
pred_phase = "non-contrast"
if y_pred == 1:
pred_phase = "arterial"
if y_pred == 2:
pred_phase = "venous"
if y_pred == 3:
pred_phase = "delayed"
phase_dict = {
0: "non-contrast",
1: "arterial",
2: "venous",
3: "delayed"
}

pred_phase = phase_dict[y_pred]

# if y_pred == 0:
# pred_phase = "non-contrast"
# if y_pred == 1:
# pred_phase = "arterial"
# if y_pred == 2:
# pred_phase = "venous"
# if y_pred == 3:
# pred_phase = "delayed"

output_path_metrics = os.path.join(outputPath, "metrics")
if not os.path.exists(output_path_metrics):
Expand All @@ -441,9 +450,12 @@ def predict_phase(TS_path, scan_path, outputPath=None, save_sample=False):

with open(outputTxt, "w") as text_file:
text_file.write('phase,'+pred_phase + '\n')
text_file.write('probability,{:.3f}'.format(y_pred_proba[y_pred]))
for i in range(len(y_pred_proba)):
text_file.write('{},{:.3f}\n'.format(phase_dict[i], y_pred_proba[i]))

print('Predicted phase: ' + pred_phase)
print('Probability: {:.3f}'.format(y_pred_proba[y_pred]))
for i in range(len(y_pred_proba)):
print('{},{:.3f}'.format(phase_dict[i], y_pred_proba[i]))

output_path_images = os.path.join(outputPath, "images")
if not os.path.exists(output_path_images):
Expand Down
6 changes: 4 additions & 2 deletions comp2comp/io/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,6 @@ def __init__(self, input_path: Union[str, Path], pipeline_name=None, save=True):
self.pipeline_name = pipeline_name

def __call__(self, inference_pipeline):
dcm_files = [d for d in os.listdir(self.input_path) if d.endswith('.dcm')]
inference_pipeline.dcm = pydicom.read_file(os.path.join(self.input_path, dcm_files[0]))
if os.path.exists(
os.path.join(
inference_pipeline.output_dir, "segmentations", "converted_dcm.nii.gz"
Expand All @@ -96,6 +94,10 @@ def __call__(self, inference_pipeline):

# if self.input_path is a folder
if self.input_path.is_dir():
# store a dcm object for retrieving dicom tags
dcm_files = [d for d in os.listdir(self.input_path) if d.endswith('.dcm')]
inference_pipeline.dcm = pydicom.read_file(os.path.join(self.input_path, dcm_files[0]))

ds = dicom_series_to_nifti(
self.input_path,
output_file=os.path.join(
Expand Down

0 comments on commit a414f3f

Please sign in to comment.