Fine-tuning a pre-trained Sequence to Sequence model for code translation

In my previous article, A Practical Guide to Building a Seq2Seq Transformer from scratch with PyTorch, I described how I built my own sequence to sequence model, trained it and used it to translate PL/1 code to readable Kotlin code.

Today, I want to follow the same principles as I did before but with a few key differences: 

  • I fine-tune a pre-trained model from Hugging face instead of simply training random weights as I previously did. This way will hopefully get better results with less training and in less time.
  • I also want to try taking advantage of more frameworks to simplify the whole process.

Installing libraries

As always I start off by installing all of the necessary libraries which are: the transformers library which provides me with a pre-trained model, a pre-built tokenizer and a data collator , the datasets library which helps with the handling of my data, the torch library – the foundational PyTorch library – widely used for deep learning applications and finally antlr4-python3-runtime designed to help with parsing and processing my code

!pip install transformers datasets evaluate torch
!pip install antlr4-python3-runtime==4.9.2

Loading the data

I start of the whole process by loading my dataset from my own JSON files. My dataset is split into two files: train.json and test.json. The training set (train.json) is used to train my model, while the validation set (test.json)  is used to evaluate the model’s performance during training. Validation ensures that the model is genuinely learning patterns from the data, rather than merely memorising it

from datasets import load_dataset

data_files = {
    "train": "data/train.json",
    "validation": "data/test.json"
}
data = load_dataset("json", data_files=data_files, field="data")

Preprocessing the Data

I now load a pre-trained BART tokenizer from Hugging Face, which I will use to preprocess my textual data. At the same time, I also define the source and target languages, as well as a simple prefix to guide the model during training

from transformers import AutoTokenizer

checkpoint = "facebook/bart-base"
tokenizer = AutoTokenizer.from_pretrained(checkpoint, force_download=True)

source_lang = "pli"
target_lang = "ktl"
prefix = "translate PL/I to Kotlin: "

I then define a preprocessing function to prepare the data for the model. This function adds a prefix to the source text, tokenizes both input and target data and truncates them to a maximum length, to ensure that the model only processes the relevant parts of the text

def preprocess_function(examples):
    inputs = [prefix + " ".join(example) for example in examples[source_lang]]
    targets = [" ".join(example) for example in examples[target_lang]]
    model_inputs = tokenizer(inputs, text_target=targets, max_length=128, truncation=True)
    return model_inputs

tokenized_datasets = data.map(preprocess_function, batched=True)

Creating a data collator

Data collators are object that use a list of dataset elements as input to form batches.

So, I create a data collator that dynamically pads the input and target data during batching. This ensures that all sequences within a given batch have the same length, which is crucial for efficient training

from transformers import DataCollatorForSeq2Seq

# Create data collator
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=checkpoint)

Initialising the Model

I now load the BART (Bidirectional and Auto-Regressive Transformer) model from Hugging Face and initialise it. BART is pre-trained model specifically designed for sequence-to-sequence tasks such as the one I am working on today

from transformers import AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer

model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)

Setting Up the Training Arguments

I configure key training parameters to optimise my model’s training process. These include the output directory, evaluation strategy, learning rate, batch sizes, weight decay, number of epochs, and logging strategy. These configurations are designed to fine-tune the training dynamics and ensure optimal performance

training_args = Seq2SeqTrainingArguments(
    output_dir="pli_to_kotlin",
    eval_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=150,
    predict_with_generate=True,
    fp16=True,
    logging_strategy="steps",
    logging_steps=10,
)

Initialising the Trainer

I also need to initialise the Seq2SeqTrainer, a specialised class from the Hugging Face Transformers library. This class simplifies and automates the training and evaluation of sequence-to-sequence models. I provide it with all previously defined elements: the model, training arguments, datasets, tokenizer, and data collator

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
)

Training the Model

Finally, I can initiate the training process by calling the train() method on the trainer I just initialised

trainer.train()

[150/150 00:17, Epoch 150/150]

Epoch Training Loss Validation Loss
1 No log 6.627490
50 0.328100 0.042963
100 0.092400 0.016886
150 0.053000 0.006777

TrainOutput(global_step=150, training_loss=0.5374304989973704, metrics={'train_runtime': 18.391, 'train_samples_per_second': 130.499, 'train_steps_per_second': 8.156, 'total_flos': 40013955072000.0, 'train_loss': 0.5374304989973704, 'epoch': 150.0})

Saving the model

Once the training is done and I am satisfied with the results, I can save my now fine-tuned model and tokenizer to a specified directory. This will allow me to use my model later for inference

model.save_pretrained("./checkpoint")
tokenizer.save_pretrained("./checkpoint")

