LegalAI-DS / translator.py
hkhwilwh
Add application file
0d24772
from transformers import MarianMTModel, MarianTokenizer, pipeline
import torch
from langdetect import detect
import re
class Translator:
def __init__(self):
self.models = {}
self.tokenizers = {}
self.language_codes = {
'arabic': 'ar',
'english': 'en',
'chinese': 'zh',
'hindi': 'hi',
'urdu': 'ur'
}
# Initialize models for each language pair
self._load_model('en', 'ar') # English to Arabic
self._load_model('ar', 'en') # Arabic to English
# Add other language pairs as needed
def _load_model(self, src_lang, tgt_lang):
"""Load translation model for a specific language pair."""
model_name = f'Helsinki-NLP/opus-mt-{src_lang}-{tgt_lang}'
key = f'{src_lang}-{tgt_lang}'
if key not in self.models:
try:
self.tokenizers[key] = MarianTokenizer.from_pretrained(model_name)
self.models[key] = MarianMTModel.from_pretrained(model_name)
except Exception as e:
print(f"Error loading model for {key}: {str(e)}")
def translate(self, text: str, source_lang: str, target_lang: str) -> str:
"""Translate text from source language to target language with improved handling."""
src_code = self.language_codes.get(source_lang.lower())
tgt_code = self.language_codes.get(target_lang.lower())
if not src_code or not tgt_code:
raise ValueError("Unsupported language")
key = f'{src_code}-{tgt_code}'
if key not in self.models:
self._load_model(src_code, tgt_code)
if key not in self.models:
raise ValueError(f"Translation model not available for {source_lang} to {target_lang}")
tokenizer = self.tokenizers[key]
model = self.models[key]
try:
# Preprocess text
text = self.preprocess_text(text)
# Split text into manageable chunks
chunks = self._split_text_into_chunks(text)
translated_chunks = []
for chunk in chunks:
# Clear GPU memory
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Tokenize with improved settings
inputs = tokenizer(
chunk,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512,
add_special_tokens=True
)
# Generate translation with improved settings
with torch.no_grad():
translated = model.generate(
**inputs,
num_beams=2, # Reduced for memory efficiency
length_penalty=0.6,
max_length=512,
min_length=0,
early_stopping=True
)
# Decode the translation
result = tokenizer.batch_decode(translated, skip_special_tokens=True)[0]
translated_chunks.append(result)
# Combine chunks
final_translation = ' '.join(translated_chunks)
# Post-process translation
final_translation = self._post_process_translation(final_translation, target_lang)
return final_translation
except Exception as e:
print(f"Translation error: {str(e)}")
return text # Return original text if translation fails
def detect_language(self, text: str) -> str:
"""Detect the language of the input text."""
try:
# Clean text for better detection
cleaned_text = re.sub(r'[^\w\s]', '', text)
detected = detect(cleaned_text)
# Map detected language code to our supported languages
lang_code_map = {
'ar': 'arabic',
'en': 'english',
'zh': 'chinese',
'hi': 'hindi',
'ur': 'urdu'
}
return lang_code_map.get(detected, 'english') # Default to English if unknown
except:
return 'english' # Default to English if detection fails
def preprocess_text(self, text: str) -> str:
"""Preprocess text before translation."""
# Remove excessive whitespace
text = re.sub(r'\s+', ' ', text).strip()
# Remove special characters that might interfere with translation
text = re.sub(r'[^\w\s\.,!?-]', '', text)
return text
def get_supported_languages(self):
"""Return list of supported languages."""
return list(self.language_codes.keys())
def _split_text_into_chunks(self, text: str, max_chunk_size: int = 450) -> list:
"""Split text into manageable chunks for translation."""
# First try to split by paragraphs
paragraphs = text.split('\n\n')
chunks = []
current_chunk = []
current_length = 0
for para in paragraphs:
# If a single paragraph is too long, split it by sentences
if len(para) > max_chunk_size:
sentences = re.split(r'([.!?])\s+', para)
i = 0
while i < len(sentences):
sentence = sentences[i]
if i + 1 < len(sentences):
sentence += sentences[i + 1] # Add back the punctuation
i += 2
else:
i += 1
if current_length + len(sentence) > max_chunk_size:
if current_chunk:
chunks.append(' '.join(current_chunk))
current_chunk = []
current_length = 0
current_chunk.append(sentence)
current_length += len(sentence)
else:
if current_length + len(para) > max_chunk_size:
chunks.append(' '.join(current_chunk))
current_chunk = []
current_length = 0
current_chunk.append(para)
current_length += len(para)
if current_chunk:
chunks.append(' '.join(current_chunk))
return chunks
def _post_process_translation(self, text: str, target_lang: str) -> str:
"""Post-process translated text based on target language."""
if target_lang.lower() in ['arabic', 'ar']:
# Fix Arabic-specific issues
text = re.sub(r'([ء-ي])\s+([ء-ي])', r'\1\2', text) # Remove spaces between Arabic letters
text = re.sub(r'[\u200B-\u200F\u202A-\u202E]', '', text) # Remove Unicode control characters
# Fix common Arabic punctuation issues
text = text.replace('،,', '،')
text = text.replace('.,', '.')
text = text.replace('؟?', '؟')
text = text.replace('!!', '!')
# Ensure proper spacing around numbers and Latin text
text = re.sub(r'([0-9])([ء-ي])', r'\1 \2', text)
text = re.sub(r'([ء-ي])([0-9])', r'\1 \2', text)
text = re.sub(r'([a-zA-Z])([ء-ي])', r'\1 \2', text)
text = re.sub(r'([ء-ي])([a-zA-Z])', r'\1 \2', text)
elif target_lang.lower() in ['english', 'en']:
# Fix English-specific issues
text = re.sub(r'\s+([.,!?])', r'\1', text) # Fix spacing before punctuation
text = re.sub(r'([.,!?])(?=[^\s])', r'\1 ', text) # Fix spacing after punctuation
text = re.sub(r'\s+', ' ', text) # Normalize spaces
text = text.replace(' ,', ',')
text = text.replace(' .', '.')
# Capitalize first letter of sentences
text = '. '.join(s.capitalize() for s in text.split('. '))
return text.strip()
def get_language_name(self, code: str) -> str:
"""Get the display name for a language code."""
names = {
'ar': 'العربية',
'en': 'English',
'zh': '中文',
'hi': 'हिंदी',
'ur': 'اردو'
}
return names.get(code, code)