Source code for class_factory.concept_web.concept_extraction

"""
concept_extraction.py
---------------------

Functions to extract, normalize, and process concept relationships from educational content (lesson readings and objectives).

Features:
- Summarizes lesson readings using LLMs and course-specific prompts.
- Extracts key concepts and their relationships, guided by lesson objectives.
- Normalizes and consolidates concept names using embeddings and inflection.
- Prepares structured data for downstream visualization and analysis.

Dependencies:
- Language Models: OpenAI GPT or similar, DistilBERT (transformers)
- Core Libraries: langchain, torch, inflect

Example:
    from class_factory.concept_web.concept_extraction import extract_relationships
    text = "Democracy relies on voting rights..."
    objectives = "Understand principles of democracy"
    relationships = extract_relationships(text, objectives, "Political Science", llm)
    processed = process_relationships(relationships)

This module is part of the ClassFactory concept mapping pipeline.
"""

# %%
# base libraries
import json
# logger setup
import logging
import os
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple, Union

import inflect
# entity resolution
import torch
import torch.nn.functional as F
# env setup
from dotenv import load_dotenv
from langchain_core.output_parsers import JsonOutputParser, StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from transformers import DistilBertModel, DistilBertTokenizer

from class_factory.concept_web.prompts import (relationship_prompt,
                                               summary_prompt)
from class_factory.utils.llm_validator import Validator
from class_factory.utils.response_parsers import ExtractedRelations
from class_factory.utils.tools import logger_setup, retry_on_json_decode_error

# logging.basicConfig(
#     level=logging.INFO,  # Set your desired level
#     format='%(name)s - %(levelname)s - %(message)s'
# )

# %%


