Source code for braviz.readAndFilter.kmc400

##############################################################################
#    Braviz, Brain Data interactive visualization                            #
#    Copyright (C) 2014  Diego Angulo                                        #
#                                                                            #
#    This program is free software: you can redistribute it and/or modify    #
#    it under the terms of the GNU Lesser General Public License as          #
#    published by  the Free Software Foundation, either version 3 of the     #
#    License, or (at your option) any later version.                         #
#                                                                            #
#    This program is distributed in the hope that it will be useful,         #
#    but WITHOUT ANY WARRANTY; without even the implied warranty of          #
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the           #
#    GNU Lesser General Public License for more details.                     #
#                                                                            #
#    You should have received a copy of the GNU Lesser General Public License#
#    along with this program.  If not, see <http://www.gnu.org/licenses/>.   #
##############################################################################


from __future__ import division, print_function

import os
import re
from braviz.readAndFilter.cache import memo_ten
from braviz.readAndFilter.config_file import get_host_config

import logging

import nibabel as nib
import numpy as np
from numpy.linalg import inv

from braviz.readAndFilter.images import numpy2vtk_img, nifti_rgb2vtk, nibNii2vtk

from braviz.readAndFilter.kmc_abstract import KmcAbstractReader

from braviz.readAndFilter.readDartelTransform import dartel2GridTransform_cached
from braviz.readAndFilter.transforms import applyTransform, readFreeSurferTransform, readFlirtMatrix
from braviz.visualization.create_lut import get_colorbrewer_lut


