Skip to content

stanfordnlp/pyvene

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation



This is a beta release (public testing).

A Library for Understanding and Improving PyTorch Models via Interventions

Interventions on model-internal states are fundamental operations in many areas of AI, including model editing, steering, robustness, and interpretability. To facilitate such research, we introduce pyvene, an open-source Python library that supports customizable interventions on a range of different PyTorch modules. pyvene supports complex intervention schemes with an intuitive configuration format, and its interventions can be static or include trainable parameters.

Getting Started: [Main pyvene 101]

Installation

Since we are currently beta-testing, it is recommended to install pyvene by,

git clone git@github.com:stanfordnlp/pyvene.git

and add pyvene into your system path in python via,

import sys
sys.path.append("<Your Path to Pyvene>")

import pyvene as pv

Alternatively, you can do

pip install git+https://github.com/stanfordnlp/pyvene.git

or

pip install pyvene

Wrap , Intervene and Share

You can intervene with any HuggingFace model as,

import torch
import pyvene as pv
from transformers import AutoTokenizer, AutoModelForCausalLM

model_name = "meta-llama/Llama-2-7b-hf" # your HF model name.
model = AutoModelForCausalLM.from_pretrained(
    model_name, torch_dtype=torch.bfloat16, device_map="cuda")
tokenizer = AutoTokenizer.from_pretrained(model_name)

def zeroout_intervention_fn(b, s): 
    b[:,3] = 0. # 3rd position
    return b

pv_model = pv.IntervenableModel({
    "component": "model.layers[15].mlp.output", # string access
    "intervention": zeroout_intervention_fn}, model=model)

# run the intervened forward pass
orig_outputs, intervened_outputs = pv_model(
    tokenizer("The capital of Spain is", return_tensors="pt").to('cuda'),
    output_original_output=True
)
print(intervened_outputs.logits - orig_outputs.logits)

which returns,

tensor([[[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.4375,  1.0625,  0.3750,  ..., -0.1562,  0.4844,  0.2969],
         [ 0.0938,  0.1250,  0.1875,  ...,  0.2031,  0.0625,  0.2188],
         [ 0.0000, -0.0625, -0.0312,  ...,  0.0000,  0.0000, -0.0156]]],
       device='cuda:0')

IntervenableModel Loaded from HuggingFace Directly

The following codeblock can reproduce honest_llama-2 chat from the paper Inference-Time Intervention: Eliciting Truthful Answers from a Language Model. The added activations are only ~0.14MB on disk!

# others can download from huggingface and use it directly
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import pyvene as pv

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-chat-hf",
    torch_dtype=torch.bfloat16,
).to("cuda")

pv_model = pv.IntervenableModel.load(
    "zhengxuanzenwu/intervenable_honest_llama2_chat_7B", # the activation diff ~0.14MB
    model,
)

print("llama-2-chat loaded with interventions:")
q = "What's a cure for insomnia that always works?"
prompt = tokenizer(q, return_tensors="pt").to("cuda")
_, iti_response_shared = pv_model.generate(prompt, max_new_tokens=64, do_sample=False)
print(tokenizer.decode(iti_response_shared[0], skip_special_tokens=True))

With this, once you discover some clever intervention schemes, you can share with others quickly without sharing the actual base LMs or the intervention code!

IntervenableModel as Regular nn.Module

You can also use the pv_gpt2 just like a regular torch model component inside another model, or another pipeline as,

import torch
import torch.nn as nn
from typing import List, Optional, Tuple, Union, Dict

class ModelWithIntervenables(nn.Module):
    def __init__(self):
        super(ModelWithIntervenables, self).__init__()
        self.pv_gpt2 = pv_gpt2
        self.relu = nn.ReLU()
        self.fc = nn.Linear(768, 1)
        # Your other downstream components go here

    def forward(
        self, 
        base,
        sources: Optional[List] = None,
        unit_locations: Optional[Dict] = None,
        activations_sources: Optional[Dict] = None,
        subspaces: Optional[List] = None,
    ):
        _, counterfactual_x = self.pv_gpt2(
            base,
            sources,
            unit_locations,
            activations_sources,
            subspaces
        )
        return self.fc(self.relu(counterfactual_x.last_hidden_state))

Complex Intervention Schema as an Object

One key abstraction that pyvene provides is the encapsulation of the intervention schema. While abstraction provides good user-interfact, pyvene can support relatively complex intervention schema. The following helper function generates the schema configuration for path patching on individual attention heads on the output of the OV circuit (i.e., analyzing causal effect of each individual component):

import pyvene as pv

def path_patching_config(
    layer, last_layer, 
    component="head_attention_value_output", unit="h.pos", 
):
    intervening_component = [
        {"layer": layer, "component": component, "unit": unit, "group_key": 0}]
    restoring_components = []
    if not stream.startswith("mlp_"):
        restoring_components += [
            {"layer": layer, "component": "mlp_output", "group_key": 1}]
    for i in range(layer+1, last_layer):
        restoring_components += [
            {"layer": i, "component": "attention_output", "group_key": 1}
            {"layer": i, "component": "mlp_output", "group_key": 1}
        ]
    intervenable_config = IntervenableConfig(intervening_component + restoring_components)
    return intervenable_config

then you can wrap the config generated by this function to a model. And after you have done your intervention, you can share your path patching with others,

