Source code for braviz.readAndFilter.kmc40

##############################################################################
#    Braviz, Brain Data interactive visualization                            #
#    Copyright (C) 2014  Diego Angulo                                        #
#                                                                            #
#    This program is free software: you can redistribute it and/or modify    #
#    it under the terms of the GNU Lesser General Public License as          #
#    published by  the Free Software Foundation, either version 3 of the     #
#    License, or (at your option) any later version.                         #
#                                                                            #
#    This program is distributed in the hope that it will be useful,         #
#    but WITHOUT ANY WARRANTY; without even the implied warranty of          #
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the           #
#    GNU Lesser General Public License for more details.                     #
#                                                                            #
#    You should have received a copy of the GNU Lesser General Public License#
#    along with this program.  If not, see <http://www.gnu.org/licenses/>.   #
##############################################################################


from __future__ import division, print_function

import os
import re
from braviz.readAndFilter.config_file import get_host_config
import logging

import nibabel as nib

from numpy.linalg import inv


from braviz.readAndFilter.cache import memo_ten
from braviz.readAndFilter.images import numpy2vtk_img, nifti_rgb2vtk, nibNii2vtk

from braviz.readAndFilter.readDartelTransform import dartel2GridTransform_cached
from braviz.readAndFilter.kmc_abstract import KmcAbstractReader
from braviz.readAndFilter.transforms import applyTransform, readFreeSurferTransform, readFlirtMatrix
from braviz.visualization.create_lut import get_colorbrewer_lut


