Untitled

 avatar
unknown
python
2 years ago
6.7 kB
5
Indexable
import re
from typing import List, Dict, Any, Set, Optional
import json
from metadata import *

DOCSTRING_REGEX_TOKENIZER = re.compile(r"[^\s,'\"`.():\[\]=*;>{\}+-/\\]+|\\+|\.+|\(\)|{\}|\[\]|\(+|\)+|:+|\[+|\]+|{+|\}+|=+|\*+|;+|>+|\++|-+|/+")

def tokenize_docstring(docstring: str) -> List[str]:
    return [t for t in DOCSTRING_REGEX_TOKENIZER.findall(docstring) if t is not None and len(t) > 0]

def load_js(line, idx):
    line = line.strip()
    try:
        js = json.loads(line)
    except json.decoder.JSONDecodeError:
        return None
    js["sample_id"] = idx
    if "processed_docstring" in js:
        js["docstring"] = js.pop("processed_docstring")
    if "processed_docstring_tokens" in js:
        js["docstring_tokens"] = js.pop("processed_docstring_tokens")
    return js

def check_len(args, docstring):
    if isinstance(docstring, str):
        return len(docstring.strip().split()) >= args.min_length
    elif isinstance(docstring, list):
        return len(docstring) >= args.min_length

def prefix_target(data_type, docs, param_name=None, symbol=True):
    if data_type == "return":
        docs = "{} {}".format(data_type, docs)
    elif data_type == "param":
        assert param_name is not None
        docs = "{} {} {}".format(data_type, param_name, docs)
    if symbol:
        docs = "@{}".format(docs)
    return docs

def get_prompt_tokens(data_type, 
                    language, 
                    param_name=None, 
                    prompt_language=False,
                    add_colon=False,
                ):
    """
    prompt_tokens has the form of [<language>, <param>, param_name, :]
    """
    assert data_type in COMMENT_TYPES

    prompt_tokens = [SPECIAL_TOKENS_MAP[data_type]]

    if prompt_language:
        prompt_tokens = [SPECIAL_TOKENS_MAP[language]] + prompt_tokens

    if data_type == "param":
        assert param_name is not None
        prompt_tokens.append(param_name)
        if add_colon:
            prompt_tokens.append(":")

    return prompt_tokens

def get_prompt_and_docstring_tokens(data_type, 
                                    language, 
                                    docs=None, 
                                    param_name=None, 
                                    prompt_language=False,
                                    prefix_target_sequence=False
                                ):
    assert data_type in ["function", "param", "return"]

    prompt_tokens = ["<{}>".format(data_type)]

    if prompt_language:
        prompt_tokens = ["<{}>".format(language)] + prompt_tokens

    if data_type == "param":
        assert param_name is not None
        prompt_tokens.append(param_name)

    docstring_tokens = None

    return prompt_tokens, docstring_tokens

def create_data_sample(js_object, 
                       data_type, 
                       language, 
                       docs=None, 
                       param_name=None, 
                       prompt_language=False,
                       prefix_target_sequence=False):
    assert data_type in COMMENT_TYPES
    js = js_object.copy()
    js["prompt_tokens"] = ["<{}>".format(data_type)]

    if prompt_language:
        js["prompt_tokens"] = ["<{}>".format(language)] + js["prompt_tokens"]
    if data_type == "function": 
        js["comment_type"] = "function"
    else:
        assert docs is not None
        #if language in ["ruby", "java", "javascript", "php"]:
        if prefix_target_sequence:
            docs = prefix_target(data_type, docs, param_name, symbol=True)
        js["docstring_tokens"] = tokenize_docstring(docs)
        if data_type == "param":
            assert param_name is not None
            js["comment_type"] = "param"
            js["prompt_tokens"].append(param_name)
        elif data_type == "return":
            js["comment_type"] = "return"

    return js

def write_samples(data_stream, summaries):
    return_idx = -1
    for sen_idx, summary in enumerate(summaries):
        if summary["comment_type"] == "return":
            return_idx = sen_idx
            break
    if return_idx > -1:
        return_obj = summaries.pop(return_idx)
        summaries.append(return_obj)
    for sen_idx, summary in enumerate(summaries):
        summary["sentence_id"] = sen_idx
        json.dump(summary, data_stream)
        data_stream.write("\n")

def get_function_docstring_tokens(js):
    docstring_tokens = js.get("docstring_tokens")
    if docstring_tokens is None:
        docstring_tokens = js.get("docstring_token")
    if docstring_tokens is not None and len(docstring_tokens) > 0:
        docstring = " ".join(docstring_tokens).strip()
        if docstring not in ["None"]:
            return docstring_tokens
    return None

def get_param_docstring_tokens(js):
    param_names, docstring_tokens_list = [], []
    docstring_params_dict = js.get("docstring_params")
    if docstring_params_dict is not None:
        param_list = docstring_params_dict["params"]
        for param_dict in param_list:
            param_name = param_dict["identifier"]
            if param_name not in ["others"]:
                param_names.append(param_name)
                param_docstring_tokens = param_dict.get("docstring_tokens")
                if param_docstring_tokens is None:
                    param_docstring_tokens = param_dict.get("docstring_token")
                if param_docstring_tokens is not None and len(param_docstring_tokens) > 0:
                    param_docstring = " ".join(param_docstring_tokens).strip()
                    if param_docstring not in ["None"]:
                        docstring_tokens_list.append(param_docstring_tokens)
                    else:
                        docstring_tokens_list.append(None)
                else:
                    docstring_tokens_list.append(None)

    assert len(param_names) == len(docstring_tokens_list)
    return param_names, docstring_tokens_list

def get_return_docstring_tokens(js):
    docstring_param_dict = js.get("docstring_params")
    if docstring_param_dict is not None:
        return_docstring_list = docstring_param_dict["returns"]
        if len(return_docstring_list) > 0:
            return_docstring_dict = return_docstring_list[0]
            return_docstring_tokens = return_docstring_dict.get("docstring_tokens")
            if return_docstring_tokens is None:
                return_docstring_tokens = return_docstring_dict.get("docstring_token")
            if return_docstring_tokens is not None and len(return_docstring_tokens) > 0:
                return_docstring = " ".join(return_docstring_tokens).strip()
                if return_docstring not in ["None"]:
                    return return_docstring_tokens
    return None
Editor is loading...