toolbox/toolbox/datasets/soda.py

43 lines
1.1 KiB
Python

import os
import pickle
import typing as t
from dataclasses import dataclass
import mashumaro
import pandas as pd
from toolbox.datasets import BaseDataset
from toolbox.utils.dataset import get_data_path
@dataclass(frozen=True)
class SodaEpisode(mashumaro.DataClassDictMixin):
narrative: str
dialogue: t.List[str]
speakers: t.List[str]
relation: str
literal: str
class SodaDataset(BaseDataset[SodaEpisode]):
'''
SODA: Million-scale Dialogue Distillation with Social Commonsense
Contextualization
https://huggingface.co/datasets/allenai/soda
'''
def generator(self) -> t.Generator[SodaEpisode, None, None]:
root_data_path = get_data_path("soda")
file_path = os.path.join(root_data_path, "test.parquet")
df = pd.read_parquet(file_path)
# Iterate through the test part of the SODA dataset
for i in df.index:
yield SodaEpisode(
narrative=df['narrative'][i],
dialogue=df['dialogue'][i],
speakers=df['speakers'][i],
relation=df['relation'][i],
literal=df['literal'][i]
)