ZhengPeng7 commited on
Commit
1b2860f
·
1 Parent(s): fa76dd7

Polish up README.

Browse files
Files changed (4) hide show
  1. .gitignore +0 -142
  2. README.md +2 -2
  3. birefnet.py +1 -1
  4. handler.py +4 -1
.gitignore DELETED
@@ -1,142 +0,0 @@
1
- # Custom
2
- e_*
3
- .vscode
4
- ckpt
5
- preds
6
- evaluation/eval-*
7
- nohup.out*
8
- tmp*
9
- *.pth
10
- core-*-python-*
11
- .DS_Store
12
- __MACOSX/
13
-
14
- # Byte-compiled / optimized / DLL files
15
- __pycache__/
16
- *.py[cod]
17
- *$py.class
18
-
19
- # C extensions
20
- *.so
21
-
22
- # Distribution / packaging
23
- .Python
24
- build/
25
- develop-eggs/
26
- dist/
27
- downloads/
28
- eggs/
29
- .eggs/
30
- lib/
31
- lib64/
32
- parts/
33
- sdist/
34
- var/
35
- wheels/
36
- pip-wheel-metadata/
37
- share/python-wheels/
38
- *.egg-info/
39
- .installed.cfg
40
- *.egg
41
- MANIFEST
42
-
43
- # PyInstaller
44
- # Usually these files are written by a python script from a template
45
- # before PyInstaller builds the exe, so as to inject date/other infos into it.
46
- *.manifest
47
- *.spec
48
-
49
- # Installer logs
50
- pip-log.txt
51
- pip-delete-this-directory.txt
52
-
53
- # Unit test / coverage reports
54
- htmlcov/
55
- .tox/
56
- .nox/
57
- .coverage
58
- .coverage.*
59
- .cache
60
- nosetests.xml
61
- coverage.xml
62
- *.cover
63
- *.py,cover
64
- .hypothesis/
65
- .pytest_cache/
66
-
67
- # Translations
68
- *.mo
69
- *.pot
70
-
71
- # Django stuff:
72
- *.log
73
- local_settings.py
74
- db.sqlite3
75
- db.sqlite3-journal
76
-
77
- # Flask stuff:
78
- instance/
79
- .webassets-cache
80
-
81
- # Scrapy stuff:
82
- .scrapy
83
-
84
- # Sphinx documentation
85
- docs/_build/
86
-
87
- # PyBuilder
88
- target/
89
-
90
- # Jupyter Notebook
91
- .ipynb_checkpoints
92
-
93
- # IPython
94
- profile_default/
95
- ipython_config.py
96
-
97
- # pyenv
98
- .python-version
99
-
100
- # pipenv
101
- # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
102
- # However, in case of collaboration, if having platform-specific dependencies or dependencies
103
- # having no cross-platform support, pipenv may install dependencies that don't work, or not
104
- # install all needed dependencies.
105
- #Pipfile.lock
106
-
107
- # PEP 582; used by e.g. github.com/David-OConnor/pyflow
108
- __pypackages__/
109
-
110
- # Celery stuff
111
- celerybeat-schedule
112
- celerybeat.pid
113
-
114
- # SageMath parsed files
115
- *.sage.py
116
-
117
- # Environments
118
- .env
119
- .venv
120
- env/
121
- venv/
122
- ENV/
123
- env.bak/
124
- venv.bak/
125
-
126
- # Spyder project settings
127
- .spyderproject
128
- .spyproject
129
-
130
- # Rope project settings
131
- .ropeproject
132
-
133
- # mkdocs documentation
134
- /site
135
-
136
- # mypy
137
- .mypy_cache/
138
- .dmypy.json
139
- dmypy.json
140
-
141
- # Pyre type checker
142
- .pyre/
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -17,8 +17,8 @@ license: mit
17
  ### Performance:
18
  | Dataset | Method | maxFm | wFmeasure | MAE | Smeasure | meanEm | HCE | maxEm | meanFm | adpEm | adpFm | mBA | maxBIoU | meanBIoU |
