Skip to content

API Reference

This API reference provides comprehensive documentation for all public classes and functions in NERxiv. For more detailed examples and usage patterns, see the How-to Guides and Tutorial sections.

nerxiv.chunker

Chunker

Bases: BaseChunker

Chunk text into smaller parts for processing and avoiding the token limit of an LLM model.

Source code in nerxiv/chunker.py
class Chunker(BaseChunker):
    """
    Chunk text into smaller parts for processing and avoiding the token limit of an LLM model.
    """

    def __init__(self, text: str = "", **kwargs):
        super().__init__(text=text, **kwargs)
        self.chunk_size = kwargs.get("chunk_size", 1000)
        self.chunk_overlap = kwargs.get("chunk_overlap", 200)

    def chunk_text(self) -> list[Document]:
        """
        Chunk the text into smaller parts.
        This is done to avoid exceeding the token limit of the LLM.

        Returns:
            list[Document]: The list of chunks as `Document` objects.
        """
        text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=self.chunk_size,
            chunk_overlap=self.chunk_overlap,
            add_start_index=True,
        )

        # ! we define a list of `Document` objects in LangChain to use the `split_documents(pages)` method
        pages = [
            Document(
                page_content=self.text,
                metadata={"source": "nerxiv.chunker.Chunker"},
            )
        ]
        chunks = text_splitter.split_documents(pages)
        self.logger.info(f"Text chunked into {len(chunks)} fixed chunks.")
        return chunks

chunk_size = kwargs.get('chunk_size', 1000)

chunk_overlap = kwargs.get('chunk_overlap', 200)

__init__(text='', **kwargs)

Source code in nerxiv/chunker.py
def __init__(self, text: str = "", **kwargs):
    super().__init__(text=text, **kwargs)
    self.chunk_size = kwargs.get("chunk_size", 1000)
    self.chunk_overlap = kwargs.get("chunk_overlap", 200)

chunk_text()

Chunk the text into smaller parts. This is done to avoid exceeding the token limit of the LLM.

RETURNS DESCRIPTION
list[Document]

list[Document]: The list of chunks as Document objects.

Source code in nerxiv/chunker.py
def chunk_text(self) -> list[Document]:
    """
    Chunk the text into smaller parts.
    This is done to avoid exceeding the token limit of the LLM.

    Returns:
        list[Document]: The list of chunks as `Document` objects.
    """
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=self.chunk_size,
        chunk_overlap=self.chunk_overlap,
        add_start_index=True,
    )

    # ! we define a list of `Document` objects in LangChain to use the `split_documents(pages)` method
    pages = [
        Document(
            page_content=self.text,
            metadata={"source": "nerxiv.chunker.Chunker"},
        )
    ]
    chunks = text_splitter.split_documents(pages)
    self.logger.info(f"Text chunked into {len(chunks)} fixed chunks.")
    return chunks

SemanticChunker

Bases: BaseChunker

Sentence-level semantic chunker using spaCy.

Source code in nerxiv/chunker.py
class SemanticChunker(BaseChunker):
    """Sentence-level semantic chunker using spaCy."""

    def __init__(self, text: str = "", **kwargs):
        super().__init__(text=text, **kwargs)

    def chunk_text(self) -> list[Document]:
        """
        Chunk the text into smaller parts based on semantic meaning using spaCy.

        Returns:
            list[Document]: The list of chunks as `Document` objects.
        """
        nlp = get_spacy_model()
        doc = nlp(self.text)
        chunks = []
        for sent in doc.sents:
            chunks.append(
                Document(
                    page_content=sent.text.strip(),
                    metadata={"source": "nerxiv.chunker.SemanticChunker"},
                )
            )
        self.logger.info(f"Text chunked into {len(chunks)} semantic chunks.")
        return chunks

__init__(text='', **kwargs)

Source code in nerxiv/chunker.py
def __init__(self, text: str = "", **kwargs):
    super().__init__(text=text, **kwargs)

chunk_text()

Chunk the text into smaller parts based on semantic meaning using spaCy.

RETURNS DESCRIPTION
list[Document]

list[Document]: The list of chunks as Document objects.

Source code in nerxiv/chunker.py
def chunk_text(self) -> list[Document]:
    """
    Chunk the text into smaller parts based on semantic meaning using spaCy.

    Returns:
        list[Document]: The list of chunks as `Document` objects.
    """
    nlp = get_spacy_model()
    doc = nlp(self.text)
    chunks = []
    for sent in doc.sents:
        chunks.append(
            Document(
                page_content=sent.text.strip(),
                metadata={"source": "nerxiv.chunker.SemanticChunker"},
            )
        )
    self.logger.info(f"Text chunked into {len(chunks)} semantic chunks.")
    return chunks

AdvancedSemanticChunker

Bases: BaseChunker

KMeans-based semantic chunker using SentenceTransformer embeddings.

Source code in nerxiv/chunker.py
class AdvancedSemanticChunker(BaseChunker):
    """KMeans-based semantic chunker using SentenceTransformer embeddings."""

    def __init__(self, text: str = "", **kwargs):
        super().__init__(text=text, **kwargs)
        self.n_chunks = kwargs.get("n_chunks", 10)
        self.model = kwargs.get("model", "all-MiniLM-L6-v2")

    def chunk_text(self) -> list[Document]:
        """
        Chunk the text into smaller parts based on semantic meaning using KMeans clustering on sentence embeddings.

        Returns:
            list[Document]: The list of chunks as `Document` objects.
        """
        nlp = get_spacy_model()
        doc = nlp(self.text)
        sentences = [sent.text.strip() for sent in doc.sents if sent.text.strip()]
        # Adjust number of clusters: at most one per sentence (1 <= n_chunks <= len(sentences))
        n_chunks = max(min(self.n_chunks, len(sentences)), 1)

        # Fit KMeans to the sentence embeddings
        model = get_sentence_model(model=self.model)
        embeddings = model.encode(sentences, show_progress_bar=False)
        kmeans = KMeans(n_clusters=n_chunks, random_state=42)
        clusters = kmeans.fit_predict(embeddings)
        chunks = [[] for _ in range(n_chunks)]
        for i, cluster in enumerate(clusters):
            chunks[cluster].append(sentences[i])

        # Combine sentences in each cluster to form chunks
        final_chunks = [
            Document(
                page_content=" ".join(chunk),
                metadata={"source": "nerxiv.chunker.AdvancedSemanticChunker"},
            )
            for chunk in chunks
            if chunk
        ]
        self.logger.info(f"Text chunked into {len(final_chunks)} semantic chunks.")
        return final_chunks

n_chunks = kwargs.get('n_chunks', 10)

model = kwargs.get('model', 'all-MiniLM-L6-v2')

__init__(text='', **kwargs)

Source code in nerxiv/chunker.py
def __init__(self, text: str = "", **kwargs):
    super().__init__(text=text, **kwargs)
    self.n_chunks = kwargs.get("n_chunks", 10)
    self.model = kwargs.get("model", "all-MiniLM-L6-v2")

chunk_text()

Chunk the text into smaller parts based on semantic meaning using KMeans clustering on sentence embeddings.

RETURNS DESCRIPTION
list[Document]

list[Document]: The list of chunks as Document objects.

Source code in nerxiv/chunker.py
def chunk_text(self) -> list[Document]:
    """
    Chunk the text into smaller parts based on semantic meaning using KMeans clustering on sentence embeddings.

    Returns:
        list[Document]: The list of chunks as `Document` objects.
    """
    nlp = get_spacy_model()
    doc = nlp(self.text)
    sentences = [sent.text.strip() for sent in doc.sents if sent.text.strip()]
    # Adjust number of clusters: at most one per sentence (1 <= n_chunks <= len(sentences))
    n_chunks = max(min(self.n_chunks, len(sentences)), 1)

    # Fit KMeans to the sentence embeddings
    model = get_sentence_model(model=self.model)
    embeddings = model.encode(sentences, show_progress_bar=False)
    kmeans = KMeans(n_clusters=n_chunks, random_state=42)
    clusters = kmeans.fit_predict(embeddings)
    chunks = [[] for _ in range(n_chunks)]
    for i, cluster in enumerate(clusters):
        chunks[cluster].append(sentences[i])

    # Combine sentences in each cluster to form chunks
    final_chunks = [
        Document(
            page_content=" ".join(chunk),
            metadata={"source": "nerxiv.chunker.AdvancedSemanticChunker"},
        )
        for chunk in chunks
        if chunk
    ]
    self.logger.info(f"Text chunked into {len(final_chunks)} semantic chunks.")
    return final_chunks

