山越貴耀 commited on
Commit
f0842b2
·
1 Parent(s): a71e38a

fixed axis lims

Browse files
Files changed (1) hide show
  1. app.py +10 -8
app.py CHANGED
@@ -113,12 +113,14 @@ def pre_render_images(df,input_sent_id):
113
  ymax,ymin = (max(y_tsne)//yscale_unit+1)*yscale_unit,(min(y_tsne)//yscale_unit-1)*yscale_unit
114
  color_list = sns.color_palette('flare',n_colors=int(len(df)*1.2))
115
  sent_list = []
 
116
  fig_production = st.progress(0)
117
  for fig_id,sent_id in enumerate(sent_id_options):
118
  fig_production.progress(fig_id+1)
119
- plot_fig(fig_id,x_tsne,y_tsne,sent_id,[xmin,xmax],[ymin,ymax],color_list)
120
  sent_list.append(df.cleaned_sentence.to_list()[sent_id])
121
- return sent_list
 
122
 
123
 
124
  if __name__=='__main__':
@@ -190,19 +192,19 @@ if __name__=='__main__':
190
  x_tsne, y_tsne = df.x_tsne, df.y_tsne
191
  xscale_unit = (max(x_tsne)-min(x_tsne))/10
192
  yscale_unit = (max(y_tsne)-min(y_tsne))/10
193
- xmax,xmin = (max(x_tsne)//xscale_unit+1)*xscale_unit,(min(x_tsne)//xscale_unit-1)*xscale_unit
194
- ymax,ymin = (max(y_tsne)//yscale_unit+1)*yscale_unit,(min(y_tsne)//yscale_unit-1)*yscale_unit
195
  color_list = sns.color_palette('flare',n_colors=1200)
196
  fig_production = st.progress(0)
197
 
198
- img = plot_fig(df,0,[xmin,xmax],[ymin,ymax],color_list)
199
  #img = cv2.imread('figures/0.png')
200
  height, width, layers = img.shape
201
  size = (width,height)
202
  out = cv2.VideoWriter('sampling_video.mp4',cv2.VideoWriter_fourcc(*'H264'), 3, size)
203
  for sent_id in range(1000):
204
  fig_production.progress((sent_id+1)/1000)
205
- img = plot_fig(df,sent_id,[xmin,xmax],[ymin,ymax],color_list)
206
  #img = cv2.imread(f'figures/{sent_id}.png')
207
  out.write(img)
208
  out.release()
@@ -223,8 +225,8 @@ if __name__=='__main__':
223
  x_tsne, y_tsne = df.x_tsne, df.y_tsne
224
  xscale_unit = (max(x_tsne)-min(x_tsne))/10
225
  yscale_unit = (max(y_tsne)-min(y_tsne))/10
226
- xmax,xmin = (max(x_tsne)//xscale_unit+1)*xscale_unit,(min(x_tsne)//xscale_unit-1)*xscale_unit
227
- ymax,ymin = (max(y_tsne)//yscale_unit+1)*yscale_unit,(min(y_tsne)//yscale_unit-1)*yscale_unit
228
  color_list = sns.color_palette('flare',n_colors=int(len(df)*1.2))
229
 
230
  fig = plt.figure(figsize=(5,5),dpi=200)
 
113
  ymax,ymin = (max(y_tsne)//yscale_unit+1)*yscale_unit,(min(y_tsne)//yscale_unit-1)*yscale_unit
114
  color_list = sns.color_palette('flare',n_colors=int(len(df)*1.2))
115
  sent_list = []
116
+ fig_list = []
117
  fig_production = st.progress(0)
118
  for fig_id,sent_id in enumerate(sent_id_options):
119
  fig_production.progress(fig_id+1)
120
+ img = plot_fig(df,sent_id,[xmin,xmax],[ymin,ymax],color_list)
121
  sent_list.append(df.cleaned_sentence.to_list()[sent_id])
122
+ fig_list.append(img)
123
+ return sent_list,fig_list
124
 
125
 
126
  if __name__=='__main__':
 
192
  x_tsne, y_tsne = df.x_tsne, df.y_tsne
193
  xscale_unit = (max(x_tsne)-min(x_tsne))/10
194
  yscale_unit = (max(y_tsne)-min(y_tsne))/10
195
+ xlims = [(max(x_tsne)//xscale_unit+1)*xscale_unit,(min(x_tsne)//xscale_unit-1)*xscale_unit]
196
+ ylims = [(max(y_tsne)//yscale_unit+1)*yscale_unit,(min(y_tsne)//yscale_unit-1)*yscale_unit]
197
  color_list = sns.color_palette('flare',n_colors=1200)
198
  fig_production = st.progress(0)
199
 
200
+ img = plot_fig(df,0,xlims,ylims,color_list)
201
  #img = cv2.imread('figures/0.png')
202
  height, width, layers = img.shape
203
  size = (width,height)
204
  out = cv2.VideoWriter('sampling_video.mp4',cv2.VideoWriter_fourcc(*'H264'), 3, size)
205
  for sent_id in range(1000):
206
  fig_production.progress((sent_id+1)/1000)
207
+ img = plot_fig(df,sent_id,xlims,ylims,color_list)
208
  #img = cv2.imread(f'figures/{sent_id}.png')
209
  out.write(img)
210
  out.release()
 
225
  x_tsne, y_tsne = df.x_tsne, df.y_tsne
226
  xscale_unit = (max(x_tsne)-min(x_tsne))/10
227
  yscale_unit = (max(y_tsne)-min(y_tsne))/10
228
+ xlims = [(max(x_tsne)//xscale_unit+1)*xscale_unit,(min(x_tsne)//xscale_unit-1)*xscale_unit]
229
+ ylims = [(max(y_tsne)//yscale_unit+1)*yscale_unit,(min(y_tsne)//yscale_unit-1)*yscale_unit]
230
  color_list = sns.color_palette('flare',n_colors=int(len(df)*1.2))
231
 
232
  fig = plt.figure(figsize=(5,5),dpi=200)