Hi everyone. How's it going? I'm Sean.
I'm going to be sharing with you guys like the results of a paper that I wrote and had published at this conference called the 37th International Conference on Legal Knowledge and Information Systems.
So the title of this presentation is let's investigate the effects of fine-tuning your model on domain data on its internal reasoning processes.
And of course, because my background is in law, then we'll be doing a case study in the legal domain.
So the question really is, when we're fine-tuning models, is to try and get them to understand domain-specific concepts to a greater level, like what actually happens inside the model and what does that show us outside of what benchmarks can show us.
So we'll start with a bit of the motivation for the research with some very pertinent questions that everyone who works with LLMs, I think you would have had some of these questions, like what model to choose. And we should be aware that there are a lot of lightweight models, such as BERT, which actually do really well for certain use cases, even this year.
And then once you decided on your model, it's like, when I needed to be fine-tuned on a domain-specific data set, what exactly is the training regime going to look like? Am I just going to do continued pre-training alone, or am I going to create some sort of instruction-tuning data set? Or do I do both, of course? And what is the shape? off that data, stuff like that.
All this really has a very big impact on how your model does on the benchmarks and in the real world as well. And let's say we've picked the model and we've done some fine tuning on it. Then there are a lot of practical issues that arise when we actually deploy these models.
One, the benchmarks always look really good. Your tuning data set is always going to be optimized if you sort of try hard enough. But then why does it feel not so right or why are customers still not very happy with the model's performance?
And how is like, why is my prompt basically not getting the model to do what I need it to do? So how do I modify the prompt
And specifically when I'm working in the legal domain or the financial domain, there are very specific terms which mean one thing in your domain, which are the exact same as some other general word, which means something else. So how is that going to influence the performance of your model?
Yeah, so that's sort of why we need to look into what we're doing when we fine-tune so we know how to improve the fine-tuning regime to address all these very important questions.
So the specific research questions that I looked into was how does the choice of the model impact performance not on a general domain but a very specific domain like the legal domain.
And how exactly does your choice of data set and the setup of your training regime impact that understanding.
So this is an example of like two very different very similar medical data sets, but one is like a question answering style data set, and then one is just like a definition style data set.
And we have to note that there's different distributions of information content in these training samples as well. For example, there's much more diversity in the lengths of these labels and x values compared to the ones here. So all that's going to sort of have an impact on your fine-tuning outcomes as well.
at what theoretical foundations do we even have to begin to investigate these questions. And so my proposal is we can look at the attention scores that are happening inside the model as a way to investigate what's happening.
So just for a recap of how attention works, let's say you have like an input text sequence here that could be like a prompt and then you have certain words that represent domain specific concepts. The model is going to feed them in and create hidden states out of them and apply some sort of summation function to create the ultimate adjusted latent space.
So for example if your attention score is very high on this one word then in latent space the model is going to pay a lot of attention to this word in your input sequence. So ideally like you would want your model to look at the specific legal or whatever domain concept word it's looking at because it's going to be relying on that more when it's deciding what the best output sequence is.
So we can sort of develop some intuition from the idea of attention scores, and this is sort of the crux of the paper, which is that domain-specific concept tokens are always going to be subsets of the entire input sequence. So in this example, I have a certain legal concept, and I know that these four words are the only domain-specific concepts I'm concerned with. So you can sort of compute the attention on all of the tokens and measure the proportion of attention that was allocated only on those concepts.
So for example, you would unroll the entire attention matrix and look at all the individual scores and then because I'm only concerned with these specific domain tokens, then I can compute the proportion of or the weight that the model is giving to when it's computing what your input sequence needs. That's just the mathematical formalization.
We don't have to worry about that. And so we have more questions, you know, when we have that intuition. Right.
Because one, it's like so. After I do some fine tuning, what's that doing to these values? How are those values changing before and after fine tuning?
And fine tuning in the form of both continued pre-training and instruction fine tuning. Are there differences? Should I continue to do instruction fine tuning or should I just stop?
And again, just to reiterate, this is really important because Or you will only think this is important if you agree with me that attention scores measure the rate at which the model is using those domain concepts. So if the attention scores on the domain concepts is very low, then it suggests that the model is not even considering these domain concepts when it's producing the output sequence.
And so I would propose, and that's just my opinion, that Attention scores are a sort of dataset agnostic method for determining your fine-tuned LLM's ability to deal with the concepts you needed to deal with. So after I fine-tune a model on financial data, for example, these attention scores can tell me how much the model is actually using those financial concepts.
Uh, so the, the, the case study in, in my paper looked at this particular model family called the Saul LM, uh, family, which was, um, basically Mistral, um, seven, seven B, um, with a base legal domain model, um, derived from continued pre-training and then an instruction fine tuned, um, variant.
Yeah.
And so the methodology I used was, you know, let's compute those attention scores on the base model and then we'll do the fine tuning and then we'll run the exact same input sequence and compare, you know, how did those attention scores change? And we do the comparison across individual attention heads. across all of the attention heads in the aggregate because most of these models have a lot of attention heads.
And we want to look at differences across layers. We'll see why in a minute.
So this is a standard attention matrix that you would see in any model when you just expose the attention matrix. And so these numbers here are telling you effectively the rate of utilization of these tokens in latent space. So after we do the fine tuning, we can compare how did these numbers change.
So this is Mistral 7b. Let's say I do fine tuning and instruction fine tuning. If I feed in the exact same text sequence, these numbers are going to change and I can compute sort of like a matrix of differences before and after fine tuning.
And this matrix tells me it helps me isolate the impact of that training and fine tuning basically. And so here's where we can really begin to evaluate the impact of your domain fine tuning on, you know,
So one interesting thing I found, and first we need to observe that in many cases in the real world, specific concepts are made up of a series of multiple tokens. and subwords.
So it's never really going to be the case that only one token is sufficient to explain a concept like in the legal field, like notice of retrenchment, which stands for a very specific thing. But it needs to be sort of made up of many different tokens in sequence.
But when you look at what the effect of that fine tuning was, you can see that only the first two tokens in that sequence were altered by the fine tuning. And then the second half of that, sequence there is basically no effect after the fine tuning.
So this is really bad news if for example I am an attorney and I deal exclusively for example with notices of retrenchment. So this has shown me that like if I've done this fine tuning on this huge legal data set, it's actually not improved or changed its understanding of notices of retrenchment in any way.
It does sort of pay more attention to the word notice, but not retrenchment and certainly not as an uninterrupted contiguous sequence of tokens, which will have implications on my task. This is another interesting way to look at the effects of fine-tuning.
Let's let the x-axis be the layer depth, so different attention layers. We have different depths of the layer 0 is the early attention layers and 30 is the latest attention layer. And then the y-axis is the average increase in attention toward those domain concepts across all the heads in that one layer. And when you plot that, you get something that looks like this.
So the red line is... the impact is is the attention change from the base misrule 7b after both continued pre-training and ift and then blue line is just what happens when you stop at continued pre-training and there are some interesting observations so the first is that very similarly across many of the legal concepts
when you look at the range of values that the attention changes take, it's very small, it's like within plus minus 1%, which means that in the overall, it's actually not really increased the rate of utilization of those legal concepts when it's processing. which is interesting.
The second thing is, in many of the cases, within maybe the first three or four layers, like layer zero, one, and two, you're seeing a small bump. increase in attention. Some of them are quite large, for example, 2.5% to 5%, which on the average is quite a large increase.
Why does this matter? It matters because a lot of research is beginning to show that the early layers are what helps your model form a very broad overview of what your prompt is talking about. So when you fine tune your model on domain data, at least in this case, it does seem to be increasing the utilization of those domain concepts when it's beginning to wrap around what your prompt means as a whole.
But if you look at the later layers, that's where a lot of the drop-off happens. In some cases, with values like greater than minus 5%, which is concerning. And again, this is interesting because more research is beginning to show that it's the later attention layers that mediate local next word level context.
So when you put them together, you can see that the impact of fine-tuning is sort of being distributed unevenly throughout different levels of the hierarchy. So at a very high level understanding, yeah, maybe it's increasing its understanding of the main concepts, but when it's looking at things like what the next word is or very local fine-grained context, it's actually doing worse.
Yeah. Another thing to note is that in a lot of the cases, the red line or basically like the mean change in attention is always higher after you use IFT. So it means that like if you're using continued pre-training, you kind of need further instruction fine tuning to sort of give you that additional bump of like maybe half of 1%. As it says here, further IFT usually results in better representation of the main concepts.
The last few things we're going to look at are just general summary statistics of the attention shifts across all the attention heads. The first measure is the entropy. Larger values are highlighted in bold.
So you can see here that almost all the time, you're getting greater values of entropy when you're doing additional fine tuning. So if I were to sort of do a plot of what the distribution of those changes would look like, it looks a little bit better after I do fine tuning because it shows that change is being spread more evenly across more attention heads, which
could suggest, not necessarily, that all aspects of the model's reasoning are benefiting from the fine-tuning as compared to continued pre-training alone. As it says here, more evenly distributed.
We'll look at skewness as well. And again, more positive values are bold and So it says here, the skewness values are always more positive after you do instruction fine-tuning.
Another thing to note is that if you do only continued pre-training, you get a lot of negative skewness, which is not necessarily what you want. Because if you plot that, what it's going to show is that when you stop at continued pre-training, the effect of that pre-training is actually pulling your model's utilization of domain concepts toward lower values. You need IFT to sort of give it that bump.
The last thing would be the kurtosis. And again, you see here, legal continued pre-training alone resulted in generally higher kurtosis values across the board. And what that means is that
the attention changes after continued pre-training alone result in much fatter tails, which means you get much more extreme shifts, which means there's some destabilization in how your model is mediating those concepts. And you have to sort of use instruction fine-tuning to re-stabilize your model.
But again, these values are still pretty large, even after you do IFT. So it just means that in general, when you're fine-tuning your model, it does introduce a lot of instability.
So in summary, here are the impacts.
And the most important one is when you do domain training on On an LLM, it is not going to learn concepts in the same way that human experts do, unfortunately. In fact, it's actually going to, in many cases, degrade the rate of model utilization of those concepts, especially when you don't do IFT.
Importantly, there's also no demonstrable impact in what the LLM is doing before and after fine-tuning. So be careful when you do fine-tuning.
That's it.
So for questions, you can connect with me on LinkedIn on the left. The paper itself is linked here on the right if you want to see it. And then I'm open to working on projects as well. So if anyone has any ideas or questions, please feel free to talk to me or shoot me an email.
Thanks.