The Model() Class
Model(model_type, model_path)
The base class for a text generation model.
Parameters:
model_type(string): Optional. The type of model to use. Defaults toGPT-NEO.
model_path(string): Optional. The path to the model to use. Defaults toEleutherAI/gpt-neo-125m, but can be any model from the HuggingFace model hub, or a path to a local folder containing the model's files.Returns:
A
Modelobject.
Attributes
Model.name
The name of the model.
Type:
string
Model.path
The location where the model's files are stored once downloaded. Defaults to
[Python package location]/[andromeda folder]/andromeda-latest.Type:
string
Model.train_args
Training arguments for the model. Wrapper for HappyTransformer's
GENTrainArgsclass.Properties:
Model.train_args.learning_rate: defaults to5e-05Model.train_args.num_train_epochs: defaults to3Model.train_args.batch_size: defaults to1Model.train_args.adam_beta1: defaults to0.9Model.train_args.adam_beta2: defaults to0.999Model.train_args.adam_epsilon: defaults to1e-08Model.train_args.max_grad_norm: defaults to1.0Model.train_args.save_preprocessed_data: defaults toFalseModel.train_args.save_preprocessed_data_path: defaults to''Model.train_args.load_preprocessed_data: defaults toFalseModel.train_args.load_preprocessed_data_path: defaults to''Model.train_args.preprocessing_processes: defaults to1Model.train_args.mlm_probability: defaults to0.15Model.train_args.fp16: defaults toFalseType:
GENTrainArgsobject
Model.eval_args
Evaluation arguments for the model. Wrapper for HappyTransformer's
GENEvalArgsclass.Properties:
Model.eval_args.batch_size: defaults to1Model.eval_args.save_preprocessed_data: defaults toFalseModel.eval_args.save_preprocessed_data_path: defaults to''Model.eval_args.load_preprocessed_data: defaults toFalseModel.eval_args.load_preprocessed_data_path: defaults to''Model.eval_args.preprocessing_processes: defaults to1Model.eval_args.mlm_probability: defaults to0.15Type:
GENEvalArgsobject
Model.config
The model configuration. Wrapper for HappyTransformer's
GENSettingsclass.Properties:
Model.config.min_length: defaults to10Model.config.max_length: defaults to50Model.config.do_sample: defaults toFalseModel.config.early_stopping: defaults toFalseModel.config.num_beams: defaults to1Model.config.temperature: defaults to1Model.config.top_k: defaults to50Model.config.no_repeat_ngram_size: defaults to0Model.config.top_p: defaults to1Model.config.bad_words: defaults toNoneType:
GENSettingsobject
Methods
Model.train()
Trains the model. Training arguments can be configured using the
train_argsattribute.Parameters:
input_filepath(string): The path to the file containing the training data.Returns:
None.
Model.evaluate()
Evaluates the model. Evaluation arguments can be configured using the
eval_argsattribute.Parameters:
input_filepath(string): The path to the file containing the evaluation data.Returns:
loss(float).
Model.generate()
Uses the model to generate text. Generation arguments can be configured using the
configattribute.Parameters:
text(string): The text to use as the prompt for the model.raw(bool): Optional. Whether to return the raw output from the model. Defaults to False.Returns:
The generated text as a string, or the raw output from the model (an object of type
GenerationResult).
Model.save()
Saves the model to disk, overwriting any previously saved files. The model will be saved to the
pathattribute, along with any configuration changes.Parameters:
- None
Returns:
None.