Spaces:
Running
Running
from transformers import MarianMTModel, MarianTokenizer, pipeline | |
import torch | |
from langdetect import detect | |
import re | |
class Translator: | |
def __init__(self): | |
self.models = {} | |
self.tokenizers = {} | |
self.language_codes = { | |
'arabic': 'ar', | |
'english': 'en', | |
'chinese': 'zh', | |
'hindi': 'hi', | |
'urdu': 'ur' | |
} | |
# Initialize models for each language pair | |
self._load_model('en', 'ar') # English to Arabic | |
self._load_model('ar', 'en') # Arabic to English | |
# Add other language pairs as needed | |
def _load_model(self, src_lang, tgt_lang): | |
"""Load translation model for a specific language pair.""" | |
model_name = f'Helsinki-NLP/opus-mt-{src_lang}-{tgt_lang}' | |
key = f'{src_lang}-{tgt_lang}' | |
if key not in self.models: | |
try: | |
self.tokenizers[key] = MarianTokenizer.from_pretrained(model_name) | |
self.models[key] = MarianMTModel.from_pretrained(model_name) | |
except Exception as e: | |
print(f"Error loading model for {key}: {str(e)}") | |
def translate(self, text: str, source_lang: str, target_lang: str) -> str: | |
"""Translate text from source language to target language with improved handling.""" | |
src_code = self.language_codes.get(source_lang.lower()) | |
tgt_code = self.language_codes.get(target_lang.lower()) | |
if not src_code or not tgt_code: | |
raise ValueError("Unsupported language") | |
key = f'{src_code}-{tgt_code}' | |
if key not in self.models: | |
self._load_model(src_code, tgt_code) | |
if key not in self.models: | |
raise ValueError(f"Translation model not available for {source_lang} to {target_lang}") | |
tokenizer = self.tokenizers[key] | |
model = self.models[key] | |
try: | |
# Preprocess text | |
text = self.preprocess_text(text) | |
# Split text into manageable chunks | |
chunks = self._split_text_into_chunks(text) | |
translated_chunks = [] | |
for chunk in chunks: | |
# Clear GPU memory | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
# Tokenize with improved settings | |
inputs = tokenizer( | |
chunk, | |
return_tensors="pt", | |
padding=True, | |
truncation=True, | |
max_length=512, | |
add_special_tokens=True | |
) | |
# Generate translation with improved settings | |
with torch.no_grad(): | |
translated = model.generate( | |
**inputs, | |
num_beams=2, # Reduced for memory efficiency | |
length_penalty=0.6, | |
max_length=512, | |
min_length=0, | |
early_stopping=True | |
) | |
# Decode the translation | |
result = tokenizer.batch_decode(translated, skip_special_tokens=True)[0] | |
translated_chunks.append(result) | |
# Combine chunks | |
final_translation = ' '.join(translated_chunks) | |
# Post-process translation | |
final_translation = self._post_process_translation(final_translation, target_lang) | |
return final_translation | |
except Exception as e: | |
print(f"Translation error: {str(e)}") | |
return text # Return original text if translation fails | |
def detect_language(self, text: str) -> str: | |
"""Detect the language of the input text.""" | |
try: | |
# Clean text for better detection | |
cleaned_text = re.sub(r'[^\w\s]', '', text) | |
detected = detect(cleaned_text) | |
# Map detected language code to our supported languages | |
lang_code_map = { | |
'ar': 'arabic', | |
'en': 'english', | |
'zh': 'chinese', | |
'hi': 'hindi', | |
'ur': 'urdu' | |
} | |
return lang_code_map.get(detected, 'english') # Default to English if unknown | |
except: | |
return 'english' # Default to English if detection fails | |
def preprocess_text(self, text: str) -> str: | |
"""Preprocess text before translation.""" | |
# Remove excessive whitespace | |
text = re.sub(r'\s+', ' ', text).strip() | |
# Remove special characters that might interfere with translation | |
text = re.sub(r'[^\w\s\.,!?-]', '', text) | |
return text | |
def get_supported_languages(self): | |
"""Return list of supported languages.""" | |
return list(self.language_codes.keys()) | |
def _split_text_into_chunks(self, text: str, max_chunk_size: int = 450) -> list: | |
"""Split text into manageable chunks for translation.""" | |
# First try to split by paragraphs | |
paragraphs = text.split('\n\n') | |
chunks = [] | |
current_chunk = [] | |
current_length = 0 | |
for para in paragraphs: | |
# If a single paragraph is too long, split it by sentences | |
if len(para) > max_chunk_size: | |
sentences = re.split(r'([.!?])\s+', para) | |
i = 0 | |
while i < len(sentences): | |
sentence = sentences[i] | |
if i + 1 < len(sentences): | |
sentence += sentences[i + 1] # Add back the punctuation | |
i += 2 | |
else: | |
i += 1 | |
if current_length + len(sentence) > max_chunk_size: | |
if current_chunk: | |
chunks.append(' '.join(current_chunk)) | |
current_chunk = [] | |
current_length = 0 | |
current_chunk.append(sentence) | |
current_length += len(sentence) | |
else: | |
if current_length + len(para) > max_chunk_size: | |
chunks.append(' '.join(current_chunk)) | |
current_chunk = [] | |
current_length = 0 | |
current_chunk.append(para) | |
current_length += len(para) | |
if current_chunk: | |
chunks.append(' '.join(current_chunk)) | |
return chunks | |
def _post_process_translation(self, text: str, target_lang: str) -> str: | |
"""Post-process translated text based on target language.""" | |
if target_lang.lower() in ['arabic', 'ar']: | |
# Fix Arabic-specific issues | |
text = re.sub(r'([ء-ي])\s+([ء-ي])', r'\1\2', text) # Remove spaces between Arabic letters | |
text = re.sub(r'[\u200B-\u200F\u202A-\u202E]', '', text) # Remove Unicode control characters | |
# Fix common Arabic punctuation issues | |
text = text.replace('،,', '،') | |
text = text.replace('.,', '.') | |
text = text.replace('؟?', '؟') | |
text = text.replace('!!', '!') | |
# Ensure proper spacing around numbers and Latin text | |
text = re.sub(r'([0-9])([ء-ي])', r'\1 \2', text) | |
text = re.sub(r'([ء-ي])([0-9])', r'\1 \2', text) | |
text = re.sub(r'([a-zA-Z])([ء-ي])', r'\1 \2', text) | |
text = re.sub(r'([ء-ي])([a-zA-Z])', r'\1 \2', text) | |
elif target_lang.lower() in ['english', 'en']: | |
# Fix English-specific issues | |
text = re.sub(r'\s+([.,!?])', r'\1', text) # Fix spacing before punctuation | |
text = re.sub(r'([.,!?])(?=[^\s])', r'\1 ', text) # Fix spacing after punctuation | |
text = re.sub(r'\s+', ' ', text) # Normalize spaces | |
text = text.replace(' ,', ',') | |
text = text.replace(' .', '.') | |
# Capitalize first letter of sentences | |
text = '. '.join(s.capitalize() for s in text.split('. ')) | |
return text.strip() | |
def get_language_name(self, code: str) -> str: | |
"""Get the display name for a language code.""" | |
names = { | |
'ar': 'العربية', | |
'en': 'English', | |
'zh': '中文', | |
'hi': 'हिंदी', | |
'ur': 'اردو' | |
} | |
return names.get(code, code) |