('./checkpoint/tokenizer_config.json',
 './checkpoint/special_tokens_map.json',
 './checkpoint/vocab.json',
 './checkpoint/merges.txt',
 './checkpoint/added_tokens.json',
 './checkpoint/tokenizer.json')

Running inference

And now for the fun part, using my model to see if it can actually translate some code.

I start off by loading in my saved model and tokenizer

from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch

# Load the model and tokenizer
model = AutoModelForSeq2SeqLM.from_pretrained("./checkpoint")
tokenizer = AutoTokenizer.from_pretrained("./checkpoint")

I then have to set up the device (GPU if available, otherwise CPU) and move the model to the appropriate device for efficient computation

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

I define a Translate_sequence function that tokenizes some input data using my tokenizer, translates the data, generates a translated_sequence and decodes it back to a readable format

def translate_sequence(sentence, tokenizer, model, device, max_length=50):
    inputs = tokenizer(prefix + sentence, return_tensors="pt", max_length=512, truncation=True).to(device)
    outputs = model.generate(inputs["input_ids"], max_length=max_length, num_beams=4, early_stopping=True)
    translated_sentence = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return translated_sentence

I also define a transpile_sequence function that formats the translated code by handling indentation and using Jinja2 templates to insert context-specific data into the code

from jinja2 import Template

def transpile_sequence(translated, context, level=0):
    tokens = translated.split()
    lint = []
    current_line = ""
    for t in tokens:
        t = context.get(t, t)
        if t in ["{", "}"]:
            if current_line:
                lint.append("".rjust(level * 4) + current_line.strip())
                current_line = ""
            if t == "{":
                lint.append("".rjust(level * 4) + t)
                level += 1
            elif t == "}":
                level -= 1
                lint.append("".rjust(level * 4) + t)
        else:
            current_line += " " + t
    if current_line:
        lint.append("".rjust(level * 4) + current_line.strip())

    formatted_code = "\n".join(lint)
    template = Template(formatted_code)
    rendered_code = template.render(context)
    return rendered_code, level

To bring everything together, I define a parse_and_translate function uses ANTLR to parse a  sample PL/1 file, extract all the statements using a visitor pattern, translate these statements into Kotlin and format the translated code

from antlr4 import *
from pli.PLILexer import PLILexer
from pli.PLIParser import PLIParser
from pli.PLIVisitor import PLIVisitor
from jinja2 import Template

def parse_and_translate(filename):
    with open(filename, 'r') as file:
        original_code = file.read()
    print("PL/I Code:")
    print(original_code)
    print()

    # Lexer and parser setup
    input_stream = InputStream(original_code)
    lexer = PLILexer(input_stream)
    stream = CommonTokenStream(lexer)
    parser = PLIParser(stream)
    tree = parser.program()

    # Visitor setup
    visitor = PLIVisitor()
    statements = visitor.visit(tree)

    # Translate and transpile each statement
    transpiled_code = ""
    level = 0
    for stmt in statements:
        pli_code = " ".join(stmt['pli'])
        context = stmt.get('context', {})
        translated = translate_sequence(pli_code, tokenizer, model, device)
        transpiled, level = transpile_sequence(translated, context, level)
        transpiled_code += transpiled + "\n"

    print("Kotlin Code:")
    print(transpiled_code)

Example

To finish off, I want to test whether my fine-tuned model meets expectations or requires further training. I invoke the parse_and_translate function with my sample PL/1 file as its parameter, allowing me to see some results

# Example usage
filename = "FIB.PLI"  # Sample file
parse_and_translate(filename)

PL/I Code:
 Factorial: proc options (main);
    dcl (n,result) fixed bin(31);
    n  = 5;
    result = Compute_factorial(n);

 end Factorial;
  /***********************************************/
  /* Subroutine                                  */
  /***********************************************/
  Compute_factorial: proc (n)  returns (fixed bin(31));
     dcl n fixed bin(15);
      if n <= 1 then
        return(1);

     return( n*Compute_factorial(n-1) );

  end Compute_factorial;


Kotlin Code:
fun main (args: Array<String>)
{
    var n : Int
    var result : Int
    n = 5
    result = compute_factorial(n)
}
fun compute_factorial(n : Int) : Int
{
    if(n<=1)
    {
        return 1
    }
    return n*compute_factorial(n-1)
}

Fine-tuning a pre-trained Seq2Seq model did indeed prove to be an efficient and practical approach to my code translation task. I managed to get done exactly what I set out to do: achieve similar or better results with less effort and less time than training from scratch. By combining pre-trained architectures with simplified preprocessing and modern tools, I streamlined the entire process while still maintaining the quality outputs I got on first go out

Scroll to Top