foz commited on
Commit
b53e5c3
·
1 Parent(s): 388359f

First demo

Browse files
Files changed (2) hide show
  1. analyse.py +157 -0
  2. app.py +45 -0
analyse.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import lxml.etree as ET
4
+ import gzip
5
+ import tifffile
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ from PIL import Image, ImageDraw
9
+ import pandas as pd
10
+
11
+
12
+ def get_paths_from_traces_file(traces_file):
13
+ tree = ET.parse(traces_file)
14
+
15
+ root = tree.getroot()
16
+ all_paths = []
17
+ path_lengths = []
18
+ for path in root.findall('path'):
19
+ length=path.get('reallength')
20
+ path_points = []
21
+ for point in path:
22
+ path_points.append((int(point.get('x')), int(point.get('y')), int(point.get('z'))))
23
+ all_paths.append(path_points)
24
+ path_lengths.append(length)
25
+ return all_paths, path_lengths
26
+
27
+ def visualise_ordering(points_list, dim):
28
+ rdim, cdim, _ = dim
29
+ vis = np.zeros((rdim, cdim, 3), dtype=np.uint8)
30
+
31
+ def get_col(i):
32
+ r = int(255 * i/len(points_list))
33
+ g = 255 - r
34
+ return r, g, 0
35
+
36
+ for n, p in enumerate(points_list):
37
+ c, r, _ = p
38
+ wr, wc = 5, 5
39
+ vis[max(0,r-wr):min(rdim,r+wr),max(0,c-wc):min(cdim,c+wc)] = get_col(n)
40
+
41
+ return vis
42
+
43
+ col_map = [(255,0,0), (0,255,0), (0,0,255), (255,255,0), (255,0,255), (0,255,255)]
44
+
45
+ def draw_paths(all_paths, foci_stack):
46
+ im = np.max(foci_stack, axis=0)
47
+ im = (im/np.max(im)*255).astype(np.uint8)
48
+ im = np.dstack((im,)*3)
49
+ im = Image.fromarray(im) #.convert('RGB')
50
+ draw = ImageDraw.Draw(im)
51
+ for i, (p, col) in enumerate(zip(all_paths, col_map)):
52
+ draw.line([(u[0], u[1]) for u in p], fill=col)
53
+ draw.text((p[0][0], p[0][1]), str(i+1), fill=col)
54
+ return im
55
+
56
+
57
+ # Sum of measure_stack over regin where mask==1
58
+ def measure_from_mask(mask, measure_stack):
59
+ return np.sum(mask * measure_stack)
60
+
61
+ # Max of measure_stack over region where mask==1
62
+ def max_from_mask(mask, measure_stack):
63
+ return np.max(mask * measure_stack)
64
+
65
+
66
+ # Translate mask to point p, treating makss near stack edges correctly
67
+ def make_mask_s(p, melem, measure_stack):
68
+ mask = melem
69
+
70
+ R = melem.shape[0] // 2
71
+ r, c, z = p
72
+
73
+ m_data = np.zeros(melem.shape)
74
+ s = measure_stack.shape
75
+ o_1, o_2, o_3 = max(R-r, 0), max(R-c, 0), max(R-z,0)
76
+ e_1, e_2, e_3 = min(R-r+s[0], 2*R), min(R-c+s[1], 2*R), min(R-z+s[2], 2*R)
77
+ m_data[o_1:e_1,o_2:e_2,o_3:e_3] = measure_stack[max(r-R,0):min(r+R,s[0]),max(c-R,0):min(c+R,s[1]),max(z-R,0):min(z+R, s[2])]
78
+ return mask, m_data
79
+
80
+ # Measure the (mean/max) value of measure_stack about the point p, using
81
+ # the structuring element melem. op indicates the appropriate measurement (mean/max)
82
+ def measure_at_point(p, melem, measure_stack, op='mean'):
83
+ if op=='mean':
84
+ mask, m_data = make_mask_s(p, melem, measure_stack)
85
+ melem_size = np.sum(melem)
86
+ return float(measure_from_mask(mask, m_data) / melem_size)
87
+ else:
88
+ mask, m_data = make_mask_s(p, melem, measure_stack)
89
+ return float(max_from_mask(mask, m_data))
90
+
91
+ # Generate spherical region
92
+ def make_sphere(R=5, z_scale_ratio=2.3):
93
+ x, y, z = np.ogrid[-R:R, -R:R, -R:R]
94
+ sphere = x**2 + y**2 + (z_scale_ratio * z)**2 < R**2
95
+ return sphere
96
+
97
+ # Measure the values of measure_stack at each of the points of points_list in turn.
98
+ # Measurement is the mean / max (specified by op) on the spherical region about each point
99
+ def measure_all_with_sphere(points_list, measure_stack, op='mean'):
100
+ melem = make_sphere()
101
+ measure_func = lambda p: measure_at_point(p, melem, measure_stack, op)
102
+ return list(map(measure_func, points_list))
103
+
104
+
105
+ # Measure fluorescence levels along ordered skeleton
106
+ def measure_chrom2(path, hei10):
107
+ # single chrom - structure containing skeleton (single_chrom.skel) and
108
+ # fluorecence levels (single_chrom.hei10) as Image3D objects (equivalent to ndarray)
109
+ # Returns list of coordinates in skeleton, the ordered path
110
+ vis = visualise_ordering(path, dim=hei10.shape)
111
+
112
+ measurements = measure_all_with_sphere(path, hei10, op='mean')
113
+ measurements_max = measure_all_with_sphere(path, hei10, op='max')
114
+
115
+ return vis, measurements, measurements_max
116
+
117
+ def extract_peaks(cell_id, all_paths, path_lengths, measured_traces):
118
+
119
+ n = len(all_paths)
120
+
121
+
122
+ #headers = ['Cell_ID', 'Trace', 'Trace_length(um)', 'detection_sphere_radius(um)', 'Foci_ID_threshold', 'Foci_per_trace']
123
+ #for i in range(max_n):
124
+ # headers += [f'Foci{i}_relative_intensity', f'Foci_{i}_position(um)']
125
+
126
+ data_dict = {}
127
+ data_dict['Cell_ID'] = [cell_id]*n
128
+ data_dict['Trace'] = range(1, n+1)
129
+ data_dict['Trace_length(um)'] = path_lengths
130
+ data_dict['Detection_sphere_radius(um)'] = [0.2]*n
131
+ data_dict['Foci_ID_threshold'] = [0.4]*n
132
+
133
+
134
+
135
+ return pd.DataFrame(data_dict)
136
+
137
+
138
+ def analyse_paths(cell_id, foci_file, traces_file):
139
+ foci_stack = tifffile.imread(foci_file)
140
+ all_paths, path_lengths = get_paths_from_traces_file(traces_file)
141
+
142
+ all_trace_vis = []
143
+ all_m = []
144
+ for p in all_paths:
145
+ vis, m, _ = measure_chrom2(p,foci_stack.transpose(2,1,0))
146
+ all_trace_vis.append(vis)
147
+ all_m.append(m)
148
+
149
+ trace_overlay = draw_paths(all_paths, foci_stack)
150
+
151
+ fig, ax = plt.subplots(len(all_paths),1)
152
+ for i, m in enumerate(all_m):
153
+ ax[i].plot(m)
154
+
155
+ extracted_peaks = extract_peaks(cell_id, all_paths, path_lengths, all_m)
156
+
157
+ return trace_overlay, all_trace_vis, fig, extracted_peaks
app.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+ from tifffile import imread
4
+ from PIL import Image
5
+ import matplotlib.pyplot as plt
6
+ from analyse import analyse_paths
7
+ import numpy as np
8
+
9
+ def process(cell_id, foci_file, traces_file):
10
+ paths, traces, fig, extracted_peaks = analyse_paths(cell_id, foci_file.name, traces_file.name)
11
+ extracted_peaks.to_csv('tmp')
12
+ return paths, [Image.fromarray(im) for im in traces], fig, extracted_peaks, 'tmp'
13
+
14
+ def preview_image(file1):
15
+ if file1:
16
+ im = imread(file1.name)
17
+ print(im.shape)
18
+ return Image.fromarray(np.max(im, axis=0))
19
+ else:
20
+ return None
21
+
22
+
23
+ with gr.Blocks() as demo:
24
+ with gr.Row():
25
+ with gr.Column():
26
+ cellid_input = gr.Textbox(label="Cell ID", placeholder="Image_1")
27
+ image_input = gr.File(label="Input foci image")
28
+ image_preview = gr.Image(label="Max projection of foci image")
29
+ image_input.change(fn=preview_image, inputs=image_input, outputs=image_preview)
30
+ path_input = gr.File(label="SNT traces file")
31
+
32
+ with gr.Column():
33
+ trace_output = gr.Image(label="Overlayed paths")
34
+ image_output=gr.Gallery(label="Traced paths")
35
+ plot_output=gr.Plot(label="Foci intensity traces")
36
+ data_output=gr.DataFrame(label="Detected peak data")#, "Peak 1 pos", "Peak 1 int"])
37
+ data_file_output=gr.File(label="Output data file (.csv)")
38
+
39
+ with gr.Row():
40
+ greet_btn = gr.Button("Process")
41
+ greet_btn.click(fn=process, inputs=[cellid_input, image_input, path_input], outputs=[trace_output, image_output, plot_output, data_output, data_file_output], api_name="process")
42
+
43
+
44
+ if __name__ == "__main__":
45
+ demo.launch()