_, tokenizer, gpt2 = pv.create_gpt2()

pv_gpt2 = pv.IntervenableModel(
    path_patching_config(4, gpt2.config.n_layer), 
    model=gpt2
)
# saving the path
pv_gpt2.save(
    save_directory="./your_gpt2_path/"
)
# loading the path
pv_gpt2 = pv.IntervenableModel.load(
    "./tmp/",
    model=gpt2)

Selected Tutorials

Level Tutorial Run in Colab Description
Beginner pyvene 101 Introduce you to the basics of pyvene
Intermediate ROME Causal Tracing Reproduce ROME's Results on Factual Associations with GPT2-XL
Intermediate Intervention v.s. Probing Illustrates how to run trainable interventions and probing with pythia-6.9B
Advanced Trainable Interventions for Causal Abstraction Illustrates how to train an intervention to discover causal mechanisms of a neural model

Contributing to This Library

Please see our guidelines about how to contribute to this repository.

Pull requests, bug reports, and all other forms of contribution are welcomed and highly encouraged! :octocat:

A Little Guide for Causal Abstraction: From Interventions to Gain Interpretability Insights

Basic interventions are fun but we cannot make any causal claim systematically. To gain actual interpretability insights, we want to measure the counterfactual behaviors of a model in a data-driven fashion. In other words, if the model responds systematically to your interventions, then you start to associate certain regions in the network with a high-level concept. We also call this alignment search process with model internals.

Understanding Causal Mechanisms with Static Interventions

Here is a more concrete example,

def add_three_numbers(a, b, c):
    var_x = a + b
    return var_x + c

The function solves a 3-digit sum problem. Let's say, we trained a neural network to solve this problem perfectly. "Can we find the representation of (a + b) in the neural network?". We can use this library to answer this question. Specifically, we can do the following,

  • Step 1: Form Interpretability (Alignment) Hypothesis: We hypothesize that a set of neurons N aligns with (a + b).
  • Step 2: Counterfactual Testings: If our hypothesis is correct, then swapping neurons N between examples would give us expected counterfactual behaviors. For instance, the values of N for (1+2)+3, when swapping with N for (2+3)+4, the output should be (2+3)+3 or (1+2)+4 depending on the direction of the swap.
  • Step 3: Reject Sampling of Hypothesis: Running tests multiple times and aggregating statistics in terms of counterfactual behavior matching. Proposing a new hypothesis based on the results.

To translate the above steps into API calls with the library, it will be a single call,

intervenable.eval_alignment(
    train_dataloader=test_dataloader,
    compute_metrics=compute_metrics,
    inputs_collator=inputs_collator
)

where you provide testing data (basically interventional data and the counterfactual behavior you are looking for) along with your metrics functions. The library will try to evaluate the alignment with the intervention you specified in the config.


Understanding Causal Mechanism with Trainable Interventions

The alignment searching process outlined above can be tedious when your neural network is large. For a single hypothesized alignment, you basically need to set up different intervention configs targeting different layers and positions to verify your hypothesis. Instead of doing this brute-force search process, you can turn it into an optimization problem which also has other benefits such as distributed alignments.

In its crux, we basically want to train an intervention to have our desired counterfactual behaviors in mind. And if we can indeed train such interventions, we claim that causally informative information should live in the intervening representations! Below, we show one type of trainable intervention models.interventions.RotatedSpaceIntervention as,

class RotatedSpaceIntervention(TrainableIntervention):
    
    """Intervention in the rotated space."""
    def forward(self, base, source):
        rotated_base = self.rotate_layer(base)
        rotated_source = self.rotate_layer(source)
        # interchange
        rotated_base[:self.interchange_dim] = rotated_source[:self.interchange_dim]
        # inverse base
        output = torch.matmul(rotated_base, self.rotate_layer.weight.T)
        return output

Instead of activation swapping in the original representation space, we first rotate them, and then do the swap followed by un-rotating the intervened representation. Additionally, we try to use SGD to learn a rotation that lets us produce expected counterfactual behavior. If we can find such rotation, we claim there is an alignment. If the cost is between X and Y.ipynb tutorial covers this with an advanced version of distributed alignment search, Boundless DAS. There are recent works outlining potential limitations of doing a distributed alignment search as well.

You can now also make a single API call to train your intervention,

intervenable.train_alignment(
    train_dataloader=train_dataloader,
    compute_loss=compute_loss,
    compute_metrics=compute_metrics,
    inputs_collator=inputs_collator
)

where you need to pass in a trainable dataset, and your customized loss and metrics function. The trainable interventions can later be saved on to your disk. You can also use intervenable.evaluate() your interventions in terms of customized objectives.

Citation

If you use this repository, please consider to cite our library paper:

@article{wu2024pyvene,
  title={pyvene: A Library for Understanding and Improving {P}y{T}orch Models via Interventions},
  author={Wu, Zhengxuan and Geiger, Atticus and Arora, Aryaman and Huang, Jing and Wang, Zheng and Noah D. Goodman and Christopher D. Manning and Christopher Potts},
  booktitle={arXiv:2403.07809},
  url={arxiv.org/abs/2403.07809},
  year={2024}
}

Related Works in Discovering Causal Mechanism of LLMs

If you would like to read more works on this area, here is a list of papers that try to align or discover the causal mechanisms of LLMs.

Star History

Star History Chart