nroggendorff commited on
Commit
8fa0d42
·
1 Parent(s): 36e0938

more initial commit

Browse files

stick to what I know for now, I guess

Update requirements.txt

Files changed (3) hide show
  1. .gitattributes +1 -0
  2. app.py +207 -0
  3. requirements.txt +0 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ * text eol=ls
app.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import time
3
+ import threading
4
+ import gradio as gr
5
+ from spaces import GPU
6
+ import torch
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
8
+ from flask import Flask, request, jsonify
9
+
10
+ gpu = lambda: GPU(duration=120)
11
+
12
+ quantization_config = BitsAndBytesConfig(
13
+ load_in_4bit=True,
14
+ bnb_4bit_quant_type="nf4",
15
+ bnb_4bit_compute_dtype=torch.bfloat16,
16
+ bnb_4bit_use_double_quant=True,
17
+ )
18
+
19
+ MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct"
20
+
21
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
22
+ model = AutoModelForCausalLM.from_pretrained(
23
+ MODEL_ID,
24
+ quantization_config=quantization_config,
25
+ device_map="auto",
26
+ trust_remote_code=True,
27
+ torch_dtype=torch.bfloat16,
28
+ )
29
+
30
+
31
+ @gpu
32
+ def inference(messages: list, temperature: float, max_tokens: int, top_p: float) -> str:
33
+ input_ids = tokenizer.apply_chat_template(
34
+ messages,
35
+ add_generation_prompt=True,
36
+ return_tensors="pt",
37
+ ).to(model.device)
38
+
39
+ do_sample = temperature > 0.0
40
+ generation_kwargs = {
41
+ "input_ids": input_ids,
42
+ "max_new_tokens": max_tokens,
43
+ "do_sample": do_sample,
44
+ "pad_token_id": tokenizer.eos_token_id,
45
+ "eos_token_id": tokenizer.eos_token_id,
46
+ }
47
+ if do_sample:
48
+ generation_kwargs["temperature"] = temperature
49
+ generation_kwargs["top_p"] = top_p
50
+
51
+ with torch.no_grad():
52
+ output_ids = model.generate(**generation_kwargs)
53
+
54
+ new_tokens = output_ids[0][input_ids.shape[1] :]
55
+ return tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
56
+
57
+
58
+ def run_inference_safe(messages, temperature, max_tokens, top_p):
59
+ try:
60
+ return inference(messages, temperature, max_tokens, top_p), None
61
+ except Exception as e:
62
+ return None, str(e)
63
+
64
+
65
+ def gradio_inference(payload_json: str) -> str:
66
+ try:
67
+ payload = json.loads(payload_json)
68
+ except json.JSONDecodeError as e:
69
+ return json.dumps({"error": f"Invalid JSON: {e}"})
70
+ content, err = run_inference_safe(
71
+ payload.get("messages", []),
72
+ float(payload.get("temperature", 0.7)),
73
+ int(payload.get("max_tokens", 1024)),
74
+ float(payload.get("top_p", 1.0)),
75
+ )
76
+ if err:
77
+ return json.dumps({"error": err})
78
+ return json.dumps({"content": content})
79
+
80
+
81
+ def make_ollama_response(model_name: str, content: str) -> dict:
82
+ return {
83
+ "model": model_name,
84
+ "created_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
85
+ "message": {
86
+ "role": "assistant",
87
+ "content": content,
88
+ },
89
+ "done": True,
90
+ }
91
+
92
+
93
+ flask_app = Flask(__name__)
94
+
95
+
96
+ @flask_app.route("/api/chat", methods=["POST"])
97
+ def ollama_chat():
98
+ body = request.get_json(force=True, silent=True) or {}
99
+ if body.get("stream", False):
100
+ return jsonify({"error": "Streaming is not supported."}), 400
101
+
102
+ messages = body.get("messages", [])
103
+ model_name = body.get("model", "llama")
104
+ options = body.get("options", {})
105
+ temperature = float(options.get("temperature", body.get("temperature", 0.7)))
106
+ max_tokens = int(options.get("num_predict", body.get("num_predict", 1024)))
107
+ top_p = float(options.get("top_p", body.get("top_p", 1.0)))
108
+
109
+ content, err = run_inference_safe(messages, temperature, max_tokens, top_p)
110
+ if err:
111
+ return jsonify({"error": err}), 500
112
+ return jsonify(make_ollama_response(model_name, content))
113
+
114
+
115
+ @flask_app.route("/api/generate", methods=["POST"])
116
+ def ollama_generate():
117
+ body = request.get_json(force=True, silent=True) or {}
118
+ if body.get("stream", False):
119
+ return jsonify({"error": "Streaming is not supported."}), 400
120
+
121
+ prompt = body.get("prompt", "")
122
+ model_name = body.get("model", "llama")
123
+ options = body.get("options", {})
124
+ temperature = float(options.get("temperature", 0.7))
125
+ max_tokens = int(options.get("num_predict", 1024))
126
+ top_p = float(options.get("top_p", 1.0))
127
+
128
+ messages = [{"role": "user", "content": prompt}]
129
+ content, err = run_inference_safe(messages, temperature, max_tokens, top_p)
130
+ if err:
131
+ return jsonify({"error": err}), 500
132
+ return jsonify(make_ollama_response(model_name, content))
133
+
134
+
135
+ @flask_app.route("/api/tags", methods=["GET"])
136
+ def ollama_tags():
137
+ return jsonify(
138
+ {
139
+ "models": [
140
+ {
141
+ "name": "llama",
142
+ "model": "llama",
143
+ "modified_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
144
+ "size": 0,
145
+ "digest": "local",
146
+ "details": {
147
+ "format": "4bit-nf4",
148
+ "family": "llama",
149
+ "parameter_size": "unknown",
150
+ "quantization_level": "Q4_NF4",
151
+ },
152
+ }
153
+ ]
154
+ }
155
+ )
156
+
157
+
158
+ @flask_app.route("/v1/models", methods=["GET"])
159
+ def openai_models():
160
+ return jsonify(
161
+ {
162
+ "object": "list",
163
+ "data": [
164
+ {
165
+ "id": "llama",
166
+ "object": "model",
167
+ "created": int(time.time()),
168
+ "owned_by": "local",
169
+ }
170
+ ],
171
+ }
172
+ )
173
+
174
+
175
+ @flask_app.route("/health", methods=["GET"])
176
+ def health():
177
+ return jsonify({"status": "ok"})
178
+
179
+
180
+ def start_flask():
181
+ flask_app.run(host="0.0.0.0", port=11434, use_reloader=False)
182
+
183
+
184
+ flask_thread = threading.Thread(target=start_flask, daemon=True)
185
+ flask_thread.start()
186
+
187
+
188
+ with gr.Blocks() as demo:
189
+ with gr.Row():
190
+ with gr.Column():
191
+ payload_input = gr.Textbox(
192
+ label="Request payload (JSON)",
193
+ placeholder='{"messages": [{"role": "user", "content": "Hello!"}], "max_tokens": 256}',
194
+ lines=6,
195
+ )
196
+ submit_btn = gr.Button("Run inference", variant="primary")
197
+ with gr.Column():
198
+ output_box = gr.Textbox(label="Response", lines=6)
199
+
200
+ submit_btn.click(
201
+ fn=gradio_inference,
202
+ inputs=payload_input,
203
+ outputs=output_box,
204
+ api_name="predict",
205
+ )
206
+
207
+ demo.launch()
requirements.txt ADDED
Binary file (112 Bytes). View file