[docs] def summarize_text(text: str, prompt: ChatPromptTemplate, course_name: str, llm: Any, parser: StrOutputParser = StrOutputParser(), verbose: bool = False) -> str: """ Summarize the provided text using a language model and a structured prompt. Args: text (str): The text to be summarized. prompt (ChatPromptTemplate): Prompt template for summarization. course_name (str): Name of the course for context. llm (Any): Language model instance. parser (StrOutputParser, optional): Output parser. Defaults to StrOutputParser(). verbose (bool, optional): Enable detailed logging. Defaults to False. Returns: str: The summary generated by the language model. Raises: ValueError: If validation fails after max retries. """ log_level = logging.INFO if verbose else logging.ERROR logger = logger_setup(log_level=log_level) chain = prompt | llm | parser retries, max_retries = 0, 3 valid = False validator = Validator(llm=llm, log_level=log_level) additional_guidance = "" while not valid and retries < max_retries: summary = chain.invoke({'course_name': course_name, 'text': text, 'additional_guidance': additional_guidance}) logger.info(f"Example summary:\n{summary}") # Validate the generated summary validation_prompt = prompt.format(course_name=course_name, text=text, additional_guidance=additional_guidance) # Use the schema from the parser if available task_schema = parser.pydantic_object.model_json_schema() if hasattr( parser, 'pydantic_object') and hasattr(parser.pydantic_object, 'model_json_schema') else "" val_response = validator.validate( task_description=validation_prompt, generated_response=summary, task_schema=task_schema, specific_guidance=additional_guidance ) logger.info(f"validation output: {val_response}") if int(val_response['status']) == 1: valid = True else: retries += 1 additional_guidance = val_response.get("additional_guidance", "") logger.warning(f"Summary validation failed on attempt {retries}. Reason: {val_response['reasoning']}") if valid: logger.debug("Validation succeeded.") else: raise ValueError("Validation failed after max retries. Ensure correct prompt and input data. Consider use of a different LLM.") return summary
[docs] @retry_on_json_decode_error() def extract_relationships(text: str, objectives: str, course_name: str, llm: Any, verbose: bool = False, logger: Optional[logging.Logger] = None) -> List[Tuple[str, str, str]]: """ Extract key concepts and their relationships from the provided text using an LLM. Args: text (str): The summarized text. objectives (str): Lesson objectives for context. course_name (str): Name of the course. llm (Any): Language model instance. verbose (bool, optional): Enable detailed logging. Defaults to False. logger (Optional[logging.Logger], optional): Logger instance. Defaults to None. Returns: List[Tuple[str, str, str]]: List of (concept1, relationship, concept2) tuples. Raises: ValueError: If validation fails after max retries. JSONDecodeError: If response parsing fails. """ log_level = logging.INFO if verbose else logging.WARNING logger = logger or logging.getLogger(__name__) logger.setLevel(log_level) # Use llm.with_structured_output for structured output if not objectives: objectives = "Not provided." additional_guidance = "" # combined_template = PromptTemplate.from_template(selected_prompt) chain = relationship_prompt | llm.with_structured_output(ExtractedRelations) logger.debug(f"""Querying with:\n{relationship_prompt.format(course_name=course_name, objectives=objectives, text="placeholder", additional_guidance="")}""") validator = Validator(llm=llm, log_level=log_level, score_threshold=7.5) retries, max_retries = 0, 3 valid = False while not valid and retries < max_retries: logger.debug(f"Invoking chain with input: course_name={course_name}, objectives={ objectives}, text=[{text[:100]}...], additional_guidance={additional_guidance}") response = chain.invoke({'course_name': course_name, 'objectives': objectives, 'text': text, 'additional_guidance': additional_guidance}) logger.debug(f"Raw response from chain.invoke: {response}") if response is None: logger.error( "LLM chain.invoke returned None. This may indicate an API failure, bad input, or output parsing error. Check your LLM logs, API key, and input data.") raise RuntimeError("LLM chain.invoke returned None. See logs for details.") # Use .model_dump() to get a dict from the pydantic object data = response.model_dump() # Validate responses response_str = json.dumps(data).replace("{", "{{").replace("}", "}}") # Use the schema from the ExtractedRelations model task_schema = ExtractedRelations.model_json_schema() if hasattr(ExtractedRelations, 'model_json_schema') else "" validation_prompt = relationship_prompt.format(course_name=course_name, objectives=objectives, text=text, additional_guidance=additional_guidance ).replace("{", "{{").replace("}", "}}") val_response = validator.validate( task_description=validation_prompt, generated_response=response_str, task_schema=task_schema, specific_guidance=additional_guidance ) logger.info(f"validation output: {val_response}") if int(val_response['status']) == 1: valid = True else: retries += 1 additional_guidance = val_response.get("additional_guidance", "") logger.warning(f"Relationship validation failed on attempt {retries}. Reason: {val_response['reasoning']}") if valid: logger.info("Validation succeeded.") else: raise ValueError("Validation failed after max retries. Ensure correct prompt and input data. Consider use of a different LLM.") # Extract concepts and relationships as explicit tuples from dict keys relationships = [ (rel["concept_1"], rel["relationship_type"], rel["concept_2"]) for rel in data["relationships"] if all(k in rel for k in ("concept_1", "relationship_type", "concept_2")) ] return relationships
[docs] def extract_concepts_from_relationships(relationships: List[Tuple[str, str, str]]) -> List[str]: """ Extract unique concept names from a list of relationships. Args: relationships (List[Tuple[str, str, str]]): List of relationship tuples or dicts. Returns: List[str]: Unique concept names. """ concepts = set() # Use a set to avoid duplicates for rel in relationships: # Accept both tuple and dict for backward compatibility if isinstance(rel, dict): c1 = rel.get("concept_1", "").lower().strip() c2 = rel.get("concept_2", "").lower().strip() else: c1 = rel[0].lower().strip() c2 = rel[2].lower().strip() if c1: concepts.add(c1) if c2: concepts.add(c2) return list(concepts)
[docs] def get_embeddings(concepts: List[str]) -> Dict[str, torch.Tensor]: """ Generate normalized embeddings for a list of concepts using DistilBERT. Args: concepts (List[str]): List of concept strings. Returns: Dict[str, torch.Tensor]: Mapping of concept to normalized embedding tensor. """ tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") model = DistilBertModel.from_pretrained("distilbert-base-uncased") # Tokenize and process concepts in batches inputs = tokenizer(concepts, return_tensors="pt", padding=True, truncation=True) with torch.no_grad(): outputs = model(**inputs) # Use the [CLS] token's embedding embeddings = outputs.last_hidden_state[:, 0, :] # Shape: (num_concepts, hidden_size) # Normalize embeddings normalized_embeddings = F.normalize(embeddings, p=2, dim=1) # Map concepts directly to embeddings return {concept: normalized_embeddings[idx] for idx, concept in enumerate(concepts)}
[docs] def normalize_concept(concept: str) -> str: """ Normalize a single concept string (lowercase, singularize, strip underscores). Args: concept (str): Concept string. Returns: str: Normalized concept string. """ p = inflect.engine() words = concept.lower().strip().replace('_', ' ').split() normalized_words = [p.singular_noun(word) or word for word in words] return " ".join(normalized_words)
[docs] def normalize_for_embedding(concepts: Union[str, List[str]]) -> Union[str, List[str]]: """ Normalize one or more concepts for embedding. Args: concepts (Union[str, List[str]]): Concept or list of concepts. Returns: Union[str, List[str]]: Normalized concept(s). """ if isinstance(concepts, str): return normalize_concept(concepts) return [normalize_concept(concept) for concept in concepts]
[docs] def normalize_for_output(concept: str) -> str: """ Format a concept for output by replacing spaces with underscores and removing 'is'. Args: concept (str): Concept string. Returns: str: Output-formatted concept string. """ concept_words = [word for word in concept.split() if word != 'is'] return "_".join(concept_words)
[docs] def replace_similar_concepts(existing_concepts: Set[str], new_concept: str, concept_embeddings: Dict[str, torch.Tensor], threshold: float = 0.995) -> str: """ Replace a new concept with an existing similar concept if cosine similarity exceeds threshold. Args: existing_concepts (Set[str]): Set of existing concepts. new_concept (str): New concept string. concept_embeddings (Dict[str, torch.Tensor]): Embeddings for all concepts. threshold (float, optional): Similarity threshold. Defaults to 0.995. Returns: str: Existing or new concept string. """ # Get the embedding of the new concept new_embedding = concept_embeddings[new_concept] for existing_concept in existing_concepts: existing_embedding = concept_embeddings[existing_concept] # Compute cosine similarity similarity = torch.matmul(new_embedding, existing_embedding).item() # If similar, return the existing concept if similarity >= threshold: return existing_concept # If no similar concept is found, return the new concept return new_concept
[docs] def process_relationships(relationships: List[Tuple[str, str, str]], threshold: float = 0.995, max_retries: int = 3) -> List[Tuple[str, str, str]]: """ Normalize and consolidate relationships by merging similar concepts. Args: relationships (List[Tuple[str, str, str]]): List of (concept1, relationship, concept2) tuples. threshold (float, optional): Similarity threshold for merging. Defaults to 0.995. max_retries (int, optional): Max attempts to resolve duplicates. Defaults to 3. Returns: List[Tuple[str, str, str]]: Processed relationships with normalized concepts. """ # Initialize a set to keep track of all unique concepts unique_concepts = set() processed_relationships = [] # Accept both dict and tuple for backward compatibility if isinstance(relationships[0], dict): relationships = [ (rel["concept_1"], rel["relationship_type"], rel["concept_2"]) for rel in relationships if all(k in rel for k in ("concept_1", "relationship_type", "concept_2")) ] # get concepts and embeddings extracted_concepts = extract_concepts_from_relationships(relationships) conceptlist = normalize_for_embedding(extracted_concepts) concept_embeddings = get_embeddings(conceptlist) for c1, relationship, c2 in relationships: c1 = normalize_for_embedding(c1) c2 = normalize_for_embedding(c2) # Replace similar concepts with existing ones retries = 0 concept1 = replace_similar_concepts(unique_concepts, c1, concept_embeddings, threshold) concept2 = replace_similar_concepts(unique_concepts, c2, concept_embeddings, threshold) # Retry resolution if concepts are identical. Max threshold increase = 0.0045 (up from default 0.0095) while concept1 == concept2 and retries < max_retries: retries += 1 concept1 = replace_similar_concepts(unique_concepts, c1, concept_embeddings, threshold + retries * 0.0015) concept2 = replace_similar_concepts(unique_concepts, c2, concept_embeddings, threshold + retries * 0.0015) # # If still identical, skip or revert to original # if concept1 == concept2: # continue # Skip self-referential relationships # Add concepts to the unique set unique_concepts.add(concept1) unique_concepts.add(concept2) # Normalize concepts clean_concept1 = normalize_for_output(concept1) clean_concept2 = normalize_for_output(concept2) clean_relation = normalize_for_output(relationship) # Add the relationship to the processed list processed_relationships.append((clean_concept1, clean_relation, clean_concept2)) return processed_relationships
# %% if __name__ == "__main__": # llm chain setup import yaml from langchain_community.llms import Ollama from langchain_openai import ChatOpenAI from pyprojroot.here import here # self-defined utils from class_factory.utils.load_documents import LessonLoader user_home = Path.home() load_dotenv() OPENAI_KEY = os.getenv('openai_key') OPENAI_ORG = os.getenv('openai_org') # Path definitions with open("class_config.yaml", "r") as file: config = yaml.safe_load(file) class_config = config['PS211'] slide_dir = user_home / class_config['slideDir'] syllabus_path = user_home / class_config['syllabus_path'] readingDir = user_home / class_config['reading_dir'] is_tabular_syllabus = class_config['is_tabular_syllabus'] projectDir = here() # parser = JsonOutputParser(pydantic_object=ExtractedRelations) llm = ChatOpenAI( model="gpt-4o-mini", temperature=0, max_tokens=None, timeout=None, max_retries=2, api_key=OPENAI_KEY, organization=OPENAI_ORG, ) # llm = Ollama( # model="llama3.1", # temperature=0.5, # ) relationship_list = [] conceptlist = [] loader = LessonLoader(syllabus_path=syllabus_path, reading_dir=readingDir, slide_dir=None) # Load documents and lesson objectives for lesson_num in range(3, 4): print(f"Lesson {lesson_num}") lesson_objectives = loader.extract_lesson_objectives(current_lesson=lesson_num, only_current=True) documents = loader.load_lessons(lesson_number_or_range=range(lesson_num, lesson_num + 1)) if not documents: continue for lsn, readings in documents.items(): for reading in readings: summary = summarize_text(reading, prompt=summary_prompt, course_name="American government", llm=llm, verbose=True) # print(summary) relationships = extract_relationships(summary, lesson_objectives, course_name="American government", llm=llm, verbose=True) print(relationships) relationship_list.extend(relationships) concepts = extract_concepts_from_relationships(relationships) conceptlist.extend(concepts) processed_relationships = process_relationships(relationship_list) # %%