Fine-tuning a pre-trained Sequence to Sequence model for code translation
In my previous article, A Practical Guide to Building a Seq2Seq Transformer from scratch with PyTorch, I described how I built my own sequence to sequence model, trained it and used it to translate PL/1 code to readable Kotlin code.
Today, I want to follow the same principles as I did before but with a few key differences:
- I fine-tune a pre-trained model from Hugging face instead of simply training random weights as I previously did. This way will hopefully get better results with less training and in less time.
- I also want to try taking advantage of more frameworks to simplify the whole process.
Installing libraries
As always I start off by installing all of the necessary libraries which are: the transformers library which provides me with a pre-trained model, a pre-built tokenizer and a data collator , the datasets library which helps with the handling of my data, the torch library – the foundational PyTorch library – widely used for deep learning applications and finally antlr4-python3-runtime designed to help with parsing and processing my code
!pip install transformers datasets evaluate torch !pip install antlr4-python3-runtime==4.9.2
Loading the data
I start of the whole process by loading my dataset from my own JSON files. My dataset is split into two files: train.json and test.json. The training set (train.json) is used to train my model, while the validation set (test.json) is used to evaluate the model’s performance during training. Validation ensures that the model is genuinely learning patterns from the data, rather than merely memorising it
from datasets import load_dataset data_files = { "train": "data/train.json", "validation": "data/test.json" } data = load_dataset("json", data_files=data_files, field="data")
Preprocessing the Data
I now load a pre-trained BART tokenizer from Hugging Face, which I will use to preprocess my textual data. At the same time, I also define the source and target languages, as well as a simple prefix to guide the model during training
from transformers import AutoTokenizer checkpoint = "facebook/bart-base" tokenizer = AutoTokenizer.from_pretrained(checkpoint, force_download=True) source_lang = "pli" target_lang = "ktl" prefix = "translate PL/I to Kotlin: "
I then define a preprocessing function to prepare the data for the model. This function adds a prefix to the source text, tokenizes both input and target data and truncates them to a maximum length, to ensure that the model only processes the relevant parts of the text
def preprocess_function(examples): inputs = [prefix + " ".join(example) for example in examples[source_lang]] targets = [" ".join(example) for example in examples[target_lang]] model_inputs = tokenizer(inputs, text_target=targets, max_length=128, truncation=True) return model_inputs tokenized_datasets = data.map(preprocess_function, batched=True)
Creating a data collator
Data collators are object that use a list of dataset elements as input to form batches.
So, I create a data collator that dynamically pads the input and target data during batching. This ensures that all sequences within a given batch have the same length, which is crucial for efficient training
from transformers import DataCollatorForSeq2Seq # Create data collator data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=checkpoint)
Initialising the Model
I now load the BART (Bidirectional and Auto-Regressive Transformer) model from Hugging Face and initialise it. BART is pre-trained model specifically designed for sequence-to-sequence tasks such as the one I am working on today
from transformers import AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
Setting Up the Training Arguments
I configure key training parameters to optimise my model’s training process. These include the output directory, evaluation strategy, learning rate, batch sizes, weight decay, number of epochs, and logging strategy. These configurations are designed to fine-tune the training dynamics and ensure optimal performance
training_args = Seq2SeqTrainingArguments( output_dir="pli_to_kotlin", eval_strategy="epoch", learning_rate=2e-5, per_device_train_batch_size=16, per_device_eval_batch_size=16, weight_decay=0.01, save_total_limit=3, num_train_epochs=150, predict_with_generate=True, fp16=True, logging_strategy="steps", logging_steps=10, )
Initialising the Trainer
I also need to initialise the Seq2SeqTrainer, a specialised class from the Hugging Face Transformers library. This class simplifies and automates the training and evaluation of sequence-to-sequence models. I provide it with all previously defined elements: the model, training arguments, datasets, tokenizer, and data collator
trainer = Seq2SeqTrainer( model=model, args=training_args, train_dataset=tokenized_datasets["train"], eval_dataset=tokenized_datasets["validation"], tokenizer=tokenizer, data_collator=data_collator, )
Training the Model
Finally, I can initiate the training process by calling the train() method on the trainer I just initialised
trainer.train()
[150/150 00:17, Epoch 150/150]
Epoch | Training Loss | Validation Loss |
---|---|---|
1 | No log | 6.627490 |
50 | 0.328100 | 0.042963 |
100 | 0.092400 | 0.016886 |
150 | 0.053000 | 0.006777 |
TrainOutput(global_step=150, training_loss=0.5374304989973704, metrics={'train_runtime': 18.391, 'train_samples_per_second': 130.499, 'train_steps_per_second': 8.156, 'total_flos': 40013955072000.0, 'train_loss': 0.5374304989973704, 'epoch': 150.0})
Saving the model
Once the training is done and I am satisfied with the results, I can save my now fine-tuned model and tokenizer to a specified directory. This will allow me to use my model later for inference
model.save_pretrained("./checkpoint") tokenizer.save_pretrained("./checkpoint")
('./checkpoint/tokenizer_config.json', './checkpoint/special_tokens_map.json', './checkpoint/vocab.json', './checkpoint/merges.txt', './checkpoint/added_tokens.json', './checkpoint/tokenizer.json')
Running inference
And now for the fun part, using my model to see if it can actually translate some code.
I start off by loading in my saved model and tokenizer
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer import torch # Load the model and tokenizer model = AutoModelForSeq2SeqLM.from_pretrained("./checkpoint") tokenizer = AutoTokenizer.from_pretrained("./checkpoint")
I then have to set up the device (GPU if available, otherwise CPU) and move the model to the appropriate device for efficient computation
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device)
I define a Translate_sequence function that tokenizes some input data using my tokenizer, translates the data, generates a translated_sequence and decodes it back to a readable format
def translate_sequence(sentence, tokenizer, model, device, max_length=50): inputs = tokenizer(prefix + sentence, return_tensors="pt", max_length=512, truncation=True).to(device) outputs = model.generate(inputs["input_ids"], max_length=max_length, num_beams=4, early_stopping=True) translated_sentence = tokenizer.decode(outputs[0], skip_special_tokens=True) return translated_sentence
I also define a transpile_sequence function that formats the translated code by handling indentation and using Jinja2 templates to insert context-specific data into the code
from jinja2 import Template def transpile_sequence(translated, context, level=0): tokens = translated.split() lint = [] current_line = "" for t in tokens: t = context.get(t, t) if t in ["{", "}"]: if current_line: lint.append("".rjust(level * 4) + current_line.strip()) current_line = "" if t == "{": lint.append("".rjust(level * 4) + t) level += 1 elif t == "}": level -= 1 lint.append("".rjust(level * 4) + t) else: current_line += " " + t if current_line: lint.append("".rjust(level * 4) + current_line.strip()) formatted_code = "\n".join(lint) template = Template(formatted_code) rendered_code = template.render(context) return rendered_code, level
To bring everything together, I define a parse_and_translate function uses ANTLR to parse a sample PL/1 file, extract all the statements using a visitor pattern, translate these statements into Kotlin and format the translated code
from antlr4 import * from pli.PLILexer import PLILexer from pli.PLIParser import PLIParser from pli.PLIVisitor import PLIVisitor from jinja2 import Template def parse_and_translate(filename): with open(filename, 'r') as file: original_code = file.read() print("PL/I Code:") print(original_code) print() # Lexer and parser setup input_stream = InputStream(original_code) lexer = PLILexer(input_stream) stream = CommonTokenStream(lexer) parser = PLIParser(stream) tree = parser.program() # Visitor setup visitor = PLIVisitor() statements = visitor.visit(tree) # Translate and transpile each statement transpiled_code = "" level = 0 for stmt in statements: pli_code = " ".join(stmt['pli']) context = stmt.get('context', {}) translated = translate_sequence(pli_code, tokenizer, model, device) transpiled, level = transpile_sequence(translated, context, level) transpiled_code += transpiled + "\n" print("Kotlin Code:") print(transpiled_code)
Example
To finish off, I want to test whether my fine-tuned model meets expectations or requires further training. I invoke the parse_and_translate function with my sample PL/1 file as its parameter, allowing me to see some results
# Example usage filename = "FIB.PLI" # Sample file parse_and_translate(filename)
PL/I Code: Factorial: proc options (main); dcl (n,result) fixed bin(31); n = 5; result = Compute_factorial(n); end Factorial; /***********************************************/ /* Subroutine */ /***********************************************/ Compute_factorial: proc (n) returns (fixed bin(31)); dcl n fixed bin(15); if n <= 1 then return(1); return( n*Compute_factorial(n-1) ); end Compute_factorial; Kotlin Code: fun main (args: Array<String>) { var n : Int var result : Int n = 5 result = compute_factorial(n) } fun compute_factorial(n : Int) : Int { if(n<=1) { return 1 } return n*compute_factorial(n-1) }
Fine-tuning a pre-trained Seq2Seq model did indeed prove to be an efficient and practical approach to my code translation task. I managed to get done exactly what I set out to do: achieve similar or better results with less effort and less time than training from scratch. By combining pre-trained architectures with simplified preprocessing and modern tools, I streamlined the entire process while still maintaining the quality outputs I got on first go out