nerxiv.rag.retriever

RETRIEVER_VERSION = '1.0.0'

Retriever

Bases: ABC

Abstract base class for retrieving relevant chunks of text from a list of documents. This class is designed to be inherited from and implemented by specific retriever classes.

Source code in nerxiv/rag/retriever.py
class Retriever(ABC):
    """
    Abstract base class for retrieving relevant chunks of text from a list of documents. This class
    is designed to be inherited from and implemented by specific retriever classes.
    """

    def __init__(self, **kwargs):
        self.logger = kwargs.get("logger", logger)

        self.model_name = kwargs.get("model", "all-MiniLM-L6-v2")
        self.n_top_chunks = kwargs.get("n_top_chunks", 5)

        self.query = kwargs.get("query")
        if not self.query:
            raise ValueError(
                "`query` is required for the retriever. Please provide a query string."
            )

    @abstractmethod
    def get_relevant_chunks(self, chunks: list[Document] = []) -> str:
        """Find the most relevant chunks describing methods."""
        pass

logger = kwargs.get('logger', logger)

model_name = kwargs.get('model', 'all-MiniLM-L6-v2')

n_top_chunks = kwargs.get('n_top_chunks', 5)

query = kwargs.get('query')

__init__(**kwargs)

Source code in nerxiv/rag/retriever.py
def __init__(self, **kwargs):
    self.logger = kwargs.get("logger", logger)

    self.model_name = kwargs.get("model", "all-MiniLM-L6-v2")
    self.n_top_chunks = kwargs.get("n_top_chunks", 5)

    self.query = kwargs.get("query")
    if not self.query:
        raise ValueError(
            "`query` is required for the retriever. Please provide a query string."
        )

get_relevant_chunks(chunks=[])

Find the most relevant chunks describing methods.

Source code in nerxiv/rag/retriever.py
@abstractmethod
def get_relevant_chunks(self, chunks: list[Document] = []) -> str:
    """Find the most relevant chunks describing methods."""
    pass

CustomRetriever

Bases: Retriever

A custom retriever class that uses the SentenceTransformer model to retrieve relevant chunks of text from a list of documents.

Source code in nerxiv/rag/retriever.py
class CustomRetriever(Retriever):
    """
    A custom retriever class that uses the `SentenceTransformer` model to retrieve relevant chunks of text
    from a list of documents.
    """

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.model = SentenceTransformer(self.model_name)
        self.logger.info(f"Loaded SentenceTransformer model: {self.model_name}")

    def get_relevant_chunks(self, chunks: list[Document] = []) -> str:
        """
        Retrieves the most relevant chunks of text from a list of documents using the `SentenceTransformer` model.

        Args:
            chunks (list[Document], optional): The chunks to be ranked. Defaults to [].

        Returns:
            str: The top `n_top_chunks` chunks joined in a single string with the highest similarity score with respect to the query.
        """
        if not chunks:
            self.logger.warning("No chunks provided.")
            return []
        chunks = [chunk.page_content for chunk in chunks]

        # Converting `self.query` and `chunks` to embeddings
        query_embeddings = self.model.encode(self.query, convert_to_tensor=True)
        chunk_embeddings = self.model.encode(chunks, convert_to_tensor=True)

        # TODO check other similarities
        similarities = util.pytorch_cos_sim(query_embeddings, chunk_embeddings).squeeze(
            0
        )
        sorted_similarities = similarities.sort(descending=True)

        # Get the top `n_top_chunks` chunks with the highest similarity score with respect to the query
        top_chunks = [
            chunks[i] for i in sorted_similarities.indices[: self.n_top_chunks]
        ]
        self.logger.info(
            f"Top {self.n_top_chunks} chunks retrieved with similarities of {sorted_similarities.values[: self.n_top_chunks]}"
        )
        return "\n\n".join(top_chunk for top_chunk in top_chunks)

model = SentenceTransformer(self.model_name)

__init__(**kwargs)

Source code in nerxiv/rag/retriever.py
def __init__(self, **kwargs):
    super().__init__(**kwargs)
    self.model = SentenceTransformer(self.model_name)
    self.logger.info(f"Loaded SentenceTransformer model: {self.model_name}")

get_relevant_chunks(chunks=[])

Retrieves the most relevant chunks of text from a list of documents using the SentenceTransformer model.

PARAMETER DESCRIPTION
chunks

The chunks to be ranked. Defaults to [].

TYPE: list[Document] DEFAULT: []

RETURNS DESCRIPTION
str

The top n_top_chunks chunks joined in a single string with the highest similarity score with respect to the query.

TYPE: str

Source code in nerxiv/rag/retriever.py
def get_relevant_chunks(self, chunks: list[Document] = []) -> str:
    """
    Retrieves the most relevant chunks of text from a list of documents using the `SentenceTransformer` model.

    Args:
        chunks (list[Document], optional): The chunks to be ranked. Defaults to [].

    Returns:
        str: The top `n_top_chunks` chunks joined in a single string with the highest similarity score with respect to the query.
    """
    if not chunks:
        self.logger.warning("No chunks provided.")
        return []
    chunks = [chunk.page_content for chunk in chunks]

    # Converting `self.query` and `chunks` to embeddings
    query_embeddings = self.model.encode(self.query, convert_to_tensor=True)
    chunk_embeddings = self.model.encode(chunks, convert_to_tensor=True)

    # TODO check other similarities
    similarities = util.pytorch_cos_sim(query_embeddings, chunk_embeddings).squeeze(
        0
    )
    sorted_similarities = similarities.sort(descending=True)

    # Get the top `n_top_chunks` chunks with the highest similarity score with respect to the query
    top_chunks = [
        chunks[i] for i in sorted_similarities.indices[: self.n_top_chunks]
    ]
    self.logger.info(
        f"Top {self.n_top_chunks} chunks retrieved with similarities of {sorted_similarities.values[: self.n_top_chunks]}"
    )
    return "\n\n".join(top_chunk for top_chunk in top_chunks)

LangChainRetriever

Bases: Retriever

Source code in nerxiv/rag/retriever.py
class LangChainRetriever(Retriever):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name)
        self.logger.info(f"Loaded `HuggingFaceEmbeddings` model: {self.model_name}")

    def get_relevant_chunks(self, chunks: list[Document] = []) -> str:
        """
        Retrieves the most relevant chunks of text from a list of documents using the `HuggingFaceEmbeddings` model.

        Args:
            chunks (list[Document], optional): The chunks to be ranked. Defaults to [].

        Returns:
            str: The top `n_top_chunks` chunks joined in a single string with the highest similarity score with respect to the query.
        """
        vector_store = InMemoryVectorStore(self.embeddings)
        _ = vector_store.add_documents(documents=chunks)
        results = vector_store.similarity_search_with_score(
            self.query, k=self.n_top_chunks
        )
        top_chunks, scores = (
            [r[0].page_content for r in results],
            [r[1] for r in results],
        )
        self.logger.info(
            f"Top {self.n_top_chunks} chunks retrieved with similarities of {scores}"
        )
        return "\n\n".join(top_chunk for top_chunk in top_chunks)

embeddings = HuggingFaceEmbeddings(model_name=(self.model_name))

__init__(**kwargs)

Source code in nerxiv/rag/retriever.py
def __init__(self, **kwargs):
    super().__init__(**kwargs)

    self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name)
    self.logger.info(f"Loaded `HuggingFaceEmbeddings` model: {self.model_name}")

get_relevant_chunks(chunks=[])

Retrieves the most relevant chunks of text from a list of documents using the HuggingFaceEmbeddings model.

PARAMETER DESCRIPTION
chunks

The chunks to be ranked. Defaults to [].

TYPE: list[Document] DEFAULT: []

RETURNS DESCRIPTION
str

The top n_top_chunks chunks joined in a single string with the highest similarity score with respect to the query.

TYPE: str