19
  | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: |
20
- | DIS-VD | BiRefNet_HR-general-epoch_130 | .925 | .894 | .026 | .927 | .952 | 811 | .960 | .909 | .944 | .888 | .828 | .837 | .817 |
21
- | DIS-VD | BiRefNet-general-epoch_244 | .907 | .875 | .033 | .911 | .943 | 1069 | .953 | .892 | .944 | .879 | .000 | .000 | .000 |
22
 
23
  <h1 align="center">Bilateral Reference for High-Resolution Dichotomous Image Segmentation</h1>
24
 
 
17
  ### Performance:
18
  | Dataset | Method | maxFm | wFmeasure | MAE | Smeasure | meanEm | HCE | maxEm | meanFm | adpEm | adpFm | mBA | maxBIoU | meanBIoU |
19
  | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: | :------: |
20
+ | DIS-VD | **BiRefNet_HR**-general-epoch_130 | .925 | .894 | .026 | .927 | .952 | 811 | .960 | .909 | .944 | .888 | .828 | .837 | .817 |
21
+ | DIS-VD | [**BiRefNet**-general-epoch_244](https://huggingface.co/ZhengPeng7/BiRefNet) | .907 | .875 | .033 | .911 | .943 | 1069 | .953 | .892 | .944 | .879 | .000 | .000 | .000 |
22
 
23
  <h1 align="center">Bilateral Reference for High-Resolution Dichotomous Image Segmentation</h1>
24
 
birefnet.py CHANGED
@@ -52,7 +52,7 @@ class Config():
52
  }[self.task]
53
  ][1] # choose 0 to skip
54
  self.lr = (1e-4 if 'DIS5K' in self.task else 1e-5) * math.sqrt(self.batch_size / 4) # DIS needs high lr to converge faster. Adapt the lr linearly
55
- self.size = 1024 * 2
56
  self.num_workers = max(4, self.batch_size) # will be decrease to min(it, batch_size) at the initialization of the data_loader
57
 
58
  # Backbone settings
 
52
  }[self.task]
53
  ][1] # choose 0 to skip
54
  self.lr = (1e-4 if 'DIS5K' in self.task else 1e-5) * math.sqrt(self.batch_size / 4) # DIS needs high lr to converge faster. Adapt the lr linearly
55
+ self.size = 1024
56
  self.num_workers = max(4, self.batch_size) # will be decrease to min(it, batch_size) at the initialization of the data_loader
57
 
58
  # Backbone settings
handler.py CHANGED
@@ -88,6 +88,7 @@ elif usage in ['General-HR']:
88
  else:
89
  resolution = (1024, 1024)
90
 
 
91
 
92
  class EndpointHandler():
93
  def __init__(self, path=''):
@@ -96,6 +97,8 @@ class EndpointHandler():
96
  )
97
  self.birefnet.to(device)
98
  self.birefnet.eval()
 
 
99
 
100
  def __call__(self, data: Dict[str, Any]):
101
  """
@@ -125,7 +128,7 @@ class EndpointHandler():
125
 
126
  # Prediction
127
  with torch.no_grad():
128
- preds = self.birefnet(image_proc.to(device))[-1].sigmoid().cpu()
129
  pred = preds[0].squeeze()
130
 
131
  # Show Results
 
88
  else:
89
  resolution = (1024, 1024)
90
 
91
+ half_precision = True
92
 
93
  class EndpointHandler():
94
  def __init__(self, path=''):
 
97
  )
98
  self.birefnet.to(device)
99
  self.birefnet.eval()
100
+ if half_precision:
101
+ self.birefnet.half()
102
 
103
  def __call__(self, data: Dict[str, Any]):
104
  """
 
128
 
129
  # Prediction
130
  with torch.no_grad():
131
+ preds = self.birefnet(image_proc.to(device).half() if half_precision else image_proc.to(device))[-1].sigmoid().cpu()
132
  pred = preds[0].squeeze()
133
 
134
  # Show Results