G

Untitled

public
Guest May 02, 2024 Never 22
Clone
Plaintext paste1.txt 79 lines (68 loc) | 3.56 KB
1
import pandas as pd
2
import matplotlib.pyplot as plt
3
import os
4
from pathlib import Path
5
import numpy as np
6
import matplotlib.dates as mdates
7
import ast # Import ast module for safely evaluating strings as Python expressions
8
import statsmodels.api as sm
9
from statsmodels.formula.api import ols
10
from statsmodels.stats.multicomp import pairwise_tukeyhsd
11
12
def find_com_ids(path_parent, cases_runs):
13
com_ids = set()
14
for case, run_number in cases_runs:
15
directory_path = Path(path_parent) / case / run_number
16
for folder in directory_path.iterdir():
17
if folder.is_dir() and 'a_predictions.csv' in os.listdir(folder):
18
com_ids.add(folder.name)
19
return com_ids
20
21
def plot_and_analyze_metrics(path_parent, cases_runs, output_directory):
22
com_ids = find_com_ids(path_parent, cases_runs)
23
24
for com_id in com_ids:
25
plt.figure(figsize=(20, 10))
26
all_data = pd.DataFrame()
27
true_vol_plotted = False
28
29
for case, run_number in cases_runs:
30
file_path = Path(path_parent) / case / run_number / com_id / 'a_predictions.csv'
31
if file_path.exists():
32
df = pd.read_csv(file_path)
33
df['YYYY-MM-DD'] = pd.to_datetime(df['YYYY-MM-DD'])
34
df['predicted_volatility'] = df['predicted_volatility'].apply(lambda x: float(ast.literal_eval(x)[0]))
35
df['true_volatility'] = df['true_volatility'].apply(lambda x: float(ast.literal_eval(x)[0]))
36
df['Model_Type'] = f"{case}_{run_number}"
37
38
# Scale the metrics
39
scaling_factor = 1000
40
df['MAE'] *= scaling_factor
41
df['MSE'] *= scaling_factor
42
df['MAPE'] *= scaling_factor
43
44
plt.plot(df['YYYY-MM-DD'], df['predicted_volatility'], label=f'{case} Predicted', alpha=0.7)
45
if not true_vol_plotted:
46
plt.plot(df['YYYY-MM-DD'], df['true_volatility'], label='True Volatility', color='orange', alpha=0.4)
47
true_vol_plotted = True
48
49
all_data = pd.concat([all_data, df], ignore_index=True)
50
#print(all_data.head())
51
52
plt.title(f'Predicted vs True Volatility for {com_id}')
53
plt.xlabel('Date')
54
plt.ylabel('Volatility')
55
plt.legend()
56
plt.grid(True)
57
plt.gca().xaxis.set_major_locator(mdates.YearLocator())
58
plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%Y'))
59
plt.xticks(rotation=45)
60
plt.tight_layout()
61
plt.savefig(Path(output_directory) / f"{com_id}_volatility_comparison.png")
62
plt.close()
63
64
# Perform statistical analysis on error metrics
65
for metric in ['MAE', 'MSE', 'MAPE']:
66
model = ols(f'{metric} ~ C(Model_Type)', data=all_data).fit()
67
anova_results = sm.stats.anova_lm(model, typ=2)
68
print(f"ANOVA results for {com_id} - {metric}:\n{anova_results}")
69
70
if anova_results['PR(>F)'][0] < 0.05:
71
tukey = pairwise_tukeyhsd(endog=all_data[metric], groups=all_data['Model_Type'], alpha=0.05)
72
print(f"Tukey HSD results for {com_id} - {metric}:")
73
print(tukey.summary())
74
75
# Example usage
76
path_parent = '/mnt/c/Loralee/working/v2/vv2/results/'
77
cases_runs = [('LSTM', 'run82'), ('LSTM_GARCH', 'run10'), ('LSTM_eGARCH', 'run42'), ('LSTM_gjrGARCH', 'run36')]
78
output_directory = '/mnt/c/Loralee/working/model_vol_comparisons'
79
plot_and_analyze_metrics(path_parent, cases_runs, output_directory)