Source code in nerxiv/rag/retriever.py
def get_relevant_chunks(self, chunks: list[Document] = []) -> str:
    """
    Retrieves the most relevant chunks of text from a list of documents using the `HuggingFaceEmbeddings` model.

    Args:
        chunks (list[Document], optional): The chunks to be ranked. Defaults to [].

    Returns:
        str: The top `n_top_chunks` chunks joined in a single string with the highest similarity score with respect to the query.
    """
    vector_store = InMemoryVectorStore(self.embeddings)
    _ = vector_store.add_documents(documents=chunks)
    results = vector_store.similarity_search_with_score(
        self.query, k=self.n_top_chunks
    )
    top_chunks, scores = (
        [r[0].page_content for r in results],
        [r[1] for r in results],
    )
    self.logger.info(
        f"Top {self.n_top_chunks} chunks retrieved with similarities of {scores}"
    )
    return "\n\n".join(top_chunk for top_chunk in top_chunks)

nerxiv.rag.generator

LLMGenerator

LLMGenerator class for generating answers with the generate method using a specified LLM model specified by the user. The LLM model is loaded using OllamaLLM implementation in LangChain.

Read more in https://python.langchain.com/docs/integrations/llms/ollama/

Source code in nerxiv/rag/generator.py
class LLMGenerator:
    """
    LLMGenerator class for generating answers with the `generate` method using a specified LLM model
    specified by the user. The LLM model is loaded using `OllamaLLM` implementation in LangChain.

    Read more in https://python.langchain.com/docs/integrations/llms/ollama/
    """

    def __init__(self, text: str = "", **kwargs):
        if not text:
            raise ValueError("`text` is required for LLM generation.")
        self.text = text
        self.logger = kwargs.get("logger", logger)

        # Define default values for metadata extraction
        defaults = {
            "temperature": 0.2,
        }
        merged_args = {**defaults, **kwargs}

        # Dynamically detect valid OllamaLLM kwargs
        sig = inspect.signature(OllamaLLM)
        valid_params = set(sig.parameters.keys())
        # Filter kwargs
        ollama_kwargs = {k: v for k, v in merged_args.items() if k in valid_params}

        self.llm = OllamaLLM(**ollama_kwargs)
        self.logger.info(f"LLM model: {ollama_kwargs.get('model')}")

    def generate(
        self,
        prompt: str = "",
        regex: str = r"\n\nAnswer\: *",
        del_regex: str = r"\n\nAnswer\: *",
    ) -> str:
        """
        Generates an answer using the specified LLM model and the provided prompt provided that
        the token limit is not exceeded.

        Args:
            prompt (str, optional): The prompt to be used for generating the answer. Defaults to "".
            regex (str, optional): The regex pattern to search for in the answer. Defaults to r"\n\nAnswer\: *".
            del_regex (str, optional): The regex pattern to delete from the answer. Defaults to r"\n\nAnswer\: *".

        Returns:
            str: The generated and cleaned answer from the LLM model.
        """

        def _delete_thinking(answer: str = "") -> str:
            """
            Deletes the thinking process from the answer string by removing the <think> block.

            Args:
                answer (str, optional): The input text to delete the thinking block. Defaults to "".

            Returns:
                str: The answer string with the <think> block removed.
            """
            return re.sub(r"<think>.*?</think>\n*", "", answer, flags=re.DOTALL)

        def _clean_answer(regex: str, del_regex: str, answer: str = "") -> str:
            """
            Cleans the answer by removing unwanted characters and extracting the relevant part of the answer.

            Args:
                regex (str): The regex pattern to search for in the answer.
                del_regex (str): The regex pattern to delete from the answer.
                answer (str, optional): The answer input. Defaults to "".

            Returns:
                str: The cleaned answer.
            """
            match = re.search(regex, answer, flags=re.IGNORECASE)
            if match:
                start = match.start()
                answer = answer[start:]
                answer = re.sub(del_regex, "", answer)
            return answer

        llm_answer = self.llm.invoke(prompt)
        answer_withouth_think_block = _delete_thinking(answer=llm_answer)
        return _clean_answer(
            answer=answer_withouth_think_block, regex=regex, del_regex=del_regex
        )

text = text

logger = kwargs.get('logger', logger)

llm = OllamaLLM(**ollama_kwargs)

__init__(text='', **kwargs)

Source code in nerxiv/rag/generator.py
def __init__(self, text: str = "", **kwargs):
    if not text:
        raise ValueError("`text` is required for LLM generation.")
    self.text = text
    self.logger = kwargs.get("logger", logger)

    # Define default values for metadata extraction
    defaults = {
        "temperature": 0.2,
    }
    merged_args = {**defaults, **kwargs}

    # Dynamically detect valid OllamaLLM kwargs
    sig = inspect.signature(OllamaLLM)
    valid_params = set(sig.parameters.keys())
    # Filter kwargs
    ollama_kwargs = {k: v for k, v in merged_args.items() if k in valid_params}

    self.llm = OllamaLLM(**ollama_kwargs)
    self.logger.info(f"LLM model: {ollama_kwargs.get('model')}")

generate(prompt='', regex='\\n\\nAnswer\\: *', del_regex='\\n\\nAnswer\\: *')

    Generates an answer using the specified LLM model and the provided prompt provided that
    the token limit is not exceeded.

    Args:
        prompt (str, optional): The prompt to be used for generating the answer. Defaults to "".
        regex (str, optional): The regex pattern to search for in the answer. Defaults to r"

Answer\: *". del_regex (str, optional): The regex pattern to delete from the answer. Defaults to r"

Answer\: *".

    Returns:
        str: The generated and cleaned answer from the LLM model.
Source code in nerxiv/rag/generator.py
def generate(
    self,
    prompt: str = "",
    regex: str = r"\n\nAnswer\: *",
    del_regex: str = r"\n\nAnswer\: *",
) -> str:
    """
    Generates an answer using the specified LLM model and the provided prompt provided that
    the token limit is not exceeded.

    Args:
        prompt (str, optional): The prompt to be used for generating the answer. Defaults to "".
        regex (str, optional): The regex pattern to search for in the answer. Defaults to r"\n\nAnswer\: *".
        del_regex (str, optional): The regex pattern to delete from the answer. Defaults to r"\n\nAnswer\: *".

    Returns:
        str: The generated and cleaned answer from the LLM model.
    """

    def _delete_thinking(answer: str = "") -> str:
        """
        Deletes the thinking process from the answer string by removing the <think> block.

        Args:
            answer (str, optional): The input text to delete the thinking block. Defaults to "".

        Returns:
            str: The answer string with the <think> block removed.
        """
        return re.sub(r"<think>.*?</think>\n*", "", answer, flags=re.DOTALL)

    def _clean_answer(regex: str, del_regex: str, answer: str = "") -> str:
        """
        Cleans the answer by removing unwanted characters and extracting the relevant part of the answer.

        Args:
            regex (str): The regex pattern to search for in the answer.
            del_regex (str): The regex pattern to delete from the answer.
            answer (str, optional): The answer input. Defaults to "".

        Returns:
            str: The cleaned answer.
        """
        match = re.search(regex, answer, flags=re.IGNORECASE)
        if match:
            start = match.start()
            answer = answer[start:]
            answer = re.sub(del_regex, "", answer)
        return answer

    llm_answer = self.llm.invoke(prompt)
    answer_withouth_think_block = _delete_thinking(answer=llm_answer)
    return _clean_answer(
        answer=answer_withouth_think_block, regex=regex, del_regex=del_regex
    )

nerxiv.rag.agents

BaseAgent

Bases: ABC

Abstract base class for extraction agents.

All agents should implement the run method which executes the extraction workflow and returns structured results.

Source code in nerxiv/rag/agents.py
class BaseAgent(ABC):
    """Abstract base class for extraction agents.

    All agents should implement the `run` method which executes the
    extraction workflow and returns structured results.
    """

    @abstractmethod
    def run(self, text: str, prompt: BasePrompt | None, **kwargs) -> None:
        """Execute the extraction workflow.

        Args:
            text (str): Input text to process.
            prompt (BasePrompt | None): Prompt template for LLM.
            **kwargs (dict): Additional parameters specific to the agent

        Returns:
            Dictionary containing extraction results
        """
        pass

run(text, prompt, **kwargs)

Execute the extraction workflow.

PARAMETER DESCRIPTION
text

Input text to process.

TYPE: str

prompt

Prompt template for LLM.

TYPE: BasePrompt | None