[docs]class Kmc40Reader(KmcAbstractReader): """ Braviz reader class designed to work with the file structure and data from the KMC pilot This project contains 50 subjects (40 preterms). Data is organized into folders, and path and names for the different files can be derived from data type and id. The constructor requires the root to this structure """ def __init__(self, path, max_cache=2000): """The path pointing to the __root of the file structure must be set here""" KmcAbstractReader.__init__(self, path, path, max_cache) self._available_images = frozenset(("MRI", "FA", "MD")) self._functional_paradigms = frozenset(("PRECISION", "POWERGRIP")) self._named_bundles = frozenset(("cortico_spinal_l", "cortico_spinal_r", "cortico_spinal_n", "cortico_spinal_d", "corpus_callosum")) def _getIds(self): """Auxiliary function to get the available ids""" contents = os.listdir(self.get_data_root()) numbers = re.compile('[0-9]+$') ids = [c for c in contents if numbers.match(c) is not None] ids.sort(key=int) return ids def _decode_subject(self, subj): subj = str(subj) if len(subj) < 3: subj = "0" * (3 - len(subj)) + subj return subj def _get_img(self, image_name, subj, space, **kw): """Auxiliary function to read nifti images""" # path=self.getDataRoot()+'/'+subj+'/MRI' if image_name == 'MRI': path = os.path.join(self.get_data_root(), subj, 'MRI') filename = '%s-MRI-full.nii.gz' % subj elif image_name == 'FA': path = os.path.join(self.get_data_root(), subj, 'camino') if space.startswith('diff'): filename = 'FA_masked.nii.gz' else: filename = 'FA_mri_masked.nii.gz' elif image_name == "MD": path = os.path.join(self.get_data_root(), subj, 'camino') if space.startswith('diff'): filename = 'MD_masked.nii.gz' else: filename = 'MD_mri_masked.nii.gz' elif image_name == "DTI": path = os.path.join(self.get_data_root(), subj, 'camino') if space.startswith('diff'): filename = 'rgb_dti_masked.nii.gz' else: filename = 'rgb_dti_mri_masked.nii.gz' elif image_name == 'APARC': path = os.path.join(self.get_data_root(), subj, 'Models') if kw.get("wm"): log = logging.getLogger(__name__) log.warning("deprecated, use WMPARC instead") path = os.path.join(self.get_data_root(), subj, 'Models3') filename = 'wmparc.nii.gz' else: filename = 'aparc+aseg.nii.gz' elif image_name == "WMPARC": path = os.path.join(self.get_data_root(), subj, 'Models3') filename = 'wmparc.nii.gz' else: log = logging.getLogger(__name__) log.error('Unknown image type %s' % image_name) raise Exception('Unknown image type %s' % image_name) wholeName = os.path.join(path, filename) try: img = nib.load(wholeName) except IOError as e: log = logging.getLogger(__name__) log.error(e.message) log.error("File %s not found" % wholeName) raise (Exception('File not found')) if kw.get('format', '').upper() == 'VTK': if image_name == "MD": img_data = img.get_data() img_data *= 1e12 vtkImg = numpy2vtk_img(img_data) elif image_name == "DTI": vtkImg = nifti_rgb2vtk(img) else: vtkImg = nibNii2vtk(img) if space == 'native': return vtkImg interpolate = True if image_name in {'APARC', "WMPARC"}: interpolate = False # print "turning off interpolate" img2 = applyTransform( vtkImg, transform=inv(img.get_affine()), interpolate=interpolate) if space == "diff" and (image_name in {"FA", "MD", "DTI"}): return img2 return self._move_img_from_world(subj, img2, interpolate, space=space) if space == "diff" and (image_name in {"FA", "MD", "DTI"}): return img elif space == "subject": return img elif space == "diff": # read transform: path = os.path.join(self.get_data_root(), subj, 'camino') #matrix = readFlirtMatrix('surf2diff.mat', 'FA.nii.gz', 'orig.nii.gz', path) matrix = readFlirtMatrix( 'diff2surf.mat', 'FA.nii.gz', 'orig.nii.gz', path) matrix = inv(matrix) affine = img.get_affine() aff2 = matrix.dot(affine) img2 = nib.Nifti1Image(img.get_data(), aff2) return img2 log = logging.getLogger(__file__) log.error("Returned nifti image is in native space") raise NotImplementedError def _move_img_from_world(self, subj, img2, interpolate=False, space='subject'): """moves an image from the subject coordinate space to talairach or dartel spaces""" if space == 'subject': return img2 elif space in ('template', 'dartel'): dartel_warp = self._get_spm_grid_transform(subj, "dartel", "back") img3 = applyTransform(img2, dartel_warp, origin2=(90, -126, -72), dimension2=(121, 145, 121), spacing2=(-1.5, 1.5, 1.5), interpolate=interpolate) # origin, dimension and spacing come from template return img3 elif space[:2].lower() == 'ta': talairach_file = self._get_talairach_transform_name(subj) transform = readFreeSurferTransform(talairach_file) img3 = applyTransform(img2, inv(transform), (-100, -120, -110), (190, 230, 230), (1, 1, 1), interpolate=interpolate) return img3 elif space[:4] in ('func', 'fmri'): # functional space paradigm = space[5:] # print paradigm paradigm = self._get_paradigm_name(paradigm) transform = self._read_func_transform(subj, paradigm, True) img3 = applyTransform(img2, transform, origin2=(78, -112, -50), dimension2=(79, 95, 68), spacing2=(-2, 2, 2), interpolate=interpolate) return img3 elif space == "diff": path = self._get_base_fibs_dir_name(subj) # notice we are reading the inverse transform diff -> world trans = readFlirtMatrix( 'diff2surf.mat', 'FA.nii.gz', 'orig.nii.gz', path) img3 = applyTransform(img2, trans, interpolate=interpolate) return img3 else: log = logging.getLogger(__name__) log.error('Unknown space %s' % space) raise Exception('Unknown space %s' % space) def _move_img_to_subject(self, subj, img2, interpolate=False, space='subject'): """moves an image from the subject coordinate space to talairach or dartel spaces""" if space == 'subject': return img2 elif space in ('template', 'dartel'): dartel_warp = self._get_spm_grid_transform("dartel", "forw") img3 = applyTransform(img2, dartel_warp, origin2=(90, -126, -72), dimension2=(121, 145, 121), spacing2=(-1.5, 1.5, 1.5), interpolate=interpolate) # origin, dimension and spacing come from template return img3 elif space[:2].lower() == 'ta': talairach_file = self._get_talairach_transform_name(subj) transform = readFreeSurferTransform(talairach_file) img3 = applyTransform(img2, inv(transform), (-100, -120, -110), (190, 230, 230), (1, 1, 1), interpolate=interpolate) return img3 elif space[:4] in ('func', 'fmri'): # functional space paradigm = space[5:] paradigm = self._get_paradigm_name(paradigm) transform = self._read_func_transform(subj, paradigm, True) img3 = applyTransform(img2, transform, origin2=(78, -112, -50), dimension2=(79, 95, 68), spacing2=(-2, 2, 2), interpolate=interpolate) return img3 elif space == "diff": path = self._get_base_fibs_dir_name(subj) # notice we are reading the inverse transform diff -> world trans = readFlirtMatrix( 'diff2surf.mat', 'FA.nii.gz', 'orig.nii.gz', path) img3 = applyTransform(img2, trans, interpolate=interpolate) return img3 else: log = logging.getLogger(__name__) log.error('Unknown space %s' % space) raise Exception('Unknown space %s' % space) #==========Free Surfer================ def _get_free_surfer_models_dir_name(self, subject): return os.path.join(self.get_data_root(), subject, 'Models3') def _get_talairach_transform_name(self, subject): """xfm extension""" return os.path.join(self.get_data_root(), subject, 'Surf', 'talairach.xfm') def _get_free_surfer_stats_dir_name(self, subject): return os.path.join(self.get_data_root(), subject, 'Models', 'stats') def _get_freesurfer_lut_name(self): return os.path.join(self.get_data_root(), 'FreeSurferColorLUT.txt') def _get_free_surfer_morph_path(self, subj): return os.path.join(self.get_data_root(), str(subj), 'Surf') def _get_free_surfer_labels_path(self, subj): return os.path.join(self.get_data_root(), str(subj), 'Surf') def _get_freesurfer_surf_name(self, subj, name): return os.path.join(self.get_data_root(), str(subj), "Surf", name) def _get_tracula_map_name(self, subj): raise IOError("Tracula data not available") #=============Camino================== def _get_base_fibs_name(self, subj): return os.path.join(self.get_data_root(), subj, 'camino', 'streams.vtk') def _get_base_fibs_dir_name(self, subj): """ Must contain 'diff2surf.mat', 'fa.nii.gz', 'orig.nii.gz' """ return os.path.join(self.get_data_root(), subj, 'camino') def _get_fa_img_name(self): return "FA.nii.gz" def _get_orig_img_name(self): return "orig.nii.gz" def _get_md_lut(self): lut = get_colorbrewer_lut(6e-10, 11e-10, "YlGnBu", 9, invert=True) return lut #==========SPM================ def _get_paradigm_name(self, paradigm_name): return paradigm_name.upper() def _get_paradigm_dir(self, subject, name, spm=False): """If spm is True return the directory containing spm.mat, else return its parent""" if not spm: return os.path.join(self.get_data_root(), subject, 'spm', name) else: return os.path.join(self.get_data_root(), subject, 'spm', name) def _get_spm_grid_transform(self, subject, paradigm, direction, assume_bad_matrix=False): """ Get the spm non linear registration transform grid associated to the paradigm Use paradigm=dartel to get the transform associated to the dartel normalization """ assert direction in {"forw", "back"} cache_key = "y_%s_%s_%s.vtk" % (paradigm, subject, direction) if paradigm == "dartel": y_file = os.path.join( self.get_data_root(), 'Dartel', "y_%s-%s.nii.gz" % (subject, direction)) else: y_file = os.path.join( self.get_data_root(), subject, 'spm', paradigm, 'y_seg_%s.nii.gz' % direction) return dartel2GridTransform_cached(y_file, cache_key, self, assume_bad_matrix) def _read_func_transform(self, subject, paradigm_name, inverse=False): paradigm_name = self._get_paradigm_name(paradigm_name) path = os.path.join(self.get_data_root(), subject, "spm") T1_func = os.path.join(path, paradigm_name, 'T1.nii.gz') T1_world = os.path.join(path, 'T1', 'T1.nii.gz') return self._read_func_transform_internal(subject, paradigm_name, inverse, path, T1_func, T1_world) @staticmethod @memo_ten def get_auto_data_root(): project_name = os.path.basename(__file__).split('.')[0] log = logging.getLogger(__name__) try: config = get_host_config(project_name) except KeyError as e: log.exception(e) raise data_root = config["data root"] if not os.path.isabs(data_root): data_root = os.path.join(os.path.dirname(__file__),"../applications",data_root) return data_root @staticmethod @memo_ten def get_auto_dyn_data_root(): return Kmc40Reader.get_auto_data_root() @staticmethod def get_auto_reader(**kw_args): """Initialized a kmc40Reader based on the computer name""" project_name = os.path.basename(__file__).split('.')[0] log = logging.getLogger(__name__) try: config = get_host_config(project_name) except KeyError as e: log.exception(e) raise data_root = config["data root"] if not os.path.isabs(data_root): data_root = os.path.join(os.path.dirname(__file__),"../applications",data_root) if kw_args.get('max_cache', 0) > 0: max_cache = kw_args.pop('max_cache') log.info("Max cache set to %.2f MB" % max_cache) else: max_cache = config["memory (mb)"] return Kmc40Reader(data_root, max_cache=max_cache)