File size: 5,103 Bytes
eea405a
 
 
 
2b02896
eea405a
 
cc6d57f
eea405a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
add58d5
eea405a
 
add58d5
 
 
eea405a
 
 
 
 
 
 
2b02896
 
 
 
 
 
 
 
 
 
 
b4cf22f
2b02896
 
eea405a
 
 
 
 
add58d5
 
 
b4cf22f
add58d5
 
 
 
 
 
 
 
 
b4cf22f
add58d5
eea405a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b4cf22f
add58d5
eea405a
 
 
b4cf22f
eea405a
 
 
 
 
cc6d57f
b4cf22f
 
 
 
 
eea405a
 
 
 
 
 
 
b4cf22f
eea405a
 
 
 
 
b4cf22f
 
 
 
eea405a
 
 
 
b4cf22f
eea405a
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
# pylint: disable=no-member
import gradio as gr
import requests
from huggingface_hub import HfApi
from huggingface_hub.errors import RepositoryNotFoundError
import pandas as pd
import plotly.express as px
from gradio_huggingfacehub_search import HuggingfaceHubSearch

HF_API = HfApi()


def format_repo_size(r_size: int) -> str:
    units = {0: "B", 1: "KB", 2: "MB", 3: "GB", 4: "TB", 5: "PB"}
    order = 0
    while r_size >= 1024 and order < len(units) - 1:
        r_size /= 1024
        order += 1
    return f"{r_size:.2f} {units[order]}"


def repo_files(r_type: str, r_id: str) -> dict:
    r_info = HF_API.repo_info(repo_id=r_id, repo_type=r_type, files_metadata=True)
    files = {}
    for sibling in r_info.siblings:
        ext = sibling.rfilename.split(".")[-1]
        if ext in files:
            files[ext]["size"] += sibling.size
            files[ext]["count"] += 1
        else:
            files[ext] = {}
            files[ext]["size"] = sibling.size
            files[ext]["count"] = 1
    return files


def repo_size(r_type, r_id):
    r_refs = HF_API.list_repo_refs(repo_id=r_id, repo_type=r_type)
    repo_sizes = {}
    for branch in r_refs.branches:
        try:
            response = requests.get(
                f"https://huggingface.co/api/{r_type}s/{r_id}/treesize/{branch.name}",
                timeout=1000,
            )
            response = response.json()
            # get the status code
        except Exception:
            response = {}
        if response.get("error") and "restricted" in response.get("error"):
            gr.Warning(f"Branch information for {r_id} not available.")
            return {}
        size = response.get("size")
        if size is not None:
            repo_sizes[branch.name] = size
    return repo_sizes


def get_repo_info(r_type, r_id):
    try:
        repo_sizes = repo_size(r_type, r_id)
        repo_files_info = repo_files(r_type, r_id)
    except RepositoryNotFoundError:
        gr.Warning(
            "Repository not found. Make sure you've entered a valid repo ID and type that corresponds to the repository."
        )
        return (
            gr.Row(visible=False),
            gr.Dataframe(visible=False),
            gr.Plot(visible=False),
            gr.Row(visible=False),
            gr.Dataframe(visible=False),
        )
    rf_sizes_df = (
        pd.DataFrame(repo_files_info)
        .T.reset_index(names="ext")
        .sort_values(by="size", ascending=False)
    )
    # check if repo_sizes is just {}
    if not repo_sizes:
        r_sizes_component = gr.Dataframe(visible=False)
        b_block = gr.Row(visible=False)
    else:
        r_sizes_df = pd.DataFrame(repo_sizes, index=["size"]).T.reset_index(
            names="branch"
        )
        r_sizes_df["formatted_size"] = r_sizes_df["size"].apply(format_repo_size)
        r_sizes_df.columns = ["Branch", "bytes", "Size"]
        r_sizes_component = gr.Dataframe(
            value=r_sizes_df[["Branch", "Size"]], visible=True
        )
        b_block = gr.Row(visible=True)

    rf_sizes_df["formatted_size"] = rf_sizes_df["size"].apply(format_repo_size)
    rf_sizes_df.columns = ["Extension", "bytes", "Count", "Size"]
    rf_sizes_plot = px.pie(
        rf_sizes_df,
        values="bytes",
        names="Extension",
        hover_data=["Size"],
        title=f"File Distribution in {r_id}",
        hole=0.3,
    )
    return (
        gr.Row(visible=True),
        gr.Dataframe(
            value=rf_sizes_df[["Extension", "Count", "Size"]],
            visible=True,
        ),
        gr.Plot(rf_sizes_plot, visible=True),
        b_block,
        r_sizes_component,
    )


with gr.Blocks(theme="ocean") as demo:
    gr.Markdown("# Repository Information")
    gr.Markdown(
        "Enter a repository ID and repository type and get back information about the repository's files and branches."
    )
    with gr.Blocks():
        # repo_id = gr.Textbox(label="Repository ID", placeholder="123456")
        repo_id = HuggingfaceHubSearch(
            label="Hub Model ID",
            placeholder="Search for model id on Huggingface",
            search_type=["model", "dataset"],
        )
        repo_type = gr.Radio(
            choices=["model", "dataset", "space"],
            label="Repository Type",
            value="model",
        )
        search_button = gr.Button(value="Search")
    with gr.Blocks():
        with gr.Row(visible=False) as results_block:
            with gr.Column():
                gr.Markdown("## File Information")
                with gr.Row():
                    file_info = gr.Dataframe(visible=False)
                    file_info_plot = gr.Plot(visible=False)
                with gr.Row(visible=False) as branch_block:
                    with gr.Column():
                        gr.Markdown("## Branch Sizes")
                        branch_sizes = gr.Dataframe(visible=False)

    search_button.click(
        get_repo_info,
        inputs=[repo_type, repo_id],
        outputs=[results_block, file_info, file_info_plot, branch_block, branch_sizes],
    )

demo.launch()