**kwargs

Additional parameters specific to the agent

TYPE: dict DEFAULT: {}

RETURNS DESCRIPTION
None

Dictionary containing extraction results

Source code in nerxiv/rag/agents.py
@abstractmethod
def run(self, text: str, prompt: BasePrompt | None, **kwargs) -> None:
    """Execute the extraction workflow.

    Args:
        text (str): Input text to process.
        prompt (BasePrompt | None): Prompt template for LLM.
        **kwargs (dict): Additional parameters specific to the agent

    Returns:
        Dictionary containing extraction results
    """
    pass

RAGExtractorAgent

Bases: BaseAgent

Source code in nerxiv/rag/agents.py
class RAGExtractorAgent(BaseAgent):
    def __init__(
        self,
        chunker: type | object,
        retriever: type | object,
        generator: type | object,
        **kwargs,
    ):
        self.chunker = chunker
        self.retriever = retriever
        self.generator = generator

        self.chunker_params = kwargs.get("chunker_params", {})
        self.retriever_params = kwargs.get("retriever_params", {})
        self.generator_params = kwargs.get("generator_params", {})

        self.logger = kwargs.get("logger", logger)

    def _obj_name(self, obj: type | object) -> str:
        """
        Gets the class name of an object or an object instance.

        Args:
            obj (type | object): The object or class to get the name of.

        Returns:
            str: The class name of the object or the name of the class itself.
        """
        if isinstance(obj, type):
            return obj.__name__
        return obj.__class__.__name__

    def _instantiate(self, component: type | object, required_kwargs: dict) -> Any:
        """I
        nstantiate `component` if it's a class, otherwise return the instance.

        The method merges `required_kwargs` with the preconfigured kwargs for
        that `component` (used by the caller).
        """
        if isinstance(component, type):
            # component is a class; instantiate with kwargs
            return component(**required_kwargs)
        # assume component is already an instance
        return component

    def _get_chunks(
        self,
        chunker_hash: str,
        text: str,
        chunker_name: str,
        cached_chunks_group: h5py.Group,
        global_time: float,
    ) -> list[Document]:
        """
        Gets the chunks when the chunker class needs to be instantiated (not read from cache).

        Args:
            chunker_hash (str): The chunker hash.
            text (str): The text to be chunked.
            chunker_name (str): The name of the chunker.
            cached_chunks_group (h5py.Group): The HDF5 group to store cached chunks.
            global_time (float): The global start time.

        Returns:
            list[Document]: The list of chunks.
        """
        self.logger.info(f"Performing new chunking with hash {chunker_hash}")
        chunker = self._instantiate(self.chunker, {**self.chunker_params, "text": text})
        chunks = chunker.chunk_text()
        # Store chunks in cache
        cached_chunks_group.attrs["chunker"] = f"nerxiv.chunker.{chunker_name}"
        cached_chunks_group.attrs["chunker_params"] = json.dumps(self.chunker_params)
        cached_chunks_group.attrs["run_time"] = time.time() - global_time
        for i, chunk in enumerate(chunks):
            cached_chunks_group.create_dataset(
                f"chunk_{i:04d}", data=chunk.page_content.encode("utf-8")
            )
        return chunks

    def parse(self, answer: str) -> dict[str, Any] | None:
        """
        Parse JSON from LLM answer if the prompt is of `StructuredPrompt` type. This method
        attempts to extract JSON from markdown code blocks (```json...```) and if successful,
        return the parsed data. If no code blocks are found, it tries to find JSON patterns
        directly in the text.

        Args:
            answer (str): Raw LLM output string.

        Returns:
            dict[str, Any] | None: Parsed JSON data as a dictionary, or None if parsing fails.
        """
        try:
            # Try to extract JSON from markdown code block
            json_match = re.search(
                r"```json\s*\n(.*?)\n\s*```", answer, re.DOTALL | re.IGNORECASE
            )

            if json_match:
                json_str = json_match.group(1)
            else:
                # Try to find JSON without code blocks
                # Look for content between { and } or [ and ]
                json_match = re.search(r"(\{.*\}|\[.*\])", answer, re.DOTALL)
                if json_match:
                    json_str = json_match.group(1)
                else:
                    self.logger.error("No JSON found in answer")
                    return None

            # Parse JSON
            data = json.loads(json_str)
        except json.JSONDecodeError as e:
            self.logger.error(f"JSON decode error: {e}")
            return None
        return data

    def run(
        self,
        file: h5py.File | None = None,
        text: str = "",
        prompt: BasePrompt | None = None,
    ) -> None:
        """
        Runs the RAG extraction pipeline: chunking, retrieval, and generation.
        Chunking and retrieval results are cached in the provided HDF5 file to avoid redundant computations.
        If the prompt is of type `StructuredPrompt`, the generated answer is parsed into structured data.

        Args:
            file (h5py.File | None, optional): The file were to store the metainformation. Defaults to None.
            text (str, optional): The text to process. Defaults to "".
            prompt (BasePrompt | None, optional): The prompt used for the LLM prompting. Defaults to None.
        """
        # initial checks
        if not file:
            self.logger.critical("`file` is required for RAGExtractorAgent")
            return None
        if not text:
            self.logger.critical("`text` is required for RAGExtractorAgent")
            return None
        if not prompt:
            self.logger.critical("`prompt` is required for RAGExtractorAgent")
            return None
        query = self.retriever_params.get("query")
        if not query:
            self.logger.critical(
                "`retriever_params` must include a 'query' key for RAGExtractorAgent"
            )
            return None

        # Create group to store RAG pipeline
        global_time = time.time()
        rag_group = file.require_group("rag_extraction")

        ### Chunking
        chunker_name = self._obj_name(self.chunker)

        # Use caching to compute chunker hash and avoid re-chunking if already done
        chunker_hash = compute_chunker_hash(
            text=text,
            chunker_name=chunker_name,
            chunker_params=self.chunker_params,
        )
        chunks_cache_group = rag_group.require_group("chunks_cache")
        if chunker_hash in chunks_cache_group:  # reuse existing chunks
            self.logger.info(f"Reusing chunks from cache with hash {chunker_hash}")
            cached_chunks_group = chunks_cache_group[chunker_hash]
            if len(cached_chunks_group.keys()) == 0:
                chunks = self._get_chunks(
                    chunker_hash,
                    text,
                    chunker_name,
                    cached_chunks_group,
                    global_time,
                )
            else:
                chunks = []
                for key in cached_chunks_group.keys():
                    chunks.append(
                        Document(
                            page_content=cached_chunks_group[key][()].decode("utf-8"),
                            metadata={"source": f"nerxiv.chunker.{chunker_name}"},
                        )
                    )
        else:  # perform new chunking
            chunks = self._get_chunks(
                chunker_hash,
                text,
                chunker_name,
                chunks_cache_group.create_group(chunker_hash),
                global_time,
            )

        ### Retrieval
        start_time = time.time()
        retriever_name = self._obj_name(self.retriever)

        # Use caching to compute retriever hash and avoid re-retrieving if already done
        retriever_hash = compute_retriever_hash(
            chunker_hash=chunker_hash, retriever_params=self.retriever_params
        )
        retrieval_cache_group = rag_group.require_group("retrieval_cache")
        if retriever_hash in retrieval_cache_group:  # reuse existing retrieval
            self.logger.info(
                f"Reusing retrieval results from cache with hash {retriever_hash}"
            )
            cached_retrieval_group = retrieval_cache_group[retriever_hash]
            text = cached_retrieval_group["retrieved_text"][()].decode("utf-8")
        else:  # perform new retrieval
            self.logger.info(f"Performing new retrieval with hash {retriever_hash}")
            retriever = self._instantiate(self.retriever, {**self.retriever_params})
            text = retriever.get_relevant_chunks(chunks=chunks)

            # Store retrieval results in cache
            cached_retrieval_group = retrieval_cache_group.create_group(retriever_hash)
            cached_retrieval_group.attrs["retriever"] = (
                f"nerxiv.rag.retriever.{retriever_name}"
            )
            cached_retrieval_group.attrs["chunker_hash"] = chunker_hash
            cached_retrieval_group.attrs["retriever_hash"] = retriever_hash
            cached_retrieval_group.attrs["retriever_params"] = json.dumps(
                self.retriever_params
            )
            cached_retrieval_group.create_dataset(
                "retrieved_text", data=text.encode("utf-8")
            )

        ### Generation
        start_time = time.time()
        generator = self._instantiate(
            self.generator, {"text": text, **self.generator_params}
        )
        built_prompt = prompt.build(text=text)
        answer = generator.generate(prompt=built_prompt)

        # Store raw answer in HDF5
        raw_answer_group = rag_group.require_group("raw_llm_answers")
        # Define group for the `query` (e.g., raw_llm_answers/filter_material_formula)
        query_group = raw_answer_group.require_group(
            self.retriever_params.get("query_name")
        )
        # Define group for the run ID (e.g., raw_llm_answers/filter_material_formula/run_0000)
        existing_runs = list(query_group.keys())
        run_id = f"run_{len(existing_runs):04d}"  # Auto-increment run ID
        run_group = query_group.create_group(run_id)
        # Store general metainformation
        run_group.attrs["model"] = self.generator_params.get("model", "gpt-oss:20b")
        # Store prompt and answer
        run_group.create_dataset("prompt", data=built_prompt.encode("utf-8"))
        run_group.create_dataset("answer", data=answer.encode("utf-8"))
        # Store references to cached data instead of duplicating
        run_group.attrs["chunker_hash"] = chunker_hash
        run_group.attrs["retriever_hash"] = retriever_hash
        # Store elapsed time and timestamp of the run
        run_group.attrs["elapsed_time"] = time.time() - start_time
        run_group.attrs["timestamp"] = datetime.datetime.now().isoformat()

        # Store total RAG pipeline time
        paper_time = time.time() - global_time
        rag_group.attrs["elapsed_time"] = paper_time
        self.logger.info(f"Prompting completed for {file} in {paper_time:.2f} seconds.")

        ### Return structured result
        # Parse and validate output for structured prompts
        if isinstance(prompt, StructuredPrompt):
            data = self.parse(answer=answer)
            if data is None:
                self.logger.error("Failed to parse LLM answer.")
                return None
            try:
                schema = prompt.output_schema
                data_fields = data[self._obj_name(schema)]
                filled_schema = schema(**data_fields)
                self.logger.info(f"Schema={filled_schema}")
            except Exception as e:
                self.logger.error(f"Validation error: {e}")
                return None

