Source code for mosfit.modules.observables.photometry

"""Definitions for the `Photometry` class."""
import csv
import json
import os
import shutil
from collections import OrderedDict
from copy import deepcopy

import numpy as np
from astropy import constants as c
from astropy import units as u
from astropy.io.votable import parse as voparse
from mosfit.constants import (ANG_CGS, C_CGS, FOUR_PI, H_C_ANG_CGS, MAG_FAC,
                              MPC_CGS)
from mosfit.modules.module import Module
from mosfit.utils import get_url_file_handle, listify, open_atomic, syst_syns


# Important: Only define one ``Module`` class per file.


[docs]class Photometry(Module): """Band-pass filters.""" FLUX_STD = 3631 * u.Jy.cgs.scale / u.Angstrom.cgs.scale * C_CGS def __init__(self, **kwargs): """Initialize module.""" super(Photometry, self).__init__(**kwargs) bands = kwargs.get('bands', '') bands = listify(bands) self._dir_path = os.path.dirname(os.path.realpath(__file__)) self._filter_run_path = os.path.join('modules','observables') if not os.path.exists(self._filter_run_path): os.makedirs(self._filter_run_path) band_list = [] if self._pool.is_master(): rules_path = os.path.join( 'modules', 'observables', 'filterrules.json') if not os.path.isfile(rules_path): rules_path = os.path.join(self._dir_path, 'filterrules.json') with open(rules_path) as f: filterrules = json.load(f, object_pairs_hook=OrderedDict) for rank in range(1, self._pool.size + 1): self._pool.comm.send(filterrules, dest=rank, tag=5) else: filterrules = self._pool.comm.recv(source=0, tag=5) for bi, band in enumerate(bands): for rule in filterrules: if '@note' in rule: continue sysinstperms = [ { 'systems': xx, 'instruments': yy, 'bandsets': zz, 'telescopes': tt, 'modes': mm } for xx in rule.get('systems', ['']) for yy in rule.get('instruments', ['']) for zz in rule.get('bandsets', ['']) for tt in rule.get('telescopes', ['']) for mm in rule.get('modes', ['']) ] for bnd in rule.get('filters', []): if band == bnd or band == '': for perm in sysinstperms: new_band = deepcopy(rule['filters'][bnd]) new_band.update(deepcopy(perm)) new_band['name'] = bnd band_list.append(new_band) self._unique_bands = band_list self._band_insts = np.array( [x['instruments'] for x in self._unique_bands], dtype=object) self._band_bsets = np.array( [x['bandsets'] for x in self._unique_bands], dtype=object) self._band_systs = np.array( [x['systems'] for x in self._unique_bands], dtype=object) self._band_teles = np.array( [x['telescopes'] for x in self._unique_bands], dtype=object) self._band_modes = np.array( [x['modes'] for x in self._unique_bands], dtype=object) self._band_names = np.array( [x['name'] for x in self._unique_bands], dtype=object) self._n_bands = len(self._unique_bands) self._band_wavelengths = [[] for i in range(self._n_bands)] self._band_energies = [[] for i in range(self._n_bands)] self._transmissions = [[] for i in range(self._n_bands)] self._band_areas = [[] for i in range(self._n_bands)] self._min_waves = np.full(self._n_bands, 0.0) self._max_waves = np.full(self._n_bands, 0.0) self._imp_waves = [[0.0, 1.0] for i in range(self._n_bands)] self._filter_integrals = np.full(self._n_bands, 0.0) self._count_integrals = np.full(self._n_bands, 0.0) self._average_wavelengths = np.full(self._n_bands, 0.0) self._band_offsets = np.full(self._n_bands, 0.0) self._band_xunits = np.full(self._n_bands, 'Angstrom', dtype=object) self._band_yunits = np.full(self._n_bands, '', dtype=object) self._band_xu = np.full(self._n_bands, u.Angstrom.cgs.scale) self._band_yu = np.full(self._n_bands, 1.0) self._band_kinds = np.full(self._n_bands, 'magnitude', dtype=object) self._band_index_cache = {} self._warned_mismatch = False self._zps = np.full(self._n_bands, 0.0) for i, band in enumerate(self._unique_bands): self._band_xunits[i] = band.get('xunit', 'Angstrom') self._band_yunits[i] = band.get('yunit', '') self._band_xu[i] = u.Unit(self._band_xunits[i]).cgs.scale self._band_yu[i] = u.Unit(self._band_yunits[i]).cgs.scale if '{0}'.format(self._band_yunits[i]) == 'cm2': self._band_kinds[i] = 'countrate'
[docs] def load_bands(self, band_indices): """Load band files.""" prt = self._printer if self._pool.is_master(): vo_tabs = OrderedDict() per = 0.0 bc = 0 band_set = set(band_indices) for i, band in enumerate(self._unique_bands): if len(band_indices) and i not in band_set: continue if self._pool.is_master(): new_per = np.round(100.0 * float(bc) / len(band_set)) if new_per > per: per = new_per prt.message('loading_bands', [per], inline=True) systems = ['AB'] zps = [0.0] path = None if 'SVO' in band: photsystem = self._band_systs[i] if photsystem in syst_syns: photsystem = syst_syns[photsystem] if photsystem not in systems: systems.append(photsystem) zpfluxes = [] for sys in systems: svopath = band['SVO'] + '/' + sys path = os.path.join(self._filter_run_path, 'filters', svopath.replace('/', '_') + '.dat') xml_path = os.path.join( self._filter_run_path, 'filters', svopath.replace('/', '_') + '.xml') xml_install_path = os.path.join( self._dir_path, 'filters', svopath.replace('/', '_') + '.xml') if not os.path.exists(xml_path): if not os.path.exists(xml_install_path): prt.message('dl_svo', [svopath], inline=True) try: response = get_url_file_handle( 'http://svo2.cab.inta-csic.es' '/svo/theory/fps3/' 'fps.php?PhotCalID=' + svopath, timeout=10) except Exception: prt.message('cant_dl_svo', warning=True) else: with open_atomic(xml_path, 'wb') as f: shutil.copyfileobj(response, f) if os.path.exists(xml_install_path): already_written = svopath in vo_tabs if not already_written: vo_tabs[svopath] = voparse(xml_install_path) vo_tab = vo_tabs[svopath] # need to account for zeropoint type for resource in vo_tab.resources: if len(resource.params) == 0: params = vo_tab.get_first_table().params else: params = resource.params oldzplen = len(zps) for param in params: if param.name == 'ZeroPoint': zpfluxes.append(param.value) if sys != 'AB': # 0th element is AB flux zps.append(2.5 * np.log10( zpfluxes[0] / zpfluxes[-1])) else: continue if sys != 'AB' and len(zps) == oldzplen: raise RuntimeError( 'ZeroPoint not found in XML.') vo_dat = vo_tab.get_first_table().array bi = max( next((i for i, x in enumerate(vo_dat) if x[1]), 0) - 1, 0) ei = -max( next((i for i, x in enumerate( reversed(vo_dat)) if x[1]), 0) - 1, 0) vo_dat = vo_dat[bi:ei if ei else len(vo_dat)] vo_string = '\n'.join([ ' '.join([str(y) for y in x]) for x in vo_dat ]) if (not self._model._fitter._prefer_cache or not os.path.exists(path)): with open_atomic(path, 'w') as f: f.write(vo_string) elif os.path.exists(xml_path): already_written = svopath in vo_tabs if not already_written: vo_tabs[svopath] = voparse(xml_path) vo_tab = vo_tabs[svopath] # need to account for zeropoint type for resource in vo_tab.resources: if len(resource.params) == 0: params = vo_tab.get_first_table().params else: params = resource.params oldzplen = len(zps) for param in params: if param.name == 'ZeroPoint': zpfluxes.append(param.value) if sys != 'AB': # 0th element is AB flux zps.append(2.5 * np.log10( zpfluxes[0] / zpfluxes[-1])) else: continue if sys != 'AB' and len(zps) == oldzplen: raise RuntimeError( 'ZeroPoint not found in XML.') if not already_written: vo_dat = vo_tab.get_first_table().array bi = max( next((i for i, x in enumerate(vo_dat) if x[1]), 0) - 1, 0) ei = -max( next((i for i, x in enumerate( reversed(vo_dat)) if x[1]), 0) - 1, 0) vo_dat = vo_dat[bi:ei if ei else len(vo_dat)] vo_string = '\n'.join([ ' '.join([str(y) for y in x]) for x in vo_dat ]) if (not self._model._fitter._prefer_cache or not os.path.exists(path)): with open_atomic(path, 'w') as f: f.write(vo_string) else: raise RuntimeError( prt.string('cant_read_svo')) self._unique_bands[i]['origin'] = band['SVO'] elif all(x in band for x in [ 'min_wavelength', 'max_wavelength', 'delta_wavelength']): nbins = int(np.round(( band['max_wavelength'] - band['min_wavelength']) / band[ 'delta_wavelength'])) + 1 rows = np.array( [np.linspace( band['min_wavelength'], band['max_wavelength'], nbins), np.full(nbins, 1.0)]).T.tolist() self._unique_bands[i]['origin'] = 'generated' elif 'path' in band: self._unique_bands[i]['origin'] = band['path'] path = band['path'] else: raise RuntimeError(prt.text('bad_filter_rule')) if path: with open(path, 'r') as f: rows = [] for row in csv.reader( f, delimiter=' ', skipinitialspace=True): rows.append([float(x) for x in row[:2]]) for rank in range(1, self._pool.size + 1): self._pool.comm.send(rows, dest=rank, tag=3) self._pool.comm.send(zps, dest=rank, tag=4) else: rows = self._pool.comm.recv(source=0, tag=3) zps = self._pool.comm.recv(source=0, tag=4) xvals, yvals = list( map(list, zip(*rows))) xvals = np.array(xvals) yvals = np.array(yvals) if '{0}'.format(self._band_yunits[i]) == 'cm2': xscale = (c.h * c.c / u.Angstrom).cgs.value / self._band_xu[i] self._band_energies[ i], self._band_areas[i] = xvals, yvals / xvals self._band_wavelengths[i] = xscale / self._band_energies[i] self._average_wavelengths[i] = np.trapz([ x * y for x, y in zip( self._band_areas[i], self._band_wavelengths[i]) ], self._band_wavelengths[i]) / np.trapz( self._band_areas[i], self._band_wavelengths[i]) else: self._band_wavelengths[ i], self._transmissions[i] = xvals, yvals self._filter_integrals[i] = self.FLUX_STD * np.trapz( np.array(self._transmissions[i]) / np.array(self._band_wavelengths[i]) ** 2, self._band_wavelengths[i]) self._count_integrals[i] = self.FLUX_STD * np.trapz( np.array(self._transmissions[i]) / np.array(self._band_wavelengths[i]) ** 2 / ( H_C_ANG_CGS / self._band_wavelengths[i]), self._band_wavelengths[i]) self._average_wavelengths[i] = np.trapz([ x * y for x, y in zip( self._transmissions[i], self._band_wavelengths[i]) ], self._band_wavelengths[i]) / np.trapz( self._transmissions[i], self._band_wavelengths[i]) if 'offset' in band: self._band_offsets[i] = band['offset'] elif 'SVO' in band: self._band_offsets[i] = zps[-1] # Do some sanity checks. if (self._band_offsets[i] != 0.0 and self._band_systs[i] == 'AB'): raise RuntimeError( 'Filters in AB system should always have offset = ' '0.0, not the case for `{}`'.format( self._band_names[i])) self._min_waves[i] = min(self._band_wavelengths[i]) self._max_waves[i] = max(self._band_wavelengths[i]) self._imp_waves[i] = set([self._min_waves[i], self._max_waves[i]]) if len(self._transmissions[i]): new_wave = self._band_wavelengths[i][ np.argmax(self._transmissions[i])] self._imp_waves[i].add(new_wave) elif len(self._band_areas[i]): new_wave = self._band_wavelengths[i][ np.argmax(self._band_areas[i])] self._imp_waves[i].add(new_wave) self._imp_waves[i] = list(sorted(self._imp_waves[i])) bc = bc + 1 if self._pool.is_master(): prt.message('band_load_complete', inline=True)
[docs] def find_band_index( self, band, telescope='', instrument='', mode='', bandset='', system=''): """Find the index corresponding to the provided band information.""" bmatch = 0 bbi = None cache_key = ':'.join([ band, telescope, instrument, mode, bandset, system]) if cache_key in self._band_index_cache: return self._band_index_cache[cache_key] ltele, linst, lmode, lbset, lsyst = tuple([x.lower() for x in [ telescope, instrument, mode, bandset, system]]) for bi, bnd in enumerate(self._unique_bands): # Band name *must* match (case-sensitive), all other matches # optional and case-insensitive. if (band != bnd['name']) and (band != ''): continue nmismatches = sum( [(linst != self._band_insts[bi].lower()) & ( linst != '') & (self._band_insts[bi] != ''), (ltele != self._band_teles[bi].lower()) & ( ltele != '') & (self._band_teles[bi] != '')]) matches = [band == bnd['name'], lsyst == self._band_systs[bi].lower(), lmode == self._band_modes[bi].lower(), linst == self._band_insts[bi].lower(), ltele == self._band_teles[bi].lower(), lbset == self._band_bsets[bi].lower()] lmatch = sum(matches) nbmatch = sum( [(band == bnd['name']) & (band != ''), (lsyst == self._band_systs[bi].lower()) & (lsyst != ''), (lmode == self._band_modes[bi].lower()) & (lmode != ''), (linst == self._band_insts[bi].lower()) & (linst != ''), (ltele == self._band_teles[bi].lower()) & (ltele != ''), (lbset == self._band_bsets[bi].lower()) & (lbset != '')]) if lmatch > bmatch and nbmatch > 0: bmatch = lmatch bbi = bi bmm = nmismatches if lmatch == 6 and nbmatch > 0: break if bbi is not None: if bmm > 0 and not self._warned_mismatch: self._printer.message('potential_mismatch', reps=[ band, instrument, telescope, self._band_insts[bbi], self._band_teles[bbi]], warning=True) self._warned_mismatch = True self._band_index_cache[cache_key] = bbi return bbi raise ValueError( self._printer.text('band_not_found', reps=[ band, bandset, mode, instrument, telescope, system]))
[docs] def preprocess(self, **kwargs): """Preprocess module.""" if not self._preprocessed: self.load_bands(kwargs['all_band_indices']) self._preprocessed = True
[docs] def process(self, **kwargs): """Process module.""" self.preprocess(**kwargs) kwargs = self.prepare_input(self.key('luminosities'), **kwargs) self._band_indices = kwargs['all_band_indices'] self._observation_types = np.array(kwargs['observation_types']) self._observed = kwargs['observed'] self._dist_const = FOUR_PI * (kwargs['lumdist'] * MPC_CGS) ** 2 self._ldist_const = np.log10(self._dist_const) self._luminosities = kwargs[self.key('luminosities')] self._frequencies = kwargs['all_frequencies'] self._zps = kwargs.get('all_zeropoints', np.zeros_like( self._luminosities)) zp1 = 1.0 + kwargs['redshift'] eff_fluxes = np.zeros_like(self._luminosities) offsets = np.zeros_like(self._luminosities) model_observations = np.zeros_like(self._luminosities) for li, lum in enumerate(self._luminosities): bi = self._band_indices[li] if bi >= 0: if (self._observation_types[li] == 'magnitude' or self._observation_types[li] == 'magcount'): offsets[li] = self._band_offsets[bi] wavs = kwargs['sample_wavelengths'][bi] yvals = np.interp( wavs, self._band_wavelengths[bi], self._transmissions[bi]) * kwargs['seds'][li] / zp1 eff_fluxes[li] = np.trapz( yvals, wavs) / self._filter_integrals[bi] elif self._observation_types[li] == 'countrate': wavs = np.array(kwargs['sample_wavelengths'][bi]) yvals = np.interp( wavs, self._band_wavelengths[bi], self._band_areas[bi]) * kwargs['seds'][li] / zp1 / ( H_C_ANG_CGS / wavs) / ANG_CGS eff_fluxes[li] = np.trapz(yvals, wavs) else: raise RuntimeError('Unknown observation kind.') else: eff_fluxes[li] = kwargs['seds'][li][0] / ANG_CGS * ( C_CGS / (self._frequencies[li] ** 2)) nbs = np.logical_or( self._observation_types == 'countrate', self._observation_types == 'fluxdensity') ybs = np.logical_or( self._observation_types == 'magnitude', self._observation_types == 'magcount') cbs = self._observation_types == 'magcount' model_observations[nbs] = eff_fluxes[nbs] / self._dist_const model_observations[ybs] = self.abmag(eff_fluxes[ybs], offsets[ybs]) model_observations[cbs] = 10.0 ** (-0.4 * (model_observations[ cbs] - self._zps[cbs])) return {'model_observations': model_observations}
[docs] def average_wavelengths(self, indices=None): """Return average wavelengths for specified band indices.""" if indices: return [x for i, x in enumerate(self._average_wavelengths) if i in indices] return self._average_wavelengths
[docs] def bands(self, indices=None): """Return the list of unique band names.""" if indices: return [x for i, x in enumerate(self._band_names) if i in indices] return self._band_names
[docs] def instruments(self, indices=None): """Return the list of instruments.""" if indices: return [x for i, x in enumerate(self._band_insts) if i in indices] return self._band_insts
[docs] def telescopes(self, indices=None): """Return the list of telescopes.""" if indices: return [x for i, x in enumerate(self._band_teles) if i in indices] return self._band_teles
[docs] def abmag(self, eff_fluxes, offsets): """Convert fluxes into AB magnitude.""" mags = np.full(len(eff_fluxes), np.inf) ef_mask = eff_fluxes != 0.0 mags[ef_mask] = - offsets[ef_mask] - MAG_FAC * ( np.log10(eff_fluxes[ef_mask]) - self._ldist_const) return mags
[docs] def set_variance_bands(self, band_pairs): """Set band (or pair of bands) that variance will be anchored to.""" self._variance_bands = [] for i, wave in enumerate(self._average_wavelengths): match_found = False for pwave, band in band_pairs: if wave == pwave: self._variance_bands.append(band) match_found = True break if not match_found: for bpi, (pwave, band) in enumerate(band_pairs): if wave < pwave: if bpi > 0: frac = ((wave - band_pairs[bpi - 1][0]) / (pwave - band_pairs[bpi - 1][0])) self._variance_bands.append( [frac, [x[1] for x in band_pairs[bpi - 1:bpi + 1]]]) else: self._variance_bands.append(band) break if bpi == len(band_pairs) - 1: self._variance_bands.append(band)
[docs] def send_request(self, request): """Send requests to other modules.""" if request == 'photometry': return self elif request == 'band_wave_ranges': return self._imp_waves elif request == 'average_wavelengths': return self._average_wavelengths elif request == 'variance_bands': return getattr(self, '_variance_bands', []) return []