import torch import torch.nn as nn from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig import os os.environ["CUDA_VISIBLE_DEVICES"] = "1" os.environ['CURL_CA_BUNDLE'] = '' device = 'cuda:0' if torch.cuda.is_available() else 'cpu' def _logsumexp_by_chunk(logits: torch.Tensor, chunk_size: int = 1024) -> torch.Tensor: seq_len = logits.shape[0] logsumexp_values = torch.zeros((seq_len), device=logits.device, dtype=logits.dtype) for s_idx in range(0, seq_len, chunk_size): end_idx = min(s_idx + chunk_size, seq_len) logsumexp_values[s_idx:end_idx] = torch.logsumexp(logits[s_idx:end_idx], dim=-1) return logsumexp_values def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor, temperature: float = 1.0) -> torch.Tensor: if temperature != 1.0: logits.div_(temperature) # https://github.com/OpenRLHF/OpenRLHF/pull/718#issuecomment-2641081881 if logits.dtype in [torch.float32, torch.float64]: batch_dim = logits.shape[:-1] last_dim = logits.shape[-1] try: from flash_attn.ops.triton.cross_entropy import cross_entropy_loss output = cross_entropy_loss(logits.reshape(-1, last_dim), labels.reshape(-1)) log_probs_labels = -output[0].view(*batch_dim) except ImportError: logits_labels = torch.gather(logits, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) logsumexp_values = _logsumexp_by_chunk(logits.reshape(-1, last_dim)) logsumexp_values = logsumexp_values.view(*batch_dim) log_probs_labels = logits_labels - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x) else: log_probs_labels = [] for row_logits, row_labels in zip(logits, labels): # loop to reduce peak mem consumption row_log_probs = F.log_softmax(row_logits, dim=-1) row_log_probs_labels = row_log_probs.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1) log_probs_labels.append(row_log_probs_labels) log_probs_labels = torch.stack(log_probs_labels) return log_probs_labels def reset_position_ids(attention_mask): position_ids = torch.zeros_like(attention_mask, dtype=torch.long) for i in range(attention_mask.size(0)): mask = attention_mask[i] seq_num = mask.max().item() for index in range(1, seq_num + 1): sample_mask = mask == index sample_length = sample_mask.sum().item() position_ids[i, sample_mask] = torch.arange(sample_length, device=mask.device) return position_ids class actor(nn.Module): def __init__(self, model_name): super().__init__() attn_implementation = "flash_attention_2" self.model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir = '/data/', torch_dtype=torch.bfloat16, device_map=None, attn_implementation= attn_implementation) @torch.no_grad() def generate(self, input_ids: torch.Tensor): # Set default values directly generate_args = { "input_ids": input_ids, "top_k": None, "top_p": None, "do_sample": True, "temperature": 1, "use_cache": True, "num_beams": 1, "attention_mask": None, "eos_token_id": 2, # example eos_token_id, change as needed "pad_token_id": 0, # example pad_token_id, change as needed "min_new_tokens": 1, } # Call generate sequences = self.model.generate(**generate_args) # Prepare mask tensor eos_token_id = generate_args["eos_token_id"] pad_token_id = generate_args["pad_token_id"] return self.process_sequences(sequences, input_ids.size(1), eos_token_id, pad_token_id) def process_sequences(self, sequences: torch.Tensor, input_len, eos_token_id, pad_token_id): attention_mask = (sequences.ne(eos_token_id) & sequences.ne(pad_token_id)).to(dtype=torch.long) seq_length = attention_mask.size(1) eos_indices = seq_length - attention_mask.long().fliplr().argmax(dim=1, keepdim=True).clamp(min=1) sequences.scatter_(dim=1, index=eos_indices, value=eos_token_id) first_token_indices = attention_mask.long().argmax(dim=1, keepdim=True) mask = torch.arange(seq_length).unsqueeze(0).expand(sequences.size(0), -1).to(device=sequences.device) attention_mask = (mask >= first_token_indices) & (mask <= eos_indices).to(dtype=torch.long) state_seq = sequences[:, input_len - 1 : -1] action_mask = state_seq.ne(eos_token_id) & state_seq.ne(pad_token_id) action_mask[:, 0] = 1 return sequences, attention_mask, action_mask def forward(self, sequences, num_actions, attention_mask): position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) output = self.model(sequences, attention_mask=attention_mask, position_ids=position_ids) output["logits"] = output["logits"].to(torch.float32) log_probs = log_probs_from_logits( output["logits"][:, :-1, :], sequences[:, 1:], temperature= 1 ) action_log_probs = log_probs[:, -num_actions:] return action_log_probs class RewardModel(nn.Module): def __init__(self, model_name): super().__init__() attn_implementation = "flash_attention_2" self.model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir = '/data/', torch_dtype=torch.bfloat16, device_map=None, attn_implementation= attn_implementation,) self.model_config = AutoConfig.from_pretrained(model_name, cache_dir = '/data/',) self.score = nn.Linear(self.model_config.hidden_size, 1, dtype= torch.bfloat16) def forward(self, input_ids, attention_mask): position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) outputs = self.model(input_ids, attention_mask=attention_mask, position_ids=position_ids, output_hidden_states=True) last_hidden_states = outputs.hidden_states[-1] values = self.score(last_hidden_states) eos_indices = attention_mask.size(1) - 1 - attention_mask.long().fliplr().argmax(dim=1, keepdim=True) eos_indices = eos_indices.unsqueeze(-1) reward = values.gather(dim=1, index=eos_indices).squeeze(1) # print(f"Reward Model Value: {values}\n Reward: {reward}") return reward class CriticModel(nn.Module): def __init__(self, model_name): super().__init__() attn_implementation = "flash_attention_2" self.model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir = '/data/', torch_dtype=torch.bfloat16, device_map=None, attn_implementation= attn_implementation,) self.model_config = AutoConfig.from_pretrained(model_name, cache_dir = '/data/',) self.score = nn.Linear(self.model_config.hidden_size, 1, dtype= torch.bfloat16) def forward(self, input_ids, attention_mask, num_actions): position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) outputs = self.model(input_ids, attention_mask=attention_mask, position_ids=position_ids, output_hidden_states=True) last_hidden_states = outputs.hidden_states[-1] values = self.score(last_hidden_states) values = values[:, :-1] if num_actions is None: return outputs action_values = values[:, -num_actions:] # print(f"Critic Model Value: {values}\n Action Values: {action_values}") return action_values if __name__ == '__main__': model_name = 'Qwen/Qwen2.5-0.5B-Instruct' actor_model = actor(model_name) reward_model = RewardModel(model_name) critic_model = CriticModel(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir = '/data/') actor_model.to(device) reward_model.to(device) critic_model.to(device) input_text = "Give me a short introduction to large language model" input_ids = tokenizer(input_text, return_tensors='pt').input_ids.to(device) sequences, attention_mask, action_mask = actor_model.generate(input_ids) generated_text = tokenizer.decode(sequences[0], skip_special_tokens=True) print(f"Generate Length: {len(generated_text)}\n Generate Text: {generated_text}") input_text = tokenizer(generated_text, return_tensors='pt').input_ids.to(device) reward_model_out = reward_model(input_text, attention_mask) critic_model_out = critic_model(input_text, attention_mask, 3) actor_model_out = actor_model(sequences, 3, attention_mask) print(actor_model_out)