chunker = chunker

retriever = retriever

generator = generator

chunker_params = kwargs.get('chunker_params', {})

retriever_params = kwargs.get('retriever_params', {})

generator_params = kwargs.get('generator_params', {})

logger = kwargs.get('logger', logger)

__init__(chunker, retriever, generator, **kwargs)

Source code in nerxiv/rag/agents.py
def __init__(
    self,
    chunker: type | object,
    retriever: type | object,
    generator: type | object,
    **kwargs,
):
    self.chunker = chunker
    self.retriever = retriever
    self.generator = generator

    self.chunker_params = kwargs.get("chunker_params", {})
    self.retriever_params = kwargs.get("retriever_params", {})
    self.generator_params = kwargs.get("generator_params", {})

    self.logger = kwargs.get("logger", logger)

parse(answer)

Parse JSON from LLM answer if the prompt is of StructuredPrompt type. This method attempts to extract JSON from markdown code blocks (json...) and if successful, return the parsed data. If no code blocks are found, it tries to find JSON patterns directly in the text.

PARAMETER DESCRIPTION
answer

Raw LLM output string.

TYPE: str

RETURNS DESCRIPTION
dict[str, Any] | None

dict[str, Any] | None: Parsed JSON data as a dictionary, or None if parsing fails.

Source code in nerxiv/rag/agents.py
def parse(self, answer: str) -> dict[str, Any] | None:
    """
    Parse JSON from LLM answer if the prompt is of `StructuredPrompt` type. This method
    attempts to extract JSON from markdown code blocks (```json...```) and if successful,
    return the parsed data. If no code blocks are found, it tries to find JSON patterns
    directly in the text.

    Args:
        answer (str): Raw LLM output string.

    Returns:
        dict[str, Any] | None: Parsed JSON data as a dictionary, or None if parsing fails.
    """
    try:
        # Try to extract JSON from markdown code block
        json_match = re.search(
            r"```json\s*\n(.*?)\n\s*```", answer, re.DOTALL | re.IGNORECASE
        )

        if json_match:
            json_str = json_match.group(1)
        else:
            # Try to find JSON without code blocks
            # Look for content between { and } or [ and ]
            json_match = re.search(r"(\{.*\}|\[.*\])", answer, re.DOTALL)
            if json_match:
                json_str = json_match.group(1)
            else:
                self.logger.error("No JSON found in answer")
                return None

        # Parse JSON
        data = json.loads(json_str)
    except json.JSONDecodeError as e:
        self.logger.error(f"JSON decode error: {e}")
        return None
    return data

run(file=None, text='', prompt=None)

Runs the RAG extraction pipeline: chunking, retrieval, and generation. Chunking and retrieval results are cached in the provided HDF5 file to avoid redundant computations. If the prompt is of type StructuredPrompt, the generated answer is parsed into structured data.

PARAMETER DESCRIPTION
file

The file were to store the metainformation. Defaults to None.

TYPE: File | None DEFAULT: None

text

The text to process. Defaults to "".

TYPE: str DEFAULT: ''

prompt

The prompt used for the LLM prompting. Defaults to None.

TYPE: BasePrompt | None DEFAULT: None

