sitammeur commited on
Commit
8606cde
·
verified ·
1 Parent(s): c7cdc30

Upload 6 files

Browse files
Files changed (6) hide show
  1. app.py +33 -0
  2. src/__init__.py +0 -0
  3. src/app/__init__.py +0 -0
  4. src/app/predict.py +49 -0
  5. src/exception.py +50 -0
  6. src/logger.py +21 -0
app.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Necessary imports
2
+ import gradio as gr
3
+ from src.app.predict import ZeroShotTextClassification
4
+
5
+
6
+ # Examples to display in the interface
7
+ examples = [
8
+ ["I love to play the guitar", "music, artist, food, travel"],
9
+ ["I am a software engineer at Google", "technology, engineering, art, science"],
10
+ ["I am a professional basketball player", "sports, athlete, chef, politics"],
11
+ ]
12
+
13
+ # Title and description and article for the interface
14
+ title = "Zero Shot Text Classification"
15
+ description = "Classify text using zero-shot classification with ModernBERT-large zeroshot model! Provide a text input and a list of candidate labels separated by commas. Read more at the links below."
16
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2412.13663' target='_blank'>Smarter, Better, Faster, Longer: A Modern Bidirectional Encoder for Fast, Memory Efficient, and Long Context Finetuning and Inference</a> | <a href='https://huggingface.co/MoritzLaurer/ModernBERT-large-zeroshot-v2.0' target='_blank'>Model Page</a></p>"
17
+
18
+
19
+ # Launch the interface
20
+ demo = gr.Interface(
21
+ fn=ZeroShotTextClassification,
22
+ inputs=[gr.Textbox(label="Input"), gr.Textbox(label="Candidate Labels")],
23
+ outputs=gr.Label(label="Classification"),
24
+ title=title,
25
+ description=description,
26
+ article=article,
27
+ examples=examples,
28
+ cache_examples=True,
29
+ cache_mode="lazy",
30
+ theme="Soft",
31
+ flagging_mode="never",
32
+ )
33
+ demo.launch(debug=False)
src/__init__.py ADDED
File without changes
src/app/__init__.py ADDED
File without changes
src/app/predict.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Necessary imports
2
+ import sys
3
+ from typing import Dict
4
+ from src.logger import logging
5
+ from src.exception import CustomExceptionHandling
6
+ from transformers import pipeline
7
+
8
+
9
+ # Load the zero-shot classification model
10
+ classifier = pipeline(
11
+ "zero-shot-classification", model="MoritzLaurer/ModernBERT-large-zeroshot-v2.0"
12
+ )
13
+
14
+
15
+ def ZeroShotTextClassification(
16
+ text_input: str, candidate_labels: str
17
+ ) -> Dict[str, float]:
18
+ """
19
+ Performs zero-shot classification on the given text input.
20
+
21
+ Args:
22
+ - text_input: The input text to classify.
23
+ - candidate_labels: A comma-separated string of candidate labels.
24
+
25
+ Returns:
26
+ Dictionary containing label-score pairs.
27
+ """
28
+ try:
29
+ # Split and clean the candidate labels
30
+ labels = [label.strip() for label in candidate_labels.split(",")]
31
+
32
+ # Log the classification attempt
33
+ logging.info(f"Attempting classification with {len(labels)} labels")
34
+
35
+ # Perform zero-shot classification
36
+ classifier = pipeline("zero-shot-classification")
37
+ prediction = classifier(text_input, labels)
38
+
39
+ # Return the classification results
40
+ logging.info("Classification completed successfully")
41
+ return {
42
+ prediction["labels"][i]: prediction["scores"][i]
43
+ for i in range(len(prediction["labels"]))
44
+ }
45
+
46
+ # Handle exceptions that may occur during the process
47
+ except Exception as e:
48
+ # Custom exception handling
49
+ raise CustomExceptionHandling(e, sys) from e
src/exception.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This module defines a custom exception handling class and a function to get error message with details of the error.
3
+ """
4
+
5
+ # Standard Library
6
+ import sys
7
+
8
+ # Local imports
9
+ from src.logger import logging
10
+
11
+
12
+ # Function Definition to get error message with details of the error (file name and line number) when an error occurs in the program
13
+ def get_error_message(error, error_detail: sys):
14
+ """
15
+ Get error message with details of the error.
16
+
17
+ Args:
18
+ - error (Exception): The error that occurred.
19
+ - error_detail (sys): The details of the error.
20
+
21
+ Returns:
22
+ str: A string containing the error message along with the file name and line number where the error occurred.
23
+ """
24
+ _, _, exc_tb = error_detail.exc_info()
25
+
26
+ # Get error details
27
+ file_name = exc_tb.tb_frame.f_code.co_filename
28
+ return "Error occured in python script name [{0}] line number [{1}] error message[{2}]".format(
29
+ file_name, exc_tb.tb_lineno, str(error)
30
+ )
31
+
32
+
33
+ # Custom Exception Handling Class Definition
34
+ class CustomExceptionHandling(Exception):
35
+ """
36
+ Custom Exception Handling:
37
+ This class defines a custom exception that can be raised when an error occurs in the program.
38
+ It takes an error message and an error detail as input and returns a formatted error message when the exception is raised.
39
+ """
40
+
41
+ # Constructor
42
+ def __init__(self, error_message, error_detail: sys):
43
+ """Initialize the exception"""
44
+ super().__init__(error_message)
45
+
46
+ self.error_message = get_error_message(error_message, error_detail=error_detail)
47
+
48
+ def __str__(self):
49
+ """String representation of the exception"""
50
+ return self.error_message
src/logger.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Importing the required modules
2
+ import os
3
+ import logging
4
+ from datetime import datetime
5
+
6
+ # Creating a log file with the current date and time as the name of the file
7
+ LOG_FILE = f"{datetime.now().strftime('%m_%d_%Y_%H_%M_%S')}.log"
8
+
9
+ # Creating a logs folder if it does not exist
10
+ logs_path = os.path.join(os.getcwd(), "logs", LOG_FILE)
11
+ os.makedirs(logs_path, exist_ok=True)
12
+
13
+ # Setting the log file path and the log level
14
+ LOG_FILE_PATH = os.path.join(logs_path, LOG_FILE)
15
+
16
+ # Configuring the logger
17
+ logging.basicConfig(
18
+ filename=LOG_FILE_PATH,
19
+ format="[ %(asctime)s ] %(lineno)d %(name)s - %(levelname)s - %(message)s",
20
+ level=logging.INFO,
21
+ )