import argparse
import itertools
import json
import os
import sys
import flask
import matplotlib
import numpy as np
import plotly.graph_objs as go
import plotly.offline as po
from viskit import core
from viskit.misc import ext
from viskit.misc.ext import flatten
matplotlib.use('Agg')
[docs]def sliding_mean(data_array, window=5):
data_array = np.array(data_array)
new_list = []
for i in range(len(data_array)):
indices = list(
range(
max(i - window + 1, 0), min(i + window + 1, len(data_array))))
avg = 0
for j in indices:
avg += data_array[j]
avg /= float(len(indices))
new_list.append(avg)
return np.array(new_list)
app = flask.Flask(__name__, static_url_path='/static')
exps_data = None
plottable_keys = None
distinct_params = None
[docs]@app.route('/js/<path:path>')
def send_js(path):
return flask.send_from_directory('js', path)
[docs]@app.route('/css/<path:path>')
def send_css(path):
return flask.send_from_directory('css', path)
[docs]def make_plot(plot_list,
use_median=False,
plot_width=None,
plot_height=None,
title=None):
data = []
p25, p50, p75 = [], [], []
for idx, plt in enumerate(plot_list):
color = core.color_defaults[idx % len(core.color_defaults)]
if use_median:
p25.append(np.mean(plt.percentile25))
p50.append(np.mean(plt.percentile50))
p75.append(np.mean(plt.percentile75))
x = list(range(len(plt.percentile50)))
y = list(plt.percentile50)
y_upper = list(plt.percentile75)
y_lower = list(plt.percentile25)
else:
x = list(range(len(plt.means)))
y = list(plt.means)
y_upper = list(plt.means + plt.stds)
y_lower = list(plt.means - plt.stds)
data.append(
go.Scatter(
x=x + x[::-1],
y=y_upper + y_lower[::-1],
fill='tozerox',
fillcolor=core.hex_to_rgb(color, 0.2),
line=go.Line(color='transparent'),
showlegend=False,
legendgroup=plt.legend,
hoverinfo='none'))
data.append(
go.Scatter(
x=x,
y=y,
name=plt.legend,
legendgroup=plt.legend,
line=dict(color=core.hex_to_rgb(color)),
))
p25str = '['
p50str = '['
p75str = '['
for p25e, p50e, p75e in zip(p25, p50, p75):
p25str += (str(p25e) + ',')
p50str += (str(p50e) + ',')
p75str += (str(p75e) + ',')
p25str += ']'
p50str += ']'
p75str += ']'
print(p25str)
print(p50str)
print(p75str)
layout = go.Layout(
legend=dict(
x=1,
y=1,
# xanchor="left",
# yanchor="bottom",
),
width=plot_width,
height=plot_height,
title=title,
)
fig = go.Figure(data=data, layout=layout)
fig_div = po.plot(fig, output_type='div', include_plotlyjs=False)
if "footnote" in plot_list[0]:
footnote = "<br />".join([
r"<span><b>%s</b></span>: <span>%s</span>" %
(plt.legend, plt.footnote) for plt in plot_list
])
return r"%s<div>%s</div>" % (fig_div, footnote)
else:
return fig_div
[docs]def make_plot_eps(plot_list, use_median=False, counter=0):
import matplotlib.pyplot as _plt
f, ax = _plt.subplots(figsize=(8, 5))
for idx, plt in enumerate(plot_list):
color = core.color_defaults[idx % len(core.color_defaults)]
if use_median:
x = list(range(len(plt.percentile50)))
y = list(plt.percentile50)
y_upper = list(plt.percentile75)
y_lower = list(plt.percentile25)
else:
x = list(range(len(plt.means)))
y = list(plt.means)
y_upper = list(plt.means + plt.stds)
y_lower = list(plt.means - plt.stds)
plt.legend = plt.legend.replace('garage.algos.trpo.TRPO', 'TRPO')
plt.legend = plt.legend.replace('garage.algos.vpg.VPG', 'REINFORCE')
plt.legend = plt.legend.replace('garage.algos.erwr.ERWR', 'ERWR')
plt.legend = plt.legend.replace('sandbox.rein.algos.trpo_vime.TRPO',
'TRPO+VIME')
plt.legend = plt.legend.replace('sandbox.rein.algos.vpg_vime.VPG',
'REINFORCE+VIME')
plt.legend = plt.legend.replace('sandbox.rein.algos.erwr_vime.ERWR',
'ERWR+VIME')
plt.legend = plt.legend.replace('0.0001', '1e-4')
# plt.legend = plt.legend.replace('0.001', 'TRPO+VIME')
# plt.legend = plt.legend.replace('0', 'TRPO')
# plt.legend = plt.legend.replace('0.005', 'TRPO+L2')
if idx == 0:
plt.legend = 'TRPO (0.0)'
if idx == 1:
plt.legend = 'TRPO+VIME (103.7)'
if idx == 2:
plt.legend = 'TRPO+L2 (0.0)'
ax.fill_between(
x,
y_lower,
y_upper,
interpolate=True,
facecolor=color,
linewidth=0.0,
alpha=0.3)
if idx == 2:
ax.plot(
x,
y,
color=color,
label=plt.legend,
linewidth=2.0,
linestyle="--")
else:
ax.plot(x, y, color=color, label=plt.legend, linewidth=2.0)
ax.grid(True)
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
if counter == 1:
# ax.set_xlim([0, 120])
ax.set_ylim([-3, 60])
# ax.set_xlim([0, 80])
loc = 'upper left'
elif counter == 2:
ax.set_ylim([-0.04, 0.4])
# ax.set_ylim([-0.1, 0.4])
ax.set_xlim([0, 2000])
loc = 'upper left'
elif counter == 3:
# ax.set_xlim([0, 1000])
loc = 'lower right'
elif counter == 4:
# ax.set_xlim([0, 800])
# ax.set_ylim([0, 2])
loc = 'lower right'
leg = ax.legend(loc=loc, prop={'size': 12}, ncol=1)
for legobj in leg.legendHandles:
legobj.set_linewidth(5.0)
def y_fmt(x, y):
return str(int(np.round(x / 1000.0))) + 'K'
_plt.savefig('tmp' + str(counter) + '.pdf', bbox_inches='tight')
[docs]def summary_name(exp, selector=None):
# if selector is not None:
# exclude_params = set([x[0] for x in selector._filters])
# else:
# exclude_params = set()
# rest_params = set(
# [x[0] for x in distinct_params]).difference(exclude_params)
# if len(rest_params) > 0:
# name = ""
# for k in rest_params:
# name += "%s=%s;" % (
# k.split(".")[-1],
# str(exp.flat_params.get(k, "")).split(".")[-1]
# )
# return name
return exp.params["exp_name"]
[docs]def check_nan(exp):
return all(
not np.any(np.isnan(vals)) for vals in list(exp.progress.values()))
[docs]def get_plot_instruction(plot_key,
split_key=None,
group_key=None,
filters=None,
use_median=False,
only_show_best=False,
only_show_best_final=False,
gen_eps=False,
only_show_best_sofar=False,
clip_plot_value=None,
plot_width=None,
plot_height=None,
filter_nan=False,
smooth_curve=False,
custom_filter=None,
legend_post_processor=None,
normalize_error=False,
custom_series_splitter=None):
print(plot_key, split_key, group_key, filters)
if filter_nan:
nonnan_exps_data = list(filter(check_nan, exps_data))
selector = core.Selector(nonnan_exps_data)
else:
selector = core.Selector(exps_data)
if legend_post_processor is None:
def default_legend_post_processor(x):
return x
legend_post_processor = default_legend_post_processor
if filters is None:
filters = dict()
for k, v in filters.items():
selector = selector.where(k, str(v))
if custom_filter is not None:
selector = selector.custom_filter(custom_filter)
# print selector._filters
if split_key is not None:
vs = [vs for k, vs in distinct_params if k == split_key][0]
split_selectors = [selector.where(split_key, v) for v in vs]
split_legends = list(map(str, vs))
else:
split_selectors = [selector]
split_legends = ["Plot"]
plots = []
counter = 1
for split_selector, split_legend in zip(split_selectors, split_legends):
if custom_series_splitter is not None:
exps = split_selector.extract()
splitted_dict = dict()
for exp in exps:
key = custom_series_splitter(exp)
if key not in splitted_dict:
splitted_dict[key] = list()
splitted_dict[key].append(exp)
splitted = list(splitted_dict.items())
group_selectors = [core.Selector(list(x[1])) for x in splitted]
group_legends = [x[0] for x in splitted]
else:
if group_key and group_key != "exp_name":
vs = [vs for k, vs in distinct_params if k == group_key][0]
group_selectors = [
split_selector.where(group_key, v) for v in vs
]
group_legends = [str(x) for x in vs]
else:
group_key = "exp_name"
vs = sorted(
[x.params["exp_name"] for x in split_selector.extract()])
group_selectors = [
split_selector.where(group_key, v) for v in vs
]
group_legends = [
summary_name(x.extract()[0], split_selector)
for x in group_selectors
]
# group_selectors = [split_selector]
# group_legends = [split_legend]
to_plot = []
for group_selector, group_legend in zip(group_selectors,
group_legends):
filtered_data = group_selector.extract()
if filtered_data:
if (only_show_best or only_show_best_final
or only_show_best_sofar):
# Group by seed and sort.
# -----------------------
filtered_params = core.extract_distinct_params(
filtered_data, l=0) # noqa: E741
filtered_params2 = [p[1] for p in filtered_params]
filtered_params_k = [p[0] for p in filtered_params]
product_space = list(itertools.product(*filtered_params2))
data_best_regret = None
best_regret = -np.inf
kv_string_best_regret = None
for idx, params in enumerate(product_space):
selector = core.Selector(exps_data)
for k, v in zip(filtered_params_k, params):
selector = selector.where(k, str(v))
data = selector.extract()
if data:
progresses = [
exp.progress.get(plot_key, np.array([np.nan]))
for exp in data
]
# progresses = [
# progress[:500] for progress in progresses
# ]
sizes = list(map(len, progresses))
max_size = max(sizes)
progresses = [
np.concatenate(
[ps,
np.ones(max_size - len(ps)) * np.nan])
for ps in progresses
]
if only_show_best_final:
progresses = np.asarray(progresses)[:, -1]
if only_show_best_sofar:
progresses = np.max(
np.asarray(progresses), axis=1)
if use_median:
medians = np.nanmedian(progresses, axis=0)
regret = np.mean(medians)
else:
means = np.nanmean(progresses, axis=0)
regret = np.mean(means)
distinct_params_k = [p[0] for p in distinct_params]
distinct_params_v = [
v for k, v in zip(filtered_params_k, params)
if k in distinct_params_k
]
distinct_params_kv = [(k, v) for k, v in zip(
distinct_params_k, distinct_params_v)]
distinct_params_kv_string = str(
distinct_params_kv).replace('), ', ')\t')
print('{}\t{}\t{}'.format(
regret, len(progresses),
distinct_params_kv_string))
if regret > best_regret:
best_regret = regret
best_progress = progresses
data_best_regret = data
kv_string_best_regret = \
distinct_params_kv_string
print(group_selector._filters)
print('best regret: {}'.format(best_regret))
# -----------------------
if best_regret != -np.inf:
progresses = [
exp.progress.get(plot_key, np.array([np.nan]))
for exp in data_best_regret
]
# progresses = \
# [progress[:500] for progress in progresses]
sizes = list(map(len, progresses))
# more intelligent:
max_size = max(sizes)
progresses = [
np.concatenate(
[ps, np.ones(max_size - len(ps)) * np.nan])
for ps in progresses
]
legend = '{} (mu: {:.3f}, std: {:.5f})'.format(
group_legend, best_regret, np.std(best_progress))
window_size = np.maximum(
int(np.round(max_size / float(1000))), 1)
if use_median:
percentile25 = np.nanpercentile(
progresses, q=25, axis=0)
percentile50 = np.nanpercentile(
progresses, q=50, axis=0)
percentile75 = np.nanpercentile(
progresses, q=75, axis=0)
if smooth_curve:
percentile25 = sliding_mean(
percentile25, window=window_size)
percentile50 = sliding_mean(
percentile50, window=window_size)
percentile75 = sliding_mean(
percentile75, window=window_size)
if clip_plot_value is not None:
percentile25 = np.clip(percentile25,
-clip_plot_value,
clip_plot_value)
percentile50 = np.clip(percentile50,
-clip_plot_value,
clip_plot_value)
percentile75 = np.clip(percentile75,
-clip_plot_value,
clip_plot_value)
to_plot.append(
ext.AttrDict(
percentile25=percentile25,
percentile50=percentile50,
percentile75=percentile75,
legend=legend_post_processor(legend)))
else:
means = np.nanmean(progresses, axis=0)
stds = np.nanstd(progresses, axis=0)
if normalize_error: # and len(progresses) > 0:
stds /= np.sqrt(
np.sum((1. - np.isnan(progresses)),
axis=0))
if smooth_curve:
means = sliding_mean(means, window=window_size)
stds = sliding_mean(stds, window=window_size)
if clip_plot_value is not None:
means = np.clip(means, -clip_plot_value,
clip_plot_value)
stds = np.clip(stds, -clip_plot_value,
clip_plot_value)
to_plot.append(
ext.AttrDict(
means=means,
stds=stds,
legend=legend_post_processor(legend)))
if to_plot and data:
to_plot[-1]["footnote"] = "%s; e.g. %s" % (
kv_string_best_regret, data[0].params.get(
"exp_name", "NA"))
else:
to_plot[-1]["footnote"] = ""
else:
progresses = [
exp.progress.get(plot_key, np.array([np.nan]))
for exp in filtered_data
]
sizes = list(map(len, progresses))
# more intelligent:
max_size = max(sizes)
progresses = [
np.concatenate(
[ps, np.ones(max_size - len(ps)) * np.nan])
for ps in progresses
]
window_size = np.maximum(
int(np.round(max_size / float(1000))), 1)
if use_median:
percentile25 = np.nanpercentile(
progresses, q=25, axis=0)
percentile50 = np.nanpercentile(
progresses, q=50, axis=0)
percentile75 = np.nanpercentile(
progresses, q=75, axis=0)
if smooth_curve:
percentile25 = sliding_mean(
percentile25, window=window_size)
percentile50 = sliding_mean(
percentile50, window=window_size)
percentile75 = sliding_mean(
percentile75, window=window_size)
if clip_plot_value is not None:
percentile25 = np.clip(percentile25,
-clip_plot_value,
clip_plot_value)
percentile50 = np.clip(percentile50,
-clip_plot_value,
clip_plot_value)
percentile75 = np.clip(percentile75,
-clip_plot_value,
clip_plot_value)
to_plot.append(
ext.AttrDict(
percentile25=percentile25,
percentile50=percentile50,
percentile75=percentile75,
legend=legend_post_processor(group_legend)))
else:
means = np.nanmean(progresses, axis=0)
stds = np.nanstd(progresses, axis=0)
if smooth_curve:
means = sliding_mean(means, window=window_size)
stds = sliding_mean(stds, window=window_size)
if clip_plot_value is not None:
means = np.clip(means, -clip_plot_value,
clip_plot_value)
stds = np.clip(stds, -clip_plot_value,
clip_plot_value)
to_plot.append(
ext.AttrDict(
means=means,
stds=stds,
legend=legend_post_processor(group_legend)))
if to_plot and not gen_eps:
fig_title = "%s: %s" % (split_key, split_legend)
# plots.append("<h3>%s</h3>" % fig_title)
plots.append(
make_plot(
to_plot,
use_median=use_median,
title=fig_title,
plot_width=plot_width,
plot_height=plot_height))
if gen_eps:
make_plot_eps(to_plot, use_median=use_median, counter=counter)
counter += 1
return "\n".join(plots)
[docs]def parse_float_arg(args, key):
x = args.get(key, "")
try:
return float(x)
except Exception:
return None
[docs]@app.route("/plot_div")
def plot_div():
# reload_data()
args = flask.request.args
plot_key = args.get("plot_key")
split_key = args.get("split_key", "")
group_key = args.get("group_key", "")
filters_json = args.get("filters", "{}")
filters = json.loads(filters_json)
if split_key:
split_key = None
if group_key:
group_key = None
# group_key = distinct_params[0][0]
# print split_key
# exp_filter = distinct_params[0]
use_median = args.get("use_median", "") == 'True'
gen_eps = args.get("eps", "") == 'True'
only_show_best = args.get("only_show_best", "") == 'True'
only_show_best_final = args.get("only_show_best_final", "") == 'True'
only_show_best_sofar = args.get("only_show_best_sofar", "") == 'True'
normalize_error = args.get("normalize_error", "") == 'True'
filter_nan = args.get("filter_nan", "") == 'True'
smooth_curve = args.get("smooth_curve", "") == 'True'
clip_plot_value = parse_float_arg(args, "clip_plot_value")
plot_width = parse_float_arg(args, "plot_width")
plot_height = parse_float_arg(args, "plot_height")
custom_filter = args.get("custom_filter", None)
custom_series_splitter = args.get("custom_series_splitter", None)
if custom_filter is not None and custom_filter.strip():
custom_filter = safer_eval(custom_filter)
else:
custom_filter = None
legend_post_processor = args.get("legend_post_processor", None)
if legend_post_processor and legend_post_processor.strip():
legend_post_processor = safer_eval(legend_post_processor)
else:
legend_post_processor = None
if custom_series_splitter and custom_series_splitter.strip():
custom_series_splitter = safer_eval(custom_series_splitter)
else:
custom_series_splitter = None
plot_div = get_plot_instruction(
plot_key=plot_key,
split_key=split_key,
filter_nan=filter_nan,
group_key=group_key,
filters=filters,
use_median=use_median,
gen_eps=gen_eps,
only_show_best=only_show_best,
only_show_best_final=only_show_best_final,
only_show_best_sofar=only_show_best_sofar,
clip_plot_value=clip_plot_value,
plot_width=plot_width,
plot_height=plot_height,
smooth_curve=smooth_curve,
custom_filter=custom_filter,
legend_post_processor=legend_post_processor,
normalize_error=normalize_error,
custom_series_splitter=custom_series_splitter)
# print plot_div
return plot_div
[docs]def safer_eval(some_string):
"""
Not full-proof, but taking advice from:
https://nedbatchelder.com/blog/201206/eval_really_is_dangerous.html
"""
if "__" in some_string or "import" in some_string:
raise Exception("string to eval looks suspicious")
return eval(some_string, {'__builtins__': {}})
[docs]@app.route("/")
def index():
# exp_folder_path = "data/s3/experiments/ppo-atari-3"
# _load_data(exp_folder_path)
# exp_json = json.dumps(exp_data)
if "AverageReturn" in plottable_keys:
plot_key = "AverageReturn"
elif plottable_keys:
plot_key = plottable_keys[0]
else:
plot_key = None
if distinct_params:
group_key = distinct_params[0][0]
else:
group_key = None
plot_div = get_plot_instruction(
plot_key=plot_key, split_key=None, group_key=group_key)
return flask.render_template(
"main.html",
plot_div=plot_div,
plot_key=plot_key,
group_key=group_key,
plottable_keys=plottable_keys,
distinct_param_keys=[str(k) for k, v in distinct_params],
distinct_params=dict(
[(str(k), list(map(str, v))) for k, v in distinct_params]),
)
[docs]def reload_data():
global exps_data
global plottable_keys
global distinct_params
exps_data = core.load_exps_data(args.data_paths, args.disable_variant)
plottable_keys = list(
set(flatten(list(exp.progress.keys()) for exp in exps_data)))
plottable_keys = sorted([k for k in plottable_keys if k is not None])
distinct_params = sorted(core.extract_distinct_params(exps_data))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("data_paths", type=str, nargs='*')
parser.add_argument("--prefix", type=str, nargs='?', default="???")
parser.add_argument("--debug", action="store_true", default=False)
parser.add_argument("--port", type=int, default=5000)
parser.add_argument(
"--disable-variant", default=False, action='store_true')
args = parser.parse_args(sys.argv[1:])
# load all folders following a prefix
if args.prefix != "???":
args.data_paths = []
dirname = os.path.dirname(args.prefix)
subdirprefix = os.path.basename(args.prefix)
for subdirname in os.listdir(dirname):
path = os.path.join(dirname, subdirname)
if os.path.isdir(path) and (subdirprefix in subdirname):
args.data_paths.append(path)
print("Importing data from {path}...".format(path=args.data_paths))
reload_data()
# port = 5000
# url = "http://0.0.0.0:{0}".format(port)
print("Done! View http://localhost:%d in your browser" % args.port)
app.run(host='0.0.0.0', port=args.port, debug=args.debug)