Source code in nerxiv/rag/agents.py
def run(
    self,
    file: h5py.File | None = None,
    text: str = "",
    prompt: BasePrompt | None = None,
) -> None:
    """
    Runs the RAG extraction pipeline: chunking, retrieval, and generation.
    Chunking and retrieval results are cached in the provided HDF5 file to avoid redundant computations.
    If the prompt is of type `StructuredPrompt`, the generated answer is parsed into structured data.

    Args:
        file (h5py.File | None, optional): The file were to store the metainformation. Defaults to None.
        text (str, optional): The text to process. Defaults to "".
        prompt (BasePrompt | None, optional): The prompt used for the LLM prompting. Defaults to None.
    """
    # initial checks
    if not file:
        self.logger.critical("`file` is required for RAGExtractorAgent")
        return None
    if not text:
        self.logger.critical("`text` is required for RAGExtractorAgent")
        return None
    if not prompt:
        self.logger.critical("`prompt` is required for RAGExtractorAgent")
        return None
    query = self.retriever_params.get("query")
    if not query:
        self.logger.critical(
            "`retriever_params` must include a 'query' key for RAGExtractorAgent"
        )
        return None

    # Create group to store RAG pipeline
    global_time = time.time()
    rag_group = file.require_group("rag_extraction")

    ### Chunking
    chunker_name = self._obj_name(self.chunker)

    # Use caching to compute chunker hash and avoid re-chunking if already done
    chunker_hash = compute_chunker_hash(
        text=text,
        chunker_name=chunker_name,
        chunker_params=self.chunker_params,
    )
    chunks_cache_group = rag_group.require_group("chunks_cache")
    if chunker_hash in chunks_cache_group:  # reuse existing chunks
        self.logger.info(f"Reusing chunks from cache with hash {chunker_hash}")
        cached_chunks_group = chunks_cache_group[chunker_hash]
        if len(cached_chunks_group.keys()) == 0:
            chunks = self._get_chunks(
                chunker_hash,
                text,
                chunker_name,
                cached_chunks_group,
                global_time,
            )
        else:
            chunks = []
            for key in cached_chunks_group.keys():
                chunks.append(
                    Document(
                        page_content=cached_chunks_group[key][()].decode("utf-8"),
                        metadata={"source": f"nerxiv.chunker.{chunker_name}"},
                    )
                )
    else:  # perform new chunking
        chunks = self._get_chunks(
            chunker_hash,
            text,
            chunker_name,
            chunks_cache_group.create_group(chunker_hash),
            global_time,
        )

    ### Retrieval
    start_time = time.time()
    retriever_name = self._obj_name(self.retriever)

    # Use caching to compute retriever hash and avoid re-retrieving if already done
    retriever_hash = compute_retriever_hash(
        chunker_hash=chunker_hash, retriever_params=self.retriever_params
    )
    retrieval_cache_group = rag_group.require_group("retrieval_cache")
    if retriever_hash in retrieval_cache_group:  # reuse existing retrieval
        self.logger.info(
            f"Reusing retrieval results from cache with hash {retriever_hash}"
        )
        cached_retrieval_group = retrieval_cache_group[retriever_hash]
        text = cached_retrieval_group["retrieved_text"][()].decode("utf-8")
    else:  # perform new retrieval
        self.logger.info(f"Performing new retrieval with hash {retriever_hash}")
        retriever = self._instantiate(self.retriever, {**self.retriever_params})
        text = retriever.get_relevant_chunks(chunks=chunks)

        # Store retrieval results in cache
        cached_retrieval_group = retrieval_cache_group.create_group(retriever_hash)
        cached_retrieval_group.attrs["retriever"] = (
            f"nerxiv.rag.retriever.{retriever_name}"
        )
        cached_retrieval_group.attrs["chunker_hash"] = chunker_hash
        cached_retrieval_group.attrs["retriever_hash"] = retriever_hash
        cached_retrieval_group.attrs["retriever_params"] = json.dumps(
            self.retriever_params
        )
        cached_retrieval_group.create_dataset(
            "retrieved_text", data=text.encode("utf-8")
        )

    ### Generation
    start_time = time.time()
    generator = self._instantiate(
        self.generator, {"text": text, **self.generator_params}
    )
    built_prompt = prompt.build(text=text)
    answer = generator.generate(prompt=built_prompt)

    # Store raw answer in HDF5
    raw_answer_group = rag_group.require_group("raw_llm_answers")
    # Define group for the `query` (e.g., raw_llm_answers/filter_material_formula)
    query_group = raw_answer_group.require_group(
        self.retriever_params.get("query_name")
    )
    # Define group for the run ID (e.g., raw_llm_answers/filter_material_formula/run_0000)
    existing_runs = list(query_group.keys())
    run_id = f"run_{len(existing_runs):04d}"  # Auto-increment run ID
    run_group = query_group.create_group(run_id)
    # Store general metainformation
    run_group.attrs["model"] = self.generator_params.get("model", "gpt-oss:20b")
    # Store prompt and answer
    run_group.create_dataset("prompt", data=built_prompt.encode("utf-8"))
    run_group.create_dataset("answer", data=answer.encode("utf-8"))
    # Store references to cached data instead of duplicating
    run_group.attrs["chunker_hash"] = chunker_hash
    run_group.attrs["retriever_hash"] = retriever_hash
    # Store elapsed time and timestamp of the run
    run_group.attrs["elapsed_time"] = time.time() - start_time
    run_group.attrs["timestamp"] = datetime.datetime.now().isoformat()

    # Store total RAG pipeline time
    paper_time = time.time() - global_time
    rag_group.attrs["elapsed_time"] = paper_time
    self.logger.info(f"Prompting completed for {file} in {paper_time:.2f} seconds.")

    ### Return structured result
    # Parse and validate output for structured prompts
    if isinstance(prompt, StructuredPrompt):
        data = self.parse(answer=answer)
        if data is None:
            self.logger.error("Failed to parse LLM answer.")
            return None
        try:
            schema = prompt.output_schema
            data_fields = data[self._obj_name(schema)]
            filled_schema = schema(**data_fields)
            self.logger.info(f"Schema={filled_schema}")
        except Exception as e:
            self.logger.error(f"Validation error: {e}")
            return None

nerxiv.prompts.prompts

Example

Bases: BaseModel

Represents an example for a prompt, containing input text and expected output.

Source code in nerxiv/prompts/prompts.py
class Example(BaseModel):
    """
    Represents an example for a prompt, containing input text and expected output.
    """

    input: str = Field(..., description="Input text for the prompt.")
    output: str = Field(..., description="Expected output from the prompt.")

input = Field(..., description='Input text for the prompt.')

output = Field(..., description='Expected output from the prompt.')

BasePrompt

Bases: BaseModel

Base class used as an interface for other prompt classes. It defines the common fields and methods that all prompts should implement. This class is not meant to be instantiated directly.

Source code in nerxiv/prompts/prompts.py
class BasePrompt(BaseModel):
    """
    Base class used as an interface for other prompt classes. It defines the common fields and methods
    that all prompts should implement. This class is not meant to be instantiated directly.
    """

    expert: str = Field(
        ...,
        description="""
        The expert or main field of expertise for the prompt. For example, 'Condensed Matter Physics'.
        """,
    )

    sub_field_expertise: str | None = Field(
        None,
        description="""
        The sub-field of expertise for the prompt. For example, 'many-body physics simulations'.
        """,
    )

    examples: list[Example] = Field(
        [],
        description="""
        Examples to illustrate the prompt. These are formatted as:

        'Examples of how to answer the prompt:
        Example 1:
            - Input text: `example.input`
            - Answer: `example.output`'

        They are used to guide the model on how to answer the prompt.
        """,
    )

    constraints: list[str] = Field(
        [],
        description="""
        Constraints to be followed in the output of the prompt. These are formatted as

        'Important constraints when generating the output: `constraints`'.

        They are mainly used as instructions to avoid unused text, broken formats or sentences, etc.
        """,
    )

    def _build_intro(self) -> str:
        """
        Builds the introduction for the prompt, which includes the `expert` and `sub_field_expertise` of the LLM.

        Returns:
            str: The formatted introduction string.
        """
        expert_lines = f"You are a {self.expert} assistant"
        if self.sub_field_expertise:
            expert_lines = (
                f"{expert_lines} with expertise in {self.sub_field_expertise}"
            )
        return expert_lines

    def _build_examples(self) -> str:
        """
        Builds the examples for the prompt, which illustrate how to answer the prompt.

        Returns:
            str: The formatted examples string.
        """
        example_lines = "Examples of how to answer the prompt:"
        for i, example in enumerate(self.examples):
            example_lines += f"\nExample {i + 1}:\n- Input text: {example.input}\n  Answer: {example.output}"
        return example_lines

    def _build_constraints(self) -> str:
        """
        Builds the constraints for the prompt, which are important instructions to follow when generating the output.

        Returns:
            str: The formatted constraints string.
        """
        constraint_lines = "Important constraints when generating the output:"
        for constraint in self.constraints:
            constraint_lines += f"\n- {constraint}"
        return constraint_lines

    def build(self) -> str:
        """
        Builds the prompt based on the fields defined in this class. This is used to format the prompt
        and append the `text` to be sent to the LLM for generation.

        Raises:
            NotImplementedError: This method should be implemented in subclasses.

        Returns:
            str: The formatted prompt ready to be sent to the LLM.
        """
        raise NotImplementedError("This method should be implemented in subclasses.")

expert = Field(..., description="\n The expert or main field of expertise for the prompt. For example, 'Condensed Matter Physics'.\n ")

sub_field_expertise = Field(None, description="\n The sub-field of expertise for the prompt. For example, 'many-body physics simulations'.\n ")

examples = Field([], description="\n Examples to illustrate the prompt. These are formatted as:\n\n 'Examples of how to answer the prompt:\n Example 1:\n - Input text: `example.input`\n - Answer: `example.output`'\n\n They are used to guide the model on how to answer the prompt.\n ")

constraints = Field([], description="\n Constraints to be followed in the output of the prompt. These are formatted as\n\n 'Important constraints when generating the output: `constraints`'.\n\n They are mainly used as instructions to avoid unused text, broken formats or sentences, etc.\n ")

build()

Builds the prompt based on the fields defined in this class. This is used to format the prompt and append the text to be sent to the LLM for generation.

RAISES DESCRIPTION
NotImplementedError

This method should be implemented in subclasses.

RETURNS DESCRIPTION
str

The formatted prompt ready to be sent to the LLM.

TYPE: str

Source code in nerxiv/prompts/prompts.py
def build(self) -> str:
    """
    Builds the prompt based on the fields defined in this class. This is used to format the prompt
    and append the `text` to be sent to the LLM for generation.

    Raises:
        NotImplementedError: This method should be implemented in subclasses.

    Returns:
        str: The formatted prompt ready to be sent to the LLM.
    """
    raise NotImplementedError("This method should be implemented in subclasses.")

Prompt

Bases: BasePrompt

Represents a prompt object with various fields to define its structure and content. The final prompt is built using the build() method, which formats the prompt based on the provided text and the fields defined in this class.

