ai_ct_scans.model_trainers module

class ai_ct_scans.model_trainers.InfillTrainer(axial_width=256, coronal_width=256, sagittal_width=256, batch_size=8, batch_width=256, batch_height=256, blank_width=64, num_encoder_convs=3, encoder_filts_per_layer=10, neurons_per_dense=512, num_dense_layers=3, decoder_filts_per_layer=10, num_decoder_convs=3, kernel_size=3, learning_rate=1e-05, save_dir=None, clear_previous_memmaps=False, save_freq=200, blur_kernel=None, show_outline=False)

Bases: object

A class for training an ai_ct_scans.models.Infiller network

border_mask_builder()

Get a list of logical ndarrays that can be used to mask out the border region of an input image, and extract that local region for a ‘label’ array at output of the model. Applying these should help with aliasing effects at the edges of cnn output

Returns

A set of masks to apply to the axial, coronal and sagittal views taken during building of a batch

Return type

(list of 2D ndarrays)

build_batch(patient_indices=None, body_part_indices=None, plane_indices=None, scan_num_indices=None, coords_input_array=None, batch_size=None, require_above_thresh=True, allow_off_edge=False)

Get a batch of inputs and labels for a ai_ct_scans.models.

Parameters
  • patient_indices (optional, ndarray of type int) – The indices of patients to access for each batch element,

  • zero-indexing (with) –

  • body_part_indices (optional, ndarray of type int) – Indices of body parts to use to build each batch element,

  • abdomen (0 for) –

  • thorax (1 for) –

  • plane_indices (optional, ndarray of type int) – Indices of plane to view the batch element from, 0 for axial,

  • coronal (1 for) –

  • sagittal (2 for) –

  • scan_num_indices (optional, ndarray of type int) – Indices of which sequential scan from each patient to use,

  • scan (1 for second) –

  • scan

  • coords_input_array (optional, list of lendth 3 1D ndarrays of type int) – The coordinates to use when

  • that (building each batch element. The coordinate corresponding to the plane_index will be the slice along) –

  • index

  • from (while the other two coordinates will define the top left coordinate of the rectangle extracted) –

  • plane (that) –

  • batch_size (optional, int) – How many slices to return for a batch

  • require_above_thresh (bool) – Whether to reject random slices that do not have any elements above

  • found (self.threshold and seek out new slices until one is) –

  • allow_off_edge (bool) – Whether to allow coords_input_array to cause the output slice to overlap the edges of

  • central (the original scans - useful to ensure that it is possible for every part of a scan to occur at the) –

  • region (masked) –

Returns

‘input_images’: a stack of 2d axial, coronal and sagittal slices

’input_planes’: a stack of one hot vectors, that correspond to which view the slice was taken from. Shape [batch size, 3] ‘input_body_part’: a stack of one hot vectors, that correspond to the body part the slice was taken from. Shape [batch size, 2] ‘input_coords’: a stack of 1D vectors describing the original xyz location of the slice taken ‘labels’: a stack of 2D axial, coronal and sagittal slices, representing the data that was masked at the centre of each element of input_images

Return type

(dict of torch.Tensors)

load_model(directory, model='model.pth')

Load a pretrained model, optimiser state, loss at time of saving, iteration at time of saving

Parameters
  • directory (pathlib Path) – Directory in which the model is saved

  • model (str) – Model filename, defaults to ‘model.pth’

loss(model_out, batch)

Defines a custom loss function for the network. Weights the loss such that reproduction of the masked region (and a small border area around it) contributes to the overall loss on the same order of magnitude as all other pixels that were predicted

Parameters
  • model_out (at least a 'labels' stack of Tensor images of the same shape as) – Stack of images that the model has predicted

  • batch (dict as built by self.build_batch) – The batch that was used for the iteration, which should include

  • model_out

Returns

the MSE error in the output prediction after reweighting masked region of prediction

Return type

(torch Tensor)

plane_mask_builder(blank_width=None)

Get a list of logical ndarrays that can be used to mask out the central region of an input image, and extract that local region for a ‘label’ array at output of the model

Returns

A set of masks to apply to the axial, coronal and sagittal views taken during building of a batch

Return type

(list of 2D ndarrays)

process_full_scan(patient_index, body_part_index, plane_index, scan_index, batch_size, overlap=1, save_path=None, allow_off_edge=True)

Run a single scan through the model in batches along a chosen axis, patching the prediction of the masked region together and subtracting from the real scan to form an ‘anomaly scan’.