[docs]class Kmc400Reader(KmcAbstractReader): """ A Braviz reader designed to work with the file structure and data from the KMC saving-brains project This project contains data from around 450 subjects but only 250 of them have images. Data is organized into folders, and path and names for the different files can be derived from data type and id. This project reads data from a non-writable directory and writes braviz specific data to a different directory. This is done to protect raw data and to allow to share it between different users. """ def __init__(self, static_root, dynamic_route, max_cache=2000): """The path pointing to the __root of the file structure must be set here""" KmcAbstractReader.__init__(self, static_root, dynamic_route, max_cache) self._available_images = frozenset(("MRI", "FA", "MD", "FLAIR", "T2")) self._functional_paradigms = frozenset( ('ATENCION', 'COORDINACION', 'MEMORIA', 'MIEDO', 'PRENSION')) self._tracula_bundles = ['CC-ForcepsMajor', 'CC-ForcepsMinor', 'LAntThalRadiation', 'LCingulumAngBundle', 'LCingulumCingGyrus', 'LCorticospinalTract', 'LInfLongFas', 'LSupLongFasParietal', 'LSupLongFasTemporal', 'LUncinateFas', 'RAntThalRadiation', 'RCingulumAngBundle', 'RCingulumCingGyrus', 'RCorticospinalTract', 'RInfLongFas', 'RSupLongFasParietal', 'RSupLongFasTemporal', 'RUncinateFas'] def _getIds(self): """Auxiliary function to get the available ids""" contents = os.listdir( os.path.join(self.get_data_root(), "freeSurfer_Tracula")) numbers = re.compile('[0-9]+$') ids = [c for c in contents if numbers.match(c) is not None] ids.sort(key=int) return ids def _decode_subject(self, subj): return str(subj) def _move_img_from_world(self, subj, img2, interpolate=False, space='subject'): """moves an image from the subject coordinate space to talairach or dartel spaces""" space = space.lower() if space == 'subject': return img2 elif space in ('template', 'dartel'): dartel_warp = self._get_spm_grid_transform(subj, "dartel", "back") img3 = applyTransform(img2, dartel_warp, origin2=(90, -126, -72), dimension2=(121, 145, 121), spacing2=(-1.5, 1.5, 1.5), interpolate=interpolate) # origin, dimension and spacing come from template return img3 elif space[:2] == 'ta': talairach_file = self._get_talairach_transform_name(subj) transform = readFreeSurferTransform(talairach_file) img3 = applyTransform(img2, inv(transform), (-100, -120, -110), (190, 230, 230), (1, 1, 1), interpolate=interpolate) return img3 elif space[:4] in ('func', 'fmri'): # functional space paradigm = space[5:] # print paradigm paradigm = self._get_paradigm_name(paradigm) transform = self._read_func_transform(subj, paradigm, True) img3 = applyTransform(img2, transform, origin2=(78, -112, -50), dimension2=(79, 95, 68), spacing2=(-2, 2, 2), interpolate=interpolate) return img3 elif space == "diff": # TODO: Check, looks wrong path = self._get_base_fibs_dir_name(subj) # notice we are reading the inverse transform diff -> world trans = readFlirtMatrix( 'diff2surf.mat', 'FA.nii.gz', 'orig.nii.gz', path) img3 = applyTransform(img2, trans, interpolate=interpolate) return img3 else: log = logging.getLogger(__name__) log.error('Unknown space %s' % space) raise Exception('Unknown space %s' % space) def _move_img_to_subject(self, subj, img2, interpolate=False, space='subject'): """moves an image from the subject coordinate space to talairach or dartel spaces""" space = space.lower() if space == 'subject': return img2 ref = self.get("image", subj, space="subject", format="vtk",name="mri") origin = ref.GetOrigin() spacing = ref.GetSpacing() dims = ref.GetDimensions() if space in ('template', 'dartel'): dartel_warp = self._get_spm_grid_transform(subj, "dartel", "forw") img3 = applyTransform(img2, dartel_warp, origin2=origin, dimension2=dims, spacing2=spacing, interpolate=interpolate) # origin, dimension and spacing come from template return img3 elif space[:2].lower() == 'ta': talairach_file = self._get_talairach_transform_name(subj) transform = readFreeSurferTransform(talairach_file) img3 = applyTransform(img2, transform, origin, dims, spacing, interpolate=interpolate) return img3 elif space[:4] in ('func', 'fmri'): # functional space paradigm = space[5:] # print paradigm paradigm = self._get_paradigm_name(paradigm) transform = self._read_func_transform(subj, paradigm, False) img3 = applyTransform(img2, transform, origin2=origin, dimension2=dims, spacing2=spacing, interpolate=interpolate) return img3 elif space == "diff": path = self._get_base_fibs_dir_name(subj) # notice we are reading the inverse transform diff -> world trans = readFlirtMatrix( 'diff2surf.mat', 'FA.nii.gz', 'orig.nii.gz', path) img3 = applyTransform( img2, trans, interpolate=interpolate, origin2=origin, spacing2=spacing, dimension2=dims) return img3 else: log = logging.getLogger(__name__) log.error('Unknown space %s' % space) raise Exception('Unknown space %s' % space) def _get_img(self, image_name, subj, space, **kw): """Auxiliary function to read nifti images""" # path=self.__root+'/'+str(subj)+'/MRI' if image_name == 'MRI': path = os.path.join(self.get_data_root(), "nii", str(subj)) filename = 'MPRAGEmodifiedSENSE.nii.gz' elif image_name == 'FA': path = os.path.join( self.get_data_root(), 'tractography_w_cerebellum', str(subj)) if space.startswith('diff'): filename = 'fa.nii.gz' else: filename = 'fa_mri.nii.gz' elif image_name == "MD": path = os.path.join( self.get_data_root(), 'tractography_w_cerebellum', str(subj)) if space.startswith('diff'): filename = 'md.nii.gz' else: filename = 'md_mri.nii.gz' elif image_name == "DTI": path = os.path.join( self.get_data_root(), 'tractography_w_cerebellum', str(subj)) if space.startswith('diff'): filename = 'rgb_dti.nii.gz' #filename = 'rgb_dti_masked.nii.gz' else: filename = 'rgb_dti_mri.nii.gz' #filename = 'rgb_dti_mri_masked.nii.gz' elif image_name == 'APARC': path = os.path.join( self.get_data_root(), "slicer_models", str(subj)) if kw.get("wm"): filename = 'wmparc.nii.gz' log = logging.getLogger(__name__) log.warning("Warning... deprecated, use WMPARC instead") else: filename = 'aparc+aseg.nii.gz' elif image_name == "WMPARC": path = os.path.join( self.get_data_root(), "slicer_models", str(subj)) filename = 'wmparc.nii.gz' elif image_name == 'T2': path = os.path.join(self.get_data_root(), "nii", str(subj)) filename = 'eT2WTSEPEBCLEAR.nii.gz' elif image_name == 'FLAIR': path = os.path.join(self.get_data_root(), "nii", str(subj)) filename = 'eFLAIRLongTRSENSE.nii.gz' else: log = logging.getLogger(__name__) log.error('Unknown image type %s' % image_name) raise Exception('Unknown image type %s' % image_name) wholeName = os.path.join(path, filename) try: img = nib.load(wholeName) except IOError as e: log = logging.getLogger(__name__) log.exception(e) log.error("File %s not found" % wholeName) raise (Exception('File not found')) if kw.get('format', '').upper() == 'VTK': if image_name == "MD": img_data = img.get_data() img_data *= 1e5 # remove lower than 0 img_data[img_data < 0] = 0 # remove bigger than 1000 img_data[img_data > 1000] = 1000 vtkImg = numpy2vtk_img(img_data) elif image_name == "DTI": vtkImg = nifti_rgb2vtk(img) else: vtkImg = nibNii2vtk(img) if space == 'native': return vtkImg interpolate = True if image_name in {'APARC', 'WMPARC'}: interpolate = False # print "turning off interpolate" img2 = applyTransform( vtkImg, transform=inv(img.get_affine()), interpolate=interpolate) if space == "diff" and (image_name in {"FA", "MD", "DTI"}): return img2 return self._move_img_from_world(subj, img2, interpolate, space=space) if space == "diff" and (image_name in {"FA", "MD", "DTI"}): return img elif space == "subject": return img elif space == "diff": # read transform: path = self._get_base_fibs_dir_name(subj) #matrix = readFlirtMatrix('surf2diff.mat', 'orig.nii.gz', 'FA.nii.gz', path) matrix = readFlirtMatrix( 'diff2surf.mat', 'fa.nii.gz', 'orig.nii.gz', path) matrix = np.linalg.inv(matrix) affine = img.get_affine() aff2 = matrix.dot(affine) img2 = nib.Nifti1Image(img.get_data(), aff2) return img2 elif space[:2] == "ta": talairach_file = self._get_talairach_transform_name(subj) # TODO needs more testing transform = readFreeSurferTransform(talairach_file) affine = img.get_affine() aff2 = transform.dot(affine) img2 = nib.Nifti1Image(img.get_data(), aff2) return img2 raise NotImplementedError("Returned nifti image is in subject space") #==========Free Surfer================ def _get_free_surfer_models_dir_name(self, subject): return os.path.join(self.get_data_root(), 'slicer_models', subject) def _get_talairach_transform_name(self, subject): """xfm extension""" return os.path.join(self.get_data_root(), "freeSurfer_Tracula", subject, "mri", "transforms", 'talairach.xfm') def _get_free_surfer_stats_dir_name(self, subject): return os.path.join(self.get_data_root(), 'freeSurfer_Tracula', subject, 'stats') def _get_freesurfer_lut_name(self): return os.path.join(self.get_data_root(), "freeSurfer_Tracula", 'FreeSurferColorLUT.txt') def _get_free_surfer_morph_path(self, subj): return os.path.join(self.get_data_root(), "freeSurfer_Tracula", str(subj), 'surf') def _get_free_surfer_labels_path(self, subj): return os.path.join(self.get_data_root(), "freeSurfer_Tracula", str(subj), 'label') def _get_freesurfer_surf_name(self, subj, name): return os.path.join(self.get_data_root(), "freeSurfer_Tracula", str(subj), "surf", name) def _get_tracula_map_name(self, subj): data_dir = os.path.join( self.get_data_root(), "freeSurfer_Tracula", "%s" % subj, "dpath") tracks_file = "merged_avg33_mni_bbr.mgz" tracks_full_file = os.path.join(data_dir, tracks_file) return tracks_full_file #=============Camino================== def _get_base_fibs_name(self, subj): # return os.path.join(self.get_data_root(), "tractography",subj, # 'CaminoTracts.vtk') return os.path.join(self.get_data_root(), "tractography_w_cerebellum", subj, 'CaminoTracts.vtk') def _get_base_fibs_dir_name(self, subj): """ Must contain 'diff2surf.mat', 'fa.nii.gz', 'orig.nii.gz' """ # return os.path.join(self.get_data_root(), "tractography",subj) return os.path.join(self.get_data_root(), "tractography_w_cerebellum", subj) def _get_fa_img_name(self): return "fa.nii.gz" def _get_orig_img_name(self): return "orig.nii.gz" def _get_md_lut(self): lut = get_colorbrewer_lut(491e-6, 924e-6, "YlGnBu", 9, invert=True) return lut #==========SPM================ def _get_paradigm_name(self, paradigm_name): if paradigm_name.endswith("SENSE"): return paradigm_name paradigm_name = paradigm_name.upper() assert paradigm_name in self._functional_paradigms if paradigm_name == "MIEDO": paradigm_name = "MIEDOSofTone" paradigm_name += "SENSE" return paradigm_name def _get_paradigm_dir(self, subject, name, spm=False): """If spm is True return the direcory containing spm.mat, else return its parent""" if not spm: return os.path.join(self.get_data_root(), "spm", subject, name) else: return os.path.join(self.get_data_root(), "spm", subject, name, "FirstLevel") def _get_spm_grid_transform(self, subject, paradigm, direction, assume_bad_matrix=False): """ Get the spm non linear registration transform grid associated to the paradigm Use paradigm=dartel to get the transform associated to the dartel normalization """ assert direction in {"forw", "back"} cache_key = "y_%s_%s_%s.vtk" % (paradigm, subject, direction) if paradigm == "dartel": y_file = os.path.join( self.get_data_root(), "spm", subject, "T1", "y_dartel_%s.nii" % direction) else: y_file = os.path.join( self.get_data_root(), "spm", subject, paradigm, "y_seg_%s.nii.gz" % direction) return dartel2GridTransform_cached(y_file, cache_key, self, assume_bad_matrix) def _read_func_transform(self, subject, paradigm_name, inverse=False): paradigm_name = self._get_paradigm_name(paradigm_name) path = os.path.join(self.get_data_root(), 'spm', subject) T1_func = os.path.join(path, paradigm_name, 'T1.nii') T1_world = os.path.join(path, 'T1', 'T1.nii') return self._read_func_transform_internal(subject, paradigm_name, inverse, path, T1_func, T1_world) @staticmethod @memo_ten def get_auto_data_root(): project_name = os.path.basename(__file__).split('.')[0] log = logging.getLogger(__name__) try: config = get_host_config(project_name) except KeyError as e: log.exception(e) raise data_root = config["data root"] if not os.path.isabs(data_root): data_root = os.path.join(os.path.dirname(__file__),"../applications",data_root) return data_root @staticmethod @memo_ten def get_auto_dyn_data_root(): project_name = os.path.basename(__file__).split('.')[0] log = logging.getLogger(__name__) try: config = get_host_config(project_name) except KeyError as e: log.exception(e) raise data_root = config["dynamic data root"] if not os.path.isabs(data_root): data_root = os.path.join(os.path.dirname(__file__),"../applications",data_root) return data_root @staticmethod def get_auto_reader(**kw_args): """Initialized a kmc400Reader based on the computer name""" project_name = os.path.basename(__file__).split('.')[0] log = logging.getLogger(__name__) try: config = get_host_config(project_name) except KeyError as e: log.error(e.message) log.exception(e) raise static_data_root = config["data root"] if not os.path.isabs(static_data_root): static_data_root = os.path.join(os.path.dirname(__file__),"../applications",static_data_root) dyn_data_root = config["dynamic data root"] if not os.path.isabs(dyn_data_root): dyn_data_root = os.path.join(os.path.dirname(__file__),"../applications",dyn_data_root) if kw_args.get('max_cache', 0) > 0: max_cache = kw_args.pop('max_cache') log.info("Max cache set to %.2f MB" % max_cache) else: max_cache = config["memory (mb)"] return Kmc400Reader(static_data_root, dyn_data_root, max_cache=max_cache)