Source code in nerxiv/prompts/prompts.py
class Prompt(BasePrompt):
    """
    Represents a prompt object with various fields to define its structure and content. The final prompt
    is built using the `build()` method, which formats the prompt based on the provided text and the fields defined in this class.
    """

    # instruction fields
    main_instruction: str = Field(
        ...,
        description="""
        Main instruction for the prompt. This has to be written in the imperative form, e.g. 'identify all mentions of the system being simulated'.
        The format in the prompt is "Given the following scientific text, your task is `main_instruction`",
        """,
    )

    secondary_instructions: list[str] = Field(
        [],
        description="""
        Secondary instructions for the prompt. These are additional instructions that complement `main_instruction`
        and are formatted as "Additionally, you also need to follow these instructions: `secondary_instructions`".
        """,
    )

    def _build_instructions(self) -> str:
        """
        Builds the instructions for the prompt using the `main_instruction` and `secondary_instructions` fields. This is
        used to format the instructions that will be sent to the LLM for generation.

        Returns:
            str: The formatted instructions string.
        """
        instruction_lines = f"Given the following scientific text, your task is: {self.main_instruction}"
        if self.secondary_instructions:
            instruction_lines = f"{instruction_lines}\nAdditionally, you also need to follow these instructions:"
            for sec_instruction in self.secondary_instructions:
                instruction_lines += f"\n- {sec_instruction}"
        return instruction_lines

    def build(self, text: str) -> str:
        """
        Builds the prompt based on the fields defined in this class. This is used to format the prompt
        and append the `text` to be sent to the LLM for generation.

        Args:
            text (str): The text to append to the prompt.

        Returns:
            str: The formatted prompt ready to be sent to the LLM.
        """
        lines = []

        # Expertise lines
        if self.expert:
            lines.append(self._build_intro())

        # Instructions
        if self.main_instruction:
            lines.append(self._build_instructions())

        # Constraints
        if self.constraints:
            lines.append(self._build_constraints())

        # Examples
        if self.examples:
            lines.append(self._build_examples())

        # Appending text
        lines.append(f"\nText:\n{text}")
        return "\n".join(lines)

main_instruction = Field(..., description='\n Main instruction for the prompt. This has to be written in the imperative form, e.g. \'identify all mentions of the system being simulated\'.\n The format in the prompt is "Given the following scientific text, your task is `main_instruction`",\n ')

secondary_instructions = Field([], description='\n Secondary instructions for the prompt. These are additional instructions that complement `main_instruction`\n and are formatted as "Additionally, you also need to follow these instructions: `secondary_instructions`".\n ')

build(text)

Builds the prompt based on the fields defined in this class. This is used to format the prompt and append the text to be sent to the LLM for generation.

PARAMETER DESCRIPTION
text

The text to append to the prompt.

TYPE: str

RETURNS DESCRIPTION
str

The formatted prompt ready to be sent to the LLM.

TYPE: str

Source code in nerxiv/prompts/prompts.py
def build(self, text: str) -> str:
    """
    Builds the prompt based on the fields defined in this class. This is used to format the prompt
    and append the `text` to be sent to the LLM for generation.

    Args:
        text (str): The text to append to the prompt.

    Returns:
        str: The formatted prompt ready to be sent to the LLM.
    """
    lines = []

    # Expertise lines
    if self.expert:
        lines.append(self._build_intro())

    # Instructions
    if self.main_instruction:
        lines.append(self._build_instructions())

    # Constraints
    if self.constraints:
        lines.append(self._build_constraints())

    # Examples
    if self.examples:
        lines.append(self._build_examples())

    # Appending text
    lines.append(f"\nText:\n{text}")
    return "\n".join(lines)

StructuredPrompt

Bases: BasePrompt

Represents a prompt object with various fields to define its structure and content. The final prompt is built using the build() method, which formats the prompt based on the provided text and the fields defined in this class.

Note: The main difference with the Prompt class is that StructuredPrompt is designed to work with a specific output schema, so instead of using main_instructions, secondary_instructions and constraints, the instructions are automatically defined by output_schema and target_fields.

Source code in nerxiv/prompts/prompts.py
class StructuredPrompt(BasePrompt):
    """
    Represents a prompt object with various fields to define its structure and content. The final prompt
    is built using the `build()` method, which formats the prompt based on the provided text and the fields defined in this class.

    **Note**: The main difference with the `Prompt` class is that `StructuredPrompt` is designed to work with a specific output schema,
    so instead of using `main_instructions`, `secondary_instructions` and `constraints`, the instructions are automatically defined by `output_schema`
    and `target_fields`.
    """

    output_schema: type[BaseModel] = Field(
        ...,
        description="""
        The target `BaseModel` schema in which the fields to be extracted are defined.
        """,
    )

    target_fields: list[str] = Field(
        ...,
        description="""
        The fields within `output_schema` that the prompt should extract. If set to `all`, all fields defined in `output_schema` will be extracted.
        """,
    )

    @model_validator(mode="after")
    @classmethod
    def validate_target_fields_in_schema(cls, data: Any) -> Any:
        """
        Validates that the `target_fields` are defined in the `output_schema` and that they are of type `Field`.

        Args:
            data (Any): The data containing the fields values to validate.

        Returns:
            Any: The data with the validated fields.
        """
        model_properties = data.output_schema.model_json_schema().get("properties", {})
        for field in data.target_fields:
            if field == "all":
                data.target_fields = list(model_properties.keys())
                break
            if field not in model_properties:
                raise ValueError(
                    f"Field '{field}' is not defined in the output schema '{data.output_schema.__name__}'."
                )
        return data

    def _build_instructions(self) -> str:
        """
        Builds the instructions for the prompt using the `output_schema` and `target_fields` fields. This is
        used to format the instructions that will be sent to the LLM for generation.

        Returns:
            str: The formatted instructions string.
        """
        # gets the model schema metadata as a dictionary
        model = self.output_schema.model_json_schema()

        # name and description of the class which inherits from BaseModel
        name = model.get("title")
        description = clean_description(
            model.get("description", "<<no definition provided>>")
        )
        instruction_lines = f"Given the following scientific text, your task is: to identify all mentions of the {name} section. "
        instruction_lines += f"This is defined as a {description} "

        instruction_lines += "You must extract the values of the following fields:"
        # getting the fields defined for the class and maching them with `target_fields`
        properties = model.get("properties", {})
        for field in self.target_fields:
            prop = properties.get(field, {})
            prop_description = clean_description(prop.get("description"))
            prop_types = [
                p.get("type") for p in prop.get("anyOf", []) if p.get("type") != "null"
            ]  # only non-null types
            if not prop_types:
                instruction_lines += f"\n- {field} defined as {prop_description}"
            else:
                prop_type = prop_types[0]
                if prop_type == "object":
                    prop_type = "dictionary"
                instruction_lines += f"\n- {field} defined as {prop_description} and which is of type {prop_type}"
            # TODO add data type

        instruction_lines += (
            "\nYou must return the extracted values in JSON format:"
            "\n```json\n"
            "{\n"
            f"  '{name}': " + "{\n"
        )
        for field in self.target_fields:
            instruction_lines += f"    '{field}': <parsed-value>,\n"

        instruction_lines += "  }\n}\n```\n"
        instruction_lines += "Note that <parsed-value> means a value of the correct type defined for that field."
        return instruction_lines

    def build(self, text: str) -> str:
        """
        Builds the prompt based on the fields defined in this class. This is used to format the prompt
        and append the `text` to be sent to the LLM for generation.

        Args:
            text (str): The text to append to the prompt.

        Returns:
            str: The formatted prompt ready to be sent to the LLM.
        """
        lines = []

        # Expertise lines
        if self.expert:
            lines.append(self._build_intro())

        # Instructions
        lines.append(self._build_instructions())

        # Constraints
        if self.constraints:
            lines.append(self._build_constraints())

        # Examples
        if self.examples:
            lines.append(self._build_examples())

        # Appending text
        lines.append(f"\nText:\n{text}")
        return "\n".join(lines)

output_schema = Field(..., description='\n The target `BaseModel` schema in which the fields to be extracted are defined.\n ')

target_fields = Field(..., description='\n The fields within `output_schema` that the prompt should extract. If set to `all`, all fields defined in `output_schema` will be extracted.\n ')

validate_target_fields_in_schema(data)