Parameters
  • patient_index (int) – The index of a patient as stored in ai_ct_scans.data_loading.MultiPatientLoader, ie

  • index (Patient 1 at 0) –

  • etc (Patient 2 at 1) –

  • body_part_index (int) – Index of a body part, 0 for abdomen, 1 for thorax currently supported

  • plane_index (int) – Index of plane to stack inputs from, 0 for axial, 1 for coronal, 2 for sagittal

  • scan_index (int) – 0 for scan 1, 1 for scan 2

  • batch_size (int) – How many layers to take along the plane_index stack for each batch, 24 with image sizes

  • VRAM (256x256 seems to work fine with 8GB of) –

  • overlap (int) – Number of times to overlap the masked regions to build up an average predicted view. The

  • output (remainder of self.blank_width and overlap should be 0 for well-stitched) –

  • save_path (pathlib Path) – File path at which to save the anomaly scan as a memmap (shape will be appended

  • ai_ct_scans.data_loading.load_memmap) (into the filename for ease of reloading with) –

  • allow_off_edge (bool) – Whether to allow the model input to move off the edge of the scans, so that the

  • False (stitched central blank regions can cover the entire scan. If) –

  • masked (the central square column of) –

  • returned (regions will be) –

Returns

3D volume of input - predicted output

Return type

(ndarray)

random_axial_slicer(arr, indices=None, allow_off_edge=False)

Takes a random crop from a random axial plane of 3D array arr

Parameters
  • allow_off_edge (bool) – optional, defaults to False. Whether to allow indices which will take the view

  • arr (ndarray) –

  • indices (list of 3 ints) – Coordinates at which to take the slice from. 1st and 2nd indices define a top

  • view (left corner of the) –

  • slice (0th index defines the axial) –

  • arr – 3D volume

Returns

2D image

Return type

(ndarray)

random_coronal_slicer(arr, indices=None, allow_off_edge=False)

Takes a random crop from a random coronal plane of 3D array arr

Parameters
  • allow_off_edge (bool) – optional, defaults to False. Whether to allow indices which will take the view

  • arr (ndarray) –

  • indices (list of 3 ints) – Coordinates at which to take the slice from. 0th and 2nd indices define a top

  • view (left corner of the) –

  • slice (1st index defines the coronal) –

  • arr – 3D volume

Returns

2D image

Return type

(ndarray)

random_sagittal_slicer(arr, indices=None, allow_off_edge=False)

Takes a random crop from a random sagittal plane of 3D array arr

Parameters
  • allow_off_edge (bool) – optional, defaults to False. Whether to allow indices which will take the view

  • arr (ndarray) –

  • indices (list of 3 ints) – Coordinates at which to take the slice from. 0th and 1st indices define a top

  • view (left corner of the) –

  • slice (0th index defines the sagittal) –

  • arr – 3D volume

Returns

2D image

Return type

(ndarray)

save_model(directory, bypass_loss_check=False)

Save the model. If it has achieved the best loss, save to ‘model.pth’ within directory, otherwise save to ‘latest_model.pth’

Parameters

directory (pathlib Path) – A directory in which to save the model

train_for_iterations(iterations)

Train the model for a set number of iterations

Parameters

iterations (int) – Number of iterations to train for

train_step()

Build a single batch, do a single forward and backward pass.

ai_ct_scans.model_trainers.blur(src, ksize[, dst[, anchor[, borderType]]]) dst

. @brief Blurs an image using the normalized box filter. . . The function smooths an image using the kernel: . . f[texttt{K} = frac{1}{texttt{ksize.width*ksize.height}} begin{bmatrix} 1 & 1 & 1 & cdots & 1 & 1 \ 1 & 1 & 1 & cdots & 1 & 1 \ hdotsfor{6} \ 1 & 1 & 1 & cdots & 1 & 1 \ end{bmatrix}f] . . The call blur(src, dst, ksize, anchor, borderType) is equivalent to boxFilter(src, dst, src.type(), ksize, . anchor, true, borderType). . . @param src input image; it can have any number of channels, which are processed independently, but . the depth should be CV_8U, CV_16U, CV_16S, CV_32F or CV_64F. . @param dst output image of the same size and type as src. . @param ksize blurring kernel size. . @param anchor anchor point; default value Point(-1,-1) means that the anchor is at the kernel . center. . @param borderType border mode used to extrapolate pixels outside of the image, see #BorderTypes. #BORDER_WRAP is not supported. . @sa boxFilter, bilateralFilter, GaussianBlur, medianBlur

ai_ct_scans.model_trainers.debug_plot(model_out, batch, index=0)

Plot the original image, the masked image, and the infilled version. Useful during debugging.

Parameters
  • model_out – The infilled image stack from an Infiller model

  • batch (dict of tensors) – A dictionary with torch Tensors ‘labels’ and ‘input_images’

  • index (int) – The index of the model’s output to compare to the original image and masked version, [0-batch size]

ai_ct_scans.model_trainers.det(tensor)

Detach a torch Tensor to a cpu numpy version

Parameters

tensor (torch.Tensor) – A tensor to be detached and turned into an ndarray

Returns

The same data as an ndarray

Return type

(ndarray)