Untitled
public
May 02, 2024
Never
22
1 import pandas as pd 2 import matplotlib.pyplot as plt 3 import os 4 from pathlib import Path 5 import numpy as np 6 import matplotlib.dates as mdates 7 import ast # Import ast module for safely evaluating strings as Python expressions 8 import statsmodels.api as sm 9 from statsmodels.formula.api import ols 10 from statsmodels.stats.multicomp import pairwise_tukeyhsd 11 12 def find_com_ids(path_parent, cases_runs): 13 com_ids = set() 14 for case, run_number in cases_runs: 15 directory_path = Path(path_parent) / case / run_number 16 for folder in directory_path.iterdir(): 17 if folder.is_dir() and 'a_predictions.csv' in os.listdir(folder): 18 com_ids.add(folder.name) 19 return com_ids 20 21 def plot_and_analyze_metrics(path_parent, cases_runs, output_directory): 22 com_ids = find_com_ids(path_parent, cases_runs) 23 24 for com_id in com_ids: 25 plt.figure(figsize=(20, 10)) 26 all_data = pd.DataFrame() 27 true_vol_plotted = False 28 29 for case, run_number in cases_runs: 30 file_path = Path(path_parent) / case / run_number / com_id / 'a_predictions.csv' 31 if file_path.exists(): 32 df = pd.read_csv(file_path) 33 df['YYYY-MM-DD'] = pd.to_datetime(df['YYYY-MM-DD']) 34 df['predicted_volatility'] = df['predicted_volatility'].apply(lambda x: float(ast.literal_eval(x)[0])) 35 df['true_volatility'] = df['true_volatility'].apply(lambda x: float(ast.literal_eval(x)[0])) 36 df['Model_Type'] = f"{case}_{run_number}" 37 38 # Scale the metrics 39 scaling_factor = 1000 40 df['MAE'] *= scaling_factor 41 df['MSE'] *= scaling_factor 42 df['MAPE'] *= scaling_factor 43 44 plt.plot(df['YYYY-MM-DD'], df['predicted_volatility'], label=f'{case} Predicted', alpha=0.7) 45 if not true_vol_plotted: 46 plt.plot(df['YYYY-MM-DD'], df['true_volatility'], label='True Volatility', color='orange', alpha=0.4) 47 true_vol_plotted = True 48 49 all_data = pd.concat([all_data, df], ignore_index=True) 50 #print(all_data.head()) 51 52 plt.title(f'Predicted vs True Volatility for {com_id}') 53 plt.xlabel('Date') 54 plt.ylabel('Volatility') 55 plt.legend() 56 plt.grid(True) 57 plt.gca().xaxis.set_major_locator(mdates.YearLocator()) 58 plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%Y')) 59 plt.xticks(rotation=45) 60 plt.tight_layout() 61 plt.savefig(Path(output_directory) / f"{com_id}_volatility_comparison.png") 62 plt.close() 63 64 # Perform statistical analysis on error metrics 65 for metric in ['MAE', 'MSE', 'MAPE']: 66 model = ols(f'{metric} ~ C(Model_Type)', data=all_data).fit() 67 anova_results = sm.stats.anova_lm(model, typ=2) 68 print(f"ANOVA results for {com_id} - {metric}:\n{anova_results}") 69 70 if anova_results['PR(>F)'][0] < 0.05: 71 tukey = pairwise_tukeyhsd(endog=all_data[metric], groups=all_data['Model_Type'], alpha=0.05) 72 print(f"Tukey HSD results for {com_id} - {metric}:") 73 print(tukey.summary()) 74 75 # Example usage 76 path_parent = '/mnt/c/Loralee/working/v2/vv2/results/' 77 cases_runs = [('LSTM', 'run82'), ('LSTM_GARCH', 'run10'), ('LSTM_eGARCH', 'run42'), ('LSTM_gjrGARCH', 'run36')] 78 output_directory = '/mnt/c/Loralee/working/model_vol_comparisons' 79 plot_and_analyze_metrics(path_parent, cases_runs, output_directory)