#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Apr 1 06:43:54 2023
@author: daniel
"""
import os
import numpy as np
import pandas as pd
from pathlib import Path
from operator import itemgetter
import matplotlib.pyplot as plt ## To plot the spectrum
import matplotlib.colors as mcolors
import optuna
optuna.logging.set_verbosity(optuna.logging.WARNING)
from optuna.importance import get_param_importances, FanovaImportanceEvaluator
from astropy.io import fits
from astropy.wcs import WCS
from progress.bar import FillingSquaresBar
from LineDetect.continuum_finder import Continuum
from LineDetect.detect_elements import MgII
[docs]class Spectrum:
"""
A class for processing spectral data stored in FITS files.
Can process either a set of .fits files or single spectra.
Note:
If the line is detected the spectrum features will be added
to the DataFrame `df` attribute, which will always append new detections.
If no line is detected then nothing will be added to the DataFrame,
but a message with the object name will print.
Args:
halfWindow (int, list, np.ndarray): The half-size of the window/kernel (in Angstroms) used to compute the continuum.
If this is a list/array of integers, then the continuum will be calculated
as the median curve across the fits across all half-window sizes in the list/array.
Defaults to 25.
region_size (int):
savgol_window_size (int):
savgol_poly_order (int): The order of the polynomial used for smoothing the spectrum.
resolution_range (tuple): A tuple of the minimum and maximum resolution (in km/s) used to detect MgII absorption.
Can also be an integer or a float.
directory (str): The path to the directory containing the FITS files. Defaults to None.
save_all (bool): Parameter to control whether to save the non-detections. If the spectral feature is not detected
and save_all=True, the qso_name will be appended alongside 'None' entries. Defaults to False to save only positive detections.
Methods:
process_files(): Process the FITS files in the directory.
process_spectrum(Lambda, y, sig_y, z, file_name): Process a single instance of spectral data.
_reprocess(): Re-runs the process_spectrum method using the saved spectrum attributes.
plot(include, errorbar, xlim, ylim, xlog, ylog, savefig): Plots the spectrum and/or continuum.
find_MgII_absorption(Lambda, y, yC, sig_y, sig_yC, z, qso_name): Find the MgII lines, if present.
find_CIV_absorption(Lambda, y, yC, sig_y, sig_yC, z, qso_name): Find the CIV lines, if present.
"""
def __init__(self, halfWindow=25, resolution_range=(1400, 1700), region_size=100, resolution_element=3,
savgol_window_size=100, savgol_poly_order=5, N_sig_limits=0.5, N_sig_line1=5, N_sig_line2=3,
rest_wavelength_1=2796.35, rest_wavelength_2=2803.53, directory=None, save_all=False):
self.halfWindow = halfWindow
self.resolution_range = resolution_range
self.region_size = region_size
self.resolution_element = resolution_element
self.savgol_window_size = savgol_window_size
self.savgol_poly_order = savgol_poly_order
self.N_sig_limits = N_sig_limits
self.N_sig_line1 = N_sig_line1
self.N_sig_line2 = N_sig_line2
self.rest_wavelength_1 = rest_wavelength_1
self.rest_wavelength_2 = rest_wavelength_2
self.directory = directory
self.save_all = save_all
#Declare a dataframe to hold the info
self.df = pd.DataFrame(columns=['QSO', 'Wavelength', 'z', 'W', 'deltaW'])
[docs] def process_files(self):
"""
Processes each FITS file in the directory, detecting any Mg II absorption that may be present.
The method iterates through each FITS file in the directory specified during initialization,
reads in the spectrum data and associated header information, applies continuum normalization,
identifies Mg II absorption features, and calculates the equivalent widths of said absorptions.
The results are stored in a pandas DataFrame (df attribute).
Note:
Unlike when processing single spectra, this method does not save
the continuum and continuum_err attributes, therefore the plot()
method cannot be called. Load a single spectrum using process_spectrum
to save the continuum attributes.
Returns:
None
"""
for i, (root, dirs, files) in enumerate(os.walk(os.path.abspath(self.directory))):
progress_bar = FillingSquaresBar('Processing files......', max=len(files))
for file in files:
#Read each file in the directory
try:
hdu = fits.open(os.path.join(root, file))
except OSError:
print(); print('Invalid file type, skipping file: {}'.format(file))
progress_bar.next(); continue
#Get the flux intensity and corresponding error array as well as redshift
flux, flux_err, z = hdu[0].data, np.sqrt(hdu[1].data), hdu[0].header['Z']
#Recreate the wavelength spectrum from the info given in the WCS of the header
w = WCS(hdu[0].header, naxis=1, relax=False, fix=False)
Lambda = w.wcs_pix2world(np.arange(len(flux)), 0)[0]
mask = np.isfinite(Lambda) & np.isfinite(flux) & np.isfinite(flux_err)
if len(mask) != len(Lambda):
print('WARNING: Non-finite values detected, these values will be omitted...')
Lambda, flux, flux_err = Lambda[mask], flux[mask], flux_err[mask]
#Cut the spectrum blueward of the LyAlpha line and less than the rest frame
mask = np.where((Lambda > (z + 1) * 1310) & (Lambda <= (z + 1) * 1554))[0]
Lambda, flux, flux_err = Lambda[mask], flux[mask], flux_err[mask]
try:
#Generate the contiuum
continuum = Continuum(Lambda, flux, flux_err, halfWindow=self.halfWindow, region_size=self.region_size, resolution_element=self.resolution_element,
savgol_window_size=self.savgol_window_size, savgol_poly_order=self.savgol_poly_order, N_sig_limits=self.N_sig_limits, N_sig_line2=self.N_sig_line2)
continuum.estimate()
except ValueError: #This will catch the failed to fit message!
print(); print('Failed to fit the continuum, skipping file: {}'.format(file))
progress_bar.next(); continue
#Find the MgII Absorption
self.find_MgII_absorption(Lambda, flux, continuum.continuum, flux_err, continuum.continuum_err, z=z, qso_name=file)
progress_bar.next()
progress_bar.finish()
return
[docs] def process_spectrum(self, Lambda, flux, flux_err, z, qso_name=None):
"""
Processes a single spectrum, detecting any Mg II absorption that may be present.
Args:
Lambda (array-like): An array-like object containing the wavelength values of the spectrum.
flux (array-like): An array-like object containing the flux values of the spectrum.
flux_err (array-like): An array-like object containing the flux error values of the spectrum.
z (float): The redshift of the QSO associated with the spectrum.
qso_name (str, optional): The name of the QSO associated with the spectrum, will be
saved in the DataFrame. Defaults to None, in which case 'No_Name' is used.
Returns:
None
"""
mask = np.isfinite(Lambda) & np.isfinite(flux) & np.isfinite(flux_err)
if len(mask) != len(Lambda):
print('WARNING: Non-finite values detected, these values will be omitted...')
Lambda, flux, flux_err = Lambda[mask], flux[mask], flux_err[mask]
qso_name = 'No_Name' if qso_name is None else qso_name
#Cut the spectrum blueward of the LyAlpha line and less than the rest frame
mask = np.where((Lambda > (z + 1) * 1310) & (Lambda <= (z + 1) * 1554))[0]
Lambda, flux, flux_err = Lambda[mask], flux[mask], flux_err[mask]
#Generate the contiuum
continuum = Continuum(Lambda, flux, flux_err, halfWindow=self.halfWindow, region_size=self.region_size, resolution_element=self.resolution_element,
savgol_window_size=self.savgol_window_size, savgol_poly_order=self.savgol_poly_order, N_sig_limits=self.N_sig_limits, N_sig_line2=self.N_sig_line2)
continuum.estimate()
#Save the continuum attributes
self.continuum, self.continuum_err = continuum.continuum, continuum.continuum_err
#Find the MgII Absorption
self.find_MgII_absorption(Lambda, flux, self.continuum, flux_err, self.continuum_err, z=z, qso_name=qso_name)
self.Lambda, self.flux, self.flux_err, self.z, self.qso_name = Lambda, flux, flux_err, z, qso_name #For plotting
return
[docs] def _reprocess(self, qso_name=None):
"""
Reprocesses the data, intended to be used after running process_spectrum().
Useful for changing the attributes and quickly re-running the same sample without
having to re-input the spectra.
Note:
This will update the DataFrame by appending the new object line features (if found).
Args:
qso_name (str, optional):
Returns:
None
"""
qso_name = self.qso_name if qso_name is None else 'No_Name'
#Cut the spectrum blueward of the LyAlpha line and less than the rest frame
mask = np.where((self.Lambda > (self.z + 1) * 1310) & (self.Lambda <= (self.z + 1) * 1554))[0]
self.Lambda, self.flux, self.flux_err = self.Lambda[mask], self.flux[mask], self.flux_err[mask]
#Generate the contiuum
continuum = Continuum(Lambda, flux, flux_err, halfWindow=self.halfWindow, region_size=self.region_size, resolution_element=self.resolution_element,
savgol_window_size=self.savgol_window_size, savgol_poly_order=self.savgol_poly_order, N_sig_limits=self.N_sig_limits, N_sig_line2=self.N_sig_line2)
continuum.estimate()
#Save the continuum attributes
self.continuum, self.continuum_err = continuum.continuum, continuum.continuum_err
#Find the MgII Absorption
self.find_MgII_absorption(self.Lambda, self.flux, self.continuum, self.flux_err, self.continuum_err, z=self.z, qso_name=self.qso_name)
return
[docs] def find_MgII_absorption(self, Lambda, y, yC, sig_y, sig_yC, z, qso_name=None):
"""
Finds Mg II absorption features in the QSO spectrum and adds the line information to the DataFrame,
including the Equivalent Width and the corresponding error.
Args:
Lambda (array-like): Wavelength array.
y (array-like): Observed flux array.
yC (array-like): Estimated continuum flux array.
sig_y (array-like): Observed flux error array.
sig_yC (array-like): Estimated continuum flux error array.
z (float): The redshift of the QSO associated with the spectrum.
qso_name (str, optional): The name of the QSO associated with the spectrum, will be
saved in the DataFrame. Defaults to None, in which case 'No_Name' is used.
Returns:
None
"""
#Declare an array to hold the resolution at each wavelength
if isinstance(self.resolution_range, int) or isinstance(self.resolution_range, float):
R = [[self.resolution_range] * len(Lambda)]
else:
R = np.linspace(self.resolution_range[0], self.resolution_range[1], len(Lambda))
#The MgII function finds the lines
Mg2796, Mg2803, EW2796, EW2803, deltaEW2796, deltaEW2803 = MgII(Lambda, y, yC, sig_y, sig_yC, R, N_sig_line1=self.N_sig_line1, N_sig_line2=self.N_sig_line2,
N_sig_limits=self.N_sig_limits, resolution_element=self.resolution_element, rest_wavelength_1=self.rest_wavelength_1, rest_wavelength_2=self.rest_wavelength_2)
Mg2796, Mg2803 = np.array(Mg2796), np.array(Mg2803)
self.Mg2796, self.Mg2803 = Mg2796.astype(int), Mg2803.astype(int)
if len(self.Mg2796) != 0:
for i in range(0, len(Mg2796) - 1, 2):
wavelength = (Lambda[self.Mg2796[i]] + Lambda[self.Mg2796[i+1]])/2
new_row = {'QSO': qso_name, 'Wavelength': wavelength, 'z': wavelength/self.rest_wavelength_1 - 1, 'W': EW2796[i], 'deltaW': deltaEW2796[i]}
self.df = pd.concat([self.df, pd.DataFrame(new_row, index=[0])], ignore_index=True)
else:
if self.save_all:
new_row = {'QSO': qso_name, 'Wavelength': 'None', 'z': 'None', 'W': 'None', 'deltaW': 'None'}
self.df = pd.concat([self.df, pd.DataFrame(new_row, index=[0])], ignore_index=True)
return
[docs] def optimize(self, Lambda, flux, flux_err, z_qso, z_element, halfWindow, region_size, resolution_element,
savgol_window_size, savgol_poly_order, N_sig_limits, N_sig_line1, N_sig_line2, n_trials=100, threshold=0.1, show_progress_bar=False):
"""
This class method will optimize the element detection parameters according to the
input constraints
Args:
resolution_element (bool): A tuple containing the (min, max) parameter range
to search through. Can be integer to hard code this parameter and exclude from
the grid search.
halfWindow (tuple): A tuple containing the (min, max) parameter range
to search through. Can be integer to hard code this parameter and exclude from
the grid search.
N_sig_limits (bool): A tuple containing the (min, max) parameter range
to search through. Can be integer to hard code this parameter and exclude from
the grid search.
N_sig_line1 (bool): A tuple containing the (min, max) parameter range
to search through. Can be integer to hard code this parameter and exclude from
the grid search.
N_sig_line2 (bool): A tuple containing the (min, max) parameter range
to search through. Can be integer to hard code this parameter and exclude from
the grid search.
n_trials (int): Number of trials to run the optimizer for. Defaults to 100.
threshold (float): The desired threshold. For example, if this is set to 0.1, the
optimization will stop if the target is within this tolerance. Defaults to 0.0001 which
will essentially top the optimization when it reaches the exact target value.
"""
def check_threshold(study, trial):
"""Check if the best value in the study is below the threshold.
This function compares the best value found in the Optuna study with a threshold value.
If the best value is below the threshold, a `ThresholdExceeded` exception is raised.
Args:
study (optuna.study.Study): The Optuna study object containing the optimization results.
trial (optuna.trial.FrozenTrial): The Optuna trial object representing the current trial.
Raises:
ThresholdExceeded: If the best value in the study is below the threshold.
Returns:
None
"""
if study.best_value > 1 - threshold:
raise ThresholdExceeded()
return
sampler = optuna.samplers.TPESampler(seed=1909)
study = optuna.create_study(direction='maximize', sampler=sampler)#, pruner=optuna.pruners.MedianPruner(n_startup_trials=5, n_warmup_steps=30, interval_steps=10))
objective = objective_spec(Lambda, flux, flux_err, z_qso=z_qso, z_element=z_element, rest_wavelength_1=self.rest_wavelength_1, rest_wavelength_2=self.rest_wavelength_2,
resolution_range=self.resolution_range, halfWindow=halfWindow, region_size=region_size, resolution_element=resolution_element, savgol_window_size=savgol_window_size,
savgol_poly_order=savgol_poly_order, N_sig_limits=N_sig_limits, N_sig_line1=N_sig_line1, N_sig_line2=N_sig_line2)
try:
study.optimize(objective, n_trials=n_trials, show_progress_bar=show_progress_bar, callbacks=[check_threshold])#, gc_after_trial=True)
except ThresholdExceeded:
print(); print(f"The optimal value is within the {threshold} tolerance threshold, stopping optimization.")
self.optimization_results, self.best_params = study, study.best_trial.params
print(); print(f"Optimization complete with a final score of {study.best_value}! Optimal parameters : {self.best_params}")
print(); print('Re-setting class attributes to optimal values...')
self.halfWindow = self.best_params['halfWindow'] if isinstance(halfWindow, tuple) else halfWindow
self.region_size = self.best_params['region_size'] if isinstance(region_size, tuple) else region_size
self.resolution_element = self.best_params['resolution_element'] if isinstance(resolution_element, tuple) else resolution_element
self.savgol_window_size = self.best_params['savgol_window_size'] if isinstance(savgol_window_size, tuple) else savgol_window_size
self.savgol_poly_order = self.best_params['savgol_poly_order'] if isinstance(savgol_poly_order, tuple) else savgol_poly_order
self.N_sig_limits = self.best_params['N_sig_limits'] if isinstance(N_sig_limits, tuple) else N_sig_limits
self.N_sig_line1 = self.best_params['N_sig_line1'] if isinstance(N_sig_line1, tuple) else N_sig_line1
self.N_sig_line2 = self.best_params['N_sig_line2'] if isinstance(N_sig_line2, tuple) else N_sig_line2
print('Re-fitting spectra with optimal values, setting qso name to "Optimized".')
self.process_spectrum(Lambda, flux, flux_err, z_qso, qso_name='Optimized')
return
[docs] def plot(self, include='both', highlight=True, xlim=None, ylim=None, xlog=False, ylog=False,
savefig=False, path=None):
"""
Plots the spectrum and/or continuum.
Args:
include (float): Designates what to plot, options include 'spectrum', 'continuum', or 'both.
highlight (bool): If True then the line will be highlighted with accompanying
vertical lines to visualize the equivalent width. Defaults to True.
xlim: Limits for the x-axis. Ex) xlim = (4000, 6000)
ylim: Limits for the y-axis. Ex) ylim = (0.9, 0.94)
xlog (boolean): If True the x-axis will be log-scaled.
Defaults to True.
ylog (boolean): If True the y-axis will be log-scaled.
Defaults to False.
savefig (bool): If True the figure will not disply but will be saved instead.
Defaults to False.
path (str, optional): Path in which the figure should be saved, defaults to None
in which case the image is saved in the local home directory.
Returns:
AxesImage
"""
if self.continuum is None or self.flux is None:
raise ValueError('This method only works after a single spectrum has been processed via the process_spectrum method.')
_set_style_() if savefig else plt.style.use('default')
if include == 'continuum' or include == 'both':
plt.errorbar(self.Lambda, self.continuum, yerr=None, label='Continuum', fmt='r--', linewidth=0.6)
if include == 'spectrum' or include == 'both':
plt.errorbar(self.Lambda, self.flux, yerr=None, fmt='k-.', linewidth=0.2)
plt.title(self.qso_name, size=14)
plt.xlabel('Wavelength [Ã…]', size=12); plt.ylabel('Flux', alpha=1, color='k', size=12)
plt.xticks(fontsize=10); plt.yticks(fontsize=12)
plt.xscale('log') if xlog else None; plt.yscale('log') if ylog else None
plt.xlim(xlim) if xlim is not None else None; plt.ylim(ylim) if ylim is not None else None
plt.legend(prop={'size': 12}) if include != 'spectrum' else None
if highlight:
if len(self.Mg2796) == 0:
print('The highlight parameter is enabled but no line was detected!')
else:
for i in range(0, len(self.Mg2796) - 1, 2):
plt.axvline(x = self.Lambda[self.Mg2796[i]], color = 'orange')
plt.axvline(x = self.Lambda[self.Mg2796[i+1]], color = 'orange')
plt.axvline(x = self.Lambda[self.Mg2803[i]], color = 'red')
plt.axvline(x = self.Lambda[self.Mg2803[i + 1]], color = 'red')
if savefig:
path = str(Path.home()) if path is None else path
path += '/' if path[-1] != '/' else ''
plt.savefig(path+'Spectrum_'+self.qso_name+'.png', dpi=300, bbox_inches='tight')
print('Figure saved in: {}'.format(path)); plt.clf()
else:
plt.show()
return
[docs] def plot_param_opt(self, xlim=None, ylim=None, xlog=True, ylog=False, savefig=False):
"""
Plots the parameter optimization history.
Note:
The Optuna API has its own plot function: plot_optimization_history(self.optimization_results)
Args:
baseline (float): Baseline accuracy achieved when using only
the default engine hyperparameters. If input a vertical
line will be plot to indicate this baseline accuracy.
Defaults to None.
xlim: Limits for the x-axis. Ex) xlim = (0, 1000)
ylim: Limits for the y-axis. Ex) ylim = (0.9, 0.94)
xlog (boolean): If True the x-axis will be log-scaled.
Defaults to True.
ylog (boolean): If True the y-axis will be log-scaled.
Defaults to False.
savefig (bool): If True the figure will not disply but will be saved instead.
Defaults to False.
Returns:
AxesImage
"""
trials = self.optimization_results.get_trials()
trial_values, best_value = [], []
for trial in range(len(trials)):
value = trials[trial].values[0]
trial_values.append(value)
if trial == 0:
best_value.append(value)
else:
if any(y > value for y in best_value): #If there are any numbers in best values that are higher than current one
best_value.append(np.array(best_value)[trial-1])
else:
best_value.append(value)
best_value, trial_values = np.array(best_value), np.array(trial_values)
best_value[1] = trial_values[1] #Make the first trial the best model, since technically it is.
for i in range(2, len(trial_values)):
if trial_values[i] < best_value[1]:
best_value[i] = best_value[1]
else:
break
_set_style_() if savefig else plt.style.use('default')
plt.plot(range(len(trials)), best_value, color='r', alpha=0.83, linestyle='-', label='Optimal')
plt.scatter(range(len(trials)), trial_values, c='b', marker='+', s=35, alpha=0.45, label='Trial')
plt.xlabel('Trial #', alpha=1, color='k'); plt.ylabel('Objective Value', alpha=1, color='k')
plt.title('Parameter Optimization History')
plt.legend(loc='upper center', ncol=2, frameon=False)
plt.rcParams['axes.facecolor']='white'; plt.grid(False)
plt.xlim(xlim) if xlim is not None else plt.xlim((1, len(trials)))
plt.ylim(ylim) if ylim is not None else None
plt.xscale('log') if xlog else None
plt.yscale('log') if ylog else None
if savefig:
plt.savefig('Ensemble_Hyperparameter_Optimization.png', bbox_inches='tight', dpi=300)
plt.clf(); plt.style.use('default')
else:
plt.show()
return
[docs] def plot_param_importance(self, plot_time=False, savefig=False):
"""
Plots the hyperparameter optimization history.
Note:
The Optuna API provides its own plotting function: plot_param_importances(self.optimization_results)
Args:
plot_tile (bool): If True, the importance on the duration will also be included. Defaults to False.
savefig (bool): If True the figure will not display but will be saved instead. Defaults to False.
Returns:
AxesImage
"""
param_importances = get_param_importances(self.optimization_results)
time_importance = FanovaImportanceEvaluator()
duration_importances = time_importance.evaluate(self.optimization_results, target=lambda t: t.duration.total_seconds())
params, importance, duration_importance = [], [], []
for key in param_importances:
params.append(key)
for name in params:
importance.append(param_importances[name])
duration_importance.append(duration_importances[name])
xtick_labels = format_labels(params)
_set_style_() if savefig else plt.style.use('default')
fig, ax = plt.subplots()
ax.barh(xtick_labels, importance, label='Importance for Line Detection', color=mcolors.TABLEAU_COLORS["tab:blue"], alpha=0.87)
if plot_time:
ax.barh(xtick_labels, duration_importance, label='Impact on Detection Speed', color=mcolors.TABLEAU_COLORS["tab:orange"], alpha=0.7, hatch='/')
ax.set_ylabel("Hyperparameter"); ax.set_xlabel("Importance Evaluation")
ax.legend(ncol=2, frameon=False, handlelength=2, bbox_to_anchor=(0.5, 1.1), loc='upper center')
ax.set_xscale('log'); plt.xlim((0, 1.))
plt.gca().invert_yaxis()
if savefig:
if plot_time:
plt.savefig('Element_Detection_Parameter_Importance.png', bbox_inches='tight', dpi=300)
else:
plt.savefig('Element_Detection_Parameter_Duration_Importance.png', bbox_inches='tight', dpi=300)
plt.clf(); plt.style.use('default')
else:
plt.show()
return
[docs]class objective_spec(object):
"""
This class is used for optimizing the spectra processing parameters,
using the high-level Optuna API
Args:
Lambda (array-like): An array-like object containing the wavelength values of the spectrum.
flux (array-like): An array-like object containing the flux values of the spectrum.
flux_err (array-like): An array-like object containing the flux error values of the spectrum.
z_qso (float): The redshift of the QSO associated with the spectrum.
z_element (float):
rest_wavelength_1 (float):
rest_wavelength_2 (float):
resolution_range (tuple):
resolution_element (bool): A tuple containing the (min, max) parameter range
to search through. Can be integer to hard code this parameter and exclude from
the grid search.
halfWindow (tuple): A tuple containing the (min, max) parameter range
to search through. Can be integer to hard code this parameter and exclude from
the grid search.
N_sig_limits (bool): A tuple containing the (min, max) parameter range
to search through. Can be integer to hard code this parameter and exclude from
the grid search.
N_sig_line1 (bool): A tuple containing the (min, max) parameter range
to search through. Can be integer to hard code this parameter and exclude from
the grid search.
N_sig_line2 (bool): A tuple containing the (min, max) parameter range
to search through. Can be integer to hard code this parameter and exclude from
the grid search.
"""
def __init__(self, Lambda, flux, flux_err, z_qso, z_element, rest_wavelength_1, rest_wavelength_2, resolution_range,
halfWindow, region_size, resolution_element, savgol_window_size, savgol_poly_order, N_sig_limits, N_sig_line1, N_sig_line2):
# Spectra
self.Lambda = Lambda
self.flux = flux
self.flux_err = flux_err
self.z_qso = z_qso
self.z_element = z_element
# Presets
self.rest_wavelength_1 = rest_wavelength_1
self.rest_wavelength_2 = rest_wavelength_2
self.resolution_range = resolution_range
# Tunable
self.halfWindow = halfWindow
self.region_size = region_size
self.resolution_element = resolution_element
self.savgol_window_size = savgol_window_size
self.savgol_poly_order = savgol_poly_order
self.N_sig_limits = N_sig_limits
self.N_sig_line1 = N_sig_line1
self.N_sig_line2 = N_sig_line2
[docs] def __call__(self, trial):
# Suggest the parameters that have been input as tuples
halfWindow = trial.suggest_int('halfWindow', self.halfWindow[0], self.halfWindow[1], step=1) if isinstance(self.halfWindow, tuple) else self.halfWindow
region_size = trial.suggest_int('region_size', self.region_size[0], self.region_size[1], step=1) if isinstance(self.region_size, tuple) else self.region_size
resolution_element = trial.suggest_int('resolution_element', self.resolution_element[0], self.resolution_element[1], step=1) if isinstance(self.resolution_element, tuple) else self.resolution_element
savgol_window_size = trial.suggest_int('savgol_window_size', self.savgol_window_size[0], self.savgol_window_size[1], step=1) if isinstance(self.savgol_window_size, tuple) else self.savgol_window_size
savgol_poly_order = trial.suggest_int('savgol_poly_order', self.savgol_poly_order[0], self.savgol_poly_order[1], step=1) if isinstance(self.savgol_poly_order, tuple) else self.savgol_poly_order
N_sig_limits = trial.suggest_float('N_sig_limits', self.N_sig_limits[0], self.N_sig_limits[1], step=0.05) if isinstance(self.N_sig_limits, tuple) else self.N_sig_limits
N_sig_line1 = trial.suggest_float('N_sig_line1', self.N_sig_line1[0], self.N_sig_line1[1], step=0.05) if isinstance(self.N_sig_line1, tuple) else self.N_sig_line1
N_sig_line2 = trial.suggest_float('N_sig_line2', self.N_sig_line2[0], self.N_sig_line2[1], step=0.05) if isinstance(self.N_sig_line2, tuple) else self.N_sig_line2
if N_sig_limits > N_sig_line1 or N_sig_limits > N_sig_line2:
return -999
trial_spectrum = Spectrum(halfWindow=halfWindow, resolution_range=self.resolution_range, region_size=region_size, resolution_element=resolution_element,
savgol_window_size=savgol_window_size, savgol_poly_order=savgol_poly_order, N_sig_limits=N_sig_limits, N_sig_line1=N_sig_line1, N_sig_line2=N_sig_line2,
rest_wavelength_1=self.rest_wavelength_1, rest_wavelength_2=self.rest_wavelength_2, save_all=False)
try:
trial_spectrum.process_spectrum(self.Lambda, self.flux, self.flux_err, self.z_qso)
except:
return -99
if len(trial_spectrum.df) != 0:
return 1 - abs(self.z_element - trial_spectrum.df.z)
else:
return -9
[docs]class ThresholdExceeded(optuna.exceptions.OptunaError):
"""Exception raised when a threshold is exceeded during optimization with Optuna.
This exception is a subclass of `optuna.exceptions.OptunaError` and is used to
represent an error condition where a threshold has been exceeded during the
optimization process with Optuna.
Args:
None
"""
pass
[docs]def _set_style_():
"""
Function to configure the matplotlib.pyplot style. This function is called before any images are saved,
after which the style is reset to the default.
Args:
None
"""
plt.rcParams["xtick.color"] = "323034"
plt.rcParams["ytick.color"] = "323034"
plt.rcParams["text.color"] = "323034"
plt.rcParams["lines.markeredgecolor"] = "black"
plt.rcParams["patch.facecolor"] = "#bc80bd" # Replace with a valid color code
plt.rcParams["patch.force_edgecolor"] = True
plt.rcParams["patch.linewidth"] = 0.8
plt.rcParams["scatter.edgecolors"] = "black"
plt.rcParams["grid.color"] = "#b1afb5" # Replace with a valid color code
plt.rcParams["axes.titlesize"] = 16
plt.rcParams["legend.title_fontsize"] = 12
plt.rcParams["xtick.labelsize"] = 16
plt.rcParams["ytick.labelsize"] = 16
plt.rcParams["font.size"] = 15
plt.rcParams["axes.prop_cycle"] = (cycler('color', ['#bc80bd', '#fb8072', '#b3de69', '#fdb462', '#fccde5', '#8dd3c7', '#ffed6f', '#bebada', '#80b1d3', '#ccebc5', '#d9d9d9'])) # Replace with valid color codes
plt.rcParams["mathtext.fontset"] = "stix"
plt.rcParams["font.family"] = "STIXGeneral"
plt.rcParams["lines.linewidth"] = 2
plt.rcParams["lines.markersize"] = 6
plt.rcParams["legend.frameon"] = True
plt.rcParams["legend.framealpha"] = 0.8
plt.rcParams["legend.fontsize"] = 13
plt.rcParams["legend.edgecolor"] = "black"
plt.rcParams["legend.borderpad"] = 0.2
plt.rcParams["legend.columnspacing"] = 1.5
plt.rcParams["legend.labelspacing"] = 0.4
plt.rcParams["text.usetex"] = False
plt.rcParams["axes.labelsize"] = 17
plt.rcParams["axes.titlelocation"] = "center"
plt.rcParams["axes.formatter.use_mathtext"] = True
plt.rcParams["axes.autolimit_mode"] = "round_numbers"
plt.rcParams["axes.labelpad"] = 3
plt.rcParams["axes.formatter.limits"] = (-4, 4)
plt.rcParams["axes.labelcolor"] = "black"
plt.rcParams["axes.edgecolor"] = "black"
plt.rcParams["axes.linewidth"] = 1
plt.rcParams["axes.grid"] = False
plt.rcParams["axes.spines.right"] = True
plt.rcParams["axes.spines.left"] = True
plt.rcParams["axes.spines.top"] = True
plt.rcParams["figure.titlesize"] = 18
plt.rcParams["figure.autolayout"] = True
plt.rcParams["figure.dpi"] = 300
return