pith. sign in

arxiv: 2510.09898 · v2 · pith:2S3DVO2Ynew · submitted 2025-10-10 · 💻 cs.LG · cs.AI

Learning Bug Context for PyTorch-to-JAX Translation with LLMs

classification 💻 cs.LG cs.AI
keywords codetranslationpytorch-to-jaxbugsexecutiongpt-4o-miniimprovelearning
0
0 comments X
read the original abstract

Large language models (LLMs) have shown strong performance on code translation between widely used programming languages. However, translation becomes much less reliable for domain-specific code, where correctness depends on framework-specific APIs and execution semantics. One example is translating deep-learning code from PyTorch to JAX, where LLM outputs often contain subtle bugs or non-idiomatic usage that prevents execution or changes behavior. Prior work suggests that curated bug-fix data from LLM-generated code can help improve code generation quality, but such resources are still limited for PyTorch-to-JAX translation. In this work, we introduce T2J, a benchmark of LLM translation bugs paired with developer-written fixes for PyTorch-to-JAX code. We start from 20 kernels in the TorchLeet dataset, translate them to JAX using the weak LLM gpt-4o-mini, and hire software developers to debug and repair the generated JAX implementations. We then use T2J to improve PyTorch-to-JAX translation for the weak LLM gpt-4o-mini via in-context learning. Our evaluation shows that using T2J yields up to 20% improvement of our proposed metric T2J-CodeTrans-Score.

This paper has not been read by Pith yet.

discussion (0)

Sign in with ORCID, Apple, or X to comment. Anyone can read and Pith papers without signing in.