ai_ct_scans.model_trainers module
- class ai_ct_scans.model_trainers.InfillTrainer(axial_width=256, coronal_width=256, sagittal_width=256, batch_size=8, batch_width=256, batch_height=256, blank_width=64, num_encoder_convs=3, encoder_filts_per_layer=10, neurons_per_dense=512, num_dense_layers=3, decoder_filts_per_layer=10, num_decoder_convs=3, kernel_size=3, learning_rate=1e-05, save_dir=None, clear_previous_memmaps=False, save_freq=200, blur_kernel=None, show_outline=False)
Bases:
object
A class for training an ai_ct_scans.models.Infiller network
- border_mask_builder()
Get a list of logical ndarrays that can be used to mask out the border region of an input image, and extract that local region for a ‘label’ array at output of the model. Applying these should help with aliasing effects at the edges of cnn output
- Returns
A set of masks to apply to the axial, coronal and sagittal views taken during building of a batch
- Return type
(list of 2D ndarrays)
- build_batch(patient_indices=None, body_part_indices=None, plane_indices=None, scan_num_indices=None, coords_input_array=None, batch_size=None, require_above_thresh=True, allow_off_edge=False)
Get a batch of inputs and labels for a ai_ct_scans.models.
- Parameters
patient_indices (optional, ndarray of type int) – The indices of patients to access for each batch element,
zero-indexing (with) –
body_part_indices (optional, ndarray of type int) – Indices of body parts to use to build each batch element,
abdomen (0 for) –
thorax (1 for) –
plane_indices (optional, ndarray of type int) – Indices of plane to view the batch element from, 0 for axial,
coronal (1 for) –
sagittal (2 for) –
scan_num_indices (optional, ndarray of type int) – Indices of which sequential scan from each patient to use,
scan (1 for second) –
scan –
coords_input_array (optional, list of lendth 3 1D ndarrays of type int) – The coordinates to use when
that (building each batch element. The coordinate corresponding to the plane_index will be the slice along) –
index –
from (while the other two coordinates will define the top left coordinate of the rectangle extracted) –
plane (that) –
batch_size (optional, int) – How many slices to return for a batch
require_above_thresh (bool) – Whether to reject random slices that do not have any elements above
found (self.threshold and seek out new slices until one is) –
allow_off_edge (bool) – Whether to allow coords_input_array to cause the output slice to overlap the edges of
central (the original scans - useful to ensure that it is possible for every part of a scan to occur at the) –
region (masked) –
- Returns
- ‘input_images’: a stack of 2d axial, coronal and sagittal slices
’input_planes’: a stack of one hot vectors, that correspond to which view the slice was taken from. Shape [batch size, 3] ‘input_body_part’: a stack of one hot vectors, that correspond to the body part the slice was taken from. Shape [batch size, 2] ‘input_coords’: a stack of 1D vectors describing the original xyz location of the slice taken ‘labels’: a stack of 2D axial, coronal and sagittal slices, representing the data that was masked at the centre of each element of input_images
- Return type
(dict of torch.Tensors)
- load_model(directory, model='model.pth')
Load a pretrained model, optimiser state, loss at time of saving, iteration at time of saving
- Parameters
directory (pathlib Path) – Directory in which the model is saved
model (str) – Model filename, defaults to ‘model.pth’
- loss(model_out, batch)
Defines a custom loss function for the network. Weights the loss such that reproduction of the masked region (and a small border area around it) contributes to the overall loss on the same order of magnitude as all other pixels that were predicted
- Parameters
model_out (at least a 'labels' stack of Tensor images of the same shape as) – Stack of images that the model has predicted
batch (dict as built by self.build_batch) – The batch that was used for the iteration, which should include
model_out –
- Returns
the MSE error in the output prediction after reweighting masked region of prediction
- Return type
(torch Tensor)
- plane_mask_builder(blank_width=None)
Get a list of logical ndarrays that can be used to mask out the central region of an input image, and extract that local region for a ‘label’ array at output of the model
- Returns
A set of masks to apply to the axial, coronal and sagittal views taken during building of a batch
- Return type
(list of 2D ndarrays)
- process_full_scan(patient_index, body_part_index, plane_index, scan_index, batch_size, overlap=1, save_path=None, allow_off_edge=True)
Run a single scan through the model in batches along a chosen axis, patching the prediction of the masked region together and subtracting from the real scan to form an ‘anomaly scan’.
- Parameters
patient_index (int) – The index of a patient as stored in ai_ct_scans.data_loading.MultiPatientLoader, ie
index (Patient 1 at 0) –
etc (Patient 2 at 1) –
body_part_index (int) – Index of a body part, 0 for abdomen, 1 for thorax currently supported
plane_index (int) – Index of plane to stack inputs from, 0 for axial, 1 for coronal, 2 for sagittal
scan_index (int) – 0 for scan 1, 1 for scan 2
batch_size (int) – How many layers to take along the plane_index stack for each batch, 24 with image sizes
VRAM (256x256 seems to work fine with 8GB of) –
overlap (int) – Number of times to overlap the masked regions to build up an average predicted view. The
output (remainder of self.blank_width and overlap should be 0 for well-stitched) –
save_path (pathlib Path) – File path at which to save the anomaly scan as a memmap (shape will be appended
ai_ct_scans.data_loading.load_memmap) (into the filename for ease of reloading with) –
allow_off_edge (bool) – Whether to allow the model input to move off the edge of the scans, so that the
False (stitched central blank regions can cover the entire scan. If) –
masked (the central square column of) –
returned (regions will be) –
- Returns
3D volume of input - predicted output
- Return type
(ndarray)
- random_axial_slicer(arr, indices=None, allow_off_edge=False)
Takes a random crop from a random axial plane of 3D array arr
- Parameters
allow_off_edge (bool) – optional, defaults to False. Whether to allow indices which will take the view
arr (ndarray) –
indices (list of 3 ints) – Coordinates at which to take the slice from. 1st and 2nd indices define a top
view (left corner of the) –
slice (0th index defines the axial) –
arr – 3D volume
- Returns
2D image
- Return type
(ndarray)
- random_coronal_slicer(arr, indices=None, allow_off_edge=False)
Takes a random crop from a random coronal plane of 3D array arr
- Parameters
allow_off_edge (bool) – optional, defaults to False. Whether to allow indices which will take the view
arr (ndarray) –
indices (list of 3 ints) – Coordinates at which to take the slice from. 0th and 2nd indices define a top
view (left corner of the) –
slice (1st index defines the coronal) –
arr – 3D volume
- Returns
2D image
- Return type
(ndarray)
- random_sagittal_slicer(arr, indices=None, allow_off_edge=False)
Takes a random crop from a random sagittal plane of 3D array arr
- Parameters
allow_off_edge (bool) – optional, defaults to False. Whether to allow indices which will take the view
arr (ndarray) –
indices (list of 3 ints) – Coordinates at which to take the slice from. 0th and 1st indices define a top
view (left corner of the) –
slice (0th index defines the sagittal) –
arr – 3D volume
- Returns
2D image
- Return type
(ndarray)
- save_model(directory, bypass_loss_check=False)
Save the model. If it has achieved the best loss, save to ‘model.pth’ within directory, otherwise save to ‘latest_model.pth’
- Parameters
directory (pathlib Path) – A directory in which to save the model
- train_for_iterations(iterations)
Train the model for a set number of iterations
- Parameters
iterations (int) – Number of iterations to train for
- train_step()
Build a single batch, do a single forward and backward pass.
- ai_ct_scans.model_trainers.blur(src, ksize[, dst[, anchor[, borderType]]]) dst
. @brief Blurs an image using the normalized box filter. . . The function smooths an image using the kernel: . . f[texttt{K} = frac{1}{texttt{ksize.width*ksize.height}} begin{bmatrix} 1 & 1 & 1 & cdots & 1 & 1 \ 1 & 1 & 1 & cdots & 1 & 1 \ hdotsfor{6} \ 1 & 1 & 1 & cdots & 1 & 1 \ end{bmatrix}f] . . The call blur(src, dst, ksize, anchor, borderType) is equivalent to boxFilter(src, dst, src.type(), ksize, . anchor, true, borderType). . . @param src input image; it can have any number of channels, which are processed independently, but . the depth should be CV_8U, CV_16U, CV_16S, CV_32F or CV_64F. . @param dst output image of the same size and type as src. . @param ksize blurring kernel size. . @param anchor anchor point; default value Point(-1,-1) means that the anchor is at the kernel . center. . @param borderType border mode used to extrapolate pixels outside of the image, see #BorderTypes. #BORDER_WRAP is not supported. . @sa boxFilter, bilateralFilter, GaussianBlur, medianBlur
- ai_ct_scans.model_trainers.debug_plot(model_out, batch, index=0)
Plot the original image, the masked image, and the infilled version. Useful during debugging.
- Parameters
model_out – The infilled image stack from an Infiller model
batch (dict of tensors) – A dictionary with torch Tensors ‘labels’ and ‘input_images’
index (int) – The index of the model’s output to compare to the original image and masked version, [0-batch size]
- ai_ct_scans.model_trainers.det(tensor)
Detach a torch Tensor to a cpu numpy version
- Parameters
tensor (torch.Tensor) – A tensor to be detached and turned into an ndarray
- Returns
The same data as an ndarray
- Return type
(ndarray)