Validates that the target_fields are defined in the output_schema and that they are of type Field.

PARAMETER DESCRIPTION
data

The data containing the fields values to validate.

TYPE: Any

RETURNS DESCRIPTION
Any

The data with the validated fields.

TYPE: Any

Source code in nerxiv/prompts/prompts.py
@model_validator(mode="after")
@classmethod
def validate_target_fields_in_schema(cls, data: Any) -> Any:
    """
    Validates that the `target_fields` are defined in the `output_schema` and that they are of type `Field`.

    Args:
        data (Any): The data containing the fields values to validate.

    Returns:
        Any: The data with the validated fields.
    """
    model_properties = data.output_schema.model_json_schema().get("properties", {})
    for field in data.target_fields:
        if field == "all":
            data.target_fields = list(model_properties.keys())
            break
        if field not in model_properties:
            raise ValueError(
                f"Field '{field}' is not defined in the output schema '{data.output_schema.__name__}'."
            )
    return data

build(text)

Builds the prompt based on the fields defined in this class. This is used to format the prompt and append the text to be sent to the LLM for generation.

PARAMETER DESCRIPTION
text

The text to append to the prompt.

TYPE: str

RETURNS DESCRIPTION
str

The formatted prompt ready to be sent to the LLM.

TYPE: str

Source code in nerxiv/prompts/prompts.py
def build(self, text: str) -> str:
    """
    Builds the prompt based on the fields defined in this class. This is used to format the prompt
    and append the `text` to be sent to the LLM for generation.

    Args:
        text (str): The text to append to the prompt.

    Returns:
        str: The formatted prompt ready to be sent to the LLM.
    """
    lines = []

    # Expertise lines
    if self.expert:
        lines.append(self._build_intro())

    # Instructions
    lines.append(self._build_instructions())

    # Constraints
    if self.constraints:
        lines.append(self._build_constraints())

    # Examples
    if self.examples:
        lines.append(self._build_examples())

    # Appending text
    lines.append(f"\nText:\n{text}")
    return "\n".join(lines)

PromptRegistryEntry

Bases: BaseModel

Represents a registry entry for a prompt, containing the retriever query and the prompt itself. This is used to register prompts in the PROMPT_REGISTRY defined in nerxiv.prompts.prompts_registry.py.

Source code in nerxiv/prompts/prompts.py
class PromptRegistryEntry(BaseModel):
    """
    Represents a registry entry for a prompt, containing the retriever query and the prompt itself. This
    is used to register prompts in the `PROMPT_REGISTRY` defined in `nerxiv.prompts.prompts_registry.py`.
    """

    retriever_query: str = Field(..., description="The query used in the retriever.")
    prompt: BasePrompt = Field(..., description="The prompt to use for the query.")

    @field_validator("retriever_query", mode="before")
    @classmethod
    def clean_retriever_query(cls, value: str) -> str:
        """
        Cleans the retriever query by removing extra whitespace and newlines.

        Args:
            value (str): The retriever query to clean.

        Returns:
            str: The cleaned retriever query.
        """
        return " ".join(value.split())

retriever_query = Field(..., description='The query used in the retriever.')

prompt = Field(..., description='The prompt to use for the query.')

clean_retriever_query(value)

Cleans the retriever query by removing extra whitespace and newlines.

PARAMETER DESCRIPTION
value

The retriever query to clean.

TYPE: str

RETURNS DESCRIPTION
str

The cleaned retriever query.

TYPE: str

Source code in nerxiv/prompts/prompts.py
@field_validator("retriever_query", mode="before")
@classmethod
def clean_retriever_query(cls, value: str) -> str:
    """
    Cleans the retriever query by removing extra whitespace and newlines.

    Args:
        value (str): The retriever query to clean.

    Returns:
        str: The cleaned retriever query.
    """
    return " ".join(value.split())

nerxiv.utils.utils

answer_to_dict(answer='', logger=logger)

Converts the answer string to a list of dictionaries by removing unwanted characters. This is useful when prompting the LLM to return a list of objects containing metainformation in a structured way.

PARAMETER DESCRIPTION
answer

The answer string to be converted to a list of dictionaries. Defaults to "".

TYPE: str DEFAULT: ''

logger

The logger to log messages. Defaults to logger.

TYPE: BoundLoggerLazyProxy DEFAULT: logger

RETURNS DESCRIPTION
list[dict]

list[dict]: The list of dictionaries extracted from the answer string.

Source code in nerxiv/utils/utils.py
def answer_to_dict(
    answer: str = "", logger: "BoundLoggerLazyProxy" = logger
) -> list[dict]:
    """
    Converts the answer string to a list of dictionaries by removing unwanted characters. This is useful when
    prompting the LLM to return a list of objects containing metainformation in a structured way.

    Args:
        answer (str, optional): The answer string to be converted to a list of dictionaries. Defaults to "".
        logger (BoundLoggerLazyProxy, optional): The logger to log messages. Defaults to logger.

    Returns:
        list[dict]: The list of dictionaries extracted from the answer string.
    """
    # Return empty list if answer is empty or the loaded list of dictionaries
    dict_answer = []
    try:
        dict_answer = json.loads(answer)
    except json.JSONDecodeError:
        logger.critical(
            f"Answer is not a valid JSON: {answer}. Please check the answer format."
        )
    return dict_answer

clean_description(description)

Cleans the description by removing extra spaces and leading/trailing whitespace.

PARAMETER DESCRIPTION
description

The description string to be cleaned.

TYPE: str

RETURNS DESCRIPTION
str

The cleaned description string with extra spaces removed.

TYPE: str

Source code in nerxiv/utils/utils.py
def clean_description(description: str) -> str:
    """
    Cleans the description by removing extra spaces and leading/trailing whitespace.

    Args:
        description (str): The description string to be cleaned.

    Returns:
        str: The cleaned description string with extra spaces removed.
    """
    return re.sub(r"\s+", " ", description).strip()

filter_material_formula_predicate(answer)

Predicate function to determine if the answer indicates the presence of a material formula.

PARAMETER DESCRIPTION
answer

The answer string to be evaluated.

TYPE: str

RETURNS DESCRIPTION
bool

True if the answer is "model", indicating a material formula is present; False otherwise.

TYPE: bool

Source code in nerxiv/utils/utils.py
def filter_material_formula_predicate(answer: str) -> bool:
    """
    Predicate function to determine if the answer indicates the presence of a material formula.

    Args:
        answer (str): The answer string to be evaluated.

    Returns:
        bool: True if the answer is "model", indicating a material formula is present; False otherwise.
    """
    return answer == "model"

filter_only_dmft_predicate(answer)

Predicate function to determine if the answer indicates the absence of DMFT method.

PARAMETER DESCRIPTION
answer

The answer string to be evaluated.

TYPE: str

RETURNS DESCRIPTION
bool

True if the answer is not "True", indicating DMFT is not used; False if DMFT is used.

TYPE: bool

Source code in nerxiv/utils/utils.py
def filter_only_dmft_predicate(answer: str) -> bool:
    """
    Predicate function to determine if the answer indicates the absence of DMFT method.

    Args:
        answer (str): The answer string to be evaluated.

    Returns:
        bool: True if the answer is not "True", indicating DMFT is not used; False if DMFT is used.
    """
    return answer != "True"

files_to_subfolder_answer(path='./data', run='run_0000', predicate=None)

Source code in nerxiv/utils/utils.py
def files_to_subfolder_answer(
    path: str = "./data",
    run: str = "run_0000",
    predicate: Callable[[str], bool] | None = None,
) -> None:
    files = list(Path(path).rglob("*.hdf5"))
    for file in files:
        with h5py.File(file, "a") as f:
            run_group = f["raw_llm_answers"][run]
            # run_group only has one key associated with what we are naming the subfolder
            subfolder_name = next(iter(run_group.keys()))

            # Check if the answer is going through a specific predicate (e.g., see `model_predicate()` utility function) or simply checking if the answer is True or False
            answer = run_group[subfolder_name]["answer"][()].decode("utf-8").strip()
            if (predicate or (lambda a: a == "True"))(answer):
                # Create subfolder and move file
                target_dir = file.parent / subfolder_name
                target_dir.mkdir(exist_ok=True)
                target_path = target_dir / file.name
                file.rename(target_path)